diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp index 74face429135..af9be4cccecf 100644 --- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp @@ -432,44 +432,50 @@ createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op, /*bound_ctrl=*/false); res = vector::makeArithReduction( rewriter, loc, gpu::convertReductionKind(mode), res, dpp); - if (ci.subgroupSize == 32) { - Value lane0 = rewriter.create( - loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); - res = - rewriter.create(loc, res.getType(), res, lane0); - } } else { return rewriter.notifyMatchFailure( op, "Subgroup reduce lowering to DPP not currently supported for " "this device."); } + if (ci.subgroupSize == 32) { + Value lane31 = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(31)); + res = rewriter.create(loc, res.getType(), res, lane31); + } } if (ci.clusterSize >= 64) { if (chipset.majorVersion <= 9) { // Broadcast 31st lane value to rows 2 and 3. - // Use row mask to avoid polluting rows 0 and 1. dpp = rewriter.create( loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_31, - rewriter.getUnitAttr(), 0xc, allBanks, - /*bound_ctrl*/ false); + rewriter.getUnitAttr(), 0xf, allBanks, + /*bound_ctrl*/ true); + res = vector::makeArithReduction( + rewriter, loc, gpu::convertReductionKind(mode), dpp, res); + // Obtain reduction from last rows, the previous rows are polluted. + Value lane63 = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(63)); + res = rewriter.create(loc, res.getType(), res, lane63); } else if (chipset.majorVersion <= 12) { // Assume reduction across 32 lanes has been done. // Perform final reduction manually by summing values in lane 0 and // lane 32. - Value lane0 = rewriter.create( - loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); - Value lane32 = rewriter.create( - loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(32)); - dpp = rewriter.create(loc, res.getType(), res, lane32); - res = rewriter.create(loc, res.getType(), res, lane0); + Value lane31 = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(31)); + Value lane63 = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(63)); + lane31 = + rewriter.create(loc, res.getType(), res, lane31); + lane63 = + rewriter.create(loc, res.getType(), res, lane63); + res = vector::makeArithReduction( + rewriter, loc, gpu::convertReductionKind(mode), lane31, lane63); } else { return rewriter.notifyMatchFailure( op, "Subgroup reduce lowering to DPP not currently supported for " "this device."); } - res = vector::makeArithReduction(rewriter, loc, - gpu::convertReductionKind(mode), res, dpp); } assert(res.getType() == input.getType()); return res; diff --git a/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir b/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir index 098145ade2ae..87a31ca20eb7 100644 --- a/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir +++ b/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir @@ -349,7 +349,7 @@ gpu.module @kernels { // CHECK-GFX10: %[[A4:.+]] = arith.addi %[[A3]], %[[P0]] : i16 // CHECK-GFX10: %[[R0:.+]] = rocdl.readlane %[[A4]], %{{.+}} : (i16, i32) -> i16 // CHECK-GFX10: %[[R1:.+]] = rocdl.readlane %[[A4]], %{{.+}} : (i16, i32) -> i16 - // CHECK-GFX10: %[[A5:.+]] = arith.addi %[[R1]], %[[R0]] : i16 + // CHECK-GFX10: %[[A5:.+]] = arith.addi %[[R0]], %[[R1]] : i16 // CHECK-GFX10: "test.consume"(%[[A5]]) : (i16) -> () %sum0 = gpu.subgroup_reduce add %arg0 : (i16) -> i16 "test.consume"(%sum0) : (i16) -> ()