[CIR] Upstream __real__ for ComplexType (#144261)

This change adds support for __real__ for ComplexType

https://github.com/llvm/llvm-project/issues/141365
This commit is contained in:
Amr Hesham
2025-06-25 18:06:57 +02:00
committed by GitHub
parent 3e337bc308
commit 9a7720ad2f
10 changed files with 190 additions and 3 deletions

View File

@@ -2371,6 +2371,35 @@ def ComplexCreateOp : CIR_Op<"complex.create", [Pure, SameTypeOperands]> {
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// ComplexRealOp
//===----------------------------------------------------------------------===//
def ComplexRealOp : CIR_Op<"complex.real", [Pure]> {
let summary = "Extract the real part of a complex value";
let description = [{
`cir.complex.real` operation takes an operand of `!cir.complex` type and
yields the real part of it.
Example:
```mlir
%1 = cir.complex.real %0 : !cir.complex<!cir.float> -> !cir.float
```
}];
let results = (outs CIR_AnyIntOrFloatType:$result);
let arguments = (ins CIR_ComplexType:$operand);
let assemblyFormat = [{
$operand `:` qualified(type($operand)) `->` qualified(type($result))
attr-dict
}];
let hasVerifier = 1;
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// Assume Operations
//===----------------------------------------------------------------------===//

View File

@@ -366,6 +366,11 @@ public:
return create<cir::ComplexCreateOp>(loc, resultComplexTy, real, imag);
}
mlir::Value createComplexReal(mlir::Location loc, mlir::Value operand) {
auto operandTy = mlir::cast<cir::ComplexType>(operand.getType());
return create<cir::ComplexRealOp>(loc, operandTy.getElementType(), operand);
}
/// Create a cir.ptr_stride operation to get access to an array element.
/// \p idx is the index of the element to access, \p shouldDecay is true if
/// the result should decay to a pointer to the element type.

View File

@@ -603,6 +603,8 @@ public:
mlir::Value VisitUnaryLNot(const UnaryOperator *e);
mlir::Value VisitUnaryReal(const UnaryOperator *e);
mlir::Value VisitCXXThisExpr(CXXThisExpr *te) { return cgf.loadCXXThis(); }
/// Emit a conversion from the specified type to the specified destination
@@ -1891,6 +1893,27 @@ mlir::Value ScalarExprEmitter::VisitUnaryLNot(const UnaryOperator *e) {
return maybePromoteBoolResult(boolVal, cgf.convertType(e->getType()));
}
mlir::Value ScalarExprEmitter::VisitUnaryReal(const UnaryOperator *e) {
// TODO(cir): handle scalar promotion.
Expr *op = e->getSubExpr();
if (op->getType()->isAnyComplexType()) {
// If it's an l-value, load through the appropriate subobject l-value.
// Note that we have to ask `e` because `op` might be an l-value that
// this won't work for, e.g. an Obj-C property.
if (e->isGLValue()) {
mlir::Location loc = cgf.getLoc(e->getExprLoc());
mlir::Value complex = cgf.emitComplexExpr(op);
return cgf.builder.createComplexReal(loc, complex);
}
// Otherwise, calculate and project.
cgf.cgm.errorNYI(e->getSourceRange(),
"VisitUnaryReal calculate and project");
}
return Visit(op);
}
/// Return the size or alignment of the type of argument of the sizeof
/// expression as an integer.
mlir::Value ScalarExprEmitter::VisitUnaryExprOrTypeTraitExpr(

View File

@@ -1914,6 +1914,24 @@ OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
return cir::ConstComplexAttr::get(realAttr, imagAttr);
}
//===----------------------------------------------------------------------===//
// ComplexRealOp
//===----------------------------------------------------------------------===//
LogicalResult cir::ComplexRealOp::verify() {
if (getType() != getOperand().getType().getElementType()) {
emitOpError() << ": result type does not match operand type";
return failure();
}
return success();
}
OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) {
auto complex =
mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
return complex ? complex.getReal() : nullptr;
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

View File

@@ -141,8 +141,8 @@ void CIRCanonicalizePass::runOnOperation() {
// Many operations are here to perform a manual `fold` in
// applyOpPatternsGreedily.
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp,
ComplexCreateOp, VecCmpOp, VecCreateOp, VecExtractOp, VecShuffleOp,
VecShuffleDynamicOp, VecTernaryOp>(op))
ComplexCreateOp, ComplexRealOp, VecCmpOp, VecCreateOp, VecExtractOp,
VecShuffleOp, VecShuffleDynamicOp, VecTernaryOp>(op))
ops.push_back(op);
});

View File

@@ -1903,7 +1903,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
CIRToLLVMVecShuffleOpLowering,
CIRToLLVMVecShuffleDynamicOpLowering,
CIRToLLVMVecTernaryOpLowering,
CIRToLLVMComplexCreateOpLowering
CIRToLLVMComplexCreateOpLowering,
CIRToLLVMComplexRealOpLowering
// clang-format on
>(converter, patterns.getContext());
@@ -2207,6 +2208,15 @@ mlir::LogicalResult CIRToLLVMComplexCreateOpLowering::matchAndRewrite(
return mlir::success();
}
mlir::LogicalResult CIRToLLVMComplexRealOpLowering::matchAndRewrite(
cir::ComplexRealOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(
op, resultLLVMTy, adaptor.getOperand(), llvm::ArrayRef<std::int64_t>{0});
return mlir::success();
}
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
return std::make_unique<ConvertCIRToLLVMPass>();
}

View File

@@ -443,6 +443,16 @@ public:
mlir::ConversionPatternRewriter &) const override;
};
class CIRToLLVMComplexRealOpLowering
: public mlir::OpConversionPattern<cir::ComplexRealOp> {
public:
using mlir::OpConversionPattern<cir::ComplexRealOp>::OpConversionPattern;
mlir::LogicalResult
matchAndRewrite(cir::ComplexRealOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};
} // namespace direct
} // namespace cir

View File

@@ -216,6 +216,30 @@ void foo9(double a, double b) {
// OGCG: store double %[[TMP_A]], ptr %[[C_REAL_PTR]], align 8
// OGCG: store double %[[TMP_B]], ptr %[[C_IMAG_PTR]], align 8
void foo13() {
double _Complex c;
double real = __real__ c;
}
// CIR: %[[COMPLEX:.*]] = cir.alloca !cir.complex<!cir.double>, !cir.ptr<!cir.complex<!cir.double>>, ["c"]
// CIR: %[[INIT:.*]] = cir.alloca !cir.double, !cir.ptr<!cir.double>, ["real", init]
// CIR: %[[TMP:.*]] = cir.load{{.*}} %[[COMPLEX]] : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double>
// CIR: %[[REAL:.*]] = cir.complex.real %[[TMP]] : !cir.complex<!cir.double> -> !cir.double
// CIR: cir.store{{.*}} %[[REAL]], %[[INIT]] : !cir.double, !cir.ptr<!cir.double>
// LLVM: %[[COMPLEX:.*]] = alloca { double, double }, i64 1, align 8
// LLVM: %[[INIT:.*]] = alloca double, i64 1, align 8
// LLVM: %[[TMP:.*]] = load { double, double }, ptr %[[COMPLEX]], align 8
// LLVM: %[[REAL:.*]] = extractvalue { double, double } %[[TMP]], 0
// LLVM: store double %[[REAL]], ptr %[[INIT]], align 8
// OGCG: %[[COMPLEX:.*]] = alloca { double, double }, align 8
// OGCG: %[[INIT:.*]] = alloca double, align 8
// OGCG: %[[REAL:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX]], i32 0, i32 0
// OGCG: %[[TMP:.*]] = load double, ptr %[[REAL]], align 8
// OGCG: store double %[[TMP]], ptr %[[INIT]], align 8
void foo14() {
int _Complex c = 2i;
}
@@ -256,3 +280,36 @@ void foo15() {
// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 1
// OGCG: store i32 %[[A_REAL]], ptr %[[B_REAL_PTR]], align 4
// OGCG: store i32 %[[A_IMAG]], ptr %[[B_IMAG_PTR]], align 4
int foo17(int _Complex a, int _Complex b) {
return __real__ a + __real__ b;
}
// CIR: %[[RET:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"]
// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
// CIR: %[[A_REAL:.*]] = cir.complex.real %[[COMPLEX_A]] : !cir.complex<!s32i> -> !s32i
// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
// CIR: %[[B_REAL:.*]] = cir.complex.real %[[COMPLEX_B]] : !cir.complex<!s32i> -> !s32i
// CIR: %[[ADD:.*]] = cir.binop(add, %[[A_REAL]], %[[B_REAL]]) nsw : !s32i
// CIR: cir.store %[[ADD]], %[[RET]] : !s32i, !cir.ptr<!s32i>
// CIR: %[[TMP:.*]] = cir.load %[[RET]] : !cir.ptr<!s32i>, !s32i
// CIR: cir.return %[[TMP]] : !s32i
// LLVM: %[[RET:.*]] = alloca i32, i64 1, align 4
// LLVM: %[[COMPLEX_A:.*]] = load { i32, i32 }, ptr {{.*}}, align 4
// LLVM: %[[A_REAL:.*]] = extractvalue { i32, i32 } %[[COMPLEX_A]], 0
// LLVM: %[[COMPLEX_B:.*]] = load { i32, i32 }, ptr {{.*}}, align 4
// LLVM: %[[B_REAL:.*]] = extractvalue { i32, i32 } %[[COMPLEX_B]], 0
// LLVM: %[[ADD:.*]] = add nsw i32 %[[A_REAL]], %[[B_REAL]]
// LLVM: store i32 %[[ADD]], ptr %[[RET]], align 4
// LLVM: %[[TMP:.*]] = load i32, ptr %[[RET]], align 4
// LLVM: ret i32 %[[TMP]]
// OGCG: %[[COMPLEX_A:.*]] = alloca { i32, i32 }, align 4
// OGCG: %[[COMPLEX_B:.*]] = alloca { i32, i32 }, align 4
// OGCG: %[[A_REAL:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_A]], i32 0, i32 0
// OGCG: %[[TMP_A:.*]] = load i32, ptr %[[A_REAL]], align 4
// OGCG: %[[B_REAL:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 0
// OGCG: %[[TMP_B:.*]] = load i32, ptr %[[B_REAL]], align 4
// OGCG: %[[ADD:.*]] = add nsw i32 %[[TMP_A]], %[[TMP_B]]
// OGCG: ret i32 %[[ADD]]

View File

@@ -21,3 +21,15 @@ module {
cir.global external @ci2 = #cir.const_complex<#cir.int<1> : !s32i, #cir.int<2> : !s64i> : !cir.complex<!s32i>
}
// -----
module {
cir.func @complex_real_invalid_result_type() -> !cir.double {
%0 = cir.alloca !cir.complex<!cir.double>, !cir.ptr<!cir.complex<!cir.double>>, ["c"]
%2 = cir.load align(8) %0 : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double>
// expected-error @below {{result type does not match operand type}}
%3 = cir.complex.real %2 : !cir.complex<!cir.double> -> !cir.float
cir.return
}
}

View File

@@ -0,0 +1,23 @@
// RUN: cir-opt %s -cir-canonicalize -o - | FileCheck %s
!s32i = !cir.int<s, 32>
module {
cir.func @fold_complex_real_test() -> !s32i {
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"]
%2 = cir.const #cir.const_complex<#cir.int<1> : !s32i, #cir.int<2> : !s32i> : !cir.complex<!s32i>
%4 = cir.complex.real %2 : !cir.complex<!s32i> -> !s32i
cir.store %4, %0 : !s32i, !cir.ptr<!s32i>
%5 = cir.load %0 : !cir.ptr<!s32i>, !s32i
cir.return %5 : !s32i
}
// CHECK: cir.func @fold_complex_real_test() -> !s32i {
// CHECK: %[[RET:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"]
// CHECK: %[[REAL:.*]] = cir.const #cir.int<1> : !s32i
// CHECK: cir.store %[[REAL]], %[[RET]] : !s32i, !cir.ptr<!s32i>
// CHECK: %[[TMP:.]] = cir.load %[[RET]] : !cir.ptr<!s32i>, !s32i
// CHECK: cir.return %[[TMP]] : !s32i
// CHECK: }
}