[mlir][Vector] Add support for Value indices to vector.extract/insert

`vector.extract/insert` ops only support constant indices. This PR is
extending them so that arbitrary values can be used instead.

This work is part of the RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops

Differential Revision: https://reviews.llvm.org/D155034
This commit is contained in:
Diego Caballero
2023-07-11 17:07:11 +00:00
parent 6ebc179978
commit 98f6289a34
19 changed files with 535 additions and 197 deletions

View File

@@ -126,6 +126,18 @@ static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
}
/// Convert `foldResult` into a Value. Integer attribute is converted to
/// an LLVM constant op.
static Value getAsLLVMValue(OpBuilder &builder, Location loc,
OpFoldResult foldResult) {
if (auto attr = foldResult.dyn_cast<Attribute>()) {
auto intAttr = cast<IntegerAttr>(attr);
return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult();
}
return foldResult.get<Value>();
}
namespace {
/// Trivial Vector to LLVM conversions
@@ -1079,41 +1091,53 @@ public:
auto loc = extractOp->getLoc();
auto resultType = extractOp.getResult().getType();
auto llvmResultType = typeConverter->convertType(resultType);
ArrayRef<int64_t> positionArray = extractOp.getPosition();
// Bail if result type cannot be lowered.
if (!llvmResultType)
return failure();
SmallVector<OpFoldResult> positionVec;
for (auto [idx, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
if (pos.is<Value>())
// Make sure we use the value that has been already converted to LLVM.
positionVec.push_back(adaptor.getDynamicPosition()[idx]);
else
positionVec.push_back(pos);
}
// Extract entire vector. Should be handled by folder, but just to be safe.
if (positionArray.empty()) {
ArrayRef<OpFoldResult> position(positionVec);
if (position.empty()) {
rewriter.replaceOp(extractOp, adaptor.getVector());
return success();
}
// One-shot extraction of vector from array (only requires extractvalue).
if (isa<VectorType>(resultType)) {
if (extractOp.hasDynamicPosition())
return failure();
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, adaptor.getVector(), positionArray);
loc, adaptor.getVector(), getAsIntegers(position));
rewriter.replaceOp(extractOp, extracted);
return success();
}
// Potential extraction of 1-D vector from array.
Value extracted = adaptor.getVector();
if (positionArray.size() > 1) {
extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, extracted, positionArray.drop_back());
if (position.size() > 1) {
if (extractOp.hasDynamicPosition())
return failure();
SmallVector<int64_t> nMinusOnePosition =
getAsIntegers(position.drop_back());
extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
nMinusOnePosition);
}
// Remaining extraction of element from 1-D LLVM vector
auto i64Type = IntegerType::get(rewriter.getContext(), 64);
auto constant =
rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
extracted =
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
rewriter.replaceOp(extractOp, extracted);
Value lastPosition = getAsLLVMValue(rewriter, loc, position.back());
// Remaining extraction of element from 1-D LLVM vector.
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted,
lastPosition);
return success();
}
};
@@ -1194,23 +1218,34 @@ public:
auto sourceType = insertOp.getSourceType();
auto destVectorType = insertOp.getDestVectorType();
auto llvmResultType = typeConverter->convertType(destVectorType);
ArrayRef<int64_t> positionArray = insertOp.getPosition();
// Bail if result type cannot be lowered.
if (!llvmResultType)
return failure();
SmallVector<OpFoldResult> positionVec;
for (auto [idx, pos] : llvm::enumerate(insertOp.getMixedPosition())) {
if (pos.is<Value>())
// Make sure we use the value that has been already converted to LLVM.
positionVec.push_back(adaptor.getDynamicPosition()[idx]);
else
positionVec.push_back(pos);
}
// Overwrite entire vector with value. Should be handled by folder, but
// just to be safe.
if (positionArray.empty()) {
ArrayRef<OpFoldResult> position(positionVec);
if (position.empty()) {
rewriter.replaceOp(insertOp, adaptor.getSource());
return success();
}
// One-shot insertion of a vector into an array (only requires insertvalue).
if (isa<VectorType>(sourceType)) {
if (insertOp.hasDynamicPosition())
return failure();
Value inserted = rewriter.create<LLVM::InsertValueOp>(
loc, adaptor.getDest(), adaptor.getSource(), positionArray);
loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
rewriter.replaceOp(insertOp, inserted);
return success();
}
@@ -1218,24 +1253,28 @@ public:
// Potential extraction of 1-D vector from array.
Value extracted = adaptor.getDest();
auto oneDVectorType = destVectorType;
if (positionArray.size() > 1) {
if (position.size() > 1) {
if (insertOp.hasDynamicPosition())
return failure();
oneDVectorType = reducedVectorTypeBack(destVectorType);
extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, extracted, positionArray.drop_back());
loc, extracted, getAsIntegers(position.drop_back()));
}
// Insertion of an element into a 1-D LLVM vector.
auto i64Type = IntegerType::get(rewriter.getContext(), 64);
auto constant =
rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
Value inserted = rewriter.create<LLVM::InsertElementOp>(
loc, typeConverter->convertType(oneDVectorType), extracted,
adaptor.getSource(), constant);
adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));
// Potential insertion of resulting 1-D vector into array.
if (positionArray.size() > 1) {
if (position.size() > 1) {
if (insertOp.hasDynamicPosition())
return failure();
inserted = rewriter.create<LLVM::InsertValueOp>(
loc, adaptor.getDest(), inserted, positionArray.drop_back());
loc, adaptor.getDest(), inserted,
getAsIntegers(position.drop_back()));
}
rewriter.replaceOp(insertOp, inserted);