[CIR] Implement EqualOp for ComplexType (#145769)
This change adds support for equal operation for ComplexType https://github.com/llvm/llvm-project/issues/141365
This commit is contained in:
@@ -2455,6 +2455,31 @@ def ComplexImagOp : CIR_Op<"complex.imag", [Pure]> {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ComplexEqualOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ComplexEqualOp : CIR_Op<"complex.eq", [Pure, SameTypeOperands]> {
|
||||
|
||||
let summary = "Computes whether two complex values are equal";
|
||||
let description = [{
|
||||
The `complex.equal` op takes two complex numbers and returns whether
|
||||
they are equal.
|
||||
|
||||
```mlir
|
||||
%r = cir.complex.eq %a, %b : !cir.complex<!cir.float>
|
||||
```
|
||||
}];
|
||||
|
||||
let results = (outs CIR_BoolType:$result);
|
||||
let arguments = (ins CIR_ComplexType:$lhs, CIR_ComplexType:$rhs);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$lhs `,` $rhs
|
||||
`:` qualified(type($lhs)) attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Assume Operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -894,9 +894,17 @@ public:
|
||||
}
|
||||
} else {
|
||||
// Complex Comparison: can only be an equality comparison.
|
||||
assert(!cir::MissingFeatures::complexType());
|
||||
cgf.cgm.errorNYI(loc, "complex comparison");
|
||||
result = builder.getBool(false, loc);
|
||||
assert(e->getOpcode() == BO_EQ || e->getOpcode() == BO_NE);
|
||||
|
||||
BinOpInfo boInfo = emitBinOps(e);
|
||||
if (e->getOpcode() == BO_EQ) {
|
||||
result =
|
||||
builder.create<cir::ComplexEqualOp>(loc, boInfo.lhs, boInfo.rhs);
|
||||
} else {
|
||||
assert(!cir::MissingFeatures::complexType());
|
||||
cgf.cgm.errorNYI(loc, "complex not equal");
|
||||
result = builder.getBool(false, loc);
|
||||
}
|
||||
}
|
||||
|
||||
return emitScalarConversion(result, cgf.getContext().BoolTy, e->getType(),
|
||||
|
||||
@@ -1900,29 +1900,30 @@ void ConvertCIRToLLVMPass::runOnOperation() {
|
||||
CIRToLLVMBrOpLowering,
|
||||
CIRToLLVMCallOpLowering,
|
||||
CIRToLLVMCmpOpLowering,
|
||||
CIRToLLVMComplexCreateOpLowering,
|
||||
CIRToLLVMComplexEqualOpLowering,
|
||||
CIRToLLVMComplexImagOpLowering,
|
||||
CIRToLLVMComplexRealOpLowering,
|
||||
CIRToLLVMConstantOpLowering,
|
||||
CIRToLLVMExpectOpLowering,
|
||||
CIRToLLVMFuncOpLowering,
|
||||
CIRToLLVMGetGlobalOpLowering,
|
||||
CIRToLLVMGetMemberOpLowering,
|
||||
CIRToLLVMSelectOpLowering,
|
||||
CIRToLLVMSwitchFlatOpLowering,
|
||||
CIRToLLVMShiftOpLowering,
|
||||
CIRToLLVMStackSaveOpLowering,
|
||||
CIRToLLVMStackRestoreOpLowering,
|
||||
CIRToLLVMStackSaveOpLowering,
|
||||
CIRToLLVMSwitchFlatOpLowering,
|
||||
CIRToLLVMTrapOpLowering,
|
||||
CIRToLLVMUnaryOpLowering,
|
||||
CIRToLLVMVecCmpOpLowering,
|
||||
CIRToLLVMVecCreateOpLowering,
|
||||
CIRToLLVMVecExtractOpLowering,
|
||||
CIRToLLVMVecInsertOpLowering,
|
||||
CIRToLLVMVecCmpOpLowering,
|
||||
CIRToLLVMVecSplatOpLowering,
|
||||
CIRToLLVMVecShuffleOpLowering,
|
||||
CIRToLLVMVecShuffleDynamicOpLowering,
|
||||
CIRToLLVMVecTernaryOpLowering,
|
||||
CIRToLLVMComplexCreateOpLowering,
|
||||
CIRToLLVMComplexRealOpLowering,
|
||||
CIRToLLVMComplexImagOpLowering
|
||||
CIRToLLVMVecShuffleOpLowering,
|
||||
CIRToLLVMVecSplatOpLowering,
|
||||
CIRToLLVMVecTernaryOpLowering
|
||||
// clang-format on
|
||||
>(converter, patterns.getContext());
|
||||
|
||||
@@ -2244,6 +2245,43 @@ mlir::LogicalResult CIRToLLVMComplexImagOpLowering::matchAndRewrite(
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult CIRToLLVMComplexEqualOpLowering::matchAndRewrite(
|
||||
cir::ComplexEqualOp op, OpAdaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const {
|
||||
mlir::Value lhs = adaptor.getLhs();
|
||||
mlir::Value rhs = adaptor.getRhs();
|
||||
|
||||
auto complexType = mlir::cast<cir::ComplexType>(op.getLhs().getType());
|
||||
mlir::Type complexElemTy =
|
||||
getTypeConverter()->convertType(complexType.getElementType());
|
||||
|
||||
mlir::Location loc = op.getLoc();
|
||||
auto lhsReal =
|
||||
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0);
|
||||
auto lhsImag =
|
||||
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1);
|
||||
auto rhsReal =
|
||||
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0);
|
||||
auto rhsImag =
|
||||
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1);
|
||||
|
||||
if (complexElemTy.isInteger()) {
|
||||
auto realCmp = rewriter.create<mlir::LLVM::ICmpOp>(
|
||||
loc, mlir::LLVM::ICmpPredicate::eq, lhsReal, rhsReal);
|
||||
auto imagCmp = rewriter.create<mlir::LLVM::ICmpOp>(
|
||||
loc, mlir::LLVM::ICmpPredicate::eq, lhsImag, rhsImag);
|
||||
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, realCmp, imagCmp);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
auto realCmp = rewriter.create<mlir::LLVM::FCmpOp>(
|
||||
loc, mlir::LLVM::FCmpPredicate::oeq, lhsReal, rhsReal);
|
||||
auto imagCmp = rewriter.create<mlir::LLVM::FCmpOp>(
|
||||
loc, mlir::LLVM::FCmpPredicate::oeq, lhsImag, rhsImag);
|
||||
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, realCmp, imagCmp);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
|
||||
return std::make_unique<ConvertCIRToLLVMPass>();
|
||||
}
|
||||
|
||||
@@ -463,6 +463,16 @@ public:
|
||||
mlir::ConversionPatternRewriter &) const override;
|
||||
};
|
||||
|
||||
class CIRToLLVMComplexEqualOpLowering
|
||||
: public mlir::OpConversionPattern<cir::ComplexEqualOp> {
|
||||
public:
|
||||
using mlir::OpConversionPattern<cir::ComplexEqualOp>::OpConversionPattern;
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(cir::ComplexEqualOp op, OpAdaptor,
|
||||
mlir::ConversionPatternRewriter &) const override;
|
||||
};
|
||||
|
||||
} // namespace direct
|
||||
} // namespace cir
|
||||
|
||||
|
||||
@@ -368,4 +368,77 @@ int foo17(int _Complex a, int _Complex b) {
|
||||
// OGCG: %[[B_REAL:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 0
|
||||
// OGCG: %[[TMP_B:.*]] = load i32, ptr %[[B_REAL]], align 4
|
||||
// OGCG: %[[ADD:.*]] = add nsw i32 %[[TMP_A]], %[[TMP_B]]
|
||||
// OGCG: ret i32 %[[ADD]]
|
||||
// OGCG: ret i32 %[[ADD]]
|
||||
|
||||
bool foo18(int _Complex a, int _Complex b) {
|
||||
return a == b;
|
||||
}
|
||||
|
||||
// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
|
||||
// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
|
||||
// CIR: %[[RESULT:.*]] = cir.complex.eq %[[COMPLEX_A]], %[[COMPLEX_B]] : !cir.complex<!s32i>
|
||||
|
||||
// LLVM: %[[COMPLEX_A:.*]] = load { i32, i32 }, ptr {{.*}}, align 4
|
||||
// LLVM: %[[COMPLEX_B:.*]] = load { i32, i32 }, ptr {{.*}}, align 4
|
||||
// LLVM: %[[A_REAL:.*]] = extractvalue { i32, i32 } %[[COMPLEX_A]], 0
|
||||
// LLVM: %[[A_IMAG:.*]] = extractvalue { i32, i32 } %[[COMPLEX_A]], 1
|
||||
// LLVM: %[[B_REAL:.*]] = extractvalue { i32, i32 } %[[COMPLEX_B]], 0
|
||||
// LLVM: %[[B_IMAG:.*]] = extractvalue { i32, i32 } %[[COMPLEX_B]], 1
|
||||
// LLVM: %[[CMP_REAL:.*]] = icmp eq i32 %[[A_REAL]], %[[B_REAL]]
|
||||
// LLVM: %[[CMP_IMAG:.*]] = icmp eq i32 %[[A_IMAG]], %[[B_IMAG]]
|
||||
// LLVM: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]]
|
||||
|
||||
// OGCG: %[[COMPLEX_A:.*]] = alloca { i32, i32 }, align 4
|
||||
// OGCG: %[[COMPLEX_B:.*]] = alloca { i32, i32 }, align 4
|
||||
// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_A]], i32 0, i32 0
|
||||
// OGCG: %[[A_REAL:.*]] = load i32, ptr %[[A_REAL_PTR]], align 4
|
||||
// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_A]], i32 0, i32 1
|
||||
// OGCG: %[[A_IMAG:.*]] = load i32, ptr %[[A_IMAG_PTR]], align 4
|
||||
// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 0
|
||||
// OGCG: %[[B_REAL:.*]] = load i32, ptr %[[B_REAL_PTR]], align 4
|
||||
// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 1
|
||||
// OGCG: %[[B_IMAG:.*]] = load i32, ptr %[[B_IMAG_PTR]], align 4
|
||||
// OGCG: %[[CMP_REAL:.*]] = icmp eq i32 %[[A_REAL]], %[[B_REAL]]
|
||||
// OGCG: %[[CMP_IMAG:.*]] = icmp eq i32 %[[A_IMAG]], %[[B_IMAG]]
|
||||
// OGCG: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]]
|
||||
|
||||
bool foo19(double _Complex a, double _Complex b) {
|
||||
return a == b;
|
||||
}
|
||||
|
||||
// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double>
|
||||
// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double>
|
||||
// CIR: %[[RESULT:.*]] = cir.complex.eq %[[COMPLEX_A]], %[[COMPLEX_B]] : !cir.complex<!cir.double>
|
||||
|
||||
// LLVM: %[[COMPLEX_A:.*]] = load { double, double }, ptr {{.*}}, align 8
|
||||
// LLVM: %[[COMPLEX_B:.*]] = load { double, double }, ptr {{.*}}, align 8
|
||||
// LLVM: %[[A_REAL:.*]] = extractvalue { double, double } %[[COMPLEX_A]], 0
|
||||
// LLVM: %[[A_IMAG:.*]] = extractvalue { double, double } %[[COMPLEX_A]], 1
|
||||
// LLVM: %[[B_REAL:.*]] = extractvalue { double, double } %[[COMPLEX_B]], 0
|
||||
// LLVM: %[[B_IMAG:.*]] = extractvalue { double, double } %[[COMPLEX_B]], 1
|
||||
// LLVM: %[[CMP_REAL:.*]] = fcmp oeq double %[[A_REAL]], %[[B_REAL]]
|
||||
// LLVM: %[[CMP_IMAG:.*]] = fcmp oeq double %[[A_IMAG]], %[[B_IMAG]]
|
||||
// LLVM: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]]
|
||||
|
||||
// OGCG: %[[COMPLEX_A:.*]] = alloca { double, double }, align 8
|
||||
// OGCG: %[[COMPLEX_B:.*]] = alloca { double, double }, align 8
|
||||
// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 0
|
||||
// OGCG: store double {{.*}}, ptr %[[A_REAL_PTR]], align 8
|
||||
// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 1
|
||||
// OGCG: store double {{.*}}, ptr %[[A_IMAG_PTR]], align 8
|
||||
// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 0
|
||||
// OGCG: store double {{.*}}, ptr %[[B_REAL_PTR]], align 8
|
||||
// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 1
|
||||
// OGCG: store double {{.*}}, ptr %[[B_IMAG_PTR]], align 8
|
||||
// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 0
|
||||
// OGCG: %[[A_REAL:.*]] = load double, ptr %[[A_REAL_PTR]], align 8
|
||||
// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 1
|
||||
// OGCG: %[[A_IMAG:.*]] = load double, ptr %[[A_IMAG_PTR]], align 8
|
||||
// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 0
|
||||
// OGCG: %[[B_REAL:.*]] = load double, ptr %[[B_REAL_PTR]], align 8
|
||||
// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 1
|
||||
// OGCG: %[[B_IMAG:.*]] = load double, ptr %[[B_IMAG_PTR]], align 8
|
||||
// OGCG: %[[CMP_REAL:.*]] = fcmp oeq double %[[A_REAL]], %[[B_REAL]]
|
||||
// OGCG: %[[CMP_IMAG:.*]] = fcmp oeq double %[[A_IMAG]], %[[B_IMAG]]
|
||||
// OGCG: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user