diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp index f0908564d957..ad8b768a2d3a 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp @@ -422,7 +422,7 @@ void linalg::convertToLLVM(mlir::Module &module) { populateLinalg1ToLLVMConversionPatterns(patterns, module.getContext()); ConversionTarget target(*module.getContext()); - target.addLegalDialects(); + target.addLegalDialect(); auto r = applyConversionPatterns(module, target, converter, std::move(patterns)); (void)r; diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index 26d6af8b6cc4..7a4edc4e439b 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -163,7 +163,7 @@ void linalg::convertLinalg3ToLLVM(Module &module) { populateLinalg3ToLLVMConversionPatterns(patterns, module.getContext()); ConversionTarget target(*module.getContext()); - target.addLegalDialects(); + target.addLegalDialect(); auto r = applyConversionPatterns(module, target, converter, std::move(patterns)); (void)r; diff --git a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp index 82541f80c3e4..e8ce2a4ed52a 100644 --- a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp @@ -127,7 +127,7 @@ public: struct EarlyLoweringPass : public FunctionPass { void runOnFunction() override { ConversionTarget target(getContext()); - target.addLegalDialects(); + target.addLegalDialect(); target.addLegalOp(); OwningRewritePatternList patterns; diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 4434e1bde58b..b2160898f520 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -345,8 +345,8 @@ struct LateLoweringPass : public ModulePass { // Perform Toy specific lowering. ConversionTarget target(getContext()); - target.addLegalDialects(); + target.addLegalDialect(); target.addLegalOp(); if (failed(applyConversionPatterns(getModule(), target, typeConverter, std::move(toyPatterns)))) { diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index e95842652d98..363f96d66c23 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -180,9 +180,7 @@ public: //===--------------------------------------------------------------------===// /// Register a legality action for the given operation. - void setOpAction(OperationName op, LegalizationAction action) { - legalOperations[op] = action; - } + void setOpAction(OperationName op, LegalizationAction action); template void setOpAction(LegalizationAction action) { setOpAction(OperationName(OpT::getOperationName(), &ctx), action); } @@ -196,21 +194,6 @@ public: addLegalOp(); } - /// Register the operations of the given dialects as legal. - void addLegalDialects(ArrayRef dialectNames) { - for (auto &dialect : dialectNames) - legalDialects[dialect] = LegalizationAction::Legal; - } - template - void addLegalDialects(StringRef name, Names... names) { - SmallVector dialectNames({name, names...}); - addLegalDialects(dialectNames); - } - template void addLegalDialects() { - SmallVector dialectNames({Args::getDialectNamespace()...}); - addLegalDialects(dialectNames); - } - /// Register the given operation as dynamically legal, i.e. requiring custom /// handling by the target via 'isLegal'. template void addDynamicallyLegalOp() { @@ -222,22 +205,39 @@ public: addDynamicallyLegalOp(); } + /// Register a legality action for the given dialects. + void setDialectAction(ArrayRef dialectNames, + LegalizationAction action); + + /// Register the operations of the given dialects as legal. + template + void addLegalDialect(StringRef name, Names... names) { + SmallVector dialectNames({name, names...}); + setDialectAction(dialectNames, LegalizationAction::Legal); + } + template void addLegalDialect() { + SmallVector dialectNames({Args::getDialectNamespace()...}); + setDialectAction(dialectNames, LegalizationAction::Legal); + } + + /// Register the operations of the given dialects as dynamically legal, i.e. + /// requiring custom handling by the target via 'isLegal'. + template + void addDynamicallyLegalDialect(StringRef name, Names... names) { + SmallVector dialectNames({name, names...}); + setDialectAction(dialectNames, LegalizationAction::Dynamic); + } + template void addDynamicallyLegalDialect() { + SmallVector dialectNames({Args::getDialectNamespace()...}); + setDialectAction(dialectNames, LegalizationAction::Dynamic); + } + //===--------------------------------------------------------------------===// // Legality Querying //===--------------------------------------------------------------------===// /// Get the legality action for the given operation. - llvm::Optional getOpAction(OperationName op) const { - // Check for an action for this specific operation. - auto it = legalOperations.find(op); - if (it != legalOperations.end()) - return it->second; - // Otherwise, default to checking for an action on the parent dialect. - auto dialectIt = legalDialects.find(op.getDialect()); - if (dialectIt != legalDialects.end()) - return dialectIt->second; - return llvm::None; - } + llvm::Optional getOpAction(OperationName op) const; private: /// A deterministic mapping of operation name to the specific legality action diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index dd91f069849b..6fbdba38428a 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -994,7 +994,7 @@ struct LLVMLoweringPass : public ModulePass { populateStdToLLVMConversionPatterns(converter, patterns); ConversionTarget target(getContext()); - target.addLegalDialects(); + target.addLegalDialect(); if (failed( applyConversionPatterns(m, target, converter, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 60c16d9eab16..2093acfdf775 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -689,7 +689,7 @@ void LowerLinalgToLLVMPass::runOnModule() { populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext()); ConversionTarget target(getContext()); - target.addLegalDialects(); + target.addLegalDialect(); if (failed(applyConversionPatterns(module, target, converter, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 5d2a889c4946..3a19692f190a 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -724,6 +724,37 @@ FunctionType TypeConverter::convertFunctionSignatureType( return FunctionType::get(arguments, results, type.getContext()); } +//===----------------------------------------------------------------------===// +// ConversionTarget +//===----------------------------------------------------------------------===// + +/// Register a legality action for the given operation. +void ConversionTarget::setOpAction(OperationName op, + LegalizationAction action) { + legalOperations[op] = action; +} + +/// Register a legality action for the given dialects. +void ConversionTarget::setDialectAction(ArrayRef dialectNames, + LegalizationAction action) { + for (StringRef dialect : dialectNames) + legalDialects[dialect] = action; +} + +/// Get the legality action for the given operation. +auto ConversionTarget::getOpAction(OperationName op) const + -> llvm::Optional { + // Check for an action for this specific operation. + auto it = legalOperations.find(op); + if (it != legalOperations.end()) + return it->second; + // Otherwise, default to checking for an action on the parent dialect. + auto dialectIt = legalDialects.find(op.getDialect()); + if (dialectIt != legalDialects.end()) + return dialectIt->second; + return llvm::None; +} + //===----------------------------------------------------------------------===// // applyConversionPatterns //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index ac2227dbdb8b..4d9df82b27ae 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -597,7 +597,7 @@ LogicalResult mlir::lowerAffineConstructs(Function &function) { AffineTerminatorLowering>::build(patterns, function.getContext()); ConversionTarget target(*function.getContext()); - target.addLegalDialects(); + target.addLegalDialect(); return applyConversionPatterns(function, target, std::move(patterns)); }