[AMDGPU] Adding AMDGPU dialect wrapper for ROCDL transpose loads. (#145395)
* 1-to-1 mapping wrapper op. * Direct lowering from AMDGPU wrapper to ROCDL intrinsics.
This commit is contained in:
@@ -898,6 +898,40 @@ def AMDGPU_GatherToLDSOp :
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def AMDGPU_TransposeLoadOp :
|
||||
AMDGPU_Op<"transpose_load", [SameVariadicOperandSize]>,
|
||||
Arguments<(ins Arg<AnyMemRef, "buffer to transpose load from", [MemRead]>:$src, Variadic<Index>:$srcIndices)>,
|
||||
Results<(outs AnyTypeOf<[AnyVectorOfNonZeroRank]>:$result)> {
|
||||
let summary = "MLIR wrapper for CDNA Transpose Load instructions";
|
||||
let description = [{
|
||||
The `amdgpu.transpose_load` op is a wrapper around the `ds_read_tr` instructions.
|
||||
The transpose load op represents a subgroup load from LDS memory,
|
||||
where the subgroup of threads collectively reads a matrix from the source
|
||||
memref, with each thread reading a vector of the matrix, and gets a transposed matrix
|
||||
in as the result. That is, each thread reads a vector of the col-major matrix at different
|
||||
indices, and the thread's read result is a vector of the corresponding row of the transposed
|
||||
matrix.
|
||||
|
||||
This op is a direct wrapper around the ROCDL `ds_read_tr` family intrinsics. Please refer
|
||||
to the CDNA4 ISA documentation for more details about its exact semantics.
|
||||
|
||||
Format example:
|
||||
```
|
||||
%0 = amdgpu.transpose_load %src[%srcIndices] : memref<128x256xf16> -> vector<4xf16>
|
||||
```
|
||||
Operands:
|
||||
* `$src`: LDS memref to read from.
|
||||
* `$srcIndices`: indices into `$src` to read from for this thread.
|
||||
* `$result`: target register this transpose load instruction will write to.
|
||||
|
||||
Note: Lowering is only supported on gfx950 and up.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
$src `[` $srcIndices `]` attr-dict `:` type($src) `->` type($result)
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def AMDGPU_ScaledMFMAOp :
|
||||
AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
|
||||
Pure]>,
|
||||
|
||||
@@ -1100,6 +1100,81 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct TransposeLoadOpLowering
|
||||
: public ConvertOpToLLVMPattern<TransposeLoadOp> {
|
||||
TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
|
||||
: ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
|
||||
|
||||
Chipset chipset;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (chipset != kGfx950)
|
||||
return op.emitOpError("Non-gfx950 chipset not supported");
|
||||
|
||||
Location loc = op.getLoc();
|
||||
auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
|
||||
|
||||
// Elements in subbyte memrefs are stored non-contiguously,
|
||||
// reject if source is sub-byte memref. Use emulated memrefs instead.
|
||||
size_t srcElementSize =
|
||||
srcMemRefType.getElementType().getIntOrFloatBitWidth();
|
||||
if (srcElementSize < 8)
|
||||
return op.emitOpError("Expect source memref to have at least 8 bits "
|
||||
"element size, got ")
|
||||
<< srcElementSize;
|
||||
|
||||
auto resultType = cast<VectorType>(op.getResult().getType());
|
||||
Value srcPtr =
|
||||
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
|
||||
(adaptor.getSrcIndices()));
|
||||
|
||||
size_t numElements = resultType.getNumElements();
|
||||
size_t elementTypeSize =
|
||||
resultType.getElementType().getIntOrFloatBitWidth();
|
||||
|
||||
// ROCDL transpose load intrinsics return vectors of 32-bit integers, if
|
||||
// the element size is smaller than 16 bits.
|
||||
Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
|
||||
rewriter.getIntegerType(32));
|
||||
Type llvmResultType = typeConverter->convertType(resultType);
|
||||
|
||||
switch (elementTypeSize) {
|
||||
case 4: {
|
||||
assert(numElements == 16);
|
||||
auto rocdlOp =
|
||||
rewriter.create<ROCDL::ds_read_tr4_b64>(loc, rocdlResultType, srcPtr);
|
||||
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
|
||||
break;
|
||||
}
|
||||
case 6: {
|
||||
assert(numElements == 16);
|
||||
auto rocdlOp =
|
||||
rewriter.create<ROCDL::ds_read_tr6_b96>(loc, rocdlResultType, srcPtr);
|
||||
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
|
||||
break;
|
||||
}
|
||||
case 8: {
|
||||
assert(numElements == 8);
|
||||
auto rocdlOp =
|
||||
rewriter.create<ROCDL::ds_read_tr8_b64>(loc, rocdlResultType, srcPtr);
|
||||
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
|
||||
break;
|
||||
}
|
||||
case 16: {
|
||||
assert(numElements == 4);
|
||||
rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
|
||||
srcPtr);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return op.emitOpError("Unsupported element size for transpose load");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
|
||||
GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
|
||||
: ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
|
||||
@@ -1749,7 +1824,7 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
|
||||
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
|
||||
ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
|
||||
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
|
||||
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
|
||||
chipset);
|
||||
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
|
||||
TransposeLoadOpLowering>(converter, chipset);
|
||||
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
#include <limits>
|
||||
@@ -524,6 +525,39 @@ LogicalResult GatherToLDSOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult TransposeLoadOp::verify() {
|
||||
MemRefType srcType = cast<MemRefType>(getSrc().getType());
|
||||
|
||||
if (!hasWorkgroupMemorySpace(srcType.getMemorySpace()))
|
||||
return emitOpError("source memory address space must be Workgroup");
|
||||
|
||||
auto transferType = cast<VectorType>(getType());
|
||||
size_t numElements = transferType.getNumElements();
|
||||
size_t elementTypeSize =
|
||||
transferType.getElementType().getIntOrFloatBitWidth();
|
||||
|
||||
// ElementSize -> NumElements
|
||||
const llvm::SmallDenseMap<size_t, size_t> KValidLoadSizeMap = {
|
||||
{4, 16},
|
||||
{6, 16},
|
||||
{8, 8},
|
||||
{16, 4},
|
||||
};
|
||||
|
||||
auto validNumElems = KValidLoadSizeMap.find(elementTypeSize);
|
||||
if (validNumElems == KValidLoadSizeMap.end()) {
|
||||
return emitOpError("Unsupported element type size for transpose load: ")
|
||||
<< elementTypeSize << " bits";
|
||||
}
|
||||
if (numElements != validNumElems->second) {
|
||||
return emitOpError(
|
||||
"Transferring type size mismatch: expected num of elements: ")
|
||||
<< validNumElems->second;
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
|
||||
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
|
||||
56
mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
Normal file
56
mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir
Normal file
@@ -0,0 +1,56 @@
|
||||
// RUN: mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
|
||||
// RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx945 2>&1 | FileCheck %s --check-prefix=CHECK-OLD
|
||||
|
||||
// CHECK-LABEL: func @transpose_load_to_rocdl_4xf16
|
||||
func.func @transpose_load_to_rocdl_4xf16(%idx1 : index, %idx2 : index, %wgmem : memref<128x72xf16, 3>) -> vector<4xf16> {
|
||||
// CHECK: rocdl.ds.read.tr16.b64
|
||||
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
|
||||
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x72xf16, 3> -> vector<4xf16>
|
||||
return %0 : vector<4xf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transpose_load_to_rocdl_8xi8
|
||||
func.func @transpose_load_to_rocdl_8xi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x128xi8, 3>) -> vector<8xi8> {
|
||||
// CHECK: %[[RES:.*]] = rocdl.ds.read.tr8.b64
|
||||
// CHECK-SAME: -> vector<2xi32>
|
||||
// CHECK-NEXT: llvm.bitcast %[[RES]] : vector<2xi32> to vector<8xi8>
|
||||
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
|
||||
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x128xi8, 3> -> vector<8xi8>
|
||||
return %0 : vector<8xi8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transpose_load_to_rocdl_i4_memrefxi8
|
||||
func.func @transpose_load_to_rocdl_i4_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<16xi4> {
|
||||
// CHECK: %[[RES:.*]] = rocdl.ds.read.tr4.b64
|
||||
// CHECK-SAME: -> vector<2xi32>
|
||||
// CHECK-NEXT: llvm.bitcast %[[RES]] : vector<2xi32> to vector<16xi4>
|
||||
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
|
||||
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<16xi4>
|
||||
return %0 : vector<16xi4>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transpose_load_to_rocdl_i6_memrefxi8
|
||||
func.func @transpose_load_to_rocdl_i6_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<16xi6> {
|
||||
// CHECK: %[[RES:.*]] = rocdl.ds.read.tr6.b96
|
||||
// CHECK-SAME: -> vector<3xi32>
|
||||
// CHECK-NEXT: llvm.bitcast %[[RES]] : vector<3xi32> to vector<16xi6>
|
||||
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
|
||||
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<16xi6>
|
||||
return %0 : vector<16xi6>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transpose_load_to_rocdl_i16_memrefxi8
|
||||
func.func @transpose_load_to_rocdl_i16_memrefxi8(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi8, 3>) -> vector<4xi16> {
|
||||
// CHECK: rocdl.ds.read.tr16.b64
|
||||
// CHECK-OLD: error: 'amdgpu.transpose_load' op Non-gfx950 chipset not supported
|
||||
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<4xi16>
|
||||
return %0 : vector<4xi16>
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
// RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx950 2>&1 | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
func.func @transpose_load_to_rocdl_16xi4(%idx1 : index, %idx2 : index, %wgmem : memref<128x16xi4, 3>) -> vector<16xi4> {
|
||||
// CHECK: memref to have at least 8 bits element size, got 4
|
||||
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x16xi4, 3> -> vector<16xi4>
|
||||
return %0 : vector<16xi4>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @transpose_load_to_rocdl_16xi6(%idx1 : index, %idx2 : index, %wgmem : memref<128x32xi6, 3>) -> vector<16xi6> {
|
||||
// CHECK: memref to have at least 8 bits element size, got 6
|
||||
%0 = amdgpu.transpose_load %wgmem[%idx1, %idx2] : memref<128x32xi6, 3> -> vector<16xi6>
|
||||
return %0 : vector<16xi6>
|
||||
}
|
||||
@@ -166,3 +166,59 @@ func.func @swizzle_scalable_vec(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
|
||||
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<[4]xf32>
|
||||
func.return %0 : vector<[4]xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @transpose_load_addrspace(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 1>) -> vector<4xf16> {
|
||||
// expected-error@+1 {{'amdgpu.transpose_load' op source memory address space must be Workgroup}}
|
||||
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 1> -> vector<4xf16>
|
||||
func.return %0 : vector<4xf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @transpose_load_addrspace(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 1>) -> vector<4xf16> {
|
||||
// expected-error@+1 {{'amdgpu.transpose_load' op source memory address space must be Workgroup}}
|
||||
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 1> -> vector<4xf16>
|
||||
func.return %0 : vector<4xf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @transpose_load_elem_f32(%idx1 : index, %idx2 : index, %mem : memref<128x32xf32, 3>) -> vector<4xf32> {
|
||||
// expected-error@+1 {{'amdgpu.transpose_load' op Unsupported element type size for transpose load: 32 bits}}
|
||||
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf32, 3> -> vector<4xf32>
|
||||
func.return %0 : vector<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @transpose_load_vector_size_f16(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 3>) -> vector<2xf16> {
|
||||
// expected-error@+1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 4}}
|
||||
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 3> -> vector<2xf16>
|
||||
func.return %0 : vector<2xf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @transpose_load_vector_size_i4(%idx1 : index, %idx2 : index, %mem : memref<128x32xi4, 3>) -> vector<20xi4> {
|
||||
// expected-error@+1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 16}}
|
||||
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xi4, 3> -> vector<20xi4>
|
||||
func.return %0 : vector<20xi4>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @transpose_load_vector_size_i8(%idx1 : index, %idx2 : index, %mem : memref<128x32xi8, 3>) -> vector<20xi8> {
|
||||
// expected-error@+1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 8}}
|
||||
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xi8, 3> -> vector<20xi8>
|
||||
func.return %0 : vector<20xi8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @transpose_load_vector_size_i8(%idx1 : index, %idx2 : index, %mem : memref<128x32xi6, 3>) -> vector<8xi6> {
|
||||
// expected-error@+1 {{'amdgpu.transpose_load' op Transferring type size mismatch: expected num of elements: 16}}
|
||||
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xi6, 3> -> vector<8xi6>
|
||||
func.return %0 : vector<8xi6>
|
||||
}
|
||||
|
||||
@@ -486,3 +486,10 @@ func.func @scaled_mfma(%arg0 : f8E8M0FNU, %arg1 : vector<32xf6E2M3FN>, %arg2 : v
|
||||
%0 = amdgpu.scaled_mfma(%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : f8E8M0FNU, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
|
||||
func.return %0 : vector<16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transpose_load
|
||||
func.func @transpose_load(%idx1 : index, %idx2 : index, %mem : memref<128x32xf16, 3>) -> vector<4xf16> {
|
||||
// CHECK: amdgpu.transpose_load
|
||||
%0 = amdgpu.transpose_load %mem[%idx1, %idx2] : memref<128x32xf16, 3> -> vector<4xf16>
|
||||
func.return %0 : vector<4xf16>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user