//===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert the Arith dialect to the EmitC // dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; //===----------------------------------------------------------------------===// // Conversion Patterns //===----------------------------------------------------------------------===// namespace { class ArithConstantOpConversionPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::ConstantOp arithConst, arith::ConstantOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( arithConst, arithConst.getType(), adaptor.getValue()); return success(); } }; template class ArithOpConversion final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.template replaceOpWithNewOp(arithOp, arithOp.getType(), adaptor.getOperands()); return success(); } }; template class IntegerOpConversion final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type type = this->getTypeConverter()->convertType(op.getType()); if (!isa_and_nonnull(type)) { return rewriter.notifyMatchFailure(op, "expected integer type"); } if (type.isInteger(1)) { // arith expects wrap-around arithmethic, which doesn't happen on `bool`. return rewriter.notifyMatchFailure(op, "i1 type is not implemented"); } Value lhs = adaptor.getLhs(); Value rhs = adaptor.getRhs(); Type arithmeticType = type; if ((type.isSignlessInteger() || type.isSignedInteger()) && !bitEnumContainsAll(op.getOverflowFlags(), arith::IntegerOverflowFlags::nsw)) { // If the C type is signed and the op doesn't guarantee "No Signed Wrap", // we compute in unsigned integers to avoid UB. arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(), /*isSigned=*/false); } if (arithmeticType != type) { lhs = rewriter.template create(op.getLoc(), arithmeticType, lhs); rhs = rewriter.template create(op.getLoc(), arithmeticType, rhs); } Value result = rewriter.template create(op.getLoc(), arithmeticType, lhs, rhs); if (arithmeticType != type) { result = rewriter.template create(op.getLoc(), type, result); } rewriter.replaceOp(op, result); return success(); } }; class SelectOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type dstType = getTypeConverter()->convertType(selectOp.getType()); if (!dstType) return rewriter.notifyMatchFailure(selectOp, "type conversion failed"); if (!adaptor.getCondition().getType().isInteger(1)) return rewriter.notifyMatchFailure( selectOp, "can only be converted if condition is a scalar of type i1"); rewriter.replaceOpWithNewOp(selectOp, dstType, adaptor.getOperands()); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) { MLIRContext *ctx = patterns.getContext(); // clang-format off patterns.add< ArithConstantOpConversionPattern, ArithOpConversion, ArithOpConversion, ArithOpConversion, ArithOpConversion, IntegerOpConversion, IntegerOpConversion, IntegerOpConversion, SelectOpConversion >(typeConverter, ctx); // clang-format on }