[MLIR][LLVM] Import dereferenceable metadata from LLVM IR (#130974)

Add support for importing `dereferenceable` and `dereferenceable_or_null` metadata into LLVM dialect. Add a new attribute which models these two metadata nodes and a new OpInterface.
This commit is contained in:
mihailo-stojanovic
2025-03-14 09:30:47 +01:00
committed by GitHub
parent bddf24ddbd
commit fc8b2bf2f8
15 changed files with 261 additions and 3 deletions

View File

@@ -1267,4 +1267,28 @@ def WorkgroupAttributionAttr
let assemblyFormat = "`<` $num_elements `,` $element_type `>`";
}
//===----------------------------------------------------------------------===//
// DereferenceableAttr
//===----------------------------------------------------------------------===//
def LLVM_DereferenceableAttr : LLVM_Attr<"Dereferenceable", "dereferenceable"> {
let summary = "LLVM dereferenceable attribute";
let description = [{
Defines `dereferenceable` or `dereferenceable_or_null` metadata that can
be set via the `DereferenceableOpInterface` on an `inttoptr` operation or
on a `load` operation which loads a pointer. The attribute is used to
denote that the result of these operations is dereferenceable up to a
certain number of bytes, represented by `$bytes`. The optional `$mayBeNull`
parameter is set to true if the attribute defines `dereferenceable_or_null`
metadata.
See the following links for more details:
https://llvm.org/docs/LangRef.html#dereferenceable-metadata
https://llvm.org/docs/LangRef.html#dereferenceable-or-null-metadata
}];
let parameters = (ins "uint64_t":$bytes,
DefaultValuedParameter<"bool", "false">:$mayBeNull);
let assemblyFormat = "`<` struct(params) `>`";
}
#endif // LLVMIR_ATTRDEFS

View File

@@ -27,6 +27,10 @@ LogicalResult verifyAccessGroupOpInterface(Operation *op);
/// the alias analysis interface.
LogicalResult verifyAliasAnalysisOpInterface(Operation *op);
/// Verifies that the operation implementing the dereferenceable interface has
/// exactly one result of LLVM pointer type.
LogicalResult verifyDereferenceableOpInterface(Operation *op);
} // namespace detail
} // namespace LLVM
} // namespace mlir

View File

@@ -330,6 +330,43 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
];
}
def DereferenceableOpInterface : OpInterface<"DereferenceableOpInterface"> {
let description = [{
An interface for memory operations that can carry dereferenceable metadata.
It provides setters and getters for the operation's dereferenceable
attributes. The default implementations of the interface methods expect
the operation to have an attribute of type DereferenceableAttr.
}];
let cppNamespace = "::mlir::LLVM";
let verify = [{ return detail::verifyDereferenceableOpInterface($_op); }];
let methods = [
InterfaceMethod<
/*desc=*/ "Returns the dereferenceable attribute or nullptr",
/*returnType=*/ "::mlir::LLVM::DereferenceableAttr",
/*methodName=*/ "getDereferenceableOrNull",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getDereferenceableAttr();
}]
>,
InterfaceMethod<
/*desc=*/ "Sets the dereferenceable attribute",
/*returnType=*/ "void",
/*methodName=*/ "setDereferenceable",
/*args=*/ (ins "::mlir::LLVM::DereferenceableAttr":$attr),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
op.setDereferenceableAttr(attr);
}]
>
];
}
def FPExceptionBehaviorOpInterface : OpInterface<"FPExceptionBehaviorOpInterface"> {
let description = [{
An interface for operations receiving an exception behavior attribute

View File

@@ -364,7 +364,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
[DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>]> {
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>,
DeclareOpInterfaceMethods<DereferenceableOpInterface>]> {
dag args = (ins LLVM_AnyPointer:$addr,
OptionalAttr<I64Attr>:$alignment,
UnitAttr:$volatile_,
@@ -373,7 +374,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
UnitAttr:$invariantGroup,
DefaultValuedAttr<
AtomicOrdering, "AtomicOrdering::not_atomic">:$ordering,
OptionalAttr<StrAttr>:$syncscope);
OptionalAttr<StrAttr>:$syncscope,
OptionalAttr<LLVM_DereferenceableAttr>:$dereferenceable);
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
let arguments = !con(args, aliasAttrs);
let results = (outs LLVM_LoadableType:$res);
@@ -407,6 +409,7 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
(`atomic` (`syncscope` `(` $syncscope^ `)`)? $ordering^)?
(`invariant` $invariant^)?
(`invariant_group` $invariantGroup^)?
(`dereferenceable` `` $dereferenceable^)?
attr-dict `:` qualified(type($addr)) `->` type($res)
}];
string llvmBuilder = [{
@@ -416,6 +419,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
llvm::MDNode *metadata = llvm::MDNode::get(inst->getContext(), std::nullopt);
inst->setMetadata(llvm::LLVMContext::MD_invariant_load, metadata);
}
if ($dereferenceable)
moduleTranslation.setDereferenceableMetadata(op, inst);
}] # setOrderingCode
# setSyncScopeCode
# setAlignmentCode
@@ -571,6 +576,29 @@ class LLVM_CastOpWithOverflowFlag<string mnemonic, string instName, Type type,
}];
}
class LLVM_DereferenceableCastOp<string mnemonic, string instName, Type type,
Type resultType, list<Trait> traits = []> :
LLVM_Op<mnemonic, !listconcat([Pure], [DeclareOpInterfaceMethods<DereferenceableOpInterface>], traits)> {
let arguments = (ins type:$arg, OptionalAttr<LLVM_DereferenceableAttr>:$dereferenceable);
let results = (outs resultType:$res);
let builders = [LLVM_OneResultOpBuilder];
let assemblyFormat = "$arg (`dereferenceable` `` $dereferenceable^)? attr-dict `:` type($arg) `to` type($res)";
string llvmInstName = instName;
string llvmBuilder = [{
auto *val = builder.Create}] # instName # [{($arg, $_resultType);
$res = val;
if ($dereferenceable) {
llvm::Instruction *inst = dyn_cast<llvm::Instruction>(val);
moduleTranslation.setDereferenceableMetadata(op, inst);
}
}];
string mlirBuilder = [{
auto op = $_builder.create<$_qualCppClassName>(
$_location, $_resultType, $arg);
$res = op;
}];
}
def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "BitCast", LLVM_AnyNonAggregate,
LLVM_AnyNonAggregate, [DeclareOpInterfaceMethods<PromotableOpInterface>]> {
let hasFolder = 1;
@@ -583,7 +611,7 @@ def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "AddrSpaceCast",
DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
let hasFolder = 1;
}
def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "IntToPtr",
def LLVM_IntToPtrOp : LLVM_DereferenceableCastOp<"inttoptr", "IntToPtr",
LLVM_ScalarOrVectorOf<AnySignlessInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyPointer>>;
def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "PtrToInt",

View File

@@ -248,6 +248,13 @@ public:
LoopAnnotationAttr translateLoopAnnotationAttr(const llvm::MDNode *node,
Location loc) const;
/// Returns the dereferenceable attribute that corresponds to the given LLVM
/// dereferenceable or dereferenceable_or_null metadata `node`. `kindID`
/// specifies the kind of the metadata node (dereferenceable or
/// dereferenceable_or_null).
FailureOr<DereferenceableAttr>
translateDereferenceableAttr(const llvm::MDNode *node, unsigned kindID);
/// Returns the alias scope attributes that map to the alias scope nodes
/// starting from the metadata `node`. Returns failure, if any of the
/// attributes cannot be found.

View File

@@ -161,6 +161,11 @@ public:
/// Sets LLVM TBAA metadata for memory operations that have TBAA attributes.
void setTBAAMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst);
/// Sets LLVM dereferenceable metadata for operations that have
/// dereferenceable attributes.
void setDereferenceableMetadata(DereferenceableOpInterface op,
llvm::Instruction *inst);
/// Sets LLVM profiling metadata for operations that have branch weights.
void setBranchWeightsMetadata(BranchWeightOpInterface op);

View File

@@ -940,6 +940,7 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
isNonTemporal, isInvariant, isInvariantGroup, ordering,
syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
/*dereferenceable=*/nullptr,
/*access_groups=*/nullptr,
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr,
/*tbaa=*/nullptr);

View File

@@ -62,6 +62,23 @@ mlir::LLVM::detail::verifyAliasAnalysisOpInterface(Operation *op) {
return isArrayOf<TBAATagAttr>(op, tags);
}
//===----------------------------------------------------------------------===//
// DereferenceableOpInterface
//===----------------------------------------------------------------------===//
LogicalResult
mlir::LLVM::detail::verifyDereferenceableOpInterface(Operation *op) {
auto iface = cast<DereferenceableOpInterface>(op);
if (auto derefAttr = iface.getDereferenceableOrNull())
if (op->getNumResults() != 1 ||
!mlir::isa<LLVMPointerType>(op->getResult(0).getType()))
return op->emitOpError(
"expected op to return a single LLVM pointer type");
return success();
}
SmallVector<Value> mlir::LLVM::AtomicCmpXchgOp::getAccessedOperands() {
return {getPtr()};
}

View File

@@ -90,6 +90,8 @@ static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) {
llvm::LLVMContext::MD_loop,
llvm::LLVMContext::MD_noalias,
llvm::LLVMContext::MD_alias_scope,
llvm::LLVMContext::MD_dereferenceable,
llvm::LLVMContext::MD_dereferenceable_or_null,
context.getMDKindID(vecTypeHintMDName),
context.getMDKindID(workGroupSizeHintMDName),
context.getMDKindID(reqdWorkGroupSizeMDName),
@@ -188,6 +190,25 @@ static LogicalResult setAccessGroupsAttr(const llvm::MDNode *node,
return success();
}
/// Converts the given dereferenceable metadata node to a dereferenceable
/// attribute, and attaches it to the imported operation if the translation
/// succeeds. Returns failure if the LLVM IR metadata node is ill-formed.
static LogicalResult setDereferenceableAttr(const llvm::MDNode *node,
unsigned kindID, Operation *op,
LLVM::ModuleImport &moduleImport) {
auto dereferenceable =
moduleImport.translateDereferenceableAttr(node, kindID);
if (failed(dereferenceable))
return failure();
auto iface = dyn_cast<DereferenceableOpInterface>(op);
if (!iface)
return failure();
iface.setDereferenceable(*dereferenceable);
return success();
}
/// Converts the given loop metadata node to an MLIR loop annotation attribute
/// and attaches it to the imported operation if the translation succeeds.
/// Returns failure otherwise.
@@ -401,6 +422,13 @@ public:
return setAliasScopesAttr(node, op, moduleImport);
if (kind == llvm::LLVMContext::MD_noalias)
return setNoaliasScopesAttr(node, op, moduleImport);
if (kind == llvm::LLVMContext::MD_dereferenceable)
return setDereferenceableAttr(node, llvm::LLVMContext::MD_dereferenceable,
op, moduleImport);
if (kind == llvm::LLVMContext::MD_dereferenceable_or_null)
return setDereferenceableAttr(
node, llvm::LLVMContext::MD_dereferenceable_or_null, op,
moduleImport);
llvm::LLVMContext &context = node->getContext();
if (kind == context.getMDKindID(vecTypeHintMDName))

View File

@@ -2527,6 +2527,31 @@ ModuleImport::translateLoopAnnotationAttr(const llvm::MDNode *node,
return loopAnnotationImporter->translateLoopAnnotation(node, loc);
}
FailureOr<DereferenceableAttr>
ModuleImport::translateDereferenceableAttr(const llvm::MDNode *node,
unsigned kindID) {
Location loc = mlirModule.getLoc();
// The only operand should be a constant integer representing the number of
// dereferenceable bytes.
if (node->getNumOperands() != 1)
return emitError(loc) << "dereferenceable metadata must have one operand: "
<< diagMD(node, llvmModule.get());
auto *numBytesMD = dyn_cast<llvm::ConstantAsMetadata>(node->getOperand(0));
auto *numBytesCst = dyn_cast<llvm::ConstantInt>(numBytesMD->getValue());
if (!numBytesCst || !numBytesCst->getValue().isNonNegative())
return emitError(loc) << "dereferenceable metadata operand must be a "
"non-negative constant integer: "
<< diagMD(node, llvmModule.get());
bool mayBeNull = kindID == llvm::LLVMContext::MD_dereferenceable_or_null;
auto derefAttr = builder.getAttr<DereferenceableAttr>(
numBytesCst->getZExtValue(), mayBeNull);
return derefAttr;
}
OwningOpRef<ModuleOp>
mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
MLIRContext *context, bool emitExpensiveWarnings,

View File

@@ -1925,6 +1925,22 @@ void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op,
inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
}
void ModuleTranslation::setDereferenceableMetadata(
DereferenceableOpInterface op, llvm::Instruction *inst) {
DereferenceableAttr derefAttr = op.getDereferenceableOrNull();
if (!derefAttr)
return;
llvm::MDNode *derefSizeNode = llvm::MDNode::get(
getLLVMContext(),
llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(
llvm::IntegerType::get(getLLVMContext(), 64), derefAttr.getBytes())));
unsigned kindId = derefAttr.getMayBeNull()
? llvm::LLVMContext::MD_dereferenceable_or_null
: llvm::LLVMContext::MD_dereferenceable;
inst->setMetadata(kindId, derefSizeNode);
}
void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull();
if (!weightsAttr)

View File

@@ -0,0 +1,7 @@
// RUN: mlir-opt --allow-unregistered-dialect -split-input-file -verify-diagnostics %s
llvm.func @deref(%arg0: !llvm.ptr) {
// expected-error @below {{op expected op to return a single LLVM pointer type}}
%0 = llvm.load %arg0 dereferenceable<bytes = 8> {alignment = 8 : i64} : !llvm.ptr -> i64
llvm.return
}

View File

@@ -338,6 +338,17 @@ declare void @llvm.experimental.noalias.scope.decl(metadata)
; // -----
; CHECK: import-failure.ll
; CHECK-SAME: dereferenceable metadata operand must be a non-negative constant integer
define void @deref(i64 %0) {
%2 = inttoptr i64 %0 to ptr, !dereferenceable !0
ret void
}
!0 = !{i64 -4}
; // -----
; CHECK: import-failure.ll
; CHECK-SAME: warning: unhandled data layout token: ni:42
target datalayout = "e-ni:42-i64:64"

View File

@@ -0,0 +1,24 @@
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
define void @deref(i64 %0, ptr %1) {
; CHECK: llvm.inttoptr
; CHECK-SAME: dereferenceable<bytes = 4>
%3 = inttoptr i64 %0 to ptr, !dereferenceable !0
; CHECK: llvm.load
; CHECK-SAME: dereferenceable<bytes = 8>
%4 = load ptr, ptr %1, align 8, !dereferenceable !1
ret void
}
define void @deref_or_null(i64 %0, ptr %1) {
; CHECK: llvm.inttoptr
; CHECK-SAME: dereferenceable<bytes = 4, mayBeNull = true>
%3 = inttoptr i64 %0 to ptr, !dereferenceable_or_null !0
; CHECK: llvm.load
; CHECK-SAME: dereferenceable<bytes = 8, mayBeNull = true>
%4 = load ptr, ptr %1, align 8, !dereferenceable_or_null !1
ret void
}
!0 = !{i64 4}
!1 = !{i64 8}

View File

@@ -0,0 +1,24 @@
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
llvm.func @deref(%arg0: i64, %arg1: !llvm.ptr) {
// CHECK: inttoptr {{.*}} !dereferenceable [[D0:![0-9]+]]
%0 = llvm.inttoptr %arg0 dereferenceable<bytes = 4> : i64 to !llvm.ptr
%1 = llvm.load %0 {alignment = 4 : i64} : !llvm.ptr -> i32
// CHECK: load {{.*}} !dereferenceable [[D1:![0-9]+]]
%2 = llvm.load %arg1 dereferenceable<bytes = 8> {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
llvm.store %1, %2 {alignment = 4 : i64} : i32, !llvm.ptr
llvm.return
}
llvm.func @deref_or_null(%arg0: i64, %arg1: !llvm.ptr) {
// CHECK: inttoptr {{.*}} !dereferenceable_or_null [[D0]]
%0 = llvm.inttoptr %arg0 dereferenceable<bytes = 4, mayBeNull = true> : i64 to !llvm.ptr
%1 = llvm.load %0 {alignment = 4 : i64} : !llvm.ptr -> i32
// CHECK: load {{.*}} !dereferenceable_or_null [[D1]]
%2 = llvm.load %arg1 dereferenceable<bytes = 8, mayBeNull = true> {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
llvm.store %1, %2 {alignment = 4 : i64} : i32, !llvm.ptr
llvm.return
}
// CHECK: [[D0]] = !{i64 4}
// CHECK: [[D1]] = !{i64 8}