[mlir][VectorOps] Introduce a vector.fma op that works on n-D vectors and lowers to llvm.intrin.fmuladd
Summary: The `vector.fma` operation is portable enough across targets that we do not want to keep it wrapped under `vector.outerproduct` and `llvm.intrin.fmuladd`. This revision lifts the op into the vector dialect and implements the lowering to LLVM by using two patterns: 1. a pattern that lowers from n-D to (n-1)-D by unrolling when n > 2 2. a pattern that converts from 1-D to the proper LLVM representation Reviewers: ftynse, stellaraccident, aartbik, dcaballe, jsetoain, tetuante Reviewed By: aartbik Subscribers: fhahn, dcaballe, merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, Joonsoo, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D74075
This commit is contained in:
@@ -410,6 +410,41 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern that turns a vector.fma on a 1-D vector
|
||||
/// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
|
||||
/// This does not match vectors of n >= 2 rank.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// vector.fma %a, %a, %a : vector<8xf32>
|
||||
/// ```
|
||||
/// is converted to:
|
||||
/// ```
|
||||
/// llvm.intr.fma %va, %va, %va:
|
||||
/// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
|
||||
/// -> !llvm<"<8 x float>">
|
||||
/// ```
|
||||
class VectorFMAOp1DConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit VectorFMAOp1DConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: LLVMOpLowering(vector::FMAOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto adaptor = vector::FMAOpOperandAdaptor(operands);
|
||||
vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
|
||||
VectorType vType = fmaOp.getVectorType();
|
||||
if (vType.getRank() != 1)
|
||||
return matchFailure();
|
||||
rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(),
|
||||
adaptor.acc());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
class VectorInsertElementOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit VectorInsertElementOpConversion(MLIRContext *context,
|
||||
@@ -502,6 +537,54 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// %d = vector.fma %a, %b, %c : vector<2x4xf32>
|
||||
/// ```
|
||||
/// is rewritten into:
|
||||
/// ```
|
||||
/// %r = splat %f0: vector<2x4xf32>
|
||||
/// %va = vector.extractvalue %a[0] : vector<2x4xf32>
|
||||
/// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
|
||||
/// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
|
||||
/// %vd = vector.fma %va, %vb, %vc : vector<4xf32>
|
||||
/// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
|
||||
/// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
|
||||
/// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
|
||||
/// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
|
||||
/// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
|
||||
/// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
|
||||
/// // %r3 holds the final value.
|
||||
/// ```
|
||||
class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
|
||||
public:
|
||||
using OpRewritePattern<FMAOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(FMAOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto vType = op.getVectorType();
|
||||
if (vType.getRank() < 2)
|
||||
return matchFailure();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto elemType = vType.getElementType();
|
||||
Value zero = rewriter.create<ConstantOp>(loc, elemType,
|
||||
rewriter.getZeroAttr(elemType));
|
||||
Value desc = rewriter.create<SplatOp>(loc, vType, zero);
|
||||
for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
|
||||
Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
|
||||
Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
|
||||
Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
|
||||
Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
|
||||
desc = rewriter.create<InsertOp>(loc, fma, desc, i);
|
||||
}
|
||||
rewriter.replaceOp(op, desc);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// When ranks are different, InsertStridedSlice needs to extract a properly
|
||||
// ranked vector from the destination vector into which to insert. This pattern
|
||||
// only takes care of this part and forwards the rest of the conversion to
|
||||
@@ -969,14 +1052,16 @@ public:
|
||||
void mlir::populateVectorToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||
MLIRContext *ctx = converter.getDialect()->getContext();
|
||||
patterns.insert<VectorInsertStridedSliceOpDifferentRankRewritePattern,
|
||||
patterns.insert<VectorFMAOpNDRewritePattern,
|
||||
VectorInsertStridedSliceOpDifferentRankRewritePattern,
|
||||
VectorInsertStridedSliceOpSameRankRewritePattern,
|
||||
VectorStridedSliceOpConversion>(ctx);
|
||||
patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
|
||||
VectorExtractElementOpConversion, VectorExtractOpConversion,
|
||||
VectorInsertElementOpConversion, VectorInsertOpConversion,
|
||||
VectorOuterProductOpConversion, VectorTypeCastOpConversion,
|
||||
VectorPrintOpConversion>(ctx, converter);
|
||||
VectorFMAOp1DConversion, VectorInsertElementOpConversion,
|
||||
VectorInsertOpConversion, VectorOuterProductOpConversion,
|
||||
VectorTypeCastOpConversion, VectorPrintOpConversion>(
|
||||
ctx, converter);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
Reference in New Issue
Block a user