[VectorOps] Add [insert/extract]element definition together with lowering to LLVM
Similar to insert/extract vector instructions but (1) work on 1-D vectors only (2) allow for a dynamic index %c3 = constant 3 : index %0 = vector.insertelement %arg0, %arg1[%c : index] : vector<4xf32> %1 = vector.extractelement %arg0[%c3 : index] : vector<4xf32> PiperOrigin-RevId: 285792205
This commit is contained in:
committed by
A. Unique TensorFlower
parent
73ec37c8bb
commit
cd5dab8ad7
@@ -300,6 +300,31 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class VectorExtractElementOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit VectorExtractElementOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: LLVMOpLowering(vector::ExtractElementOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
|
||||
auto extractEltOp = cast<vector::ExtractElementOp>(op);
|
||||
auto vectorType = extractEltOp.getVectorType();
|
||||
auto llvmType = lowering.convertType(vectorType.getElementType());
|
||||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmType)
|
||||
return matchFailure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
|
||||
op, llvmType, adaptor.vector(), adaptor.position());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
class VectorExtractOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit VectorExtractOpConversion(MLIRContext *context,
|
||||
@@ -355,6 +380,31 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class VectorInsertElementOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit VectorInsertElementOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: LLVMOpLowering(vector::InsertElementOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto adaptor = vector::InsertElementOpOperandAdaptor(operands);
|
||||
auto insertEltOp = cast<vector::InsertElementOp>(op);
|
||||
auto vectorType = insertEltOp.getDestVectorType();
|
||||
auto llvmType = lowering.convertType(vectorType);
|
||||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmType)
|
||||
return matchFailure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
|
||||
op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
class VectorInsertOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit VectorInsertOpConversion(MLIRContext *context,
|
||||
@@ -566,7 +616,8 @@ public:
|
||||
void mlir::populateVectorToLLVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||
patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
|
||||
VectorExtractOpConversion, VectorInsertOpConversion,
|
||||
VectorExtractElementOpConversion, VectorExtractOpConversion,
|
||||
VectorInsertElementOpConversion, VectorInsertOpConversion,
|
||||
VectorOuterProductOpConversion, VectorTypeCastOpConversion>(
|
||||
converter.getDialect()->getContext(), converter);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user