[mlir][Vector] Support vector.extract(xfer_read) folding with dynamic indices (#143269)
This PR is part of the last step to remove `vector.extractelement` and `vector.insertelement` ops. RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops It adds support for folding `vector.transfer_read(vector.extract) -> memref.load` with dynamic indices, which is currently supported by `vector.extractelement`.
This commit is contained in:
@@ -886,17 +886,31 @@ class RewriteScalarExtractOfTransferRead
|
||||
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
|
||||
xferOp.getIndices().end());
|
||||
for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
|
||||
assert(isa<Attribute>(pos) && "Unexpected non-constant index");
|
||||
int64_t offset = cast<IntegerAttr>(cast<Attribute>(pos)).getInt();
|
||||
int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
|
||||
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
|
||||
rewriter, extractOp.getLoc(),
|
||||
rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
|
||||
if (auto value = dyn_cast<Value>(ofr)) {
|
||||
|
||||
// Compute affine expression `newIndices[idx] + pos` where `pos` can be
|
||||
// either a constant or a value.
|
||||
OpFoldResult composedIdx;
|
||||
if (auto attr = dyn_cast<Attribute>(pos)) {
|
||||
int64_t offset = cast<IntegerAttr>(attr).getInt();
|
||||
composedIdx = affine::makeComposedFoldedAffineApply(
|
||||
rewriter, extractOp.getLoc(),
|
||||
rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
|
||||
} else {
|
||||
Value dynamicOffset = cast<Value>(pos);
|
||||
AffineExpr sym0, sym1;
|
||||
bindSymbols(rewriter.getContext(), sym0, sym1);
|
||||
composedIdx = affine::makeComposedFoldedAffineApply(
|
||||
rewriter, extractOp.getLoc(), sym0 + sym1,
|
||||
{newIndices[idx], dynamicOffset});
|
||||
}
|
||||
|
||||
// Update the corresponding index with the folded result.
|
||||
if (auto value = dyn_cast<Value>(composedIdx)) {
|
||||
newIndices[idx] = value;
|
||||
} else {
|
||||
newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
|
||||
extractOp.getLoc(), *getConstantIntValue(ofr));
|
||||
extractOp.getLoc(), *getConstantIntValue(composedIdx));
|
||||
}
|
||||
}
|
||||
if (isa<MemRefType>(xferOp.getBase().getType())) {
|
||||
|
||||
@@ -148,3 +148,33 @@ func.func @subvector_extract(%m: memref<?x?xf32>, %idx: index) -> vector<16xf32>
|
||||
return %1 : vector<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
|
||||
// CHECK-LABEL: func @transfer_read_1d_extract_dynamic(
|
||||
// CHECK-SAME: %[[MEMREF:.*]]: memref<?xf32>, %[[M_IDX:.*]]: index, %[[E_IDX:.*]]: index
|
||||
// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[E_IDX]]]
|
||||
// CHECK: %[[RES:.*]] = memref.load %[[MEMREF]][%[[APPLY]]]
|
||||
func.func @transfer_read_1d_extract_dynamic(%m: memref<?xf32>, %idx: index,
|
||||
%offset: index) -> f32 {
|
||||
%cst = arith.constant 0.0 : f32
|
||||
%vec = vector.transfer_read %m[%idx], %cst {in_bounds = [true]} : memref<?xf32>, vector<5xf32>
|
||||
%elem = vector.extract %vec[%offset] : f32 from vector<5xf32>
|
||||
return %elem : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
|
||||
// CHECK-LABEL: func @transfer_read_2d_extract_dynamic(
|
||||
// CHECK-SAME: %[[MEMREF:.*]]: memref<?x?xf32>, %[[ROW_IDX:.*]]: index, %[[COL_IDX:.*]]: index, %[[ROW_OFFSET:.*]]: index, %[[COL_OFFSET:.*]]: index
|
||||
// CHECK: %[[ROW_APPLY:.*]] = affine.apply #[[$MAP]]()[%[[ROW_IDX]], %[[ROW_OFFSET]]]
|
||||
// CHECK: %[[COL_APPLY:.*]] = affine.apply #[[$MAP]]()[%[[COL_IDX]], %[[COL_OFFSET]]]
|
||||
// CHECK: %[[RES:.*]] = memref.load %[[MEMREF]][%[[ROW_APPLY]], %[[COL_APPLY]]]
|
||||
func.func @transfer_read_2d_extract_dynamic(%m: memref<?x?xf32>, %row_idx: index, %col_idx: index,
|
||||
%row_offset: index, %col_offset: index) -> f32 {
|
||||
%cst = arith.constant 0.0 : f32
|
||||
%vec = vector.transfer_read %m[%row_idx, %col_idx], %cst {in_bounds = [true, true]} : memref<?x?xf32>, vector<10x5xf32>
|
||||
%elem = vector.extract %vec[%row_offset, %col_offset] : f32 from vector<10x5xf32>
|
||||
return %elem : f32
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user