[mlir][xegpu] Convert Vector contraction to XeGPU (#122115)
Adds pattern to lower vector.contract to XeGPU operation.
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
@@ -312,6 +313,48 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
|
||||
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = contractOp.getLoc();
|
||||
|
||||
if (contractOp.getKind() != vector::CombiningKind::ADD)
|
||||
return rewriter.notifyMatchFailure(contractOp,
|
||||
"Expects add combining kind");
|
||||
|
||||
TypedValue<Type> acc = contractOp.getAcc();
|
||||
VectorType accType = dyn_cast<VectorType>(acc.getType());
|
||||
if (!accType || accType.getRank() != 2)
|
||||
return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
|
||||
|
||||
// Accept only plain 2D data layout.
|
||||
// VNNI packing is applied to DPAS as a separate lowering step.
|
||||
TypedValue<VectorType> lhs = contractOp.getLhs();
|
||||
TypedValue<VectorType> rhs = contractOp.getRhs();
|
||||
if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
|
||||
return rewriter.notifyMatchFailure(contractOp,
|
||||
"Expects lhs and rhs 2D vectors");
|
||||
|
||||
if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
|
||||
return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
|
||||
|
||||
// TODO: Update shape validation to be target aware.
|
||||
auto accShape = accType.getShape();
|
||||
int64_t dimN = accShape[1];
|
||||
if (dimN != 8 && dimN != 16)
|
||||
return rewriter.notifyMatchFailure(contractOp,
|
||||
"Invalid operand dimensions");
|
||||
|
||||
auto dpasOp = rewriter.create<xegpu::DpasOp>(
|
||||
loc, TypeRange{contractOp.getResultType()}, ValueRange{lhs, rhs, acc});
|
||||
rewriter.replaceOp(contractOp, dpasOp);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertVectorToXeGPUPass
|
||||
: public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
|
||||
void runOnOperation() override {
|
||||
@@ -327,5 +370,5 @@ struct ConvertVectorToXeGPUPass
|
||||
void mlir::populateVectorToXeGPUConversionPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
|
||||
StoreLowering>(patterns.getContext());
|
||||
StoreLowering, ContractionLowering>(patterns.getContext());
|
||||
}
|
||||
|
||||
158
mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
Normal file
158
mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
Normal file
@@ -0,0 +1,158 @@
|
||||
// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
|
||||
|
||||
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
|
||||
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
func.func @dpas_gemm_f16(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
|
||||
%acc: vector<8x16xf32>) -> vector<8x16xf32> {
|
||||
%3 = vector.contract
|
||||
{indexing_maps = [#map, #map1, #map2],
|
||||
iterator_types = ["parallel", "parallel", "reduction"],
|
||||
kind = #vector.kind<add>} %lhs, %rhs, %acc
|
||||
: vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf32>
|
||||
return %3 : vector<8x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @dpas_gemm_f16(
|
||||
// CHECK-SAME: %[[LHS:.+]]: vector<8x16xf16>,
|
||||
// CHECK-SAME: %[[RHS:.+]]: vector<16x16xf16>,
|
||||
// CHECK-SAME: %[[ACC:.+]]: vector<8x16xf32>
|
||||
// CHECK: %[[DPAS:.+]] = xegpu.dpas
|
||||
// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
|
||||
// CHECK-SAME: {{.*}}-> vector<8x16xf32>
|
||||
// CHECK: return %[[DPAS]]
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
|
||||
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
func.func @dpas_gemm_i8(%lhs: vector<8x32xi8>, %rhs: vector<32x16xi8>,
|
||||
%acc: vector<8x16xi32>) -> vector<8x16xi32> {
|
||||
%3 = vector.contract
|
||||
{indexing_maps = [#map, #map1, #map2],
|
||||
iterator_types = ["parallel", "parallel", "reduction"],
|
||||
kind = #vector.kind<add>} %lhs, %rhs, %acc
|
||||
: vector<8x32xi8>, vector<32x16xi8> into vector<8x16xi32>
|
||||
return %3 : vector<8x16xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @dpas_gemm_i8(
|
||||
// CHECK-SAME: %[[LHS:.+]]: vector<8x32xi8>,
|
||||
// CHECK-SAME: %[[RHS:.+]]: vector<32x16xi8>,
|
||||
// CHECK-SAME: %[[ACC:.+]]: vector<8x16xi32>
|
||||
// CHECK: %[[DPAS:.+]] = xegpu.dpas
|
||||
// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
|
||||
// CHECK-SAME: {{.*}}-> vector<8x16xi32>
|
||||
// CHECK: return %[[DPAS]]
|
||||
|
||||
// -----
|
||||
|
||||
// For simplicity, only plain data layouts are currently supported.
|
||||
// VNNI packing is applied later as a separate lowering step.
|
||||
|
||||
#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
|
||||
#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3)>
|
||||
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
|
||||
func.func @negative_vnni_packed(%lhs: vector<8x8x2xf16>, %rhs: vector<8x16x2xf16>,
|
||||
%acc: vector<8x16xf32>) -> vector<8x16xf32> {
|
||||
%3 = vector.contract
|
||||
{indexing_maps = [#map, #map1, #map2],
|
||||
iterator_types = ["parallel", "parallel", "reduction", "reduction"],
|
||||
kind = #vector.kind<add>} %lhs, %rhs, %acc
|
||||
: vector<8x8x2xf16>, vector<8x16x2xf16> into vector<8x16xf32>
|
||||
return %3 : vector<8x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @negative_vnni_packed(
|
||||
// CHECK: vector.contract
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
|
||||
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
func.func @negative_combining_kind(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
|
||||
%acc: vector<8x16xf32>) -> vector<8x16xf32> {
|
||||
%3 = vector.contract
|
||||
{indexing_maps = [#map, #map1, #map2],
|
||||
iterator_types = ["parallel", "parallel", "reduction"],
|
||||
kind = #vector.kind<mul>} %lhs, %rhs, %acc
|
||||
: vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf32>
|
||||
return %3 : vector<8x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @negative_combining_kind(
|
||||
// CHECK: vector.contract
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
|
||||
#map2 = affine_map<(d0, d1, d2) -> ()>
|
||||
func.func @negative_accumulator_shape(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
|
||||
%acc: vector<f32>) -> vector<f32> {
|
||||
%3 = vector.contract
|
||||
{indexing_maps = [#map, #map1, #map2],
|
||||
iterator_types = ["reduction", "reduction", "reduction"],
|
||||
kind = #vector.kind<add>} %lhs, %rhs, %acc
|
||||
: vector<8x16xf16>, vector<16x16xf16> into vector<f32>
|
||||
return %3 : vector<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @negative_accumulator_shape(
|
||||
// CHECK: vector.contract
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0, d1, d2) -> (d2, d0)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
|
||||
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
func.func @negative_gemm_transpose_a(%lhs: vector<16x8xf16>, %rhs: vector<16x16xf16>,
|
||||
%acc: vector<8x16xf32>) -> vector<8x16xf32> {
|
||||
%3 = vector.contract
|
||||
{indexing_maps = [#map, #map1, #map2],
|
||||
iterator_types = ["parallel", "parallel", "reduction"],
|
||||
kind = #vector.kind<add>} %lhs, %rhs, %acc
|
||||
: vector<16x8xf16>, vector<16x16xf16> into vector<8x16xf32>
|
||||
return %3 : vector<8x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @negative_gemm_transpose_a(
|
||||
// CHECK: vector.contract
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
func.func @negative_gemm_transpose_b(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>,
|
||||
%acc: vector<8x16xf32>) -> vector<8x16xf32> {
|
||||
%3 = vector.contract
|
||||
{indexing_maps = [#map, #map1, #map2],
|
||||
iterator_types = ["parallel", "parallel", "reduction"],
|
||||
kind = #vector.kind<add>} %lhs, %rhs, %acc
|
||||
: vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf32>
|
||||
return %3 : vector<8x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @negative_gemm_transpose_b(
|
||||
// CHECK: vector.contract
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
|
||||
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
func.func @negative_n_dim_size(%lhs: vector<8x16xf16>, %rhs: vector<16x32xf16>,
|
||||
%acc: vector<8x32xf32>) -> vector<8x32xf32> {
|
||||
%3 = vector.contract
|
||||
{indexing_maps = [#map, #map1, #map2],
|
||||
iterator_types = ["parallel", "parallel", "reduction"],
|
||||
kind = #vector.kind<add>} %lhs, %rhs, %acc
|
||||
: vector<8x16xf16>, vector<16x32xf16> into vector<8x32xf32>
|
||||
return %3 : vector<8x32xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @negative_n_dim_size(
|
||||
// CHECK: vector.contract
|
||||
Reference in New Issue
Block a user