[mlir] Attach InferTypeOpInterface on SameOperandsAndResultType operations when possible

This allows for inferring the result types of operations in certain situations by using the type of
an operand. This commit allowed for automatically supporting type inference for many more
operations with no additional effort, e.g. nearly all Arithmetic operations now support
result type inferrence with no additional changes.

Differential Revision: https://reviews.llvm.org/D124581
This commit is contained in:
River Riddle
2022-04-26 11:12:45 -07:00
parent 1bd1edaf40
commit 92a836da07
13 changed files with 47 additions and 30 deletions

View File

@@ -12,6 +12,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#define GET_OP_CLASSES

View File

@@ -10,4 +10,5 @@ add_mlir_dialect_library(MLIRStandalone
LINK_LIBS PUBLIC
MLIRIR
MLIRInferTypeOpInterface
)

View File

@@ -13,6 +13,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"

View File

@@ -15,6 +15,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/Support/MathExtras.h"

View File

@@ -14,6 +14,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TensorEncoding.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#define GET_ATTRDEF_CLASSES

View File

@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRQuant
LINK_LIBS PUBLIC
MLIRIR
MLIRInferTypeOpInterface
MLIRSideEffectInterfaces
MLIRSupport
)

View File

@@ -11,5 +11,6 @@ add_mlir_dialect_library(MLIRSparseTensor
LINK_LIBS PUBLIC
MLIRDialect
MLIRIR
MLIRInferTypeOpInterface
MLIRSupport
)

View File

@@ -333,8 +333,25 @@ void Operator::populateTypeInferenceInfo(
// Skip cases currently being custom generated.
// TODO: Remove special cases.
if (getTrait("::mlir::OpTrait::SameOperandsAndResultType"))
if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
// Check for a non-variable length operand to use as the type anchor.
auto *operandI = llvm::find_if(arguments, [](const Argument &arg) {
NamedTypeConstraint *operand = arg.dyn_cast<NamedTypeConstraint *>();
return operand && !operand->isVariableLength();
});
if (operandI == arguments.end())
return;
// Map each of the result types to the anchor operation.
int operandIdx = operandI - arguments.begin();
resultTypeMapping.resize(getNumResults());
for (int i = 0; i < getNumResults(); ++i)
resultTypeMapping[i].emplace_back(operandIdx);
allResultsHaveKnownTypes = true;
traits.push_back(Trait::create(inferTrait->getDefInit()));
return;
}
// We create equivalence classes of argument/result types where arguments
// and results are mapped into the same index space and indices corresponding

View File

@@ -5,9 +5,9 @@ module attributes {shape.lib = [@shape_lib]} {
// expected-remark@+1 {{associated shape function: same_result_shape}}
func.func @tanh(%arg: tensor<10x20xf32>) -> tensor<10x20xf32>
attributes {shape.function = @shape_lib::@same_result_shape} {
// expected-remark@+1 {{no associated way}}
// expected-remark@+1 {{implements InferType op interface}}
%0 = math.tanh %arg : tensor<10x20xf32>
// expected-remark@+1 {{associated shape function: same_result_shape}}
// expected-remark@+1 {{implements InferType op interface}}
%1 = "test.same_operand_result_type"(%0) : (tensor<10x20xf32>) -> tensor<10x20xf32>
return %1 : tensor<10x20xf32>
}

View File

@@ -2608,15 +2608,9 @@ class TableGenBuildInferReturnTypeBaseOp<string mnemonic,
}];
}
// Single variadic arg with SameOperandsAndResultType and InferTypeOpInterface.
// Tests suppression of ambiguous build methods for operations with
// SameOperandsAndResultType and InferTypeOpInterface.
def TableGenBuildOp5 : TableGenBuildInferReturnTypeBaseOp<
"tblgen_build_5", [SameOperandsAndResultType]>;
// Op with InferTypeOpInterface and regions.
def TableGenBuildOp6 : TableGenBuildInferReturnTypeBaseOp<
"tblgen_build_6", [InferTypeOpInterface]> {
def TableGenBuildOp5 : TableGenBuildInferReturnTypeBaseOp<
"tblgen_build_5", [InferTypeOpInterface]> {
let regions = (region AnyRegion:$body);
}

View File

@@ -199,7 +199,7 @@ def NS_HCollectiveParamsOp : NS_Op<"op_collective_params", []> {
let results = (outs AnyType:$b);
}
// CHECK_LABEL: class NS_HCollectiveParamsOp :
// CHECK_LABEL: class HCollectiveParamsOp :
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type b, ::mlir::Value a);
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a);
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {})
@@ -212,7 +212,7 @@ def NS_HCollectiveParamsSuppress0Op : NS_Op<"op_collective_suppress0", []> {
let results = (outs Variadic<I32>:$b);
}
// CHECK_LABEL: class NS_HCollectiveParamsSuppress0Op :
// CHECK_LABEL: class HCollectiveParamsSuppress0Op :
// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a);
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
@@ -224,7 +224,7 @@ def NS_HCollectiveParamsSuppress1Op : NS_Op<"op_collective_suppress1", []> {
let results = (outs I32:$b);
}
// CHECK_LABEL: class NS_HCollectiveParamsSuppress1Op :
// CHECK_LABEL: class HCollectiveParamsSuppress1Op :
// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a);
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
@@ -237,7 +237,7 @@ def NS_HCollectiveParamsSuppress2Op : NS_Op<"op_collective_suppress2", [SameVari
let arguments = (ins Variadic<I32>:$a);
let results = (outs Variadic<I32>:$b, Variadic<F32>:$c);
}
// CHECK_LABEL: class NS_HCollectiveParamsSuppress2Op :
// CHECK_LABEL: class HCollectiveParamsSuppress2Op :
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::TypeRange c, ::mlir::ValueRange a);
// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a);
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
@@ -247,11 +247,11 @@ def NS_IOp : NS_Op<"op_with_same_operands_and_result_types_trait", [SameOperands
let arguments = (ins AnyType:$a, AnyType:$b);
let results = (outs AnyType:$r);
}
// CHECK_LABEL: class NS_IOp :
// CHECK_LABEL: class IOp :
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b);
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b);
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b);
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b);
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
// Check default value of `attributes` for the `genInferredTypeCollectiveParamBuilder` builder
@@ -259,7 +259,7 @@ def NS_JOp : NS_Op<"op_with_InferTypeOpInterface_interface", [DeclareOpInterface
let arguments = (ins AnyType:$a, AnyType:$b);
let results = (outs AnyType:$r);
}
// CHECK_LABEL: class NS_JOp :
// CHECK_LABEL: class JOp :
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b);
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b);
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b);
@@ -292,14 +292,14 @@ def NS_LOp : NS_Op<"op_with_same_operands_and_result_types_unwrapped_attr", [Sam
let arguments = (ins AnyType:$a, AnyType:$b, I32Attr:$attr1);
let results = (outs AnyType:$r);
}
// CHECK_LABEL: class NS_LOp :
// CHECK_LABEL: class LOp :
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1);
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1);
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1);
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b, uint32_t attr1);
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b, uint32_t attr1);
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b, uint32_t attr1);
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1);
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b, uint32_t attr1);
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});

View File

@@ -27,7 +27,12 @@ def OpB : NS_Op<"same_input_output_type_op", [SameOperandsAndResultType]> {
// CHECK: void OpB::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type y, ::mlir::Value x)
// CHECK: odsState.addTypes(y);
// CHECK: void OpB::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value x)
// CHECK: odsState.addTypes({x.getType()});
// CHECK: ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
// CHECK: if (::mlir::succeeded(OpB::inferReturnTypes(odsBuilder.getContext(),
// CHECK: odsState.location, odsState.operands,
// CHECK: odsState.attributes.getDictionary(odsState.getContext()),
// CHECK: /*regions=*/{}, inferredReturnTypes)))
// CHECK: odsState.addTypes(inferredReturnTypes);
def OpC : NS_Op<"three_normal_result_op", []> {
let results = (outs I32:$x, /*unnamed*/I32, I32:$z);

View File

@@ -204,7 +204,7 @@ TEST_F(OpBuildGenTest,
verifyOp(op, {i32Ty, f32Ty}, {*cstI32}, attrs);
}
// The next 2 tests test supression of ambiguous build methods for ops that
// The next test checks supression of ambiguous build methods for ops that
// have a single variadic input, and single non-variadic result, and which
// support the SameOperandsAndResultType trait and and optionally the
// InferOpTypeInterface interface. For such ops, the ODS framework generates
@@ -213,14 +213,8 @@ TEST_F(OpBuildGenTest, BuildMethodsSameOperandsAndResultTypeSuppression) {
testSingleVariadicInputInferredType<test::TableGenBuildOp4>();
}
TEST_F(
OpBuildGenTest,
BuildMethodsSameOperandsAndResultTypeAndInferOpTypeInterfaceSuppression) {
testSingleVariadicInputInferredType<test::TableGenBuildOp5>();
}
TEST_F(OpBuildGenTest, BuildMethodsRegionsAndInferredType) {
auto op = builder.create<test::TableGenBuildOp6>(
auto op = builder.create<test::TableGenBuildOp5>(
loc, ValueRange{*cstI32, *cstF32}, /*attributes=*/noAttrs);
ASSERT_EQ(op->getNumRegions(), 1u);
verifyOp(op, {i32Ty}, {*cstI32, *cstF32}, noAttrs);