[mlir][xegpu] Fix seg-fault caused by setting a null attribute (#146002)

This commit is contained in:
Chao Chen
2025-07-01 15:42:52 -05:00
committed by GitHub
parent 829f2f2448
commit 5d849d3a90
2 changed files with 28 additions and 5 deletions

View File

@@ -376,10 +376,12 @@ struct WgToSgElementwiseOp : public ConversionPattern {
// Copy all attributes, but update "layout_result_0" to drop
// sgLayout/sgData
for (auto attr : op->getAttrs()) {
if (auto layout = dyn_cast<xegpu::LayoutAttr>(attr.getValue()))
state.addAttribute(attr.getName(), layout.dropSgLayoutAndData());
else
if (auto layout = dyn_cast<xegpu::LayoutAttr>(attr.getValue())) {
if (auto newLayout = layout.dropSgLayoutAndData())
state.addAttribute(attr.getName(), newLayout);
} else {
state.addAttribute(attr.getName(), attr.getValue());
}
}
Operation *newOp = rewriter.create(state);
newResults.push_back(newOp->getResult(0));
@@ -629,8 +631,10 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
std::string name = xegpu::getLayoutName(result);
if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
op->removeAttr(name);
if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op))
op->setAttr(name, layout.dropSgLayoutAndData());
if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op)) {
if (auto newLayout = layout.dropSgLayoutAndData())
op->setAttr(name, newLayout);
}
}
}
});

View File

@@ -1,6 +1,25 @@
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
gpu.module @test_elementwise_ops {
// CHECK-LABEL: unary_ops_sg_layout_only
gpu.func @unary_ops_sg_layout_only(%a: memref<24x32xf32>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8]>>
%load_a = xegpu.load_nd %tdesc_a
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8]>>
-> vector<24x32xf32>
// CHECK: math.exp {{.*}} : vector<12x8xf32>
%exp = math.exp %load_a
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8]>}
: vector<24x32xf32>
// CHECK: arith.negf {{.*}} : vector<12x8xf32>
%negf = arith.negf %load_a
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8]>}
: vector<24x32xf32>
gpu.return
}
// CHECK-LABEL: unary_ops
gpu.func @unary_ops(%a: memref<24x32xf32>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>