[VectorOps] Add [insert/extract]element definition together with lowering to LLVM

Similar to insert/extract vector instructions but
(1) work on 1-D vectors only
(2) allow for a dynamic index

  %c3 = constant 3 : index
  %0 = vector.insertelement %arg0, %arg1[%c : index] : vector<4xf32>
  %1 = vector.extractelement %arg0[%c3 : index] : vector<4xf32>

PiperOrigin-RevId: 285792205
This commit is contained in:
Aart Bik
2019-12-16 09:52:13 -08:00
committed by A. Unique TensorFlower
parent 73ec37c8bb
commit cd5dab8ad7
6 changed files with 248 additions and 2 deletions

View File

@@ -300,6 +300,31 @@ public:
}
};
class VectorExtractElementOpConversion : public LLVMOpLowering {
public:
explicit VectorExtractElementOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
auto extractEltOp = cast<vector::ExtractElementOp>(op);
auto vectorType = extractEltOp.getVectorType();
auto llvmType = lowering.convertType(vectorType.getElementType());
// Bail if result type cannot be lowered.
if (!llvmType)
return matchFailure();
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
op, llvmType, adaptor.vector(), adaptor.position());
return matchSuccess();
}
};
class VectorExtractOpConversion : public LLVMOpLowering {
public:
explicit VectorExtractOpConversion(MLIRContext *context,
@@ -355,6 +380,31 @@ public:
}
};
class VectorInsertElementOpConversion : public LLVMOpLowering {
public:
explicit VectorInsertElementOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: LLVMOpLowering(vector::InsertElementOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto adaptor = vector::InsertElementOpOperandAdaptor(operands);
auto insertEltOp = cast<vector::InsertElementOp>(op);
auto vectorType = insertEltOp.getDestVectorType();
auto llvmType = lowering.convertType(vectorType);
// Bail if result type cannot be lowered.
if (!llvmType)
return matchFailure();
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
return matchSuccess();
}
};
class VectorInsertOpConversion : public LLVMOpLowering {
public:
explicit VectorInsertOpConversion(MLIRContext *context,
@@ -566,7 +616,8 @@ public:
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
VectorExtractOpConversion, VectorInsertOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
VectorInsertElementOpConversion, VectorInsertOpConversion,
VectorOuterProductOpConversion, VectorTypeCastOpConversion>(
converter.getDialect()->getContext(), converter);
}