[mlir][MemRef] Changed AssumeAlignment into a Pure ViewLikeOp (#139521)

Made AssumeAlignment a ViewLikeOp that returns a new SSA memref equal
to its memref argument and made it have Pure trait. This
gives it a defined memory effect that matches what it does in practice
and makes it behave nicely with optimizations which won't get rid of it
unless its result isn't being used.
This commit is contained in:
Shay Kleiman
2025-05-18 13:50:29 +03:00
committed by GitHub
parent a0a2a1e095
commit ffb9bbfd07
11 changed files with 87 additions and 41 deletions

View File

@@ -142,22 +142,37 @@ class AllocLikeOp<string mnemonic,
// AssumeAlignmentOp
//===----------------------------------------------------------------------===//
def AssumeAlignmentOp : MemRef_Op<"assume_alignment"> {
def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
Pure,
ViewLikeOpInterface,
SameOperandsAndResultType
]> {
let summary =
"assertion that gives alignment information to the input memref";
"assumption that gives alignment information to the input memref";
let description = [{
The `assume_alignment` operation takes a memref and an integer of alignment
value, and internally annotates the buffer with the given alignment. If
the buffer isn't aligned to the given alignment, the behavior is undefined.
The `assume_alignment` operation takes a memref and an integer alignment
value. It returns a new SSA value of the same memref type, but associated
with the assumption that the underlying buffer is aligned to the given
alignment.
This operation doesn't affect the semantics of a correct program. It's for
optimization only, and the optimization is best-effort.
If the buffer isn't aligned to the given alignment, its result is poison.
This operation doesn't affect the semantics of a program where the
alignment assumption holds true. It is intended for optimization purposes,
allowing the compiler to generate more efficient code based on the
alignment assumption. The optimization is best-effort.
}];
let arguments = (ins AnyMemRef:$memref,
ConfinedAttr<I32Attr, [IntPositive]>:$alignment);
let results = (outs);
let results = (outs AnyMemRef:$result);
let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
let extraClassDeclaration = [{
MemRefType getType() { return ::llvm::cast<MemRefType>(getResult().getType()); }
Value getViewSource() { return getMemref(); }
}];
let hasVerifier = 1;
}

View File

@@ -432,8 +432,7 @@ struct AssumeAlignmentOpLowering
createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
alignmentConst);
rewriter.eraseOp(op);
rewriter.replaceOp(op, memref);
return success();
}
};

View File

@@ -44,13 +44,6 @@ using namespace mlir::gpu;
// The functions below provide interface-like verification, but are too specific
// to barrier elimination to become interfaces.
/// Implement the MemoryEffectsOpInterface in the suitable way.
static bool isKnownNoEffectsOpWithoutInterface(Operation *op) {
// memref::AssumeAlignment is conceptually pure, but marking it as such would
// make DCE immediately remove it.
return isa<memref::AssumeAlignmentOp>(op);
}
/// Returns `true` if the op is defines the parallel region that is subject to
/// barrier synchronization.
static bool isParallelRegionBoundary(Operation *op) {
@@ -101,10 +94,6 @@ collectEffects(Operation *op,
if (ignoreBarriers && isa<BarrierOp>(op))
return true;
// Skip over ops that we know have no effects.
if (isKnownNoEffectsOpWithoutInterface(op))
return true;
// Collect effect instances the operation. Note that the implementation of
// getEffects erases all effect instances that have the type other than the
// template parameter so we collect them first in a local buffer and then

View File

@@ -527,6 +527,11 @@ LogicalResult AssumeAlignmentOp::verify() {
return success();
}
void AssumeAlignmentOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "assume_align");
}
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//

View File

@@ -229,7 +229,7 @@ struct ConvertMemRefAssumeAlignment final
}
rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
op, adaptor.getMemref(), adaptor.getAlignmentAttr());
op, newTy, adaptor.getMemref(), adaptor.getAlignmentAttr());
return success();
}
};

View File

@@ -919,6 +919,35 @@ public:
}
};
/// Pattern to replace `extract_strided_metadata(assume_alignment)`
///
/// With
/// \verbatim
/// extract_strided_metadata(memref)
/// \endverbatim
///
/// Since `assume_alignment` is a view-like op that does not modify the
/// underlying buffer, offset, sizes, or strides, extracting strided metadata
/// from its result is equivalent to extracting it from its source. This
/// canonicalization removes the unnecessary indirection.
struct ExtractStridedMetadataOpAssumeAlignmentFolder
: public OpRewritePattern<memref::ExtractStridedMetadataOp> {
public:
using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
PatternRewriter &rewriter) const override {
auto assumeAlignmentOp =
op.getSource().getDefiningOp<memref::AssumeAlignmentOp>();
if (!assumeAlignmentOp)
return failure();
rewriter.replaceOpWithNewOp<memref::ExtractStridedMetadataOp>(
op, assumeAlignmentOp.getViewSource());
return success();
}
};
/// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the
/// source of the ViewLikeOp.
class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
@@ -1185,6 +1214,7 @@ void memref::populateExpandStridedMetadataPatterns(
ExtractStridedMetadataOpSubviewFolder,
ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpMemorySpaceCastFolder,
ExtractStridedMetadataOpAssumeAlignmentFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
patterns.getContext());
}
@@ -1201,6 +1231,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
ExtractStridedMetadataOpReinterpretCastFolder,
ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpMemorySpaceCastFolder,
ExtractStridedMetadataOpAssumeAlignmentFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
patterns.getContext());
}

View File

@@ -683,7 +683,7 @@ func.func @load_and_assume(
%arg0: memref<?x?xf32, strided<[?, ?], offset: ?>>,
%i0: index, %i1: index)
-> f32 {
memref.assume_alignment %arg0, 16 : memref<?x?xf32, strided<[?, ?], offset: ?>>
%2 = memref.load %arg0[%i0, %i1] : memref<?x?xf32, strided<[?, ?], offset: ?>>
%arg0_align = memref.assume_alignment %arg0, 16 : memref<?x?xf32, strided<[?, ?], offset: ?>>
%2 = memref.load %arg0_align[%i0, %i1] : memref<?x?xf32, strided<[?, ?], offset: ?>>
func.return %2 : f32
}

View File

@@ -10,4 +10,11 @@ func.func @func_with_assert(%arg0: index, %arg1: index) {
%0 = arith.cmpi slt, %arg0, %arg1 : index
cf.assert %0, "%arg0 must be less than %arg1"
return
}
// CHECK-LABEL: func @func_with_assume_alignment(
// CHECK: %0 = memref.assume_alignment %arg0, 64 : memref<128xi8>
func.func @func_with_assume_alignment(%arg0: memref<128xi8>) {
%0 = memref.assume_alignment %arg0, 64 : memref<128xi8>
return
}

View File

@@ -63,8 +63,8 @@ func.func @memref_load_i4(%arg0: index) -> i4 {
func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
%0 = memref.alloc() : memref<3x125xi4>
memref.assume_alignment %0, 64 : memref<3x125xi4>
%1 = memref.load %0[%arg0,%arg1] : memref<3x125xi4>
%align0 =memref.assume_alignment %0, 64 : memref<3x125xi4>
%1 = memref.load %align0[%arg0,%arg1] : memref<3x125xi4>
return %1 : i4
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
@@ -73,9 +73,9 @@ func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
// CHECK: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
// CHECK: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
// CHECK: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
// CHECK: %[[LOAD:.+]] = memref.load %[[ASSUME]][%[[INDEX]]]
// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
@@ -88,9 +88,9 @@ func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
// CHECK32: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
// CHECK32: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
// CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
// CHECK32: %[[LOAD:.+]] = memref.load %[[ASSUME]][%[[INDEX]]]
// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
@@ -350,8 +350,8 @@ func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () {
func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
%0 = memref.alloc() : memref<3x125xi4>
memref.assume_alignment %0, 64 : memref<3x125xi4>
memref.store %arg2, %0[%arg0,%arg1] : memref<3x125xi4>
%align0 = memref.assume_alignment %0, 64 : memref<3x125xi4>
memref.store %arg2, %align0[%arg0,%arg1] : memref<3x125xi4>
return
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
@@ -359,7 +359,7 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
// CHECK: func @memref_store_i4_rank2(
// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
// CHECK-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
// CHECK-DAG: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i8
// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
@@ -369,8 +369,8 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ASSUME]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ASSUME]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
// CHECK: return
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 8)>
@@ -378,7 +378,7 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
// CHECK32: func @memref_store_i4_rank2(
// CHECK32-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
// CHECK32-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
// CHECK32-DAG: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i32
// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
@@ -388,8 +388,8 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
// CHECK32-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
// CHECK32-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
// CHECK32-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ASSUME]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ASSUME]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
// CHECK32: return
// -----

View File

@@ -878,7 +878,7 @@ func.func @invalid_memref_cast() {
// alignment is not power of 2.
func.func @assume_alignment(%0: memref<4x4xf16>) {
// expected-error@+1 {{alignment must be power of 2}}
memref.assume_alignment %0, 12 : memref<4x4xf16>
%1 = memref.assume_alignment %0, 12 : memref<4x4xf16>
return
}
@@ -887,7 +887,7 @@ func.func @assume_alignment(%0: memref<4x4xf16>) {
// 0 alignment value.
func.func @assume_alignment(%0: memref<4x4xf16>) {
// expected-error@+1 {{attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
memref.assume_alignment %0, 0 : memref<4x4xf16>
%1 = memref.assume_alignment %0, 0 : memref<4x4xf16>
return
}

View File

@@ -284,7 +284,7 @@ func.func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16>
func.func @assume_alignment(%0: memref<4x4xf16>) {
// CHECK: memref.assume_alignment %[[MEMREF]], 16 : memref<4x4xf16>
memref.assume_alignment %0, 16 : memref<4x4xf16>
%1 = memref.assume_alignment %0, 16 : memref<4x4xf16>
return
}