[mlir] [VectorOps,LinAlg] Remove direct LLVM lowering for vector.broadcast
Summary: The direct lowering of vector.broadcast into LLVM has been replaced by progressive lowering into elementary vector ops. This also required a small refactoring of a llvm.mlir test that used a direct vector.broadcast operator (just to define a matmul). Reviewers: nicolasvasilache, andydavis1, rriddle Reviewed By: nicolasvasilache Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D76143
This commit is contained in:
@@ -817,59 +817,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(ajcbik): remove this rule once LinAlg tests are cleaned up
|
||||
class VectorOuterProductOpConversion : public ConvertToLLVMPattern {
|
||||
public:
|
||||
explicit VectorOuterProductOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::OuterProductOp::getOperationName(),
|
||||
context, typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
|
||||
auto *ctx = op->getContext();
|
||||
auto vLHS = adaptor.lhs().getType().cast<LLVM::LLVMType>();
|
||||
auto vRHS = adaptor.rhs().getType().cast<LLVM::LLVMType>();
|
||||
auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements();
|
||||
auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
|
||||
auto llvmArrayOfVectType = typeConverter.convertType(
|
||||
cast<vector::OuterProductOp>(op).getResult().getType());
|
||||
Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
|
||||
Value a = adaptor.lhs(), b = adaptor.rhs();
|
||||
Value acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
|
||||
SmallVector<Value, 8> lhs, accs;
|
||||
lhs.reserve(rankLHS);
|
||||
accs.reserve(rankLHS);
|
||||
for (unsigned d = 0, e = rankLHS; d < e; ++d) {
|
||||
// shufflevector explicitly requires i32.
|
||||
auto attr = rewriter.getI32IntegerAttr(d);
|
||||
SmallVector<Attribute, 4> bcastAttr(rankRHS, attr);
|
||||
auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx);
|
||||
Value aD = nullptr, accD = nullptr;
|
||||
// 1. Broadcast the element a[d] into vector aD.
|
||||
aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr);
|
||||
// 2. If acc is present, extract 1-d vector acc[d] into accD.
|
||||
if (acc)
|
||||
accD = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, vRHS, acc, rewriter.getI64ArrayAttr(d));
|
||||
// 3. Compute aD outer b (plus accD, if relevant).
|
||||
Value aOuterbD =
|
||||
accD
|
||||
? rewriter.create<LLVM::FMAOp>(loc, vRHS, aD, b, accD).getResult()
|
||||
: rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult();
|
||||
// 4. Insert as value `d` in the descriptor.
|
||||
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType,
|
||||
desc, aOuterbD,
|
||||
rewriter.getI64ArrayAttr(d));
|
||||
}
|
||||
rewriter.replaceOp(op, desc);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
|
||||
public:
|
||||
explicit VectorTypeCastOpConversion(MLIRContext *context,
|
||||
@@ -1160,8 +1107,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
||||
VectorShuffleOpConversion, VectorExtractElementOpConversion,
|
||||
VectorExtractOpConversion, VectorFMAOp1DConversion,
|
||||
VectorInsertElementOpConversion, VectorInsertOpConversion,
|
||||
VectorOuterProductOpConversion, VectorTypeCastOpConversion,
|
||||
VectorPrintOpConversion>(ctx, converter);
|
||||
VectorTypeCastOpConversion, VectorPrintOpConversion>(
|
||||
ctx, converter);
|
||||
}
|
||||
|
||||
void mlir::populateVectorToLLVMMatrixConversionPatterns(
|
||||
|
||||
Reference in New Issue
Block a user