This commit introduces a new MathToEmitC conversion pass that lowers selected math operations from the Math dialect to the emitc.call_opaque operation in the EmitC dialect. **Supported Math Operations:** The following operations are converted: - math.floor -> emitc.call_opaque<"floor"> - math.round -> emitc.call_opaque<"round"> - math.exp -> emitc.call_opaque<"exp"> - math.cos -> emitc.call_opaque<"cos"> - math.sin -> emitc.call_opaque<"sin"> - math.acos -> emitc.call_opaque<"acos"> - math.asin -> emitc.call_opaque<"asin"> - math.atan2 -> emitc.call_opaque<"atan2"> - math.ceil -> emitc.call_opaque<"ceil"> - math.absf -> emitc.call_opaque<"fabs"> - math.powf -> emitc.call_opaque<"pow"> **Target Language Standards:** The pass supports targeting different language standards: - C99: Generates calls with suffixes (e.g., floorf, fabsf) for single-precision floats. - CPP11: Prepends std:: to functions (e.g., std::floor, std::fabs). **Design Decisions:** The pass uses emitc.call_opaque instead of emitc.call to better emulate C-style function overloading. emitc.call_opaque does not require a unique type signature, making it more suitable for operations like <math.h> functions that may be overloaded for different types. This design choice ensures compatibility with C/C++ conventions.
86 lines
3.9 KiB
C++
86 lines
3.9 KiB
C++
//===- MathToEmitC.cpp - Math to EmitC Patterns -----------------*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
|
|
|
|
#include "mlir/Dialect/EmitC/IR/EmitC.h"
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
template <typename OpType>
|
|
class LowerToEmitCCallOpaque : public OpRewritePattern<OpType> {
|
|
std::string calleeStr;
|
|
emitc::LanguageTarget languageTarget;
|
|
|
|
public:
|
|
LowerToEmitCCallOpaque(MLIRContext *context, std::string calleeStr,
|
|
emitc::LanguageTarget languageTarget)
|
|
: OpRewritePattern<OpType>(context), calleeStr(std::move(calleeStr)),
|
|
languageTarget(languageTarget) {}
|
|
|
|
LogicalResult matchAndRewrite(OpType op,
|
|
PatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
template <typename OpType>
|
|
LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite(
|
|
OpType op, PatternRewriter &rewriter) const {
|
|
if (!llvm::all_of(op->getOperandTypes(),
|
|
llvm::IsaPred<Float32Type, Float64Type>) ||
|
|
!llvm::all_of(op->getResultTypes(),
|
|
llvm::IsaPred<Float32Type, Float64Type>))
|
|
return rewriter.notifyMatchFailure(
|
|
op.getLoc(),
|
|
"expected all operands and results to be of type f32 or f64");
|
|
std::string modifiedCalleeStr = calleeStr;
|
|
if (languageTarget == emitc::LanguageTarget::cpp11) {
|
|
modifiedCalleeStr = "std::" + calleeStr;
|
|
} else if (languageTarget == emitc::LanguageTarget::c99) {
|
|
auto operandType = op->getOperandTypes()[0];
|
|
if (operandType.isF32())
|
|
modifiedCalleeStr = calleeStr + "f";
|
|
}
|
|
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
|
|
op, op.getType(), modifiedCalleeStr, op->getOperands());
|
|
return success();
|
|
}
|
|
|
|
} // namespace
|
|
|
|
// Populates patterns to replace `math` operations with `emitc.call_opaque`,
|
|
// using function names consistent with those in <math.h>.
|
|
void mlir::populateConvertMathToEmitCPatterns(
|
|
RewritePatternSet &patterns, emitc::LanguageTarget languageTarget) {
|
|
auto *context = patterns.getContext();
|
|
patterns.insert<LowerToEmitCCallOpaque<math::FloorOp>>(context, "floor",
|
|
languageTarget);
|
|
patterns.insert<LowerToEmitCCallOpaque<math::RoundOp>>(context, "round",
|
|
languageTarget);
|
|
patterns.insert<LowerToEmitCCallOpaque<math::ExpOp>>(context, "exp",
|
|
languageTarget);
|
|
patterns.insert<LowerToEmitCCallOpaque<math::CosOp>>(context, "cos",
|
|
languageTarget);
|
|
patterns.insert<LowerToEmitCCallOpaque<math::SinOp>>(context, "sin",
|
|
languageTarget);
|
|
patterns.insert<LowerToEmitCCallOpaque<math::AcosOp>>(context, "acos",
|
|
languageTarget);
|
|
patterns.insert<LowerToEmitCCallOpaque<math::AsinOp>>(context, "asin",
|
|
languageTarget);
|
|
patterns.insert<LowerToEmitCCallOpaque<math::Atan2Op>>(context, "atan2",
|
|
languageTarget);
|
|
patterns.insert<LowerToEmitCCallOpaque<math::CeilOp>>(context, "ceil",
|
|
languageTarget);
|
|
patterns.insert<LowerToEmitCCallOpaque<math::AbsFOp>>(context, "fabs",
|
|
languageTarget);
|
|
patterns.insert<LowerToEmitCCallOpaque<math::PowFOp>>(context, "pow",
|
|
languageTarget);
|
|
}
|