[mlir][nvgpu] Add a nvgpu.rewrite_copy_as_tma transform operation.
This revision adds support for direct lowering of a linalg.copy on buffers between global and shared memory to a tma async load + synchronization operations. This uses the recently introduced Hopper NVVM and NVGPU abstraction to connect things end to end. Differential Revision: https://reviews.llvm.org/D157087
This commit is contained in:
committed by
Nicolas Vasilache
parent
b6d994de0f
commit
a3cd2eeb2d
@@ -164,4 +164,33 @@ def RewriteMatmulAsMmaSyncOp :
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RewriteCopyAsTmaOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def RewriteCopyAsTmaOp :
|
||||
Op<Transform_Dialect, "nvgpu.rewrite_copy_as_tma",
|
||||
[FunctionalStyleTransformOpTrait,
|
||||
MemoryEffectsOpInterface,
|
||||
TransformEachOpTrait,
|
||||
TransformOpInterface,
|
||||
ReportTrackingListenerFailuresOpTrait]> {
|
||||
let description = [{
|
||||
Rewrite a copy operation on memref to tma operations that transit through
|
||||
shared memory.
|
||||
}];
|
||||
|
||||
let arguments = (ins TransformHandleTypeInterface:$target);
|
||||
let results = (outs);
|
||||
|
||||
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::DiagnosedSilenceableFailure apply(
|
||||
::mlir::transform::TransformRewriter &rewriter,
|
||||
::mlir::transform::TransformResults &transformResults,
|
||||
::mlir::transform::TransformState &state);
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // NVGPU_TRANSFORM_OPS
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
|
||||
@@ -20,20 +21,17 @@
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/TypeRange.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
using namespace mlir::nvgpu;
|
||||
using namespace mlir::NVVM;
|
||||
using namespace mlir::transform;
|
||||
|
||||
#define DEBUG_TYPE "nvgpu-transforms"
|
||||
@@ -517,7 +515,7 @@ private:
|
||||
/// Build a list of memref.load operations indexed at `(row, col)` indices
|
||||
/// that make sense for a particular MMA instruction and specified via the
|
||||
/// IndexCalculator callback.
|
||||
SmallVector<Value> buildMemrefLoads(OpBuilder &b, Location loc,
|
||||
SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc,
|
||||
OpFoldResult laneId, Value memref,
|
||||
IndexCalculator indexFn);
|
||||
|
||||
@@ -527,7 +525,7 @@ private:
|
||||
/// data that makes sense for the particular MMA operation.
|
||||
/// The `vectorShape` matches existing NVGPU dialect op specification but
|
||||
/// could also be flattened in the future if needed for simplification.
|
||||
Value buildMmaSyncMemrefLoadOperand(OpBuilder &b, Location loc,
|
||||
Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc,
|
||||
OpFoldResult laneId, Value memref,
|
||||
IndexCalculator indexFn,
|
||||
ArrayRef<int64_t> vectorShape);
|
||||
@@ -535,7 +533,7 @@ private:
|
||||
/// Build a list of memref.store operations indexed at `(row, col)` indices
|
||||
/// that make sense for a particular MMA instruction and specified via the
|
||||
/// IndexCalculator callback.
|
||||
SmallVector<Operation *> buildMemrefStores(OpBuilder &b, Location loc,
|
||||
SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc,
|
||||
ValueRange toStore,
|
||||
OpFoldResult laneId, Value memref,
|
||||
IndexCalculator indexFn);
|
||||
@@ -546,7 +544,7 @@ private:
|
||||
/// data that makes sense for the particular MMA operation.
|
||||
/// The `vectorShape` matches existing NVGPU dialect op specification but
|
||||
/// could also be flattened in the future if needed for simplification.
|
||||
SmallVector<Operation *> buildMmaSyncMemrefStoreOperand(
|
||||
SmallVector<Operation *> buildMmaSyncMemRefStoreOperand(
|
||||
OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
|
||||
Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape);
|
||||
|
||||
@@ -573,7 +571,7 @@ static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Value> MmaSyncBuilder::buildMemrefLoads(OpBuilder &b, Location loc,
|
||||
SmallVector<Value> MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
|
||||
OpFoldResult laneId,
|
||||
Value memref,
|
||||
IndexCalculator indexFn) {
|
||||
@@ -591,10 +589,10 @@ SmallVector<Value> MmaSyncBuilder::buildMemrefLoads(OpBuilder &b, Location loc,
|
||||
return res;
|
||||
}
|
||||
|
||||
Value MmaSyncBuilder::buildMmaSyncMemrefLoadOperand(
|
||||
Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
|
||||
OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
|
||||
IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
|
||||
auto loads = buildMemrefLoads(b, loc, laneId, memref, indexFn);
|
||||
auto loads = buildMemRefLoads(b, loc, laneId, memref, indexFn);
|
||||
|
||||
Type elementType = getElementTypeOrSelf(memref.getType());
|
||||
auto vt = VectorType::get(vectorShape, elementType);
|
||||
@@ -614,7 +612,7 @@ Value MmaSyncBuilder::buildMmaSyncMemrefLoadOperand(
|
||||
}
|
||||
|
||||
SmallVector<Operation *>
|
||||
MmaSyncBuilder::buildMemrefStores(OpBuilder &b, Location loc,
|
||||
MmaSyncBuilder::buildMemRefStores(OpBuilder &b, Location loc,
|
||||
ValueRange toStore, OpFoldResult laneId,
|
||||
Value memref, IndexCalculator indexFn) {
|
||||
auto aff = [&](AffineExpr e) {
|
||||
@@ -632,7 +630,7 @@ MmaSyncBuilder::buildMemrefStores(OpBuilder &b, Location loc,
|
||||
return res;
|
||||
}
|
||||
|
||||
SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemrefStoreOperand(
|
||||
SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
|
||||
OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
|
||||
Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
|
||||
SmallVector<Value> toStore;
|
||||
@@ -647,7 +645,7 @@ SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemrefStoreOperand(
|
||||
[&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
|
||||
toStore.push_back(v);
|
||||
});
|
||||
return buildMemrefStores(b, loc, toStore, laneId, memref, indexFn);
|
||||
return buildMemRefStores(b, loc, toStore, laneId, memref, indexFn);
|
||||
}
|
||||
|
||||
static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
|
||||
@@ -690,22 +688,22 @@ MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
|
||||
}
|
||||
|
||||
FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
|
||||
Value lhsMemref = linalgOp.getDpsInputOperand(0)->get();
|
||||
Value rhsMemref = linalgOp.getDpsInputOperand(1)->get();
|
||||
Value resMemref = linalgOp.getDpsInitOperand(0)->get();
|
||||
assert(lhsMemref.getType().cast<MemRefType>().getRank() == 2 &&
|
||||
Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
|
||||
Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
|
||||
Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
|
||||
assert(lhsMemRef.getType().cast<MemRefType>().getRank() == 2 &&
|
||||
"expected lhs to be a 2D memref");
|
||||
assert(rhsMemref.getType().cast<MemRefType>().getRank() == 2 &&
|
||||
assert(rhsMemRef.getType().cast<MemRefType>().getRank() == 2 &&
|
||||
"expected rhs to be a 2D memref");
|
||||
assert(resMemref.getType().cast<MemRefType>().getRank() == 2 &&
|
||||
assert(resMemRef.getType().cast<MemRefType>().getRank() == 2 &&
|
||||
"expected res to be a 2D memref");
|
||||
|
||||
int64_t m = cast<MemRefType>(lhsMemref.getType()).getShape()[0];
|
||||
int64_t n = cast<MemRefType>(rhsMemref.getType()).getShape()[1];
|
||||
int64_t k = cast<MemRefType>(lhsMemref.getType()).getShape()[1];
|
||||
Type lhsType = getElementTypeOrSelf(lhsMemref.getType());
|
||||
Type rhsType = getElementTypeOrSelf(rhsMemref.getType());
|
||||
Type resType = getElementTypeOrSelf(resMemref.getType());
|
||||
int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0];
|
||||
int64_t n = cast<MemRefType>(rhsMemRef.getType()).getShape()[1];
|
||||
int64_t k = cast<MemRefType>(lhsMemRef.getType()).getShape()[1];
|
||||
Type lhsType = getElementTypeOrSelf(lhsMemRef.getType());
|
||||
Type rhsType = getElementTypeOrSelf(rhsMemRef.getType());
|
||||
Type resType = getElementTypeOrSelf(resMemRef.getType());
|
||||
|
||||
FailureOr<MmaSyncInfo> maybeInfo =
|
||||
getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
|
||||
@@ -715,15 +713,15 @@ FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
|
||||
MmaSyncInfo info = *maybeInfo;
|
||||
auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
|
||||
auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
|
||||
Value lhs = buildMmaSyncMemrefLoadOperand(b, loc, laneId, lhsMemref,
|
||||
Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
|
||||
lhsIndexFn, lhsShape);
|
||||
Value rhs = buildMmaSyncMemrefLoadOperand(b, loc, laneId, rhsMemref,
|
||||
Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
|
||||
rhsIndexFn, rhsShape);
|
||||
Value res = buildMmaSyncMemrefLoadOperand(b, loc, laneId, resMemref,
|
||||
Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
|
||||
resIndexFn, resShape);
|
||||
res = b.create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape,
|
||||
info.tf32Enabled);
|
||||
buildMmaSyncMemrefStoreOperand(b, loc, res, laneId, resMemref, resIndexFn,
|
||||
buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
|
||||
resShape);
|
||||
return res.getDefiningOp();
|
||||
}
|
||||
@@ -754,6 +752,284 @@ DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Hopper builders.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Helper to create the base Hopper-specific operations that are reused in
|
||||
/// various other places.
|
||||
struct HopperBuilder {
|
||||
HopperBuilder(RewriterBase &rewriter, Location loc)
|
||||
: rewriter(rewriter), loc(loc) {}
|
||||
|
||||
TypedValue<nvgpu::MBarrierType>
|
||||
buildAndInitBarrierInSharedMemory(OpFoldResult numThreads);
|
||||
|
||||
/// Create tma descriptor op to initiate transfer from global to shared
|
||||
/// memory. This must be done before the launch op, on the host.
|
||||
TypedValue<nvgpu::TensorMapDescriptorType>
|
||||
buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
|
||||
gpu::LaunchOp launchOp);
|
||||
|
||||
/// Build a tma load from global memory to shared memory using `barrier` to
|
||||
/// synchronize. Return the number of bytes that will be transferred.
|
||||
OpFoldResult
|
||||
buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
|
||||
TypedValue<MemRefType> sharedMemref,
|
||||
TypedValue<nvgpu::MBarrierType> barrier,
|
||||
SmallVectorImpl<Operation *> &loadOps);
|
||||
void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierType> barrier,
|
||||
ArrayRef<OpFoldResult> sizes);
|
||||
|
||||
/// If threadIdx.x == 0 does TMA request + wait, else just wait.
|
||||
/// Return the operation that performs the transfer on thread0.
|
||||
// TODO: In the future, don't hardcode to thread 0 but elect a leader.
|
||||
SmallVector<Operation *> buildPredicateLoadsOnThread0(
|
||||
ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
|
||||
ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
|
||||
TypedValue<nvgpu::MBarrierType> barrier);
|
||||
|
||||
void buildTryWaitParity(TypedValue<nvgpu::MBarrierType> barrier);
|
||||
|
||||
RewriterBase &rewriter;
|
||||
Location loc;
|
||||
};
|
||||
|
||||
SmallVector<Operation *> HopperBuilder::buildPredicateLoadsOnThread0(
|
||||
ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
|
||||
ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
|
||||
TypedValue<nvgpu::MBarrierType> barrier) {
|
||||
SmallVector<Operation *> loadOps;
|
||||
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value tidx = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
|
||||
Value cond =
|
||||
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, tidx, zero);
|
||||
// clang-format off
|
||||
rewriter.create<scf::IfOp>(
|
||||
/*location=*/loc,
|
||||
/*conditional=*/cond,
|
||||
/*thenBuilder=*/
|
||||
[&](OpBuilder &lb, Location loc) {
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.reserve(globalDescriptors.size());
|
||||
for (auto [desc, shmem] : llvm::zip_equal(
|
||||
globalDescriptors, sharedMemBuffers)) {
|
||||
OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
|
||||
sizes.push_back(sz);
|
||||
}
|
||||
// TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load.
|
||||
// This may or may not have perf implications.
|
||||
buildBarrierArriveTx(barrier, sizes);
|
||||
rewriter.create<scf::YieldOp>(loc);
|
||||
},
|
||||
/*elseBuilder=*/
|
||||
[&](OpBuilder &lb, Location loc) {
|
||||
// TODO: is this for no-thread divergence?
|
||||
// Should we just yield the size and hoist?
|
||||
buildBarrierArriveTx(barrier, getAsIndexOpFoldResult(rewriter.getContext(), 0));
|
||||
rewriter.create<scf::YieldOp>(loc);
|
||||
});
|
||||
// clang-format on
|
||||
return loadOps;
|
||||
}
|
||||
|
||||
static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) {
|
||||
return gpu::AddressSpaceAttr::get(
|
||||
b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
|
||||
// return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace));
|
||||
}
|
||||
|
||||
TypedValue<nvgpu::MBarrierType>
|
||||
HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) {
|
||||
auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
|
||||
Value barrier = rewriter.create<nvgpu::MBarrierCreateOp>(
|
||||
loc, nvgpu::MBarrierType::get(rewriter.getContext(), sharedMemorySpace));
|
||||
rewriter.create<nvgpu::MBarrierInitOp>(
|
||||
loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads));
|
||||
rewriter.create<gpu::BarrierOp>(loc);
|
||||
return cast<TypedValue<nvgpu::MBarrierType>>(barrier);
|
||||
}
|
||||
|
||||
TypedValue<nvgpu::TensorMapDescriptorType>
|
||||
HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
|
||||
gpu::LaunchOp launchOp) {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(launchOp);
|
||||
Value unrankedMemRef = rewriter.create<memref::CastOp>(
|
||||
loc,
|
||||
UnrankedMemRefType::get(memref.getType().getElementType(),
|
||||
memref.getType().getMemorySpace()),
|
||||
memref);
|
||||
SmallVector<OpFoldResult> mixedSizes =
|
||||
memref::getMixedSizes(rewriter, loc, memref);
|
||||
SmallVector<Value> sizes =
|
||||
getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes);
|
||||
|
||||
auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
|
||||
Value desc = rewriter.create<nvgpu::TmaCreateDescriptorOp>(
|
||||
loc,
|
||||
nvgpu::TensorMapDescriptorType::get(
|
||||
rewriter.getContext(),
|
||||
MemRefType::Builder(memref.getType())
|
||||
.setMemorySpace(sharedMemorySpace),
|
||||
TensorMapSwizzleKind::SWIZZLE_NONE,
|
||||
TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
|
||||
TensorMapInterleaveKind::INTERLEAVE_NONE),
|
||||
unrankedMemRef, sizes);
|
||||
return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
|
||||
}
|
||||
|
||||
OpFoldResult HopperBuilder::buildTmaAsyncLoad(
|
||||
TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
|
||||
TypedValue<MemRefType> sharedMemref,
|
||||
TypedValue<nvgpu::MBarrierType> barrier,
|
||||
SmallVectorImpl<Operation *> &loadOps) {
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Operation *loadOp = rewriter.create<nvgpu::TmaAsyncLoadOp>(
|
||||
loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero});
|
||||
loadOps.push_back(loadOp);
|
||||
auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref);
|
||||
SmallVector<AffineExpr> symbols(mixedSizes.size());
|
||||
bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
|
||||
AffineExpr prodExprInBytes =
|
||||
computeProduct(ctx, symbols) *
|
||||
(sharedMemref.getType().getElementTypeBitWidth() / 8);
|
||||
auto res = affine::makeComposedFoldedAffineApply(rewriter, loc,
|
||||
prodExprInBytes, mixedSizes);
|
||||
return res;
|
||||
}
|
||||
|
||||
void HopperBuilder::buildBarrierArriveTx(
|
||||
TypedValue<nvgpu::MBarrierType> barrier,
|
||||
ArrayRef<OpFoldResult> mixedSizes) {
|
||||
assert(!mixedSizes.empty() && "expecte non-empty sizes");
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
SmallVector<AffineExpr> symbols(mixedSizes.size());
|
||||
bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
|
||||
AffineExpr sumExpr = computeSum(ctx, symbols);
|
||||
OpFoldResult size =
|
||||
affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, mixedSizes);
|
||||
Value sizeVal = getValueOrCreateConstantIndexOp(rewriter, loc, size);
|
||||
rewriter.create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal);
|
||||
}
|
||||
|
||||
void HopperBuilder::buildTryWaitParity(
|
||||
TypedValue<nvgpu::MBarrierType> barrier) {
|
||||
Value parity = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
// 10M is an arbitrary, not too small or too big number to specify the number
|
||||
// of ticks before retry.
|
||||
// TODO: hoist this in a default dialect constant.
|
||||
Value ticksBeforeRetry =
|
||||
rewriter.create<arith::ConstantIndexOp>(loc, 10000000);
|
||||
rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity,
|
||||
ticksBeforeRetry);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RewriteCopyAsTmaOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Helper to create the tma operations corresponding to `linalg::CopyOp`.
|
||||
struct CopyBuilder : public HopperBuilder {
|
||||
CopyBuilder(RewriterBase &rewriter, Location loc)
|
||||
: HopperBuilder(rewriter, loc) {}
|
||||
|
||||
SmallVector<Operation *> rewrite(ArrayRef<Operation *> copyOps);
|
||||
};
|
||||
|
||||
SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) {
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
if (copyOps.empty())
|
||||
return SmallVector<Operation *>();
|
||||
|
||||
auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
|
||||
assert(launchOp && "expected launch op");
|
||||
|
||||
// 1. Init a barrier object in shared memory.
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
rewriter.setInsertionPoint(copyOps.front());
|
||||
AffineExpr bx, by, bz;
|
||||
bindSymbols(ctx, bx, by, bz);
|
||||
AffineExpr prod = computeProduct(ctx, ArrayRef<AffineExpr>{bx, by, bz});
|
||||
OpFoldResult numThreads = affine::makeComposedFoldedAffineApply(
|
||||
rewriter, loc, prod,
|
||||
ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
|
||||
launchOp.getBlockSizeZ()});
|
||||
|
||||
TypedValue<nvgpu::MBarrierType> barrier =
|
||||
buildAndInitBarrierInSharedMemory(numThreads);
|
||||
|
||||
SmallVector<TypedValue<MemRefType>> shmems;
|
||||
SmallVector<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescs;
|
||||
for (Operation *op : copyOps) {
|
||||
auto copyOp = cast<linalg::CopyOp>(op);
|
||||
auto inMemRef =
|
||||
cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
|
||||
MemRefType inMemRefType = inMemRef.getType();
|
||||
assert(inMemRefType.getRank() == 2 && "expected in to be a 2D memref");
|
||||
|
||||
// 2. Build global memory descriptor.
|
||||
TypedValue<nvgpu::TensorMapDescriptorType> globalDesc =
|
||||
buildGlobalMemRefDescriptor(inMemRef, launchOp);
|
||||
globalDescs.push_back(globalDesc);
|
||||
|
||||
// 3. Shared memory and descriptor for the tmp array.
|
||||
auto shmem =
|
||||
cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
|
||||
shmems.push_back(shmem);
|
||||
}
|
||||
|
||||
// 4. Load in from global memory to shared memory using tma.
|
||||
OpBuilder::InsertionGuard g2(rewriter);
|
||||
rewriter.setInsertionPoint(copyOps.front());
|
||||
SmallVector<Operation *> results =
|
||||
buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
|
||||
|
||||
// 5. Spin-loop until data is ready.
|
||||
buildTryWaitParity(barrier);
|
||||
|
||||
// 6. Erase the ops that have now been rewritten.
|
||||
for (Operation *op : copyOps)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter,
|
||||
transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
auto payloadOps = state.getPayloadOps(getTarget());
|
||||
gpu::LaunchOp commonLaunchOp;
|
||||
Operation *firstOp, *failingOp;
|
||||
if (llvm::any_of(payloadOps, [&](Operation *op) {
|
||||
if (!commonLaunchOp) {
|
||||
commonLaunchOp = op->getParentOfType<gpu::LaunchOp>();
|
||||
firstOp = op;
|
||||
}
|
||||
auto fail = !op->getParentOfType<gpu::LaunchOp>() ||
|
||||
commonLaunchOp != op->getParentOfType<gpu::LaunchOp>() ||
|
||||
!isa<linalg::CopyOp>(op);
|
||||
if (fail)
|
||||
failingOp = op;
|
||||
return fail;
|
||||
})) {
|
||||
DiagnosedSilenceableFailure diag =
|
||||
emitSilenceableError()
|
||||
<< "target ops must be linalg::CopyOp nested under a common "
|
||||
"gpu.LaunchOp to be rewritten because the tma descriptors need to "
|
||||
"be created on the host.\nBut got: "
|
||||
<< *firstOp << "\nand " << *failingOp;
|
||||
return diag;
|
||||
}
|
||||
|
||||
// TODO: more robust detection of copy, with transposes etc.
|
||||
CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps));
|
||||
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Transform op registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -767,6 +1043,7 @@ public:
|
||||
declareGeneratedDialect<arith::ArithDialect>();
|
||||
declareGeneratedDialect<affine::AffineDialect>();
|
||||
declareGeneratedDialect<nvgpu::NVGPUDialect>();
|
||||
declareGeneratedDialect<NVVM::NVVMDialect>();
|
||||
declareGeneratedDialect<vector::VectorDialect>();
|
||||
registerTransformOps<
|
||||
#define GET_OP_LIST
|
||||
|
||||
@@ -161,13 +161,13 @@ AffineExpr mlir::computeSum(MLIRContext *ctx, ArrayRef<AffineExpr> basis) {
|
||||
if (basis.empty())
|
||||
return getAffineConstantExpr(0, ctx);
|
||||
return std::accumulate(basis.begin(), basis.end(),
|
||||
getAffineConstantExpr(1, ctx),
|
||||
getAffineConstantExpr(0, ctx),
|
||||
std::plus<AffineExpr>());
|
||||
}
|
||||
|
||||
AffineExpr mlir::computeProduct(MLIRContext *ctx, ArrayRef<AffineExpr> basis) {
|
||||
if (basis.empty())
|
||||
return getAffineConstantExpr(0, ctx);
|
||||
return getAffineConstantExpr(1, ctx);
|
||||
return std::accumulate(basis.begin(), basis.end(),
|
||||
getAffineConstantExpr(1, ctx),
|
||||
std::multiplies<AffineExpr>());
|
||||
|
||||
84
mlir/test/Dialect/NVGPU/tmaload-transform.mlir
Normal file
84
mlir/test/Dialect/NVGPU/tmaload-transform.mlir
Normal file
@@ -0,0 +1,84 @@
|
||||
// RUN: mlir-opt %s \
|
||||
// RUN: -test-transform-dialect-interpreter \
|
||||
// RUN: -test-transform-dialect-erase-schedule \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
memref.global "private" @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
|
||||
memref.global "private" @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
|
||||
|
||||
// CHECK-LABEL: func.func @main()
|
||||
func.func @main() {
|
||||
%c1 = arith.constant 1 : index
|
||||
%c128 = arith.constant 128 : index
|
||||
|
||||
%0 = gpu.wait async
|
||||
%memref, %asyncToken = gpu.alloc async [%0] () : memref<64x8xf32>
|
||||
%memref_1, %asyncToken_2 = gpu.alloc async [%0] () : memref<8x128xf32>
|
||||
|
||||
// CHECK: %[[M1:.*]] = memref.cast %{{.*}} : memref<64x8xf32> to memref<*xf32>
|
||||
// CHECK: %[[c64:.*]] = arith.constant 64 : index
|
||||
// CHECK: %[[c8:.*]] = arith.constant 8 : index
|
||||
// CHECK: %[[D1:.*]] = nvgpu.tma.create.descriptor %[[M1]] box[%[[c64]], %[[c8]]]
|
||||
// CHECK-SAME: : memref<*xf32> -> <tensor = memref<64x8xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
|
||||
// CHECK: %[[cast_2:.*]] = memref.cast %memref_0 : memref<8x128xf32> to memref<*xf32>
|
||||
// CHECK: %[[c8_2:.*]] = arith.constant 8 : index
|
||||
// CHECK: %[[c128_2:.*]] = arith.constant 128 : index
|
||||
// CHECK: %[[D2:.*]] = nvgpu.tma.create.descriptor %cast_2 box[%[[c8_2]], %[[c128_2]]]
|
||||
// CHECK-SAME: : memref<*xf32> -> <tensor = memref<8x128xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
|
||||
// CHECK: gpu.launch
|
||||
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
|
||||
threads(%tx, %ty, %tz) in (%block_x = %c128, %block_y = %c1, %block_z = %c1) {
|
||||
// CHECK: %[[G1:.*]] = memref.get_global @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
|
||||
// CHECK: %[[G2:.*]] = memref.get_global @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
|
||||
%out = memref.get_global @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
|
||||
%out_1 = memref.get_global @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
|
||||
|
||||
// CHECK: %[[B:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>
|
||||
// CHECK: nvgpu.mbarrier.init %[[B]], %{{.*}} : <memorySpace = #gpu.address_space<workgroup>
|
||||
// CHECK: gpu.barrier
|
||||
//
|
||||
// CHECK: %[[c0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[TIDX:.*]] = gpu.thread_id x
|
||||
// CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[TIDX]], %[[c0]] : index
|
||||
//
|
||||
// CHECK: scf.if %[[CMP]] {
|
||||
//
|
||||
// CHECK: %[[c0_7:.*]] = arith.constant 0 : index
|
||||
// CHECK: nvgpu.tma.async.load %[[D1]][%[[c0_7]], %[[c0_7]]], %[[B]] to %[[G1]]
|
||||
// CHECK-SAME: : <tensor = memref<64x8xf32, #gpu.address_space<workgroup>>,
|
||||
// CHECK-SAME: swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>
|
||||
// CHECK-SAME: -> memref<64x8xf32, #gpu.address_space<workgroup>>
|
||||
//
|
||||
// CHECK: %[[c0_8:.*]] = arith.constant 0 : index
|
||||
// CHECK: nvgpu.tma.async.load %[[D2]][%[[c0_8]], %[[c0_8]]], %[[B]] to %[[G2]]
|
||||
// CHECK-SAME: : <tensor = memref<8x128xf32, #gpu.address_space<workgroup>>,
|
||||
// CHECK-SAME: swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>
|
||||
// CHECK-SAME: -> memref<8x128xf32, #gpu.address_space<workgroup>>
|
||||
//
|
||||
// CHECK: %[[c6144:.*]] = arith.constant 6144 : index
|
||||
// CHECK: nvgpu.mbarrier.arrive.expect_tx %[[B]], %[[c6144]] : <memorySpace = #gpu.address_space<workgroup>
|
||||
// CHECK: } else {
|
||||
// CHECK: %[[c0_7:.*]] = arith.constant 0 : index
|
||||
// CHECK: nvgpu.mbarrier.arrive.expect_tx %[[B]], %[[c0_7]] : <memorySpace = #gpu.address_space<workgroup>
|
||||
// CHECK: }
|
||||
//
|
||||
// CHECK: %[[c0_6:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[c10000000:.*]] = arith.constant 10000000 : index
|
||||
// CHECK: nvgpu.mbarrier.try_wait.parity %[[B]], %[[c0_6]], %[[c10000000]] : <memorySpace = #gpu.address_space<workgroup>
|
||||
|
||||
/// Both copies are matched and end up in the same async group.
|
||||
linalg.copy ins(%memref: memref<64x8xf32>) outs(%out: memref<64x8xf32, #gpu.address_space<workgroup>>)
|
||||
linalg.copy ins(%memref_1: memref<8x128xf32>) outs(%out_1: memref<8x128xf32, #gpu.address_space<workgroup>>)
|
||||
|
||||
gpu.terminator
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb1(%arg1: !transform.any_op):
|
||||
%copy = transform.structured.match ops{["linalg.copy"]} in %arg1
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
transform.nvgpu.rewrite_copy_as_tma %copy : (!transform.any_op) -> ()
|
||||
}
|
||||
109
mlir/test/Integration/GPU/CUDA/sm90/tmaload-transform.mlir
Normal file
109
mlir/test/Integration/GPU/CUDA/sm90/tmaload-transform.mlir
Normal file
@@ -0,0 +1,109 @@
|
||||
// RUN: mlir-opt %s \
|
||||
// RUN: -test-transform-dialect-interpreter \
|
||||
// RUN: -test-transform-dialect-erase-schedule \
|
||||
// RUN: -convert-nvgpu-to-nvvm -gpu-kernel-outlining \
|
||||
// RUN: -convert-scf-to-cf -convert-nvvm-to-llvm \
|
||||
// RUN: -convert-vector-to-llvm \
|
||||
// RUN: -convert-math-to-llvm \
|
||||
// RUN: -expand-strided-metadata \
|
||||
// RUN: -lower-affine \
|
||||
// RUN: -convert-index-to-llvm=index-bitwidth=32 \
|
||||
// RUN: -convert-arith-to-llvm \
|
||||
// RUN: -finalize-memref-to-llvm \
|
||||
// RUN: -convert-func-to-llvm \
|
||||
// RUN: -canonicalize \
|
||||
// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-nvgpu-to-nvvm{use-opaque-pointers=1},lower-affine,convert-scf-to-cf,convert-vector-to-llvm,convert-math-to-llvm,expand-strided-metadata,lower-affine,convert-index-to-llvm{index-bitwidth=32},convert-arith-to-llvm,reconcile-unrealized-casts,gpu-to-cubin{chip=sm_90 features=+ptx80 dump-ptx}))' \
|
||||
// RUN: 2&>1 | FileCheck %s --check-prefixes=CHECK-PTX
|
||||
|
||||
// CHECK-PTX: mbarrier.init.shared {{.*}} !llvm.ptr<3>, i32
|
||||
/// If branch
|
||||
// CHECK-PTX: cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes
|
||||
// CHECK-PTX: cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes
|
||||
// CHECK-PTX: mbarrier.arrive.expect_tx.shared
|
||||
/// Else branch
|
||||
// CHECK-PTX: mbarrier.arrive.expect_tx.shared
|
||||
// CHECK-PTX: mbarrier.try_wait.parity.shared
|
||||
|
||||
// TODO: GPU layering does not currently work end-to-end. Activate the following
|
||||
// when fixed.
|
||||
// R-UN: | mlir-opt -convert-index-to-llvm=index-bitwidth=32 \
|
||||
// R-UN: -gpu-to-llvm \
|
||||
// R-UN: -convert-func-to-llvm \
|
||||
// R-UN: -cse \
|
||||
// R-UN: -canonicalize \
|
||||
// R-UN: -reconcile-unrealized-casts \
|
||||
// R-UN: | mlir-cpu-runner \
|
||||
// R-UN: --shared-libs=%mlir_cuda_runtime \
|
||||
// R-UN: --shared-libs=%mlir_runner_utils \
|
||||
// R-UN: --entry-point-result=void \
|
||||
// R-UN: | FileCheck %s
|
||||
|
||||
// C-HECK: [GPU] TMA BEFORE lhs[45][7] 0.000000
|
||||
// C-HECK: [GPU] TMA BEFORE rhs[7][0] 0.000000
|
||||
// C-HECK: [GPU] TMA LOADED lhs[45][7] 7.000000
|
||||
// C-HECK: [GPU] TMA LOADED rhs[7][0] 3.000000
|
||||
|
||||
|
||||
module @mymod {
|
||||
memref.global "private" @bufferLhsGlobal : memref<64x8xf32, 3>
|
||||
memref.global "private" @bufferRhsGlobal : memref<8x128xf32, 3>
|
||||
func.func @main() {
|
||||
%c10000000 = arith.constant 10000000 : index
|
||||
%c6144 = arith.constant 6144 : index
|
||||
%c45 = arith.constant 45 : index
|
||||
%c7 = arith.constant 7 : index
|
||||
%c64 = arith.constant 64 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c8 = arith.constant 8 : index
|
||||
%c128 = arith.constant 128 : index
|
||||
%cst = arith.constant 3.000000e+00 : f32
|
||||
%alloc = memref.alloc() : memref<64x8xf32>
|
||||
%alloc_0 = memref.alloc() : memref<8x128xf32>
|
||||
scf.for %arg0 = %c0 to %c8 step %c1 {
|
||||
scf.for %arg1 = %c0 to %c128 step %c1 {
|
||||
memref.store %cst, %alloc_0[%arg0, %arg1] : memref<8x128xf32>
|
||||
}
|
||||
}
|
||||
scf.for %arg0 = %c0 to %c64 step %c1 {
|
||||
scf.for %arg1 = %c0 to %c8 step %c1 {
|
||||
%5 = arith.index_cast %arg1 : index to i64
|
||||
%6 = arith.uitofp %5 : i64 to f32
|
||||
memref.store %6, %alloc[%arg0, %arg1] : memref<64x8xf32>
|
||||
}
|
||||
}
|
||||
%0 = gpu.wait async
|
||||
%memref, %asyncToken = gpu.alloc async [%0] () : memref<64x8xf32>
|
||||
%memref_1, %asyncToken_2 = gpu.alloc async [%0] () : memref<8x128xf32>
|
||||
%1 = gpu.memcpy async [%0] %memref, %alloc : memref<64x8xf32>, memref<64x8xf32>
|
||||
%2 = gpu.memcpy async [%0] %memref_1, %alloc_0 : memref<8x128xf32>, memref<8x128xf32>
|
||||
|
||||
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
|
||||
threads(%tx, %ty, %tz) in (%block_x = %c128, %block_y = %c1, %block_z = %c1) {
|
||||
%out = memref.get_global @bufferLhsGlobal : memref<64x8xf32, 3>
|
||||
%out_1 = memref.get_global @bufferRhsGlobal : memref<8x128xf32, 3>
|
||||
linalg.copy ins(%memref: memref<64x8xf32>) outs(%out: memref<64x8xf32, 3>)
|
||||
linalg.copy ins(%memref_1: memref<8x128xf32>) outs(%out_1: memref<8x128xf32, 3>)
|
||||
|
||||
%6 = gpu.thread_id x
|
||||
%10 = arith.cmpi eq, %6, %c0 : index
|
||||
scf.if %10 {
|
||||
%11 = memref.load %out[%c45, %c7] : memref<64x8xf32, 3>
|
||||
%12 = memref.load %out_1[%c7, %c0] : memref<8x128xf32, 3>
|
||||
gpu.printf "[GPU] TMA LOADED lhs[45][7] %f\0A" %11 : f32
|
||||
gpu.printf "[GPU] TMA LOADED rhs[7][0] %f\0A" %12 : f32
|
||||
}
|
||||
gpu.terminator
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb1(%arg1: !transform.any_op):
|
||||
%copy = transform.structured.match ops{["linalg.copy"]} in %arg1
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
transform.nvgpu.rewrite_copy_as_tma %copy
|
||||
: (!transform.any_op) -> ()
|
||||
}
|
||||
Reference in New Issue
Block a user