//===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" using namespace mlir; //===----------------------------------------------------------------------===// // Printing op availability pass //===----------------------------------------------------------------------===// namespace { /// A pass for testing SPIR-V op availability. struct PrintOpAvailability : public PassWrapper { void runOnFunction() override; StringRef getArgument() const final { return "test-spirv-op-availability"; } StringRef getDescription() const final { return "Test SPIR-V op availability"; } }; } // end anonymous namespace void PrintOpAvailability::runOnFunction() { auto f = getFunction(); llvm::outs() << f.getName() << "\n"; Dialect *spvDialect = getContext().getLoadedDialect("spv"); f->walk([&](Operation *op) { if (op->getDialect() != spvDialect) return WalkResult::advance(); auto opName = op->getName(); auto &os = llvm::outs(); if (auto minVersion = dyn_cast(op)) os << opName << " min version: " << spirv::stringifyVersion(minVersion.getMinVersion()) << "\n"; if (auto maxVersion = dyn_cast(op)) os << opName << " max version: " << spirv::stringifyVersion(maxVersion.getMaxVersion()) << "\n"; if (auto extension = dyn_cast(op)) { os << opName << " extensions: ["; for (const auto &exts : extension.getExtensions()) { os << " ["; llvm::interleaveComma(exts, os, [&](spirv::Extension ext) { os << spirv::stringifyExtension(ext); }); os << "]"; } os << " ]\n"; } if (auto capability = dyn_cast(op)) { os << opName << " capabilities: ["; for (const auto &caps : capability.getCapabilities()) { os << " ["; llvm::interleaveComma(caps, os, [&](spirv::Capability cap) { os << spirv::stringifyCapability(cap); }); os << "]"; } os << " ]\n"; } os.flush(); return WalkResult::advance(); }); } namespace mlir { void registerPrintOpAvailabilityPass() { PassRegistration(); } } // namespace mlir //===----------------------------------------------------------------------===// // Converting target environment pass //===----------------------------------------------------------------------===// namespace { /// A pass for testing SPIR-V op availability. struct ConvertToTargetEnv : public PassWrapper { StringRef getArgument() const override { return "test-spirv-target-env"; } StringRef getDescription() const override { return "Test SPIR-V target environment"; } void runOnFunction() override; }; struct ConvertToAtomCmpExchangeWeak : public RewritePattern { ConvertToAtomCmpExchangeWeak(MLIRContext *context); LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override; }; struct ConvertToBitReverse : public RewritePattern { ConvertToBitReverse(MLIRContext *context); LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override; }; struct ConvertToGroupNonUniformBallot : public RewritePattern { ConvertToGroupNonUniformBallot(MLIRContext *context); LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override; }; struct ConvertToModule : public RewritePattern { ConvertToModule(MLIRContext *context); LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override; }; struct ConvertToSubgroupBallot : public RewritePattern { ConvertToSubgroupBallot(MLIRContext *context); LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override; }; } // end anonymous namespace void ConvertToTargetEnv::runOnFunction() { MLIRContext *context = &getContext(); FuncOp fn = getFunction(); auto targetEnv = fn.getOperation() ->getAttr(spirv::getTargetEnvAttrName()) .cast(); if (!targetEnv) { fn.emitError("missing 'spv.target_env' attribute"); return signalPassFailure(); } auto target = SPIRVConversionTarget::get(targetEnv); RewritePatternSet patterns(context); patterns.add(context); if (failed(applyPartialConversion(fn, *target, std::move(patterns)))) return signalPassFailure(); } ConvertToAtomCmpExchangeWeak::ConvertToAtomCmpExchangeWeak(MLIRContext *context) : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op", 1, context, {"spv.AtomicCompareExchangeWeak"}) {} LogicalResult ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { Value ptr = op->getOperand(0); Value value = op->getOperand(1); Value comparator = op->getOperand(2); // Create a spv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits in // memory semantics to additionally require AtomicStorage capability. rewriter.replaceOpWithNewOp( op, value.getType(), ptr, spirv::Scope::Workgroup, spirv::MemorySemantics::AcquireRelease | spirv::MemorySemantics::AtomicCounterMemory, spirv::MemorySemantics::Acquire, value, comparator); return success(); } ConvertToBitReverse::ConvertToBitReverse(MLIRContext *context) : RewritePattern("test.convert_to_bit_reverse_op", 1, context, {"spv.BitReverse"}) {} LogicalResult ConvertToBitReverse::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { Value predicate = op->getOperand(0); rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), predicate); return success(); } ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot( MLIRContext *context) : RewritePattern("test.convert_to_group_non_uniform_ballot_op", 1, context, {"spv.GroupNonUniformBallot"}) {} LogicalResult ConvertToGroupNonUniformBallot::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { Value predicate = op->getOperand(0); rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate); return success(); } ConvertToModule::ConvertToModule(MLIRContext *context) : RewritePattern("test.convert_to_module_op", 1, context, {"spv.module"}) {} LogicalResult ConvertToModule::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp( op, spirv::AddressingModel::PhysicalStorageBuffer64, spirv::MemoryModel::Vulkan); return success(); } ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context) : RewritePattern("test.convert_to_subgroup_ballot_op", 1, context, {"spv.SubgroupBallotKHR"}) {} LogicalResult ConvertToSubgroupBallot::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { Value predicate = op->getOperand(0); rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), predicate); return success(); } namespace mlir { void registerConvertToTargetEnvPass() { PassRegistration(); } } // namespace mlir