//===- mlir-vulkan-runner.cpp - MLIR Vulkan Execution Driver --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This is a command line utility that executes an MLIR file on the Vulkan by // translating MLIR GPU module to SPIR-V and host part to LLVM IR before // JIT-compiling and executing the latter. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/ExecutionEngine/JitRunner.h" #include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/TargetSelect.h" using namespace mlir; static LogicalResult runMLIRPasses(ModuleOp module) { PassManager passManager(module.getContext()); applyPassManagerCLOptions(passManager); passManager.addPass(createGpuKernelOutliningPass()); passManager.addPass(memref::createFoldSubViewOpsPass()); passManager.addPass(createConvertGPUToSPIRVPass()); OpPassManager &modulePM = passManager.nest(); modulePM.addPass(spirv::createLowerABIAttributesPass()); modulePM.addPass(spirv::createUpdateVersionCapabilityExtensionPass()); passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass()); LowerToLLVMOptions llvmOptions(module.getContext(), DataLayout(module)); llvmOptions.emitCWrappers = true; passManager.addPass(createMemRefToLLVMPass()); passManager.addPass(createLowerToLLVMPass(llvmOptions)); passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass()); return passManager.run(module); } int main(int argc, char **argv) { llvm::llvm_shutdown_obj x; registerPassManagerCLOptions(); llvm::InitLLVM y(argc, argv); llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); mlir::initializeLLVMPasses(); mlir::JitRunnerConfig jitRunnerConfig; jitRunnerConfig.mlirTransformer = runMLIRPasses; mlir::DialectRegistry registry; registry.insert(); mlir::registerLLVMDialectTranslation(registry); return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig); }