[mlir] [VectorOps,LinAlg] Remove direct LLVM lowering for vector.broadcast

Summary:
The direct lowering of vector.broadcast into LLVM has been replaced by
progressive lowering into elementary vector ops. This also required a
small refactoring of a llvm.mlir test that used a direct vector.broadcast
operator (just to define a matmul).

Reviewers: nicolasvasilache, andydavis1, rriddle

Reviewed By: nicolasvasilache

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D76143
This commit is contained in:
aartbik
2020-03-13 09:35:29 -07:00
parent 89b19e8959
commit a213ece30b
2 changed files with 19 additions and 64 deletions

View File

@@ -817,59 +817,6 @@ public:
}
};
// TODO(ajcbik): remove this rule once LinAlg tests are cleaned up
class VectorOuterProductOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorOuterProductOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::OuterProductOp::getOperationName(),
context, typeConverter) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
auto *ctx = op->getContext();
auto vLHS = adaptor.lhs().getType().cast<LLVM::LLVMType>();
auto vRHS = adaptor.rhs().getType().cast<LLVM::LLVMType>();
auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements();
auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
auto llvmArrayOfVectType = typeConverter.convertType(
cast<vector::OuterProductOp>(op).getResult().getType());
Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
Value a = adaptor.lhs(), b = adaptor.rhs();
Value acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
SmallVector<Value, 8> lhs, accs;
lhs.reserve(rankLHS);
accs.reserve(rankLHS);
for (unsigned d = 0, e = rankLHS; d < e; ++d) {
// shufflevector explicitly requires i32.
auto attr = rewriter.getI32IntegerAttr(d);
SmallVector<Attribute, 4> bcastAttr(rankRHS, attr);
auto bcastArrayAttr = ArrayAttr::get(bcastAttr, ctx);
Value aD = nullptr, accD = nullptr;
// 1. Broadcast the element a[d] into vector aD.
aD = rewriter.create<LLVM::ShuffleVectorOp>(loc, a, a, bcastArrayAttr);
// 2. If acc is present, extract 1-d vector acc[d] into accD.
if (acc)
accD = rewriter.create<LLVM::ExtractValueOp>(
loc, vRHS, acc, rewriter.getI64ArrayAttr(d));
// 3. Compute aD outer b (plus accD, if relevant).
Value aOuterbD =
accD
? rewriter.create<LLVM::FMAOp>(loc, vRHS, aD, b, accD).getResult()
: rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult();
// 4. Insert as value `d` in the descriptor.
desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType,
desc, aOuterbD,
rewriter.getI64ArrayAttr(d));
}
rewriter.replaceOp(op, desc);
return matchSuccess();
}
};
class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorTypeCastOpConversion(MLIRContext *context,
@@ -1160,8 +1107,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorShuffleOpConversion, VectorExtractElementOpConversion,
VectorExtractOpConversion, VectorFMAOp1DConversion,
VectorInsertElementOpConversion, VectorInsertOpConversion,
VectorOuterProductOpConversion, VectorTypeCastOpConversion,
VectorPrintOpConversion>(ctx, converter);
VectorTypeCastOpConversion, VectorPrintOpConversion>(
ctx, converter);
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(