//===- OptimizeForNVVM.cpp - Optimize LLVM IR for NVVM ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { namespace NVVM { #define GEN_PASS_DEF_NVVMOPTIMIZEFORTARGET #include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc" } // namespace NVVM } // namespace mlir using namespace mlir; namespace { // Replaces fdiv on fp16 with fp32 multiplication with reciprocal plus one // (conditional) Newton iteration. // // This as accurate as promoting the division to fp32 in the NVPTX backend, but // faster because it performs less Newton iterations, avoids the slow path // for e.g. denormals, and allows reuse of the reciprocal for multiple divisions // by the same divisor. struct ExpandDivF16 : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; private: LogicalResult matchAndRewrite(LLVM::FDivOp op, PatternRewriter &rewriter) const override; }; struct NVVMOptimizeForTarget : public NVVM::impl::NVVMOptimizeForTargetBase { void runOnOperation() override; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } }; } // namespace LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op, PatternRewriter &rewriter) const { if (!op.getType().isF16()) return rewriter.notifyMatchFailure(op, "not f16"); Location loc = op.getLoc(); Type f32Type = rewriter.getF32Type(); Type i32Type = rewriter.getI32Type(); // Extend lhs and rhs to fp32. Value lhs = rewriter.create(loc, f32Type, op.getLhs()); Value rhs = rewriter.create(loc, f32Type, op.getRhs()); // float rcp = rcp.approx.ftz.f32(rhs), approx = lhs * rcp. Value rcp = rewriter.create(loc, f32Type, rhs); Value approx = rewriter.create(loc, lhs, rcp); // Refine the approximation with one Newton iteration: // float refined = approx + (lhs - approx * rhs) * rcp; Value err = rewriter.create( loc, approx, rewriter.create(loc, rhs), lhs); Value refined = rewriter.create(loc, err, rcp, approx); // Use refined value if approx is normal (exponent neither all 0 or all 1). Value mask = rewriter.create( loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000)); Value cast = rewriter.create(loc, i32Type, approx); Value exp = rewriter.create(loc, i32Type, cast, mask); Value zero = rewriter.create( loc, i32Type, rewriter.getUI32IntegerAttr(0)); Value pred = rewriter.create( loc, rewriter.create(loc, LLVM::ICmpPredicate::eq, exp, zero), rewriter.create(loc, LLVM::ICmpPredicate::eq, exp, mask)); Value result = rewriter.create(loc, f32Type, pred, approx, refined); // Replace with trucation back to fp16. rewriter.replaceOpWithNewOp(op, op.getType(), result); return success(); } void NVVMOptimizeForTarget::runOnOperation() { MLIRContext *ctx = getOperation()->getContext(); RewritePatternSet patterns(ctx); patterns.add(ctx); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } std::unique_ptr NVVM::createOptimizeForTargetPass() { return std::make_unique(); }