[mlir][vector] Use DenseI64ArrayAttr for ExtractOp/InsertOp positions
`DenseI64ArrayAttr` provides a better API than `I64ArrayAttr`. E.g., accessors returning `ArrayRef<int64_t>` (instead of `ArrayAttr`) are generated. Differential Revision: https://reviews.llvm.org/D156684
This commit is contained in:
@@ -1025,44 +1025,37 @@ public:
|
||||
auto loc = extractOp->getLoc();
|
||||
auto resultType = extractOp.getResult().getType();
|
||||
auto llvmResultType = typeConverter->convertType(resultType);
|
||||
auto positionArrayAttr = extractOp.getPosition();
|
||||
ArrayRef<int64_t> positionArray = extractOp.getPosition();
|
||||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmResultType)
|
||||
return failure();
|
||||
|
||||
// Extract entire vector. Should be handled by folder, but just to be safe.
|
||||
if (positionArrayAttr.empty()) {
|
||||
if (positionArray.empty()) {
|
||||
rewriter.replaceOp(extractOp, adaptor.getVector());
|
||||
return success();
|
||||
}
|
||||
|
||||
// One-shot extraction of vector from array (only requires extractvalue).
|
||||
if (isa<VectorType>(resultType)) {
|
||||
SmallVector<int64_t> indices;
|
||||
for (auto idx : positionArrayAttr.getAsRange<IntegerAttr>())
|
||||
indices.push_back(idx.getInt());
|
||||
Value extracted = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, adaptor.getVector(), indices);
|
||||
loc, adaptor.getVector(), positionArray);
|
||||
rewriter.replaceOp(extractOp, extracted);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Potential extraction of 1-D vector from array.
|
||||
Value extracted = adaptor.getVector();
|
||||
auto positionAttrs = positionArrayAttr.getValue();
|
||||
if (positionAttrs.size() > 1) {
|
||||
SmallVector<int64_t> nMinusOnePosition;
|
||||
for (auto idx : positionAttrs.drop_back())
|
||||
nMinusOnePosition.push_back(cast<IntegerAttr>(idx).getInt());
|
||||
extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
|
||||
nMinusOnePosition);
|
||||
if (positionArray.size() > 1) {
|
||||
extracted = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, extracted, positionArray.drop_back());
|
||||
}
|
||||
|
||||
// Remaining extraction of element from 1-D LLVM vector
|
||||
auto position = cast<IntegerAttr>(positionAttrs.back());
|
||||
auto i64Type = IntegerType::get(rewriter.getContext(), 64);
|
||||
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
|
||||
auto constant =
|
||||
rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
|
||||
extracted =
|
||||
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
|
||||
rewriter.replaceOp(extractOp, extracted);
|
||||
@@ -1147,7 +1140,7 @@ public:
|
||||
auto sourceType = insertOp.getSourceType();
|
||||
auto destVectorType = insertOp.getDestVectorType();
|
||||
auto llvmResultType = typeConverter->convertType(destVectorType);
|
||||
auto positionArrayAttr = insertOp.getPosition();
|
||||
ArrayRef<int64_t> positionArray = insertOp.getPosition();
|
||||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmResultType)
|
||||
@@ -1155,7 +1148,7 @@ public:
|
||||
|
||||
// Overwrite entire vector with value. Should be handled by folder, but
|
||||
// just to be safe.
|
||||
if (positionArrayAttr.empty()) {
|
||||
if (positionArray.empty()) {
|
||||
rewriter.replaceOp(insertOp, adaptor.getSource());
|
||||
return success();
|
||||
}
|
||||
@@ -1163,36 +1156,32 @@ public:
|
||||
// One-shot insertion of a vector into an array (only requires insertvalue).
|
||||
if (isa<VectorType>(sourceType)) {
|
||||
Value inserted = rewriter.create<LLVM::InsertValueOp>(
|
||||
loc, adaptor.getDest(), adaptor.getSource(),
|
||||
LLVM::convertArrayToIndices(positionArrayAttr));
|
||||
loc, adaptor.getDest(), adaptor.getSource(), positionArray);
|
||||
rewriter.replaceOp(insertOp, inserted);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Potential extraction of 1-D vector from array.
|
||||
Value extracted = adaptor.getDest();
|
||||
auto positionAttrs = positionArrayAttr.getValue();
|
||||
auto position = cast<IntegerAttr>(positionAttrs.back());
|
||||
auto oneDVectorType = destVectorType;
|
||||
if (positionAttrs.size() > 1) {
|
||||
if (positionArray.size() > 1) {
|
||||
oneDVectorType = reducedVectorTypeBack(destVectorType);
|
||||
extracted = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, extracted,
|
||||
LLVM::convertArrayToIndices(positionAttrs.drop_back()));
|
||||
loc, extracted, positionArray.drop_back());
|
||||
}
|
||||
|
||||
// Insertion of an element into a 1-D LLVM vector.
|
||||
auto i64Type = IntegerType::get(rewriter.getContext(), 64);
|
||||
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
|
||||
auto constant =
|
||||
rewriter.create<LLVM::ConstantOp>(loc, i64Type, positionArray.back());
|
||||
Value inserted = rewriter.create<LLVM::InsertElementOp>(
|
||||
loc, typeConverter->convertType(oneDVectorType), extracted,
|
||||
adaptor.getSource(), constant);
|
||||
|
||||
// Potential insertion of resulting 1-D vector into array.
|
||||
if (positionAttrs.size() > 1) {
|
||||
if (positionArray.size() > 1) {
|
||||
inserted = rewriter.create<LLVM::InsertValueOp>(
|
||||
loc, adaptor.getDest(), inserted,
|
||||
LLVM::convertArrayToIndices(positionAttrs.drop_back()));
|
||||
loc, adaptor.getDest(), inserted, positionArray.drop_back());
|
||||
}
|
||||
|
||||
rewriter.replaceOp(insertOp, inserted);
|
||||
|
||||
Reference in New Issue
Block a user