[mlir][gpu][spirv] Remove rotation semantics of gpu.shuffle up/down (#139105)
From the description of gpu.shuffle operation, shuffle up/down rotates values in the subgroup because it applies modulo on the shifted value to calculate the result lane ID. It is inconsistent with the definition of SPIR-V shuffle up/down and NVVM data movement definitions within subgroup. In NVVM, it says "If the computed source lane index j is in range, the returned i32 value will be the value of %a from lane j; otherwise, it will be the the value of %a from the current thread." It will keep the original value if the result land ID is out of range. In SPIR-V OpGroupNonUniformShuffleUp and OpGroupNonUniformShuffleDown, it says "The resulting value is undefined if Delta is greater than the current invocation’s id within the scope or if the identified invocation is not in scope restricted tangle." It's an undefined value if the result land ID is out of range. Anyway, there is no circular movement in shuffle up/down from these 2 specifications. This patch removes the circular movement in gpu.shuffle up/down and lower gpu.shuffle up/down to SPIR-V OpGroupNonUniformShuffleUp and OpGroupNonUniformShuffleDown directly. Reference: https://docs.nvidia.com/cuda/archive/12.2.1/nvvm-ir-spec/index.html#data-movement https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpGroupNonUniformShuffleUp https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpGroupNonUniformShuffleDown
This commit is contained in:
@@ -1332,7 +1332,8 @@ def GPU_ShuffleOp : GPU_Op<
|
||||
%3, %4 = gpu.shuffle down %0, %cst1, %width : f32
|
||||
```
|
||||
|
||||
For lane `k`, returns the value from lane `(k + 1) % width`.
|
||||
For lane `k`, returns the value from lane `(k + cst1)`. If `(k + cst1)` is
|
||||
bigger than or equal to `width`, the value is poison and `valid` is `false`.
|
||||
|
||||
`up` example:
|
||||
|
||||
@@ -1341,7 +1342,8 @@ def GPU_ShuffleOp : GPU_Op<
|
||||
%5, %6 = gpu.shuffle up %0, %cst1, %width : f32
|
||||
```
|
||||
|
||||
For lane `k`, returns the value from lane `(k - 1) % width`.
|
||||
For lane `k`, returns the value from lane `(k - cst1)`. If `(k - cst1)` is
|
||||
smaller than `0`, the value is poison and `valid` is `false`.
|
||||
|
||||
`idx` example:
|
||||
|
||||
|
||||
@@ -435,26 +435,57 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
|
||||
return rewriter.notifyMatchFailure(
|
||||
shuffleOp, "shuffle width and target subgroup size mismatch");
|
||||
|
||||
assert(!adaptor.getOffset().getType().isSignedInteger() &&
|
||||
"shuffle offset must be a signless/unsigned integer");
|
||||
|
||||
Location loc = shuffleOp.getLoc();
|
||||
Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
|
||||
shuffleOp.getLoc(), rewriter);
|
||||
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
|
||||
Value result;
|
||||
Value validVal;
|
||||
|
||||
switch (shuffleOp.getMode()) {
|
||||
case gpu::ShuffleMode::XOR:
|
||||
case gpu::ShuffleMode::XOR: {
|
||||
result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>(
|
||||
loc, scope, adaptor.getValue(), adaptor.getOffset());
|
||||
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
|
||||
shuffleOp.getLoc(), rewriter);
|
||||
break;
|
||||
case gpu::ShuffleMode::IDX:
|
||||
}
|
||||
case gpu::ShuffleMode::IDX: {
|
||||
result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
|
||||
loc, scope, adaptor.getValue(), adaptor.getOffset());
|
||||
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
|
||||
shuffleOp.getLoc(), rewriter);
|
||||
break;
|
||||
default:
|
||||
return rewriter.notifyMatchFailure(shuffleOp, "unimplemented shuffle mode");
|
||||
}
|
||||
case gpu::ShuffleMode::DOWN: {
|
||||
result = rewriter.create<spirv::GroupNonUniformShuffleDownOp>(
|
||||
loc, scope, adaptor.getValue(), adaptor.getOffset());
|
||||
|
||||
Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
|
||||
Value resultLaneId =
|
||||
rewriter.create<arith::AddIOp>(loc, laneId, adaptor.getOffset());
|
||||
validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
|
||||
resultLaneId, adaptor.getWidth());
|
||||
break;
|
||||
}
|
||||
case gpu::ShuffleMode::UP: {
|
||||
result = rewriter.create<spirv::GroupNonUniformShuffleUpOp>(
|
||||
loc, scope, adaptor.getValue(), adaptor.getOffset());
|
||||
|
||||
Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
|
||||
Value resultLaneId =
|
||||
rewriter.create<arith::SubIOp>(loc, laneId, adaptor.getOffset());
|
||||
auto i32Type = rewriter.getIntegerType(32);
|
||||
validVal = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sge, resultLaneId,
|
||||
rewriter.create<arith::ConstantOp>(
|
||||
loc, i32Type, rewriter.getIntegerAttr(i32Type, 0)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOp(shuffleOp, {result, trueVal});
|
||||
rewriter.replaceOp(shuffleOp, {result, validVal});
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
@@ -15,8 +15,8 @@ gpu.module @kernels {
|
||||
|
||||
// CHECK: %[[MASK:.+]] = spirv.Constant 8 : i32
|
||||
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
|
||||
// CHECK: %{{.+}} = spirv.Constant true
|
||||
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffleXor <Subgroup> %[[VAL]], %[[MASK]] : f32, i32
|
||||
// CHECK: %{{.+}} = spirv.Constant true
|
||||
%result, %valid = gpu.shuffle xor %val, %mask, %width : f32
|
||||
gpu.return
|
||||
}
|
||||
@@ -64,11 +64,78 @@ gpu.module @kernels {
|
||||
|
||||
// CHECK: %[[MASK:.+]] = spirv.Constant 8 : i32
|
||||
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
|
||||
// CHECK: %{{.+}} = spirv.Constant true
|
||||
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffle <Subgroup> %[[VAL]], %[[MASK]] : f32, i32
|
||||
// CHECK: %{{.+}} = spirv.Constant true
|
||||
%result, %valid = gpu.shuffle idx %val, %mask, %width : f32
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {
|
||||
gpu.container_module,
|
||||
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], []>,
|
||||
#spirv.resource_limits<subgroup_size = 16>>
|
||||
} {
|
||||
|
||||
gpu.module @kernels {
|
||||
// CHECK-LABEL: spirv.func @shuffle_down()
|
||||
gpu.func @shuffle_down() kernel
|
||||
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
|
||||
%offset = arith.constant 4 : i32
|
||||
%width = arith.constant 16 : i32
|
||||
%val = arith.constant 42.0 : f32
|
||||
|
||||
// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
|
||||
// CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
|
||||
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
|
||||
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffleDown <Subgroup> %[[VAL]], %[[OFFSET]] : f32, i32
|
||||
|
||||
// CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__ : !spirv.ptr<i32, Input>
|
||||
// CHECK: %[[LANE_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]] : i32
|
||||
// CHECK: %[[VAL_LANE_ID:.+]] = spirv.IAdd %[[LANE_ID]], %[[OFFSET]] : i32
|
||||
// CHECK: %[[VALID:.+]] = spirv.ULessThan %[[VAL_LANE_ID]], %[[WIDTH]] : i32
|
||||
|
||||
%result, %valid = gpu.shuffle down %val, %offset, %width : f32
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {
|
||||
gpu.container_module,
|
||||
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], []>,
|
||||
#spirv.resource_limits<subgroup_size = 16>>
|
||||
} {
|
||||
|
||||
gpu.module @kernels {
|
||||
// CHECK-LABEL: spirv.func @shuffle_up()
|
||||
gpu.func @shuffle_up() kernel
|
||||
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
|
||||
%offset = arith.constant 4 : i32
|
||||
%width = arith.constant 16 : i32
|
||||
%val = arith.constant 42.0 : f32
|
||||
|
||||
// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
|
||||
// CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
|
||||
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
|
||||
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffleUp <Subgroup> %[[VAL]], %[[OFFSET]] : f32, i32
|
||||
|
||||
// CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__ : !spirv.ptr<i32, Input>
|
||||
// CHECK: %[[LANE_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]] : i32
|
||||
// CHECK: %[[VAL_LANE_ID:.+]] = spirv.ISub %[[LANE_ID]], %[[OFFSET]] : i32
|
||||
// CHECK: %[[CST0:.+]] = spirv.Constant 0 : i32
|
||||
// CHECK: %[[VALID:.+]] = spirv.SGreaterThanEqual %[[VAL_LANE_ID]], %[[CST0]] : i32
|
||||
|
||||
%result, %valid = gpu.shuffle up %val, %offset, %width : f32
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user