Files
clang-p2996/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
Aart Bik ee42e23614 [mlir][sparse][gpu] first implementation of the GPU libgen approach
The sparse compiler now has two prototype strategies for GPU acceleration:

* CUDA codegen: this converts sparsified code to CUDA threads
* CUDA libgen: this converts pre-sparsified code to cuSPARSE library calls

This revision introduces the first steps required for the second approach.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D150170
2023-05-15 08:49:38 -07:00

682 lines
28 KiB
C++

//===- SparseGPUCodegen.cpp - Generates GPU code --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is a prototype GPU codegenerator for the sparse compiler.
// The objective is to eventually use the right combination of
// direct code generation and libary calls into vendor-specific
// highly optimized sparse libraries (e.g. cuSparse for CUDA).
//
//===----------------------------------------------------------------------===//
#include "CodegenUtils.h"
#include "LoopEmitter.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
namespace {
//===----------------------------------------------------------------------===//
// Helper methods.
//===----------------------------------------------------------------------===//
/// Marks the given top module as a GPU container module.
static void markAsGPUContainer(ModuleOp topModule) {
topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
UnitAttr::get(topModule->getContext()));
}
/// Constructs a new GPU module (for GPU kernels) inside the given top module,
/// or returns an existing GPU module if one was built previously.
static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule) {
for (auto op : topModule.getBodyRegion().getOps<gpu::GPUModuleOp>())
return op; // existing
markAsGPUContainer(topModule);
builder.setInsertionPointToStart(&topModule.getBodyRegion().front());
return builder.create<gpu::GPUModuleOp>(topModule->getLoc(),
"sparse_kernels");
}
/// Constructs a new GPU kernel in the given GPU module.
static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule,
SmallVectorImpl<Value> &args) {
// Get a unique kernel name. Not very creative,
// but we simply try kernel0, kernel1, etc.
unsigned kernelNumber = 0;
SmallString<16> kernelName;
do {
kernelName.clear();
("kernel" + Twine(kernelNumber++)).toStringRef(kernelName);
} while (gpuModule.lookupSymbol(kernelName));
// Then we insert a new kernel with given arguments into the module.
builder.setInsertionPointToStart(&gpuModule.getBodyRegion().front());
SmallVector<Type> argsTp;
for (unsigned i = 0, e = args.size(); i < e; i++)
argsTp.push_back(args[i].getType());
FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {});
auto gpuFunc =
builder.create<gpu::GPUFuncOp>(gpuModule->getLoc(), kernelName, type);
gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
builder.getUnitAttr());
return gpuFunc;
}
/// Constructs code to launch GPU kernel.
static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc,
SmallVectorImpl<Value> &args,
SmallVectorImpl<Value> &tokens,
unsigned numThreads) {
Location loc = gpuFunc->getLoc();
Value none = TypedValue<::mlir::IntegerType>{};
Value one = constantIndex(builder, loc, 1);
Value numT = constantIndex(builder, loc, numThreads);
gpu::KernelDim3 gridSize = {one, one, one};
gpu::KernelDim3 blckSize = {numT, one, one};
return builder
.create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize,
/*dynSharedMemSz*/ none, args,
builder.getType<gpu::AsyncTokenType>(), tokens)
.getAsyncToken();
}
/// Maps the provided ranked host buffer into the device address space.
/// Writes from the host are guaranteed to be visible to device kernels
/// that are launched afterwards. Writes from the device are guaranteed
/// to be visible on the host after synchronizing with the device kernel
/// completion. Needs to cast the buffer to a unranked buffer.
static Value genHostRegisterMemref(OpBuilder &builder, Location loc,
Value mem) {
MemRefType memTp = cast<MemRefType>(mem.getType());
UnrankedMemRefType resTp =
UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0);
Value cast = builder.create<memref::CastOp>(loc, resTp, mem);
builder.create<gpu::HostRegisterOp>(loc, cast);
return cast;
}
/// Unmaps the provided buffer, expecting the casted buffer.
static void genHostUnregisterMemref(OpBuilder &builder, Location loc,
Value cast) {
builder.create<gpu::HostUnregisterOp>(loc, cast);
}
/// Generates first wait in an asynchronous chain.
static Value genFirstWait(OpBuilder &builder, Location loc) {
Type tokenType = builder.getType<gpu::AsyncTokenType>();
return builder.create<gpu::WaitOp>(loc, tokenType, ValueRange())
.getAsyncToken();
}
/// Generates last, blocking wait in an asynchronous chain.
static void genBlockingWait(OpBuilder &builder, Location loc,
ValueRange operands) {
builder.create<gpu::WaitOp>(loc, Type(), operands);
}
/// Allocates memory on the device.
/// TODO: A `host_shared` attribute could be used to indicate that
/// the buffer is visible by both host and device, but lowering
/// that feature does not seem to be fully supported yet.
static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem,
Value token) {
auto tp = cast<ShapedType>(mem.getType());
auto elemTp = tp.getElementType();
auto shape = tp.getShape();
auto memTp = MemRefType::get(shape, elemTp);
SmallVector<Value> dynamicSizes;
for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) {
if (shape[r] == ShapedType::kDynamic) {
Value dimOp = linalg::createOrFoldDimOp(builder, loc, mem, r);
dynamicSizes.push_back(dimOp);
}
}
return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
token, dynamicSizes, ValueRange());
}
// Allocates a void buffer on the device with given size.
static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size,
Value token) {
const auto memTp =
MemRefType::get({ShapedType::kDynamic}, builder.getI8Type());
return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
token, size, ValueRange());
}
/// Deallocates memory from the device.
static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem,
Value token) {
return builder.create<gpu::DeallocOp>(loc, token.getType(), token, mem)
.getAsyncToken();
}
/// Copies memory between host and device (direction is implicit).
static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst,
Value src, Value token) {
return builder.create<gpu::MemcpyOp>(loc, token.getType(), token, dst, src)
.getAsyncToken();
}
/// Generates an alloc/copy pair.
static Value genAllocCopy(OpBuilder &builder, Location loc, Value b,
SmallVectorImpl<Value> &tokens) {
Value firstToken = genFirstWait(builder, loc);
auto alloc = genAllocMemRef(builder, loc, b, firstToken);
Value devMem = alloc.getResult(0);
Value depToken = alloc.getAsyncToken(); // copy-after-alloc
tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken));
return devMem;
}
/// Generates a memref from tensor operation.
static Value genTensorToMemref(PatternRewriter &rewriter, Location loc,
Value tensor) {
auto tensorType = tensor.getType().cast<ShapedType>();
auto memrefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor);
}
/// Prepares the outlined arguments, passing scalars and buffers in. Here we
/// assume that the first buffer is the one allocated for output. We create
/// a set of properly chained asynchronous allocation/copy pairs to increase
/// overlap before launching the kernel.
/// TODO: the output assumption may be a bit too brittle
static Value genParametersIn(OpBuilder &builder, Location loc,
SmallVectorImpl<Value> &scalars,
SmallVectorImpl<Value> &buffers,
SmallVectorImpl<Value> &args,
SmallVectorImpl<Value> &tokens,
bool useHostRegistrationForOut) {
Value out;
// Scalars are passed by value.
for (Value s : scalars)
args.push_back(s);
// Buffers are need to be made visible on device.
for (Value b : buffers) {
if (useHostRegistrationForOut) {
out = genHostRegisterMemref(builder, loc, b);
args.push_back(b);
useHostRegistrationForOut = false;
continue;
}
args.push_back(genAllocCopy(builder, loc, b, tokens));
}
return out;
}
/// Finalizes the outlined arguments. The output buffer is copied depending
/// on the kernel token and then deallocated. All other buffers are simply
/// deallocated. Then we wait for all operations to complete.
static void genParametersOut(OpBuilder &builder, Location loc, Value out,
Value kernelToken, SmallVectorImpl<Value> &scalars,
SmallVectorImpl<Value> &buffers,
SmallVectorImpl<Value> &args,
SmallVectorImpl<Value> &tokens) {
unsigned base = scalars.size();
for (unsigned i = base, e = args.size(); i < e; i++) {
Value firstToken;
if (i == base) {
// Assumed output parameter: unregister or copy-out.
if (out) {
genHostUnregisterMemref(builder, loc, out);
out = Value();
continue;
}
firstToken =
genCopyMemRef(builder, loc, buffers[0], args[i], kernelToken);
} else {
firstToken = genFirstWait(builder, loc);
}
tokens.push_back(genDeallocMemRef(builder, loc, args[i], firstToken));
}
}
/// Constructs code for new GPU kernel.
static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc,
scf::ParallelOp forallOp,
SmallVectorImpl<Value> &constants,
SmallVectorImpl<Value> &scalars,
SmallVectorImpl<Value> &buffers) {
Location loc = gpuFunc->getLoc();
Block &block = gpuFunc.getBody().front();
rewriter.setInsertionPointToStart(&block);
// Re-generate the constants, recapture all arguments.
unsigned arg = 0;
IRMapping irMap;
for (Value c : constants)
irMap.map(c, rewriter.clone(*c.getDefiningOp())->getResult(0));
for (Value s : scalars)
irMap.map(s, block.getArgument(arg++));
for (Value b : buffers)
irMap.map(b, block.getArgument(arg++));
// Assume 1-dimensional grid/block configuration (only x dimension),
// so that:
// row = blockIdx.x * blockDim.x + threadIdx.x
// inc = blockDim.x * gridDim.x
Value bid = rewriter.create<gpu::BlockIdOp>(loc, gpu::Dimension::x);
Value bsz = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x);
Value tid = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
Value gsz = rewriter.create<gpu::GridDimOp>(loc, gpu::Dimension::x);
Value mul = rewriter.create<arith::MulIOp>(loc, bid, bsz);
Value row = rewriter.create<arith::AddIOp>(loc, mul, tid);
Value inc = rewriter.create<arith::MulIOp>(loc, bsz, gsz);
// Construct the iteration over the computational space that
// accounts for the fact that the total number of threads and
// the amount of work to be done usually do not match precisely.
// for (r = row; r < N; r += inc) {
// <loop-body>
// }
Value upper = irMap.lookup(forallOp.getUpperBound()[0]);
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, row, upper, inc);
rewriter.cloneRegionBefore(forallOp.getLoopBody(), forOp.getLoopBody(),
forOp.getLoopBody().begin(), irMap);
// Done.
rewriter.setInsertionPointAfter(forOp);
rewriter.create<gpu::ReturnOp>(gpuFunc->getLoc());
}
//===----------------------------------------------------------------------===//
// Library helper methods.
//===----------------------------------------------------------------------===//
/// Helper to detect a * b.
static bool matchMulOfArgs(linalg::GenericOp op, Value val) {
if (auto *def = val.getDefiningOp()) {
if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
Value a = op.getBlock()->getArguments()[0];
Value b = op.getBlock()->getArguments()[1];
return (def->getOperand(0) == a && def->getOperand(1) == b) ||
(def->getOperand(0) == b && def->getOperand(1) == a);
}
}
return false;
}
/// Helper to detect x = x + a * b
static bool matchSumOfMultOfArgs(linalg::GenericOp op) {
auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
Value x = op.getBlock()->getArguments()[2];
return (def->getOperand(0) == x &&
matchMulOfArgs(op, def->getOperand(1))) ||
(def->getOperand(1) == x &&
matchMulOfArgs(op, def->getOperand(0)));
}
}
return false;
}
/// Test for sorted COO with suitable data and coordinates types.
static bool isAdmissibleCOO(SparseTensorType &aTp) {
return aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) &&
aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
(aTp.getElementType().isF64() || aTp.getElementType().isF32()) &&
(aTp.getCrdWidth() == 0 || aTp.getCrdWidth() == 32 ||
aTp.getCrdWidth() == 64);
}
/// Test for CSR with suitable data and coordinates types.
static bool isAdmissibleCSR(SparseTensorType &aTp) {
return aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) &&
aTp.isUniqueLvl(1) &&
(aTp.getElementType().isF64() || aTp.getElementType().isF32()) &&
(aTp.getCrdWidth() == 0 || aTp.getCrdWidth() == 32 ||
aTp.getCrdWidth() == 64);
}
/// Generates the first positions/coordinates of a sparse matrix.
static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a,
bool isCOO, bool enableRT) {
if (isCOO) {
// Library uses SoA COO, direct IR uses AoS COO.
if (enableRT)
return genToCoordinates(builder, loc, a, 0, /*cooStart=*/0);
return genToCoordinatesBuffer(builder, loc, a);
}
// CSR uses positions.
return genToPositions(builder, loc, a, 1);
}
/// Generates the second coordinates of a sparse matrix.
static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
bool isCOO, bool enableRT) {
if (isCOO && !enableRT)
return Value(); // nothing needed
return genToCoordinates(builder, loc, a, 1, /*cooStart=*/0);
}
/// Generates the sparse matrix multiplication.
static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp,
Type tokenTp, Value token, Value szY, Value szX,
Value nnzA, Value rowA, Value colA, Value valA,
bool isCOO, bool enableRT) {
if (isCOO) {
// Library uses SoA COO, direct IR uses AoS COO.
if (enableRT)
return builder.create<gpu::CreateCooOp>(loc, handleTp, tokenTp, token,
szY, szX, nnzA, rowA, colA, valA);
llvm_unreachable("gpu::CreateCooAoSOp is deprecated");
}
return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, szY,
szX, nnzA, rowA, colA, valA);
}
/// Match and rewrite SpMV kernel.
static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
linalg::GenericOp op, bool enableRT) {
Location loc = op.getLoc();
Value a = op.getOperand(0);
Value x = op.getOperand(1);
Value y = op.getOperand(2); // we have y = Ax
SmallVector<Value> tokens;
// Only admissible sparse matrix format and dense vectors for now.
bool isCOO = false;
SparseTensorType aTp = getSparseTensorType(a);
SparseTensorType xTp = getSparseTensorType(x);
SparseTensorType yTp = getSparseTensorType(y);
if (xTp.hasEncoding() || yTp.hasEncoding())
return failure();
if (isAdmissibleCOO(aTp)) {
isCOO = true;
// TODO: CreateCooAoSOp was deprecated, find another way
if (!enableRT)
return failure();
} else if (isAdmissibleCSR(aTp)) {
isCOO = false;
} else {
return failure();
}
// Start sparse kernel and copy data from host to device.
// a : memR/memC/memV -> rowA,colA,valA
// x : memX -> vecX
// y : memY -> vecY
Value nnzA = rewriter.create<NumberOfEntriesOp>(loc, a);
Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
Value memC = genSecondCrds(rewriter, loc, a, isCOO, enableRT);
Value memV = genToValues(rewriter, loc, a);
Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
Value valA = genAllocCopy(rewriter, loc, memV, tokens);
Value memX = genTensorToMemref(rewriter, loc, x);
Value vecX = genAllocCopy(rewriter, loc, memX, tokens);
Value memY = genTensorToMemref(rewriter, loc, y);
Value vecY = genAllocCopy(rewriter, loc, memY, tokens);
genBlockingWait(rewriter, loc, tokens);
tokens.clear();
// Create sparse environment and sparse matrix/dense vector handles.
Type indexTp = rewriter.getIndexType();
Type handleTp = rewriter.getType<gpu::SparseHandleType>();
Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
Value token = genFirstWait(rewriter, loc);
auto env =
rewriter.create<gpu::CreateSparseEnvOp>(loc, handleTp, tokenTp, token);
Value handle = env.getResult(0);
token = env.getAsyncToken();
Operation *spGenA = genSpMat(rewriter, loc, handleTp, tokenTp, token, szY,
szX, nnzA, rowA, colA, valA, isCOO, enableRT);
Value spMatA = spGenA->getResult(0);
token = spGenA->getResult(1);
auto dvecX = rewriter.create<gpu::CreateDnVecOp>(loc, handleTp, tokenTp,
token, vecX, szX);
Value dnX = dvecX.getResult(0);
token = dvecX.getAsyncToken();
auto dvecY = rewriter.create<gpu::CreateDnVecOp>(loc, handleTp, tokenTp,
token, vecY, szY);
Value dnY = dvecY.getResult(0);
token = dvecY.getAsyncToken();
// Precompute buffersize for SpMV.
auto bufferComp = rewriter.create<gpu::SpMVBufferSizeOp>(
loc, indexTp, tokenTp, token, handle, spMatA, dnX, dnY);
Value bufferSz = bufferComp.getResult(0);
token = bufferComp.getAsyncToken();
auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
Value buffer = buf.getResult(0);
token = buf.getAsyncToken();
// Perform the SpMV.
auto spmvComp = rewriter.create<gpu::SpMVOp>(loc, tokenTp, token, handle,
spMatA, dnX, dnY, buffer);
token = spmvComp.getAsyncToken();
// Copy data back to host and free all the resoures.
token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
.getAsyncToken();
token = rewriter.create<gpu::DestroyDnVecOp>(loc, tokenTp, token, dnX)
.getAsyncToken();
token = rewriter.create<gpu::DestroyDnVecOp>(loc, tokenTp, token, dnY)
.getAsyncToken();
token = rewriter.create<gpu::DestroySparseEnvOp>(loc, tokenTp, token, handle)
.getAsyncToken();
tokens.push_back(token);
genBlockingWait(rewriter, loc, tokens);
tokens.clear();
token = genFirstWait(rewriter, loc);
token = genCopyMemRef(rewriter, loc, memY, vecY, token);
token = genDeallocMemRef(rewriter, loc, rowA, token);
if (colA)
token = genDeallocMemRef(rewriter, loc, colA, token);
token = genDeallocMemRef(rewriter, loc, valA, token);
token = genDeallocMemRef(rewriter, loc, buffer, token);
token = genDeallocMemRef(rewriter, loc, vecX, token);
token = genDeallocMemRef(rewriter, loc, vecY, token);
tokens.push_back(token);
genBlockingWait(rewriter, loc, tokens);
tokens.clear();
// Done.
rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
return success();
}
/// Match and rewrite SpMM kernel.
static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
linalg::GenericOp op, bool enableRT) {
return failure(); // TODO: implement
}
//===----------------------------------------------------------------------===//
// Rewriting rules for direct code generation.
//===----------------------------------------------------------------------===//
/// Proof-of-concept rewriter. This rule generates a GPU implementation
/// for each outermost forall loop generated by the sparse compiler.
/// TODO: right works with parallelization-strategy=dense-outer-loop
/// but give this its own flags in the future
struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
ForallRewriter(MLIRContext *context, unsigned nT)
: OpRewritePattern(context), numThreads(nT){};
LogicalResult matchAndRewrite(scf::ParallelOp forallOp,
PatternRewriter &rewriter) const override {
// Reject inadmissible loop form.
// Essentially only accept a loop, generated by the sparse compiler,
// of the form
// forall (i = 0; i < N; i++)
// so that cyclic scheduling over the threads is easy.
if (!forallOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ||
forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 ||
!matchPattern(forallOp.getLowerBound()[0], m_Zero()) ||
!matchPattern(forallOp.getStep()[0], m_One()))
return failure();
// Collect every value that is computed outside the parallel loop.
SetVector<Value> invariants; // stable iteration!
forallOp->walk([&](Operation *op) {
// Collect all values of admissible ops.
for (OpOperand &o : op->getOpOperands()) {
Value val = o.get();
Block *block;
if (auto arg = dyn_cast<BlockArgument>(val))
block = arg.getOwner();
else
block = val.getDefiningOp()->getBlock();
if (!isNestedIn(block, forallOp))
invariants.insert(val);
}
});
// Outline the outside values as proper parameters. Fail when sharing
// value between host and device is not straightforward.
SmallVector<Value> constants;
SmallVector<Value> scalars;
SmallVector<Value> buffers;
for (Value val : invariants) {
Type tp = val.getType();
if (val.getDefiningOp<arith::ConstantOp>())
constants.push_back(val);
else if (isa<FloatType>(tp) || tp.isIntOrIndex())
scalars.push_back(val);
else if (isa<MemRefType>(tp))
buffers.push_back(val);
else
return failure(); // don't know how to share
}
// Pass outlined non-constant values.
// TODO: Experiment with `useHostRegistrationForOut` to see if we want to
// keep the feature at all (either through a heuristic or compiler
// option for gpu codegen).
Location loc = forallOp->getLoc();
SmallVector<Value> args;
SmallVector<Value> tokens;
Value out = genParametersIn(rewriter, loc, scalars, buffers, args, tokens,
/*useHostRegistrationForOut=*/false);
// Set up GPU module and construct GPU function.
auto saveIp = rewriter.saveInsertionPoint();
ModuleOp topModule = forallOp->getParentOfType<ModuleOp>();
auto gpuModule = genGPUModule(rewriter, topModule);
auto gpuFunc = genGPUFunc(rewriter, gpuModule, args);
genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers);
// Generate code that launches the kernel asynchronously, blocking on all
// opens tokens and yielding a new token for the output.
// TODO: Passing in tokens to launch up does not seem to be properly lowered
// by cubin yet, hence the current blocking wait.
rewriter.restoreInsertionPoint(saveIp);
genBlockingWait(rewriter, loc, tokens);
tokens.clear();
Value kernelToken =
genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads);
// Finalize the outlined arguments.
genParametersOut(rewriter, loc, out, kernelToken, scalars, buffers, args,
tokens);
genBlockingWait(rewriter, loc, tokens);
rewriter.eraseOp(forallOp);
return success();
}
private:
// Helper method to see if block appears in given loop.
static bool isNestedIn(Block *block, scf::ParallelOp forallOp) {
for (Operation *o = block->getParentOp(); o; o = o->getParentOp()) {
if (o == forallOp)
return true;
}
return false;
}
unsigned numThreads;
};
//===----------------------------------------------------------------------===//
// Rewriting rules for library recognition and code generation.
//===----------------------------------------------------------------------===//
/// Proof-of-concept rewriter. This rule recognizes certain math kernels
/// and replaces these with corresponding calls into the sparse library.
struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
LinalgOpRewriter(MLIRContext *context, bool rt)
: OpRewritePattern(context), enableRT(rt) {}
LogicalResult matchAndRewrite(linalg::GenericOp op,
PatternRewriter &rewriter) const override {
if (op.getNumDpsInits() != 1)
return failure(); // reject multi-output
const unsigned numLoops = op.getNumLoops();
const unsigned numTensors = op->getNumOperands();
const auto iteratorTypes = op.getIteratorTypesArray();
SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
AffineExpr i, j, k;
bindDims(getContext(), i, j, k);
// TODO: more robust patterns, tranposed versions, more kernels...
// Recognize a SpMV kernel.
if (numLoops == 2 && numTensors == 3 &&
linalg::isParallelIterator(iteratorTypes[0]) &&
linalg::isReductionIterator(iteratorTypes[1]) &&
maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) {
return rewriteSpMV(rewriter, op, enableRT);
}
// Recognize a SpMM kernel.
if (numLoops == 3 && numTensors == 3 &&
linalg::isParallelIterator(iteratorTypes[0]) &&
linalg::isParallelIterator(iteratorTypes[1]) &&
linalg::isReductionIterator(iteratorTypes[2]) &&
maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
return rewriteSpMM(rewriter, op, enableRT);
}
return failure();
}
private:
bool enableRT;
};
} // namespace
//===----------------------------------------------------------------------===//
// Public method for populating GPU rewriting rules.
//
// Currently two set of rewriting rules are made available. The first set
// implements direct code generation, currently by means of convering the
// outermost paralell loop into GPU threads. The second set implements
// libary recognition of a set of sparse operations. Eventually, the right
// combination of these two approaches has to be found.
//===----------------------------------------------------------------------===//
void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns,
unsigned numThreads) {
patterns.add<ForallRewriter>(patterns.getContext(), numThreads);
}
void mlir::populateSparseGPULibgenPatterns(RewritePatternSet &patterns,
bool enableRT) {
patterns.add<LinalgOpRewriter>(patterns.getContext(), enableRT);
}