[mlir][MemRef] Changed AssumeAlignment into a Pure ViewLikeOp (#139521)
Made AssumeAlignment a ViewLikeOp that returns a new SSA memref equal to its memref argument and made it have Pure trait. This gives it a defined memory effect that matches what it does in practice and makes it behave nicely with optimizations which won't get rid of it unless its result isn't being used.
This commit is contained in:
@@ -142,22 +142,37 @@ class AllocLikeOp<string mnemonic,
|
||||
// AssumeAlignmentOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def AssumeAlignmentOp : MemRef_Op<"assume_alignment"> {
|
||||
def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
|
||||
Pure,
|
||||
ViewLikeOpInterface,
|
||||
SameOperandsAndResultType
|
||||
]> {
|
||||
let summary =
|
||||
"assertion that gives alignment information to the input memref";
|
||||
"assumption that gives alignment information to the input memref";
|
||||
let description = [{
|
||||
The `assume_alignment` operation takes a memref and an integer of alignment
|
||||
value, and internally annotates the buffer with the given alignment. If
|
||||
the buffer isn't aligned to the given alignment, the behavior is undefined.
|
||||
The `assume_alignment` operation takes a memref and an integer alignment
|
||||
value. It returns a new SSA value of the same memref type, but associated
|
||||
with the assumption that the underlying buffer is aligned to the given
|
||||
alignment.
|
||||
|
||||
This operation doesn't affect the semantics of a correct program. It's for
|
||||
optimization only, and the optimization is best-effort.
|
||||
If the buffer isn't aligned to the given alignment, its result is poison.
|
||||
This operation doesn't affect the semantics of a program where the
|
||||
alignment assumption holds true. It is intended for optimization purposes,
|
||||
allowing the compiler to generate more efficient code based on the
|
||||
alignment assumption. The optimization is best-effort.
|
||||
}];
|
||||
let arguments = (ins AnyMemRef:$memref,
|
||||
ConfinedAttr<I32Attr, [IntPositive]>:$alignment);
|
||||
let results = (outs);
|
||||
let results = (outs AnyMemRef:$result);
|
||||
|
||||
let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getType() { return ::llvm::cast<MemRefType>(getResult().getType()); }
|
||||
|
||||
Value getViewSource() { return getMemref(); }
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
|
||||
@@ -432,8 +432,7 @@ struct AssumeAlignmentOpLowering
|
||||
createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
|
||||
rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
|
||||
alignmentConst);
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
rewriter.replaceOp(op, memref);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -44,13 +44,6 @@ using namespace mlir::gpu;
|
||||
// The functions below provide interface-like verification, but are too specific
|
||||
// to barrier elimination to become interfaces.
|
||||
|
||||
/// Implement the MemoryEffectsOpInterface in the suitable way.
|
||||
static bool isKnownNoEffectsOpWithoutInterface(Operation *op) {
|
||||
// memref::AssumeAlignment is conceptually pure, but marking it as such would
|
||||
// make DCE immediately remove it.
|
||||
return isa<memref::AssumeAlignmentOp>(op);
|
||||
}
|
||||
|
||||
/// Returns `true` if the op is defines the parallel region that is subject to
|
||||
/// barrier synchronization.
|
||||
static bool isParallelRegionBoundary(Operation *op) {
|
||||
@@ -101,10 +94,6 @@ collectEffects(Operation *op,
|
||||
if (ignoreBarriers && isa<BarrierOp>(op))
|
||||
return true;
|
||||
|
||||
// Skip over ops that we know have no effects.
|
||||
if (isKnownNoEffectsOpWithoutInterface(op))
|
||||
return true;
|
||||
|
||||
// Collect effect instances the operation. Note that the implementation of
|
||||
// getEffects erases all effect instances that have the type other than the
|
||||
// template parameter so we collect them first in a local buffer and then
|
||||
|
||||
@@ -527,6 +527,11 @@ LogicalResult AssumeAlignmentOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
void AssumeAlignmentOp::getAsmResultNames(
|
||||
function_ref<void(Value, StringRef)> setNameFn) {
|
||||
setNameFn(getResult(), "assume_align");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -229,7 +229,7 @@ struct ConvertMemRefAssumeAlignment final
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
|
||||
op, adaptor.getMemref(), adaptor.getAlignmentAttr());
|
||||
op, newTy, adaptor.getMemref(), adaptor.getAlignmentAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -919,6 +919,35 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Pattern to replace `extract_strided_metadata(assume_alignment)`
|
||||
///
|
||||
/// With
|
||||
/// \verbatim
|
||||
/// extract_strided_metadata(memref)
|
||||
/// \endverbatim
|
||||
///
|
||||
/// Since `assume_alignment` is a view-like op that does not modify the
|
||||
/// underlying buffer, offset, sizes, or strides, extracting strided metadata
|
||||
/// from its result is equivalent to extracting it from its source. This
|
||||
/// canonicalization removes the unnecessary indirection.
|
||||
struct ExtractStridedMetadataOpAssumeAlignmentFolder
|
||||
: public OpRewritePattern<memref::ExtractStridedMetadataOp> {
|
||||
public:
|
||||
using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto assumeAlignmentOp =
|
||||
op.getSource().getDefiningOp<memref::AssumeAlignmentOp>();
|
||||
if (!assumeAlignmentOp)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<memref::ExtractStridedMetadataOp>(
|
||||
op, assumeAlignmentOp.getViewSource());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the
|
||||
/// source of the ViewLikeOp.
|
||||
class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
|
||||
@@ -1185,6 +1214,7 @@ void memref::populateExpandStridedMetadataPatterns(
|
||||
ExtractStridedMetadataOpSubviewFolder,
|
||||
ExtractStridedMetadataOpCastFolder,
|
||||
ExtractStridedMetadataOpMemorySpaceCastFolder,
|
||||
ExtractStridedMetadataOpAssumeAlignmentFolder,
|
||||
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
|
||||
patterns.getContext());
|
||||
}
|
||||
@@ -1201,6 +1231,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
|
||||
ExtractStridedMetadataOpReinterpretCastFolder,
|
||||
ExtractStridedMetadataOpCastFolder,
|
||||
ExtractStridedMetadataOpMemorySpaceCastFolder,
|
||||
ExtractStridedMetadataOpAssumeAlignmentFolder,
|
||||
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
@@ -683,7 +683,7 @@ func.func @load_and_assume(
|
||||
%arg0: memref<?x?xf32, strided<[?, ?], offset: ?>>,
|
||||
%i0: index, %i1: index)
|
||||
-> f32 {
|
||||
memref.assume_alignment %arg0, 16 : memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%2 = memref.load %arg0[%i0, %i1] : memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%arg0_align = memref.assume_alignment %arg0, 16 : memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%2 = memref.load %arg0_align[%i0, %i1] : memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
func.return %2 : f32
|
||||
}
|
||||
|
||||
@@ -10,4 +10,11 @@ func.func @func_with_assert(%arg0: index, %arg1: index) {
|
||||
%0 = arith.cmpi slt, %arg0, %arg1 : index
|
||||
cf.assert %0, "%arg0 must be less than %arg1"
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func_with_assume_alignment(
|
||||
// CHECK: %0 = memref.assume_alignment %arg0, 64 : memref<128xi8>
|
||||
func.func @func_with_assume_alignment(%arg0: memref<128xi8>) {
|
||||
%0 = memref.assume_alignment %arg0, 64 : memref<128xi8>
|
||||
return
|
||||
}
|
||||
@@ -63,8 +63,8 @@ func.func @memref_load_i4(%arg0: index) -> i4 {
|
||||
|
||||
func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
|
||||
%0 = memref.alloc() : memref<3x125xi4>
|
||||
memref.assume_alignment %0, 64 : memref<3x125xi4>
|
||||
%1 = memref.load %0[%arg0,%arg1] : memref<3x125xi4>
|
||||
%align0 =memref.assume_alignment %0, 64 : memref<3x125xi4>
|
||||
%1 = memref.load %align0[%arg0,%arg1] : memref<3x125xi4>
|
||||
return %1 : i4
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
|
||||
@@ -73,9 +73,9 @@ func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
|
||||
// CHECK: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
|
||||
// CHECK: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
|
||||
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
|
||||
// CHECK: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
|
||||
// CHECK: %[[LOAD:.+]] = memref.load %[[ASSUME]][%[[INDEX]]]
|
||||
// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
|
||||
// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
|
||||
// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
|
||||
@@ -88,9 +88,9 @@ func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
|
||||
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
|
||||
// CHECK32: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
|
||||
// CHECK32: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
|
||||
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
|
||||
// CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
|
||||
// CHECK32: %[[LOAD:.+]] = memref.load %[[ASSUME]][%[[INDEX]]]
|
||||
// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
|
||||
// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
|
||||
// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]]
|
||||
@@ -350,8 +350,8 @@ func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () {
|
||||
|
||||
func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
|
||||
%0 = memref.alloc() : memref<3x125xi4>
|
||||
memref.assume_alignment %0, 64 : memref<3x125xi4>
|
||||
memref.store %arg2, %0[%arg0,%arg1] : memref<3x125xi4>
|
||||
%align0 = memref.assume_alignment %0, 64 : memref<3x125xi4>
|
||||
memref.store %arg2, %align0[%arg0,%arg1] : memref<3x125xi4>
|
||||
return
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
|
||||
@@ -359,7 +359,7 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
|
||||
// CHECK: func @memref_store_i4_rank2(
|
||||
// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
|
||||
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
|
||||
// CHECK-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
|
||||
// CHECK-DAG: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
|
||||
// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i8
|
||||
// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
|
||||
// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
|
||||
@@ -369,8 +369,8 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
|
||||
// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
|
||||
// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
|
||||
// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
|
||||
// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
|
||||
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
|
||||
// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ASSUME]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
|
||||
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ASSUME]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
|
||||
// CHECK: return
|
||||
|
||||
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 8)>
|
||||
@@ -378,7 +378,7 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
|
||||
// CHECK32: func @memref_store_i4_rank2(
|
||||
// CHECK32-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
|
||||
// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
|
||||
// CHECK32-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
|
||||
// CHECK32-DAG: %[[ASSUME:.+]] = memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
|
||||
// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i32
|
||||
// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
|
||||
// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
|
||||
@@ -388,8 +388,8 @@ func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
|
||||
// CHECK32-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
|
||||
// CHECK32-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
|
||||
// CHECK32-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
|
||||
// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
|
||||
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
|
||||
// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ASSUME]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
|
||||
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ASSUME]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
|
||||
// CHECK32: return
|
||||
|
||||
// -----
|
||||
|
||||
@@ -878,7 +878,7 @@ func.func @invalid_memref_cast() {
|
||||
// alignment is not power of 2.
|
||||
func.func @assume_alignment(%0: memref<4x4xf16>) {
|
||||
// expected-error@+1 {{alignment must be power of 2}}
|
||||
memref.assume_alignment %0, 12 : memref<4x4xf16>
|
||||
%1 = memref.assume_alignment %0, 12 : memref<4x4xf16>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -887,7 +887,7 @@ func.func @assume_alignment(%0: memref<4x4xf16>) {
|
||||
// 0 alignment value.
|
||||
func.func @assume_alignment(%0: memref<4x4xf16>) {
|
||||
// expected-error@+1 {{attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
|
||||
memref.assume_alignment %0, 0 : memref<4x4xf16>
|
||||
%1 = memref.assume_alignment %0, 0 : memref<4x4xf16>
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -284,7 +284,7 @@ func.func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
// CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16>
|
||||
func.func @assume_alignment(%0: memref<4x4xf16>) {
|
||||
// CHECK: memref.assume_alignment %[[MEMREF]], 16 : memref<4x4xf16>
|
||||
memref.assume_alignment %0, 16 : memref<4x4xf16>
|
||||
%1 = memref.assume_alignment %0, 16 : memref<4x4xf16>
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user