diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h index ea658fb16a36..f2252d04079c 100644 --- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h +++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h @@ -813,6 +813,18 @@ uint64_t getAllocaAddressSpace(mlir::DataLayout *dataLayout); llvm::SmallVector deduceOptimalExtents(mlir::ValueRange extents1, mlir::ValueRange extents2); +/// Given array extents generate code that sets them all to zeroes, +/// if the array is empty, e.g.: +/// %false = arith.constant false +/// %c0 = arith.constant 0 : index +/// %p1 = arith.cmpi eq, %e0, %c0 : index +/// %p2 = arith.ori %false, %p1 : i1 +/// %p3 = arith.cmpi eq, %e1, %c0 : index +/// %p4 = arith.ori %p1, %p2 : i1 +/// %result0 = arith.select %p4, %c0, %e0 : index +/// %result1 = arith.select %p4, %c0, %e1 : index +llvm::SmallVector updateRuntimeExtentsForEmptyArrays( + fir::FirOpBuilder &builder, mlir::Location loc, mlir::ValueRange extents); } // namespace fir::factory #endif // FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H diff --git a/flang/include/flang/Optimizer/HLFIR/Passes.td b/flang/include/flang/Optimizer/HLFIR/Passes.td index 90cf6e74241b..eae0c9ca2e36 100644 --- a/flang/include/flang/Optimizer/HLFIR/Passes.td +++ b/flang/include/flang/Optimizer/HLFIR/Passes.td @@ -19,6 +19,11 @@ def ConvertHLFIRtoFIR : Pass<"convert-hlfir-to-fir", "::mlir::ModuleOp"> { def BufferizeHLFIR : Pass<"bufferize-hlfir", "::mlir::ModuleOp"> { let summary = "Convert HLFIR operations operating on hlfir.expr into operations on memory"; + let options = [Option<"optimizeEmptyElementals", "opt-empty-elementals", + "bool", /*default=*/"false", + "When converting hlfir.elemental into a loop nest, " + "check if the resulting expression is an empty array, " + "and make sure none of the loops is executed.">]; } def OptimizedBufferization : Pass<"opt-bufferization"> { diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp index 35dc9a2abd69..af350d1331e5 100644 --- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp +++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp @@ -1759,3 +1759,29 @@ fir::factory::deduceOptimalExtents(mlir::ValueRange extents1, } return extents; } + +llvm::SmallVector fir::factory::updateRuntimeExtentsForEmptyArrays( + fir::FirOpBuilder &builder, mlir::Location loc, mlir::ValueRange extents) { + if (extents.size() <= 1) + return extents; + + mlir::Type i1Type = builder.getI1Type(); + mlir::Value isEmpty = createZeroValue(builder, loc, i1Type); + + llvm::SmallVector zeroes; + for (mlir::Value extent : extents) { + mlir::Type type = extent.getType(); + mlir::Value zero = createZeroValue(builder, loc, type); + zeroes.push_back(zero); + mlir::Value isZero = builder.create( + loc, mlir::arith::CmpIPredicate::eq, extent, zero); + isEmpty = builder.create(loc, isEmpty, isZero); + } + + llvm::SmallVector newExtents; + for (auto [zero, extent] : llvm::zip_equal(zeroes, extents)) { + newExtents.push_back( + builder.create(loc, isEmpty, zero, extent)); + } + return newExtents; +} diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp index e664834d31d3..30e7ef789095 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp @@ -761,8 +761,10 @@ struct HLFIRListener : public mlir::OpBuilder::Listener { struct ElementalOpConversion : public mlir::OpConversionPattern { using mlir::OpConversionPattern::OpConversionPattern; - explicit ElementalOpConversion(mlir::MLIRContext *ctx) - : mlir::OpConversionPattern{ctx} { + explicit ElementalOpConversion(mlir::MLIRContext *ctx, + bool optimizeEmptyElementals = false) + : mlir::OpConversionPattern{ctx}, + optimizeEmptyElementals(optimizeEmptyElementals) { // This pattern recursively converts nested ElementalOp's // by cloning and then converting them, so we have to allow // for recursive pattern application. The recursion is bounded @@ -791,6 +793,10 @@ struct ElementalOpConversion // of the loop nest. temp = derefPointersAndAllocatables(loc, builder, temp); + if (optimizeEmptyElementals) + extents = fir::factory::updateRuntimeExtentsForEmptyArrays(builder, loc, + extents); + // Generate a loop nest looping around the fir.elemental shape and clone // fir.elemental region inside the inner loop. hlfir::LoopNest loopNest = @@ -861,6 +867,9 @@ struct ElementalOpConversion rewriter.replaceOp(elemental, bufferizedExpr); return mlir::success(); } + +private: + bool optimizeEmptyElementals = false; }; struct CharExtremumOpConversion : public mlir::OpConversionPattern { @@ -932,6 +941,8 @@ struct EvaluateInMemoryOpConversion class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase { public: + using BufferizeHLFIRBase::BufferizeHLFIRBase; + void runOnOperation() override { // TODO: make this a pass operating on FuncOp. The issue is that // FirOpBuilder helpers may generate new FuncOp because of runtime/llvm @@ -943,13 +954,13 @@ public: auto module = this->getOperation(); auto *context = &getContext(); mlir::RewritePatternSet patterns(context); - patterns - .insert(context); + patterns.insert(context); + patterns.insert(context, optimizeEmptyElementals); mlir::ConversionTarget target(*context); // Note that YieldElementOp is not marked as an illegal operation. // It must be erased by its parent converter and there is no explicit diff --git a/flang/test/HLFIR/elemental-with-empty-check-codegen.fir b/flang/test/HLFIR/elemental-with-empty-check-codegen.fir new file mode 100644 index 000000000000..051127cf67ac --- /dev/null +++ b/flang/test/HLFIR/elemental-with-empty-check-codegen.fir @@ -0,0 +1,56 @@ +// Test hlfir.elemental code generation with a dynamic check +// for empty result array +// RUN: fir-opt %s --bufferize-hlfir=opt-empty-elementals=true | FileCheck %s + +func.func @test(%v: i32, %e0: i32, %e1: i32, %e2: i64, %e3: i64) { + %shape = fir.shape %e0, %e1, %e2, %e3 : (i32, i32, i64, i64) -> !fir.shape<4> + %result = hlfir.elemental %shape : (!fir.shape<4>) -> !hlfir.expr { + ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index): + hlfir.yield_element %v : i32 + } + return +} +// CHECK-LABEL: func.func @test( +// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, +// CHECK-SAME: %[[VAL_3:.*]]: i64, %[[VAL_4:.*]]: i64) { +// CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] : (i32, i32, i64, i64) -> !fir.shape<4> +// CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_1]] : (i32) -> index +// CHECK: %[[VAL_7:.*]] = fir.convert %[[VAL_2]] : (i32) -> index +// CHECK: %[[VAL_8:.*]] = fir.convert %[[VAL_3]] : (i64) -> index +// CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_4]] : (i64) -> index +// CHECK: %[[VAL_10:.*]] = fir.allocmem !fir.array, %[[VAL_6]], %[[VAL_7]], %[[VAL_8]], %[[VAL_9]] {bindc_name = ".tmp.array", uniq_name = ""} +// CHECK: %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_10]](%[[VAL_5]]) {uniq_name = ".tmp.array"} : (!fir.heap>, !fir.shape<4>) -> (!fir.box>, !fir.heap>) +// CHECK: %[[VAL_12:.*]] = arith.constant true +// CHECK: %[[VAL_13:.*]] = arith.constant false +// CHECK: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_15:.*]] = arith.cmpi eq, %[[VAL_6]], %[[C0_1]] : index +// CHECK: %[[VAL_16:.*]] = arith.ori %[[VAL_13]], %[[VAL_15]] : i1 +// CHECK: %[[C0_2:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_17:.*]] = arith.cmpi eq, %[[VAL_7]], %[[C0_2]] : index +// CHECK: %[[VAL_18:.*]] = arith.ori %[[VAL_16]], %[[VAL_17]] : i1 +// CHECK: %[[C0_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_19:.*]] = arith.cmpi eq, %[[VAL_8]], %[[C0_3]] : index +// CHECK: %[[VAL_20:.*]] = arith.ori %[[VAL_18]], %[[VAL_19]] : i1 +// CHECK: %[[C0_4:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_9]], %[[C0_4]] : index +// CHECK: %[[VAL_22:.*]] = arith.ori %[[VAL_20]], %[[VAL_21]] : i1 +// CHECK: %[[VAL_23:.*]] = arith.select %[[VAL_22]], %[[C0_1]], %[[VAL_6]] : index +// CHECK: %[[VAL_24:.*]] = arith.select %[[VAL_22]], %[[C0_2]], %[[VAL_7]] : index +// CHECK: %[[VAL_25:.*]] = arith.select %[[VAL_22]], %[[C0_3]], %[[VAL_8]] : index +// CHECK: %[[VAL_26:.*]] = arith.select %[[VAL_22]], %[[C0_4]], %[[VAL_9]] : index +// CHECK: %[[VAL_27:.*]] = arith.constant 1 : index +// CHECK: fir.do_loop %[[VAL_28:.*]] = %[[VAL_27]] to %[[VAL_26]] step %[[VAL_27]] { +// CHECK: fir.do_loop %[[VAL_29:.*]] = %[[VAL_27]] to %[[VAL_25]] step %[[VAL_27]] { +// CHECK: fir.do_loop %[[VAL_30:.*]] = %[[VAL_27]] to %[[VAL_24]] step %[[VAL_27]] { +// CHECK: fir.do_loop %[[VAL_31:.*]] = %[[VAL_27]] to %[[VAL_23]] step %[[VAL_27]] { +// CHECK: %[[VAL_32:.*]] = hlfir.designate %[[VAL_11]]#0 (%[[VAL_31]], %[[VAL_30]], %[[VAL_29]], %[[VAL_28]]) : (!fir.box>, index, index, index, index) -> !fir.ref +// CHECK: hlfir.assign %[[VAL_0]] to %[[VAL_32]] temporary_lhs : i32, !fir.ref +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_33:.*]] = fir.undefined tuple>, i1> +// CHECK: %[[VAL_34:.*]] = fir.insert_value %[[VAL_33]], %[[VAL_12]], [1 : index] : (tuple>, i1>, i1) -> tuple>, i1> +// CHECK: %[[VAL_35:.*]] = fir.insert_value %[[VAL_34]], %[[VAL_11]]#0, [0 : index] : (tuple>, i1>, !fir.box>) -> tuple>, i1> +// CHECK: return +// CHECK: }