[mlir][gpu] Add gpu.rotate operation (#142796)
Add gpu.rotate operation and a pattern to convert gpu.rotate to SPIR-V OpGroupNonUniformRotateKHR.
This commit is contained in:
@@ -1304,8 +1304,8 @@ def GPU_ShuffleOp : GPU_Op<
|
||||
Results<(outs AnyIntegerOrFloatOr1DVector:$shuffleResult, I1:$valid)> {
|
||||
let summary = "Shuffles values within a subgroup.";
|
||||
let description = [{
|
||||
The "shuffle" op moves values to a across lanes (a.k.a., invocations,
|
||||
work items) within the same subgroup. The `width` argument specifies the
|
||||
The "shuffle" op moves values across lanes in a subgroup (a.k.a., local
|
||||
invocation) within the same subgroup. The `width` argument specifies the
|
||||
number of lanes that participate in the shuffle, and must be uniform
|
||||
across all lanes. Further, the first `width` lanes of the subgroup must
|
||||
be active.
|
||||
@@ -1366,6 +1366,54 @@ def GPU_ShuffleOp : GPU_Op<
|
||||
];
|
||||
}
|
||||
|
||||
def GPU_RotateOp : GPU_Op<
|
||||
"rotate", [Pure, AllTypesMatch<["value", "rotateResult"]>]>,
|
||||
Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, I32:$offset, I32:$width)>,
|
||||
Results<(outs AnyIntegerOrFloatOr1DVector:$rotateResult, I1:$valid)> {
|
||||
let summary = "Rotate values within a subgroup.";
|
||||
let description = [{
|
||||
The "rotate" op moves values across lanes in a subgroup (a.k.a., local
|
||||
invocations) within the same subgroup. The `width` argument specifies the
|
||||
number of lanes that participate in the rotation, and must be uniform across
|
||||
all participating lanes. Further, the first `width` lanes of the subgroup
|
||||
must be active.
|
||||
|
||||
`width` must be a power of two, and `offset` must be in the range
|
||||
`[0, width)`.
|
||||
|
||||
Return the `rotateResult` of the invocation whose id within the group is
|
||||
calculated as follows:
|
||||
|
||||
```mlir
|
||||
Invocation ID = ((LaneId + offset) & (width - 1)) + (LaneId & ~(width - 1))
|
||||
```
|
||||
|
||||
Returns the `rotateResult` and `true` if the current lane id is smaller than
|
||||
`width`, and poison value and `false` otherwise.
|
||||
|
||||
example:
|
||||
|
||||
```mlir
|
||||
%offset = arith.constant 1 : i32
|
||||
%width = arith.constant 16 : i32
|
||||
%1, %2 = gpu.rotate %0, %offset, %width : f32
|
||||
```
|
||||
|
||||
For lane `k`, returns the value from lane `(k + cst1) % width`.
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$value `,` $offset `,` $width attr-dict `:` type($value)
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
// Helper function that creates a rotate with constant offset/width.
|
||||
OpBuilder<(ins "Value":$value, "int32_t":$offset, "int32_t":$width)>
|
||||
];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def GPU_BarrierOp : GPU_Op<"barrier"> {
|
||||
let summary = "Synchronizes all work items of a workgroup.";
|
||||
let description = [{
|
||||
|
||||
@@ -122,6 +122,16 @@ public:
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Pattern to convert a gpu.rotate op into a spirv.GroupNonUniformRotateKHROp.
|
||||
class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(gpu::RotateOp rotateOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
@@ -488,6 +498,41 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Rotate
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult GPURotateConversion::matchAndRewrite(
|
||||
gpu::RotateOp rotateOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
const spirv::TargetEnv &targetEnv =
|
||||
getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
|
||||
unsigned subgroupSize =
|
||||
targetEnv.getAttr().getResourceLimits().getSubgroupSize();
|
||||
IntegerAttr widthAttr;
|
||||
if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) ||
|
||||
widthAttr.getValue().getZExtValue() > subgroupSize)
|
||||
return rewriter.notifyMatchFailure(
|
||||
rotateOp,
|
||||
"rotate width is not a constant or larger than target subgroup size");
|
||||
|
||||
Location loc = rotateOp.getLoc();
|
||||
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
|
||||
Value rotateResult = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
|
||||
loc, scope, adaptor.getValue(), adaptor.getOffset(), adaptor.getWidth());
|
||||
Value validVal;
|
||||
if (widthAttr.getValue().getZExtValue() == subgroupSize) {
|
||||
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter);
|
||||
} else {
|
||||
Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
|
||||
validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
|
||||
laneId, adaptor.getWidth());
|
||||
}
|
||||
|
||||
rewriter.replaceOp(rotateOp, {rotateResult, validVal});
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Group ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -776,7 +821,7 @@ void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<
|
||||
GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
|
||||
GPUReturnOpConversion, GPUShuffleConversion,
|
||||
GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
|
||||
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
|
||||
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
|
||||
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
|
||||
|
||||
@@ -1331,6 +1331,49 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
|
||||
mode);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RotateOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void RotateOp::build(OpBuilder &builder, OperationState &result, Value value,
|
||||
int32_t offset, int32_t width) {
|
||||
build(builder, result, value,
|
||||
builder.create<arith::ConstantOp>(result.location,
|
||||
builder.getI32IntegerAttr(offset)),
|
||||
builder.create<arith::ConstantOp>(result.location,
|
||||
builder.getI32IntegerAttr(width)));
|
||||
}
|
||||
|
||||
LogicalResult RotateOp::verify() {
|
||||
auto offsetConstOp = getOffset().getDefiningOp<arith::ConstantOp>();
|
||||
if (!offsetConstOp)
|
||||
return emitOpError() << "offset is not a constant value";
|
||||
|
||||
auto offsetIntAttr =
|
||||
llvm::dyn_cast<mlir::IntegerAttr>(offsetConstOp.getValue());
|
||||
|
||||
auto widthConstOp = getWidth().getDefiningOp<arith::ConstantOp>();
|
||||
if (!widthConstOp)
|
||||
return emitOpError() << "width is not a constant value";
|
||||
|
||||
auto widthIntAttr =
|
||||
llvm::dyn_cast<mlir::IntegerAttr>(widthConstOp.getValue());
|
||||
|
||||
llvm::APInt offsetValue = offsetIntAttr.getValue();
|
||||
llvm::APInt widthValue = widthIntAttr.getValue();
|
||||
|
||||
if (!widthValue.isPowerOf2())
|
||||
return emitOpError() << "width must be a power of two";
|
||||
|
||||
if (offsetValue.sge(widthValue) || offsetValue.slt(0)) {
|
||||
int64_t widthValueInt = widthValue.getSExtValue();
|
||||
return emitOpError() << "offset must be in the range [0, " << widthValueInt
|
||||
<< ")";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BarrierOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
102
mlir/test/Conversion/GPUToSPIRV/rotate.mlir
Normal file
102
mlir/test/Conversion/GPUToSPIRV/rotate.mlir
Normal file
@@ -0,0 +1,102 @@
|
||||
// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s
|
||||
|
||||
module attributes {
|
||||
gpu.container_module,
|
||||
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
|
||||
#spirv.resource_limits<subgroup_size = 16>>
|
||||
} {
|
||||
|
||||
gpu.module @kernels {
|
||||
// CHECK-LABEL: spirv.func @rotate()
|
||||
gpu.func @rotate() kernel
|
||||
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
|
||||
%offset = arith.constant 4 : i32
|
||||
%width = arith.constant 16 : i32
|
||||
%val = arith.constant 42.0 : f32
|
||||
|
||||
// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
|
||||
// CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
|
||||
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
|
||||
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
|
||||
// CHECK: %{{.+}} = spirv.Constant true
|
||||
%result, %valid = gpu.rotate %val, %offset, %width : f32
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {
|
||||
gpu.container_module,
|
||||
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
|
||||
#spirv.resource_limits<subgroup_size = 16>>
|
||||
} {
|
||||
|
||||
gpu.module @kernels {
|
||||
// CHECK-LABEL: spirv.func @rotate_width_less_than_subgroup_size()
|
||||
gpu.func @rotate_width_less_than_subgroup_size() kernel
|
||||
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
|
||||
%offset = arith.constant 4 : i32
|
||||
%width = arith.constant 8 : i32
|
||||
%val = arith.constant 42.0 : f32
|
||||
|
||||
// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
|
||||
// CHECK: %[[WIDTH:.+]] = spirv.Constant 8 : i32
|
||||
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
|
||||
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
|
||||
// CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__
|
||||
// CHECK: %[[INVOCATION_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]]
|
||||
// CHECK: %{{.+}} = spirv.ULessThan %[[INVOCATION_ID]], %[[WIDTH]]
|
||||
%result, %valid = gpu.rotate %val, %offset, %width : f32
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {
|
||||
gpu.container_module,
|
||||
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
|
||||
#spirv.resource_limits<subgroup_size = 16>>
|
||||
} {
|
||||
|
||||
gpu.module @kernels {
|
||||
gpu.func @rotate_with_bigger_than_subgroup_size() kernel
|
||||
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
|
||||
%offset = arith.constant 4 : i32
|
||||
%width = arith.constant 32 : i32
|
||||
%val = arith.constant 42.0 : f32
|
||||
|
||||
// expected-error @+1 {{failed to legalize operation 'gpu.rotate'}}
|
||||
%result, %valid = gpu.rotate %val, %offset, %width : f32
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {
|
||||
gpu.container_module,
|
||||
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
|
||||
#spirv.resource_limits<subgroup_size = 16>>
|
||||
} {
|
||||
|
||||
gpu.module @kernels {
|
||||
gpu.func @rotate_non_const_width(%width: i32) kernel
|
||||
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
|
||||
%offset = arith.constant 4 : i32
|
||||
%val = arith.constant 42.0 : f32
|
||||
|
||||
// expected-error @+1 {{'gpu.rotate' op width is not a constant value}}
|
||||
%result, %valid = gpu.rotate %val, %offset, %width : f32
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -478,6 +478,84 @@ func.func @shuffle_unsupported_type_vec(%arg0 : vector<[4]xf32>, %arg1 : i32, %a
|
||||
|
||||
// -----
|
||||
|
||||
func.func @rotate_mismatching_type(%arg0 : f32) {
|
||||
%offset = arith.constant 4 : i32
|
||||
%width = arith.constant 16 : i32
|
||||
// expected-error@+1 {{op failed to verify that all of {value, rotateResult} have same type}}
|
||||
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (i32, i1)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @rotate_unsupported_type(%arg0 : index) {
|
||||
%offset = arith.constant 4 : i32
|
||||
%width = arith.constant 16 : i32
|
||||
// expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'index'}}
|
||||
%rotate, %valid = gpu.rotate %arg0, %offset, %width : index
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @rotate_unsupported_type_vec(%arg0 : vector<[4]xf32>) {
|
||||
%offset = arith.constant 4 : i32
|
||||
%width = arith.constant 16 : i32
|
||||
// expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'vector<[4]xf32>'}}
|
||||
%rotate, %valid = gpu.rotate %arg0, %offset, %width : vector<[4]xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @rotate_unsupported_width(%arg0 : f32) {
|
||||
%offset = arith.constant 4 : i32
|
||||
%width = arith.constant 15 : i32
|
||||
// expected-error@+1 {{op width must be a power of two}}
|
||||
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @rotate_unsupported_offset(%arg0 : f32) {
|
||||
%offset = arith.constant 16 : i32
|
||||
%width = arith.constant 16 : i32
|
||||
// expected-error@+1 {{op offset must be in the range [0, 16)}}
|
||||
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @rotate_unsupported_offset_minus(%arg0 : f32) {
|
||||
%offset = arith.constant -1 : i32
|
||||
%width = arith.constant 16 : i32
|
||||
// expected-error@+1 {{op offset must be in the range [0, 16)}}
|
||||
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @rotate_offset_non_constant(%arg0 : f32, %offset : i32) {
|
||||
%width = arith.constant 16 : i32
|
||||
// expected-error@+1 {{op offset is not a constant value}}
|
||||
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @rotate_width_non_constant(%arg0 : f32, %width : i32) {
|
||||
%offset = arith.constant 0 : i32
|
||||
// expected-error@+1 {{op width is not a constant value}}
|
||||
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
gpu.module @gpu_funcs {
|
||||
// expected-error @+1 {{custom op 'gpu.func' gpu.func requires named arguments}}
|
||||
|
||||
@@ -140,6 +140,10 @@ module attributes {gpu.container_module} {
|
||||
// CHECK: gpu.shuffle idx %{{.*}}, %{{.*}}, %{{.*}} : f32
|
||||
%shfl3, %pred3 = gpu.shuffle idx %arg0, %offset, %width : f32
|
||||
|
||||
// CHECK: gpu.rotate %{{.*}}, %{{.*}}, %{{.*}} : f32
|
||||
%rotate_width = arith.constant 16 : i32
|
||||
%rotate, %pred4 = gpu.rotate %arg0, %offset, %rotate_width : f32
|
||||
|
||||
"gpu.barrier"() : () -> ()
|
||||
|
||||
"some_op"(%bIdX, %tIdX) : (index, index) -> ()
|
||||
|
||||
Reference in New Issue
Block a user