//===- BufferizeHLFIR.cpp - Bufferize HLFIR ------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // This file defines a pass that bufferize hlfir.expr. It translates operations // producing or consuming hlfir.expr into operations operating on memory. // An hlfir.expr is translated to a tuple // where cleanupflag is set to true if storage for the expression was allocated // on the heap. //===----------------------------------------------------------------------===// #include "flang/Optimizer/Builder/Character.h" #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/HLFIRTools.h" #include "flang/Optimizer/Builder/MutableBox.h" #include "flang/Optimizer/Builder/Runtime/Assign.h" #include "flang/Optimizer/Builder/Todo.h" #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Optimizer/HLFIR/Passes.h" #include "flang/Optimizer/Support/FIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include namespace hlfir { #define GEN_PASS_DEF_BUFFERIZEHLFIR #include "flang/Optimizer/HLFIR/Passes.h.inc" } // namespace hlfir namespace { /// Helper to create tuple from a bufferized expr storage and clean up /// instruction flag. static mlir::Value packageBufferizedExpr(mlir::Location loc, fir::FirOpBuilder &builder, mlir::Value storage, mlir::Value mustFree) { auto tupleType = mlir::TupleType::get( builder.getContext(), mlir::TypeRange{storage.getType(), mustFree.getType()}); auto undef = builder.create(loc, tupleType); auto insert = builder.create( loc, tupleType, undef, mustFree, builder.getArrayAttr( {builder.getIntegerAttr(builder.getIndexType(), 1)})); return builder.create( loc, tupleType, insert, storage, builder.getArrayAttr( {builder.getIntegerAttr(builder.getIndexType(), 0)})); } /// Helper to create tuple from a bufferized expr storage and constant /// boolean clean-up flag. static mlir::Value packageBufferizedExpr(mlir::Location loc, fir::FirOpBuilder &builder, mlir::Value storage, bool mustFree) { mlir::Value mustFreeValue = builder.createBool(loc, mustFree); return packageBufferizedExpr(loc, builder, storage, mustFreeValue); } /// Helper to extract the storage from a tuple created by packageBufferizedExpr. /// It assumes no tuples are used as HLFIR operation operands, which is /// currently enforced by the verifiers that only accept HLFIR value or /// variable types which do not include tuples. static mlir::Value getBufferizedExprStorage(mlir::Value bufferizedExpr) { auto tupleType = bufferizedExpr.getType().dyn_cast(); if (!tupleType) return bufferizedExpr; assert(tupleType.size() == 2 && "unexpected tuple type"); if (auto insert = bufferizedExpr.getDefiningOp()) if (insert.getVal().getType() == tupleType.getType(0)) return insert.getVal(); TODO(bufferizedExpr.getLoc(), "general extract storage case"); } /// Helper to extract the clean-up flag from a tuple created by /// packageBufferizedExpr. static mlir::Value getBufferizedExprMustFreeFlag(mlir::Value bufferizedExpr) { auto tupleType = bufferizedExpr.getType().dyn_cast(); if (!tupleType) return bufferizedExpr; assert(tupleType.size() == 2 && "unexpected tuple type"); if (auto insert = bufferizedExpr.getDefiningOp()) if (auto insert0 = insert.getAdt().getDefiningOp()) if (insert0.getVal().getType() == tupleType.getType(1)) return insert0.getVal(); TODO(bufferizedExpr.getLoc(), "general extract storage case"); } static std::pair createTempFromMold(mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity mold) { llvm::SmallVector lenParams; hlfir::genLengthParameters(loc, builder, mold, lenParams); llvm::StringRef tmpName{".tmp"}; mlir::Value alloc; mlir::Value isHeapAlloc; mlir::Value shape{}; if (mold.isArray()) { mlir::Type sequenceType = hlfir::getFortranElementOrSequenceType(mold.getType()); shape = hlfir::genShape(loc, builder, mold); auto extents = hlfir::getIndexExtents(loc, builder, shape); alloc = builder.createHeapTemporary(loc, sequenceType, tmpName, extents, lenParams); isHeapAlloc = builder.createBool(loc, true); } else { alloc = builder.createTemporary(loc, mold.getFortranElementType(), tmpName, /*shape*/ std::nullopt, lenParams); isHeapAlloc = builder.createBool(loc, false); } auto declareOp = builder.create( loc, alloc, tmpName, shape, lenParams, fir::FortranVariableFlagsAttr{}); return {hlfir::Entity{declareOp.getBase()}, isHeapAlloc}; } static std::pair createArrayTemp(mlir::Location loc, fir::FirOpBuilder &builder, mlir::Type exprType, mlir::Value shape, mlir::ValueRange extents, mlir::ValueRange lenParams) { mlir::Type sequenceType = hlfir::getFortranElementOrSequenceType(exprType); llvm::StringRef tmpName{".tmp.array"}; mlir::Value allocmem = builder.createHeapTemporary(loc, sequenceType, tmpName, extents, lenParams); auto declareOp = builder.create(loc, allocmem, tmpName, shape, lenParams, fir::FortranVariableFlagsAttr{}); mlir::Value trueVal = builder.createBool(loc, true); return {hlfir::Entity{declareOp.getBase()}, trueVal}; } struct AsExprOpConversion : public mlir::OpConversionPattern { using mlir::OpConversionPattern::OpConversionPattern; explicit AsExprOpConversion(mlir::MLIRContext *ctx) : mlir::OpConversionPattern{ctx} {} mlir::LogicalResult matchAndRewrite(hlfir::AsExprOp asExpr, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location loc = asExpr->getLoc(); auto module = asExpr->getParentOfType(); fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); hlfir::Entity source = hlfir::Entity{adaptor.getVar()}; auto [temp, cleanup] = createTempFromMold(loc, builder, source); builder.create(loc, source, temp); mlir::Value bufferizedExpr = packageBufferizedExpr(loc, builder, temp, cleanup); rewriter.replaceOp(asExpr, bufferizedExpr); return mlir::success(); } }; struct ApplyOpConversion : public mlir::OpConversionPattern { using mlir::OpConversionPattern::OpConversionPattern; explicit ApplyOpConversion(mlir::MLIRContext *ctx) : mlir::OpConversionPattern{ctx} {} mlir::LogicalResult matchAndRewrite(hlfir::ApplyOp apply, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location loc = apply->getLoc(); hlfir::Entity bufferizedExpr{getBufferizedExprStorage(adaptor.getExpr())}; mlir::Type resultType = hlfir::getVariableElementType(bufferizedExpr); mlir::Value result = rewriter.create( loc, resultType, bufferizedExpr, adaptor.getIndices(), adaptor.getTypeparams()); if (fir::isa_trivial(apply.getType())) result = rewriter.create(loc, result); rewriter.replaceOp(apply, result); return mlir::success(); } }; struct AssignOpConversion : public mlir::OpConversionPattern { using mlir::OpConversionPattern::OpConversionPattern; explicit AssignOpConversion(mlir::MLIRContext *ctx) : mlir::OpConversionPattern{ctx} {} mlir::LogicalResult matchAndRewrite(hlfir::AssignOp assign, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( assign, getBufferizedExprStorage(adaptor.getOperands()[0]), getBufferizedExprStorage(adaptor.getOperands()[1])); return mlir::success(); } }; struct ConcatOpConversion : public mlir::OpConversionPattern { using mlir::OpConversionPattern::OpConversionPattern; explicit ConcatOpConversion(mlir::MLIRContext *ctx) : mlir::OpConversionPattern{ctx} {} mlir::LogicalResult matchAndRewrite(hlfir::ConcatOp concat, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location loc = concat->getLoc(); auto module = concat->getParentOfType(); fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); assert(adaptor.getStrings().size() >= 2 && "must have at least two strings operands"); if (adaptor.getStrings().size() > 2) TODO(loc, "codegen of optimized chained concatenation of more than two " "strings"); hlfir::Entity lhs{getBufferizedExprStorage(adaptor.getStrings()[0])}; hlfir::Entity rhs{getBufferizedExprStorage(adaptor.getStrings()[1])}; auto [lhsExv, c1] = hlfir::translateToExtendedValue(loc, builder, lhs); auto [rhsExv, c2] = hlfir::translateToExtendedValue(loc, builder, rhs); assert(!c1 && !c2 && "expected variables"); fir::ExtendedValue res = fir::factory::CharacterExprHelper{builder, loc}.createConcatenate( *lhsExv.getCharBox(), *rhsExv.getCharBox()); // Ensure the memory type is the same as the result type. mlir::Type addrType = fir::ReferenceType::get( hlfir::getFortranElementType(concat.getResult().getType())); mlir::Value cast = builder.createConvert(loc, addrType, fir::getBase(res)); res = fir::substBase(res, cast); auto hlfirTempRes = hlfir::genDeclare(loc, builder, res, "tmp", fir::FortranVariableFlagsAttr{}) .getBase(); mlir::Value bufferizedExpr = packageBufferizedExpr(loc, builder, hlfirTempRes, false); rewriter.replaceOp(concat, bufferizedExpr); return mlir::success(); } }; struct SetLengthOpConversion : public mlir::OpConversionPattern { using mlir::OpConversionPattern::OpConversionPattern; explicit SetLengthOpConversion(mlir::MLIRContext *ctx) : mlir::OpConversionPattern{ctx} {} mlir::LogicalResult matchAndRewrite(hlfir::SetLengthOp setLength, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location loc = setLength->getLoc(); auto module = setLength->getParentOfType(); fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); // Create a temp with the new length. hlfir::Entity string{getBufferizedExprStorage(adaptor.getString())}; auto charType = hlfir::getFortranElementType(setLength.getType()); llvm::StringRef tmpName{".tmp"}; llvm::SmallVector lenParams{adaptor.getLength()}; auto alloca = builder.createTemporary(loc, charType, tmpName, /*shape=*/std::nullopt, lenParams); auto declareOp = builder.create( loc, alloca, tmpName, /*shape=*/mlir::Value{}, lenParams, fir::FortranVariableFlagsAttr{}); // Assign string value to the created temp. builder.create(loc, string, declareOp.getBase()); mlir::Value bufferizedExpr = packageBufferizedExpr(loc, builder, alloca, false); rewriter.replaceOp(setLength, bufferizedExpr); return mlir::success(); } }; struct AssociateOpConversion : public mlir::OpConversionPattern { using mlir::OpConversionPattern::OpConversionPattern; explicit AssociateOpConversion(mlir::MLIRContext *ctx) : mlir::OpConversionPattern{ctx} {} mlir::LogicalResult matchAndRewrite(hlfir::AssociateOp associate, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location loc = associate->getLoc(); // If this is the last use of the expression value and this is an hlfir.expr // that was bufferized, re-use the storage. // Otherwise, create a temp and assign the storage to it. mlir::Value bufferizedExpr = getBufferizedExprStorage(adaptor.getSource()); const bool isTrivialValue = fir::isa_trivial(bufferizedExpr.getType()); auto replaceWith = [&](mlir::Value hlfirVar, mlir::Value firVar, mlir::Value flag) { associate.getResult(0).replaceAllUsesWith(hlfirVar); associate.getResult(1).replaceAllUsesWith(firVar); associate.getResult(2).replaceAllUsesWith(flag); rewriter.replaceOp(associate, {hlfirVar, firVar, flag}); }; if (!isTrivialValue && associate.getSource().hasOneUse()) { mlir::Value mustFree = getBufferizedExprMustFreeFlag(adaptor.getSource()); mlir::Value firBase = hlfir::Entity{bufferizedExpr}.getFirBase(); replaceWith(bufferizedExpr, firBase, mustFree); return mlir::success(); } if (isTrivialValue) { auto module = associate->getParentOfType(); fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); auto temp = builder.createTemporary(loc, bufferizedExpr.getType(), associate.getUniqName()); builder.create(loc, bufferizedExpr, temp); mlir::Value mustFree = builder.createBool(loc, false); replaceWith(temp, temp, mustFree); return mlir::success(); } TODO(loc, "hlfir.associate of hlfir.expr with more than one use"); } }; struct EndAssociateOpConversion : public mlir::OpConversionPattern { using mlir::OpConversionPattern::OpConversionPattern; explicit EndAssociateOpConversion(mlir::MLIRContext *ctx) : mlir::OpConversionPattern{ctx} {} mlir::LogicalResult matchAndRewrite(hlfir::EndAssociateOp endAssociate, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Value mustFree = adaptor.getMustFree(); mlir::Location loc = endAssociate->getLoc(); rewriter.eraseOp(endAssociate); auto genFree = [&]() { mlir::Value var = adaptor.getVar(); if (var.getType().isa()) TODO(loc, "unbox"); rewriter.create(loc, var); }; if (auto cstMustFree = fir::getIntIfConstant(mustFree)) { if (*cstMustFree != 0) genFree(); // else, nothing to do. return mlir::success(); } TODO(endAssociate.getLoc(), "conditional free"); } }; struct NoReassocOpConversion : public mlir::OpConversionPattern { using mlir::OpConversionPattern::OpConversionPattern; explicit NoReassocOpConversion(mlir::MLIRContext *ctx) : mlir::OpConversionPattern{ctx} {} mlir::LogicalResult matchAndRewrite(hlfir::NoReassocOp noreassoc, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( noreassoc, getBufferizedExprStorage(adaptor.getVal())); return mlir::success(); } }; /// This Listener allows setting both the builder and the rewriter as /// listeners. This is required when a pattern uses a firBuilder helper that /// may create illegal operations that will need to be translated and requires /// notifying the rewriter. struct HLFIRListener : public mlir::OpBuilder::Listener { HLFIRListener(fir::FirOpBuilder &builder, mlir::ConversionPatternRewriter &rewriter) : builder{builder}, rewriter{rewriter} {} void notifyOperationInserted(mlir::Operation *op) override { builder.notifyOperationInserted(op); rewriter.notifyOperationInserted(op); } virtual void notifyBlockCreated(mlir::Block *block) override { builder.notifyBlockCreated(block); rewriter.notifyBlockCreated(block); } fir::FirOpBuilder &builder; mlir::ConversionPatternRewriter &rewriter; }; struct ElementalOpConversion : public mlir::OpConversionPattern { using mlir::OpConversionPattern::OpConversionPattern; explicit ElementalOpConversion(mlir::MLIRContext *ctx) : mlir::OpConversionPattern{ctx} {} mlir::LogicalResult matchAndRewrite(hlfir::ElementalOp elemental, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location loc = elemental->getLoc(); auto module = elemental->getParentOfType(); fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); // The body of the elemental op may contain operation that will require // to be translated. Notify the rewriter about the cloned operations. HLFIRListener listener{builder, rewriter}; builder.setListener(&listener); mlir::Value shape = adaptor.getShape(); auto extents = hlfir::getIndexExtents(loc, builder, shape); auto [temp, cleanup] = createArrayTemp(loc, builder, elemental.getType(), shape, extents, adaptor.getTypeparams()); // Generate a loop nest looping around the fir.elemental shape and clone // fir.elemental region inside the inner loop. auto [innerLoop, oneBasedLoopIndices] = hlfir::genLoopNest(loc, builder, extents); auto insPt = builder.saveInsertionPoint(); builder.setInsertionPointToStart(innerLoop.getBody()); auto yield = hlfir::inlineElementalOp(loc, builder, elemental, oneBasedLoopIndices); hlfir::Entity elementValue(yield.getElementValue()); // Skip final AsExpr if any. It would create an element temporary, // which is no needed since the element will be assigned right away in // the array temporary. An hlfir.as_expr may have been added if the // elemental is a "view" over a variable (e.g parentheses or transpose). if (auto asExpr = elementValue.getDefiningOp()) { elementValue = hlfir::Entity{asExpr.getVar()}; if (asExpr->hasOneUse()) rewriter.eraseOp(asExpr); } rewriter.eraseOp(yield); // Assign the element value to the temp element for this iteration. auto tempElement = hlfir::getElementAt(loc, builder, temp, oneBasedLoopIndices); builder.create(loc, elementValue, tempElement); builder.restoreInsertionPoint(insPt); mlir::Value bufferizedExpr = packageBufferizedExpr(loc, builder, temp, cleanup); rewriter.replaceOp(elemental, bufferizedExpr); return mlir::success(); } }; class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase { public: void runOnOperation() override { // TODO: make this a pass operating on FuncOp. The issue is that // FirOpBuilder helpers may generate new FuncOp because of runtime/llvm // intrinsics calls creation. This may create race conflict if the pass is // scheduled on FuncOp. A solution could be to provide an optional mutex // when building a FirOpBuilder and locking around FuncOp and GlobalOp // creation, but this needs a bit more thinking, so at this point the pass // is scheduled on the moduleOp. auto module = this->getOperation(); auto *context = &getContext(); mlir::RewritePatternSet patterns(context); patterns.insert(context); mlir::ConversionTarget target(*context); target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](mlir::Operation *op) { return llvm::all_of( op->getResultTypes(), [](mlir::Type ty) { return !ty.isa(); }) && llvm::all_of(op->getOperandTypes(), [](mlir::Type ty) { return !ty.isa(); }); }); if (mlir::failed( mlir::applyFullConversion(module, target, std::move(patterns)))) { mlir::emitError(mlir::UnknownLoc::get(context), "failure in HLFIR bufferization pass"); signalPassFailure(); } } }; } // namespace std::unique_ptr hlfir::createBufferizeHLFIRPass() { return std::make_unique(); }