[mlir][Vector] Add support for Value indices to vector.extract/insert
`vector.extract/insert` ops only support constant indices. This PR is extending them so that arbitrary values can be used instead. This work is part of the RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops Differential Revision: https://reviews.llvm.org/D155034
This commit is contained in:
@@ -126,6 +126,18 @@ static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
|
||||
return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
|
||||
}
|
||||
|
||||
/// Convert `foldResult` into a Value. Integer attribute is converted to
|
||||
/// an LLVM constant op.
|
||||
static Value getAsLLVMValue(OpBuilder &builder, Location loc,
|
||||
OpFoldResult foldResult) {
|
||||
if (auto attr = foldResult.dyn_cast<Attribute>()) {
|
||||
auto intAttr = cast<IntegerAttr>(attr);
|
||||
return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult();
|
||||
}
|
||||
|
||||
return foldResult.get<Value>();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// Trivial Vector to LLVM conversions
|
||||
@@ -1079,41 +1091,53 @@ public:
|
||||
auto loc = extractOp->getLoc();
|
||||
auto resultType = extractOp.getResult().getType();
|
||||
auto llvmResultType = typeConverter->convertType(resultType);
|
||||
ArrayRef<int64_t> positionArray = extractOp.getPosition();
|
||||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmResultType)
|
||||
return failure();
|
||||
|
||||
SmallVector<OpFoldResult> positionVec;
|
||||
for (auto [idx, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
|
||||
if (pos.is<Value>())
|
||||
// Make sure we use the value that has been already converted to LLVM.
|
||||
positionVec.push_back(adaptor.getDynamicPosition()[idx]);
|
||||
else
|
||||
positionVec.push_back(pos);
|
||||
}
|
||||
|
||||
// Extract entire vector. Should be handled by folder, but just to be safe.
|
||||
if (positionArray.empty()) {
|
||||
ArrayRef<OpFoldResult> position(positionVec);
|
||||
if (position.empty()) {
|
||||
rewriter.replaceOp(extractOp, adaptor.getVector());
|
||||
return success();
|
||||
}
|
||||
|
||||
// One-shot extraction of vector from array (only requires extractvalue).
|
||||
if (isa<VectorType>(resultType)) {
|
||||
if (extractOp.hasDynamicPosition())
|
||||
return failure();
|
||||
|
||||
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, adaptor.getVector(), positionArray);
|
||||
loc, adaptor.getVector(), getAsIntegers(position));
|
||||
rewriter.replaceOp(extractOp, extracted);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Potential extraction of 1-D vector from array.
|
||||
Value extracted = adaptor.getVector();
|
||||
if (positionArray.size() > 1) {
|
||||
extracted = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, extracted, positionArray.drop_back());
|
||||
if (position.size() > 1) {
|
||||
if (extractOp.hasDynamicPosition())
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> nMinusOnePosition =
|
||||
getAsIntegers(position.drop_back());
|
||||
extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
|
||||
nMinusOnePosition);
|
||||
}
|
||||
|
||||
// Remaining extraction of element from 1-D LLVM vector
|
||||
auto i64Type = IntegerType::get(rewriter.getContext(), 64);
|
||||
auto constant =
|
||||
rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
|
||||
extracted =
|
||||
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
|
||||
rewriter.replaceOp(extractOp, extracted);
|
||||
|
||||
Value lastPosition = getAsLLVMValue(rewriter, loc, position.back());
|
||||
// Remaining extraction of element from 1-D LLVM vector.
|
||||
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted,
|
||||
lastPosition);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -1194,23 +1218,34 @@ public:
|
||||
auto sourceType = insertOp.getSourceType();
|
||||
auto destVectorType = insertOp.getDestVectorType();
|
||||
auto llvmResultType = typeConverter->convertType(destVectorType);
|
||||
ArrayRef<int64_t> positionArray = insertOp.getPosition();
|
||||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmResultType)
|
||||
return failure();
|
||||
|
||||
SmallVector<OpFoldResult> positionVec;
|
||||
for (auto [idx, pos] : llvm::enumerate(insertOp.getMixedPosition())) {
|
||||
if (pos.is<Value>())
|
||||
// Make sure we use the value that has been already converted to LLVM.
|
||||
positionVec.push_back(adaptor.getDynamicPosition()[idx]);
|
||||
else
|
||||
positionVec.push_back(pos);
|
||||
}
|
||||
|
||||
// Overwrite entire vector with value. Should be handled by folder, but
|
||||
// just to be safe.
|
||||
if (positionArray.empty()) {
|
||||
ArrayRef<OpFoldResult> position(positionVec);
|
||||
if (position.empty()) {
|
||||
rewriter.replaceOp(insertOp, adaptor.getSource());
|
||||
return success();
|
||||
}
|
||||
|
||||
// One-shot insertion of a vector into an array (only requires insertvalue).
|
||||
if (isa<VectorType>(sourceType)) {
|
||||
if (insertOp.hasDynamicPosition())
|
||||
return failure();
|
||||
|
||||
Value inserted = rewriter.create<LLVM::InsertValueOp>(
|
||||
loc, adaptor.getDest(), adaptor.getSource(), positionArray);
|
||||
loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
|
||||
rewriter.replaceOp(insertOp, inserted);
|
||||
return success();
|
||||
}
|
||||
@@ -1218,24 +1253,28 @@ public:
|
||||
// Potential extraction of 1-D vector from array.
|
||||
Value extracted = adaptor.getDest();
|
||||
auto oneDVectorType = destVectorType;
|
||||
if (positionArray.size() > 1) {
|
||||
if (position.size() > 1) {
|
||||
if (insertOp.hasDynamicPosition())
|
||||
return failure();
|
||||
|
||||
oneDVectorType = reducedVectorTypeBack(destVectorType);
|
||||
extracted = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, extracted, positionArray.drop_back());
|
||||
loc, extracted, getAsIntegers(position.drop_back()));
|
||||
}
|
||||
|
||||
// Insertion of an element into a 1-D LLVM vector.
|
||||
auto i64Type = IntegerType::get(rewriter.getContext(), 64);
|
||||
auto constant =
|
||||
rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
|
||||
Value inserted = rewriter.create<LLVM::InsertElementOp>(
|
||||
loc, typeConverter->convertType(oneDVectorType), extracted,
|
||||
adaptor.getSource(), constant);
|
||||
adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));
|
||||
|
||||
// Potential insertion of resulting 1-D vector into array.
|
||||
if (positionArray.size() > 1) {
|
||||
if (position.size() > 1) {
|
||||
if (insertOp.hasDynamicPosition())
|
||||
return failure();
|
||||
|
||||
inserted = rewriter.create<LLVM::InsertValueOp>(
|
||||
loc, adaptor.getDest(), inserted, positionArray.drop_back());
|
||||
loc, adaptor.getDest(), inserted,
|
||||
getAsIntegers(position.drop_back()));
|
||||
}
|
||||
|
||||
rewriter.replaceOp(insertOp, inserted);
|
||||
|
||||
Reference in New Issue
Block a user