[AMDGPU] Adding AMDGPU dialect wrapper for ROCDL transpose loads. (#145395)

* 1-to-1 mapping wrapper op.
* Direct lowering from AMDGPU wrapper to ROCDL intrinsics.
This commit is contained in:
Alan Li
2025-06-25 22:58:14 -04:00
committed by GitHub
parent e2cc82b177
commit 3f3282cee8
7 changed files with 281 additions and 2 deletions

View File

@@ -898,6 +898,40 @@ def AMDGPU_GatherToLDSOp :
let hasVerifier = 1;
}
def AMDGPU_TransposeLoadOp :
AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
Results<(outs AnyTypeOf<[AnyVectorOfNonZeroRank]>:$result)> {
let summary = "MLIR wrapper for CDNA Transpose Load instructions";
let description = [{
The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions.
The transpose load op represents a subgroup load from LDS memory,
where the subgroup of threads collectively reads a matrix from the source
memref, with each thread reading a vector of the matrix, and gets a transposed matrix
in as the result. That is, each thread reads a vector of the col-major matrix at different
indices, and the thread's read result is a vector of the corresponding row of the transposed
matrix.
This op is a direct wrapper around the ROCDL `ds_read_tr` family intrinsics. Please refer
to the CDNA4 ISA documentation for more details about its exact semantics.
Format example:
```
%0 = amdgpu.transpose_load %src[%srcIndices] : memref<128x256xf16> -> vector<4xf16>
```
Operands:
* `$src`: LDS memref to read from.
* `$srcIndices`: indices into `$src` to read from for this thread.
* `$result`: target register this transpose load instruction will write to.
Note: Lowering is only supported on gfx950 and up.
}];
let assemblyFormat = [{
$src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($result)
}];
let hasVerifier = 1;
}
def AMDGPU_ScaledMFMAOp :
AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
Pure]>,

View File

@@ -1100,6 +1100,81 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};
struct TransposeLoadOpLowering
: public ConvertOpToLLVMPattern<TransposeLoadOp> {
TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (chipset != kGfx950)
return op.emitOpError("Non-gfx950 chipset not supported");
Location loc = op.getLoc();
auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
// Elements in subbyte memrefs are stored non-contiguously,
// reject if source is sub-byte memref. Use emulated memrefs instead.
size_t srcElementSize =
srcMemRefType.getElementType().getIntOrFloatBitWidth();
if (srcElementSize < 8)
return op.emitOpError("Expect source memref to have at least 8 bits "
"element size, got ")
<< srcElementSize;
auto resultType = cast<VectorType>(op.getResult().getType());
Value srcPtr =
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
(adaptor.getSrcIndices()));
size_t numElements = resultType.getNumElements();
size_t elementTypeSize =
resultType.getElementType().getIntOrFloatBitWidth();
// ROCDL transpose load intrinsics return vectors of 32-bit integers, if
// the element size is smaller than 16 bits.
Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
rewriter.getIntegerType(32));
Type llvmResultType = typeConverter->convertType(resultType);
switch (elementTypeSize) {
case 4: {
assert(numElements == 16);
auto rocdlOp =
rewriter.create<ROCDL::ds_read_tr4_b64>(loc, rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
case 6: {
assert(numElements == 16);
auto rocdlOp =
rewriter.create<ROCDL::ds_read_tr6_b96>(loc, rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
case 8: {
assert(numElements == 8);
auto rocdlOp =
rewriter.create<ROCDL::ds_read_tr8_b64>(loc, rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
case 16: {
assert(numElements == 4);
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
srcPtr);
break;
}
default:
return op.emitOpError("Unsupported element size for transpose load");
}
return success();
}
};
struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
@@ -1749,7 +1824,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
chipset);
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
TransposeLoadOpLowering>(converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}

View File

@@ -24,6 +24,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
#include <limits>
@@ -524,6 +525,39 @@ LogicalResult GatherToLDSOp::verify() {
return success();
}
LogicalResult TransposeLoadOp::verify() {
MemRefType srcType = cast<MemRefType>(getSrc().getType());
if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
return emitOpError("source memory address space must be Workgroup");
auto transferType = cast<VectorType>(getType());
size_t numElements = transferType.getNumElements();
size_t elementTypeSize =
transferType.getElementType().getIntOrFloatBitWidth();
// ElementSize -> NumElements
const llvm::SmallDenseMap<size_t, size_t> KValidLoadSizeMap = {
{4, 16},
{6, 16},
{8, 8},
{16, 4},
};
auto validNumElems = KValidLoadSizeMap.find(elementTypeSize);
if (validNumElems == KValidLoadSizeMap.end()) {
return emitOpError("Unsupported element type size for transpose load: ")
<< elementTypeSize << " bits";
}
if (numElements != validNumElems->second) {
return emitOpError(
"Transferring type size mismatch: expected num of elements: ")
<< validNumElems->second;
}
return success();
}
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES

View File

@@ -0,0 +1,56 @@
// RUN: mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
// RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx945 2>&1 | FileCheck %s --check-prefix=CHECK-OLD
// CHECK-LABEL: func @transpose_load_to_rocdl_4xf16
func.func @transpose_load_to_rocdl_4xf16(%idx1 : index, %idx2 : index, %wgmem : memref<128x72xf16, 3>) -> vector<4xf16> {
// CHECK: rocdl.ds.read.tr16.b64
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x72xf16, 3> -> vector<4xf16>
return %0 : vector<4xf16>
}
// -----
// CHECK-LABEL: func @transpose_load_to_rocdl_8xi8
func.func @transpose_load_to_rocdl_8xi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x128xi8, 3>) -> vector<8xi8> {
// CHECK: %[[RES:.*]] = rocdl.ds.read.tr8.b64
// CHECK-SAME: -> vector<2xi32>
// CHECK-NEXT: llvm.bitcast %[[RES]] : vector<2xi32> to vector<8xi8>
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x128xi8, 3> -> vector<8xi8>
return %0 : vector<8xi8>
}
// -----
// CHECK-LABEL: func @transpose_load_to_rocdl_i4_memrefxi8
func.func @transpose_load_to_rocdl_i4_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<16xi4> {
// CHECK: %[[RES:.*]] = rocdl.ds.read.tr4.b64
// CHECK-SAME: -> vector<2xi32>
// CHECK-NEXT: llvm.bitcast %[[RES]] : vector<2xi32> to vector<16xi4>
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<16xi4>
return %0 : vector<16xi4>
}
// -----
// CHECK-LABEL: func @transpose_load_to_rocdl_i6_memrefxi8
func.func @transpose_load_to_rocdl_i6_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<16xi6> {
// CHECK: %[[RES:.*]] = rocdl.ds.read.tr6.b96
// CHECK-SAME: -> vector<3xi32>
// CHECK-NEXT: llvm.bitcast %[[RES]] : vector<3xi32> to vector<16xi6>
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<16xi6>
return %0 : vector<16xi6>
}
// -----
// CHECK-LABEL: func @transpose_load_to_rocdl_i16_memrefxi8
func.func @transpose_load_to_rocdl_i16_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<4xi16> {
// CHECK: rocdl.ds.read.tr16.b64
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<4xi16>
return %0 : vector<4xi16>
}

View File

@@ -0,0 +1,17 @@
// RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx950 2>&1 | FileCheck %s
// -----
func.func @transpose_load_to_rocdl_16xi4(%idx1 : index, %idx2 : index, %wgmem : memref<128x16xi4, 3>) -> vector<16xi4> {
// CHECK: memref to have at least 8 bits element size, got 4
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x16xi4, 3> -> vector<16xi4>
return %0 : vector<16xi4>
}
// -----
func.func @transpose_load_to_rocdl_16xi6(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi6, 3>) -> vector<16xi6> {
// CHECK: memref to have at least 8 bits element size, got 6
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi6, 3> -> vector<16xi6>
return %0 : vector<16xi6>
}

View File

@@ -166,3 +166,59 @@ func.func @swizzle_scalable_vec(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<[4]xf32>
func.return %0 : vector<[4]xf32>
}
// -----
func.func @transpose_load_addrspace(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 1>) -> vector<4xf16> {
// expected-error@+1 {{'amdgpu.transpose_load' op source memory address space must be Workgroup}}
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 1> -> vector<4xf16>
func.return %0 : vector<4xf16>
}
// -----
func.func @transpose_load_addrspace(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 1>) -> vector<4xf16> {
// expected-error@+1 {{'amdgpu.transpose_load' op source memory address space must be Workgroup}}
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 1> -> vector<4xf16>
func.return %0 : vector<4xf16>
}
// -----
func.func @transpose_load_elem_f32(%idx1 : index, %idx2 : index, %mem : memref<128x32xf32, 3>) -> vector<4xf32> {
// expected-error@+1 {{'amdgpu.transpose_load' op Unsupported element type size for transpose load: 32 bits}}
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf32, 3> -> vector<4xf32>
func.return %0 : vector<4xf32>
}
// -----
func.func @transpose_load_vector_size_f16(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 3>) -> vector<2xf16> {
// expected-error@+1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 4}}
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 3> -> vector<2xf16>
func.return %0 : vector<2xf16>
}
// -----
func.func @transpose_load_vector_size_i4(%idx1 : index, %idx2 : index, %mem : memref<128x32xi4, 3>) -> vector<20xi4> {
// expected-error@+1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 16}}
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xi4, 3> -> vector<20xi4>
func.return %0 : vector<20xi4>
}
// -----
func.func @transpose_load_vector_size_i8(%idx1 : index, %idx2 : index, %mem : memref<128x32xi8, 3>) -> vector<20xi8> {
// expected-error@+1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 8}}
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<20xi8>
func.return %0 : vector<20xi8>
}
// -----
func.func @transpose_load_vector_size_i8(%idx1 : index, %idx2 : index, %mem : memref<128x32xi6, 3>) -> vector<8xi6> {
// expected-error@+1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 16}}
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xi6, 3> -> vector<8xi6>
func.return %0 : vector<8xi6>
}

View File

@@ -486,3 +486,10 @@ func.func @scaled_mfma(%arg0 : f8E8M0FNU, %arg1 : vector<32xf6E2M3FN>, %arg2 : v
%0 = amdgpu.scaled_mfma(%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : f8E8M0FNU, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
func.return %0 : vector<16xf32>
}
// CHECK-LABEL: func @transpose_load
func.func @transpose_load(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 3>) -> vector<4xf16> {
// CHECK: amdgpu.transpose_load
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 3> -> vector<4xf16>
func.return %0 : vector<4xf16>
}