[mlir][LLVM] Add nneg flag (#115498)

This implementation is based on the existing one for the exact flag.

If the nneg flag is set and the argument is negative, the result is a
poison value.
This commit is contained in:
lfrenot
2024-11-11 13:01:50 +00:00
committed by GitHub
parent 5e7662efec
commit 89aaf2cf68
7 changed files with 90 additions and 2 deletions

View File

@@ -114,6 +114,33 @@ def ExactFlagInterface : OpInterface<"ExactFlagInterface"> {
];
}
def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> {
let description = [{
This interface defines an LLVM operation with an nneg flag and
provides a uniform API for accessing it.
}];
let cppNamespace = "::mlir::LLVM";
let methods = [
InterfaceMethod<[{
Get the nneg flag for the operation.
}], "bool", "getNonNeg", (ins), [{}], [{
return $_op.getProperties().nonNeg;
}]>,
InterfaceMethod<[{
Set the nneg flag for the operation.
}], "void", "setNonNeg", (ins "bool":$nonNeg), [{}], [{
$_op.getProperties().nonNeg = nonNeg;
}]>,
StaticInterfaceMethod<[{
Get the attribute name of the nonNeg property.
}], "StringRef", "getNonNegName", (ins), [{}], [{
return "nonNeg";
}]>,
];
}
def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
let description = [{
An interface for operations that can carry branch weights metadata. It

View File

@@ -508,6 +508,23 @@ class LLVM_CastOp<string mnemonic, string instName, Type type,
$_location, $_resultType, $arg);
}];
}
class LLVM_CastOpWithNNegFlag<string mnemonic, string instName, Type type,
Type resultType, list<Trait> traits = []> :
LLVM_Op<mnemonic, !listconcat([Pure], [DeclareOpInterfaceMethods<NonNegFlagInterface>], traits)>,
LLVM_Builder<"$res = builder.Create" # instName # "($arg, $_resultType, /*Name=*/\"\", op.getNonNeg());"> {
let arguments = (ins type:$arg, UnitAttr:$nonNeg);
let results = (outs resultType:$res);
let builders = [LLVM_OneResultOpBuilder];
let assemblyFormat = "(`nneg` $nonNeg^)? $arg attr-dict `:` type($arg) `to` type($res)";
string llvmInstName = instName;
string mlirBuilder = [{
auto op = $_builder.create<$_qualCppClassName>(
$_location, $_resultType, $arg);
moduleImport.setNonNegFlag(inst, op);
$res = op;
}];
}
def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "BitCast", LLVM_AnyNonAggregate,
LLVM_AnyNonAggregate, [DeclareOpInterfaceMethods<PromotableOpInterface>]> {
let hasFolder = 1;
@@ -531,7 +548,7 @@ def LLVM_SExtOp : LLVM_CastOp<"sext", "SExt",
LLVM_ScalarOrVectorOf<AnySignlessInteger>> {
let hasVerifier = 1;
}
def LLVM_ZExtOp : LLVM_CastOp<"zext", "ZExt",
def LLVM_ZExtOp : LLVM_CastOpWithNNegFlag<"zext", "ZExt",
LLVM_ScalarOrVectorOf<AnySignlessInteger>,
LLVM_ScalarOrVectorOf<AnySignlessInteger>> {
let hasFolder = 1;
@@ -543,7 +560,7 @@ def LLVM_TruncOp : LLVM_CastOp<"trunc", "Trunc",
def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "SIToFP",
LLVM_ScalarOrVectorOf<AnySignlessInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
def LLVM_UIToFPOp : LLVM_CastOp<"uitofp", "UIToFP",
def LLVM_UIToFPOp : LLVM_CastOpWithNNegFlag<"uitofp", "UIToFP",
LLVM_ScalarOrVectorOf<AnySignlessInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "FPToSI",

View File

@@ -192,6 +192,11 @@ public:
/// implement the exact flag interface.
void setExactFlag(llvm::Instruction *inst, Operation *op) const;
/// Sets the nneg flag attribute for the imported operation `op` given
/// the original instruction `inst`. Asserts if the operation does not
/// implement the nneg flag interface.
void setNonNegFlag(llvm::Instruction *inst, Operation *op) const;
/// Sets the fastmath flags attribute for the imported operation `op` given
/// the original instruction `inst`. Asserts if the operation does not
/// implement the fastmath interface.

View File

@@ -689,6 +689,12 @@ void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const {
iface.setIsExact(inst->isExact());
}
void ModuleImport::setNonNegFlag(llvm::Instruction *inst, Operation *op) const {
auto iface = cast<NonNegFlagInterface>(op);
iface.setNonNeg(inst->hasNonNeg());
}
void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
Operation *op) const {
auto iface = cast<FastmathFlagsInterface>(op);

View File

@@ -325,6 +325,19 @@ func.func @casts(%arg0: i32, %arg1: i64, %arg2: vector<4xi32>,
llvm.return
}
// CHECK-LABEL: @nneg_casts
// CHECK-SAME: (%[[I32:.*]]: i32, %[[I64:.*]]: i64, %[[V4I32:.*]]: vector<4xi32>, %[[V4I64:.*]]: vector<4xi64>, %[[PTR:.*]]: !llvm.ptr)
func.func @nneg_casts(%arg0: i32, %arg1: i64, %arg2: vector<4xi32>,
%arg3: vector<4xi64>, %arg4: !llvm.ptr) {
// CHECK: = llvm.zext nneg %[[I32]] : i32 to i64
%0 = llvm.zext nneg %arg0 : i32 to i64
// CHECK: = llvm.zext nneg %[[V4I32]] : vector<4xi32> to vector<4xi64>
%4 = llvm.zext nneg %arg2 : vector<4xi32> to vector<4xi64>
// CHECK: = llvm.uitofp nneg %[[I32]] : i32 to f32
%7 = llvm.uitofp nneg %arg0 : i32 to f32
llvm.return
}
// CHECK-LABEL: @vect
func.func @vect(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32, %arg3: !llvm.vec<2 x ptr>) {
// CHECK: = llvm.extractelement {{.*}} : vector<4xf32>

View File

@@ -0,0 +1,10 @@
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
; CHECK-LABEL: @nnegflag_inst
define void @nnegflag_inst(i32 %arg1) {
; CHECK: llvm.zext nneg %{{.*}} : i32 to i64
%1 = zext nneg i32 %arg1 to i64
; CHECK: llvm.uitofp nneg %{{.*}} : i32 to f32
%2 = uitofp nneg i32 %arg1 to float
ret void
}

View File

@@ -0,0 +1,10 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
// CHECK-LABEL: define void @nnegflag_func
llvm.func @nnegflag_func(%arg0: i32) {
// CHECK: %{{.*}} = zext nneg i32 %{{.*}} to i64
%0 = llvm.zext nneg %arg0 : i32 to i64
// CHECK: %{{.*}} = uitofp nneg i32 %{{.*}} to float
%1 = llvm.uitofp nneg %arg0 : i32 to f32
llvm.return
}