//===- MlirOptMain.cpp - MLIR Optimizer Driver ----------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This is a utility that runs an optimization pass and prints the result back // out. It is designed to support unit testing. // //===----------------------------------------------------------------------===// #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/DebugCounter.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Support/Timing.h" #include "mlir/Support/ToolUtilities.h" #include "mlir/Tools/ParseUtilities.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileUtilities.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/Regex.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/StringSaver.h" #include "llvm/Support/ThreadPool.h" #include "llvm/Support/ToolOutputFile.h" using namespace mlir; using namespace llvm; /// Perform the actions on the input file indicated by the command line flags /// within the specified context. /// /// This typically parses the main source file, runs zero or more optimization /// passes, then prints the output. /// static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses, const std::shared_ptr &sourceMgr, MLIRContext *context, PassPipelineFn passManagerSetupFn, bool emitBytecode, bool implicitModule) { DefaultTimingManager tm; applyDefaultTimingManagerCLOptions(tm); TimingScope timing = tm.getRootScope(); // Disable multi-threading when parsing the input file. This removes the // unnecessary/costly context synchronization when parsing. bool wasThreadingEnabled = context->isMultithreadingEnabled(); context->disableMultithreading(); // Prepare the parser config, and attach any useful/necessary resource // handlers. Unhandled external resources are treated as passthrough, i.e. // they are not processed and will be emitted directly to the output // untouched. PassReproducerOptions reproOptions; FallbackAsmResourceMap fallbackResourceMap; ParserConfig config(context, /*verifyAfterParse=*/true, &fallbackResourceMap); reproOptions.attachResourceParser(config); // Parse the input file and reset the context threading state. TimingScope parserTiming = timing.nest("Parser"); OwningOpRef op = parseSourceFileForTool(sourceMgr, config, implicitModule); context->enableMultithreading(wasThreadingEnabled); if (!op) return failure(); parserTiming.stop(); // Prepare the pass manager, applying command-line and reproducer options. PassManager pm(context, OpPassManager::Nesting::Implicit, op.get()->getName().getStringRef()); pm.enableVerifier(verifyPasses); applyPassManagerCLOptions(pm); pm.enableTiming(timing); if (failed(reproOptions.apply(pm)) || failed(passManagerSetupFn(pm))) return failure(); // Run the pipeline. if (failed(pm.run(*op))) return failure(); // Print the output. TimingScope outputTiming = timing.nest("Output"); if (emitBytecode) { BytecodeWriterConfig writerConfig(fallbackResourceMap); writeBytecodeToFile(op.get(), os, writerConfig); } else { AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr, &fallbackResourceMap); op.get()->print(os, asmState); os << '\n'; } return success(); } /// Parses the memory buffer. If successfully, run a series of passes against /// it and print the result. static LogicalResult processBuffer(raw_ostream &os, std::unique_ptr ownedBuffer, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, bool preloadDialectsInContext, bool emitBytecode, bool implicitModule, PassPipelineFn passManagerSetupFn, DialectRegistry ®istry, llvm::ThreadPool *threadPool) { // Tell sourceMgr about this buffer, which is what the parser will pick up. auto sourceMgr = std::make_shared(); sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); // Create a context just for the current buffer. Disable threading on creation // since we'll inject the thread-pool separately. MLIRContext context(registry, MLIRContext::Threading::DISABLED); if (threadPool) context.setThreadPool(*threadPool); // Parse the input file. if (preloadDialectsInContext) context.loadAllAvailableDialects(); context.allowUnregisteredDialects(allowUnregisteredDialects); if (verifyDiagnostics) context.printOpOnDiagnostic(false); context.getDebugActionManager().registerActionHandler(); // If we are in verify diagnostics mode then we have a lot of work to do, // otherwise just perform the actions without worrying about it. if (!verifyDiagnostics) { SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context); return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context, passManagerSetupFn, emitBytecode, implicitModule); } SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context); // Do any processing requested by command line flags. We don't care whether // these actions succeed or fail, we only care what diagnostics they produce // and whether they match our expectations. (void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context, passManagerSetupFn, emitBytecode, implicitModule); // Verify the diagnostic handler to make sure that each of the diagnostics // matched. return sourceMgrHandler.verify(); } LogicalResult mlir::MlirOptMain(raw_ostream &outputStream, std::unique_ptr buffer, PassPipelineFn passManagerSetupFn, DialectRegistry ®istry, bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, bool preloadDialectsInContext, bool emitBytecode, bool implicitModule) { // The split-input-file mode is a very specific mode that slices the file // up into small pieces and checks each independently. // We use an explicit threadpool to avoid creating and joining/destroying // threads for each of the split. ThreadPool *threadPool = nullptr; // Create a temporary context for the sake of checking if // --mlir-disable-threading was passed on the command line. // We use the thread-pool this context is creating, and avoid // creating any thread when disabled. MLIRContext threadPoolCtx; if (threadPoolCtx.isMultithreadingEnabled()) threadPool = &threadPoolCtx.getThreadPool(); auto chunkFn = [&](std::unique_ptr chunkBuffer, raw_ostream &os) { return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics, verifyPasses, allowUnregisteredDialects, preloadDialectsInContext, emitBytecode, implicitModule, passManagerSetupFn, registry, threadPool); }; return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream, splitInputFile, /*insertMarkerInOutput=*/true); } LogicalResult mlir::MlirOptMain( raw_ostream &outputStream, std::unique_ptr buffer, const PassPipelineCLParser &passPipeline, DialectRegistry ®istry, bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, bool preloadDialectsInContext, bool emitBytecode, bool implicitModule, bool dumpPassPipeline) { auto passManagerSetupFn = [&](PassManager &pm) { auto errorHandler = [&](const Twine &msg) { emitError(UnknownLoc::get(pm.getContext())) << msg; return failure(); }; if (failed(passPipeline.addToPipeline(pm, errorHandler))) return failure(); if (dumpPassPipeline) { pm.dump(); llvm::errs() << "\n"; } return success(); }; return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn, registry, splitInputFile, verifyDiagnostics, verifyPasses, allowUnregisteredDialects, preloadDialectsInContext, emitBytecode, implicitModule); } LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName, DialectRegistry ®istry, bool preloadDialectsInContext) { static cl::opt inputFilename( cl::Positional, cl::desc(""), cl::init("-")); static cl::opt outputFilename("o", cl::desc("Output filename"), cl::value_desc("filename"), cl::init("-")); static cl::opt splitInputFile( "split-input-file", cl::desc("Split the input file into pieces and process each " "chunk independently"), cl::init(false)); static cl::opt verifyDiagnostics( "verify-diagnostics", cl::desc("Check that emitted diagnostics match " "expected-* lines on the corresponding line"), cl::init(false)); static cl::opt verifyPasses( "verify-each", cl::desc("Run the verifier after each transformation pass"), cl::init(true)); static cl::opt allowUnregisteredDialects( "allow-unregistered-dialect", cl::desc("Allow operation with no registered dialects"), cl::init(false)); static cl::opt showDialects( "show-dialects", cl::desc("Print the list of registered dialects"), cl::init(false)); static cl::opt emitBytecode( "emit-bytecode", cl::desc("Emit bytecode when generating output"), cl::init(false)); static cl::opt noImplicitModule{ "no-implicit-module", cl::desc( "Disable implicit addition of a top-level module op during parsing"), cl::init(false)}; static cl::opt dumpPassPipeline{ "dump-pass-pipeline", cl::desc("Print the pipeline that will be run"), cl::init(false)}; InitLLVM y(argc, argv); // Register any command line options. registerAsmPrinterCLOptions(); registerMLIRContextCLOptions(); registerPassManagerCLOptions(); registerDefaultTimingManagerCLOptions(); DebugCounter::registerCLOptions(); PassPipelineCLParser passPipeline("", "Compiler passes to run", "p"); // Build the list of dialects as a header for the --help message. std::string helpHeader = (toolName + "\nAvailable Dialects: ").str(); { llvm::raw_string_ostream os(helpHeader); interleaveComma(registry.getDialectNames(), os, [&](auto name) { os << name; }); } // Parse pass names in main to ensure static initialization completed. cl::ParseCommandLineOptions(argc, argv, helpHeader); if (showDialects) { llvm::outs() << "Available Dialects:\n"; interleave( registry.getDialectNames(), llvm::outs(), [](auto name) { llvm::outs() << name; }, "\n"); return success(); } // Set up the input file. std::string errorMessage; auto file = openInputFile(inputFilename, &errorMessage); if (!file) { llvm::errs() << errorMessage << "\n"; return failure(); } auto output = openOutputFile(outputFilename, &errorMessage); if (!output) { llvm::errs() << errorMessage << "\n"; return failure(); } if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry, splitInputFile, verifyDiagnostics, verifyPasses, allowUnregisteredDialects, preloadDialectsInContext, emitBytecode, /*implicitModule=*/!noImplicitModule, dumpPassPipeline))) return failure(); // Keep the output file if the invocation of MlirOptMain was successful. output->keep(); return success(); }