//===- SimplifyFIROperations.cpp -- simplify complex FIR operations ------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// /// \file /// This pass transforms some FIR operations into their equivalent /// implementations using other FIR operations. The transformation /// can legally use SCF dialect and generate Fortran runtime calls. //===----------------------------------------------------------------------===// #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/Runtime/Inquiry.h" #include "flang/Optimizer/Builder/Todo.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Transforms/Passes.h" #include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include namespace fir { #define GEN_PASS_DEF_SIMPLIFYFIROPERATIONS #include "flang/Optimizer/Transforms/Passes.h.inc" } // namespace fir #define DEBUG_TYPE "flang-simplify-fir-operations" namespace { /// Pass runner. class SimplifyFIROperationsPass : public fir::impl::SimplifyFIROperationsBase { public: using fir::impl::SimplifyFIROperationsBase< SimplifyFIROperationsPass>::SimplifyFIROperationsBase; void runOnOperation() override final; }; /// Base class for all conversions holding the pass options. template class ConversionBase : public mlir::OpRewritePattern { public: using mlir::OpRewritePattern::OpRewritePattern; template ConversionBase(mlir::MLIRContext *context, Args &&...args) : mlir::OpRewritePattern(context), options{std::forward(args)...} {} mlir::LogicalResult matchAndRewrite(Op, mlir::PatternRewriter &) const override; protected: fir::SimplifyFIROperationsOptions options; }; /// fir::IsContiguousBoxOp converter. using IsContiguousBoxCoversion = ConversionBase; /// fir::BoxTotalElementsOp converter. using BoxTotalElementsConversion = ConversionBase; } // namespace /// Generate a call to IsContiguous/IsContiguousUpTo function or an inline /// sequence reading extents/strides from the box and checking them. /// This conversion may produce fir.box_elesize and a loop (for assumed /// rank). template <> mlir::LogicalResult IsContiguousBoxCoversion::matchAndRewrite( fir::IsContiguousBoxOp op, mlir::PatternRewriter &rewriter) const { mlir::Location loc = op.getLoc(); fir::FirOpBuilder builder(rewriter, op.getOperation()); mlir::Value box = op.getBox(); if (options.preferInlineImplementation) { auto boxType = mlir::cast(box.getType()); unsigned rank = fir::getBoxRank(boxType); // If rank is one, or 'innermost' attribute is set and // it is not a scalar, then generate a simple comparison // for the leading dimension: (stride == elem_size || extent == 0). // // The scalar cases are supposed to be optimized by the canonicalization. if (rank == 1 || (op.getInnermost() && rank > 0)) { mlir::Type idxTy = builder.getIndexType(); auto eleSize = builder.create(loc, idxTy, box); mlir::Value zero = fir::factory::createZeroValue(builder, loc, idxTy); auto dimInfo = builder.create(loc, idxTy, idxTy, idxTy, box, zero); mlir::Value stride = dimInfo.getByteStride(); mlir::Value pred1 = builder.create( loc, mlir::arith::CmpIPredicate::eq, eleSize, stride); mlir::Value extent = dimInfo.getExtent(); mlir::Value pred2 = builder.create( loc, mlir::arith::CmpIPredicate::eq, extent, zero); mlir::Value result = builder.create(loc, pred1, pred2); result = builder.createConvert(loc, op.getType(), result); rewriter.replaceOp(op, result); return mlir::success(); } // TODO: support arrays with multiple dimensions. } // Generate Fortran runtime call. mlir::Value result; if (op.getInnermost()) { mlir::Value one = builder.createIntegerConstant(loc, builder.getI32Type(), 1); result = fir::runtime::genIsContiguousUpTo(builder, loc, box, one); } else { result = fir::runtime::genIsContiguous(builder, loc, box); } result = builder.createConvert(loc, op.getType(), result); rewriter.replaceOp(op, result); return mlir::success(); } /// Generate a call to Size runtime function or an inline /// sequence reading extents from the box an multiplying them. /// This conversion may produce a loop (for assumed rank). template <> mlir::LogicalResult BoxTotalElementsConversion::matchAndRewrite( fir::BoxTotalElementsOp op, mlir::PatternRewriter &rewriter) const { mlir::Location loc = op.getLoc(); fir::FirOpBuilder builder(rewriter, op.getOperation()); // TODO: support preferInlineImplementation. // Reading the extent from the box for 1D arrays probably // results in less code than the call, so we can always // inline it. bool doInline = options.preferInlineImplementation && false; if (!doInline) { // Generate Fortran runtime call. mlir::Value result = fir::runtime::genSize(builder, loc, op.getBox()); result = builder.createConvert(loc, op.getType(), result); rewriter.replaceOp(op, result); return mlir::success(); } // Generate inline implementation. TODO(loc, "inline BoxTotalElementsOp"); return mlir::failure(); } class DoConcurrentConversion : public mlir::OpRewritePattern { public: using mlir::OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(fir::DoConcurrentOp doConcurentOp, mlir::PatternRewriter &rewriter) const override { assert(doConcurentOp.getRegion().hasOneBlock()); mlir::Block &wrapperBlock = doConcurentOp.getRegion().getBlocks().front(); auto loop = mlir::cast(wrapperBlock.getTerminator()); assert(loop.getRegion().hasOneBlock()); mlir::Block &loopBlock = loop.getRegion().getBlocks().front(); // Collect iteration variable(s) allocations do that we can move them // outside the `fir.do_concurrent` wrapper. llvm::SmallVector opsToMove; for (mlir::Operation &op : llvm::drop_end(wrapperBlock)) opsToMove.push_back(&op); fir::FirOpBuilder firBuilder( rewriter, doConcurentOp->getParentOfType()); auto *allocIt = firBuilder.getAllocaBlock(); for (mlir::Operation *op : llvm::reverse(opsToMove)) rewriter.moveOpBefore(op, allocIt, allocIt->begin()); rewriter.setInsertionPointAfter(doConcurentOp); fir::DoLoopOp innermostUnorderdLoop; mlir::SmallVector ivArgs; for (auto [lb, ub, st, iv] : llvm::zip_equal(loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), *loop.getLoopInductionVars())) { innermostUnorderdLoop = rewriter.create( doConcurentOp.getLoc(), lb, ub, st, /*unordred=*/true, /*finalCountValue=*/false, /*iterArgs=*/std::nullopt, loop.getReduceOperands(), loop.getReduceAttrsAttr()); ivArgs.push_back(innermostUnorderdLoop.getInductionVar()); rewriter.setInsertionPointToStart(innermostUnorderdLoop.getBody()); } rewriter.inlineBlockBefore( &loopBlock, innermostUnorderdLoop.getBody()->getTerminator(), ivArgs); rewriter.eraseOp(doConcurentOp); return mlir::success(); } }; void SimplifyFIROperationsPass::runOnOperation() { mlir::ModuleOp module = getOperation(); mlir::MLIRContext &context = getContext(); mlir::RewritePatternSet patterns(&context); fir::populateSimplifyFIROperationsPatterns(patterns, preferInlineImplementation); mlir::GreedyRewriteConfig config; config.setRegionSimplificationLevel( mlir::GreedySimplifyRegionLevel::Disabled); if (mlir::failed( mlir::applyPatternsGreedily(module, std::move(patterns), config))) { mlir::emitError(module.getLoc(), DEBUG_TYPE " pass failed"); signalPassFailure(); } } void fir::populateSimplifyFIROperationsPatterns( mlir::RewritePatternSet &patterns, bool preferInlineImplementation) { patterns.insert( patterns.getContext(), preferInlineImplementation); patterns.insert(patterns.getContext()); }