[mlir][vector] Add unroll patterns for vector.load and vector.store (#143420)
This PR adds unroll patterns for vector.load and vector.store. This PR is follow up of #137558
This commit is contained in:
@@ -1736,7 +1736,9 @@ def Vector_TransferWriteOp :
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def Vector_LoadOp : Vector_Op<"load"> {
|
||||
def Vector_LoadOp : Vector_Op<"load", [
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
|
||||
]> {
|
||||
let summary = "reads an n-D slice of memory into an n-D vector";
|
||||
let description = [{
|
||||
The 'vector.load' operation reads an n-D slice of memory into an n-D
|
||||
@@ -1822,7 +1824,9 @@ def Vector_LoadOp : Vector_Op<"load"> {
|
||||
"$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
|
||||
}
|
||||
|
||||
def Vector_StoreOp : Vector_Op<"store"> {
|
||||
def Vector_StoreOp : Vector_Op<"store", [
|
||||
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
|
||||
]> {
|
||||
let summary = "writes an n-D vector to an n-D slice of memory";
|
||||
let description = [{
|
||||
The 'vector.store' operation writes an n-D vector to an n-D slice of memory.
|
||||
|
||||
@@ -5371,6 +5371,10 @@ OpFoldResult LoadOp::fold(FoldAdaptor) {
|
||||
return OpFoldResult();
|
||||
}
|
||||
|
||||
std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
|
||||
return llvm::to_vector<4>(getVectorType().getShape());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StoreOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -5406,6 +5410,10 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor,
|
||||
return memref::foldMemRefCast(*this);
|
||||
}
|
||||
|
||||
std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
|
||||
return llvm::to_vector<4>(getVectorType().getShape());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MaskedLoadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -54,6 +54,28 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
|
||||
return slicedIndices;
|
||||
}
|
||||
|
||||
// Compute the new indices by adding `offsets` to `originalIndices`.
|
||||
// If m < n (m = offsets.size(), n = originalIndices.size()),
|
||||
// then only the trailing m values in `originalIndices` are updated.
|
||||
static SmallVector<Value> sliceLoadStoreIndices(PatternRewriter &rewriter,
|
||||
Location loc,
|
||||
OperandRange originalIndices,
|
||||
ArrayRef<int64_t> offsets) {
|
||||
assert(offsets.size() <= originalIndices.size() &&
|
||||
"Offsets should not exceed the number of original indices");
|
||||
SmallVector<Value> indices(originalIndices);
|
||||
|
||||
auto start = indices.size() - offsets.size();
|
||||
for (auto [i, offset] : llvm::enumerate(offsets)) {
|
||||
if (offset != 0) {
|
||||
indices[start + i] = rewriter.create<arith::AddIOp>(
|
||||
loc, originalIndices[start + i],
|
||||
rewriter.create<arith::ConstantIndexOp>(loc, offset));
|
||||
}
|
||||
}
|
||||
return indices;
|
||||
}
|
||||
|
||||
// Clones `op` into a new operations that takes `operands` and returns
|
||||
// `resultTypes`.
|
||||
static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
|
||||
@@ -631,6 +653,90 @@ private:
|
||||
vector::UnrollVectorOptions options;
|
||||
};
|
||||
|
||||
struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
|
||||
UnrollLoadPattern(MLIRContext *context,
|
||||
const vector::UnrollVectorOptions &options,
|
||||
PatternBenefit benefit = 1)
|
||||
: OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
VectorType vecType = loadOp.getVectorType();
|
||||
|
||||
auto targetShape = getTargetShape(options, loadOp);
|
||||
if (!targetShape)
|
||||
return failure();
|
||||
|
||||
Location loc = loadOp.getLoc();
|
||||
ArrayRef<int64_t> originalShape = vecType.getShape();
|
||||
SmallVector<int64_t> strides(targetShape->size(), 1);
|
||||
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, vecType, rewriter.getZeroAttr(vecType));
|
||||
|
||||
SmallVector<int64_t> loopOrder =
|
||||
getUnrollOrder(originalShape.size(), loadOp, options);
|
||||
|
||||
auto targetVecType =
|
||||
VectorType::get(*targetShape, vecType.getElementType());
|
||||
|
||||
for (SmallVector<int64_t> offsets :
|
||||
StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
|
||||
SmallVector<Value> indices =
|
||||
sliceLoadStoreIndices(rewriter, loc, loadOp.getIndices(), offsets);
|
||||
Value slicedLoad = rewriter.create<vector::LoadOp>(
|
||||
loc, targetVecType, loadOp.getBase(), indices);
|
||||
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
|
||||
loc, slicedLoad, result, offsets, strides);
|
||||
}
|
||||
rewriter.replaceOp(loadOp, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
vector::UnrollVectorOptions options;
|
||||
};
|
||||
|
||||
struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
|
||||
UnrollStorePattern(MLIRContext *context,
|
||||
const vector::UnrollVectorOptions &options,
|
||||
PatternBenefit benefit = 1)
|
||||
: OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
|
||||
|
||||
LogicalResult matchAndRewrite(vector::StoreOp storeOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
VectorType vecType = storeOp.getVectorType();
|
||||
|
||||
auto targetShape = getTargetShape(options, storeOp);
|
||||
if (!targetShape)
|
||||
return failure();
|
||||
|
||||
Location loc = storeOp.getLoc();
|
||||
ArrayRef<int64_t> originalShape = vecType.getShape();
|
||||
SmallVector<int64_t> strides(targetShape->size(), 1);
|
||||
|
||||
Value base = storeOp.getBase();
|
||||
Value vector = storeOp.getValueToStore();
|
||||
|
||||
SmallVector<int64_t> loopOrder =
|
||||
getUnrollOrder(originalShape.size(), storeOp, options);
|
||||
|
||||
for (SmallVector<int64_t> offsets :
|
||||
StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
|
||||
SmallVector<Value> indices =
|
||||
sliceLoadStoreIndices(rewriter, loc, storeOp.getIndices(), offsets);
|
||||
Value slice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
|
||||
loc, vector, offsets, *targetShape, strides);
|
||||
rewriter.create<vector::StoreOp>(loc, slice, base, indices);
|
||||
}
|
||||
rewriter.eraseOp(storeOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
vector::UnrollVectorOptions options;
|
||||
};
|
||||
|
||||
struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
|
||||
UnrollBroadcastPattern(MLIRContext *context,
|
||||
const vector::UnrollVectorOptions &options,
|
||||
@@ -699,10 +805,10 @@ private:
|
||||
void mlir::vector::populateVectorUnrollPatterns(
|
||||
RewritePatternSet &patterns, const UnrollVectorOptions &options,
|
||||
PatternBenefit benefit) {
|
||||
patterns
|
||||
.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
|
||||
UnrollContractionPattern, UnrollElementwisePattern,
|
||||
UnrollReductionPattern, UnrollMultiReductionPattern,
|
||||
UnrollTransposePattern, UnrollGatherPattern, UnrollBroadcastPattern>(
|
||||
patterns.getContext(), options, benefit);
|
||||
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
|
||||
UnrollContractionPattern, UnrollElementwisePattern,
|
||||
UnrollReductionPattern, UnrollMultiReductionPattern,
|
||||
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
|
||||
UnrollStorePattern, UnrollBroadcastPattern>(
|
||||
patterns.getContext(), options, benefit);
|
||||
}
|
||||
|
||||
@@ -378,3 +378,45 @@ func.func @vector_broadcast_with_tailing_unit_dim(%v: vector<4x1xf32>) -> vector
|
||||
// CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<2x1xf32> to vector<2x2xf32>
|
||||
// CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
|
||||
// CHECK: return [[r3]] : vector<4x4xf32>
|
||||
|
||||
|
||||
func.func @vector_load_2D(%mem: memref<4x4xf16>) -> vector<4x4xf16> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%0 = vector.load %mem[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
|
||||
return %0 : vector<4x4xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @vector_load_2D(
|
||||
// CHECK-SAME: %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> {
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16>
|
||||
// CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
|
||||
// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16>
|
||||
// CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C0]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
|
||||
// CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16>
|
||||
// CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
|
||||
// CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16>
|
||||
// CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
|
||||
// CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16>
|
||||
// CHECK: return %[[V7]] : vector<4x4xf16>
|
||||
|
||||
|
||||
func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
vector.store %v, %mem[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @vector_store_2D(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) {
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[V0:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
|
||||
// CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
|
||||
// CHECK: %[[V1:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
|
||||
// CHECK: vector.store %[[V1]], %[[ARG0]][%[[C0]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
|
||||
// CHECK: %[[V2:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
|
||||
// CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
|
||||
// CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
|
||||
// CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
|
||||
|
||||
@@ -163,7 +163,8 @@ struct TestVectorUnrollingPatterns
|
||||
.setFilterConstraint([](Operation *op) {
|
||||
return success(
|
||||
isa<arith::AddFOp, vector::FMAOp, vector::MultiDimReductionOp,
|
||||
vector::BroadcastOp>(op));
|
||||
vector::BroadcastOp, vector::LoadOp, vector::StoreOp>(
|
||||
op));
|
||||
}));
|
||||
populateVectorUnrollPatterns(
|
||||
patterns, UnrollVectorOptions()
|
||||
|
||||
Reference in New Issue
Block a user