# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # This file contains the Sparsifier class. from mlir import execution_engine from mlir import ir from mlir import passmanager from typing import Sequence class Sparsifier: """Sparsifier class for compiling and building MLIR modules.""" def __init__( self, extras: str, options: str, opt_level: int, shared_libs: Sequence[str], ): pipeline = ( f"builtin.module({extras}sparsifier{{{options} reassociate-fp-reductions=1" " enable-index-optimizations=1})" ) self.pipeline = pipeline self.opt_level = opt_level self.shared_libs = shared_libs def __call__(self, module: ir.Module): """Convenience application method.""" self.compile(module) def compile(self, module: ir.Module): """Compiles the module by invoking the sparsifier pipeline.""" passmanager.PassManager.parse(self.pipeline).run(module.operation) def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: """Wraps the module in a JIT execution engine.""" return execution_engine.ExecutionEngine( module, opt_level=self.opt_level, shared_libs=self.shared_libs ) def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: """Compiles and jits the module.""" self.compile(module) return self.jit(module)