# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # This file contains the Nvgpu class. from mlir import execution_engine from mlir import ir from mlir import passmanager from typing import Sequence import errno import os import sys _SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) sys.path.append(_SCRIPT_PATH) class NvgpuCompiler: """Nvgpu class for compiling and building MLIR modules.""" def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]): pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})" self.pipeline = pipeline self.shared_libs = shared_libs self.opt_level = opt_level def __call__(self, module: ir.Module): """Convenience application method.""" self.compile(module) def compile(self, module: ir.Module): """Compiles the module by invoking the nvgpu pipeline.""" passmanager.PassManager.parse(self.pipeline).run(module.operation) def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: """Wraps the module in a JIT execution engine.""" return execution_engine.ExecutionEngine( module, opt_level=self.opt_level, shared_libs=self.shared_libs ) def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: """Compiles and jits the module.""" self.compile(module) return self.jit(module)