[CIR] Refactor type interfaces (#146044)

- Generalizes CIRFPTypeInterface files to CIRTypeInterfaces for future type interfaces additions.
- Renames CIRFPTypeInterface to FPTypeInterface.
- Fixes FPTypeInterface tablegen prefix.

This mirrors incubator changes from https://github.com/llvm/clangir/pull/1713
This commit is contained in:
Henrich Lauko
2025-06-27 16:47:58 +02:00
committed by GitHub
parent dc6d2b841f
commit 61c0a94a90
18 changed files with 54 additions and 58 deletions

View File

@@ -160,18 +160,17 @@ def FPAttr : CIR_Attr<"FP", "fp", [TypedAttrInterface]> {
value of the specified floating-point type. Supporting only CIR FP types.
}];
let parameters = (ins
AttributeSelfTypeParameter<"", "::cir::CIRFPTypeInterface">:$type,
AttributeSelfTypeParameter<"", "::cir::FPTypeInterface">:$type,
APFloatParameter<"">:$value
);
let builders = [
AttrBuilderWithInferredContext<(ins "mlir::Type":$type,
"const llvm::APFloat &":$value), [{
return $_get(type.getContext(), mlir::cast<CIRFPTypeInterface>(type),
value);
return $_get(type.getContext(), mlir::cast<FPTypeInterface>(type), value);
}]>,
AttrBuilder<(ins "mlir::Type":$type,
"const llvm::APFloat &":$value), [{
return $_get($_ctxt, mlir::cast<CIRFPTypeInterface>(type), value);
return $_get($_ctxt, mlir::cast<FPTypeInterface>(type), value);
}]>,
];
let extraClassDeclaration = [{

View File

@@ -16,7 +16,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "clang/CIR/Interfaces/CIRFPTypeInterface.h"
#include "clang/CIR/Interfaces/CIRTypeInterfaces.h"
namespace cir {

View File

@@ -15,7 +15,7 @@
include "clang/CIR/Dialect/IR/CIRDialect.td"
include "clang/CIR/Dialect/IR/CIRTypeConstraints.td"
include "clang/CIR/Interfaces/CIRFPTypeInterface.td"
include "clang/CIR/Interfaces/CIRTypeInterfaces.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
@@ -82,7 +82,7 @@ def CIR_IntType : CIR_Type<"Int", "int",
class CIR_FloatType<string name, string mnemonic> : CIR_Type<name, mnemonic, [
DeclareTypeInterfaceMethods<DataLayoutTypeInterface>,
DeclareTypeInterfaceMethods<CIRFPTypeInterface>
DeclareTypeInterfaceMethods<CIR_FPTypeInterface>
]>;
def CIR_Single : CIR_FloatType<"Single", "float"> {

View File

@@ -6,17 +6,17 @@
//
//===---------------------------------------------------------------------===//
//
// Defines the interface to generically handle CIR floating-point types.
// Defines cir type interfaces.
//
//===----------------------------------------------------------------------===//
#ifndef LLVM_CLANG_INCLUDE_CLANG_CIR_INTERFACES_CIRFPTYPEINTERFACE_H
#define LLVM_CLANG_INCLUDE_CLANG_CIR_INTERFACES_CIRFPTYPEINTERFACE_H
#ifndef CLANG_CIR_INTERFACES_CIRTYPEINTERFACES_H
#define CLANG_CIR_INTERFACES_CIRTYPEINTERFACES_H
#include "mlir/IR/Types.h"
#include "llvm/ADT/APFloat.h"
/// Include the tablegen'd interface declarations.
#include "clang/CIR/Interfaces/CIRFPTypeInterface.h.inc"
#include "clang/CIR/Interfaces/CIRTypeInterfaces.h.inc"
#endif // LLVM_CLANG_INCLUDE_CLANG_CIR_INTERFACES_CIRFPTYPEINTERFACE_H
#endif // CLANG_CIR_INTERFACES_CIRTYPEINTERFACES_H

View File

@@ -6,16 +6,16 @@
//
//===----------------------------------------------------------------------===//
//
// Defines the interface to generically handle CIR floating-point types.
// Defines cir type interfaces.
//
//===----------------------------------------------------------------------===//
#ifndef LLVM_CLANG_INCLUDE_CLANG_CIR_INTERFACES_CIRFPTYPEINTERFACE_TD
#define LLVM_CLANG_INCLUDE_CLANG_CIR_INTERFACES_CIRFPTYPEINTERFACE_TD
#ifndef CLANG_CIR_INTERFACES_CIRTYPEINTERFACES_TD
#define CLANG_CIR_INTERFACES_CIRTYPEINTERFACES_TD
include "mlir/IR/OpBase.td"
def CIRFPTypeInterface : TypeInterface<"CIRFPTypeInterface"> {
def CIR_FPTypeInterface : TypeInterface<"FPTypeInterface"> {
let description = [{
Contains helper functions to query properties about a floating-point type.
}];
@@ -53,4 +53,4 @@ def CIRFPTypeInterface : TypeInterface<"CIRFPTypeInterface"> {
];
}
#endif // LLVM_CLANG_INCLUDE_CLANG_CIR_INTERFACES_CIRFPTYPEINTERFACE_TD
#endif // CLANG_CIR_INTERFACES_CIRTYPEINTERFACES_TD

View File

@@ -21,4 +21,4 @@ endfunction()
add_clang_mlir_op_interface(CIROpInterfaces)
add_clang_mlir_op_interface(CIRLoopOpInterface)
add_clang_mlir_type_interface(CIRFPTypeInterface)
add_clang_mlir_type_interface(CIRTypeInterfaces)

View File

@@ -62,8 +62,7 @@ cir::ConstantOp CIRGenBuilderTy::getConstInt(mlir::Location loc, mlir::Type t,
cir::ConstantOp
clang::CIRGen::CIRGenBuilderTy::getConstFP(mlir::Location loc, mlir::Type t,
llvm::APFloat fpVal) {
assert(mlir::isa<cir::CIRFPTypeInterface>(t) &&
"expected floating point type");
assert(mlir::isa<cir::FPTypeInterface>(t) && "expected floating point type");
return create<cir::ConstantOp>(loc, getAttr<cir::FPAttr>(t, fpVal));
}

View File

@@ -11,7 +11,7 @@
#include "Address.h"
#include "CIRGenTypeCache.h"
#include "clang/CIR/Interfaces/CIRFPTypeInterface.h"
#include "clang/CIR/Interfaces/CIRTypeInterfaces.h"
#include "clang/CIR/MissingFeatures.h"
#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
@@ -141,8 +141,7 @@ public:
bool isSized(mlir::Type ty) {
if (mlir::isa<cir::PointerType, cir::ArrayType, cir::BoolType, cir::IntType,
cir::CIRFPTypeInterface, cir::ComplexType, cir::RecordType>(
ty))
cir::FPTypeInterface, cir::ComplexType, cir::RecordType>(ty))
return true;
if (const auto vt = mlir::dyn_cast<cir::VectorType>(ty))

View File

@@ -195,7 +195,7 @@ ComplexExprEmitter::VisitImaginaryLiteral(const ImaginaryLiteral *il) {
realValueAttr = cir::IntAttr::get(elementTy, 0);
imagValueAttr = cir::IntAttr::get(elementTy, imagValue);
} else {
assert(mlir::isa<cir::CIRFPTypeInterface>(elementTy) &&
assert(mlir::isa<cir::FPTypeInterface>(elementTy) &&
"Expected complex element type to be floating-point");
llvm::APFloat imagValue =

View File

@@ -696,7 +696,7 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &value,
}
mlir::Type ty = cgm.convertType(destType);
assert(mlir::isa<cir::CIRFPTypeInterface>(ty) &&
assert(mlir::isa<cir::FPTypeInterface>(ty) &&
"expected floating-point type");
return cgm.getBuilder().getAttr<cir::FPAttr>(ty, init);
}
@@ -793,7 +793,7 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &value,
builder.getAttr<cir::IntAttr>(complexElemTy, imag));
}
assert(isa<cir::CIRFPTypeInterface>(complexElemTy) &&
assert(isa<cir::FPTypeInterface>(complexElemTy) &&
"expected floating-point type");
llvm::APFloat real = value.getComplexFloatReal();
llvm::APFloat imag = value.getComplexFloatImag();

View File

@@ -155,7 +155,7 @@ public:
mlir::Value VisitFloatingLiteral(const FloatingLiteral *e) {
mlir::Type type = cgf.convertType(e->getType());
assert(mlir::isa<cir::CIRFPTypeInterface>(type) &&
assert(mlir::isa<cir::FPTypeInterface>(type) &&
"expect floating-point type");
return builder.create<cir::ConstantOp>(
cgf.getLoc(e->getExprLoc()),
@@ -331,18 +331,18 @@ public:
cgf.getCIRGenModule().errorNYI("signed bool");
if (cgf.getBuilder().isInt(dstTy))
castKind = cir::CastKind::bool_to_int;
else if (mlir::isa<cir::CIRFPTypeInterface>(dstTy))
else if (mlir::isa<cir::FPTypeInterface>(dstTy))
castKind = cir::CastKind::bool_to_float;
else
llvm_unreachable("Internal error: Cast to unexpected type");
} else if (cgf.getBuilder().isInt(srcTy)) {
if (cgf.getBuilder().isInt(dstTy))
castKind = cir::CastKind::integral;
else if (mlir::isa<cir::CIRFPTypeInterface>(dstTy))
else if (mlir::isa<cir::FPTypeInterface>(dstTy))
castKind = cir::CastKind::int_to_float;
else
llvm_unreachable("Internal error: Cast to unexpected type");
} else if (mlir::isa<cir::CIRFPTypeInterface>(srcTy)) {
} else if (mlir::isa<cir::FPTypeInterface>(srcTy)) {
if (cgf.getBuilder().isInt(dstTy)) {
// If we can't recognize overflow as undefined behavior, assume that
// overflow saturates. This protects against normal optimizations if we
@@ -351,7 +351,7 @@ public:
cgf.getCIRGenModule().errorNYI("strict float cast overflow");
assert(!cir::MissingFeatures::fpConstraints());
castKind = cir::CastKind::float_to_int;
} else if (mlir::isa<cir::CIRFPTypeInterface>(dstTy)) {
} else if (mlir::isa<cir::FPTypeInterface>(dstTy)) {
// TODO: split this to createFPExt/createFPTrunc
return builder.createFloatingCast(src, fullDstTy);
} else {
@@ -654,7 +654,7 @@ public:
if (srcType->isHalfType() &&
!cgf.getContext().getLangOpts().NativeHalfType) {
// Cast to FP using the intrinsic if the half type itself isn't supported.
if (mlir::isa<cir::CIRFPTypeInterface>(mlirDstType)) {
if (mlir::isa<cir::FPTypeInterface>(mlirDstType)) {
if (cgf.getContext().getTargetInfo().useFP16ConversionIntrinsics())
cgf.getCIRGenModule().errorNYI(loc,
"cast via llvm.convert.from.fp16");

View File

@@ -20,7 +20,7 @@ static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value,
static mlir::ParseResult
parseFloatLiteral(mlir::AsmParser &parser,
mlir::FailureOr<llvm::APFloat> &value,
cir::CIRFPTypeInterface fpType);
cir::FPTypeInterface fpType);
static mlir::ParseResult parseConstPtr(mlir::AsmParser &parser,
mlir::IntegerAttr &value);
@@ -158,7 +158,7 @@ static void printFloatLiteral(AsmPrinter &p, APFloat value, Type ty) {
static ParseResult parseFloatLiteral(AsmParser &parser,
FailureOr<APFloat> &value,
CIRFPTypeInterface fpType) {
cir::FPTypeInterface fpType) {
APFloat parsedValue(0.0);
if (parser.parseFloat(fpType.getFloatSemantics(), parsedValue))
@@ -171,11 +171,11 @@ static ParseResult parseFloatLiteral(AsmParser &parser,
FPAttr FPAttr::getZero(Type type) {
return get(type,
APFloat::getZero(
mlir::cast<CIRFPTypeInterface>(type).getFloatSemantics()));
mlir::cast<cir::FPTypeInterface>(type).getFloatSemantics()));
}
LogicalResult FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
CIRFPTypeInterface fpType, APFloat value) {
cir::FPTypeInterface fpType, APFloat value) {
if (APFloat::SemanticsToEnum(fpType.getFloatSemantics()) !=
APFloat::SemanticsToEnum(value.getSemantics()))
return emitError() << "floating-point semantics mismatch";

View File

@@ -421,13 +421,13 @@ LogicalResult cir::CastOp::verify() {
return success();
}
case cir::CastKind::floating: {
if (!mlir::isa<cir::CIRFPTypeInterface>(srcType) ||
!mlir::isa<cir::CIRFPTypeInterface>(resType))
if (!mlir::isa<cir::FPTypeInterface>(srcType) ||
!mlir::isa<cir::FPTypeInterface>(resType))
return emitOpError() << "requires !cir.float type for source and result";
return success();
}
case cir::CastKind::float_to_int: {
if (!mlir::isa<cir::CIRFPTypeInterface>(srcType))
if (!mlir::isa<cir::FPTypeInterface>(srcType))
return emitOpError() << "requires !cir.float type for source";
if (!mlir::dyn_cast<cir::IntType>(resType))
return emitOpError() << "requires !cir.int type for result";
@@ -448,7 +448,7 @@ LogicalResult cir::CastOp::verify() {
return success();
}
case cir::CastKind::float_to_bool: {
if (!mlir::isa<cir::CIRFPTypeInterface>(srcType))
if (!mlir::isa<cir::FPTypeInterface>(srcType))
return emitOpError() << "requires !cir.float type for source";
if (!mlir::isa<cir::BoolType>(resType))
return emitOpError() << "requires !cir.bool type for result";
@@ -464,14 +464,14 @@ LogicalResult cir::CastOp::verify() {
case cir::CastKind::int_to_float: {
if (!mlir::isa<cir::IntType>(srcType))
return emitOpError() << "requires !cir.int type for source";
if (!mlir::isa<cir::CIRFPTypeInterface>(resType))
if (!mlir::isa<cir::FPTypeInterface>(resType))
return emitOpError() << "requires !cir.float type for result";
return success();
}
case cir::CastKind::bool_to_float: {
if (!mlir::isa<cir::BoolType>(srcType))
return emitOpError() << "requires !cir.bool type for source";
if (!mlir::isa<cir::CIRFPTypeInterface>(resType))
if (!mlir::isa<cir::FPTypeInterface>(resType))
return emitOpError() << "requires !cir.float type for result";
return success();
}

View File

@@ -539,8 +539,7 @@ uint64_t FP128Type::getABIAlignment(const mlir::DataLayout &dataLayout,
}
const llvm::fltSemantics &LongDoubleType::getFloatSemantics() const {
return mlir::cast<cir::CIRFPTypeInterface>(getUnderlying())
.getFloatSemantics();
return mlir::cast<cir::FPTypeInterface>(getUnderlying()).getFloatSemantics();
}
llvm::TypeSize

View File

@@ -6,13 +6,13 @@
//
//===----------------------------------------------------------------------===//
//
// Defines the interface to generically handle CIR floating-point types.
// Defines cir type interfaces.
//
//===----------------------------------------------------------------------===//
#include "clang/CIR/Interfaces/CIRFPTypeInterface.h"
#include "clang/CIR/Interfaces/CIRTypeInterfaces.h"
using namespace cir;
/// Include the generated interfaces.
#include "clang/CIR/Interfaces/CIRFPTypeInterface.cpp.inc"
#include "clang/CIR/Interfaces/CIRTypeInterfaces.cpp.inc"

View File

@@ -1,14 +1,14 @@
add_clang_library(MLIRCIRInterfaces
CIROpInterfaces.cpp
CIRLoopOpInterface.cpp
CIRFPTypeInterface.cpp
CIRTypeInterfaces.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces
DEPENDS
MLIRCIREnumsGen
MLIRCIRFPTypeInterfaceIncGen
MLIRCIRTypeInterfacesIncGen
MLIRCIRLoopOpInterfaceIncGen
MLIRCIROpInterfacesIncGen

View File

@@ -529,12 +529,12 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite(
mlir::Type srcTy = elementTypeIfVector(castOp.getSrc().getType());
mlir::Type dstTy = elementTypeIfVector(castOp.getType());
if (!mlir::isa<cir::CIRFPTypeInterface>(dstTy) ||
!mlir::isa<cir::CIRFPTypeInterface>(srcTy))
if (!mlir::isa<cir::FPTypeInterface>(dstTy) ||
!mlir::isa<cir::FPTypeInterface>(srcTy))
return castOp.emitError() << "NYI cast from " << srcTy << " to " << dstTy;
auto getFloatWidth = [](mlir::Type ty) -> unsigned {
return mlir::cast<cir::CIRFPTypeInterface>(ty).getWidth();
return mlir::cast<cir::FPTypeInterface>(ty).getWidth();
};
if (getFloatWidth(srcTy) > getFloatWidth(dstTy))
@@ -928,7 +928,7 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
attr = rewriter.getIntegerAttr(
typeConverter->convertType(op.getType()),
mlir::cast<cir::IntAttr>(op.getValue()).getValue());
} else if (mlir::isa<cir::CIRFPTypeInterface>(op.getType())) {
} else if (mlir::isa<cir::FPTypeInterface>(op.getType())) {
attr = rewriter.getFloatAttr(
typeConverter->convertType(op.getType()),
mlir::cast<cir::FPAttr>(op.getValue()).getValue());
@@ -1349,7 +1349,7 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
}
// Floating point unary operations: + - ++ --
if (mlir::isa<cir::CIRFPTypeInterface>(elementType)) {
if (mlir::isa<cir::FPTypeInterface>(elementType)) {
switch (op.getKind()) {
case cir::UnaryOpKind::Inc: {
assert(!isVector && "++ not allowed on vector types");
@@ -1438,7 +1438,7 @@ mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite(
return op.emitError() << "inconsistent operands' types not supported yet";
mlir::Type type = op.getRhs().getType();
if (!mlir::isa<cir::IntType, cir::BoolType, cir::CIRFPTypeInterface,
if (!mlir::isa<cir::IntType, cir::BoolType, cir::FPTypeInterface,
mlir::IntegerType, cir::VectorType>(type))
return op.emitError() << "operand type not supported yet";
@@ -1601,7 +1601,7 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
/* isSigned=*/false);
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
} else if (mlir::isa<cir::CIRFPTypeInterface>(type)) {
} else if (mlir::isa<cir::FPTypeInterface>(type)) {
mlir::LLVM::FCmpPredicate kind =
convertCmpKindToFCmpPredicate(cmpOp.getKind());
rewriter.replaceOpWithNewOp<mlir::LLVM::FCmpOp>(
@@ -2059,7 +2059,7 @@ mlir::LogicalResult CIRToLLVMVecCmpOpLowering::matchAndRewrite(
op.getLoc(),
convertCmpKindToICmpPredicate(op.getKind(), intType.isSigned()),
adaptor.getLhs(), adaptor.getRhs());
} else if (mlir::isa<cir::CIRFPTypeInterface>(elementType)) {
} else if (mlir::isa<cir::FPTypeInterface>(elementType)) {
bitResult = rewriter.create<mlir::LLVM::FCmpOp>(
op.getLoc(), convertCmpKindToFCmpPredicate(op.getKind()),
adaptor.getLhs(), adaptor.getRhs());

View File

@@ -138,7 +138,7 @@ lowerConstArrayAttr(cir::ConstArrayAttr constArr,
return convertToDenseElementsAttr<cir::IntAttr, mlir::APInt>(
constArr, dims, type, converter->convertType(type));
if (mlir::isa<cir::CIRFPTypeInterface>(type))
if (mlir::isa<cir::FPTypeInterface>(type))
return convertToDenseElementsAttr<cir::FPAttr, mlir::APFloat>(
constArr, dims, type, converter->convertType(type));