[mlir][xegpu] Add support for distributing gpu.barrier (#145434)
This commit is contained in:
@@ -455,6 +455,14 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
|
||||
if (!operand)
|
||||
return rewriter.notifyMatchFailure(
|
||||
subgroupOp, "warp result is not a xegpu::LoadNd op");
|
||||
// Make sure the load op is the last operation in the warp op body. This
|
||||
// ensure that load op is not sinked earlier violating any barrier
|
||||
// synchronizations.
|
||||
auto yield = cast<gpu::YieldOp>(
|
||||
subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
|
||||
Operation *lastNode = yield->getPrevNode();
|
||||
if (!dyn_cast_or_null<xegpu::LoadNdOp>(lastNode))
|
||||
return failure();
|
||||
|
||||
auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
|
||||
xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
|
||||
@@ -782,6 +790,29 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
|
||||
}
|
||||
};
|
||||
|
||||
/// Sink a gpu::BarrierOp at the end of enclosing `gpu.warp_execute_on_lane_0`
|
||||
/// region. This will simply move the barrier op outside of the warp op.
|
||||
struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
|
||||
using gpu::WarpDistributionPattern::WarpDistributionPattern;
|
||||
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto yield = cast<gpu::YieldOp>(
|
||||
subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
|
||||
Operation *lastNode = yield->getPrevNode();
|
||||
// The last node must be a gpu::BarrierOp.
|
||||
auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);
|
||||
if (!barrierOp)
|
||||
return failure();
|
||||
// Move the barrier op outside of the warp op.
|
||||
rewriter.setInsertionPointAfter(subgroupOp);
|
||||
rewriter.create<gpu::BarrierOp>(
|
||||
barrierOp.getLoc(), barrierOp->getResultTypes(),
|
||||
barrierOp->getOperands(), barrierOp->getAttrs());
|
||||
rewriter.eraseOp(barrierOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
@@ -796,7 +827,8 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<CreateNdDescDistribution, StoreNdDistribution,
|
||||
LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
|
||||
UpdateNdOffsetDistribution>(patterns.getContext());
|
||||
UpdateNdOffsetDistribution, GpuBarrierDistribution>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
void XeGPUSubgroupDistributePass::runOnOperation() {
|
||||
|
||||
@@ -278,3 +278,22 @@ gpu.module @test {
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: gpu.func @gpu_barrier({{.*}}) {
|
||||
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
|
||||
// CHECK-NEXT: %[[T1:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<16xf16> -> vector<1xf16>
|
||||
// CHECK-NEXT: gpu.barrier
|
||||
// CHECK-NEXT: %[[T2:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
|
||||
// CHECK-NEXT: xegpu.store_nd %[[T1]], %[[T2]] : vector<1xf16>, !xegpu.tensor_desc<16xf16>
|
||||
gpu.module @test {
|
||||
gpu.func @gpu_barrier(%arg0: memref<256xf16>, %arg1: memref<256xf16>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
|
||||
%1 = xegpu.load_nd %0 {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>> -> vector<16xf16>
|
||||
gpu.barrier
|
||||
%2 = xegpu.create_nd_tdesc %arg1[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
|
||||
xegpu.store_nd %1, %2 : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user