[MLIR][XeGPU] Add support for elementwise ops in Wg to Sg distribute pass [1/N] (#142797)
This PR adds support for Elementwise operations' (unary & binary) lowering from Workgroup to Subgroup.
This commit is contained in:
@@ -8,10 +8,12 @@
|
||||
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/Dialect/Affine/Utils.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/Index/IR/IndexDialect.h"
|
||||
#include "mlir/Dialect/Index/IR/IndexOps.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
@@ -19,6 +21,7 @@
|
||||
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include <optional>
|
||||
|
||||
namespace mlir {
|
||||
namespace xegpu {
|
||||
@@ -328,6 +331,65 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// This pattern transforms elementwise ops to work at subgroup level.
|
||||
struct WgToSgElementwiseOp : public ConversionPattern {
|
||||
WgToSgElementwiseOp(MLIRContext *ctx)
|
||||
: ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Only match ops with elementwise trait and single result.
|
||||
if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
|
||||
return failure();
|
||||
|
||||
auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
|
||||
assert(resultType && "Expected result to be a VectorType");
|
||||
|
||||
ArrayRef<int64_t> wgShape = resultType.getShape();
|
||||
|
||||
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
|
||||
if (!layout || !layout.getSgLayout())
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
|
||||
|
||||
size_t numVariants = operands.empty() ? 0 : operands.front().size();
|
||||
|
||||
if (llvm::any_of(operands, [&](const ValueRange &operandVec) {
|
||||
return operandVec.size() != numVariants;
|
||||
}))
|
||||
return failure();
|
||||
|
||||
SmallVector<Value> newResults;
|
||||
VectorType newResultType =
|
||||
VectorType::get(sgShape, resultType.getElementType());
|
||||
|
||||
for (size_t i = 0; i < numVariants; ++i) {
|
||||
SmallVector<Value> opOperands;
|
||||
for (auto &operandVec : operands)
|
||||
opOperands.push_back(operandVec[i]);
|
||||
|
||||
OperationState state(op->getLoc(), op->getName());
|
||||
state.addOperands(opOperands);
|
||||
state.addTypes(newResultType);
|
||||
// Copy all attributes, but update "layout_result_0" to drop
|
||||
// sgLayout/sgData
|
||||
for (auto attr : op->getAttrs()) {
|
||||
if (auto layout = dyn_cast<xegpu::LayoutAttr>(attr.getValue()))
|
||||
state.addAttribute(attr.getName(), layout.dropSgLayoutAndData());
|
||||
else
|
||||
state.addAttribute(attr.getName(), attr.getValue());
|
||||
}
|
||||
Operation *newOp = rewriter.create(state);
|
||||
newResults.push_back(newOp->getResult(0));
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithMultiple(op, {newResults});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Handles UnrealizedConversionCastOp generated during
|
||||
// SCFStructuralTypeConversions (step 1). This op may appear as either a
|
||||
// target or source materialization for Vector values, e.g.:
|
||||
@@ -411,7 +473,8 @@ namespace xegpu {
|
||||
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
|
||||
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
|
||||
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
|
||||
UnrealizedConversionCastOpPattern>(patterns.getContext());
|
||||
UnrealizedConversionCastOpPattern, WgToSgElementwiseOp>(
|
||||
patterns.getContext());
|
||||
}
|
||||
} // namespace xegpu
|
||||
} // namespace mlir
|
||||
@@ -518,6 +581,30 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
|
||||
return isLegal(layout);
|
||||
});
|
||||
|
||||
target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
|
||||
[=](Operation *op) -> std::optional<bool> {
|
||||
// Only handle elementwise mappable ops
|
||||
if (!OpTrait::hasElementwiseMappableTraits(op))
|
||||
return true;
|
||||
|
||||
VectorType resultType =
|
||||
dyn_cast<VectorType>(op->getResult(0).getType());
|
||||
if (!resultType)
|
||||
return true;
|
||||
|
||||
// Check if all operands are vectors of the same shape
|
||||
// TODO: Support other types.
|
||||
for (Value operand : op->getOperands()) {
|
||||
VectorType operandType = dyn_cast<VectorType>(operand.getType());
|
||||
if (!operandType || operandType.getShape() != resultType.getShape()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
|
||||
return isLegal(layout);
|
||||
});
|
||||
|
||||
target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
|
||||
[=](UnrealizedConversionCastOp op) {
|
||||
return llvm::is_contained(existingCastOps, op.getOperation());
|
||||
|
||||
164
mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
Normal file
164
mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
Normal file
@@ -0,0 +1,164 @@
|
||||
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
|
||||
|
||||
gpu.module @test_elementwise_ops {
|
||||
// CHECK-LABEL: unary_ops
|
||||
gpu.func @unary_ops(%a: memref<24x32xf32>) {
|
||||
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
|
||||
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
%load_a = xegpu.load_nd %tdesc_a
|
||||
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
-> vector<24x32xf32>
|
||||
// CHECK: math.exp {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
|
||||
%exp = math.exp %load_a
|
||||
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
|
||||
: vector<24x32xf32>
|
||||
// CHECK: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
|
||||
%negf = arith.negf %load_a
|
||||
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
|
||||
: vector<24x32xf32>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: binary_ops
|
||||
gpu.func @binary_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>) {
|
||||
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
|
||||
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
|
||||
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
%load_a = xegpu.load_nd %tdesc_a
|
||||
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
-> vector<24x32xf32>
|
||||
%load_b = xegpu.load_nd %tdesc_b
|
||||
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
-> vector<24x32xf32>
|
||||
// CHECK: arith.addf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
|
||||
// CHECK-SAME: : vector<12x8xf32>
|
||||
%addf = arith.addf %load_a, %load_b
|
||||
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
|
||||
: vector<24x32xf32>
|
||||
// CHECK: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
|
||||
// CHECK-SAME: : vector<12x8xf32>
|
||||
%powf = math.powf %load_a, %load_b
|
||||
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
|
||||
: vector<24x32xf32>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ternary_ops
|
||||
gpu.func @ternary_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi1>) {
|
||||
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
|
||||
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
|
||||
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
%tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi1>
|
||||
-> !xegpu.tensor_desc<24x32xi1, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
%load_a = xegpu.load_nd %tdesc_a
|
||||
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
-> vector<24x32xf32>
|
||||
%load_b = xegpu.load_nd %tdesc_b
|
||||
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
-> vector<24x32xf32>
|
||||
%load_c = xegpu.load_nd %tdesc_c
|
||||
: !xegpu.tensor_desc<24x32xi1, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
-> vector<24x32xi1>
|
||||
// CHECK: arith.select {{.*}}, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
|
||||
// CHECK-SAME: : vector<12x8xi1>, vector<12x8xf32>
|
||||
%select = arith.select %load_c, %load_a, %load_b
|
||||
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
|
||||
: vector<24x32xi1>, vector<24x32xf32>
|
||||
// CHECK: math.fma {{.*}}, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
|
||||
// CHECK-SAME: : vector<12x8xf32>
|
||||
%fma = math.fma %load_a, %load_b, %load_a
|
||||
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
|
||||
: vector<24x32xf32>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: type_conversion_ops
|
||||
gpu.func @type_conversion_ops(%a: memref<24x32xf32>, %b: memref<24x32xi32>) {
|
||||
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
|
||||
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xi32>
|
||||
-> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
%load_a = xegpu.load_nd %tdesc_a
|
||||
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
-> vector<24x32xf32>
|
||||
%load_b = xegpu.load_nd %tdesc_b
|
||||
: !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
-> vector<24x32xi32>
|
||||
// CHECK: arith.truncf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
|
||||
// CHECK-SAME: : vector<12x8xf32> to vector<12x8xf16>
|
||||
%truncf = arith.truncf %load_a
|
||||
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
|
||||
: vector<24x32xf32> to vector<24x32xf16>
|
||||
// CHECK: arith.bitcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
|
||||
// CHECK-SAME: : vector<12x8xi32> to vector<12x8xf32>
|
||||
%bitcast = arith.bitcast %load_b
|
||||
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
|
||||
: vector<24x32xi32> to vector<24x32xf32>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: comparison_ops
|
||||
gpu.func @comparison_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
|
||||
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
|
||||
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
|
||||
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
%tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32>
|
||||
-> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
%tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32>
|
||||
-> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
%load_a = xegpu.load_nd %tdesc_a
|
||||
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
-> vector<24x32xf32>
|
||||
%load_b = xegpu.load_nd %tdesc_b
|
||||
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
-> vector<24x32xf32>
|
||||
%load_c = xegpu.load_nd %tdesc_c
|
||||
: !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
-> vector<24x32xi32>
|
||||
%load_d = xegpu.load_nd %tdesc_d
|
||||
: !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
|
||||
-> vector<24x32xi32>
|
||||
// CHECK: arith.cmpf ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
|
||||
// CHECK-SAME: : vector<12x8xf32>
|
||||
%cmpf = arith.cmpf ult, %load_a, %load_b
|
||||
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
|
||||
: vector<24x32xf32>
|
||||
// CHECK: arith.cmpi eq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
|
||||
// CHECK-SAME: : vector<12x8xi32>
|
||||
%cmpi = arith.cmpi eq, %load_c, %load_d
|
||||
{layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
|
||||
: vector<24x32xi32>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// 1 to N decomposition of elementwise operations
|
||||
// CHECK-LABEL: elementwise_ops_rr_assignment
|
||||
gpu.func @elementwise_ops_rr_assignment(%a: memref<24x32xf32>, %b: memref<24x32xf32>) {
|
||||
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
|
||||
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
|
||||
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
|
||||
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
|
||||
%load_a = xegpu.load_nd %tdesc_a
|
||||
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
|
||||
-> vector<24x32xf32>
|
||||
%load_b = xegpu.load_nd %tdesc_b
|
||||
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
|
||||
-> vector<24x32xf32>
|
||||
// CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
|
||||
// CHECK-SAME-COUNT-12: : vector<2x2xf32>
|
||||
// CHECK-NOT: arith.negf
|
||||
%negf = arith.negf %load_a
|
||||
{layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
|
||||
: vector<24x32xf32>
|
||||
// CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
|
||||
// CHECK-SAME-COUNT-12: : vector<2x2xf32>
|
||||
// CHECK-NOT: math.powf
|
||||
%powf = math.powf %load_a, %load_b
|
||||
{layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
|
||||
: vector<24x32xf32>
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user