[mlir][spirv] Add instruction OpGroupNonUniformRotateKHR (#133428)
Add an instruction under the extension SPV_KHR_subgroup_rotate. The specification for the extension is here: https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_subgroup_rotate.html
This commit is contained in:
@@ -4489,6 +4489,7 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor : I32EnumAttrCase<"OpGroupNonUnifo
|
||||
def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
|
||||
def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
|
||||
def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
|
||||
def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
|
||||
def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
|
||||
def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>;
|
||||
def SPIRV_OC_OpUDot : I32EnumAttrCase<"OpUDot", 4451>;
|
||||
@@ -4598,7 +4599,8 @@ def SPIRV_OpcodeAttr :
|
||||
SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd,
|
||||
SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor,
|
||||
SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
|
||||
SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR,
|
||||
SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpGroupNonUniformRotateKHR,
|
||||
SPIRV_OC_OpSubgroupBallotKHR,
|
||||
SPIRV_OC_OpSDot, SPIRV_OC_OpUDot, SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat,
|
||||
SPIRV_OC_OpUDotAccSat, SPIRV_OC_OpSUDotAccSat,
|
||||
SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
|
||||
|
||||
@@ -1361,4 +1361,78 @@ def SPIRV_GroupNonUniformBallotBitCountOp : SPIRV_Op<"GroupNonUniformBallotBitCo
|
||||
|
||||
// -----
|
||||
|
||||
def SPIRV_GroupNonUniformRotateKHROp : SPIRV_Op<"GroupNonUniformRotateKHR", [
|
||||
Pure, AllTypesMatch<["value", "result"]>]> {
|
||||
let summary = [{
|
||||
Rotate values across invocations within a subgroup.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Return the Value of the invocation whose id within the group is calculated
|
||||
as follows:
|
||||
|
||||
LocalId = SubgroupLocalInvocationId if Execution is Subgroup or
|
||||
LocalInvocationId if Execution is Workgroup
|
||||
RotationGroupSize = ClusterSize when ClusterSize is present, otherwise
|
||||
RotationGroupSize = SubgroupMaxSize if the Kernel capability is declared
|
||||
and SubgroupSize if not.
|
||||
Invocation ID = ( (LocalId + Delta) & (RotationGroupSize - 1) ) +
|
||||
(LocalId & ~(RotationGroupSize - 1))
|
||||
|
||||
Result Type must be a scalar or vector of floating-point type, integer
|
||||
type, or Boolean type.
|
||||
|
||||
Execution is a Scope. It must be either Workgroup or Subgroup.
|
||||
|
||||
The type of Value must be the same as Result Type.
|
||||
|
||||
Delta must be a scalar of integer type, whose Signedness operand is 0.
|
||||
Delta must be dynamically uniform within Execution.
|
||||
|
||||
Delta is treated as unsigned and the resulting value is undefined if the
|
||||
selected lane is inactive.
|
||||
|
||||
ClusterSize is the size of cluster to use. ClusterSize must be a scalar of
|
||||
integer type, whose Signedness operand is 0. ClusterSize must come from a
|
||||
constant instruction. Behavior is undefined unless ClusterSize is at least
|
||||
1 and a power of 2. If ClusterSize is greater than the declared
|
||||
SubGroupSize, executing this instruction results in undefined behavior.
|
||||
|
||||
<!-- End of AutoGen section -->
|
||||
|
||||
#### Example:
|
||||
|
||||
```mlir
|
||||
%four = spirv.Constant 4 : i32
|
||||
%0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %value, %delta : f32, i32 -> f32
|
||||
%1 = spirv.GroupNonUniformRotateKHR <Workgroup>, %value, %delta,
|
||||
clustersize(%four) : f32, i32, i32 -> f32
|
||||
```
|
||||
}];
|
||||
|
||||
let availability = [
|
||||
MinVersion<SPIRV_V_1_3>,
|
||||
MaxVersion<SPIRV_V_1_6>,
|
||||
Extension<[]>,
|
||||
Capability<[SPIRV_C_GroupNonUniformRotateKHR]>
|
||||
];
|
||||
|
||||
let arguments = (ins
|
||||
SPIRV_ScopeAttr:$execution_scope,
|
||||
AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$value,
|
||||
SPIRV_SignlessOrUnsignedInt:$delta,
|
||||
Optional<SPIRV_SignlessOrUnsignedInt>:$cluster_size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
AnyTypeOf<[SPIRV_ScalarOrVectorOf<SPIRV_Float>, SPIRV_ScalarOrVectorOf<SPIRV_Integer>, SPIRV_ScalarOrVectorOf<SPIRV_Bool>]>:$result
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$execution_scope `,` $value `,` $delta (`,` `cluster_size` `(` $cluster_size^ `)`)? attr-dict `:` type($value) `,` type($delta) (`,` type($cluster_size)^)? `->` type(results)
|
||||
}];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#endif // MLIR_DIALECT_SPIRV_IR_NON_UNIFORM_OPS
|
||||
|
||||
@@ -304,6 +304,29 @@ LogicalResult GroupNonUniformLogicalXorOp::verify() {
|
||||
return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(*this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.GroupNonUniformRotateKHR
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult GroupNonUniformRotateKHROp::verify() {
|
||||
spirv::Scope scope = getExecutionScope();
|
||||
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
|
||||
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
|
||||
|
||||
if (Value clusterSizeVal = getClusterSize()) {
|
||||
mlir::Operation *defOp = clusterSizeVal.getDefiningOp();
|
||||
int32_t clusterSize = 0;
|
||||
|
||||
if (failed(extractValueFromConstOp(defOp, clusterSize)))
|
||||
return emitOpError("cluster size operand must come from a constant op");
|
||||
|
||||
if (!llvm::isPowerOf2_32(clusterSize))
|
||||
return emitOpError("cluster size operand must be a power of two");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Group op verification
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -604,3 +604,70 @@ func.func @group_non_uniform_logical_xor(%val: i32) -> i32 {
|
||||
%0 = spirv.GroupNonUniformLogicalXor <Workgroup> <Reduce> %val : i32 -> i32
|
||||
return %0: i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.GroupNonUniformRotateKHR
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: @group_non_uniform_rotate_khr
|
||||
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
|
||||
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup>, %{{.+}} : f32, i32 -> f32
|
||||
%0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta : f32, i32 -> f32
|
||||
return %0: f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @group_non_uniform_rotate_khr
|
||||
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
|
||||
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Workgroup>, %{{.+}} : f32, i32, i32 -> f32
|
||||
%four = spirv.Constant 4 : i32
|
||||
%0 = spirv.GroupNonUniformRotateKHR <Workgroup>, %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
|
||||
return %0: f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
|
||||
%four = spirv.Constant 4 : i32
|
||||
// expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
|
||||
%0 = spirv.GroupNonUniformRotateKHR <Device>, %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
|
||||
return %0: f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: si32) -> f32 {
|
||||
%four = spirv.Constant 4 : i32
|
||||
// expected-error @+1 {{op operand #1 must be 8/16/32/64-bit signless/unsigned integer, but got 'si32'}}
|
||||
%0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%four) : f32, si32, i32 -> f32
|
||||
return %0: f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
|
||||
%four = spirv.Constant 4 : si32
|
||||
// expected-error @+1 {{op operand #2 must be 8/16/32/64-bit signless/unsigned integer, but got 'si32'}}
|
||||
%0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%four) : f32, i32, si32 -> f32
|
||||
return %0: f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32, %four: i32) -> f32 {
|
||||
// expected-error @+1 {{cluster size operand must come from a constant op}}
|
||||
%0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
|
||||
return %0: f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
|
||||
%five = spirv.Constant 5 : i32
|
||||
// expected-error @+1 {{cluster size operand must be a power of two}}
|
||||
%0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%five) : f32, i32, i32 -> f32
|
||||
return %0: f32
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user