diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 7dbb7a334fe6..384717aeca66 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -886,17 +886,31 @@ class RewriteScalarExtractOfTransferRead SmallVector newIndices(xferOp.getIndices().begin(), xferOp.getIndices().end()); for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) { - assert(isa(pos) && "Unexpected non-constant index"); - int64_t offset = cast(cast(pos)).getInt(); int64_t idx = newIndices.size() - extractOp.getNumIndices() + i; - OpFoldResult ofr = affine::makeComposedFoldedAffineApply( - rewriter, extractOp.getLoc(), - rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]}); - if (auto value = dyn_cast(ofr)) { + + // Compute affine expression `newIndices[idx] + pos` where `pos` can be + // either a constant or a value. + OpFoldResult composedIdx; + if (auto attr = dyn_cast(pos)) { + int64_t offset = cast(attr).getInt(); + composedIdx = affine::makeComposedFoldedAffineApply( + rewriter, extractOp.getLoc(), + rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]}); + } else { + Value dynamicOffset = cast(pos); + AffineExpr sym0, sym1; + bindSymbols(rewriter.getContext(), sym0, sym1); + composedIdx = affine::makeComposedFoldedAffineApply( + rewriter, extractOp.getLoc(), sym0 + sym1, + {newIndices[idx], dynamicOffset}); + } + + // Update the corresponding index with the folded result. + if (auto value = dyn_cast(composedIdx)) { newIndices[idx] = value; } else { newIndices[idx] = rewriter.create( - extractOp.getLoc(), *getConstantIntValue(ofr)); + extractOp.getLoc(), *getConstantIntValue(composedIdx)); } } if (isa(xferOp.getBase().getType())) { diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir index 52b0fdee184f..7a1d6b3a8344 100644 --- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir +++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir @@ -148,3 +148,33 @@ func.func @subvector_extract(%m: memref, %idx: index) -> vector<16xf32> return %1 : vector<16xf32> } +// ----- + +// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-LABEL: func @transfer_read_1d_extract_dynamic( +// CHECK-SAME: %[[MEMREF:.*]]: memref, %[[M_IDX:.*]]: index, %[[E_IDX:.*]]: index +// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()[%[[M_IDX]], %[[E_IDX]]] +// CHECK: %[[RES:.*]] = memref.load %[[MEMREF]][%[[APPLY]]] +func.func @transfer_read_1d_extract_dynamic(%m: memref, %idx: index, + %offset: index) -> f32 { + %cst = arith.constant 0.0 : f32 + %vec = vector.transfer_read %m[%idx], %cst {in_bounds = [true]} : memref, vector<5xf32> + %elem = vector.extract %vec[%offset] : f32 from vector<5xf32> + return %elem : f32 +} + +// ----- + +// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-LABEL: func @transfer_read_2d_extract_dynamic( +// CHECK-SAME: %[[MEMREF:.*]]: memref, %[[ROW_IDX:.*]]: index, %[[COL_IDX:.*]]: index, %[[ROW_OFFSET:.*]]: index, %[[COL_OFFSET:.*]]: index +// CHECK: %[[ROW_APPLY:.*]] = affine.apply #[[$MAP]]()[%[[ROW_IDX]], %[[ROW_OFFSET]]] +// CHECK: %[[COL_APPLY:.*]] = affine.apply #[[$MAP]]()[%[[COL_IDX]], %[[COL_OFFSET]]] +// CHECK: %[[RES:.*]] = memref.load %[[MEMREF]][%[[ROW_APPLY]], %[[COL_APPLY]]] +func.func @transfer_read_2d_extract_dynamic(%m: memref, %row_idx: index, %col_idx: index, + %row_offset: index, %col_offset: index) -> f32 { + %cst = arith.constant 0.0 : f32 + %vec = vector.transfer_read %m[%row_idx, %col_idx], %cst {in_bounds = [true, true]} : memref, vector<10x5xf32> + %elem = vector.extract %vec[%row_offset, %col_offset] : f32 from vector<10x5xf32> + return %elem : f32 +}