[mlir][gpu][spirv] Remove rotation semantics of gpu.shuffle up/down (#139105)

From the description of gpu.shuffle operation, shuffle up/down rotates
values in the subgroup because it applies modulo on the shifted value to
calculate the result lane ID. It is inconsistent with the definition of
SPIR-V shuffle up/down and NVVM data movement definitions within
subgroup.

In NVVM, it says

"If the computed source lane index j is in range, the returned i32 value
will be the value of %a from lane j; otherwise, it will be the the value
of %a from the current thread."

It will keep the original value if the result land ID is out of range.

In SPIR-V OpGroupNonUniformShuffleUp and OpGroupNonUniformShuffleDown,
it says

"The resulting value is undefined if Delta is greater than the current
invocation’s id within the scope or if the identified invocation is not
in scope restricted tangle."

It's an undefined value if the result land ID is out of range.

Anyway, there is no circular movement in shuffle up/down from these 2
specifications. This patch removes the circular movement in gpu.shuffle
up/down and lower gpu.shuffle up/down to SPIR-V
OpGroupNonUniformShuffleUp and OpGroupNonUniformShuffleDown directly.

Reference:

https://docs.nvidia.com/cuda/archive/12.2.1/nvvm-ir-spec/index.html#data-movement

https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpGroupNonUniformShuffleUp

https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpGroupNonUniformShuffleDown
This commit is contained in:
Hsiangkai Wang
2025-06-19 07:56:30 +01:00
committed by GitHub
parent 590066bee7
commit 03461c9c6e
3 changed files with 111 additions and 11 deletions

View File

@@ -1332,7 +1332,8 @@ def GPU_ShuffleOp : GPU_Op<
%3, %4 = gpu.shuffle down %0, %cst1, %width : f32
```
For lane `k`, returns the value from lane `(k + 1) % width`.
For lane `k`, returns the value from lane `(k + cst1)`. If `(k + cst1)` is
bigger than or equal to `width`, the value is poison and `valid` is `false`.
`up` example:
@@ -1341,7 +1342,8 @@ def GPU_ShuffleOp : GPU_Op<
%5, %6 = gpu.shuffle up %0, %cst1, %width : f32
```
For lane `k`, returns the value from lane `(k - 1) % width`.
For lane `k`, returns the value from lane `(k - cst1)`. If `(k - cst1)` is
smaller than `0`, the value is poison and `valid` is `false`.
`idx` example:

View File

@@ -435,26 +435,57 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
return rewriter.notifyMatchFailure(
shuffleOp, "shuffle width and target subgroup size mismatch");
assert(!adaptor.getOffset().getType().isSignedInteger() &&
"shuffle offset must be a signless/unsigned integer");
Location loc = shuffleOp.getLoc();
Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
shuffleOp.getLoc(), rewriter);
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
Value result;
Value validVal;
switch (shuffleOp.getMode()) {
case gpu::ShuffleMode::XOR:
case gpu::ShuffleMode::XOR: {
result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>(
loc, scope, adaptor.getValue(), adaptor.getOffset());
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
shuffleOp.getLoc(), rewriter);
break;
case gpu::ShuffleMode::IDX:
}
case gpu::ShuffleMode::IDX: {
result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
loc, scope, adaptor.getValue(), adaptor.getOffset());
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
shuffleOp.getLoc(), rewriter);
break;
default:
return rewriter.notifyMatchFailure(shuffleOp, "unimplemented shuffle mode");
}
case gpu::ShuffleMode::DOWN: {
result = rewriter.create<spirv::GroupNonUniformShuffleDownOp>(
loc, scope, adaptor.getValue(), adaptor.getOffset());
Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
Value resultLaneId =
rewriter.create<arith::AddIOp>(loc, laneId, adaptor.getOffset());
validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
resultLaneId, adaptor.getWidth());
break;
}
case gpu::ShuffleMode::UP: {
result = rewriter.create<spirv::GroupNonUniformShuffleUpOp>(
loc, scope, adaptor.getValue(), adaptor.getOffset());
Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
Value resultLaneId =
rewriter.create<arith::SubIOp>(loc, laneId, adaptor.getOffset());
auto i32Type = rewriter.getIntegerType(32);
validVal = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, resultLaneId,
rewriter.create<arith::ConstantOp>(
loc, i32Type, rewriter.getIntegerAttr(i32Type, 0)));
break;
}
}
rewriter.replaceOp(shuffleOp, {result, trueVal});
rewriter.replaceOp(shuffleOp, {result, validVal});
return success();
}

View File

@@ -15,8 +15,8 @@ gpu.module @kernels {
// CHECK: %[[MASK:.+]] = spirv.Constant 8 : i32
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
// CHECK: %{{.+}} = spirv.Constant true
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffleXor <Subgroup> %[[VAL]], %[[MASK]] : f32, i32
// CHECK: %{{.+}} = spirv.Constant true
%result, %valid = gpu.shuffle xor %val, %mask, %width : f32
gpu.return
}
@@ -64,11 +64,78 @@ gpu.module @kernels {
// CHECK: %[[MASK:.+]] = spirv.Constant 8 : i32
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
// CHECK: %{{.+}} = spirv.Constant true
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffle <Subgroup> %[[VAL]], %[[MASK]] : f32, i32
// CHECK: %{{.+}} = spirv.Constant true
%result, %valid = gpu.shuffle idx %val, %mask, %width : f32
gpu.return
}
}
}
// -----
module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], []>,
#spirv.resource_limits<subgroup_size = 16>>
} {
gpu.module @kernels {
// CHECK-LABEL: spirv.func @shuffle_down()
gpu.func @shuffle_down() kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
%offset = arith.constant 4 : i32
%width = arith.constant 16 : i32
%val = arith.constant 42.0 : f32
// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
// CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffleDown <Subgroup> %[[VAL]], %[[OFFSET]] : f32, i32
// CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__ : !spirv.ptr<i32, Input>
// CHECK: %[[LANE_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]] : i32
// CHECK: %[[VAL_LANE_ID:.+]] = spirv.IAdd %[[LANE_ID]], %[[OFFSET]] : i32
// CHECK: %[[VALID:.+]] = spirv.ULessThan %[[VAL_LANE_ID]], %[[WIDTH]] : i32
%result, %valid = gpu.shuffle down %val, %offset, %width : f32
gpu.return
}
}
}
// -----
module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle, GroupNonUniformShuffleRelative], []>,
#spirv.resource_limits<subgroup_size = 16>>
} {
gpu.module @kernels {
// CHECK-LABEL: spirv.func @shuffle_up()
gpu.func @shuffle_up() kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
%offset = arith.constant 4 : i32
%width = arith.constant 16 : i32
%val = arith.constant 42.0 : f32
// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
// CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
// CHECK: %{{.+}} = spirv.GroupNonUniformShuffleUp <Subgroup> %[[VAL]], %[[OFFSET]] : f32, i32
// CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__ : !spirv.ptr<i32, Input>
// CHECK: %[[LANE_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]] : i32
// CHECK: %[[VAL_LANE_ID:.+]] = spirv.ISub %[[LANE_ID]], %[[OFFSET]] : i32
// CHECK: %[[CST0:.+]] = spirv.Constant 0 : i32
// CHECK: %[[VALID:.+]] = spirv.SGreaterThanEqual %[[VAL_LANE_ID]], %[[CST0]] : i32
%result, %valid = gpu.shuffle up %val, %offset, %width : f32
gpu.return
}
}
}