Files
clang-p2996/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
Guray Ozen 63389326f5 [mlir][nvvm] Support predicates in BasicPtxBuilder (#67102)
This PR enhances `BasicPtxBuilder` to support predicates in PTX code
generation. The `BasicPtxBuilder` interface was initially introduced for
generating PTX code automatically for Ops that aren't supported by LLVM
core. Predicates, which are typically not supported in LLVM core, are
now supported using the same mechanism.

In PTX programming, instructions can be guarded by predicates as shown
below:. Here `@p` is a predicate register and guard the execution of the
instruction.

```
@p ptx.code op1, op2, op3
```

This PR introduces the `getPredicate` function in the `BasicPtxBuilder`
interface to set an optional predicate. When a predicate is provided,
the instruction is generated with predicate and guarded, otherwise,
predicate is not genearted. Note that the predicate value must always
appear as the last argument on the Op definition.

Additionally, this PR implements predicate usage for the following ops:

- mbarrier.init
- mbarrier.init.shared
- mbarrier.arrive.expect_tx
- mbarrier.arrive.expect_tx.shared
- cp.async.bulk.tensor.shared.cluster.global
- cp.async.bulk.tensor.global.shared.cta

See for more detail in PTX programing model

https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#ptx-instructions
2023-10-17 12:42:36 +02:00

121 lines
4.0 KiB
C++

//===- NVVMToLLVM.cpp - NVVM to LLVM dialect conversion -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a translation NVVM ops which is not supported in LLVM
// core.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "nvvm-to-llvm"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define DBGSNL() (llvm::dbgs() << "\n")
namespace mlir {
#define GEN_PASS_DEF_CONVERTNVVMTOLLVMPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace NVVM;
namespace {
struct PtxLowering
: public OpInterfaceRewritePattern<BasicPtxBuilderInterface> {
using OpInterfaceRewritePattern<
BasicPtxBuilderInterface>::OpInterfaceRewritePattern;
PtxLowering(MLIRContext *context, PatternBenefit benefit = 2)
: OpInterfaceRewritePattern(context, benefit) {}
LogicalResult matchAndRewrite(BasicPtxBuilderInterface op,
PatternRewriter &rewriter) const override {
if (op.hasIntrinsic()) {
LLVM_DEBUG(DBGS() << "Ptx Builder does not lower \n\t" << op << "\n");
return failure();
}
SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
LLVM_DEBUG(DBGS() << op.getPtx() << "\n");
PtxBuilder generator(op, rewriter);
op.getAsmValues(rewriter, asmValues);
for (auto &[asmValue, modifier] : asmValues) {
LLVM_DEBUG(DBGSNL() << asmValue << "\t Modifier : " << &modifier);
generator.insertValue(asmValue, modifier);
}
generator.buildAndReplaceOp();
return success();
}
};
struct ConvertNVVMToLLVMPass
: public impl::ConvertNVVMToLLVMPassBase<ConvertNVVMToLLVMPass> {
using Base::Base;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect, NVVM::NVVMDialect>();
}
void runOnOperation() override {
ConversionTarget target(getContext());
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
RewritePatternSet pattern(&getContext());
mlir::populateNVVMToLLVMConversionPatterns(pattern);
if (failed(
applyPartialConversion(getOperation(), target, std::move(pattern))))
signalPassFailure();
}
};
/// Implement the interface to convert NVVM to LLVM.
struct NVVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
void loadDependentDialects(MLIRContext *context) const final {
context->loadDialect<NVVMDialect>();
}
/// Hook for derived dialect interface to provide conversion patterns
/// and mark dialect legal for the conversion target.
void populateConvertToLLVMConversionPatterns(
ConversionTarget &target, LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns) const final {
populateNVVMToLLVMConversionPatterns(patterns);
}
};
} // namespace
void mlir::populateNVVMToLLVMConversionPatterns(RewritePatternSet &patterns) {
patterns.add<PtxLowering>(patterns.getContext());
}
void mlir::registerConvertNVVMToLLVMInterface(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) {
dialect->addInterfaces<NVVMToLLVMDialectInterface>();
});
}