[MLIR][LLVM] Import dereferenceable metadata from LLVM IR (#130974)
Add support for importing `dereferenceable` and `dereferenceable_or_null` metadata into LLVM dialect. Add a new attribute which models these two metadata nodes and a new OpInterface.
This commit is contained in:
committed by
GitHub
parent
bddf24ddbd
commit
fc8b2bf2f8
@@ -1267,4 +1267,28 @@ def WorkgroupAttributionAttr
|
||||
let assemblyFormat = "`<` $num_elements `,` $element_type `>`";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DereferenceableAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def LLVM_DereferenceableAttr : LLVM_Attr<"Dereferenceable", "dereferenceable"> {
|
||||
let summary = "LLVM dereferenceable attribute";
|
||||
let description = [{
|
||||
Defines `dereferenceable` or `dereferenceable_or_null` metadata that can
|
||||
be set via the `DereferenceableOpInterface` on an `inttoptr` operation or
|
||||
on a `load` operation which loads a pointer. The attribute is used to
|
||||
denote that the result of these operations is dereferenceable up to a
|
||||
certain number of bytes, represented by `$bytes`. The optional `$mayBeNull`
|
||||
parameter is set to true if the attribute defines `dereferenceable_or_null`
|
||||
metadata.
|
||||
|
||||
See the following links for more details:
|
||||
https://llvm.org/docs/LangRef.html#dereferenceable-metadata
|
||||
https://llvm.org/docs/LangRef.html#dereferenceable-or-null-metadata
|
||||
}];
|
||||
let parameters = (ins "uint64_t":$bytes,
|
||||
DefaultValuedParameter<"bool", "false">:$mayBeNull);
|
||||
let assemblyFormat = "`<` struct(params) `>`";
|
||||
}
|
||||
|
||||
#endif // LLVMIR_ATTRDEFS
|
||||
|
||||
@@ -27,6 +27,10 @@ LogicalResult verifyAccessGroupOpInterface(Operation *op);
|
||||
/// the alias analysis interface.
|
||||
LogicalResult verifyAliasAnalysisOpInterface(Operation *op);
|
||||
|
||||
/// Verifies that the operation implementing the dereferenceable interface has
|
||||
/// exactly one result of LLVM pointer type.
|
||||
LogicalResult verifyDereferenceableOpInterface(Operation *op);
|
||||
|
||||
} // namespace detail
|
||||
} // namespace LLVM
|
||||
} // namespace mlir
|
||||
|
||||
@@ -330,6 +330,43 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
|
||||
];
|
||||
}
|
||||
|
||||
def DereferenceableOpInterface : OpInterface<"DereferenceableOpInterface"> {
|
||||
let description = [{
|
||||
An interface for memory operations that can carry dereferenceable metadata.
|
||||
It provides setters and getters for the operation's dereferenceable
|
||||
attributes. The default implementations of the interface methods expect
|
||||
the operation to have an attribute of type DereferenceableAttr.
|
||||
}];
|
||||
|
||||
let cppNamespace = "::mlir::LLVM";
|
||||
let verify = [{ return detail::verifyDereferenceableOpInterface($_op); }];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
/*desc=*/ "Returns the dereferenceable attribute or nullptr",
|
||||
/*returnType=*/ "::mlir::LLVM::DereferenceableAttr",
|
||||
/*methodName=*/ "getDereferenceableOrNull",
|
||||
/*args=*/ (ins),
|
||||
/*methodBody=*/ [{}],
|
||||
/*defaultImpl=*/ [{
|
||||
auto op = cast<ConcreteOp>(this->getOperation());
|
||||
return op.getDereferenceableAttr();
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/ "Sets the dereferenceable attribute",
|
||||
/*returnType=*/ "void",
|
||||
/*methodName=*/ "setDereferenceable",
|
||||
/*args=*/ (ins "::mlir::LLVM::DereferenceableAttr":$attr),
|
||||
/*methodBody=*/ [{}],
|
||||
/*defaultImpl=*/ [{
|
||||
auto op = cast<ConcreteOp>(this->getOperation());
|
||||
op.setDereferenceableAttr(attr);
|
||||
}]
|
||||
>
|
||||
];
|
||||
}
|
||||
|
||||
def FPExceptionBehaviorOpInterface : OpInterface<"FPExceptionBehaviorOpInterface"> {
|
||||
let description = [{
|
||||
An interface for operations receiving an exception behavior attribute
|
||||
|
||||
@@ -364,7 +364,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
|
||||
[DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
|
||||
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>]> {
|
||||
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>,
|
||||
DeclareOpInterfaceMethods<DereferenceableOpInterface>]> {
|
||||
dag args = (ins LLVM_AnyPointer:$addr,
|
||||
OptionalAttr<I64Attr>:$alignment,
|
||||
UnitAttr:$volatile_,
|
||||
@@ -373,7 +374,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
|
||||
UnitAttr:$invariantGroup,
|
||||
DefaultValuedAttr<
|
||||
AtomicOrdering, "AtomicOrdering::not_atomic">:$ordering,
|
||||
OptionalAttr<StrAttr>:$syncscope);
|
||||
OptionalAttr<StrAttr>:$syncscope,
|
||||
OptionalAttr<LLVM_DereferenceableAttr>:$dereferenceable);
|
||||
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
|
||||
let arguments = !con(args, aliasAttrs);
|
||||
let results = (outs LLVM_LoadableType:$res);
|
||||
@@ -407,6 +409,7 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
|
||||
(`atomic` (`syncscope` `(` $syncscope^ `)`)? $ordering^)?
|
||||
(`invariant` $invariant^)?
|
||||
(`invariant_group` $invariantGroup^)?
|
||||
(`dereferenceable` `` $dereferenceable^)?
|
||||
attr-dict `:` qualified(type($addr)) `->` type($res)
|
||||
}];
|
||||
string llvmBuilder = [{
|
||||
@@ -416,6 +419,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
|
||||
llvm::MDNode *metadata = llvm::MDNode::get(inst->getContext(), std::nullopt);
|
||||
inst->setMetadata(llvm::LLVMContext::MD_invariant_load, metadata);
|
||||
}
|
||||
if ($dereferenceable)
|
||||
moduleTranslation.setDereferenceableMetadata(op, inst);
|
||||
}] # setOrderingCode
|
||||
# setSyncScopeCode
|
||||
# setAlignmentCode
|
||||
@@ -571,6 +576,29 @@ class LLVM_CastOpWithOverflowFlag<string mnemonic, string instName, Type type,
|
||||
}];
|
||||
}
|
||||
|
||||
class LLVM_DereferenceableCastOp<string mnemonic, string instName, Type type,
|
||||
Type resultType, list<Trait> traits = []> :
|
||||
LLVM_Op<mnemonic, !listconcat([Pure], [DeclareOpInterfaceMethods<DereferenceableOpInterface>], traits)> {
|
||||
let arguments = (ins type:$arg, OptionalAttr<LLVM_DereferenceableAttr>:$dereferenceable);
|
||||
let results = (outs resultType:$res);
|
||||
let builders = [LLVM_OneResultOpBuilder];
|
||||
let assemblyFormat = "$arg (`dereferenceable` `` $dereferenceable^)? attr-dict `:` type($arg) `to` type($res)";
|
||||
string llvmInstName = instName;
|
||||
string llvmBuilder = [{
|
||||
auto *val = builder.Create}] # instName # [{($arg, $_resultType);
|
||||
$res = val;
|
||||
if ($dereferenceable) {
|
||||
llvm::Instruction *inst = dyn_cast<llvm::Instruction>(val);
|
||||
moduleTranslation.setDereferenceableMetadata(op, inst);
|
||||
}
|
||||
}];
|
||||
string mlirBuilder = [{
|
||||
auto op = $_builder.create<$_qualCppClassName>(
|
||||
$_location, $_resultType, $arg);
|
||||
$res = op;
|
||||
}];
|
||||
}
|
||||
|
||||
def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "BitCast", LLVM_AnyNonAggregate,
|
||||
LLVM_AnyNonAggregate, [DeclareOpInterfaceMethods<PromotableOpInterface>]> {
|
||||
let hasFolder = 1;
|
||||
@@ -583,7 +611,7 @@ def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "AddrSpaceCast",
|
||||
DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "IntToPtr",
|
||||
def LLVM_IntToPtrOp : LLVM_DereferenceableCastOp<"inttoptr", "IntToPtr",
|
||||
LLVM_ScalarOrVectorOf<AnySignlessInteger>,
|
||||
LLVM_ScalarOrVectorOf<LLVM_AnyPointer>>;
|
||||
def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "PtrToInt",
|
||||
|
||||
@@ -248,6 +248,13 @@ public:
|
||||
LoopAnnotationAttr translateLoopAnnotationAttr(const llvm::MDNode *node,
|
||||
Location loc) const;
|
||||
|
||||
/// Returns the dereferenceable attribute that corresponds to the given LLVM
|
||||
/// dereferenceable or dereferenceable_or_null metadata `node`. `kindID`
|
||||
/// specifies the kind of the metadata node (dereferenceable or
|
||||
/// dereferenceable_or_null).
|
||||
FailureOr<DereferenceableAttr>
|
||||
translateDereferenceableAttr(const llvm::MDNode *node, unsigned kindID);
|
||||
|
||||
/// Returns the alias scope attributes that map to the alias scope nodes
|
||||
/// starting from the metadata `node`. Returns failure, if any of the
|
||||
/// attributes cannot be found.
|
||||
|
||||
@@ -161,6 +161,11 @@ public:
|
||||
/// Sets LLVM TBAA metadata for memory operations that have TBAA attributes.
|
||||
void setTBAAMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst);
|
||||
|
||||
/// Sets LLVM dereferenceable metadata for operations that have
|
||||
/// dereferenceable attributes.
|
||||
void setDereferenceableMetadata(DereferenceableOpInterface op,
|
||||
llvm::Instruction *inst);
|
||||
|
||||
/// Sets LLVM profiling metadata for operations that have branch weights.
|
||||
void setBranchWeightsMetadata(BranchWeightOpInterface op);
|
||||
|
||||
|
||||
@@ -940,6 +940,7 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
|
||||
alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
|
||||
isNonTemporal, isInvariant, isInvariantGroup, ordering,
|
||||
syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
|
||||
/*dereferenceable=*/nullptr,
|
||||
/*access_groups=*/nullptr,
|
||||
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr,
|
||||
/*tbaa=*/nullptr);
|
||||
|
||||
@@ -62,6 +62,23 @@ mlir::LLVM::detail::verifyAliasAnalysisOpInterface(Operation *op) {
|
||||
return isArrayOf<TBAATagAttr>(op, tags);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DereferenceableOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult
|
||||
mlir::LLVM::detail::verifyDereferenceableOpInterface(Operation *op) {
|
||||
auto iface = cast<DereferenceableOpInterface>(op);
|
||||
|
||||
if (auto derefAttr = iface.getDereferenceableOrNull())
|
||||
if (op->getNumResults() != 1 ||
|
||||
!mlir::isa<LLVMPointerType>(op->getResult(0).getType()))
|
||||
return op->emitOpError(
|
||||
"expected op to return a single LLVM pointer type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Value> mlir::LLVM::AtomicCmpXchgOp::getAccessedOperands() {
|
||||
return {getPtr()};
|
||||
}
|
||||
|
||||
@@ -90,6 +90,8 @@ static ArrayRef<unsigned> getSupportedMetadataImpl(llvm::LLVMContext &context) {
|
||||
llvm::LLVMContext::MD_loop,
|
||||
llvm::LLVMContext::MD_noalias,
|
||||
llvm::LLVMContext::MD_alias_scope,
|
||||
llvm::LLVMContext::MD_dereferenceable,
|
||||
llvm::LLVMContext::MD_dereferenceable_or_null,
|
||||
context.getMDKindID(vecTypeHintMDName),
|
||||
context.getMDKindID(workGroupSizeHintMDName),
|
||||
context.getMDKindID(reqdWorkGroupSizeMDName),
|
||||
@@ -188,6 +190,25 @@ static LogicalResult setAccessGroupsAttr(const llvm::MDNode *node,
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Converts the given dereferenceable metadata node to a dereferenceable
|
||||
/// attribute, and attaches it to the imported operation if the translation
|
||||
/// succeeds. Returns failure if the LLVM IR metadata node is ill-formed.
|
||||
static LogicalResult setDereferenceableAttr(const llvm::MDNode *node,
|
||||
unsigned kindID, Operation *op,
|
||||
LLVM::ModuleImport &moduleImport) {
|
||||
auto dereferenceable =
|
||||
moduleImport.translateDereferenceableAttr(node, kindID);
|
||||
if (failed(dereferenceable))
|
||||
return failure();
|
||||
|
||||
auto iface = dyn_cast<DereferenceableOpInterface>(op);
|
||||
if (!iface)
|
||||
return failure();
|
||||
|
||||
iface.setDereferenceable(*dereferenceable);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Converts the given loop metadata node to an MLIR loop annotation attribute
|
||||
/// and attaches it to the imported operation if the translation succeeds.
|
||||
/// Returns failure otherwise.
|
||||
@@ -401,6 +422,13 @@ public:
|
||||
return setAliasScopesAttr(node, op, moduleImport);
|
||||
if (kind == llvm::LLVMContext::MD_noalias)
|
||||
return setNoaliasScopesAttr(node, op, moduleImport);
|
||||
if (kind == llvm::LLVMContext::MD_dereferenceable)
|
||||
return setDereferenceableAttr(node, llvm::LLVMContext::MD_dereferenceable,
|
||||
op, moduleImport);
|
||||
if (kind == llvm::LLVMContext::MD_dereferenceable_or_null)
|
||||
return setDereferenceableAttr(
|
||||
node, llvm::LLVMContext::MD_dereferenceable_or_null, op,
|
||||
moduleImport);
|
||||
|
||||
llvm::LLVMContext &context = node->getContext();
|
||||
if (kind == context.getMDKindID(vecTypeHintMDName))
|
||||
|
||||
@@ -2527,6 +2527,31 @@ ModuleImport::translateLoopAnnotationAttr(const llvm::MDNode *node,
|
||||
return loopAnnotationImporter->translateLoopAnnotation(node, loc);
|
||||
}
|
||||
|
||||
FailureOr<DereferenceableAttr>
|
||||
ModuleImport::translateDereferenceableAttr(const llvm::MDNode *node,
|
||||
unsigned kindID) {
|
||||
Location loc = mlirModule.getLoc();
|
||||
|
||||
// The only operand should be a constant integer representing the number of
|
||||
// dereferenceable bytes.
|
||||
if (node->getNumOperands() != 1)
|
||||
return emitError(loc) << "dereferenceable metadata must have one operand: "
|
||||
<< diagMD(node, llvmModule.get());
|
||||
|
||||
auto *numBytesMD = dyn_cast<llvm::ConstantAsMetadata>(node->getOperand(0));
|
||||
auto *numBytesCst = dyn_cast<llvm::ConstantInt>(numBytesMD->getValue());
|
||||
if (!numBytesCst || !numBytesCst->getValue().isNonNegative())
|
||||
return emitError(loc) << "dereferenceable metadata operand must be a "
|
||||
"non-negative constant integer: "
|
||||
<< diagMD(node, llvmModule.get());
|
||||
|
||||
bool mayBeNull = kindID == llvm::LLVMContext::MD_dereferenceable_or_null;
|
||||
auto derefAttr = builder.getAttr<DereferenceableAttr>(
|
||||
numBytesCst->getZExtValue(), mayBeNull);
|
||||
|
||||
return derefAttr;
|
||||
}
|
||||
|
||||
OwningOpRef<ModuleOp>
|
||||
mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
|
||||
MLIRContext *context, bool emitExpensiveWarnings,
|
||||
|
||||
@@ -1925,6 +1925,22 @@ void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op,
|
||||
inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
|
||||
}
|
||||
|
||||
void ModuleTranslation::setDereferenceableMetadata(
|
||||
DereferenceableOpInterface op, llvm::Instruction *inst) {
|
||||
DereferenceableAttr derefAttr = op.getDereferenceableOrNull();
|
||||
if (!derefAttr)
|
||||
return;
|
||||
|
||||
llvm::MDNode *derefSizeNode = llvm::MDNode::get(
|
||||
getLLVMContext(),
|
||||
llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(
|
||||
llvm::IntegerType::get(getLLVMContext(), 64), derefAttr.getBytes())));
|
||||
unsigned kindId = derefAttr.getMayBeNull()
|
||||
? llvm::LLVMContext::MD_dereferenceable_or_null
|
||||
: llvm::LLVMContext::MD_dereferenceable;
|
||||
inst->setMetadata(kindId, derefSizeNode);
|
||||
}
|
||||
|
||||
void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
|
||||
DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull();
|
||||
if (!weightsAttr)
|
||||
|
||||
7
mlir/test/Dialect/LLVMIR/dereferenceable-invalid.mlir
Normal file
7
mlir/test/Dialect/LLVMIR/dereferenceable-invalid.mlir
Normal file
@@ -0,0 +1,7 @@
|
||||
// RUN: mlir-opt --allow-unregistered-dialect -split-input-file -verify-diagnostics %s
|
||||
|
||||
llvm.func @deref(%arg0: !llvm.ptr) {
|
||||
// expected-error @below {{op expected op to return a single LLVM pointer type}}
|
||||
%0 = llvm.load %arg0 dereferenceable<bytes = 8> {alignment = 8 : i64} : !llvm.ptr -> i64
|
||||
llvm.return
|
||||
}
|
||||
@@ -338,6 +338,17 @@ declare void @llvm.experimental.noalias.scope.decl(metadata)
|
||||
|
||||
; // -----
|
||||
|
||||
; CHECK: import-failure.ll
|
||||
; CHECK-SAME: dereferenceable metadata operand must be a non-negative constant integer
|
||||
define void @deref(i64 %0) {
|
||||
%2 = inttoptr i64 %0 to ptr, !dereferenceable !0
|
||||
ret void
|
||||
}
|
||||
|
||||
!0 = !{i64 -4}
|
||||
|
||||
; // -----
|
||||
|
||||
; CHECK: import-failure.ll
|
||||
; CHECK-SAME: warning: unhandled data layout token: ni:42
|
||||
target datalayout = "e-ni:42-i64:64"
|
||||
|
||||
24
mlir/test/Target/LLVMIR/Import/metadata-dereferenceable.ll
Normal file
24
mlir/test/Target/LLVMIR/Import/metadata-dereferenceable.ll
Normal file
@@ -0,0 +1,24 @@
|
||||
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
|
||||
|
||||
define void @deref(i64 %0, ptr %1) {
|
||||
; CHECK: llvm.inttoptr
|
||||
; CHECK-SAME: dereferenceable<bytes = 4>
|
||||
%3 = inttoptr i64 %0 to ptr, !dereferenceable !0
|
||||
; CHECK: llvm.load
|
||||
; CHECK-SAME: dereferenceable<bytes = 8>
|
||||
%4 = load ptr, ptr %1, align 8, !dereferenceable !1
|
||||
ret void
|
||||
}
|
||||
|
||||
define void @deref_or_null(i64 %0, ptr %1) {
|
||||
; CHECK: llvm.inttoptr
|
||||
; CHECK-SAME: dereferenceable<bytes = 4, mayBeNull = true>
|
||||
%3 = inttoptr i64 %0 to ptr, !dereferenceable_or_null !0
|
||||
; CHECK: llvm.load
|
||||
; CHECK-SAME: dereferenceable<bytes = 8, mayBeNull = true>
|
||||
%4 = load ptr, ptr %1, align 8, !dereferenceable_or_null !1
|
||||
ret void
|
||||
}
|
||||
|
||||
!0 = !{i64 4}
|
||||
!1 = !{i64 8}
|
||||
24
mlir/test/Target/LLVMIR/attribute-dereferenceable.mlir
Normal file
24
mlir/test/Target/LLVMIR/attribute-dereferenceable.mlir
Normal file
@@ -0,0 +1,24 @@
|
||||
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
|
||||
|
||||
llvm.func @deref(%arg0: i64, %arg1: !llvm.ptr) {
|
||||
// CHECK: inttoptr {{.*}} !dereferenceable [[D0:![0-9]+]]
|
||||
%0 = llvm.inttoptr %arg0 dereferenceable<bytes = 4> : i64 to !llvm.ptr
|
||||
%1 = llvm.load %0 {alignment = 4 : i64} : !llvm.ptr -> i32
|
||||
// CHECK: load {{.*}} !dereferenceable [[D1:![0-9]+]]
|
||||
%2 = llvm.load %arg1 dereferenceable<bytes = 8> {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
|
||||
llvm.store %1, %2 {alignment = 4 : i64} : i32, !llvm.ptr
|
||||
llvm.return
|
||||
}
|
||||
|
||||
llvm.func @deref_or_null(%arg0: i64, %arg1: !llvm.ptr) {
|
||||
// CHECK: inttoptr {{.*}} !dereferenceable_or_null [[D0]]
|
||||
%0 = llvm.inttoptr %arg0 dereferenceable<bytes = 4, mayBeNull = true> : i64 to !llvm.ptr
|
||||
%1 = llvm.load %0 {alignment = 4 : i64} : !llvm.ptr -> i32
|
||||
// CHECK: load {{.*}} !dereferenceable_or_null [[D1]]
|
||||
%2 = llvm.load %arg1 dereferenceable<bytes = 8, mayBeNull = true> {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
|
||||
llvm.store %1, %2 {alignment = 4 : i64} : i32, !llvm.ptr
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK: [[D0]] = !{i64 4}
|
||||
// CHECK: [[D1]] = !{i64 8}
|
||||
Reference in New Issue
Block a user