//===- MathToFuncs.cpp - Math to outlined implementation conversion -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/MathToFuncs/MathToFuncs.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMATHTOFUNCS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; namespace { // Pattern to convert vector operations to scalar operations. template struct VecOpToScalarOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; }; // Callback type for getting pre-generated FuncOp implementing // a power operation of the given type. using GetPowerFuncCallbackTy = function_ref; // Pattern to convert scalar IPowIOp into a call of outlined // software implementation. struct IPowIOpLowering : public OpRewritePattern { private: GetPowerFuncCallbackTy getFuncOpCallback; public: IPowIOpLowering(MLIRContext *context, GetPowerFuncCallbackTy cb) : OpRewritePattern(context), getFuncOpCallback(cb) {} /// Convert IPowI into a call to a local function implementing /// the power operation. The local function computes a scalar result, /// so vector forms of IPowI are linearized. LogicalResult matchAndRewrite(math::IPowIOp op, PatternRewriter &rewriter) const final; }; } // namespace template LogicalResult VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const { Type opType = op.getType(); Location loc = op.getLoc(); auto vecType = opType.template dyn_cast(); if (!vecType) return rewriter.notifyMatchFailure(op, "not a vector operation"); if (!vecType.hasRank()) return rewriter.notifyMatchFailure(op, "unknown vector rank"); ArrayRef shape = vecType.getShape(); int64_t numElements = vecType.getNumElements(); Value result = rewriter.create( loc, DenseElementsAttr::get( vecType, IntegerAttr::get(vecType.getElementType(), 0))); SmallVector strides = computeStrides(shape); for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) { SmallVector positions = delinearize(strides, linearIndex); SmallVector operands; for (Value input : op->getOperands()) operands.push_back( rewriter.create(loc, input, positions)); Value scalarOp = rewriter.create(loc, vecType.getElementType(), operands); result = rewriter.create(loc, scalarOp, result, positions); } rewriter.replaceOp(op, result); return success(); } /// Create linkonce_odr function to implement the power function with /// the given \p funcType type inside \p module. \p funcType must be /// 'IntegerType (*)(IntegerType, IntegerType)' function type. /// /// template /// T __mlir_math_ipowi_*(T b, T p) { /// if (p == T(0)) /// return T(1); /// if (p < T(0)) { /// if (b == T(0)) /// return T(1) / T(0); // trigger div-by-zero /// if (b == T(1)) /// return T(1); /// if (b == T(-1)) { /// if (p & T(1)) /// return T(-1); /// return T(1); /// } /// return T(0); /// } /// T result = T(1); /// while (true) { /// if (p & T(1)) /// result *= b; /// p >>= T(1); /// if (p == T(0)) /// return result; /// b *= b; /// } /// } static func::FuncOp createElementIPowIFunc(ModuleOp *module, Type elementType) { assert(elementType.isa() && "non-integer element type for IPowIOp"); // IntegerType elementType = funcType.getInput(0).cast(); ImplicitLocOpBuilder builder = ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody()); std::string funcName("__mlir_math_ipowi"); llvm::raw_string_ostream nameOS(funcName); nameOS << '_' << elementType; FunctionType funcType = FunctionType::get( builder.getContext(), {elementType, elementType}, elementType); auto funcOp = builder.create(funcName, funcType); LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR; Attribute linkage = LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); funcOp->setAttr("llvm.linkage", linkage); funcOp.setPrivate(); Block *entryBlock = funcOp.addEntryBlock(); Region *funcBody = entryBlock->getParent(); Value bArg = funcOp.getArgument(0); Value pArg = funcOp.getArgument(1); builder.setInsertionPointToEnd(entryBlock); Value zeroValue = builder.create( elementType, builder.getIntegerAttr(elementType, 0)); Value oneValue = builder.create( elementType, builder.getIntegerAttr(elementType, 1)); Value minusOneValue = builder.create( elementType, builder.getIntegerAttr(elementType, APInt(elementType.getIntOrFloatBitWidth(), -1ULL, /*isSigned=*/true))); // if (p == T(0)) // return T(1); auto pIsZero = builder.create(arith::CmpIPredicate::eq, pArg, zeroValue); Block *thenBlock = builder.createBlock(funcBody); builder.create(oneValue); Block *fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (p == T(0)). builder.setInsertionPointToEnd(pIsZero->getBlock()); builder.create(pIsZero, thenBlock, fallthroughBlock); // if (p < T(0)) { builder.setInsertionPointToEnd(fallthroughBlock); auto pIsNeg = builder.create(arith::CmpIPredicate::sle, pArg, zeroValue); // if (b == T(0)) builder.createBlock(funcBody); auto bIsZero = builder.create(arith::CmpIPredicate::eq, bArg, zeroValue); // return T(1) / T(0); thenBlock = builder.createBlock(funcBody); builder.create( builder.create(oneValue, zeroValue).getResult()); fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (b == T(0)). builder.setInsertionPointToEnd(bIsZero->getBlock()); builder.create(bIsZero, thenBlock, fallthroughBlock); // if (b == T(1)) builder.setInsertionPointToEnd(fallthroughBlock); auto bIsOne = builder.create(arith::CmpIPredicate::eq, bArg, oneValue); // return T(1); thenBlock = builder.createBlock(funcBody); builder.create(oneValue); fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (b == T(1)). builder.setInsertionPointToEnd(bIsOne->getBlock()); builder.create(bIsOne, thenBlock, fallthroughBlock); // if (b == T(-1)) { builder.setInsertionPointToEnd(fallthroughBlock); auto bIsMinusOne = builder.create(arith::CmpIPredicate::eq, bArg, minusOneValue); // if (p & T(1)) builder.createBlock(funcBody); auto pIsOdd = builder.create( arith::CmpIPredicate::ne, builder.create(pArg, oneValue), zeroValue); // return T(-1); thenBlock = builder.createBlock(funcBody); builder.create(minusOneValue); fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (p & T(1)). builder.setInsertionPointToEnd(pIsOdd->getBlock()); builder.create(pIsOdd, thenBlock, fallthroughBlock); // return T(1); // } // b == T(-1) builder.setInsertionPointToEnd(fallthroughBlock); builder.create(oneValue); fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (b == T(-1)). builder.setInsertionPointToEnd(bIsMinusOne->getBlock()); builder.create(bIsMinusOne, pIsOdd->getBlock(), fallthroughBlock); // return T(0); // } // (p < T(0)) builder.setInsertionPointToEnd(fallthroughBlock); builder.create(zeroValue); Block *loopHeader = builder.createBlock( funcBody, funcBody->end(), {elementType, elementType, elementType}, {builder.getLoc(), builder.getLoc(), builder.getLoc()}); // Set up conditional branch for (p < T(0)). builder.setInsertionPointToEnd(pIsNeg->getBlock()); // Set initial values of 'result', 'b' and 'p' for the loop. builder.create(pIsNeg, bIsZero->getBlock(), loopHeader, ValueRange{oneValue, bArg, pArg}); // T result = T(1); // while (true) { // if (p & T(1)) // result *= b; // p >>= T(1); // if (p == T(0)) // return result; // b *= b; // } Value resultTmp = loopHeader->getArgument(0); Value baseTmp = loopHeader->getArgument(1); Value powerTmp = loopHeader->getArgument(2); builder.setInsertionPointToEnd(loopHeader); // if (p & T(1)) auto powerTmpIsOdd = builder.create( arith::CmpIPredicate::ne, builder.create(powerTmp, oneValue), zeroValue); thenBlock = builder.createBlock(funcBody); // result *= b; Value newResultTmp = builder.create(resultTmp, baseTmp); fallthroughBlock = builder.createBlock(funcBody, funcBody->end(), elementType, builder.getLoc()); builder.setInsertionPointToEnd(thenBlock); builder.create(newResultTmp, fallthroughBlock); // Set up conditional branch for (p & T(1)). builder.setInsertionPointToEnd(powerTmpIsOdd->getBlock()); builder.create(powerTmpIsOdd, thenBlock, fallthroughBlock, resultTmp); // Merged 'result'. newResultTmp = fallthroughBlock->getArgument(0); // p >>= T(1); builder.setInsertionPointToEnd(fallthroughBlock); Value newPowerTmp = builder.create(powerTmp, oneValue); // if (p == T(0)) auto newPowerIsZero = builder.create(arith::CmpIPredicate::eq, newPowerTmp, zeroValue); // return result; thenBlock = builder.createBlock(funcBody); builder.create(newResultTmp); fallthroughBlock = builder.createBlock(funcBody); // Set up conditional branch for (p == T(0)). builder.setInsertionPointToEnd(newPowerIsZero->getBlock()); builder.create(newPowerIsZero, thenBlock, fallthroughBlock); // b *= b; // } builder.setInsertionPointToEnd(fallthroughBlock); Value newBaseTmp = builder.create(baseTmp, baseTmp); // Pass new values for 'result', 'b' and 'p' to the loop header. builder.create( ValueRange{newResultTmp, newBaseTmp, newPowerTmp}, loopHeader); return funcOp; } /// Convert IPowI into a call to a local function implementing /// the power operation. The local function computes a scalar result, /// so vector forms of IPowI are linearized. LogicalResult IPowIOpLowering::matchAndRewrite(math::IPowIOp op, PatternRewriter &rewriter) const { auto baseType = op.getOperands()[0].getType().dyn_cast(); if (!baseType) return rewriter.notifyMatchFailure(op, "non-integer base operand"); // The outlined software implementation must have been already // generated. func::FuncOp elementFunc = getFuncOpCallback(baseType); if (!elementFunc) return rewriter.notifyMatchFailure(op, "missing software implementation"); rewriter.replaceOpWithNewOp(op, elementFunc, op.getOperands()); return success(); } namespace { struct ConvertMathToFuncsPass : public impl::ConvertMathToFuncsBase { ConvertMathToFuncsPass() = default; void runOnOperation() override; private: // Generate outlined implementations for power operations // and store them in powerFuncs map. void preprocessPowOperations(); // A map between function types deduced from power operations // and the corresponding outlined software implementations // of these operations. DenseMap powerFuncs; }; } // namespace void ConvertMathToFuncsPass::preprocessPowOperations() { ModuleOp module = getOperation(); module.walk([&](Operation *op) { TypeSwitch(op).Case([&](math::IPowIOp op) { Type resultType = getElementTypeOrSelf(op.getResult().getType()); // Generate the software implementation of this operation, // if it has not been generated yet. auto entry = powerFuncs.try_emplace(resultType, func::FuncOp{}); if (entry.second) entry.first->second = createElementIPowIFunc(&module, resultType); }); }); } void ConvertMathToFuncsPass::runOnOperation() { ModuleOp module = getOperation(); // Create outlined implementations for power operations. preprocessPowOperations(); RewritePatternSet patterns(&getContext()); patterns.add>(patterns.getContext()); // For the given Type Returns FuncOp stored in powerFuncs map. auto getPowerFuncOpByType = [&](Type type) -> func::FuncOp { auto it = powerFuncs.find(type); if (it == powerFuncs.end()) return {}; return it->second; }; patterns.add(patterns.getContext(), getPowerFuncOpByType); ConversionTarget target(getContext()); target.addLegalDialect(); target.addIllegalOp(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } std::unique_ptr mlir::createConvertMathToFuncsPass() { return std::make_unique(); }