[mlir] Attach InferTypeOpInterface on SameOperandsAndResultType operations when possible
This allows for inferring the result types of operations in certain situations by using the type of an operand. This commit allowed for automatically supporting type inference for many more operations with no additional effort, e.g. nearly all Arithmetic operations now support result type inferrence with no additional changes. Differential Revision: https://reviews.llvm.org/D124581
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
||||
@@ -10,4 +10,5 @@ add_mlir_dialect_library(MLIRStandalone
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRInferTypeOpInterface
|
||||
)
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Interfaces/VectorInterfaces.h"
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/TensorEncoding.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
|
||||
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRQuant
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRSupport
|
||||
)
|
||||
|
||||
@@ -11,5 +11,6 @@ add_mlir_dialect_library(MLIRSparseTensor
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRDialect
|
||||
MLIRIR
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRSupport
|
||||
)
|
||||
|
||||
@@ -333,8 +333,25 @@ void Operator::populateTypeInferenceInfo(
|
||||
|
||||
// Skip cases currently being custom generated.
|
||||
// TODO: Remove special cases.
|
||||
if (getTrait("::mlir::OpTrait::SameOperandsAndResultType"))
|
||||
if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
|
||||
// Check for a non-variable length operand to use as the type anchor.
|
||||
auto *operandI = llvm::find_if(arguments, [](const Argument &arg) {
|
||||
NamedTypeConstraint *operand = arg.dyn_cast<NamedTypeConstraint *>();
|
||||
return operand && !operand->isVariableLength();
|
||||
});
|
||||
if (operandI == arguments.end())
|
||||
return;
|
||||
|
||||
// Map each of the result types to the anchor operation.
|
||||
int operandIdx = operandI - arguments.begin();
|
||||
resultTypeMapping.resize(getNumResults());
|
||||
for (int i = 0; i < getNumResults(); ++i)
|
||||
resultTypeMapping[i].emplace_back(operandIdx);
|
||||
|
||||
allResultsHaveKnownTypes = true;
|
||||
traits.push_back(Trait::create(inferTrait->getDefInit()));
|
||||
return;
|
||||
}
|
||||
|
||||
// We create equivalence classes of argument/result types where arguments
|
||||
// and results are mapped into the same index space and indices corresponding
|
||||
|
||||
@@ -5,9 +5,9 @@ module attributes {shape.lib = [@shape_lib]} {
|
||||
// expected-remark@+1 {{associated shape function: same_result_shape}}
|
||||
func.func @tanh(%arg: tensor<10x20xf32>) -> tensor<10x20xf32>
|
||||
attributes {shape.function = @shape_lib::@same_result_shape} {
|
||||
// expected-remark@+1 {{no associated way}}
|
||||
// expected-remark@+1 {{implements InferType op interface}}
|
||||
%0 = math.tanh %arg : tensor<10x20xf32>
|
||||
// expected-remark@+1 {{associated shape function: same_result_shape}}
|
||||
// expected-remark@+1 {{implements InferType op interface}}
|
||||
%1 = "test.same_operand_result_type"(%0) : (tensor<10x20xf32>) -> tensor<10x20xf32>
|
||||
return %1 : tensor<10x20xf32>
|
||||
}
|
||||
|
||||
@@ -2608,15 +2608,9 @@ class TableGenBuildInferReturnTypeBaseOp<string mnemonic,
|
||||
}];
|
||||
}
|
||||
|
||||
// Single variadic arg with SameOperandsAndResultType and InferTypeOpInterface.
|
||||
// Tests suppression of ambiguous build methods for operations with
|
||||
// SameOperandsAndResultType and InferTypeOpInterface.
|
||||
def TableGenBuildOp5 : TableGenBuildInferReturnTypeBaseOp<
|
||||
"tblgen_build_5", [SameOperandsAndResultType]>;
|
||||
|
||||
// Op with InferTypeOpInterface and regions.
|
||||
def TableGenBuildOp6 : TableGenBuildInferReturnTypeBaseOp<
|
||||
"tblgen_build_6", [InferTypeOpInterface]> {
|
||||
def TableGenBuildOp5 : TableGenBuildInferReturnTypeBaseOp<
|
||||
"tblgen_build_5", [InferTypeOpInterface]> {
|
||||
let regions = (region AnyRegion:$body);
|
||||
}
|
||||
|
||||
|
||||
@@ -199,7 +199,7 @@ def NS_HCollectiveParamsOp : NS_Op<"op_collective_params", []> {
|
||||
let results = (outs AnyType:$b);
|
||||
}
|
||||
|
||||
// CHECK_LABEL: class NS_HCollectiveParamsOp :
|
||||
// CHECK_LABEL: class HCollectiveParamsOp :
|
||||
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type b, ::mlir::Value a);
|
||||
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a);
|
||||
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {})
|
||||
@@ -212,7 +212,7 @@ def NS_HCollectiveParamsSuppress0Op : NS_Op<"op_collective_suppress0", []> {
|
||||
let results = (outs Variadic<I32>:$b);
|
||||
}
|
||||
|
||||
// CHECK_LABEL: class NS_HCollectiveParamsSuppress0Op :
|
||||
// CHECK_LABEL: class HCollectiveParamsSuppress0Op :
|
||||
// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a);
|
||||
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
|
||||
|
||||
@@ -224,7 +224,7 @@ def NS_HCollectiveParamsSuppress1Op : NS_Op<"op_collective_suppress1", []> {
|
||||
let results = (outs I32:$b);
|
||||
}
|
||||
|
||||
// CHECK_LABEL: class NS_HCollectiveParamsSuppress1Op :
|
||||
// CHECK_LABEL: class HCollectiveParamsSuppress1Op :
|
||||
// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a);
|
||||
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
|
||||
|
||||
@@ -237,7 +237,7 @@ def NS_HCollectiveParamsSuppress2Op : NS_Op<"op_collective_suppress2", [SameVari
|
||||
let arguments = (ins Variadic<I32>:$a);
|
||||
let results = (outs Variadic<I32>:$b, Variadic<F32>:$c);
|
||||
}
|
||||
// CHECK_LABEL: class NS_HCollectiveParamsSuppress2Op :
|
||||
// CHECK_LABEL: class HCollectiveParamsSuppress2Op :
|
||||
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::TypeRange c, ::mlir::ValueRange a);
|
||||
// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a);
|
||||
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
|
||||
@@ -247,11 +247,11 @@ def NS_IOp : NS_Op<"op_with_same_operands_and_result_types_trait", [SameOperands
|
||||
let arguments = (ins AnyType:$a, AnyType:$b);
|
||||
let results = (outs AnyType:$r);
|
||||
}
|
||||
// CHECK_LABEL: class NS_IOp :
|
||||
// CHECK_LABEL: class IOp :
|
||||
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b);
|
||||
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b);
|
||||
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b);
|
||||
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
|
||||
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b);
|
||||
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
|
||||
|
||||
// Check default value of `attributes` for the `genInferredTypeCollectiveParamBuilder` builder
|
||||
@@ -259,7 +259,7 @@ def NS_JOp : NS_Op<"op_with_InferTypeOpInterface_interface", [DeclareOpInterface
|
||||
let arguments = (ins AnyType:$a, AnyType:$b);
|
||||
let results = (outs AnyType:$r);
|
||||
}
|
||||
// CHECK_LABEL: class NS_JOp :
|
||||
// CHECK_LABEL: class JOp :
|
||||
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b);
|
||||
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b);
|
||||
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b);
|
||||
@@ -292,14 +292,14 @@ def NS_LOp : NS_Op<"op_with_same_operands_and_result_types_unwrapped_attr", [Sam
|
||||
let arguments = (ins AnyType:$a, AnyType:$b, I32Attr:$attr1);
|
||||
let results = (outs AnyType:$r);
|
||||
}
|
||||
// CHECK_LABEL: class NS_LOp :
|
||||
// CHECK_LABEL: class LOp :
|
||||
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1);
|
||||
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1);
|
||||
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1);
|
||||
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b, uint32_t attr1);
|
||||
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b, uint32_t attr1);
|
||||
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b, uint32_t attr1);
|
||||
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
|
||||
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1);
|
||||
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b, uint32_t attr1);
|
||||
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
|
||||
|
||||
|
||||
|
||||
@@ -27,7 +27,12 @@ def OpB : NS_Op<"same_input_output_type_op", [SameOperandsAndResultType]> {
|
||||
// CHECK: void OpB::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type y, ::mlir::Value x)
|
||||
// CHECK: odsState.addTypes(y);
|
||||
// CHECK: void OpB::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value x)
|
||||
// CHECK: odsState.addTypes({x.getType()});
|
||||
// CHECK: ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
|
||||
// CHECK: if (::mlir::succeeded(OpB::inferReturnTypes(odsBuilder.getContext(),
|
||||
// CHECK: odsState.location, odsState.operands,
|
||||
// CHECK: odsState.attributes.getDictionary(odsState.getContext()),
|
||||
// CHECK: /*regions=*/{}, inferredReturnTypes)))
|
||||
// CHECK: odsState.addTypes(inferredReturnTypes);
|
||||
|
||||
def OpC : NS_Op<"three_normal_result_op", []> {
|
||||
let results = (outs I32:$x, /*unnamed*/I32, I32:$z);
|
||||
|
||||
@@ -204,7 +204,7 @@ TEST_F(OpBuildGenTest,
|
||||
verifyOp(op, {i32Ty, f32Ty}, {*cstI32}, attrs);
|
||||
}
|
||||
|
||||
// The next 2 tests test supression of ambiguous build methods for ops that
|
||||
// The next test checks supression of ambiguous build methods for ops that
|
||||
// have a single variadic input, and single non-variadic result, and which
|
||||
// support the SameOperandsAndResultType trait and and optionally the
|
||||
// InferOpTypeInterface interface. For such ops, the ODS framework generates
|
||||
@@ -213,14 +213,8 @@ TEST_F(OpBuildGenTest, BuildMethodsSameOperandsAndResultTypeSuppression) {
|
||||
testSingleVariadicInputInferredType<test::TableGenBuildOp4>();
|
||||
}
|
||||
|
||||
TEST_F(
|
||||
OpBuildGenTest,
|
||||
BuildMethodsSameOperandsAndResultTypeAndInferOpTypeInterfaceSuppression) {
|
||||
testSingleVariadicInputInferredType<test::TableGenBuildOp5>();
|
||||
}
|
||||
|
||||
TEST_F(OpBuildGenTest, BuildMethodsRegionsAndInferredType) {
|
||||
auto op = builder.create<test::TableGenBuildOp6>(
|
||||
auto op = builder.create<test::TableGenBuildOp5>(
|
||||
loc, ValueRange{*cstI32, *cstF32}, /*attributes=*/noAttrs);
|
||||
ASSERT_EQ(op->getNumRegions(), 1u);
|
||||
verifyOp(op, {i32Ty}, {*cstI32, *cstF32}, noAttrs);
|
||||
|
||||
Reference in New Issue
Block a user