[CIR] Clean up IntAttr (#146661)

- Add common CIR_ prefix
- Simplify printing/parsing
- Make it use IntTypeInterface

This mirrors incubator changes from https://github.com/llvm/clangir/pull/1725
This commit is contained in:
Henrich Lauko
2025-07-02 16:36:09 +02:00
committed by GitHub
parent 38ad6b1983
commit 8dcdc0ff1f
5 changed files with 103 additions and 77 deletions

View File

@@ -63,7 +63,7 @@ public:
mlir::Value getConstAPInt(mlir::Location loc, mlir::Type typ,
const llvm::APInt &val) {
return create<cir::ConstantOp>(loc, getAttr<cir::IntAttr>(typ, val));
return create<cir::ConstantOp>(loc, cir::IntAttr::get(typ, val));
}
cir::ConstantOp getConstant(mlir::Location loc, mlir::TypedAttr attr) {

View File

@@ -117,36 +117,67 @@ def UndefAttr : CIR_TypedAttr<"Undef", "undef"> {
// IntegerAttr
//===----------------------------------------------------------------------===//
def IntAttr : CIR_Attr<"Int", "int", [TypedAttrInterface]> {
def CIR_IntAttr : CIR_Attr<"Int", "int", [TypedAttrInterface]> {
let summary = "An attribute containing an integer value";
let description = [{
An integer attribute is a literal attribute that represents an integral
value of the specified integer type.
}];
let parameters = (ins AttributeSelfTypeParameter<"">:$type,
APIntParameter<"">:$value);
let parameters = (ins
AttributeSelfTypeParameter<"", "cir::IntTypeInterface">:$type,
APIntParameter<"">:$value
);
let builders = [
AttrBuilderWithInferredContext<(ins "mlir::Type":$type,
"const llvm::APInt &":$value), [{
return $_get(type.getContext(), type, value);
auto intType = mlir::cast<cir::IntTypeInterface>(type);
return $_get(type.getContext(), intType, value);
}]>,
AttrBuilderWithInferredContext<(ins "mlir::Type":$type,
"int64_t":$value), [{
IntType intType = mlir::cast<IntType>(type);
auto intType = mlir::cast<cir::IntTypeInterface>(type);
mlir::APInt apValue(intType.getWidth(), value, intType.isSigned());
return $_get(intType.getContext(), intType, apValue);
}]>,
];
let extraClassDeclaration = [{
int64_t getSInt() const { return getValue().getSExtValue(); }
uint64_t getUInt() const { return getValue().getZExtValue(); }
bool isNullValue() const { return getValue() == 0; }
uint64_t getBitWidth() const {
return mlir::cast<IntType>(getType()).getWidth();
int64_t getSInt() const;
uint64_t getUInt() const;
bool isNullValue() const;
bool isSigned() const;
bool isUnsigned() const;
uint64_t getBitWidth() const;
}];
let extraClassDefinition = [{
int64_t $cppClass::getSInt() const {
return getValue().getSExtValue();
}
uint64_t $cppClass::getUInt() const {
return getValue().getZExtValue();
}
bool $cppClass::isNullValue() const {
return getValue() == 0;
}
bool $cppClass::isSigned() const {
return mlir::cast<IntTypeInterface>(getType()).isSigned();
}
bool $cppClass::isUnsigned() const {
return mlir::cast<IntTypeInterface>(getType()).isUnsigned();
}
uint64_t $cppClass::getBitWidth() const {
return mlir::cast<IntTypeInterface>(getType()).getWidth();
}
}];
let assemblyFormat = [{
`<` custom<IntLiteral>($value, ref($type)) `>`
}];
let genVerifyDecl = 1;
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//

View File

@@ -684,7 +684,7 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &value,
if (mlir::isa<cir::BoolType>(ty))
return builder.getCIRBoolAttr(value.getInt().getZExtValue());
assert(mlir::isa<cir::IntType>(ty) && "expected integral type");
return cgm.getBuilder().getAttr<cir::IntAttr>(ty, value.getInt());
return cir::IntAttr::get(ty, value.getInt());
}
case APValue::Float: {
const llvm::APFloat &init = value.getFloat();
@@ -789,8 +789,8 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &value,
llvm::APSInt real = value.getComplexIntReal();
llvm::APSInt imag = value.getComplexIntImag();
return builder.getAttr<cir::ConstComplexAttr>(
complexType, builder.getAttr<cir::IntAttr>(complexElemTy, real),
builder.getAttr<cir::IntAttr>(complexElemTy, imag));
complexType, cir::IntAttr::get(complexElemTy, real),
cir::IntAttr::get(complexElemTy, imag));
}
assert(isa<cir::FPTypeInterface>(complexElemTy) &&

View File

@@ -157,8 +157,7 @@ public:
mlir::Value VisitIntegerLiteral(const IntegerLiteral *e) {
mlir::Type type = cgf.convertType(e->getType());
return builder.create<cir::ConstantOp>(
cgf.getLoc(e->getExprLoc()),
builder.getAttr<cir::IntAttr>(type, e->getValue()));
cgf.getLoc(e->getExprLoc()), cir::IntAttr::get(type, e->getValue()));
}
mlir::Value VisitFloatingLiteral(const FloatingLiteral *e) {
@@ -1970,21 +1969,21 @@ mlir::Value ScalarExprEmitter::VisitUnaryExprOrTypeTraitExpr(
"sizeof operator for VariableArrayType",
e->getStmtClassName());
return builder.getConstant(
loc, builder.getAttr<cir::IntAttr>(
cgf.cgm.UInt64Ty, llvm::APSInt(llvm::APInt(64, 1), true)));
loc, cir::IntAttr::get(cgf.cgm.UInt64Ty,
llvm::APSInt(llvm::APInt(64, 1), true)));
}
} else if (e->getKind() == UETT_OpenMPRequiredSimdAlign) {
cgf.getCIRGenModule().errorNYI(
e->getSourceRange(), "sizeof operator for OpenMpRequiredSimdAlign",
e->getStmtClassName());
return builder.getConstant(
loc, builder.getAttr<cir::IntAttr>(
cgf.cgm.UInt64Ty, llvm::APSInt(llvm::APInt(64, 1), true)));
loc, cir::IntAttr::get(cgf.cgm.UInt64Ty,
llvm::APSInt(llvm::APInt(64, 1), true)));
}
return builder.getConstant(
loc, builder.getAttr<cir::IntAttr>(
cgf.cgm.UInt64Ty, e->EvaluateKnownConstInt(cgf.getContext())));
loc, cir::IntAttr::get(cgf.cgm.UInt64Ty,
e->EvaluateKnownConstInt(cgf.getContext())));
}
/// Return true if the specified expression is cheap enough and side-effect-free

View File

@@ -15,6 +15,19 @@
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
//===-----------------------------------------------------------------===//
// IntLiteral
//===-----------------------------------------------------------------===//
static void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value,
cir::IntTypeInterface ty);
static mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser,
llvm::APInt &value,
cir::IntTypeInterface ty);
//===-----------------------------------------------------------------===//
// FloatLiteral
//===-----------------------------------------------------------------===//
static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value,
mlir::Type ty);
static mlir::ParseResult
@@ -82,69 +95,52 @@ static void printConstPtr(AsmPrinter &p, mlir::IntegerAttr value) {
// IntAttr definitions
//===----------------------------------------------------------------------===//
Attribute IntAttr::parse(AsmParser &parser, Type odsType) {
mlir::APInt apValue;
if (!mlir::isa<IntType>(odsType))
return {};
auto type = mlir::cast<IntType>(odsType);
// Consume the '<' symbol.
if (parser.parseLess())
return {};
// Fetch arbitrary precision integer value.
if (type.isSigned()) {
int64_t value = 0;
if (parser.parseInteger(value)) {
parser.emitError(parser.getCurrentLocation(), "expected integer value");
} else {
apValue = mlir::APInt(type.getWidth(), value, type.isSigned(),
/*implicitTrunc=*/true);
if (apValue.getSExtValue() != value)
parser.emitError(parser.getCurrentLocation(),
"integer value too large for the given type");
}
template <typename IntT>
static bool isTooLargeForType(const mlir::APInt &value, IntT expectedValue) {
if constexpr (std::is_signed_v<IntT>) {
return value.getSExtValue() != expectedValue;
} else {
uint64_t value = 0;
if (parser.parseInteger(value)) {
parser.emitError(parser.getCurrentLocation(), "expected integer value");
} else {
apValue = mlir::APInt(type.getWidth(), value, type.isSigned(),
/*implicitTrunc=*/true);
if (apValue.getZExtValue() != value)
parser.emitError(parser.getCurrentLocation(),
"integer value too large for the given type");
}
return value.getZExtValue() != expectedValue;
}
// Consume the '>' symbol.
if (parser.parseGreater())
return {};
return IntAttr::get(type, apValue);
}
void IntAttr::print(AsmPrinter &printer) const {
auto type = mlir::cast<IntType>(getType());
printer << '<';
if (type.isSigned())
printer << getSInt();
template <typename IntT>
static mlir::ParseResult parseIntLiteralImpl(mlir::AsmParser &p,
llvm::APInt &value,
cir::IntTypeInterface ty) {
IntT ivalue;
const bool isSigned = ty.isSigned();
if (p.parseInteger(ivalue))
return p.emitError(p.getCurrentLocation(), "expected integer value");
value = mlir::APInt(ty.getWidth(), ivalue, isSigned, /*implicitTrunc=*/true);
if (isTooLargeForType(value, ivalue))
return p.emitError(p.getCurrentLocation(),
"integer value too large for the given type");
return success();
}
mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser, llvm::APInt &value,
cir::IntTypeInterface ty) {
if (ty.isSigned())
return parseIntLiteralImpl<int64_t>(parser, value, ty);
return parseIntLiteralImpl<uint64_t>(parser, value, ty);
}
void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value,
cir::IntTypeInterface ty) {
if (ty.isSigned())
p << value.getSExtValue();
else
printer << getUInt();
printer << '>';
p << value.getZExtValue();
}
LogicalResult IntAttr::verify(function_ref<InFlightDiagnostic()> emitError,
Type type, APInt value) {
if (!mlir::isa<IntType>(type))
return emitError() << "expected 'simple.int' type";
auto intType = mlir::cast<IntType>(type);
if (value.getBitWidth() != intType.getWidth())
cir::IntTypeInterface type, llvm::APInt value) {
if (value.getBitWidth() != type.getWidth())
return emitError() << "type and value bitwidth mismatch: "
<< intType.getWidth() << " != " << value.getBitWidth();
<< type.getWidth() << " != " << value.getBitWidth();
return success();
}