[mlir][vector] Use DenseI64ArrayAttr for ExtractOp/InsertOp positions

`DenseI64ArrayAttr` provides a better API than `I64ArrayAttr`. E.g., accessors returning `ArrayRef<int64_t>` (instead of `ArrayAttr`) are generated.

Differential Revision: https://reviews.llvm.org/D156684
This commit is contained in:
Matthias Springer
2023-07-31 15:21:29 +02:00
parent aba0ef7059
commit 16b75cd2bb
14 changed files with 100 additions and 163 deletions

View File

@@ -1025,44 +1025,37 @@ public:
auto loc = extractOp->getLoc();
auto resultType = extractOp.getResult().getType();
auto llvmResultType = typeConverter->convertType(resultType);
auto positionArrayAttr = extractOp.getPosition();
ArrayRef<int64_t> positionArray = extractOp.getPosition();
// Bail if result type cannot be lowered.
if (!llvmResultType)
return failure();
// Extract entire vector. Should be handled by folder, but just to be safe.
if (positionArrayAttr.empty()) {
if (positionArray.empty()) {
rewriter.replaceOp(extractOp, adaptor.getVector());
return success();
}
// One-shot extraction of vector from array (only requires extractvalue).
if (isa<VectorType>(resultType)) {
SmallVector<int64_t> indices;
for (auto idx : positionArrayAttr.getAsRange<IntegerAttr>())
indices.push_back(idx.getInt());
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, adaptor.getVector(), indices);
loc, adaptor.getVector(), positionArray);
rewriter.replaceOp(extractOp, extracted);
return success();
}
// Potential extraction of 1-D vector from array.
Value extracted = adaptor.getVector();
auto positionAttrs = positionArrayAttr.getValue();
if (positionAttrs.size() > 1) {
SmallVector<int64_t> nMinusOnePosition;
for (auto idx : positionAttrs.drop_back())
nMinusOnePosition.push_back(cast<IntegerAttr>(idx).getInt());
extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
nMinusOnePosition);
if (positionArray.size() > 1) {
extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, extracted, positionArray.drop_back());
}
// Remaining extraction of element from 1-D LLVM vector
auto position = cast<IntegerAttr>(positionAttrs.back());
auto i64Type = IntegerType::get(rewriter.getContext(), 64);
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
auto constant =
rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
extracted =
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
rewriter.replaceOp(extractOp, extracted);
@@ -1147,7 +1140,7 @@ public:
auto sourceType = insertOp.getSourceType();
auto destVectorType = insertOp.getDestVectorType();
auto llvmResultType = typeConverter->convertType(destVectorType);
auto positionArrayAttr = insertOp.getPosition();
ArrayRef<int64_t> positionArray = insertOp.getPosition();
// Bail if result type cannot be lowered.
if (!llvmResultType)
@@ -1155,7 +1148,7 @@ public:
// Overwrite entire vector with value. Should be handled by folder, but
// just to be safe.
if (positionArrayAttr.empty()) {
if (positionArray.empty()) {
rewriter.replaceOp(insertOp, adaptor.getSource());
return success();
}
@@ -1163,36 +1156,32 @@ public:
// One-shot insertion of a vector into an array (only requires insertvalue).
if (isa<VectorType>(sourceType)) {
Value inserted = rewriter.create<LLVM::InsertValueOp>(
loc, adaptor.getDest(), adaptor.getSource(),
LLVM::convertArrayToIndices(positionArrayAttr));
loc, adaptor.getDest(), adaptor.getSource(), positionArray);
rewriter.replaceOp(insertOp, inserted);
return success();
}
// Potential extraction of 1-D vector from array.
Value extracted = adaptor.getDest();
auto positionAttrs = positionArrayAttr.getValue();
auto position = cast<IntegerAttr>(positionAttrs.back());
auto oneDVectorType = destVectorType;
if (positionAttrs.size() > 1) {
if (positionArray.size() > 1) {
oneDVectorType = reducedVectorTypeBack(destVectorType);
extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, extracted,
LLVM::convertArrayToIndices(positionAttrs.drop_back()));
loc, extracted, positionArray.drop_back());
}
// Insertion of an element into a 1-D LLVM vector.
auto i64Type = IntegerType::get(rewriter.getContext(), 64);
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
auto constant =
rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
Value inserted = rewriter.create<LLVM::InsertElementOp>(
loc, typeConverter->convertType(oneDVectorType), extracted,
adaptor.getSource(), constant);
// Potential insertion of resulting 1-D vector into array.
if (positionAttrs.size() > 1) {
if (positionArray.size() > 1) {
inserted = rewriter.create<LLVM::InsertValueOp>(
loc, adaptor.getDest(), inserted,
LLVM::convertArrayToIndices(positionAttrs.drop_back()));
loc, adaptor.getDest(), inserted, positionArray.drop_back());
}
rewriter.replaceOp(insertOp, inserted);