[mlir][gpu] Add gpu.rotate operation (#142796)

Add gpu.rotate operation and a pattern to convert gpu.rotate to SPIR-V
OpGroupNonUniformRotateKHR.
This commit is contained in:
Hsiangkai Wang
2025-07-01 11:32:25 +01:00
committed by GitHub
parent a97826a13b
commit f581ef5b66
6 changed files with 323 additions and 3 deletions

View File

@@ -1304,8 +1304,8 @@ def GPU_ShuffleOp : GPU_Op<
Results<(outs AnyIntegerOrFloatOr1DVector:$shuffleResult, I1:$valid)> {
let summary = "Shuffles values within a subgroup.";
let description = [{
The "shuffle" op moves values to a across lanes (a.k.a., invocations,
work items) within the same subgroup. The `width` argument specifies the
The "shuffle" op moves values across lanes in a subgroup (a.k.a., local
invocation) within the same subgroup. The `width` argument specifies the
number of lanes that participate in the shuffle, and must be uniform
across all lanes. Further, the first `width` lanes of the subgroup must
be active.
@@ -1366,6 +1366,54 @@ def GPU_ShuffleOp : GPU_Op<
];
}
def GPU_RotateOp : GPU_Op<
"rotate", [Pure, AllTypesMatch<["value", "rotateResult"]>]>,
Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, I32:$offset, I32:$width)>,
Results<(outs AnyIntegerOrFloatOr1DVector:$rotateResult, I1:$valid)> {
let summary = "Rotate values within a subgroup.";
let description = [{
The "rotate" op moves values across lanes in a subgroup (a.k.a., local
invocations) within the same subgroup. The `width` argument specifies the
number of lanes that participate in the rotation, and must be uniform across
all participating lanes. Further, the first `width` lanes of the subgroup
must be active.
`width` must be a power of two, and `offset` must be in the range
`[0, width)`.
Return the `rotateResult` of the invocation whose id within the group is
calculated as follows:
```mlir
Invocation ID = ((LaneId + offset) & (width - 1)) + (LaneId & ~(width - 1))
```
Returns the `rotateResult` and `true` if the current lane id is smaller than
`width`, and poison value and `false` otherwise.
example:
```mlir
%offset = arith.constant 1 : i32
%width = arith.constant 16 : i32
%1, %2 = gpu.rotate %0, %offset, %width : f32
```
For lane `k`, returns the value from lane `(k + cst1) % width`.
}];
let assemblyFormat = [{
$value `,` $offset `,` $width attr-dict `:` type($value)
}];
let builders = [
// Helper function that creates a rotate with constant offset/width.
OpBuilder<(ins "Value":$value, "int32_t":$offset, "int32_t":$width)>
];
let hasVerifier = 1;
}
def GPU_BarrierOp : GPU_Op<"barrier"> {
let summary = "Synchronizes all work items of a workgroup.";
let description = [{

View File

@@ -122,6 +122,16 @@ public:
ConversionPatternRewriter &rewriter) const override;
};
/// Pattern to convert a gpu.rotate op into a spirv.GroupNonUniformRotateKHROp.
class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(gpu::RotateOp rotateOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -488,6 +498,41 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
return success();
}
//===----------------------------------------------------------------------===//
// Rotate
//===----------------------------------------------------------------------===//
LogicalResult GPURotateConversion::matchAndRewrite(
gpu::RotateOp rotateOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
const spirv::TargetEnv &targetEnv =
getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
unsigned subgroupSize =
targetEnv.getAttr().getResourceLimits().getSubgroupSize();
IntegerAttr widthAttr;
if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) ||
widthAttr.getValue().getZExtValue() > subgroupSize)
return rewriter.notifyMatchFailure(
rotateOp,
"rotate width is not a constant or larger than target subgroup size");
Location loc = rotateOp.getLoc();
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
Value rotateResult = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
loc, scope, adaptor.getValue(), adaptor.getOffset(), adaptor.getWidth());
Value validVal;
if (widthAttr.getValue().getZExtValue() == subgroupSize) {
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter);
} else {
Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
laneId, adaptor.getWidth());
}
rewriter.replaceOp(rotateOp, {rotateResult, validVal});
return success();
}
//===----------------------------------------------------------------------===//
// Group ops
//===----------------------------------------------------------------------===//
@@ -776,7 +821,7 @@ void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<
GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
GPUReturnOpConversion, GPUShuffleConversion,
GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,

View File

@@ -1331,6 +1331,49 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
mode);
}
//===----------------------------------------------------------------------===//
// RotateOp
//===----------------------------------------------------------------------===//
void RotateOp::build(OpBuilder &builder, OperationState &result, Value value,
int32_t offset, int32_t width) {
build(builder, result, value,
builder.create<arith::ConstantOp>(result.location,
builder.getI32IntegerAttr(offset)),
builder.create<arith::ConstantOp>(result.location,
builder.getI32IntegerAttr(width)));
}
LogicalResult RotateOp::verify() {
auto offsetConstOp = getOffset().getDefiningOp<arith::ConstantOp>();
if (!offsetConstOp)
return emitOpError() << "offset is not a constant value";
auto offsetIntAttr =
llvm::dyn_cast<mlir::IntegerAttr>(offsetConstOp.getValue());
auto widthConstOp = getWidth().getDefiningOp<arith::ConstantOp>();
if (!widthConstOp)
return emitOpError() << "width is not a constant value";
auto widthIntAttr =
llvm::dyn_cast<mlir::IntegerAttr>(widthConstOp.getValue());
llvm::APInt offsetValue = offsetIntAttr.getValue();
llvm::APInt widthValue = widthIntAttr.getValue();
if (!widthValue.isPowerOf2())
return emitOpError() << "width must be a power of two";
if (offsetValue.sge(widthValue) || offsetValue.slt(0)) {
int64_t widthValueInt = widthValue.getSExtValue();
return emitOpError() << "offset must be in the range [0, " << widthValueInt
<< ")";
}
return success();
}
//===----------------------------------------------------------------------===//
// BarrierOp
//===----------------------------------------------------------------------===//

View File

@@ -0,0 +1,102 @@
// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s
module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
#spirv.resource_limits<subgroup_size = 16>>
} {
gpu.module @kernels {
// CHECK-LABEL: spirv.func @rotate()
gpu.func @rotate() kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
%offset = arith.constant 4 : i32
%width = arith.constant 16 : i32
%val = arith.constant 42.0 : f32
// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
// CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
// CHECK: %{{.+}} = spirv.Constant true
%result, %valid = gpu.rotate %val, %offset, %width : f32
gpu.return
}
}
}
// -----
module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
#spirv.resource_limits<subgroup_size = 16>>
} {
gpu.module @kernels {
// CHECK-LABEL: spirv.func @rotate_width_less_than_subgroup_size()
gpu.func @rotate_width_less_than_subgroup_size() kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
%offset = arith.constant 4 : i32
%width = arith.constant 8 : i32
%val = arith.constant 42.0 : f32
// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
// CHECK: %[[WIDTH:.+]] = spirv.Constant 8 : i32
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
// CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__
// CHECK: %[[INVOCATION_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]]
// CHECK: %{{.+}} = spirv.ULessThan %[[INVOCATION_ID]], %[[WIDTH]]
%result, %valid = gpu.rotate %val, %offset, %width : f32
gpu.return
}
}
}
// -----
module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
#spirv.resource_limits<subgroup_size = 16>>
} {
gpu.module @kernels {
gpu.func @rotate_with_bigger_than_subgroup_size() kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
%offset = arith.constant 4 : i32
%width = arith.constant 32 : i32
%val = arith.constant 42.0 : f32
// expected-error @+1 {{failed to legalize operation 'gpu.rotate'}}
%result, %valid = gpu.rotate %val, %offset, %width : f32
gpu.return
}
}
}
// -----
module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
#spirv.resource_limits<subgroup_size = 16>>
} {
gpu.module @kernels {
gpu.func @rotate_non_const_width(%width: i32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
%offset = arith.constant 4 : i32
%val = arith.constant 42.0 : f32
// expected-error @+1 {{'gpu.rotate' op width is not a constant value}}
%result, %valid = gpu.rotate %val, %offset, %width : f32
gpu.return
}
}
}

View File

@@ -478,6 +478,84 @@ func.func @shuffle_unsupported_type_vec(%arg0 : vector<[4]xf32>, %arg1 : i32, %a
// -----
func.func @rotate_mismatching_type(%arg0 : f32) {
%offset = arith.constant 4 : i32
%width = arith.constant 16 : i32
// expected-error@+1 {{op failed to verify that all of {value, rotateResult} have same type}}
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (i32, i1)
return
}
// -----
func.func @rotate_unsupported_type(%arg0 : index) {
%offset = arith.constant 4 : i32
%width = arith.constant 16 : i32
// expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'index'}}
%rotate, %valid = gpu.rotate %arg0, %offset, %width : index
return
}
// -----
func.func @rotate_unsupported_type_vec(%arg0 : vector<[4]xf32>) {
%offset = arith.constant 4 : i32
%width = arith.constant 16 : i32
// expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'vector<[4]xf32>'}}
%rotate, %valid = gpu.rotate %arg0, %offset, %width : vector<[4]xf32>
return
}
// -----
func.func @rotate_unsupported_width(%arg0 : f32) {
%offset = arith.constant 4 : i32
%width = arith.constant 15 : i32
// expected-error@+1 {{op width must be a power of two}}
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
return
}
// -----
func.func @rotate_unsupported_offset(%arg0 : f32) {
%offset = arith.constant 16 : i32
%width = arith.constant 16 : i32
// expected-error@+1 {{op offset must be in the range [0, 16)}}
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
return
}
// -----
func.func @rotate_unsupported_offset_minus(%arg0 : f32) {
%offset = arith.constant -1 : i32
%width = arith.constant 16 : i32
// expected-error@+1 {{op offset must be in the range [0, 16)}}
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
return
}
// -----
func.func @rotate_offset_non_constant(%arg0 : f32, %offset : i32) {
%width = arith.constant 16 : i32
// expected-error@+1 {{op offset is not a constant value}}
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
return
}
// -----
func.func @rotate_width_non_constant(%arg0 : f32, %width : i32) {
%offset = arith.constant 0 : i32
// expected-error@+1 {{op width is not a constant value}}
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
return
}
// -----
module {
gpu.module @gpu_funcs {
// expected-error @+1 {{custom op 'gpu.func' gpu.func requires named arguments}}

View File

@@ -140,6 +140,10 @@ module attributes {gpu.container_module} {
// CHECK: gpu.shuffle idx %{{.*}}, %{{.*}}, %{{.*}} : f32
%shfl3, %pred3 = gpu.shuffle idx %arg0, %offset, %width : f32
// CHECK: gpu.rotate %{{.*}}, %{{.*}}, %{{.*}} : f32
%rotate_width = arith.constant 16 : i32
%rotate, %pred4 = gpu.rotate %arg0, %offset, %rotate_width : f32
"gpu.barrier"() : () -> ()
"some_op"(%bIdX, %tIdX) : (index, index) -> ()