diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h index 6e85b8f4ddf8..0684ad0f926e 100644 --- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h +++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h @@ -513,6 +513,12 @@ genTypeAndKindConvert(mlir::Location loc, fir::FirOpBuilder &builder, Entity loadElementAt(mlir::Location loc, fir::FirOpBuilder &builder, Entity entity, mlir::ValueRange oneBasedIndices); +/// Return a vector of extents for the given entity. +/// The function creates new operations, but tries to clean-up +/// after itself. +llvm::SmallVector +genExtentsVector(mlir::Location loc, fir::FirOpBuilder &builder, Entity entity); + } // namespace hlfir #endif // FORTRAN_OPTIMIZER_BUILDER_HLFIRTOOLS_H diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp index 5e5d0bbd6813..f71adf123511 100644 --- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp +++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp @@ -1421,3 +1421,15 @@ hlfir::Entity hlfir::loadElementAt(mlir::Location loc, return loadTrivialScalar(loc, builder, getElementAt(loc, builder, entity, oneBasedIndices)); } + +llvm::SmallVector +hlfir::genExtentsVector(mlir::Location loc, fir::FirOpBuilder &builder, + hlfir::Entity entity) { + entity = hlfir::derefPointersAndAllocatables(loc, builder, entity); + mlir::Value shape = hlfir::genShape(loc, builder, entity); + llvm::SmallVector extents = + hlfir::getExplicitExtentsFromShape(shape, builder); + if (shape.getUses().empty()) + shape.getDefiningOp()->erase(); + return extents; +} diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp index 0fe3620b7f1a..fe7ae0eeed3c 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp @@ -37,6 +37,79 @@ static llvm::cl::opt forceMatmulAsElemental( namespace { +// Helper class to generate operations related to computing +// product of values. +class ProductFactory { +public: + ProductFactory(mlir::Location loc, fir::FirOpBuilder &builder) + : loc(loc), builder(builder) {} + + // Generate an update of the inner product value: + // acc += v1 * v2, OR + // acc += CONJ(v1) * v2, OR + // acc ||= v1 && v2 + // + // CONJ parameter specifies whether the first complex product argument + // needs to be conjugated. + template + mlir::Value genAccumulateProduct(mlir::Value acc, mlir::Value v1, + mlir::Value v2) { + mlir::Type resultType = acc.getType(); + acc = castToProductType(acc, resultType); + v1 = castToProductType(v1, resultType); + v2 = castToProductType(v2, resultType); + mlir::Value result; + if (mlir::isa(resultType)) { + result = builder.create( + loc, acc, builder.create(loc, v1, v2)); + } else if (mlir::isa(resultType)) { + if constexpr (CONJ) + result = fir::IntrinsicLibrary{builder, loc}.genConjg(resultType, v1); + else + result = v1; + + result = builder.create( + loc, acc, builder.create(loc, result, v2)); + } else if (mlir::isa(resultType)) { + result = builder.create( + loc, acc, builder.create(loc, v1, v2)); + } else if (mlir::isa(resultType)) { + result = builder.create( + loc, acc, builder.create(loc, v1, v2)); + } else { + llvm_unreachable("unsupported type"); + } + + return builder.createConvert(loc, resultType, result); + } + +private: + mlir::Location loc; + fir::FirOpBuilder &builder; + + mlir::Value castToProductType(mlir::Value value, mlir::Type type) { + if (mlir::isa(type)) + return builder.createConvert(loc, builder.getIntegerType(1), value); + + // TODO: the multiplications/additions by/of zero resulting from + // complex * real are optimized by LLVM under -fno-signed-zeros + // -fno-honor-nans. + // We can make them disappear by default if we: + // * either expand the complex multiplication into real + // operations, OR + // * set nnan nsz fast-math flags to the complex operations. + if (fir::isa_complex(type) && !fir::isa_complex(value.getType())) { + mlir::Value zeroCmplx = fir::factory::createZeroValue(builder, loc, type); + fir::factory::Complex helper(builder, loc); + mlir::Type partType = helper.getComplexPartType(type); + return helper.insertComplexPart(zeroCmplx, + castToProductType(value, partType), + /*isImagPart=*/false); + } + return builder.createConvert(loc, type, value); + } +}; + class TransposeAsElementalConversion : public mlir::OpRewritePattern { public: @@ -90,11 +163,8 @@ private: static mlir::Value genResultShape(mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity array) { - mlir::Value inShape = hlfir::genShape(loc, builder, array); - llvm::SmallVector inExtents = - hlfir::getExplicitExtentsFromShape(inShape, builder); - if (inShape.getUses().empty()) - inShape.getDefiningOp()->erase(); + llvm::SmallVector inExtents = + hlfir::genExtentsVector(loc, builder, array); // transpose indices assert(inExtents.size() == 2 && "checked in TransposeOp::validate"); @@ -137,7 +207,7 @@ public: mlir::Value resultShape, dimExtent; llvm::SmallVector arrayExtents; if (isTotalReduction) - arrayExtents = genArrayExtents(loc, builder, array); + arrayExtents = hlfir::genExtentsVector(loc, builder, array); else std::tie(resultShape, dimExtent) = genResultShapeForPartialReduction(loc, builder, array, dimVal); @@ -163,7 +233,8 @@ public: // If DIM is not present, do total reduction. // Initial value for the reduction. - mlir::Value reductionInitValue = genInitValue(loc, builder, elementType); + mlir::Value reductionInitValue = + fir::factory::createZeroValue(builder, loc, elementType); // The reduction loop may be unordered if FastMathFlags::reassoc // transformations are allowed. The integer reduction is always @@ -264,17 +335,6 @@ public: } private: - static llvm::SmallVector - genArrayExtents(mlir::Location loc, fir::FirOpBuilder &builder, - hlfir::Entity array) { - mlir::Value inShape = hlfir::genShape(loc, builder, array); - llvm::SmallVector inExtents = - hlfir::getExplicitExtentsFromShape(inShape, builder); - if (inShape.getUses().empty()) - inShape.getDefiningOp()->erase(); - return inExtents; - } - // Return fir.shape specifying the shape of the result // of a SUM reduction with DIM=dimVal. The second return value // is the extent of the DIM dimension. @@ -283,7 +343,7 @@ private: fir::FirOpBuilder &builder, hlfir::Entity array, int64_t dimVal) { llvm::SmallVector inExtents = - genArrayExtents(loc, builder, array); + hlfir::genExtentsVector(loc, builder, array); assert(dimVal > 0 && dimVal <= static_cast(inExtents.size()) && "DIM must be present and a positive constant not exceeding " "the array's rank"); @@ -293,26 +353,6 @@ private: return {builder.create(loc, inExtents), dimExtent}; } - // Generate the initial value for a SUM reduction with the given - // data type. - static mlir::Value genInitValue(mlir::Location loc, - fir::FirOpBuilder &builder, - mlir::Type elementType) { - if (auto ty = mlir::dyn_cast(elementType)) { - const llvm::fltSemantics &sem = ty.getFloatSemantics(); - return builder.createRealConstant(loc, elementType, - llvm::APFloat::getZero(sem)); - } else if (auto ty = mlir::dyn_cast(elementType)) { - mlir::Value initValue = genInitValue(loc, builder, ty.getElementType()); - return fir::factory::Complex{builder, loc}.createComplex(ty, initValue, - initValue); - } else if (mlir::isa(elementType)) { - return builder.createIntegerConstant(loc, elementType, 0); - } - - llvm_unreachable("unsupported SUM reduction type"); - } - // Generate scalar addition of the two values (of the same data type). static mlir::Value genScalarAdd(mlir::Location loc, fir::FirOpBuilder &builder, @@ -570,16 +610,10 @@ private: static std::tuple genResultShape(mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity input1, hlfir::Entity input2) { - mlir::Value input1Shape = hlfir::genShape(loc, builder, input1); - llvm::SmallVector input1Extents = - hlfir::getExplicitExtentsFromShape(input1Shape, builder); - if (input1Shape.getUses().empty()) - input1Shape.getDefiningOp()->erase(); - mlir::Value input2Shape = hlfir::genShape(loc, builder, input2); - llvm::SmallVector input2Extents = - hlfir::getExplicitExtentsFromShape(input2Shape, builder); - if (input2Shape.getUses().empty()) - input2Shape.getDefiningOp()->erase(); + llvm::SmallVector input1Extents = + hlfir::genExtentsVector(loc, builder, input1); + llvm::SmallVector input2Extents = + hlfir::genExtentsVector(loc, builder, input2); llvm::SmallVector newExtents; mlir::Value innerProduct1Extent, innerProduct2Extent; @@ -627,60 +661,6 @@ private: innerProductExtent[0]}; } - static mlir::Value castToProductType(mlir::Location loc, - fir::FirOpBuilder &builder, - mlir::Value value, mlir::Type type) { - if (mlir::isa(type)) - return builder.createConvert(loc, builder.getIntegerType(1), value); - - // TODO: the multiplications/additions by/of zero resulting from - // complex * real are optimized by LLVM under -fno-signed-zeros - // -fno-honor-nans. - // We can make them disappear by default if we: - // * either expand the complex multiplication into real - // operations, OR - // * set nnan nsz fast-math flags to the complex operations. - if (fir::isa_complex(type) && !fir::isa_complex(value.getType())) { - mlir::Value zeroCmplx = fir::factory::createZeroValue(builder, loc, type); - fir::factory::Complex helper(builder, loc); - mlir::Type partType = helper.getComplexPartType(type); - return helper.insertComplexPart( - zeroCmplx, castToProductType(loc, builder, value, partType), - /*isImagPart=*/false); - } - return builder.createConvert(loc, type, value); - } - - // Generate an update of the inner product value: - // acc += v1 * v2, OR - // acc ||= v1 && v2 - static mlir::Value genAccumulateProduct(mlir::Location loc, - fir::FirOpBuilder &builder, - mlir::Type resultType, - mlir::Value acc, mlir::Value v1, - mlir::Value v2) { - acc = castToProductType(loc, builder, acc, resultType); - v1 = castToProductType(loc, builder, v1, resultType); - v2 = castToProductType(loc, builder, v2, resultType); - mlir::Value result; - if (mlir::isa(resultType)) - result = builder.create( - loc, acc, builder.create(loc, v1, v2)); - else if (mlir::isa(resultType)) - result = builder.create( - loc, acc, builder.create(loc, v1, v2)); - else if (mlir::isa(resultType)) - result = builder.create( - loc, acc, builder.create(loc, v1, v2)); - else if (mlir::isa(resultType)) - result = builder.create( - loc, acc, builder.create(loc, v1, v2)); - else - llvm_unreachable("unsupported type"); - - return builder.createConvert(loc, resultType, result); - } - static mlir::LogicalResult genContiguousMatmul(mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity result, mlir::Value resultShape, @@ -748,9 +728,9 @@ private: hlfir::loadElementAt(loc, builder, lhs, {I, K}); hlfir::Entity rhsElementValue = hlfir::loadElementAt(loc, builder, rhs, {K, J}); - mlir::Value productValue = genAccumulateProduct( - loc, builder, resultElementType, resultElementValue, - lhsElementValue, rhsElementValue); + mlir::Value productValue = + ProductFactory{loc, builder}.genAccumulateProduct( + resultElementValue, lhsElementValue, rhsElementValue); builder.create(loc, productValue, resultElement); return {}; }; @@ -785,9 +765,9 @@ private: hlfir::loadElementAt(loc, builder, lhs, {J, K}); hlfir::Entity rhsElementValue = hlfir::loadElementAt(loc, builder, rhs, {K}); - mlir::Value productValue = genAccumulateProduct( - loc, builder, resultElementType, resultElementValue, - lhsElementValue, rhsElementValue); + mlir::Value productValue = + ProductFactory{loc, builder}.genAccumulateProduct( + resultElementValue, lhsElementValue, rhsElementValue); builder.create(loc, productValue, resultElement); return {}; }; @@ -817,9 +797,9 @@ private: hlfir::loadElementAt(loc, builder, lhs, {K}); hlfir::Entity rhsElementValue = hlfir::loadElementAt(loc, builder, rhs, {K, J}); - mlir::Value productValue = genAccumulateProduct( - loc, builder, resultElementType, resultElementValue, - lhsElementValue, rhsElementValue); + mlir::Value productValue = + ProductFactory{loc, builder}.genAccumulateProduct( + resultElementValue, lhsElementValue, rhsElementValue); builder.create(loc, productValue, resultElement); return {}; }; @@ -885,9 +865,9 @@ private: hlfir::loadElementAt(loc, builder, lhs, lhsIndices); hlfir::Entity rhsElementValue = hlfir::loadElementAt(loc, builder, rhs, rhsIndices); - mlir::Value productValue = genAccumulateProduct( - loc, builder, resultElementType, reductionArgs[0], lhsElementValue, - rhsElementValue); + mlir::Value productValue = + ProductFactory{loc, builder}.genAccumulateProduct( + reductionArgs[0], lhsElementValue, rhsElementValue); return {productValue}; }; llvm::SmallVector innerProductValue = @@ -904,6 +884,73 @@ private: } }; +class DotProductConversion + : public mlir::OpRewritePattern { +public: + using mlir::OpRewritePattern::OpRewritePattern; + + llvm::LogicalResult + matchAndRewrite(hlfir::DotProductOp product, + mlir::PatternRewriter &rewriter) const override { + hlfir::Entity op = hlfir::Entity{product}; + if (!op.isScalar()) + return rewriter.notifyMatchFailure(product, "produces non-scalar result"); + + mlir::Location loc = product.getLoc(); + fir::FirOpBuilder builder{rewriter, product.getOperation()}; + hlfir::Entity lhs = hlfir::Entity{product.getLhs()}; + hlfir::Entity rhs = hlfir::Entity{product.getRhs()}; + mlir::Type resultElementType = product.getType(); + bool isUnordered = mlir::isa(resultElementType) || + mlir::isa(resultElementType) || + static_cast(builder.getFastMathFlags() & + mlir::arith::FastMathFlags::reassoc); + + mlir::Value extent = genProductExtent(loc, builder, lhs, rhs); + + auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder, + mlir::ValueRange oneBasedIndices, + mlir::ValueRange reductionArgs) + -> llvm::SmallVector { + hlfir::Entity lhsElementValue = + hlfir::loadElementAt(loc, builder, lhs, oneBasedIndices); + hlfir::Entity rhsElementValue = + hlfir::loadElementAt(loc, builder, rhs, oneBasedIndices); + mlir::Value productValue = + ProductFactory{loc, builder}.genAccumulateProduct( + reductionArgs[0], lhsElementValue, rhsElementValue); + return {productValue}; + }; + + mlir::Value initValue = + fir::factory::createZeroValue(builder, loc, resultElementType); + + llvm::SmallVector result = hlfir::genLoopNestWithReductions( + loc, builder, {extent}, + /*reductionInits=*/{initValue}, genBody, isUnordered); + + rewriter.replaceOp(product, result[0]); + return mlir::success(); + } + +private: + static mlir::Value genProductExtent(mlir::Location loc, + fir::FirOpBuilder &builder, + hlfir::Entity input1, + hlfir::Entity input2) { + llvm::SmallVector input1Extents = + hlfir::genExtentsVector(loc, builder, input1); + llvm::SmallVector input2Extents = + hlfir::genExtentsVector(loc, builder, input2); + + assert(input1Extents.size() == 1 && input2Extents.size() == 1 && + "hlfir.dot_product arguments must be vectors"); + llvm::SmallVector extent = + fir::factory::deduceOptimalExtents(input1Extents, input2Extents); + return extent[0]; + } +}; + class SimplifyHLFIRIntrinsics : public hlfir::impl::SimplifyHLFIRIntrinsicsBase { public: @@ -939,6 +986,8 @@ public: if (forceMatmulAsElemental || this->allowNewSideEffects) patterns.insert>(context); + patterns.insert(context); + if (mlir::failed(mlir::applyPatternsGreedily( getOperation(), std::move(patterns), config))) { mlir::emitError(getOperation()->getLoc(), diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-dotproduct.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-dotproduct.fir new file mode 100644 index 000000000000..f59b1422dbc8 --- /dev/null +++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-dotproduct.fir @@ -0,0 +1,144 @@ +// Test hlfir.dot_product simplification to a reduction loop: +// RUN: fir-opt --simplify-hlfir-intrinsics %s | FileCheck %s + +func.func @dot_product_integer(%arg0: !hlfir.expr, %arg1: !hlfir.expr) -> i32 { + %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr, !hlfir.expr) -> i32 + return %res : i32 +} +// CHECK-LABEL: func.func @dot_product_integer( +// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr, +// CHECK-SAME: %[[VAL_1:.*]]: !hlfir.expr) -> i32 { +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr) -> !fir.shape<1> +// CHECK: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index +// CHECK: %[[VAL_6:.*]] = fir.do_loop %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] unordered iter_args(%[[VAL_8:.*]] = %[[VAL_3]]) -> (i32) { +// CHECK: %[[VAL_9:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_7]] : (!hlfir.expr, index) -> i16 +// CHECK: %[[VAL_10:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_7]] : (!hlfir.expr, index) -> i32 +// CHECK: %[[VAL_11:.*]] = fir.convert %[[VAL_9]] : (i16) -> i32 +// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_11]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_8]], %[[VAL_12]] : i32 +// CHECK: fir.result %[[VAL_13]] : i32 +// CHECK: } +// CHECK: return %[[VAL_6]] : i32 +// CHECK: } + +func.func @dot_product_real(%arg0: !hlfir.expr, %arg1: !hlfir.expr) -> f32 { + %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr, !hlfir.expr) -> f32 + return %res : f32 +} +// CHECK-LABEL: func.func @dot_product_real( +// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr, +// CHECK-SAME: %[[VAL_1:.*]]: !hlfir.expr) -> f32 { +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr) -> !fir.shape<1> +// CHECK: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index +// CHECK: %[[VAL_6:.*]] = fir.do_loop %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_8:.*]] = %[[VAL_3]]) -> (f32) { +// CHECK: %[[VAL_9:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_7]] : (!hlfir.expr, index) -> f32 +// CHECK: %[[VAL_10:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_7]] : (!hlfir.expr, index) -> f16 +// CHECK: %[[VAL_11:.*]] = fir.convert %[[VAL_10]] : (f16) -> f32 +// CHECK: %[[VAL_12:.*]] = arith.mulf %[[VAL_9]], %[[VAL_11]] : f32 +// CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_8]], %[[VAL_12]] : f32 +// CHECK: fir.result %[[VAL_13]] : f32 +// CHECK: } +// CHECK: return %[[VAL_6]] : f32 +// CHECK: } + +func.func @dot_product_complex(%arg0: !hlfir.expr>, %arg1: !hlfir.expr>) -> complex { + %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr>, !hlfir.expr>) -> complex + return %res : complex +} +// CHECK-LABEL: func.func @dot_product_complex( +// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr>, +// CHECK-SAME: %[[VAL_1:.*]]: !hlfir.expr>) -> complex { +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr>) -> !fir.shape<1> +// CHECK: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index +// CHECK: %[[VAL_6:.*]] = fir.undefined complex +// CHECK: %[[VAL_7:.*]] = fir.insert_value %[[VAL_6]], %[[VAL_3]], [0 : index] : (complex, f32) -> complex +// CHECK: %[[VAL_8:.*]] = fir.insert_value %[[VAL_7]], %[[VAL_3]], [1 : index] : (complex, f32) -> complex +// CHECK: %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (complex) { +// CHECK: %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]] : (!hlfir.expr>, index) -> complex +// CHECK: %[[VAL_13:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_10]] : (!hlfir.expr>, index) -> complex +// CHECK: %[[VAL_14:.*]] = fir.convert %[[VAL_13]] : (complex) -> complex +// CHECK: %[[VAL_15:.*]] = fir.extract_value %[[VAL_12]], [1 : index] : (complex) -> f32 +// CHECK: %[[VAL_16:.*]] = arith.negf %[[VAL_15]] : f32 +// CHECK: %[[VAL_17:.*]] = fir.insert_value %[[VAL_12]], %[[VAL_16]], [1 : index] : (complex, f32) -> complex +// CHECK: %[[VAL_18:.*]] = fir.mulc %[[VAL_17]], %[[VAL_14]] : complex +// CHECK: %[[VAL_19:.*]] = fir.addc %[[VAL_11]], %[[VAL_18]] : complex +// CHECK: fir.result %[[VAL_19]] : complex +// CHECK: } +// CHECK: return %[[VAL_9]] : complex +// CHECK: } + +func.func @dot_product_real_complex(%arg0: !hlfir.expr, %arg1: !hlfir.expr>) -> complex { + %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr, !hlfir.expr>) -> complex + return %res : complex +} +// CHECK-LABEL: func.func @dot_product_real_complex( +// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr, +// CHECK-SAME: %[[VAL_1:.*]]: !hlfir.expr>) -> complex { +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr) -> !fir.shape<1> +// CHECK: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index +// CHECK: %[[VAL_6:.*]] = fir.undefined complex +// CHECK: %[[VAL_7:.*]] = fir.insert_value %[[VAL_6]], %[[VAL_3]], [0 : index] : (complex, f32) -> complex +// CHECK: %[[VAL_8:.*]] = fir.insert_value %[[VAL_7]], %[[VAL_3]], [1 : index] : (complex, f32) -> complex +// CHECK: %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (complex) { +// CHECK: %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]] : (!hlfir.expr, index) -> f32 +// CHECK: %[[VAL_13:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_10]] : (!hlfir.expr>, index) -> complex +// CHECK: %[[VAL_14:.*]] = fir.undefined complex +// CHECK: %[[VAL_15:.*]] = fir.insert_value %[[VAL_14]], %[[VAL_3]], [0 : index] : (complex, f32) -> complex +// CHECK: %[[VAL_16:.*]] = fir.insert_value %[[VAL_15]], %[[VAL_3]], [1 : index] : (complex, f32) -> complex +// CHECK: %[[VAL_17:.*]] = fir.insert_value %[[VAL_16]], %[[VAL_12]], [0 : index] : (complex, f32) -> complex +// CHECK: %[[VAL_18:.*]] = fir.convert %[[VAL_13]] : (complex) -> complex +// CHECK: %[[VAL_19:.*]] = fir.extract_value %[[VAL_17]], [1 : index] : (complex) -> f32 +// CHECK: %[[VAL_20:.*]] = arith.negf %[[VAL_19]] : f32 +// CHECK: %[[VAL_21:.*]] = fir.insert_value %[[VAL_17]], %[[VAL_20]], [1 : index] : (complex, f32) -> complex +// CHECK: %[[VAL_22:.*]] = fir.mulc %[[VAL_21]], %[[VAL_18]] : complex +// CHECK: %[[VAL_23:.*]] = fir.addc %[[VAL_11]], %[[VAL_22]] : complex +// CHECK: fir.result %[[VAL_23]] : complex +// CHECK: } +// CHECK: return %[[VAL_9]] : complex +// CHECK: } + +func.func @dot_product_logical(%arg0: !hlfir.expr>, %arg1: !hlfir.expr>) -> !fir.logical<4> { + %res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr>, !hlfir.expr>) -> !fir.logical<4> + return %res : !fir.logical<4> +} +// CHECK-LABEL: func.func @dot_product_logical( +// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr>, +// CHECK-SAME: %[[VAL_1:.*]]: !hlfir.expr>) -> !fir.logical<4> { +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_3:.*]] = arith.constant false +// CHECK: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr>) -> !fir.shape<1> +// CHECK: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index +// CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_3]] : (i1) -> !fir.logical<4> +// CHECK: %[[VAL_7:.*]] = fir.do_loop %[[VAL_8:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] unordered iter_args(%[[VAL_9:.*]] = %[[VAL_6]]) -> (!fir.logical<4>) { +// CHECK: %[[VAL_10:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_8]] : (!hlfir.expr>, index) -> !fir.logical<1> +// CHECK: %[[VAL_11:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_8]] : (!hlfir.expr>, index) -> !fir.logical<4> +// CHECK: %[[VAL_12:.*]] = fir.convert %[[VAL_9]] : (!fir.logical<4>) -> i1 +// CHECK: %[[VAL_13:.*]] = fir.convert %[[VAL_10]] : (!fir.logical<1>) -> i1 +// CHECK: %[[VAL_14:.*]] = fir.convert %[[VAL_11]] : (!fir.logical<4>) -> i1 +// CHECK: %[[VAL_15:.*]] = arith.andi %[[VAL_13]], %[[VAL_14]] : i1 +// CHECK: %[[VAL_16:.*]] = arith.ori %[[VAL_12]], %[[VAL_15]] : i1 +// CHECK: %[[VAL_17:.*]] = fir.convert %[[VAL_16]] : (i1) -> !fir.logical<4> +// CHECK: fir.result %[[VAL_17]] : !fir.logical<4> +// CHECK: } +// CHECK: return %[[VAL_7]] : !fir.logical<4> +// CHECK: } + +func.func @dot_product_known_dim(%arg0: !hlfir.expr<10xf32>, %arg1: !hlfir.expr) -> f32 { + %res1 = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<10xf32>, !hlfir.expr) -> f32 + %res2 = hlfir.dot_product %arg1 %arg0 : (!hlfir.expr, !hlfir.expr<10xf32>) -> f32 + %res = arith.addf %res1, %res2 : f32 + return %res : f32 +} +// CHECK-LABEL: func.func @dot_product_known_dim( +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 10 : index +// CHECK: fir.do_loop %{{.*}} = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_2]] +// CHECK: fir.do_loop %{{.*}} = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_2]]