Introduce MLIR Op Properties

This new features enabled to dedicate custom storage inline within operations.
This storage can be used as an alternative to attributes to store data that is
specific to an operation. Attribute can also be stored inside the properties
storage if desired, but any kind of data can be present as well. This offers
a way to store and mutate data without uniquing in the Context like Attribute.
See the OpPropertiesTest.cpp for an example where a struct with a
std::vector<> is attached to an operation and mutated in-place:

struct TestProperties {
  int a = -1;
  float b = -1.;
  std::vector<int64_t> array = {-33};
};

More complex scheme (including reference-counting) are also possible.

The only constraint to enable storing a C++ object as "properties" on an
operation is to implement three functions:

- convert from the candidate object to an Attribute
- convert from the Attribute to the candidate object
- hash the object

Optional the parsing and printing can also be customized with 2 extra
functions.

A new options is introduced to ODS to allow dialects to specify:

  let usePropertiesForAttributes = 1;

When set to true, the inherent attributes for all the ops in this dialect
will be using properties instead of being stored alongside discardable
attributes.
The TestDialect showcases this feature.

Another change is that we introduce new APIs on the Operation class
to access separately the inherent attributes from the discardable ones.
We envision deprecating and removing the `getAttr()`, `getAttrsDictionary()`,
and other similar method which don't make the distinction explicit, leading
to an entirely separate namespace for discardable attributes.

Differential Revision: https://reviews.llvm.org/D141742
This commit is contained in:
Mehdi Amini
2023-02-26 10:46:01 -05:00
parent 04fc02e583
commit d572cd1b06
84 changed files with 3080 additions and 477 deletions

View File

@@ -290,12 +290,14 @@ Syntax:
operation ::= op-result-list? (generic-operation | custom-operation)
trailing-location?
generic-operation ::= string-literal `(` value-use-list? `)` successor-list?
region-list? dictionary-attribute? `:` function-type
dictionary-properties? region-list? dictionary-attribute?
`:` function-type
custom-operation ::= bare-id custom-operation-format
op-result-list ::= op-result (`,` op-result)* `=`
op-result ::= value-id (`:` integer-literal)
successor-list ::= `[` successor (`,` successor)* `]`
successor ::= caret-id (`:` block-arg-list)?
dictionary-propertes ::= `<` dictionary-attribute `>`
region-list ::= `(` region (`,` region)* `)`
dictionary-attribute ::= `{` (attribute-entry (`,` attribute-entry)*)? `}`
trailing-location ::= (`loc` `(` location `)`)?
@@ -312,9 +314,10 @@ semantics. For example, MLIR supports
The internal representation of an operation is simple: an operation is
identified by a unique string (e.g. `dim`, `tf.Conv2d`, `x86.repmovsb`,
`ppc.eieio`, etc), can return zero or more results, take zero or more operands,
has a dictionary of [attributes](#attributes), has zero or more successors, and
zero or more enclosed [regions](#regions). The generic printing form includes
all these elements literally, with a function type to indicate the types of the
has storage for [properties](#properties), has a dictionary of
[attributes](#attributes), has zero or more successors, and zero or more
enclosed [regions](#regions). The generic printing form includes all these
elements literally, with a function type to indicate the types of the
results and operands.
Example:
@@ -328,8 +331,11 @@ Example:
%foo, %bar = "foo_div"() : () -> (f32, i32)
// Invoke a TensorFlow function called tf.scramble with two inputs
// and an attribute "fruit".
%2 = "tf.scramble"(%result#0, %bar) {fruit = "banana"} : (f32, i32) -> f32
// and an attribute "fruit" stored in properties.
%2 = "tf.scramble"(%result#0, %bar) <{fruit = "banana"}> : (f32, i32) -> f32
// Invoke an operation with some discardable attributes
%foo, %bar = "foo_div"() {some_attr = "value", other_attr = 42 : i64} : () -> (f32, i32)
```
In addition to the basic syntax above, dialects may register known operations.
@@ -733,6 +739,15 @@ The [builtin dialect](Dialects/Builtin.md) defines a set of types that are
directly usable by any other dialect in MLIR. These types cover a range from
primitive integer and floating-point types, function types, and more.
## Properties
Properties are extra data members stored directly on an Operation class. They
provide a way to store [inherent attributes](#attributes) and other arbitrary
data. The semantics of the data is specific to a given operation, and may be
exposed through [Interfaces](Interfaces.md) accessors and other methods.
Properties can always be serialized to Attribute in order to be printed
generically.
## Attributes
Syntax:
@@ -751,9 +766,10 @@ values. MLIR's builtin dialect provides a rich set of
arrays, dictionaries, strings, etc.). Additionally, dialects can define their
own [dialect attribute values](#dialect-attribute-values).
The top-level attribute dictionary attached to an operation has special
semantics. The attribute entries are considered to be of two different kinds
based on whether their dictionary key has a dialect prefix:
For dialects which haven't adopted properties yet, the top-level attribute
dictionary attached to an operation has special semantics. The attribute
entries are considered to be of two different kinds based on whether their
dictionary key has a dialect prefix:
- *inherent attributes* are inherent to the definition of an operation's
semantics. The operation itself is expected to verify the consistency of
@@ -771,6 +787,10 @@ Note that attribute values are allowed to themselves be dictionary attributes,
but only the top-level dictionary attribute attached to the operation is subject
to the classification above.
When properties are adopted, only discardable attributes are stored in the
top-level dictionary, while inherent attributes are stored in the properties
storage.
### Attribute Value Aliases
```

View File

@@ -145,6 +145,11 @@ public:
/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
if constexpr (SourceOp::hasProperties())
rewrite(cast<SourceOp>(op),
OpAdaptor(operands, op->getAttrDictionary(),
cast<SourceOp>(op).getProperties()),
rewriter);
rewrite(cast<SourceOp>(op), OpAdaptor(operands, op->getAttrDictionary()),
rewriter);
}
@@ -154,6 +159,11 @@ public:
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
if constexpr (SourceOp::hasProperties())
return matchAndRewrite(cast<SourceOp>(op),
OpAdaptor(operands, op->getAttrDictionary(),
cast<SourceOp>(op).getProperties()),
rewriter);
return matchAndRewrite(cast<SourceOp>(op),
OpAdaptor(operands, op->getAttrDictionary()),
rewriter);

View File

@@ -103,6 +103,9 @@ class Dialect {
// If this dialect can be extended at runtime with new operations or types.
bit isExtensible = 0;
// Whether inherent Attributes defined in ODS will be stored as Properties.
bit usePropertiesForAttributes = 0;
}
#endif // DIALECTBASE_TD

View File

@@ -26,6 +26,8 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/ErrorHandling.h"
#include <optional>
namespace mlir {
class AsmParser;
@@ -462,6 +464,35 @@ public:
return verifyRegionFn(op);
}
/// Implementation for properties (unsupported right now here).
std::optional<Attribute> getInherentAttr(Operation *op,
StringRef name) final {
llvm::report_fatal_error("Unsupported getInherentAttr on Dynamic dialects");
}
void setInherentAttr(Operation *op, StringAttr name, Attribute value) final {
llvm::report_fatal_error("Unsupported setInherentAttr on Dynamic dialects");
}
void populateInherentAttrs(Operation *op, NamedAttrList &attrs) final {}
LogicalResult
verifyInherentAttrs(OperationName opName, NamedAttrList &attributes,
function_ref<InFlightDiagnostic()> getDiag) final {
return success();
}
int getOpPropertyByteSize() final { return 0; }
void initProperties(OperationName opName, OpaqueProperties storage,
OpaqueProperties init) final {}
void deleteProperties(OpaqueProperties prop) final {}
void populateDefaultProperties(OperationName opName,
OpaqueProperties properties) final {}
LogicalResult setPropertiesFromAttr(Operation *op, Attribute attr,
InFlightDiagnostic *diag) final {
return failure();
}
Attribute getPropertiesAsAttr(Operation *op) final { return {}; }
void copyProperties(OpaqueProperties lhs, OpaqueProperties rhs) final {}
llvm::hash_code hashProperties(OpaqueProperties prop) final { return {}; }
private:
DynamicOpDefinition(
StringRef name, ExtensibleDialect *dialect,

View File

@@ -0,0 +1,45 @@
//===- ODSSupport.h ---------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines a number of support method for ODS generated code.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_ODSSUPPORT_H
#define MLIR_IR_ODSSUPPORT_H
#include "mlir/IR/Attributes.h"
namespace mlir {
//===----------------------------------------------------------------------===//
// Support for properties
//===----------------------------------------------------------------------===//
/// Convert an IntegerAttr attribute to an int64_t, or return an error if the
/// attribute isn't an IntegerAttr. If the optional diagnostic is provided an
/// error message is also emitted.
LogicalResult convertFromAttribute(int64_t &storage, Attribute attr,
InFlightDiagnostic *diag);
/// Convert the provided int64_t to an IntegerAttr attribute.
Attribute convertToAttribute(MLIRContext *ctx, int64_t storage);
/// Convert a DenseI64ArrayAttr to the provided storage. It is expected that the
/// storage has the same size as the array. An error is returned if the
/// attribute isn't a DenseI64ArrayAttr or it does not have the same size. If
/// the optional diagnostic is provided an error message is also emitted.
LogicalResult convertFromAttribute(MutableArrayRef<int64_t> storage,
Attribute attr, InFlightDiagnostic *diag);
/// Convert the provided ArrayRef<int64_t> to a DenseI64ArrayAttr attribute.
Attribute convertToAttribute(MLIRContext *ctx, ArrayRef<int64_t> storage);
} // namespace mlir
#endif // MLIR_IR_ODSSUPPORT_H

View File

@@ -179,6 +179,69 @@ class TypeConstraint<Pred predicate, string summary = "",
string cppClassName = cppClassNameParam;
}
// Base class for defining properties.
class Property<string storageTypeParam = "", string desc = ""> {
// User-readable one line summary used in error reporting messages. If empty,
// a generic message will be used.
string summary = desc;
// The full description of this property.
string description = "";
code storageType = storageTypeParam;
code interfaceType = storageTypeParam;
// The expression to convert from the storage type to the Interface
// type. For example, an enum can be stored as an int but returned as an
// enum class.
//
// Format:
// - `$_storage` will contain the property in the storage type.
// - `$_ctxt` will contain an `MLIRContext *`.
code convertFromStorage = "$_storage";
// The call expression to build a property storage from the interface type.
//
// Format:
// - `$_storage` will contain the property in the storage type.
// - `$_value` will contain the property in the user interface type.
code assignToStorage = "$_storage = $_value";
// The call expression to convert from the storage type to an attribute.
//
// Format:
// - `$_storage` is the storage type value.
// - `$_ctxt` is a `MLIRContext *`.
//
// The expression must result in an Attribute.
code convertToAttribute = [{
convertToAttribute($_ctxt, $_storage)
}];
// The call expression to convert from an Attribute to the storage type.
//
// Format:
// - `$_storage` is the storage type value.
// - `$_attr` is the attribute.
// - `$_diag` is an optional Diagnostic pointer to emit error.
//
// The expression must return a LogicalResult
code convertFromAttribute = [{
return convertFromAttribute($_storage, $_attr, $_diag);
}];
// The call expression to hash the property.
//
// Format:
// - `$_storage` is the variable to hash.
//
// The expression should define a llvm::hash_code.
code hashProperty = [{
llvm::hash_value($_storage);
}];
// Default value for the property.
string defaultValue = ?;
}
// Subclass for constraints on an attribute.
class AttrConstraint<Pred predicate, string summary = ""> :
Constraint<predicate, summary>;
@@ -1090,6 +1153,16 @@ class DefaultValuedStrAttr<Attr attr, string val>
class DefaultValuedOptionalStrAttr<Attr attr, string val>
: DefaultValuedOptionalAttr<attr, "\"" # val # "\"">;
//===----------------------------------------------------------------------===//
// Primitive property kinds
class ArrayProperty<string storageTypeParam = "", int n, string desc = ""> :
Property<storageTypeParam # "[" # n # "]", desc> {
let interfaceType = "::llvm::ArrayRef<" # storageTypeParam # ">";
let convertFromStorage = "$_storage";
let assignToStorage = "::llvm::copy($_value, $_storage)";
}
//===----------------------------------------------------------------------===//
// Primitive attribute kinds

View File

@@ -71,6 +71,21 @@ void ensureRegionTerminator(
} // namespace impl
/// Structure used by default as a "marker" when no "Properties" are set on an
/// Operation.
struct EmptyProperties {};
/// Traits to detect whether an Operation defined a `Properties` type, otherwise
/// it'll default to `EmptyProperties`.
template <class Op, class = void>
struct PropertiesSelector {
using type = EmptyProperties;
};
template <class Op>
struct PropertiesSelector<Op, std::void_t<typename Op::Properties>> {
using type = typename Op::Properties;
};
/// This is the concrete base class that holds the operation pointer and has
/// non-generic methods that only depend on State (to avoid having them
/// instantiated on template types that don't affect them.
@@ -206,6 +221,13 @@ protected:
/// in generic form.
static void print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect);
/// Parse properties as a Attribute.
static ParseResult genericParseProperties(OpAsmParser &parser,
Attribute &result);
/// Print the properties as a Attribute.
static void genericPrintProperties(OpAsmPrinter &p, Attribute properties);
/// Print an operation name, eliding the dialect prefix if necessary.
static void printOpName(Operation *op, OpAsmPrinter &p,
StringRef defaultDialect);
@@ -214,6 +236,14 @@ protected:
/// so we can cast it away here.
explicit OpState(Operation *state) : state(state) {}
/// For all op which don't have properties, we keep a single instance of
/// `EmptyProperties` to be used where a reference to a properties is needed:
/// this allow to bind a pointer to the reference without triggering UB.
static EmptyProperties &getEmptyProperties() {
static EmptyProperties emptyProperties;
return emptyProperties;
}
private:
Operation *state;
@@ -1471,13 +1501,17 @@ namespace op_definition_impl {
/// Returns true if this given Trait ID matches the IDs of any of the provided
/// trait types `Traits`.
template <template <typename T> class... Traits>
static bool hasTrait(TypeID traitID) {
inline bool hasTrait(TypeID traitID) {
TypeID traitIDs[] = {TypeID::get<Traits>()...};
for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
if (traitIDs[i] == traitID)
return true;
return false;
}
template <>
inline bool hasTrait<>(TypeID traitID) {
return false;
}
//===----------------------------------------------------------------------===//
// Trait Folding
@@ -1693,6 +1727,33 @@ public:
(checkInterfaceTarget<Models>(), ...);
info->attachInterface<Models...>();
}
/// Convert the provided attribute to a property and assigned it to the
/// provided properties. This default implementation forwards to a free
/// function `setPropertiesFromAttribute` that can be looked up with ADL in
/// the namespace where the properties are defined. It can also be overridden
/// in the derived ConcreteOp.
template <typename PropertiesTy>
static LogicalResult setPropertiesFromAttr(PropertiesTy &prop, Attribute attr,
InFlightDiagnostic *diag) {
return setPropertiesFromAttribute(prop, attr, diag);
}
/// Convert the provided properties to an attribute. This default
/// implementation forwards to a free function `getPropertiesAsAttribute` that
/// can be looked up with ADL in the namespace where the properties are
/// defined. It can also be overridden in the derived ConcreteOp.
template <typename PropertiesTy>
static Attribute getPropertiesAsAttr(MLIRContext *ctx,
const PropertiesTy &prop) {
return getPropertiesAsAttribute(ctx, prop);
}
/// Hash the provided properties. This default implementation forwards to a
/// free function `computeHash` that can be looked up with ADL in the
/// namespace where the properties are defined. It can also be overridden in
/// the derived ConcreteOp.
template <typename PropertiesTy>
static llvm::hash_code computePropertiesHash(const PropertiesTy &prop) {
return computeHash(prop);
}
private:
/// Trait to check if T provides a 'fold' method for a single result op.
@@ -1733,10 +1794,35 @@ private:
template <typename T>
using detect_has_print = llvm::is_detected<has_print, T>;
/// Trait to check if printProperties(OpAsmPrinter, T) exist
template <typename T, typename... Args>
using has_print_properties = decltype(printProperties(
std::declval<OpAsmPrinter &>(), std::declval<T>()));
template <typename T>
using detect_has_print_properties =
llvm::is_detected<has_print_properties, T>;
/// Trait to check if parseProperties(OpAsmParser, T) exist
template <typename T, typename... Args>
using has_parse_properties = decltype(parseProperties(
std::declval<OpAsmParser &>(), std::declval<T &>()));
template <typename T>
using detect_has_parse_properties =
llvm::is_detected<has_parse_properties, T>;
/// Trait to check if T provides a 'ConcreteEntity' type alias.
template <typename T>
using has_concrete_entity_t = typename T::ConcreteEntity;
public:
/// Returns true if this operation defines a `Properties` inner type.
static constexpr bool hasProperties() {
return !std::is_same_v<
typename ConcreteType::template InferredProperties<ConcreteType>,
EmptyProperties>;
}
private:
/// A struct-wrapped type alias to T::ConcreteEntity if provided and to
/// ConcreteType otherwise. This is akin to std::conditional but doesn't fail
/// on the missing typedef. Useful for checking if the interface is targeting
@@ -1801,11 +1887,18 @@ private:
foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
OpFoldResult result;
if constexpr (has_fold_adaptor_single_result_v<ConcreteOpT>)
result = cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
operands, op->getAttrDictionary(), op->getRegions()));
else
if constexpr (has_fold_adaptor_single_result_v<ConcreteOpT>) {
if constexpr (hasProperties()) {
result = cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
operands, op->getAttrDictionary(),
cast<ConcreteOpT>(op).getProperties(), op->getRegions()));
} else {
result = cast<ConcreteOpT>(op).fold(typename ConcreteOpT::FoldAdaptor(
operands, op->getAttrDictionary(), {}, op->getRegions()));
}
} else {
result = cast<ConcreteOpT>(op).fold(operands);
}
// If the fold failed or was in-place, try to fold the traits of the
// operation.
@@ -1824,10 +1917,18 @@ private:
SmallVectorImpl<OpFoldResult> &results) {
auto result = LogicalResult::failure();
if constexpr (has_fold_adaptor_v<ConcreteOpT>) {
result = cast<ConcreteOpT>(op).fold(
typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(),
op->getRegions()),
results);
if constexpr (hasProperties()) {
result = cast<ConcreteOpT>(op).fold(
typename ConcreteOpT::FoldAdaptor(
operands, op->getAttrDictionary(),
cast<ConcreteOpT>(op).getProperties(), op->getRegions()),
results);
} else {
result = cast<ConcreteOpT>(op).fold(
typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(),
{}, op->getRegions()),
results);
}
} else {
result = cast<ConcreteOpT>(op).fold(operands, results);
}
@@ -1859,6 +1960,48 @@ private:
};
}
public:
template <typename T>
using InferredProperties = typename PropertiesSelector<T>::type;
template <typename T = ConcreteType>
InferredProperties<T> &getProperties() {
if constexpr (!hasProperties())
return getEmptyProperties();
return *getOperation()
->getPropertiesStorage()
.template as<InferredProperties<T> *>();
}
/// This hook populates any unset default attrs when mapped to properties.
template <typename T = ConcreteType>
static void populateDefaultProperties(OperationName opName,
InferredProperties<T> &properties) {}
/// Print the operation properties. Unless overridden, this method will try to
/// dispatch to a `printProperties` free-function if it exists, and otherwise
/// by converting the properties to an Attribute.
template <typename T>
static void printProperties(MLIRContext *ctx, OpAsmPrinter &p,
const T &properties) {
if constexpr (detect_has_print_properties<T>::value)
return printProperties(p, properties);
genericPrintProperties(p,
ConcreteType::getPropertiesAsAttr(ctx, properties));
}
/// Parser the properties. Unless overridden, this method will print by
/// converting the properties to an Attribute.
template <typename T = ConcreteType>
static ParseResult parseProperties(OpAsmParser &parser,
OperationState &result) {
if constexpr (detect_has_parse_properties<InferredProperties<T>>::value) {
return parseProperties(
parser, result.getOrAddProperties<InferredProperties<T>>());
}
return genericParseProperties(parser, result.propertiesAttr);
}
private:
/// Implementation of `PopulateDefaultAttrsFn` OperationName hook.
static OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn() {
return ConcreteType::populateDefaultAttrs;

View File

@@ -957,13 +957,13 @@ public:
/// populated in `result`.
template <typename AttrType>
std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
parseCustomAttributeWithFallback(AttrType &result) {
parseCustomAttributeWithFallback(AttrType &result, Type type = {}) {
SMLoc loc = getCurrentLocation();
// Parse any kind of attribute.
Attribute attr;
if (parseCustomAttributeWithFallback(
attr, {}, [&](Attribute &result, Type type) -> ParseResult {
attr, type, [&](Attribute &result, Type type) -> ParseResult {
result = AttrType::parse(*this, type);
return success(!!result);
}))
@@ -979,8 +979,8 @@ public:
/// SFINAE parsing method for Attribute that don't implement a parse method.
template <typename AttrType>
std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
parseCustomAttributeWithFallback(AttrType &result) {
return parseAttribute(result);
parseCustomAttributeWithFallback(AttrType &result, Type type = {}) {
return parseAttribute(result, type);
}
/// Parse an arbitrary optional attribute of a given type and return it in
@@ -1368,6 +1368,7 @@ public:
std::optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions =
std::nullopt,
std::optional<ArrayRef<NamedAttribute>> parsedAttributes = std::nullopt,
std::optional<Attribute> parsedPropertiesAttribute = std::nullopt,
std::optional<FunctionType> parsedFnType = std::nullopt) = 0;
/// Parse a single SSA value operand name along with a result number if

View File

@@ -22,6 +22,12 @@
#include <optional>
namespace mlir {
namespace detail {
/// This is a "tag" used for mapping the properties storage in
/// llvm::TrailingObjects.
enum class OpProperties : char {};
} // namespace detail
/// Operation is the basic unit of execution within MLIR.
///
/// The following documentation are recommended to understand this class:
@@ -67,26 +73,35 @@ namespace mlir {
/// Some operations like branches also refer to other Block, in which case they
/// would have an array of `BlockOperand`.
///
/// An Operation may contain optionally a "Properties" object: this is a
/// pre-defined C++ object with a fixed size. This object is owned by the
/// operation and deleted with the operation. It can be converted to an
/// Attribute on demand, or loaded from an Attribute.
///
///
/// Finally an Operation also contain an optional `DictionaryAttr`, a Location,
/// and a pointer to its parent Block (if any).
class alignas(8) Operation final
: public llvm::ilist_node_with_parent<Operation, Block>,
private llvm::TrailingObjects<Operation, detail::OperandStorage,
BlockOperand, Region, OpOperand> {
detail::OpProperties, BlockOperand, Region,
OpOperand> {
public:
/// Create a new Operation with the specific fields. This constructor
/// populates the provided attribute list with default attributes if
/// necessary.
static Operation *create(Location location, OperationName name,
TypeRange resultTypes, ValueRange operands,
NamedAttrList &&attributes, BlockRange successors,
NamedAttrList &&attributes,
OpaqueProperties properties, BlockRange successors,
unsigned numRegions);
/// Create a new Operation with the specific fields. This constructor uses an
/// existing attribute dictionary to avoid uniquing a list of attributes.
static Operation *create(Location location, OperationName name,
TypeRange resultTypes, ValueRange operands,
DictionaryAttr attributes, BlockRange successors,
DictionaryAttr attributes,
OpaqueProperties properties, BlockRange successors,
unsigned numRegions);
/// Create a new Operation from the fields stored in `state`.
@@ -96,6 +111,7 @@ public:
static Operation *create(Location location, OperationName name,
TypeRange resultTypes, ValueRange operands,
NamedAttrList &&attributes,
OpaqueProperties properties,
BlockRange successors = {},
RegionRange regions = {});
@@ -414,24 +430,82 @@ public:
// constants to names. Attributes may be dynamically added and removed over
// the lifetime of an operation.
/// Access an inherent attribute by name: returns an empty optional if there
/// is no inherent attribute with this name.
///
/// This method is available as a transient facility in the migration process
/// to use Properties instead.
std::optional<Attribute> getInherentAttr(StringRef name);
/// Set an inherent attribute by name.
///
/// This method is available as a transient facility in the migration process
/// to use Properties instead.
void setInherentAttr(StringAttr name, Attribute value);
/// Access a discardable attribute by name, returns an null Attribute if the
/// discardable attribute does not exist.
Attribute getDiscardableAttr(StringRef name) { return attrs.get(name); }
/// Access a discardable attribute by name, returns an null Attribute if the
/// discardable attribute does not exist.
Attribute getDiscardableAttr(StringAttr name) { return attrs.get(name); }
/// Set a discardable attribute by name.
void setDiscardableAttr(StringAttr name, Attribute value) {
NamedAttrList attributes(attrs);
if (attributes.set(name, value) != value)
attrs = attributes.getDictionary(getContext());
}
/// Return all of the discardable attributes on this operation.
ArrayRef<NamedAttribute> getDiscardableAttrs() { return attrs.getValue(); }
/// Return all of the discardable attributes on this operation as a
/// DictionaryAttr.
DictionaryAttr getDiscardableAttrDictionary() { return attrs; }
/// Return all of the attributes on this operation.
ArrayRef<NamedAttribute> getAttrs() { return attrs.getValue(); }
ArrayRef<NamedAttribute> getAttrs() {
if (!getPropertiesStorage())
return getDiscardableAttrs();
return getAttrDictionary().getValue();
}
/// Return all of the attributes on this operation as a DictionaryAttr.
DictionaryAttr getAttrDictionary() { return attrs; }
DictionaryAttr getAttrDictionary();
/// Set the attribute dictionary on this operation.
void setAttrs(DictionaryAttr newAttrs) {
/// Set the attributes from a dictionary on this operation.
/// These methods are expensive: if the dictionnary only contains discardable
/// attributes, `setDiscardableAttrs` is more efficient.
void setAttrs(DictionaryAttr newAttrs);
void setAttrs(ArrayRef<NamedAttribute> newAttrs);
/// Set the discardable attribute dictionary on this operation.
void setDiscardableAttrs(DictionaryAttr newAttrs) {
assert(newAttrs && "expected valid attribute dictionary");
attrs = newAttrs;
}
void setAttrs(ArrayRef<NamedAttribute> newAttrs) {
setAttrs(DictionaryAttr::get(getContext(), newAttrs));
void setDiscardableAttrs(ArrayRef<NamedAttribute> newAttrs) {
setDiscardableAttrs(DictionaryAttr::get(getContext(), newAttrs));
}
/// Return the specified attribute if present, null otherwise.
Attribute getAttr(StringAttr name) { return attrs.get(name); }
Attribute getAttr(StringRef name) { return attrs.get(name); }
/// These methods are expensive: if the dictionnary only contains discardable
/// attributes, `getDiscardableAttr` is more efficient.
Attribute getAttr(StringAttr name) {
if (getPropertiesStorageSize()) {
if (std::optional<Attribute> inherentAttr = getInherentAttr(name))
return *inherentAttr;
}
return attrs.get(name);
}
Attribute getAttr(StringRef name) {
if (getPropertiesStorageSize()) {
if (std::optional<Attribute> inherentAttr = getInherentAttr(name))
return *inherentAttr;
}
return attrs.get(name);
}
template <typename AttrClass>
AttrClass getAttrOfType(StringAttr name) {
@@ -444,8 +518,20 @@ public:
/// Return true if the operation has an attribute with the provided name,
/// false otherwise.
bool hasAttr(StringAttr name) { return attrs.contains(name); }
bool hasAttr(StringRef name) { return attrs.contains(name); }
bool hasAttr(StringAttr name) {
if (getPropertiesStorageSize()) {
if (std::optional<Attribute> inherentAttr = getInherentAttr(name))
return (bool)*inherentAttr;
}
return attrs.contains(name);
}
bool hasAttr(StringRef name) {
if (getPropertiesStorageSize()) {
if (std::optional<Attribute> inherentAttr = getInherentAttr(name))
return (bool)*inherentAttr;
}
return attrs.contains(name);
}
template <typename AttrClass, typename NameT>
bool hasAttrOfType(NameT &&name) {
return static_cast<bool>(
@@ -455,6 +541,12 @@ public:
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
void setAttr(StringAttr name, Attribute value) {
if (getPropertiesStorageSize()) {
if (std::optional<Attribute> inherentAttr = getInherentAttr(name)) {
setInherentAttr(name, value);
return;
}
}
NamedAttrList attributes(attrs);
if (attributes.set(name, value) != value)
attrs = attributes.getDictionary(getContext());
@@ -467,6 +559,12 @@ public:
/// attribute that was erased, or nullptr if there was no attribute with such
/// name.
Attribute removeAttr(StringAttr name) {
if (getPropertiesStorageSize()) {
if (std::optional<Attribute> inherentAttr = getInherentAttr(name)) {
setInherentAttr(name, {});
return *inherentAttr;
}
}
NamedAttrList attributes(attrs);
Attribute removedAttr = attributes.erase(name);
if (removedAttr)
@@ -511,7 +609,7 @@ public:
return dialect_attr_iterator(attrs.end(), attrs.end());
}
/// Set the dialect attributes for this operation, and preserve all dependent.
/// Set the dialect attributes for this operation, and preserve all inherent.
template <typename DialectAttrT>
void setDialectAttrs(DialectAttrT &&dialectAttrs) {
NamedAttrList attrs;
@@ -735,6 +833,44 @@ public:
/// handlers that may be listening.
InFlightDiagnostic emitRemark(const Twine &message = {});
/// Returns the properties storage size.
int getPropertiesStorageSize() const {
return ((int)propertiesStorageSize) * 8;
}
/// Returns the properties storage.
OpaqueProperties getPropertiesStorage() {
if (propertiesStorageSize)
return {
reinterpret_cast<void *>(getTrailingObjects<detail::OpProperties>())};
return {nullptr};
}
OpaqueProperties getPropertiesStorage() const {
if (propertiesStorageSize)
return {reinterpret_cast<void *>(const_cast<detail::OpProperties *>(
getTrailingObjects<detail::OpProperties>()))};
return {nullptr};
}
/// Return the properties converted to an attribute.
/// This is expensive, and mostly useful when dealing with unregistered
/// operation. Returns an empty attribute if no properties are present.
Attribute getPropertiesAsAttribute();
/// Set the properties from the provided attribute.
/// This is an expensive operation that can fail if the attribute is not
/// matching the expectations of the properties for this operation. This is
/// mostly useful for unregistered operations or used when parsing the
/// generic format. An optional diagnostic can be passed in for richer errors.
LogicalResult setPropertiesFromAttribute(Attribute attr,
InFlightDiagnostic *diagnostic);
/// Copy properties from an existing other properties object. The two objects
/// must be the same type.
void copyProperties(OpaqueProperties rhs);
/// Compute a hash for the op properties (if any).
llvm::hash_code hashProperties();
private:
//===--------------------------------------------------------------------===//
// Ordering
@@ -758,7 +894,8 @@ private:
private:
Operation(Location location, OperationName name, unsigned numResults,
unsigned numSuccessors, unsigned numRegions,
DictionaryAttr attributes, bool hasOperandStorage);
int propertiesStorageSize, DictionaryAttr attributes,
OpaqueProperties properties, bool hasOperandStorage);
// Operations are deleted through the destroy() member because they are
// allocated with malloc.
@@ -845,13 +982,21 @@ private:
const unsigned numResults;
const unsigned numSuccs;
const unsigned numRegions : 31;
const unsigned numRegions : 23;
/// This bit signals whether this operation has an operand storage or not. The
/// operand storage may be elided for operations that are known to never have
/// operands.
bool hasOperandStorage : 1;
/// The size of the storage for properties (if any), divided by 8: since the
/// Properties storage will always be rounded up to the next multiple of 8 we
/// save some bits here.
unsigned char propertiesStorageSize : 8;
/// This is the maximum size we support to allocate properties inline with an
/// operation: this must match the bitwidth above.
static constexpr int64_t propertiesCapacity = 8 * 256;
/// This holds the name of the operation.
OperationName name;
@@ -871,8 +1016,9 @@ private:
friend class llvm::ilist_node_with_parent<Operation, Block>;
// This stuff is used by the TrailingObjects template.
friend llvm::TrailingObjects<Operation, detail::OperandStorage, BlockOperand,
Region, OpOperand>;
friend llvm::TrailingObjects<Operation, detail::OperandStorage,
detail::OpProperties, BlockOperand, Region,
OpOperand>;
size_t numTrailingObjects(OverloadToken<detail::OperandStorage>) const {
return hasOperandStorage ? 1 : 0;
}
@@ -880,6 +1026,9 @@ private:
return numSuccs;
}
size_t numTrailingObjects(OverloadToken<Region>) const { return numRegions; }
size_t numTrailingObjects(OverloadToken<detail::OpProperties>) const {
return getPropertiesStorageSize();
}
};
inline raw_ostream &operator<<(raw_ostream &os, const Operation &op) {

View File

@@ -14,8 +14,10 @@
#ifndef MLIR_IR_OPERATIONSUPPORT_H
#define MLIR_IR_OPERATIONSUPPORT_H
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockSupport.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Types.h"
@@ -24,6 +26,7 @@
#include "llvm/ADT/BitmaskEnum.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
#include "llvm/Support/TrailingObjects.h"
#include <memory>
@@ -37,6 +40,7 @@ namespace mlir {
class Dialect;
class DictionaryAttr;
class ElementsAttr;
struct EmptyProperties;
class MutableOperandRangeRange;
class NamedAttrList;
class Operation;
@@ -59,6 +63,25 @@ class ValueRange;
template <typename ValueRangeT>
class ValueTypeRange;
//===----------------------------------------------------------------------===//
// OpaqueProperties
//===----------------------------------------------------------------------===//
/// Simple wrapper around a void* in order to express generically how to pass
/// in op properties through APIs.
class OpaqueProperties {
public:
OpaqueProperties(void *prop) : properties(prop) {}
operator bool() const { return properties != nullptr; }
template <typename Dest>
Dest as() const {
return static_cast<Dest>(const_cast<void *>(properties));
}
private:
void *properties;
};
//===----------------------------------------------------------------------===//
// OperationName
//===----------------------------------------------------------------------===//
@@ -98,6 +121,26 @@ public:
virtual void printAssembly(Operation *, OpAsmPrinter &, StringRef) = 0;
virtual LogicalResult verifyInvariants(Operation *) = 0;
virtual LogicalResult verifyRegionInvariants(Operation *) = 0;
/// Implementation for properties
virtual std::optional<Attribute> getInherentAttr(Operation *,
StringRef name) = 0;
virtual void setInherentAttr(Operation *op, StringAttr name,
Attribute value) = 0;
virtual void populateInherentAttrs(Operation *op, NamedAttrList &attrs) = 0;
virtual LogicalResult
verifyInherentAttrs(OperationName opName, NamedAttrList &attributes,
function_ref<InFlightDiagnostic()> getDiag) = 0;
virtual int getOpPropertyByteSize() = 0;
virtual void initProperties(OperationName opName, OpaqueProperties storage,
OpaqueProperties init) = 0;
virtual void deleteProperties(OpaqueProperties) = 0;
virtual void populateDefaultProperties(OperationName opName,
OpaqueProperties properties) = 0;
virtual LogicalResult setPropertiesFromAttr(Operation *, Attribute,
InFlightDiagnostic *) = 0;
virtual Attribute getPropertiesAsAttr(Operation *) = 0;
virtual void copyProperties(OpaqueProperties, OpaqueProperties) = 0;
virtual llvm::hash_code hashProperties(OpaqueProperties) = 0;
};
public:
@@ -158,6 +201,25 @@ protected:
void printAssembly(Operation *, OpAsmPrinter &, StringRef) final;
LogicalResult verifyInvariants(Operation *) final;
LogicalResult verifyRegionInvariants(Operation *) final;
/// Implementation for properties
std::optional<Attribute> getInherentAttr(Operation *op,
StringRef name) final;
void setInherentAttr(Operation *op, StringAttr name, Attribute value) final;
void populateInherentAttrs(Operation *op, NamedAttrList &attrs) final;
LogicalResult
verifyInherentAttrs(OperationName opName, NamedAttrList &attributes,
function_ref<InFlightDiagnostic()> getDiag) final;
int getOpPropertyByteSize() final;
void initProperties(OperationName opName, OpaqueProperties storage,
OpaqueProperties init) final;
void deleteProperties(OpaqueProperties) final;
void populateDefaultProperties(OperationName opName,
OpaqueProperties properties) final;
LogicalResult setPropertiesFromAttr(Operation *, Attribute,
InFlightDiagnostic *) final;
Attribute getPropertiesAsAttr(Operation *) final;
void copyProperties(OpaqueProperties, OpaqueProperties) final;
llvm::hash_code hashProperties(OpaqueProperties) final;
};
public:
@@ -309,6 +371,68 @@ public:
return !isRegistered() || hasInterface(interfaceID);
}
/// Lookup an inherent attribute by name, this method isn't recommended
/// and may be removed in the future.
std::optional<Attribute> getInherentAttr(Operation *op,
StringRef name) const {
return getImpl()->getInherentAttr(op, name);
}
void setInherentAttr(Operation *op, StringAttr name, Attribute value) const {
return getImpl()->setInherentAttr(op, name, value);
}
void populateInherentAttrs(Operation *op, NamedAttrList &attrs) const {
return getImpl()->populateInherentAttrs(op, attrs);
}
/// This method exists for backward compatibility purpose when using
/// properties to store inherent attributes, it enables validating the
/// attributes when parsed from the older generic syntax pre-Properties.
LogicalResult
verifyInherentAttrs(NamedAttrList &attributes,
function_ref<InFlightDiagnostic()> getDiag) const {
return getImpl()->verifyInherentAttrs(*this, attributes, getDiag);
}
/// This hooks return the number of bytes to allocate for the op properties.
int getOpPropertyByteSize() const {
return getImpl()->getOpPropertyByteSize();
}
/// This hooks destroy the op properties.
void destroyOpProperties(OpaqueProperties properties) const {
getImpl()->deleteProperties(properties);
}
/// Initialize the op properties.
void initOpProperties(OpaqueProperties storage, OpaqueProperties init) const {
getImpl()->initProperties(*this, storage, init);
}
/// Set the default values on the ODS attribute in the properties.
void populateDefaultProperties(OpaqueProperties properties) const {
getImpl()->populateDefaultProperties(*this, properties);
}
/// Return the op properties converted to an Attribute.
Attribute getOpPropertiesAsAttribute(Operation *op) const {
return getImpl()->getPropertiesAsAttr(op);
}
/// Define the op properties from the provided Attribute.
LogicalResult
setOpPropertiesFromAttribute(Operation *op, Attribute properties,
InFlightDiagnostic *diagnostic) const {
return getImpl()->setPropertiesFromAttr(op, properties, diagnostic);
}
void copyOpProperties(OpaqueProperties lhs, OpaqueProperties rhs) const {
return getImpl()->copyProperties(lhs, rhs);
}
llvm::hash_code hashOpProperties(OpaqueProperties properties) const {
return getImpl()->hashProperties(properties);
}
/// Return the dialect this operation is registered to if the dialect is
/// loaded in the context, or nullptr if the dialect isn't loaded.
Dialect *getDialect() const {
@@ -413,6 +537,104 @@ public:
LogicalResult verifyRegionInvariants(Operation *op) final {
return ConcreteOp::getVerifyRegionInvariantsFn()(op);
}
/// Implementation for "Properties"
using Properties = std::remove_reference_t<
decltype(std::declval<ConcreteOp>().getProperties())>;
std::optional<Attribute> getInherentAttr(Operation *op,
StringRef name) final {
if constexpr (hasProperties) {
auto concreteOp = cast<ConcreteOp>(op);
return ConcreteOp::getInherentAttr(concreteOp.getProperties(), name);
}
// If the op does not have support for properties, we dispatch back to the
// dictionnary of discardable attributes for now.
return cast<ConcreteOp>(op)->getDiscardableAttr(name);
}
void setInherentAttr(Operation *op, StringAttr name,
Attribute value) final {
if constexpr (hasProperties) {
auto concreteOp = cast<ConcreteOp>(op);
return ConcreteOp::setInherentAttr(concreteOp.getProperties(), name,
value);
}
// If the op does not have support for properties, we dispatch back to the
// dictionnary of discardable attributes for now.
return cast<ConcreteOp>(op)->setDiscardableAttr(name, value);
}
void populateInherentAttrs(Operation *op, NamedAttrList &attrs) final {
if constexpr (hasProperties) {
auto concreteOp = cast<ConcreteOp>(op);
ConcreteOp::populateInherentAttrs(concreteOp.getProperties(), attrs);
}
}
LogicalResult
verifyInherentAttrs(OperationName opName, NamedAttrList &attributes,
function_ref<InFlightDiagnostic()> getDiag) final {
if constexpr (hasProperties)
return ConcreteOp::verifyInherentAttrs(opName, attributes, getDiag);
return success();
}
// Detect if the concrete operation defined properties.
static constexpr bool hasProperties = !std::is_same_v<
typename ConcreteOp::template InferredProperties<ConcreteOp>,
EmptyProperties>;
int getOpPropertyByteSize() final {
if constexpr (hasProperties)
return sizeof(Properties);
return 0;
}
void initProperties(OperationName opName, OpaqueProperties storage,
OpaqueProperties init) final {
using Properties =
typename ConcreteOp::template InferredProperties<ConcreteOp>;
if (init)
new (storage.as<Properties *>()) Properties(*init.as<Properties *>());
else
new (storage.as<Properties *>()) Properties();
if constexpr (hasProperties)
ConcreteOp::populateDefaultProperties(opName,
*storage.as<Properties *>());
}
void deleteProperties(OpaqueProperties prop) final {
prop.as<Properties *>()->~Properties();
}
void populateDefaultProperties(OperationName opName,
OpaqueProperties properties) final {
if constexpr (hasProperties)
ConcreteOp::populateDefaultProperties(opName,
*properties.as<Properties *>());
}
LogicalResult setPropertiesFromAttr(Operation *op, Attribute attr,
InFlightDiagnostic *diag) final {
if constexpr (hasProperties)
return ConcreteOp::setPropertiesFromAttr(
cast<ConcreteOp>(op).getProperties(), attr, diag);
if (diag)
*diag << "This operation does not support properties";
return failure();
}
Attribute getPropertiesAsAttr(Operation *op) final {
if constexpr (hasProperties) {
auto concreteOp = cast<ConcreteOp>(op);
return ConcreteOp::getPropertiesAsAttr(concreteOp->getContext(),
concreteOp.getProperties());
}
return {};
}
void copyProperties(OpaqueProperties lhs, OpaqueProperties rhs) final {
*lhs.as<Properties *>() = *rhs.as<Properties *>();
}
llvm::hash_code hashProperties(OpaqueProperties prop) final {
if constexpr (hasProperties)
return ConcreteOp::computePropertiesHash(*prop.as<Properties *>());
return {};
}
};
/// Lookup the registered operation information for the given operation.
@@ -600,6 +822,11 @@ public:
assign(range.begin(), range.end());
}
void clear() {
attrs.clear();
dictionarySorted.setPointerAndInt(nullptr, false);
}
bool empty() const { return attrs.empty(); }
void reserve(size_type N) { attrs.reserve(N); }
@@ -694,6 +921,19 @@ struct OperationState {
/// Regions that the op will hold.
SmallVector<std::unique_ptr<Region>, 1> regions;
// If we're creating an unregistered operation, this Attribute is used to
// build the properties. Otherwise it is ignored. For registered operations
// see the `getOrAddProperties` method.
Attribute propertiesAttr;
private:
OpaqueProperties properties = nullptr;
TypeID propertiesId;
llvm::function_ref<void(OpaqueProperties)> propertiesDeleter;
llvm::function_ref<void(OpaqueProperties, const OpaqueProperties)>
propertiesSetter;
friend class Operation;
public:
OperationState(Location location, StringRef name);
OperationState(Location location, OperationName name);
@@ -706,6 +946,37 @@ public:
TypeRange types, ArrayRef<NamedAttribute> attributes = {},
BlockRange successors = {},
MutableArrayRef<std::unique_ptr<Region>> regions = {});
OperationState(OperationState &&other) = default;
OperationState(const OperationState &other) = default;
OperationState &operator=(OperationState &&other) = default;
OperationState &operator=(const OperationState &other) = default;
~OperationState();
/// Get (or create) a properties of the provided type to be set on the
/// operation on creation.
template <typename T>
T &getOrAddProperties() {
if (!properties) {
T *p = new T{};
properties = p;
propertiesDeleter = [](OpaqueProperties prop) {
delete prop.as<const T *>();
};
propertiesSetter = [](OpaqueProperties new_prop,
const OpaqueProperties prop) {
*new_prop.as<T *>() = *prop.as<const T *>();
};
propertiesId = TypeID::get<T>();
}
assert(propertiesId == TypeID::get<T>() && "Inconsistent properties");
return *properties.as<T *>();
}
OpaqueProperties getRawProperties() { return properties; }
// Set the properties defined on this OpState on the given operation,
// optionally emit diagnostics on error through the provided diagnostic.
LogicalResult setProperties(Operation *op,
InFlightDiagnostic *diagnostic) const;
void addOperands(ValueRange newOperands);

View File

@@ -244,11 +244,11 @@ LogicalResult inferReturnTensorTypes(
function_ref<
LogicalResult(MLIRContext *, std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
RegionRange regions,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &retComponents)>
componentTypeFn,
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes);
/// Verifies that the inferred result types match the actual result types for
@@ -281,7 +281,7 @@ public:
static LogicalResult
inferReturnTypes(MLIRContext *context, std::optional<Location> location,
ValueRange operands, DictionaryAttr attributes,
RegionRange regions,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
static_assert(
ConcreteType::template hasTrait<InferShapedTypeOpInterface::Trait>(),
@@ -291,7 +291,7 @@ public:
"requires InferTypeOpInterface to ensure succesful invocation");
return ::mlir::detail::inferReturnTensorTypes(
ConcreteType::inferReturnTypeComponents, context, location, operands,
attributes, regions, inferredReturnTypes);
attributes, properties, regions, inferredReturnTypes);
}
};

View File

@@ -44,6 +44,7 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
"::std::optional<::mlir::Location>":$location,
"::mlir::ValueRange":$operands,
"::mlir::DictionaryAttr":$attributes,
"::mlir::OpaqueProperties":$properties,
"::mlir::RegionRange":$regions,
"::llvm::SmallVectorImpl<::mlir::Type>&":$inferredReturnTypes)
>,
@@ -75,13 +76,14 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
"::std::optional<::mlir::Location>":$location,
"::mlir::ValueRange":$operands,
"::mlir::DictionaryAttr":$attributes,
"::mlir::OpaqueProperties":$properties,
"::mlir::RegionRange":$regions,
"::llvm::SmallVectorImpl<::mlir::Type>&":$returnTypes),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
llvm::SmallVector<Type, 4> inferredReturnTypes;
if (failed(ConcreteOp::inferReturnTypes(context, location, operands,
attributes, regions,
attributes, properties, regions,
inferredReturnTypes)))
return failure();
if (!ConcreteOp::isCompatibleReturnTypes(inferredReturnTypes,
@@ -147,6 +149,7 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
"::std::optional<::mlir::Location>":$location,
"::mlir::ValueShapeRange":$operands,
"::mlir::DictionaryAttr":$attributes,
"::mlir::OpaqueProperties":$properties,
"::mlir::RegionRange":$regions,
"::llvm::SmallVectorImpl<::mlir::ShapedTypeComponents>&":
$inferredReturnShapes),

View File

@@ -22,6 +22,7 @@
#define MLIR_TABLEGEN_ARGUMENT_H_
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Property.h"
#include "mlir/TableGen/Type.h"
#include "llvm/ADT/PointerUnion.h"
#include <string>
@@ -58,8 +59,9 @@ struct NamedTypeConstraint {
TypeConstraint constraint;
};
// Operation argument: either attribute or operand
using Argument = llvm::PointerUnion<NamedAttribute *, NamedTypeConstraint *>;
// Operation argument: either attribute, property, or operand
using Argument = llvm::PointerUnion<NamedAttribute *, NamedProperty *,
NamedTypeConstraint *>;
} // namespace tblgen
} // namespace mlir

View File

@@ -576,7 +576,12 @@ class ExtraClassDeclaration
public:
/// Create an extra class declaration.
ExtraClassDeclaration(StringRef extraClassDeclaration,
StringRef extraClassDefinition = "")
std::string extraClassDefinition = "")
: ExtraClassDeclaration(extraClassDeclaration.str(),
std::move(extraClassDefinition)) {}
ExtraClassDeclaration(std::string extraClassDeclaration,
std::string extraClassDefinition = "")
: extraClassDeclaration(extraClassDeclaration),
extraClassDefinition(extraClassDefinition) {}
@@ -590,7 +595,7 @@ public:
private:
/// The string of the extra class declarations. It is re-indented before
/// printed.
StringRef extraClassDeclaration;
std::string extraClassDeclaration;
/// The string of the extra class definitions. It is re-indented before
/// printed.
std::string extraClassDefinition;

View File

@@ -86,6 +86,10 @@ public:
/// operations or types.
bool isExtensible() const;
/// Default to use properties for storing Attributes for operations in this
/// dialect.
bool usePropertiesForAttributes() const;
// Returns whether two dialects are equal by checking the equality of the
// underlying record.
bool operator==(const Dialect &other) const;

View File

@@ -18,6 +18,7 @@
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Builder.h"
#include "mlir/TableGen/Dialect.h"
#include "mlir/TableGen/Property.h"
#include "mlir/TableGen/Region.h"
#include "mlir/TableGen/Successor.h"
#include "mlir/TableGen/Trait.h"
@@ -166,10 +167,14 @@ public:
unsigned getNumVariableLengthResults() const;
/// Op attribute iterators.
using attribute_iterator = const NamedAttribute *;
attribute_iterator attribute_begin() const;
attribute_iterator attribute_end() const;
llvm::iterator_range<attribute_iterator> getAttributes() const;
using const_attribute_iterator = const NamedAttribute *;
const_attribute_iterator attribute_begin() const;
const_attribute_iterator attribute_end() const;
llvm::iterator_range<const_attribute_iterator> getAttributes() const;
using attribute_iterator = NamedAttribute *;
attribute_iterator attribute_begin();
attribute_iterator attribute_end();
llvm::iterator_range<attribute_iterator> getAttributes();
int getNumAttributes() const { return attributes.size(); }
int getNumNativeAttributes() const { return numNativeAttributes; }
@@ -185,6 +190,27 @@ public:
const_value_iterator operand_end() const;
const_value_range getOperands() const;
// Op properties iterators.
using const_property_iterator = const NamedProperty *;
const_property_iterator properties_begin() const {
return properties.begin();
}
const_property_iterator properties_end() const { return properties.end(); }
llvm::iterator_range<const_property_iterator> getProperties() const {
return properties;
}
using property_iterator = NamedProperty *;
property_iterator properties_begin() { return properties.begin(); }
property_iterator properties_end() { return properties.end(); }
llvm::iterator_range<property_iterator> getProperties() { return properties; }
int getNumCoreAttributes() const { return properties.size(); }
// Op properties accessors.
NamedProperty &getProperty(int index) { return properties[index]; }
const NamedProperty &getProperty(int index) const {
return properties[index];
}
int getNumOperands() const { return operands.size(); }
NamedTypeConstraint &getOperand(int index) { return operands[index]; }
const NamedTypeConstraint &getOperand(int index) const {
@@ -353,6 +379,9 @@ private:
/// computed upon request).
SmallVector<NamedAttribute, 4> attributes;
/// The properties of the op.
SmallVector<NamedProperty> properties;
/// The arguments of the op (operands and native attributes).
SmallVector<Argument, 4> arguments;

View File

@@ -0,0 +1,86 @@
//===- Property.h - Property wrapper class --------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Property wrapper to simplify using TableGen Record defining a MLIR
// Property.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_PROPERTY_H_
#define MLIR_TABLEGEN_PROPERTY_H_
#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Constraint.h"
#include "llvm/ADT/StringRef.h"
namespace llvm {
class DefInit;
class Record;
} // namespace llvm
namespace mlir {
namespace tblgen {
class Dialect;
class Type;
// Wrapper class providing helper methods for accessing MLIR Property defined
// in TableGen. This class should closely reflect what is defined as class
// `Property` in TableGen.
class Property {
public:
explicit Property(const llvm::Record *record);
explicit Property(const llvm::DefInit *init);
// Returns the storage type.
StringRef getStorageType() const;
// Returns the interface type for this property.
StringRef getInterfaceType() const;
// Returns the template getter method call which reads this property's
// storage and returns the value as of the desired return type.
StringRef getConvertFromStorageCall() const;
// Returns the template setter method call which reads this property's
// in the provided interface type and assign it to the storage.
StringRef getAssignToStorageCall() const;
// Returns the conversion method call which reads this property's
// in the storage type and builds an attribute.
StringRef getConvertToAttributeCall() const;
// Returns the setter method call which reads this property's
// in the provided interface type and assign it to the storage.
StringRef getConvertFromAttributeCall() const;
// Returns the code to compute the hash for this property.
StringRef getHashPropertyCall() const;
// Returns whether this Property has a default value.
bool hasDefaultValue() const;
// Returns the default value for this Property.
StringRef getDefaultValue() const;
// Returns the TableGen definition this Property was constructed from.
const llvm::Record &getDef() const;
private:
// The TableGen definition of this constraint.
const llvm::Record *def;
};
// A struct wrapping an op property and its name together
struct NamedProperty {
llvm::StringRef name;
Property prop;
};
} // namespace tblgen
} // namespace mlir
#endif // MLIR_TABLEGEN_PROPERTY_H_

View File

@@ -526,9 +526,14 @@ public:
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op),
OpAdaptor(operands, op->getAttrDictionary()),
rewriter);
auto sourceOp = cast<SourceOp>(op);
if constexpr (SourceOp::hasProperties())
return matchAndRewrite(sourceOp,
OpAdaptor(operands, op->getAttrDictionary(),
sourceOp.getProperties()),
rewriter);
return matchAndRewrite(
sourceOp, OpAdaptor(operands, op->getAttrDictionary()), rewriter);
}
/// Rewrite and Match methods that operate on the SourceOp type. These must be

View File

@@ -225,13 +225,16 @@ public:
public:
using RangeT = ArrayRef<ValueRange>;
using BaseT = typename SourceOp::template GenericAdaptor<RangeT>;
using Properties = typename SourceOp::template InferredProperties<SourceOp>;
OpAdaptor(const OneToNTypeMapping *operandMapping,
const OneToNTypeMapping *resultMapping,
const ValueRange *convertedOperands, RangeT values,
DictionaryAttr attrs = nullptr, RegionRange regions = {})
: BaseT(values, attrs, regions), operandMapping(operandMapping),
resultMapping(resultMapping), convertedOperands(convertedOperands) {}
DictionaryAttr attrs = nullptr, Properties &properties = {},
RegionRange regions = {})
: BaseT(values, attrs, properties, regions),
operandMapping(operandMapping), resultMapping(resultMapping),
convertedOperands(convertedOperands) {}
/// Get the type mapping of the original operands to the converted operands.
const OneToNTypeMapping &getOperandMapping() const {
@@ -271,7 +274,8 @@ public:
valueRanges.push_back(values);
}
OpAdaptor adaptor(&operandMapping, &resultMapping, &convertedOperands,
valueRanges, op->getAttrDictionary(), op->getRegions());
valueRanges, op->getAttrDictionary(),
cast<SourceOp>(op).getProperties(), op->getRegions());
// Call overload implemented by the derived class.
return matchAndRewrite(cast<SourceOp>(op), adaptor, rewriter);

View File

@@ -540,6 +540,7 @@ public:
std::optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions =
std::nullopt,
std::optional<ArrayRef<NamedAttribute>> parsedAttributes = std::nullopt,
std::optional<Attribute> propertiesAttribute = std::nullopt,
std::optional<FunctionType> parsedFnType = std::nullopt);
/// Parse an operation instance that is in the generic form and insert it at
@@ -1075,7 +1076,8 @@ Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) {
auto name = OperationName("builtin.unrealized_conversion_cast", getContext());
auto *op = Operation::create(
getEncodedSourceLocation(loc), name, type, /*operands=*/{},
/*attributes=*/std::nullopt, /*successors=*/{}, /*numRegions=*/0);
/*attributes=*/std::nullopt, /*properties=*/nullptr, /*successors=*/{},
/*numRegions=*/0);
forwardRefPlaceholders[op->getResult(0)] = loc;
return op->getResult(0);
}
@@ -1255,6 +1257,7 @@ ParseResult OperationParser::parseGenericOperationAfterOpName(
std::optional<ArrayRef<Block *>> parsedSuccessors,
std::optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions,
std::optional<ArrayRef<NamedAttribute>> parsedAttributes,
std::optional<Attribute> propertiesAttribute,
std::optional<FunctionType> parsedFnType) {
// Parse the operand list, if not explicitly provided.
@@ -1284,6 +1287,16 @@ ParseResult OperationParser::parseGenericOperationAfterOpName(
result.addSuccessors(*parsedSuccessors);
}
// Parse the properties, if not explicitly provided.
if (propertiesAttribute) {
result.propertiesAttr = *propertiesAttribute;
} else if (consumeIf(Token::less)) {
result.propertiesAttr = parseAttribute();
if (!result.propertiesAttr)
return failure();
if (parseToken(Token::greater, "expected '>' to close properties"))
return failure();
}
// Parse the region list, if not explicitly provided.
if (!parsedRegions) {
if (consumeIf(Token::l_paren)) {
@@ -1390,10 +1403,52 @@ Operation *OperationParser::parseGenericOperation() {
if (parseGenericOperationAfterOpName(result))
return nullptr;
// Operation::create() is not allowed to fail, however setting the properties
// from an attribute is a failable operation. So we save the attribute here
// and set it on the operation post-parsing.
Attribute properties;
std::swap(properties, result.propertiesAttr);
// If we don't have properties in the textual IR, but the operation now has
// support for properties, we support some backward-compatible generic syntax
// for the operation and as such we accept inherent attributes mixed in the
// dictionary of discardable attributes. We pre-validate these here because
// invalid attributes can't be casted to the properties storage and will be
// silently dropped. For example an attribute { foo = 0 : i32 } that is
// declared as F32Attr in ODS would have a C++ type of FloatAttr in the
// properties array. When setting it we would do something like:
//
// properties.foo = dyn_cast<FloatAttr>(fooAttr);
//
// which would end up with a null Attribute. The diagnostic from the verifier
// would be "missing foo attribute" instead of something like "expects a 32
// bits float attribute but got a 32 bits integer attribute".
if (!properties && !result.getRawProperties()) {
Optional<RegisteredOperationName> info = result.name.getRegisteredInfo();
if (info) {
if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
return mlir::emitError(srcLocation) << "'" << name << "' op ";
})))
return nullptr;
}
}
// Create the operation and try to parse a location for it.
Operation *op = opBuilder.create(result);
if (parseTrailingLocationSpecifier(op))
return nullptr;
// Try setting the properties for the operation, using a diagnostic to print
// errors.
if (properties) {
InFlightDiagnostic diagnostic =
mlir::emitError(srcLocation, "invalid properties ")
<< properties << " for op " << name << ": ";
if (failed(op->setPropertiesFromAttribute(properties, &diagnostic)))
return nullptr;
diagnostic.abandon();
}
return op;
}
@@ -1461,10 +1516,11 @@ public:
std::optional<ArrayRef<Block *>> parsedSuccessors,
std::optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions,
std::optional<ArrayRef<NamedAttribute>> parsedAttributes,
std::optional<Attribute> parsedPropertiesAttribute,
std::optional<FunctionType> parsedFnType) final {
return parser.parseGenericOperationAfterOpName(
result, parsedUnresolvedOperands, parsedSuccessors, parsedRegions,
parsedAttributes, parsedFnType);
parsedAttributes, parsedPropertiesAttribute, parsedFnType);
}
//===--------------------------------------------------------------------===//
// Utilities
@@ -1933,10 +1989,23 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
if (opAsmParser.didEmitError())
return nullptr;
Attribute properties = opState.propertiesAttr;
opState.propertiesAttr = Attribute{};
// Otherwise, create the operation and try to parse a location for it.
Operation *op = opBuilder.create(opState);
if (parseTrailingLocationSpecifier(op))
return nullptr;
// Try setting the properties for the operation.
if (properties) {
InFlightDiagnostic diagnostic =
mlir::emitError(srcLocation, "invalid properties ")
<< properties << " for op " << op->getName().getStringRef() << ": ";
if (failed(op->setPropertiesFromAttribute(properties, &diagnostic)))
return nullptr;
diagnostic.abandon();
}
return op;
}

View File

@@ -340,7 +340,8 @@ static LogicalResult inferOperationTypes(OperationState &state) {
if (succeeded(inferInterface->inferReturnTypes(
context, state.location, state.operands,
state.attributes.getDictionary(context), state.regions, state.types)))
state.attributes.getDictionary(context), state.getRawProperties(),
state.regions, state.types)))
return success();
// Diagnostic emitted by interface.

View File

@@ -113,6 +113,7 @@ private:
resultTypes.push_back(tokenType);
auto *newOp = Operation::create(op->getLoc(), op->getName(), resultTypes,
op->getOperands(), op->getAttrDictionary(),
op->getPropertiesStorage(),
op->getSuccessors(), op->getNumRegions());
// Clone regions into new op.

View File

@@ -15,6 +15,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
@@ -1354,9 +1355,10 @@ void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
/// shape of the source.
LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
ExtractStridedMetadataOpAdaptor extractAdaptor(operands, attributes, regions);
ExtractStridedMetadataOpAdaptor extractAdaptor(
operands, attributes, *properties.as<EmptyProperties *>(), regions);
auto sourceType = extractAdaptor.getSource().getType().dyn_cast<MemRefType>();
if (!sourceType)
return failure();

View File

@@ -833,7 +833,8 @@ class ExtractStridedMetadataOpReinterpretCastFolder
SmallVector<Type> inferredReturnTypes;
if (failed(extractStridedMetadataOp.inferReturnTypes(
rewriter.getContext(), loc, {reinterpretCastOp.getSource()},
/*attributes=*/{}, /*regions=*/{}, inferredReturnTypes)))
/*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
inferredReturnTypes)))
return rewriter.notifyMatchFailure(
reinterpretCastOp, "reinterpret_cast source's type is incompatible");

View File

@@ -1779,7 +1779,7 @@ bool mlir::scf::insideMutuallyExclusiveBranches(Operation *a, Operation *b) {
LogicalResult
IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
ValueRange operands, DictionaryAttr attrs,
RegionRange regions,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (regions.empty())
return failure();
@@ -1872,7 +1872,8 @@ void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
MLIRContext *ctx = builder.getContext();
auto attrDict = DictionaryAttr::get(ctx, result.attributes);
if (succeeded(inferReturnTypes(ctx, std::nullopt, result.operands, attrDict,
result.regions, inferredReturnTypes))) {
/*properties=*/nullptr, result.regions,
inferredReturnTypes))) {
result.addTypes(inferredReturnTypes);
}
}

View File

@@ -393,7 +393,7 @@ void AssumingOp::build(
LogicalResult mlir::shape::AddOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType().isa<SizeType>() ||
operands[1].getType().isa<SizeType>())
@@ -911,7 +911,7 @@ void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
Builder b(context);
auto shape = attributes.getAs<DenseIntElementsAttr>("shape");
@@ -1092,7 +1092,7 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
LogicalResult mlir::shape::DimOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
DimOpAdaptor dimOp(operands);
inferredReturnTypes.assign({dimOp.getIndex().getType()});
@@ -1140,7 +1140,7 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
LogicalResult mlir::shape::DivOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType().isa<SizeType>() ||
operands[1].getType().isa<SizeType>())
@@ -1361,7 +1361,7 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.assign({IndexType::get(context)});
return success();
@@ -1399,7 +1399,7 @@ OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
LogicalResult mlir::shape::MeetOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands.empty())
return failure();
@@ -1535,7 +1535,7 @@ void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
LogicalResult mlir::shape::RankOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType().isa<ShapeType>())
inferredReturnTypes.assign({SizeType::get(context)});
@@ -1571,7 +1571,7 @@ OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType().isa<ShapeType>())
inferredReturnTypes.assign({SizeType::get(context)});
@@ -1603,7 +1603,7 @@ OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {
LogicalResult mlir::shape::MaxOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType() == operands[1].getType())
inferredReturnTypes.assign({operands[0].getType()});
@@ -1635,7 +1635,7 @@ OpFoldResult MinOp::fold(FoldAdaptor adaptor) {
LogicalResult mlir::shape::MinOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType() == operands[1].getType())
inferredReturnTypes.assign({operands[0].getType()});
@@ -1672,7 +1672,7 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
LogicalResult mlir::shape::MulOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType().isa<SizeType>() ||
operands[1].getType().isa<SizeType>())
@@ -1759,7 +1759,7 @@ void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType().isa<ValueShapeType>())
inferredReturnTypes.assign({ShapeType::get(context)});

View File

@@ -364,7 +364,8 @@ static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
IntegerAttr axis = attributes.get("axis").cast<IntegerAttr>();
@@ -389,7 +390,8 @@ LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
@@ -415,7 +417,8 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
inferredReturnShapes.push_back(ShapedTypeComponents(operands.getShape(0)));
inferredReturnShapes.push_back(ShapedTypeComponents(operands.getShape(1)));
@@ -424,7 +427,8 @@ LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
// Infer all dimension sizes by reducing based on inputs.
int32_t axis =
@@ -484,7 +488,8 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
LogicalResult tosa::EqualOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outShape;
if (resolveBroadcastShape(operands, outShape).failed()) {
@@ -505,7 +510,8 @@ bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
ShapeAdaptor weightShape = operands.getShape(1);
@@ -536,7 +542,8 @@ LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); }
LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor lhsShape = operands.getShape(0);
ShapeAdaptor rhsShape = operands.getShape(1);
@@ -562,7 +569,8 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
LogicalResult tosa::PadOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
ShapeAdaptor paddingShape = operands.getShape(1);
@@ -624,7 +632,8 @@ static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
LogicalResult tosa::SliceOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
inferredReturnShapes.push_back(ShapedTypeComponents(
convertToMlirShape(SliceOpAdaptor(operands, attributes).getSize())));
@@ -633,7 +642,8 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
LogicalResult tosa::TableOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
@@ -649,7 +659,8 @@ LogicalResult tosa::TableOp::inferReturnTypeComponents(
LogicalResult tosa::TileOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
TileOpAdaptor adaptor(operands, attributes);
ArrayRef<int64_t> multiples = adaptor.getMultiples();
@@ -682,7 +693,8 @@ bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ReshapeOpAdaptor adaptor(operands, attributes);
ShapeAdaptor inputShape = operands.getShape(0);
@@ -751,7 +763,8 @@ LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int64_t> &perms) {
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
ShapeAdaptor permsShape = operands.getShape(1);
@@ -818,7 +831,8 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
LogicalResult tosa::GatherOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape;
outputShape.resize(3, ShapedType::kDynamic);
@@ -843,7 +857,8 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ResizeOpAdaptor adaptor(operands, attributes);
llvm::SmallVector<int64_t, 4> outputShape;
@@ -883,7 +898,8 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape;
outputShape.resize(3, ShapedType::kDynamic);
@@ -942,7 +958,7 @@ static LogicalResult ReduceInferReturnTypes(
LogicalResult OP::inferReturnTypeComponents( \
MLIRContext *context, ::std::optional<Location> location, \
ValueShapeRange operands, DictionaryAttr attributes, \
RegionRange regions, \
OpaqueProperties properties, RegionRange regions, \
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
Type inputType = \
operands.getType()[0].cast<TensorType>().getElementType(); \
@@ -978,7 +994,7 @@ static LogicalResult NAryInferReturnTypes(
LogicalResult OP::inferReturnTypeComponents( \
MLIRContext *context, ::std::optional<Location> location, \
ValueShapeRange operands, DictionaryAttr attributes, \
RegionRange regions, \
OpaqueProperties properties, RegionRange regions, \
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
return NAryInferReturnTypes(operands, inferredReturnShapes); \
}
@@ -1062,7 +1078,8 @@ static LogicalResult poolingInferReturnTypes(
LogicalResult Conv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
Conv2DOp::Adaptor adaptor(operands.getValues(), attributes);
@@ -1125,7 +1142,8 @@ LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); }
LogicalResult Conv3DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
Conv3DOp::Adaptor adaptor(operands.getValues(), attributes);
@@ -1198,21 +1216,24 @@ LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); }
LogicalResult AvgPool2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
}
LogicalResult MaxPool2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
}
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
DepthwiseConv2DOp::Adaptor adaptor(operands.getValues(), attributes);
@@ -1288,7 +1309,8 @@ LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); }
LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
TransposeConv2DOp::Adaptor adaptor(operands.getValues(), attributes);
// outputShape is mutable.
@@ -1353,7 +1375,8 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
LogicalResult IfOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<tosa::YieldOp> yieldOps;
for (Region *region : regions) {
@@ -1397,7 +1420,8 @@ LogicalResult IfOp::inferReturnTypeComponents(
LogicalResult WhileOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<tosa::YieldOp> yieldOps;
for (auto &block : *regions[1])

View File

@@ -39,6 +39,7 @@ TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy,
if (shapeInterface
.inferReturnTypeComponents(op.getContext(), op.getLoc(),
op->getOperands(), op->getAttrDictionary(),
op->getPropertiesStorage(),
op->getRegions(), returnedShapes)
.failed())
return op;

View File

@@ -218,9 +218,9 @@ void propagateShapesInRegion(Region &region) {
ValueShapeRange range(op.getOperands(), operandShape);
if (shapeInterface
.inferReturnTypeComponents(op.getContext(), op.getLoc(), range,
op.getAttrDictionary(),
op.getRegions(), returnedShapes)
.inferReturnTypeComponents(
op.getContext(), op.getLoc(), range, op.getAttrDictionary(),
op.getPropertiesStorage(), op.getRegions(), returnedShapes)
.succeeded()) {
for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
Value result = std::get<0>(it);

View File

@@ -1152,7 +1152,7 @@ void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
LogicalResult
ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
ValueRange operands, DictionaryAttr attributes,
RegionRange,
OpaqueProperties properties, RegionRange,
SmallVectorImpl<Type> &inferredReturnTypes) {
ExtractOp::Adaptor op(operands, attributes);
auto vectorType = op.getVector().getType().cast<VectorType>();
@@ -2084,7 +2084,7 @@ LogicalResult ShuffleOp::verify() {
LogicalResult
ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
ValueRange operands, DictionaryAttr attributes,
RegionRange,
OpaqueProperties properties, RegionRange,
SmallVectorImpl<Type> &inferredReturnTypes) {
ShuffleOp::Adaptor op(operands, attributes);
auto v1Type = op.getV1().getType().cast<VectorType>();

View File

@@ -2575,7 +2575,6 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
if (!filteredAttrs.empty())
printFilteredAttributesFn(filteredAttrs);
}
void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
// Print the name without quotes if possible.
::printKeywordOrString(attr.getName().strref(), os);
@@ -3355,6 +3354,10 @@ void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
os << ']';
}
// Print the properties.
if (Attribute prop = op->getPropertiesAsAttribute())
os << " <" << prop << '>';
// Print regions.
if (op->getNumRegions() != 0) {
os << " (";
@@ -3365,7 +3368,7 @@ void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
os << ')';
}
auto attrs = op->getAttrs();
auto attrs = op->getDiscardableAttrs();
printOptionalAttrDict(attrs);
// Print the type signature of the operation.
@@ -3509,6 +3512,10 @@ void OperationPrinter::printRegion(Region &region, bool printEntryBlockArgs,
void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
ValueRange operands) {
if (!mapAttr) {
os << "<<NULL AFFINE MAP>>";
return;
}
AffineMap map = mapAttr.getValue();
unsigned numDims = map.getNumDims();
auto printValueName = [&](unsigned pos, bool isSymbol) {

View File

@@ -22,6 +22,7 @@ add_mlir_library(MLIRIR
IntegerSet.cpp
Location.cpp
MLIRContext.cpp
ODSSupport.cpp
Operation.cpp
OperationSupport.cpp
PatternMatch.cpp

View File

@@ -16,6 +16,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
@@ -801,6 +802,64 @@ OperationName::UnregisteredOpModel::verifyRegionInvariants(Operation *) {
return success();
}
Optional<Attribute>
OperationName::UnregisteredOpModel::getInherentAttr(Operation *op,
StringRef name) {
auto dict = dyn_cast_or_null<DictionaryAttr>(getPropertiesAsAttr(op));
if (!dict)
return std::nullopt;
if (Attribute attr = dict.get(name))
return attr;
return std::nullopt;
}
void OperationName::UnregisteredOpModel::setInherentAttr(Operation *op,
StringAttr name,
Attribute value) {
auto dict = dyn_cast_or_null<DictionaryAttr>(getPropertiesAsAttr(op));
assert(dict);
NamedAttrList attrs(dict);
attrs.set(name, value);
*op->getPropertiesStorage().as<Attribute *>() =
attrs.getDictionary(op->getContext());
}
void OperationName::UnregisteredOpModel::populateInherentAttrs(
Operation *op, NamedAttrList &attrs) {}
LogicalResult OperationName::UnregisteredOpModel::verifyInherentAttrs(
OperationName opName, NamedAttrList &attributes,
function_ref<InFlightDiagnostic()> getDiag) {
return success();
}
int OperationName::UnregisteredOpModel::getOpPropertyByteSize() {
return sizeof(Attribute);
}
void OperationName::UnregisteredOpModel::initProperties(
OperationName opName, OpaqueProperties storage, OpaqueProperties init) {
new (storage.as<Attribute *>()) Attribute();
}
void OperationName::UnregisteredOpModel::deleteProperties(
OpaqueProperties prop) {
prop.as<Attribute *>()->~Attribute();
}
void OperationName::UnregisteredOpModel::populateDefaultProperties(
OperationName opName, OpaqueProperties properties) {}
LogicalResult OperationName::UnregisteredOpModel::setPropertiesFromAttr(
Operation *op, Attribute attr, InFlightDiagnostic *diag) {
*op->getPropertiesStorage().as<Attribute *>() = attr;
return success();
}
Attribute
OperationName::UnregisteredOpModel::getPropertiesAsAttr(Operation *op) {
return *op->getPropertiesStorage().as<Attribute *>();
}
void OperationName::UnregisteredOpModel::copyProperties(OpaqueProperties lhs,
OpaqueProperties rhs) {
*lhs.as<Attribute *>() = *rhs.as<Attribute *>();
}
llvm::hash_code
OperationName::UnregisteredOpModel::hashProperties(OpaqueProperties prop) {
return llvm::hash_combine(*prop.as<Attribute *>());
}
//===----------------------------------------------------------------------===//
// RegisteredOperationName
//===----------------------------------------------------------------------===//

View File

@@ -0,0 +1,57 @@
//===- ODSSupport.cpp -----------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains out-of-line implementations of the support types that
// Operation and related classes build on top of.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/ODSSupport.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
using namespace mlir;
LogicalResult mlir::convertFromAttribute(int64_t &storage,
::mlir::Attribute attr,
::mlir::InFlightDiagnostic *diag) {
auto valueAttr = dyn_cast<IntegerAttr>(attr);
if (!valueAttr) {
if (diag)
*diag << "expected IntegerAttr for key `value`";
return failure();
}
storage = valueAttr.getValue().getSExtValue();
return success();
}
Attribute mlir::convertToAttribute(MLIRContext *ctx, int64_t storage) {
return IntegerAttr::get(IntegerType::get(ctx, 64), storage);
}
LogicalResult mlir::convertFromAttribute(MutableArrayRef<int64_t> storage,
::mlir::Attribute attr,
::mlir::InFlightDiagnostic *diag) {
auto valueAttr = dyn_cast<DenseI64ArrayAttr>(attr);
if (!valueAttr) {
if (diag)
*diag << "expected DenseI64ArrayAttr for key `value`";
return failure();
}
if (valueAttr.size() != static_cast<int64_t>(storage.size())) {
if (diag)
*diag << "Size mismatch in attribute conversion: " << valueAttr.size()
<< " vs " << storage.size();
return failure();
}
llvm::copy(valueAttr.asArrayRef(), storage.begin());
return success();
}
Attribute mlir::convertToAttribute(MLIRContext *ctx,
ArrayRef<int64_t> storage) {
return DenseI64ArrayAttr::get(ctx, storage);
}

View File

@@ -7,13 +7,17 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/Operation.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FoldInterfaces.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include <numeric>
@@ -25,19 +29,31 @@ using namespace mlir;
/// Create a new Operation from operation state.
Operation *Operation::create(const OperationState &state) {
return create(state.location, state.name, state.types, state.operands,
state.attributes.getDictionary(state.getContext()),
state.successors, state.regions);
Operation *op =
create(state.location, state.name, state.types, state.operands,
state.attributes.getDictionary(state.getContext()),
state.properties, state.successors, state.regions);
if (LLVM_UNLIKELY(state.propertiesAttr)) {
assert(!state.properties);
LogicalResult result =
op->setPropertiesFromAttribute(state.propertiesAttr,
/*diagnostic=*/nullptr);
assert(result.succeeded() && "invalid properties in op creation");
(void)result;
}
return op;
}
/// Create a new Operation with the specific fields.
Operation *Operation::create(Location location, OperationName name,
TypeRange resultTypes, ValueRange operands,
NamedAttrList &&attributes, BlockRange successors,
NamedAttrList &&attributes,
OpaqueProperties properties, BlockRange successors,
RegionRange regions) {
unsigned numRegions = regions.size();
Operation *op = create(location, name, resultTypes, operands,
std::move(attributes), successors, numRegions);
Operation *op =
create(location, name, resultTypes, operands, std::move(attributes),
properties, successors, numRegions);
for (unsigned i = 0; i < numRegions; ++i)
if (regions[i])
op->getRegion(i).takeBody(*regions[i]);
@@ -47,21 +63,23 @@ Operation *Operation::create(Location location, OperationName name,
/// Create a new Operation with the specific fields.
Operation *Operation::create(Location location, OperationName name,
TypeRange resultTypes, ValueRange operands,
NamedAttrList &&attributes, BlockRange successors,
NamedAttrList &&attributes,
OpaqueProperties properties, BlockRange successors,
unsigned numRegions) {
// Populate default attributes.
name.populateDefaultAttrs(attributes);
return create(location, name, resultTypes, operands,
attributes.getDictionary(location.getContext()), successors,
numRegions);
attributes.getDictionary(location.getContext()), properties,
successors, numRegions);
}
/// Overload of create that takes an existing DictionaryAttr to avoid
/// unnecessarily uniquing a list of attributes.
Operation *Operation::create(Location location, OperationName name,
TypeRange resultTypes, ValueRange operands,
DictionaryAttr attributes, BlockRange successors,
DictionaryAttr attributes,
OpaqueProperties properties, BlockRange successors,
unsigned numRegions) {
assert(llvm::all_of(resultTypes, [](Type t) { return t; }) &&
"unexpected null result type");
@@ -72,6 +90,7 @@ Operation *Operation::create(Location location, OperationName name,
unsigned numSuccessors = successors.size();
unsigned numOperands = operands.size();
unsigned numResults = resultTypes.size();
int opPropertiesAllocSize = name.getOpPropertyByteSize();
// If the operation is known to have no operands, don't allocate an operand
// storage.
@@ -82,8 +101,10 @@ Operation *Operation::create(Location location, OperationName name,
// into account the size of the operation, its trailing objects, and its
// prefixed objects.
size_t byteSize =
totalSizeToAlloc<detail::OperandStorage, BlockOperand, Region, OpOperand>(
needsOperandStorage ? 1 : 0, numSuccessors, numRegions, numOperands);
totalSizeToAlloc<detail::OperandStorage, detail::OpProperties,
BlockOperand, Region, OpOperand>(
needsOperandStorage ? 1 : 0, opPropertiesAllocSize, numSuccessors,
numRegions, numOperands);
size_t prefixByteSize = llvm::alignTo(
Operation::prefixAllocSize(numTrailingResults, numInlineResults),
alignof(Operation));
@@ -91,9 +112,9 @@ Operation *Operation::create(Location location, OperationName name,
void *rawMem = mallocMem + prefixByteSize;
// Create the new Operation.
Operation *op =
::new (rawMem) Operation(location, name, numResults, numSuccessors,
numRegions, attributes, needsOperandStorage);
Operation *op = ::new (rawMem) Operation(
location, name, numResults, numSuccessors, numRegions,
opPropertiesAllocSize, attributes, properties, needsOperandStorage);
assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) &&
"unexpected successors in a non-terminator operation");
@@ -122,16 +143,22 @@ Operation *Operation::create(Location location, OperationName name,
for (unsigned i = 0; i != numSuccessors; ++i)
new (&blockOperands[i]) BlockOperand(op, successors[i]);
// This must be done after properties are initalized.
op->setAttrs(attributes);
return op;
}
Operation::Operation(Location location, OperationName name, unsigned numResults,
unsigned numSuccessors, unsigned numRegions,
DictionaryAttr attributes, bool hasOperandStorage)
int fullPropertiesStorageSize, DictionaryAttr attributes,
OpaqueProperties properties, bool hasOperandStorage)
: location(location), numResults(numResults), numSuccs(numSuccessors),
numRegions(numRegions), hasOperandStorage(hasOperandStorage), name(name),
attrs(attributes) {
numRegions(numRegions), hasOperandStorage(hasOperandStorage),
propertiesStorageSize((fullPropertiesStorageSize + 7) / 8), name(name) {
assert(attributes && "unexpected null attribute dictionary");
assert(fullPropertiesStorageSize <= propertiesCapacity &&
"Properties size overflow");
#ifndef NDEBUG
if (!getDialect() && !getContext()->allowsUnregisteredDialects())
llvm::report_fatal_error(
@@ -140,6 +167,8 @@ Operation::Operation(Location location, OperationName name, unsigned numResults,
"allowUnregisteredDialects() on the MLIRContext, or use "
"-allow-unregistered-dialect with the MLIR tool used.");
#endif
if (fullPropertiesStorageSize)
name.initOpProperties(getPropertiesStorage(), properties);
}
// Operations are deleted through the destroy() member because they are
@@ -168,6 +197,8 @@ Operation::~Operation() {
// Explicitly destroy the regions.
for (auto &region : getRegions())
region.~Region();
if (propertiesStorageSize)
name.destroyOpProperties(getPropertiesStorage());
}
/// Destroy this operation or one of its subclasses.
@@ -259,6 +290,68 @@ InFlightDiagnostic Operation::emitRemark(const Twine &message) {
return diag;
}
DictionaryAttr Operation::getAttrDictionary() {
if (getPropertiesStorageSize()) {
NamedAttrList attrsList = attrs;
getName().populateInherentAttrs(this, attrsList);
return attrsList.getDictionary(getContext());
}
return attrs;
}
void Operation::setAttrs(DictionaryAttr newAttrs) {
assert(newAttrs && "expected valid attribute dictionary");
if (getPropertiesStorageSize()) {
attrs = DictionaryAttr::get(getContext(), {});
for (const NamedAttribute &attr : newAttrs)
setAttr(attr.getName(), attr.getValue());
return;
}
attrs = newAttrs;
}
void Operation::setAttrs(ArrayRef<NamedAttribute> newAttrs) {
if (getPropertiesStorageSize()) {
setAttrs(DictionaryAttr::get(getContext(), {}));
for (const NamedAttribute &attr : newAttrs)
setAttr(attr.getName(), attr.getValue());
return;
}
attrs = DictionaryAttr::get(getContext(), newAttrs);
}
std::optional<Attribute> Operation::getInherentAttr(StringRef name) {
return getName().getInherentAttr(this, name);
}
void Operation::setInherentAttr(StringAttr name, Attribute value) {
getName().setInherentAttr(this, name, value);
}
Attribute Operation::getPropertiesAsAttribute() {
Optional<RegisteredOperationName> info = getRegisteredInfo();
if (LLVM_UNLIKELY(!info))
return *getPropertiesStorage().as<Attribute *>();
return info->getOpPropertiesAsAttribute(this);
}
LogicalResult
Operation::setPropertiesFromAttribute(Attribute attr,
InFlightDiagnostic *diagnostic) {
Optional<RegisteredOperationName> info = getRegisteredInfo();
if (LLVM_UNLIKELY(!info)) {
*getPropertiesStorage().as<Attribute *>() = attr;
return success();
}
return info->setOpPropertiesFromAttribute(this, attr, diagnostic);
}
void Operation::copyProperties(OpaqueProperties rhs) {
name.copyOpProperties(getPropertiesStorage(), rhs);
}
llvm::hash_code Operation::hashProperties() {
return name.hashOpProperties(getPropertiesStorage());
}
//===----------------------------------------------------------------------===//
// Operation Ordering
//===----------------------------------------------------------------------===//
@@ -581,7 +674,7 @@ Operation *Operation::clone(IRMapping &mapper, CloneOptions options) {
// Create the new operation.
auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs,
successors, getNumRegions());
getPropertiesStorage(), successors, getNumRegions());
mapper.map(this, newOp);
// Clone the regions.
@@ -636,6 +729,20 @@ void OpState::printOpName(Operation *op, OpAsmPrinter &p,
p.getStream() << name;
}
/// Parse properties as a Attribute.
ParseResult OpState::genericParseProperties(OpAsmParser &parser,
Attribute &result) {
if (parser.parseLess() || parser.parseAttribute(result) ||
parser.parseGreater())
return failure();
return success();
}
/// Print the properties as a Attribute.
void OpState::genericPrintProperties(OpAsmPrinter &p, Attribute properties) {
p << "<" << properties << ">";
}
/// Emit an error about fatal conditions with this operation, reporting up to
/// any diagnostic handlers that may be listening.
InFlightDiagnostic OpState::emitError(const Twine &message) {

View File

@@ -193,6 +193,23 @@ OperationState::OperationState(Location location, StringRef name,
: OperationState(location, OperationName(name, location.getContext()),
operands, types, attributes, successors, regions) {}
OperationState::~OperationState() {
if (properties)
propertiesDeleter(properties);
}
LogicalResult
OperationState::setProperties(Operation *op,
InFlightDiagnostic *diagnostic) const {
if (LLVM_UNLIKELY(propertiesAttr)) {
assert(!properties);
return op->setPropertiesFromAttribute(propertiesAttr, diagnostic);
}
if (properties)
propertiesSetter(op->getPropertiesStorage(), properties);
return success();
}
void OperationState::addOperands(ValueRange newOperands) {
operands.append(newOperands.begin(), newOperands.end());
}
@@ -633,8 +650,9 @@ llvm::hash_code OperationEquivalence::computeHash(
// - Operation Name
// - Attributes
// - Result Types
llvm::hash_code hash = llvm::hash_combine(
op->getName(), op->getAttrDictionary(), op->getResultTypes());
llvm::hash_code hash =
llvm::hash_combine(op->getName(), op->getAttrDictionary(),
op->getResultTypes(), op->hashProperties());
// - Operands
ValueRange operands = op->getOperands();

View File

@@ -220,15 +220,15 @@ LogicalResult mlir::detail::inferReturnTensorTypes(
function_ref<
LogicalResult(MLIRContext *, std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
RegionRange regions,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &retComponents)>
componentTypeFn,
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
SmallVector<ShapedTypeComponents, 2> retComponents;
if (failed(componentTypeFn(context, location, operands, attributes, regions,
retComponents)))
if (failed(componentTypeFn(context, location, operands, attributes,
properties, regions, retComponents)))
return failure();
for (const auto &shapeAndType : retComponents) {
Type elementTy = shapeAndType.getElementType();
@@ -249,7 +249,12 @@ LogicalResult mlir::detail::inferReturnTensorTypes(
LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes());
auto retTypeFn = cast<InferTypeOpInterface>(op);
return retTypeFn.refineReturnTypes(op->getContext(), op->getLoc(),
op->getOperands(), op->getAttrDictionary(),
op->getRegions(), inferredReturnTypes);
auto result = retTypeFn.refineReturnTypes(
op->getContext(), op->getLoc(), op->getOperands(),
op->getAttrDictionary(), op->getPropertiesStorage(), op->getRegions(),
inferredReturnTypes);
if (failed(result))
op->emitOpError() << "failed to infer returned types";
return result;
}

View File

@@ -1613,8 +1613,8 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
// TODO: Handle failure.
if (failed(inferInterface->inferReturnTypes(
state.getContext(), state.location, state.operands,
state.attributes.getDictionary(state.getContext()), state.regions,
state.types)))
state.attributes.getDictionary(state.getContext()),
state.getRawProperties(), state.regions, state.types)))
return;
} else {
// Otherwise, this is a fixed number of results.

View File

@@ -24,6 +24,7 @@ llvm_add_library(MLIRTableGen STATIC
Pass.cpp
Pattern.cpp
Predicate.cpp
Property.cpp
Region.cpp
SideEffects.cpp
Successor.cpp

View File

@@ -133,13 +133,18 @@ static ::mlir::LogicalResult {0}(
/// functions are stripped anyways.
static const char *const attrConstraintCode = R"(
static ::mlir::LogicalResult {0}(
::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) {
if (attr && !({1})) {
return op->emitOpError("attribute '") << attrName
::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> getDiag) {{
if (attr && !({1}))
return getDiag() << "attribute '" << attrName
<< "' failed to satisfy constraint: {2}";
}
return ::mlir::success();
}
static ::mlir::LogicalResult {0}(
::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) {{
return {0}(attr, attrName, [op]() {{
return op->emitOpError();
});
}
)";
/// Code for a successor constraint.

View File

@@ -103,6 +103,10 @@ bool Dialect::isExtensible() const {
return def->getValueAsBit("isExtensible");
}
bool Dialect::usePropertiesForAttributes() const {
return def->getValueAsBit("usePropertiesForAttributes");
}
bool Dialect::operator==(const Dialect &other) const {
return def == other.def;
}

View File

@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/Argument.h"
#include "mlir/TableGen/Predicate.h"
#include "mlir/TableGen/Trait.h"
#include "mlir/TableGen/Type.h"
@@ -322,14 +323,23 @@ auto Operator::getTraits() const -> llvm::iterator_range<const_trait_iterator> {
return {trait_begin(), trait_end()};
}
auto Operator::attribute_begin() const -> attribute_iterator {
auto Operator::attribute_begin() const -> const_attribute_iterator {
return attributes.begin();
}
auto Operator::attribute_end() const -> attribute_iterator {
auto Operator::attribute_end() const -> const_attribute_iterator {
return attributes.end();
}
auto Operator::getAttributes() const
-> llvm::iterator_range<attribute_iterator> {
-> llvm::iterator_range<const_attribute_iterator> {
return {attribute_begin(), attribute_end()};
}
auto Operator::attribute_begin() -> attribute_iterator {
return attributes.begin();
}
auto Operator::attribute_end() -> attribute_iterator {
return attributes.end();
}
auto Operator::getAttributes() -> llvm::iterator_range<attribute_iterator> {
return {attribute_begin(), attribute_end()};
}
@@ -542,6 +552,7 @@ void Operator::populateOpStructure() {
auto &recordKeeper = def.getRecords();
auto *typeConstraintClass = recordKeeper.getClass("TypeConstraint");
auto *attrClass = recordKeeper.getClass("Attr");
auto *propertyClass = recordKeeper.getClass("Property");
auto *derivedAttrClass = recordKeeper.getClass("DerivedAttr");
auto *opVarClass = recordKeeper.getClass("OpVariable");
numNativeAttributes = 0;
@@ -576,9 +587,14 @@ void Operator::populateOpStructure() {
"derived attributes not allowed in argument list");
attributes.push_back({givenName, Attribute(argDef)});
++numNativeAttributes;
} else if (argDef->isSubClassOf(propertyClass)) {
if (givenName.empty())
PrintFatalError(argDef->getLoc(), "properties must be named");
properties.push_back({givenName, Property(argDef)});
} else {
PrintFatalError(def.getLoc(), "unexpected def type; only defs deriving "
"from TypeConstraint or Attr are allowed");
PrintFatalError(def.getLoc(),
"unexpected def type; only defs deriving "
"from TypeConstraint or Attr or Property are allowed");
}
if (!givenName.empty())
argumentsAndResultsIndex[givenName] = i;
@@ -608,7 +624,7 @@ void Operator::populateOpStructure() {
// `attributes` because we will put their elements' pointers in `arguments`.
// SmallVector may perform re-allocation under the hood when adding new
// elements.
int operandIndex = 0, attrIndex = 0;
int operandIndex = 0, attrIndex = 0, propIndex = 0;
for (unsigned i = 0; i != numArgs; ++i) {
Record *argDef = dyn_cast<DefInit>(argumentValues->getArg(i))->getDef();
if (argDef->isSubClassOf(opVarClass))
@@ -618,11 +634,13 @@ void Operator::populateOpStructure() {
attrOrOperandMapping.push_back(
{OperandOrAttribute::Kind::Operand, operandIndex});
arguments.emplace_back(&operands[operandIndex++]);
} else {
assert(argDef->isSubClassOf(attrClass));
} else if (argDef->isSubClassOf(attrClass)) {
attrOrOperandMapping.push_back(
{OperandOrAttribute::Kind::Attribute, attrIndex});
arguments.emplace_back(&attributes[attrIndex++]);
} else {
assert(argDef->isSubClassOf(propertyClass));
arguments.emplace_back(&properties[propIndex++]);
}
}

View File

@@ -0,0 +1,86 @@
//===- Property.cpp - Property wrapper class ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Property wrapper to simplify using TableGen Record defining a MLIR
// Property.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Property.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
using llvm::DefInit;
using llvm::Init;
using llvm::Record;
using llvm::StringInit;
// Returns the initializer's value as string if the given TableGen initializer
// is a code or string initializer. Returns the empty StringRef otherwise.
static StringRef getValueAsString(const Init *init) {
if (const auto *str = dyn_cast<StringInit>(init))
return str->getValue().trim();
return {};
}
Property::Property(const Record *record) : def(record) {
assert((record->isSubClassOf("Property") || record->isSubClassOf("Attr")) &&
"must be subclass of TableGen 'Property' class");
}
Property::Property(const DefInit *init) : Property(init->getDef()) {}
StringRef Property::getStorageType() const {
const auto *init = def->getValueInit("storageType");
auto type = getValueAsString(init);
if (type.empty())
return "Property";
return type;
}
StringRef Property::getInterfaceType() const {
const auto *init = def->getValueInit("interfaceType");
return getValueAsString(init);
}
StringRef Property::getConvertFromStorageCall() const {
const auto *init = def->getValueInit("convertFromStorage");
return getValueAsString(init);
}
StringRef Property::getAssignToStorageCall() const {
const auto *init = def->getValueInit("assignToStorage");
return getValueAsString(init);
}
StringRef Property::getConvertToAttributeCall() const {
const auto *init = def->getValueInit("convertToAttribute");
return getValueAsString(init);
}
StringRef Property::getConvertFromAttributeCall() const {
const auto *init = def->getValueInit("convertFromAttribute");
return getValueAsString(init);
}
StringRef Property::getHashPropertyCall() const {
return getValueAsString(def->getValueInit("hashProperty"));
}
bool Property::hasDefaultValue() const { return !getDefaultValue().empty(); }
StringRef Property::getDefaultValue() const {
const auto *init = def->getValueInit("defaultValue");
return getValueAsString(init);
}
const llvm::Record &Property::getDef() const { return *def; }

View File

@@ -11,10 +11,10 @@
// COM: bytecode contains
// COM: module {
// COM: version: 1.12
// COM: "test.versionedB"() {attribute = #test.attr_params<24, 42>} : () -> ()
// COM: "test.versionedB"() <{attribute = #test.attr_params<24, 42>}> : () -> ()
// COM: }
// RUN: mlir-opt %S/versioned-attr-1.12.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK1
// CHECK1: "test.versionedB"() {attribute = #test.attr_params<42, 24>} : () -> ()
// CHECK1: "test.versionedB"() <{attribute = #test.attr_params<42, 24>}> : () -> ()
//===--------------------------------------------------------------------===//
// Test attribute upgrade
@@ -23,7 +23,7 @@
// COM: bytecode contains
// COM: module {
// COM: version: 2.0
// COM: "test.versionedB"() {attribute = #test.attr_params<42, 24>} : () -> ()
// COM: "test.versionedB"() <{attribute = #test.attr_params<42, 24>}> : () -> ()
// COM: }
// RUN: mlir-opt %S/versioned-attr-2.0.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK2
// CHECK2: "test.versionedB"() {attribute = #test.attr_params<42, 24>} : () -> ()
// CHECK2: "test.versionedB"() <{attribute = #test.attr_params<42, 24>}> : () -> ()

View File

@@ -11,10 +11,10 @@
// COM: bytecode contains
// COM: module {
// COM: version: 2.0
// COM: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> ()
// COM: "test.versionedA"() <{dims = 123 : i64, modifier = false}> : () -> ()
// COM: }
// RUN: mlir-opt %S/versioned-op-2.0.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK1
// CHECK1: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> ()
// CHECK1: "test.versionedA"() <{dims = 123 : i64, modifier = false}> : () -> ()
//===--------------------------------------------------------------------===//
// Test upgrade
@@ -23,10 +23,10 @@
// COM: bytecode contains
// COM: module {
// COM: version: 1.12
// COM: "test.versionedA"() {dimensions = 123 : i64} : () -> ()
// COM: "test.versionedA"() <{dimensions = 123 : i64}> : () -> ()
// COM: }
// RUN: mlir-opt %S/versioned-op-1.12.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK2
// CHECK2: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> ()
// CHECK2: "test.versionedA"() <{dims = 123 : i64, modifier = false}> : () -> ()
//===--------------------------------------------------------------------===//
// Test forbidden downgrade
@@ -35,7 +35,7 @@
// COM: bytecode contains
// COM: module {
// COM: version: 2.2
// COM: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> ()
// COM: "test.versionedA"() <{dims = 123 : i64, modifier = false}> : () -> ()
// COM: }
// RUN: not mlir-opt %S/versioned-op-2.2.mlirbc 2>&1 | FileCheck %s --check-prefix=ERR_NEW_VERSION
// ERR_NEW_VERSION: current test dialect version is 2.0, can't parse version: 2.2

View File

@@ -41,16 +41,16 @@ func.func @pack_unpack(%arg0: i1, %arg1: i2) -> (i1, i2) {
//
// CHECK-TUP-LABEL: func.func @materializations_tuple_args(
// CHECK-TUP-SAME: %[[ARG0:.*]]: tuple<tuple<>, i1, tuple<tuple<i2>>>) -> (i1, i2) {
// CHECK-TUP-DAG: %[[V0:.*]] = "test.get_tuple_element"(%[[ARG0]]) {index = 0 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<>
// CHECK-TUP-DAG: %[[V1:.*]] = "test.get_tuple_element"(%[[ARG0]]) {index = 1 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1
// CHECK-TUP-DAG: %[[V2:.*]] = "test.get_tuple_element"(%[[ARG0]]) {index = 2 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
// CHECK-TUP-DAG: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 0 : i32} : (tuple<tuple<i2>>) -> tuple<i2>
// CHECK-TUP-DAG: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 0 : i32} : (tuple<i2>) -> i2
// CHECK-TUP-DAG: %[[V5:.*]] = "test.get_tuple_element"(%[[ARG0]]) {index = 0 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<>
// CHECK-TUP-DAG: %[[V6:.*]] = "test.get_tuple_element"(%[[ARG0]]) {index = 1 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1
// CHECK-TUP-DAG: %[[V7:.*]] = "test.get_tuple_element"(%[[ARG0]]) {index = 2 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
// CHECK-TUP-DAG: %[[V8:.*]] = "test.get_tuple_element"(%[[V7]]) {index = 0 : i32} : (tuple<tuple<i2>>) -> tuple<i2>
// CHECK-TUP-DAG: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) {index = 0 : i32} : (tuple<i2>) -> i2
// CHECK-TUP-DAG: %[[V0:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 0 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<>
// CHECK-TUP-DAG: %[[V1:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 1 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1
// CHECK-TUP-DAG: %[[V2:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 2 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
// CHECK-TUP-DAG: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2>
// CHECK-TUP-DAG: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<i2>) -> i2
// CHECK-TUP-DAG: %[[V5:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 0 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<>
// CHECK-TUP-DAG: %[[V6:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 1 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1
// CHECK-TUP-DAG: %[[V7:.*]] = "test.get_tuple_element"(%[[ARG0]]) <{index = 2 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
// CHECK-TUP-DAG: %[[V8:.*]] = "test.get_tuple_element"(%[[V7]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2>
// CHECK-TUP-DAG: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) <{index = 0 : i32}> : (tuple<i2>) -> i2
// CHECK-TUP-DAG: return %[[V1]], %[[V9]] : i1, i2
// If we only convert the func ops, argument materializations are created from
@@ -64,11 +64,11 @@ func.func @pack_unpack(%arg0: i1, %arg1: i2) -> (i1, i2) {
// CHECK-FUNC-DAG: %[[V1:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple<i2>
// CHECK-FUNC-DAG: %[[V2:.*]] = "test.make_tuple"(%[[V1]]) : (tuple<i2>) -> tuple<tuple<i2>>
// CHECK-FUNC-DAG: %[[V3:.*]] = "test.make_tuple"(%[[V0]], %[[ARG0]], %[[V2]]) : (tuple<>, i1, tuple<tuple<i2>>) -> tuple<tuple<>, i1, tuple<tuple<i2>>>
// CHECK-FUNC-DAG: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 0 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<>
// CHECK-FUNC-DAG: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 1 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1
// CHECK-FUNC-DAG: %[[V6:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 2 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
// CHECK-FUNC-DAG: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple<tuple<i2>>) -> tuple<i2>
// CHECK-FUNC-DAG: %[[V8:.*]] = "test.get_tuple_element"(%[[V7]]) {index = 0 : i32} : (tuple<i2>) -> i2
// CHECK-FUNC-DAG: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<>
// CHECK-FUNC-DAG: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 1 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1
// CHECK-FUNC-DAG: %[[V6:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 2 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
// CHECK-FUNC-DAG: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2>
// CHECK-FUNC-DAG: %[[V8:.*]] = "test.get_tuple_element"(%[[V7]]) <{index = 0 : i32}> : (tuple<i2>) -> i2
// CHECK-FUNC-DAG: return %[[V5]], %[[V8]] : i1, i2
// If we convert both tuple and func ops, basically everything disappears.
@@ -117,11 +117,11 @@ func.func @materializations_tuple_args(%arg0: tuple<tuple<>, i1, tuple<tuple<i2>
// CHECK-FUNC-DAG: %[[V1:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple<i2>
// CHECK-FUNC-DAG: %[[V2:.*]] = "test.make_tuple"(%[[V1]]) : (tuple<i2>) -> tuple<tuple<i2>>
// CHECK-FUNC-DAG: %[[V3:.*]] = "test.make_tuple"(%[[V0]], %[[ARG0]], %[[V2]]) : (tuple<>, i1, tuple<tuple<i2>>) -> tuple<tuple<>, i1, tuple<tuple<i2>>>
// CHECK-FUNC-DAG: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 0 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<>
// CHECK-FUNC-DAG: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 1 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1
// CHECK-FUNC-DAG: %[[V6:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 2 : i32} : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
// CHECK-FUNC-DAG: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple<tuple<i2>>) -> tuple<i2>
// CHECK-FUNC-DAG: %[[V8:.*]] = "test.get_tuple_element"(%[[V7]]) {index = 0 : i32} : (tuple<i2>) -> i2
// CHECK-FUNC-DAG: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<>
// CHECK-FUNC-DAG: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 1 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> i1
// CHECK-FUNC-DAG: %[[V6:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 2 : i32}> : (tuple<tuple<>, i1, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
// CHECK-FUNC-DAG: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2>
// CHECK-FUNC-DAG: %[[V8:.*]] = "test.get_tuple_element"(%[[V7]]) <{index = 0 : i32}> : (tuple<i2>) -> i2
// CHECK-FUNC-DAG: return %[[V5]], %[[V8]] : i1, i2
// If we convert both tuple and func ops, basically everything disappears.

View File

@@ -36,13 +36,13 @@ func.func @if_result(%arg0: tuple<tuple<>, i1, tuple<i2>>, %arg1: i1) -> tuple<t
// CHECK-NEXT: %[[V1:.*]] = "test.make_tuple"(%[[V0]], %[[ARG0]]) : (tuple<>, i1) -> tuple<tuple<>, i1>
// CHECK-NEXT: %[[V2:.*]] = scf.if %[[ARG1]] -> (i1) {
// CHECK-NEXT: %[[V3:.*]] = "test.op"(%[[V1]]) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
// CHECK-NEXT: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 0 : i32} : (tuple<tuple<>, i1>) -> tuple<>
// CHECK-NEXT: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 1 : i32} : (tuple<tuple<>, i1>) -> i1
// CHECK-NEXT: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<>
// CHECK-NEXT: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1
// CHECK-NEXT: scf.yield %[[V5]] : i1
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[V6:.*]] = "test.source"() : () -> tuple<tuple<>, i1>
// CHECK-NEXT: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple<tuple<>, i1>) -> tuple<>
// CHECK-NEXT: %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 1 : i32} : (tuple<tuple<>, i1>) -> i1
// CHECK-NEXT: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<>
// CHECK-NEXT: %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1
// CHECK-NEXT: scf.yield %[[V8]] : i1
// CHECK-NEXT: }
// CHECK-NEXT: return %[[V2]] : i1
@@ -94,14 +94,14 @@ func.func @while_operands_results(%arg0: tuple<tuple<>, i1, tuple<i2>>, %arg1: i
// CHECK-NEXT: %[[V1:.*]] = "test.make_tuple"() : () -> tuple<>
// CHECK-NEXT: %[[V2:.*]] = "test.make_tuple"(%[[V1]], %[[ARG2]]) : (tuple<>, i1) -> tuple<tuple<>, i1>
// CHECK-NEXT: %[[V3:.*]] = "test.op"(%[[V2]]) : (tuple<tuple<>, i1>) -> tuple<tuple<>, i1>
// CHECK-NEXT: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 0 : i32} : (tuple<tuple<>, i1>) -> tuple<>
// CHECK-NEXT: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 1 : i32} : (tuple<tuple<>, i1>) -> i1
// CHECK-NEXT: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<>
// CHECK-NEXT: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1
// CHECK-NEXT: scf.condition(%[[ARG1]]) %[[V5]] : i1
// CHECK-NEXT: } do {
// CHECK-NEXT: ^bb0(%[[ARG3:.*]]: i1):
// CHECK-NEXT: %[[V6:.*]] = "test.source"() : () -> tuple<tuple<>, i1>
// CHECK-NEXT: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple<tuple<>, i1>) -> tuple<>
// CHECK-NEXT: %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 1 : i32} : (tuple<tuple<>, i1>) -> i1
// CHECK-NEXT: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<tuple<>, i1>) -> tuple<>
// CHECK-NEXT: %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 1 : i32}> : (tuple<tuple<>, i1>) -> i1
// CHECK-NEXT: scf.yield %[[V8]] : i1
// CHECK-NEXT: }
// CHECK-NEXT: return %[[V0]] : i1

View File

@@ -98,6 +98,7 @@ func.func @shape_of(%value_arg : !shape.value_shape,
// -----
func.func @shape_of_incompatible_return_types(%value_arg : tensor<1x2xindex>) {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{'shape.shape_of' op inferred type(s) 'tensor<2xindex>' are incompatible with return type(s) of operation 'tensor<3xindex>'}}
%0 = shape.shape_of %value_arg : tensor<1x2xindex> -> tensor<3xindex>
return
@@ -268,6 +269,7 @@ func.func @fn(%arg: !shape.shape) -> !shape.witness {
// Test that type inference flags the wrong return type.
func.func @const_shape() {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{'tensor<3xindex>' are incompatible with return type(s) of operation 'tensor<2xindex>'}}
%0 = shape.const_shape [4, 5, 6] : tensor<2xindex>
return
@@ -276,6 +278,7 @@ func.func @const_shape() {
// -----
func.func @invalid_meet(%arg0 : !shape.shape, %arg1 : index) -> index {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{requires all sizes or shapes}}
%result = shape.meet %arg0, %arg1 : !shape.shape, index -> index
return %result : index
@@ -284,6 +287,7 @@ func.func @invalid_meet(%arg0 : !shape.shape, %arg1 : index) -> index {
// -----
func.func @invalid_meet(%arg0 : tensor<2xindex>, %arg1 : tensor<3xindex>) -> tensor<?xindex> {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{unequal shape cardinality}}
%result = shape.meet %arg0, %arg1 : tensor<2xindex>, tensor<3xindex> -> tensor<?xindex>
return %result : tensor<?xindex>

View File

@@ -39,6 +39,7 @@ func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>,
// -----
func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{Cannot concat tensors with different sizes on the non-axis dimension 1}}
%0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
@@ -47,6 +48,7 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
// -----
func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{'tosa.concat' op inferred type(s) 'tensor<3x2xf32>' are incompatible with return type(s) of operation 'tensor<?x?xi8>}}
%0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
return %0 : tensor<?x?xi8>
@@ -100,6 +102,7 @@ func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: ten
// -----
func.func @test_reduce_sum_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{'tosa.reduce_sum' op inferred type(s) 'tensor<1x3x4x5xf32>' are incompatible with return type(s) of operation 'tensor<1x3x4x5xi32>'}}
%0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<2x3x4x5xf32>) -> tensor<1x3x4x5xi32>
return
@@ -108,6 +111,7 @@ func.func @test_reduce_sum_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
// -----
func.func @test_reduce_max_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{'tosa.reduce_max' op inferred type(s) 'tensor<2x3x4x1xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x1xi32>'}}
%0 = "tosa.reduce_max"(%arg0) {axis = 3 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x1xi32>
return
@@ -116,6 +120,7 @@ func.func @test_reduce_max_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
// -----
func.func @test_reduce_min_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{'tosa.reduce_min' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x1x4x5xi32>'}}
%0 = "tosa.reduce_min"(%arg0) {axis = 1 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x1x4x5xi32>
return
@@ -124,6 +129,7 @@ func.func @test_reduce_min_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
// -----
func.func @test_reduce_prod_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{'tosa.reduce_prod' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x5xf32>'}}
%0 = "tosa.reduce_prod"(%arg0) {axis = 1 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>
return
@@ -132,6 +138,7 @@ func.func @test_reduce_prod_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
// -----
func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}}
%0 = "tosa.reshape"(%arg0) {new_shape = array<i64: 13, 21, 3, 1>} : (tensor<13x21x3xf32>) -> tensor<13x21x3x1xi32>
return

View File

@@ -7,7 +7,7 @@ func.func @add_to_worklist_after_inplace_update() {
// worklist of the GreedyPatternRewriteDriver (regardless of the value of
// config.max_iterations).
// CHECK: "test.any_attr_of_i32_str"() {attr = 3 : i32} : () -> ()
// CHECK: "test.any_attr_of_i32_str"() <{attr = 3 : i32}> : () -> ()
"test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
return
}

View File

@@ -587,11 +587,6 @@ func.func @bad_arrow(%arg : !unreg.ptr<(i32)->)
// -----
// expected-error @+1 {{attribute 'attr' occurs more than once in the attribute list}}
test.format_symbol_name_attr_op @name { attr = "xx" }
// -----
func.func @forward_reference_type_check() -> (i8) {
cf.br ^bb2

View File

@@ -1437,3 +1437,4 @@ test.dialect_custom_format_fallback custom_format_fallback
// Check that an op with an optional result parses f80 as type.
// CHECK: test.format_optional_result_d_op : f80
test.format_optional_result_d_op : f80

View File

@@ -0,0 +1,20 @@
// # RUN: mlir-opt %s -split-input-file | mlir-opt |FileCheck %s
// # RUN: mlir-opt %s -mlir-print-op-generic -split-input-file | mlir-opt -mlir-print-op-generic | FileCheck %s --check-prefix=GENERIC
// CHECK: test.with_properties
// CHECK-SAME: <{a = 32 : i64, array = array<i64: 1, 2, 3, 4>, b = "foo"}>
// GENERIC: "test.with_properties"()
// GENERIC-SAME: <{a = 32 : i64, array = array<i64: 1, 2, 3, 4>, b = "foo"}> : () -> ()
test.with_properties <{a = 32 : i64, array = array<i64: 1, 2, 3, 4>, b = "foo"}>
// CHECK: test.with_nice_properties
// CHECK-SAME: "foo bar" is -3
// GENERIC: "test.with_nice_properties"()
// GENERIC-SAME: <{prop = {label = "foo bar", value = -3 : i32}}> : () -> ()
test.with_nice_properties "foo bar" is -3
// CHECK: test.with_wrapped_properties
// CHECK-SAME: "content for properties"
// GENERIC: "test.with_wrapped_properties"()
// GENERIC-SAME: <{prop = "content for properties"}> : () -> ()
test.with_wrapped_properties <{prop = "content for properties"}>

View File

@@ -12,5 +12,5 @@ func.func @test() -> i32 {
}
// CHECK-LABEL: func.func @test
// CHECK-NEXT: %[[C:.*]] = "test.constant"() {value = 33 : i32}
// CHECK-NEXT: %[[C:.*]] = "test.constant"() <{value = 33 : i32}>
// CHECK-NEXT: return %[[C]]

View File

@@ -7,5 +7,5 @@ func.func @test() -> i32 {
}
// CHECK-LABEL: func.func @test
// CHECK-NEXT: %[[C:.*]] = "test.constant"() {value = 5 : i32}
// CHECK-NEXT: %[[C:.*]] = "test.constant"() <{value = 5 : i32}>
// CHECK-NEXT: return %[[C]]

View File

@@ -345,8 +345,8 @@ func.func @failedSingleBlockImplicitTerminator_missing_terminator() {
// -----
// expected-error@+1 {{op attribute 'sym_visibility' failed to satisfy constraint: string attribute}}
"test.symbol"() {sym_name = "foo_2", sym_visibility} : () -> ()
// expected-error@+1 {{invalid properties {sym_name = "foo_2", sym_visibility} for op test.symbol: Invalid attribute `sym_visibility` in property conversion: unit}}
"test.symbol"() <{sym_name = "foo_2", sym_visibility}> : () -> ()
// -----
@@ -390,14 +390,14 @@ func.func @failedMissingOperandSizeAttr(%arg: i32) {
// -----
func.func @failedOperandSizeAttrWrongType(%arg: i32) {
// expected-error @+1 {{requires dense i32 array attribute 'operand_segment_sizes'}}
// expected-error @+1 {{attribute 'operand_segment_sizes' failed to satisfy constraint: i32 dense array attribute}}
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = 10} : (i32, i32, i32, i32) -> ()
}
// -----
func.func @failedOperandSizeAttrWrongElementType(%arg: i32) {
// expected-error @+1 {{requires dense i32 array attribute 'operand_segment_sizes'}}
// expected-error @+1 {{attribute 'operand_segment_sizes' failed to satisfy constraint: i32 dense array attribute}}
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = array<i64: 1, 1, 1, 1>} : (i32, i32, i32, i32) -> ()
}
@@ -655,7 +655,7 @@ func.func @failed_type_traits() {
// Check that we can query traits in attributes
func.func @succeeded_attr_traits() {
// CHECK: "test.attr_with_trait"() {attr = #test.attr_with_trait} : () -> ()
// CHECK: "test.attr_with_trait"() <{attr = #test.attr_with_trait}> : () -> ()
"test.attr_with_trait"() {attr = #test.attr_with_trait} : () -> ()
return
}

View File

@@ -1,7 +1,7 @@
// RUN: mlir-opt -test-int-range-inference %s | FileCheck %s
// CHECK-LABEL: func @constant
// CHECK: %[[cst:.*]] = "test.constant"() {value = 3 : index}
// CHECK: %[[cst:.*]] = "test.constant"() <{value = 3 : index}
// CHECK: return %[[cst]]
func.func @constant() -> index {
%0 = test.with_bounds { umin = 3 : index, umax = 3 : index,
@@ -10,7 +10,7 @@ func.func @constant() -> index {
}
// CHECK-LABEL: func @increment
// CHECK: %[[cst:.*]] = "test.constant"() {value = 4 : index}
// CHECK: %[[cst:.*]] = "test.constant"() <{value = 4 : index}
// CHECK: return %[[cst]]
func.func @increment() -> index {
%0 = test.with_bounds { umin = 3 : index, umax = 3 : index, smin = 0 : index, smax = 0x7fffffffffffffff : index }
@@ -103,8 +103,8 @@ func.func @func_args_unbound(%arg0 : index) -> index {
// CHECK-LABEL: func @propagate_across_while_loop_false()
func.func @propagate_across_while_loop_false() -> index {
// CHECK-DAG: %[[C0:.*]] = "test.constant"() {value = 0
// CHECK-DAG: %[[C1:.*]] = "test.constant"() {value = 1
// CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
// CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
%0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
smin = 0 : index, smax = 0 : index }
%1 = scf.while : () -> index {
@@ -122,8 +122,8 @@ func.func @propagate_across_while_loop_false() -> index {
// CHECK-LABEL: func @propagate_across_while_loop
func.func @propagate_across_while_loop(%arg0 : i1) -> index {
// CHECK-DAG: %[[C0:.*]] = "test.constant"() {value = 0
// CHECK-DAG: %[[C1:.*]] = "test.constant"() {value = 1
// CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
// CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
%0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
smin = 0 : index, smax = 0 : index }
%1 = scf.while : () -> index {
@@ -140,7 +140,7 @@ func.func @propagate_across_while_loop(%arg0 : i1) -> index {
// CHECK-LABEL: func @dont_propagate_across_infinite_loop()
func.func @dont_propagate_across_infinite_loop() -> index {
// CHECK: %[[C0:.*]] = "test.constant"() {value = 0
// CHECK: %[[C0:.*]] = "test.constant"() <{value = 0
%0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
smin = 0 : index, smax = 0 : index }
// CHECK: %[[loopRes:.*]] = scf.while

View File

@@ -10,8 +10,8 @@
// CHECK-SAME: %[[ARG0:.*]]: i1,
// CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
// CHECK: %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple<i1, i32>
// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 0 : i32} : (tuple<i1, i32>) -> i1
// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 1 : i32} : (tuple<i1, i32>) -> i32
// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
// CHECK: return %[[RET0]], %[[RET1]] : i1, i32
// CHECK-12N-LABEL: func @identity(
// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
@@ -61,12 +61,12 @@ func.func @recursive_decomposition(%arg0: tuple<tuple<tuple<i1>>>) -> tuple<tupl
// CHECK: %[[V2:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple<i2>
// CHECK: %[[V3:.*]] = "test.make_tuple"(%[[V2]]) : (tuple<i2>) -> tuple<tuple<i2>>
// CHECK: %[[V4:.*]] = "test.make_tuple"(%[[V0]], %[[V1]], %[[V3]]) : (tuple<>, tuple<i1>, tuple<tuple<i2>>) -> tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>
// CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<>
// CHECK: %[[V6:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 1 : i32} : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<i1>
// CHECK: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple<i1>) -> i1
// CHECK: %[[V8:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 2 : i32} : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
// CHECK: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) {index = 0 : i32} : (tuple<tuple<i2>>) -> tuple<i2>
// CHECK: %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) {index = 0 : i32} : (tuple<i2>) -> i2
// CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 0 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<>
// CHECK: %[[V6:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 1 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<i1>
// CHECK: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<i1>) -> i1
// CHECK: %[[V8:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 2 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
// CHECK: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2>
// CHECK: %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) <{index = 0 : i32}> : (tuple<i2>) -> i2
// CHECK: return %[[V7]], %[[V10]] : i1, i2
// CHECK-12N-LABEL: func @mixed_recursive_decomposition(
// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
@@ -88,12 +88,12 @@ func.func private @callee(tuple<i1, i32>) -> tuple<i1, i32>
// CHECK-SAME: %[[ARG0:.*]]: i1,
// CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
// CHECK: %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple<i1, i32>
// CHECK: %[[CALL_ARG0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 0 : i32} : (tuple<i1, i32>) -> i1
// CHECK: %[[CALL_ARG1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 1 : i32} : (tuple<i1, i32>) -> i32
// CHECK: %[[CALL_ARG0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
// CHECK: %[[CALL_ARG1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
// CHECK: %[[DECOMPOSED:.*]]:2 = call @callee(%[[CALL_ARG0]], %[[CALL_ARG1]]) : (i1, i32) -> (i1, i32)
// CHECK: %[[CALL_RESULT_RECOMPOSED:.*]] = "test.make_tuple"(%[[DECOMPOSED]]#0, %[[DECOMPOSED]]#1) : (i1, i32) -> tuple<i1, i32>
// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) {index = 0 : i32} : (tuple<i1, i32>) -> i1
// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) {index = 1 : i32} : (tuple<i1, i32>) -> i32
// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
// CHECK: return %[[RET0]], %[[RET1]] : i1, i32
// CHECK-12N-LABEL: func @caller(
// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
@@ -131,13 +131,13 @@ func.func @caller(%arg0: tuple<>) -> tuple<> {
// CHECK-LABEL: func @unconverted_op_result() -> (i1, i32) {
// CHECK: %[[UNCONVERTED_VALUE:.*]] = "test.source"() : () -> tuple<i1, i32>
// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 0 : i32} : (tuple<i1, i32>) -> i1
// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 1 : i32} : (tuple<i1, i32>) -> i32
// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
// CHECK: return %[[RET0]], %[[RET1]] : i1, i32
// CHECK-12N-LABEL: func @unconverted_op_result() -> (i1, i32) {
// CHECK-12N: %[[UNCONVERTED_VALUE:.*]] = "test.source"() : () -> tuple<i1, i32>
// CHECK-12N: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 0 : i32} : (tuple<i1, i32>) -> i1
// CHECK-12N: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 1 : i32} : (tuple<i1, i32>) -> i32
// CHECK-12N: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
// CHECK-12N: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
// CHECK-12N: return %[[RET0]], %[[RET1]] : i1, i32
func.func @unconverted_op_result() -> tuple<i1, i32> {
%0 = "test.source"() : () -> (tuple<i1, i32>)
@@ -155,9 +155,9 @@ func.func @unconverted_op_result() -> tuple<i1, i32> {
// CHECK: %[[V0:.*]] = "test.make_tuple"(%[[ARG1]]) : (i32) -> tuple<i32>
// CHECK: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]], %[[V0]]) : (i1, tuple<i32>) -> tuple<i1, tuple<i32>>
// CHECK: %[[V2:.*]] = "test.op"(%[[V1]]) : (tuple<i1, tuple<i32>>) -> tuple<i1, tuple<i32>>
// CHECK: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 0 : i32} : (tuple<i1, tuple<i32>>) -> i1
// CHECK: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 1 : i32} : (tuple<i1, tuple<i32>>) -> tuple<i32>
// CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple<i32>) -> i32
// CHECK: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) <{index = 0 : i32}> : (tuple<i1, tuple<i32>>) -> i1
// CHECK: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) <{index = 1 : i32}> : (tuple<i1, tuple<i32>>) -> tuple<i32>
// CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 0 : i32}> : (tuple<i32>) -> i32
// CHECK: return %[[V3]], %[[V5]] : i1, i32
// CHECK-12N-LABEL: func @nested_unconverted_op_result(
// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
@@ -165,9 +165,9 @@ func.func @unconverted_op_result() -> tuple<i1, i32> {
// CHECK-12N: %[[V0:.*]] = "test.make_tuple"(%[[ARG1]]) : (i32) -> tuple<i32>
// CHECK-12N: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]], %[[V0]]) : (i1, tuple<i32>) -> tuple<i1, tuple<i32>>
// CHECK-12N: %[[V2:.*]] = "test.op"(%[[V1]]) : (tuple<i1, tuple<i32>>) -> tuple<i1, tuple<i32>>
// CHECK-12N: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 0 : i32} : (tuple<i1, tuple<i32>>) -> i1
// CHECK-12N: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 1 : i32} : (tuple<i1, tuple<i32>>) -> tuple<i32>
// CHECK-12N: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple<i32>) -> i32
// CHECK-12N: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) <{index = 0 : i32}> : (tuple<i1, tuple<i32>>) -> i1
// CHECK-12N: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) <{index = 1 : i32}> : (tuple<i1, tuple<i32>>) -> tuple<i32>
// CHECK-12N: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 0 : i32}> : (tuple<i32>) -> i32
// CHECK-12N: return %[[V3]], %[[V5]] : i1, i32
func.func @nested_unconverted_op_result(%arg: tuple<i1, tuple<i32>>) -> tuple<i1, tuple<i32>> {
%0 = "test.op"(%arg) : (tuple<i1, tuple<i32>>) -> (tuple<i1, tuple<i32>>)
@@ -191,12 +191,12 @@ func.func private @callee(tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tup
// CHECK-SAME: %[[I5:.*]]: i5,
// CHECK-SAME: %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) {
// CHECK: %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[I4]], %[[I5]]) : (i4, i5) -> tuple<i4, i5>
// CHECK: %[[ARG_TUPLE_0:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 0 : i32} : (tuple<i4, i5>) -> i4
// CHECK: %[[ARG_TUPLE_1:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 1 : i32} : (tuple<i4, i5>) -> i5
// CHECK: %[[ARG_TUPLE_0:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) <{index = 0 : i32}> : (tuple<i4, i5>) -> i4
// CHECK: %[[ARG_TUPLE_1:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) <{index = 1 : i32}> : (tuple<i4, i5>) -> i5
// CHECK: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[ARG_TUPLE_0]], %[[ARG_TUPLE_1]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
// CHECK: %[[RET_TUPLE:.*]] = "test.make_tuple"(%[[CALL]]#3, %[[CALL]]#4) : (i4, i5) -> tuple<i4, i5>
// CHECK: %[[RET_TUPLE_0:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) {index = 0 : i32} : (tuple<i4, i5>) -> i4
// CHECK: %[[RET_TUPLE_1:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) {index = 1 : i32} : (tuple<i4, i5>) -> i5
// CHECK: %[[RET_TUPLE_0:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) <{index = 0 : i32}> : (tuple<i4, i5>) -> i4
// CHECK: %[[RET_TUPLE_1:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) <{index = 1 : i32}> : (tuple<i4, i5>) -> i5
// CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[RET_TUPLE_0]], %[[RET_TUPLE_1]], %[[CALL]]#5 : i1, i2, i3, i4, i5, i6
// CHECK-12N-LABEL: func @caller(
// CHECK-12N-SAME: %[[I1:.*]]: i1,

View File

@@ -2,7 +2,7 @@
// CHECK-LABEL: verifyDirectPattern
func.func @verifyDirectPattern() -> i32 {
// CHECK-NEXT: "test.legal_op_a"() {status = "Success"}
// CHECK-NEXT: "test.legal_op_a"() <{status = "Success"}
%result = "test.illegal_op_a"() : () -> (i32)
// expected-remark@+1 {{op 'func.return' is not legalizable}}
return %result : i32
@@ -10,7 +10,7 @@ func.func @verifyDirectPattern() -> i32 {
// CHECK-LABEL: verifyLargerBenefit
func.func @verifyLargerBenefit() -> i32 {
// CHECK-NEXT: "test.legal_op_a"() {status = "Success"}
// CHECK-NEXT: "test.legal_op_a"() <{status = "Success"}
%result = "test.illegal_op_c"() : () -> (i32)
// expected-remark@+1 {{op 'func.return' is not legalizable}}
return %result : i32

View File

@@ -7,7 +7,7 @@ func.func @foo() -> i32 {
// The new operation should be present in the output and contain an attribute
// with value "42" that results from folding.
// CHECK: "test.op_in_place_fold"(%{{.*}}) {attr = 42 : i32}
// CHECK: "test.op_in_place_fold"(%{{.*}}) <{attr = 42 : i32}
%0 = "test.op_in_place_fold_anchor"(%c42) : (i32) -> (i32)
return %0 : i32
}

View File

@@ -21,6 +21,7 @@
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/ODSSupport.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
@@ -43,6 +44,36 @@
using namespace mlir;
using namespace test;
Attribute MyPropStruct::asAttribute(MLIRContext *ctx) const {
return StringAttr::get(ctx, content);
}
LogicalResult MyPropStruct::setFromAttr(MyPropStruct &prop, Attribute attr,
InFlightDiagnostic *diag) {
StringAttr strAttr = attr.dyn_cast<StringAttr>();
if (!strAttr) {
if (diag)
*diag << "Expect StringAttr but got " << attr;
return failure();
}
prop.content = strAttr.getValue();
return success();
}
llvm::hash_code MyPropStruct::hash() const {
return hash_value(StringRef(content));
}
static LogicalResult setPropertiesFromAttribute(PropertiesWithCustomPrint &prop,
Attribute attr,
InFlightDiagnostic *diagnostic);
static DictionaryAttr
getPropertiesAsAttribute(MLIRContext *ctx,
const PropertiesWithCustomPrint &prop);
static llvm::hash_code computeHash(const PropertiesWithCustomPrint &prop);
static void customPrintProperties(OpAsmPrinter &p,
const PropertiesWithCustomPrint &prop);
static ParseResult customParseProperties(OpAsmParser &parser,
PropertiesWithCustomPrint &prop);
void test::registerTestDialect(DialectRegistry &registry) {
registry.insert<TestDialect>();
}
@@ -514,7 +545,7 @@ Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
::mlir::LogicalResult FormatInferType2Op::inferReturnTypes(
::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location,
::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
::mlir::RegionRange regions,
OpaqueProperties properties, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
return ::mlir::success();
@@ -1264,7 +1295,7 @@ LogicalResult TestOpWithVariadicResultsAndFolder::fold(
}
OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
if (adaptor.getOp() && !(*this)->hasAttr("attr")) {
if (adaptor.getOp() && !(*this)->getAttr("attr")) {
// The folder adds "attr" if not present.
(*this)->setAttr("attr", adaptor.getOp());
return getResult();
@@ -1297,7 +1328,7 @@ OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
MLIRContext *, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType() != operands[1].getType()) {
return emitOptionalError(location, "operand type mismatch ",
@@ -1312,16 +1343,17 @@ LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
// refineReturnType, currently only refineReturnType can be omitted.
LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &returnTypes) {
returnTypes.clear();
return OpWithRefineTypeInterfaceOp::refineReturnTypes(
context, location, operands, attributes, regions, returnTypes);
context, location, operands, attributes, properties, regions,
returnTypes);
}
LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
MLIRContext *, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &returnTypes) {
if (operands[0].getType() != operands[1].getType()) {
return emitOptionalError(location, "operand type mismatch ",
@@ -1340,7 +1372,8 @@ LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
MLIRContext *context, std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
// Create return type consisting of the last element of the first operand.
auto operandType = operands.front().getType();
@@ -1797,6 +1830,59 @@ OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) {
return nullptr;
}
static LogicalResult
setPropertiesFromAttribute(PropertiesWithCustomPrint &prop, Attribute attr,
InFlightDiagnostic *diagnostic) {
DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
if (!dict) {
if (diagnostic)
*diagnostic << "expected DictionaryAttr to set TestProperties";
return failure();
}
auto label = dict.getAs<mlir::StringAttr>("label");
if (!label) {
if (diagnostic)
*diagnostic << "expected StringAttr for key `label`";
return failure();
}
auto valueAttr = dict.getAs<IntegerAttr>("value");
if (!valueAttr) {
if (diagnostic)
*diagnostic << "expected IntegerAttr for key `value`";
return failure();
}
prop.label = std::make_shared<std::string>(label.getValue());
prop.value = valueAttr.getValue().getSExtValue();
return success();
}
static DictionaryAttr
getPropertiesAsAttribute(MLIRContext *ctx,
const PropertiesWithCustomPrint &prop) {
SmallVector<NamedAttribute> attrs;
Builder b{ctx};
attrs.push_back(b.getNamedAttr("label", b.getStringAttr(*prop.label)));
attrs.push_back(b.getNamedAttr("value", b.getI32IntegerAttr(prop.value)));
return b.getDictionaryAttr(attrs);
}
static llvm::hash_code computeHash(const PropertiesWithCustomPrint &prop) {
return llvm::hash_combine(prop.value, StringRef(*prop.label));
}
static void customPrintProperties(OpAsmPrinter &p,
const PropertiesWithCustomPrint &prop) {
p.printKeywordOrString(*prop.label);
p << " is " << prop.value;
}
static ParseResult customParseProperties(OpAsmParser &parser,
PropertiesWithCustomPrint &prop) {
std::string label;
if (parser.parseKeywordOrString(&label) || parser.parseKeyword("is") ||
parser.parseInteger(prop.value))
return failure();
prop.label = std::make_shared<std::string>(std::move(label));
return success();
}
#include "TestOpEnums.cpp.inc"
#include "TestOpInterfaces.cpp.inc"
#include "TestTypeInterfaces.cpp.inc"

View File

@@ -42,6 +42,8 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include <memory>
namespace mlir {
class DLTIDialect;
class RewritePatternSet;
@@ -54,6 +56,30 @@ class RewritePatternSet;
#include "TestOpInterfaces.h.inc"
#include "TestOpsDialect.h.inc"
namespace test {
// Define some classes to exercises the Properties feature.
struct PropertiesWithCustomPrint {
/// A shared_ptr to a const object is safe: it is equivalent to a value-based
/// member. Here the label will be deallocated when the last operation
/// refering to it is destroyed. However there is no pool-allocation: this is
/// offloaded to the client.
std::shared_ptr<const std::string> label;
int value;
};
class MyPropStruct {
public:
std::string content;
// These three methods are invoked through the `MyStructProperty` wrapper
// defined in TestOps.td
mlir::Attribute asAttribute(mlir::MLIRContext *ctx) const;
static mlir::LogicalResult setFromAttr(MyPropStruct &prop,
mlir::Attribute attr,
mlir::InFlightDiagnostic *diag);
llvm::hash_code hash() const;
};
} // namespace test
#define GET_OP_CLASSES
#include "TestOps.h.inc"

View File

@@ -24,6 +24,7 @@ def Test_Dialect : Dialect {
let useDefaultTypePrinterParser = 0;
let useDefaultAttributePrinterParser = 1;
let isExtensible = 1;
let usePropertiesForAttributes = 1;
let dependentDialects = ["::mlir::DLTIDialect"];
let extraClassDeclaration = [{

View File

@@ -430,7 +430,7 @@ def VariadicRegionInferredTypesOp : TEST_Op<"variadic_region_inferred",
let extraClassDeclaration = [{
static mlir::LogicalResult inferReturnTypes(mlir::MLIRContext *context,
std::optional<::mlir::Location> location, mlir::ValueRange operands,
mlir::DictionaryAttr attributes, mlir::RegionRange regions,
mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, mlir::RegionRange regions,
llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
inferredReturnTypes.assign({mlir::IntegerType::get(context, 16)});
return mlir::success();
@@ -2524,7 +2524,7 @@ def FormatInferTypeOp : TEST_Op<"format_infer_type", [InferTypeOpInterface]> {
let extraClassDeclaration = [{
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
::std::optional<::mlir::Location> location, ::mlir::ValueRange operands,
::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
return ::mlir::success();
@@ -2547,7 +2547,7 @@ class FormatInferAllTypesBaseOp<string mnemonic, list<Trait> traits = []>
let extraClassDeclaration = [{
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
::std::optional<::mlir::Location> location, ::mlir::ValueRange operands,
::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
::mlir::TypeRange operandTypes = operands.getTypes();
inferredReturnTypes.assign(operandTypes.begin(), operandTypes.end());
@@ -2594,7 +2594,7 @@ def FormatInferTypeRegionsOp
let extraClassDeclaration = [{
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
::std::optional<::mlir::Location> location, ::mlir::ValueRange operands,
::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
if (regions.empty())
return ::mlir::failure();
@@ -2615,9 +2615,10 @@ def FormatInferTypeVariadicOperandsOp
let extraClassDeclaration = [{
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
::std::optional<::mlir::Location> location, ::mlir::ValueRange operands,
::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
FormatInferTypeVariadicOperandsOpAdaptor adaptor(operands, attributes);
FormatInferTypeVariadicOperandsOpAdaptor adaptor(
operands, attributes, *properties.as<Properties *>(), {});
auto aTypes = adaptor.getA().getTypes();
auto bTypes = adaptor.getB().getTypes();
inferredReturnTypes.append(aTypes.begin(), aTypes.end());
@@ -2823,7 +2824,7 @@ class TableGenBuildInferReturnTypeBaseOp<string mnemonic,
let extraClassDeclaration = [{
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *,
::std::optional<::mlir::Location> location, ::mlir::ValueRange operands,
::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
::mlir::DictionaryAttr attributes, mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
inferredReturnTypes.assign({operands[0].getType()});
return ::mlir::success();
@@ -3280,4 +3281,77 @@ def TestVersionedOpB : TEST_Op<"versionedB"> {
);
}
//===----------------------------------------------------------------------===//
// Test Properties
//===----------------------------------------------------------------------===//
// Op with a properties struct defined inline.
def TestOpWithProperties : TEST_Op<"with_properties"> {
let assemblyFormat = "prop-dict attr-dict";
let arguments = (ins
Property<"int64_t">:$a,
StrAttr:$b, // Attributes can directly be used here.
ArrayProperty<"int64_t", 4>:$array // example of an array
);
}
// Demonstrate how to wrap an existing C++ class named MyPropStruct.
def MyStructProperty : Property<"MyPropStruct"> {
let convertToAttribute = "$_storage.asAttribute($_ctxt)";
let convertFromAttribute = "return MyPropStruct::setFromAttr($_storage, $_attr, $_diag);";
let hashProperty = "$_storage.hash();";
}
def TestOpWithWrappedProperties : TEST_Op<"with_wrapped_properties"> {
let assemblyFormat = "prop-dict attr-dict";
let arguments = (ins
MyStructProperty:$prop
);
}
// Op with a properties struct defined out-of-line. The struct has custom
// printer/parser.
def PropertiesWithCustomPrint : Property<"PropertiesWithCustomPrint"> {
let convertToAttribute = [{
getPropertiesAsAttribute($_ctxt, $_storage)
}];
let convertFromAttribute = [{
return setPropertiesFromAttribute($_storage, $_attr, $_diag);
}];
let hashProperty = [{
computeHash($_storage);
}];
}
def TestOpWithNiceProperties : TEST_Op<"with_nice_properties"> {
let assemblyFormat = "prop-dict attr-dict";
let arguments = (ins
PropertiesWithCustomPrint:$prop
);
let extraClassDeclaration = [{
void printProperties(::mlir::MLIRContext *ctx, ::mlir::OpAsmPrinter &p,
const Properties &prop);
static ::mlir::ParseResult parseProperties(::mlir::OpAsmParser &parser,
::mlir::OperationState &result);
}];
let extraClassDefinition = [{
void TestOpWithNiceProperties::printProperties(::mlir::MLIRContext *ctx,
::mlir::OpAsmPrinter &p, const Properties &prop) {
customPrintProperties(p, prop.prop);
}
::mlir::ParseResult TestOpWithNiceProperties::parseProperties(
::mlir::OpAsmParser &parser,
::mlir::OperationState &result) {
Properties &prop = result.getOrAddProperties<Properties>();
if (customParseProperties(parser, prop.prop))
return failure();
return success();
}
}];
}
#endif // TEST_OPS

View File

@@ -434,7 +434,8 @@ static void invokeCreateWithInferredReturnType(Operation *op) {
SmallVector<Type, 2> inferredReturnTypes;
if (succeeded(OpTy::inferReturnTypes(
context, std::nullopt, values, op->getAttrDictionary(),
op->getRegions(), inferredReturnTypes))) {
op->getPropertiesStorage(), op->getRegions(),
inferredReturnTypes))) {
OperationState state(location, OpTy::getOperationName());
// TODO: Expand to regions.
OpTy::build(b, state, values, op->getAttrs());

View File

@@ -69,8 +69,8 @@ def OpC : NS_Op<"op_c"> {
/// Test that an attribute contraint was generated.
// CHECK: static ::mlir::LogicalResult [[$A_ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]](
// CHECK: if (attr && !((attrPred(attr, *op)))) {
// CHECK-NEXT: return op->emitOpError("attribute '") << attrName
// CHECK: if (attr && !((attrPred(attr, *op))))
// CHECK-NEXT: return getDiag() << "attribute '" << attrName
// CHECK-NEXT: << "' failed to satisfy constraint: an attribute";
/// Test that duplicate attribute constraint was not generated.
@@ -78,8 +78,9 @@ def OpC : NS_Op<"op_c"> {
/// Test that a attribute constraint with a different description was generated.
// CHECK: static ::mlir::LogicalResult [[$O_ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]](
// CHECK: if (attr && !((attrPred(attr, *op)))) {
// CHECK-NEXT: return op->emitOpError("attribute '") << attrName
// CHECK: static ::mlir::LogicalResult [[$O_ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]](
// CHECK: if (attr && !((attrPred(attr, *op))))
// CHECK-NEXT: return getDiag() << "attribute '" << attrName
// CHECK-NEXT: << "' failed to satisfy constraint: another attribute";
/// Test that a successor contraint was generated.

View File

@@ -34,13 +34,13 @@ def OpUsingAllOfThose : Op<Test_Dialect, "OpUsingAllOfThose"> {
// CHECK-NEXT: << " must be TypeInterfaceInNamespace instance, but got " << type;
// CHECK: static ::mlir::LogicalResult {{__mlir_ods_local_attr_constraint.*}}(
// CHECK: if (attr && !((attr.isa<TopLevelAttrInterface>()))) {
// CHECK-NEXT: return op->emitOpError("attribute '") << attrName
// CHECK: if (attr && !((attr.isa<TopLevelAttrInterface>())))
// CHECK-NEXT: return getDiag() << "attribute '" << attrName
// CHECK-NEXT: << "' failed to satisfy constraint: TopLevelAttrInterface instance";
// CHECK: static ::mlir::LogicalResult {{__mlir_ods_local_attr_constraint.*}}(
// CHECK: if (attr && !((attr.isa<test::AttrInterfaceInNamespace>()))) {
// CHECK-NEXT: return op->emitOpError("attribute '") << attrName
// CHECK: if (attr && !((attr.isa<test::AttrInterfaceInNamespace>())))
// CHECK-NEXT: return getDiag() << "attribute '" << attrName
// CHECK-NEXT: << "' failed to satisfy constraint: AttrInterfaceInNamespace instance";
// CHECK: TopLevelAttrInterface OpUsingAllOfThose::getAttr1()

View File

@@ -68,7 +68,7 @@ def AOp : NS_Op<"a_op", []> {
// DEF: ::mlir::LogicalResult AOpAdaptor::verify
// DEF: ::mlir::Attribute tblgen_aAttr;
// DEF-NEXT: while (true) {
// DEF: while (true) {
// DEF-NEXT: if (namedAttrIt == namedAttrRange.end())
// DEF-NEXT: return emitError(loc, "'test.a_op' op ""requires attribute 'aAttr'");
// DEF-NEXT: if (namedAttrIt->getName() == AOp::getAAttrAttrName(*odsOpName)) {
@@ -217,10 +217,10 @@ def AgetOp : Op<Test2_Dialect, "a_get_op", []> {
// DEF: ::mlir::LogicalResult AgetOpAdaptor::verify
// DEF: ::mlir::Attribute tblgen_aAttr;
// DEF-NEXT: while (true)
// DEF: while (true)
// DEF: ::mlir::Attribute tblgen_bAttr;
// DEF-NEXT: ::mlir::Attribute tblgen_cAttr;
// DEF-NEXT: while (true)
// DEF: while (true)
// DEF: if (tblgen_aAttr && !((some-condition)))
// DEF-NEXT: return emitError(loc, "'test2.a_get_op' op ""attribute 'aAttr' failed to satisfy constraint: some attribute kind");
// DEF: if (tblgen_bAttr && !((some-condition)))

View File

@@ -126,7 +126,7 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
// DEFS-LABEL: NS::AOp definitions
// DEFS: AOpGenericAdaptorBase::AOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions)
// DEFS: AOpGenericAdaptorBase::AOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, ::mlir::EmptyProperties properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions)
// DEFS: ::mlir::RegionRange AOpGenericAdaptorBase::getSomeRegions()
// DEFS-NEXT: return odsRegions.drop_front(1);
// DEFS: ::mlir::RegionRange AOpGenericAdaptorBase::getRegions()

View File

@@ -64,7 +64,7 @@ def OptionalGroupA : TestFormat_Op<[{
// CHECK-NEXT: result.addAttribute("a", parser.getBuilder().getUnitAttr())
// CHECK: parser.parseKeyword("bar")
// CHECK-LABEL: OptionalGroupB::print
// CHECK: if (!(*this)->getAttr("a"))
// CHECK: if (!getAAttr())
// CHECK-NEXT: odsPrinter << ' ' << "foo"
// CHECK-NEXT: else
// CHECK-NEXT: odsPrinter << ' ' << "bar"
@@ -74,7 +74,8 @@ def OptionalGroupB : TestFormat_Op<[{
// Optional group anchored on a default-valued attribute:
// CHECK-LABEL: OptionalGroupC::parse
// CHECK: if ((*this)->getAttr("a") != ::mlir::OpBuilder((*this)->getContext()).getStringAttr("default")) {
// CHECK: if (getAAttr() && getAAttr() != ::mlir::OpBuilder((*this)->getContext()).getStringAttr("default")) {
// CHECK-NEXT: odsPrinter << ' ';
// CHECK-NEXT: odsPrinter.printAttributeWithoutType(getAAttr());
// CHECK-NEXT: }

View File

@@ -152,7 +152,7 @@ def OpL3 : NS_Op<"op_with_all_types_constraint",
// CHECK-LABEL: LogicalResult OpL3::inferReturnTypes
// CHECK-NOT: }
// CHECK: ::mlir::Type odsInferredType0 = attributes.get("a").cast<::mlir::TypedAttr>().getType();
// CHECK: ::mlir::Type odsInferredType0 = odsInferredTypeAttr0.getType();
// CHECK: inferredReturnTypes[0] = odsInferredType0;
def OpL4 : NS_Op<"two_inference_edges", [

View File

@@ -5,8 +5,8 @@ func.func @verifyFusedLocs(%arg0 : i32) -> i32 {
%0 = "test.op_a"(%arg0) {attr = 10 : i32} : (i32) -> i32 loc("a")
%result = "test.op_a"(%0) {attr = 20 : i32} : (i32) -> i32 loc("b")
// CHECK: "test.op_b"(%arg0) {attr = 10 : i32} : (i32) -> i32 loc("a")
// CHECK: "test.op_b"(%arg0) {attr = 20 : i32} : (i32) -> i32 loc(fused["b", "a"])
// CHECK: "test.op_b"(%arg0) <{attr = 10 : i32}> : (i32) -> i32 loc("a")
// CHECK: "test.op_b"(%arg0) <{attr = 20 : i32}> : (i32) -> i32 loc(fused["b", "a"])
return %result : i32
}
@@ -41,7 +41,7 @@ func.func @verifyZeroArg() -> i32 {
// CHECK-LABEL: testIgnoreArgMatch
// CHECK-SAME: (%{{[a-z0-9]*}}: i32 loc({{[^)]*}}), %[[ARG1:[a-z0-9]*]]: i32 loc({{[^)]*}}),
func.func @testIgnoreArgMatch(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: f32) {
// CHECK: "test.ignore_arg_match_dst"(%[[ARG1]]) {f = 15 : i64}
// CHECK: "test.ignore_arg_match_dst"(%[[ARG1]]) <{f = 15 : i64}>
"test.ignore_arg_match_src"(%arg0, %arg1, %arg2) {d = 42, e = 24, f = 15} : (i32, i32, i32) -> ()
// CHECK: test.ignore_arg_match_src
@@ -57,7 +57,7 @@ func.func @testIgnoreArgMatch(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: f32) {
// CHECK-LABEL: verifyInterleavedOperandAttribute
// CHECK-SAME: %[[ARG0:.*]]: i32 loc({{[^)]*}}), %[[ARG1:.*]]: i32 loc({{[^)]*}})
func.func @verifyInterleavedOperandAttribute(%arg0: i32, %arg1: i32) {
// CHECK: "test.interleaved_operand_attr2"(%[[ARG0]], %[[ARG1]]) {attr1 = 15 : i64, attr2 = 42 : i64}
// CHECK: "test.interleaved_operand_attr2"(%[[ARG0]], %[[ARG1]]) <{attr1 = 15 : i64, attr2 = 42 : i64}>
"test.interleaved_operand_attr1"(%arg0, %arg1) {attr1 = 15, attr2 = 42} : (i32, i32) -> ()
return
}
@@ -69,13 +69,13 @@ func.func @verifyBenefit(%arg0 : i32) -> i32 {
%2 = "test.op_g"(%1) : (i32) -> i32
// CHECK: "test.op_f"(%arg0)
// CHECK: "test.op_b"(%arg0) {attr = 34 : i32}
// CHECK: "test.op_b"(%arg0) <{attr = 34 : i32}>
return %0 : i32
}
// CHECK-LABEL: verifyNativeCodeCall
func.func @verifyNativeCodeCall(%arg0: i32, %arg1: i32) -> (i32, i32) {
// CHECK: %0 = "test.native_code_call2"(%arg0) {attr = [42, 24]} : (i32) -> i32
// CHECK: %0 = "test.native_code_call2"(%arg0) <{attr = [42, 24]}> : (i32) -> i32
// CHECK: return %0, %arg1
%0 = "test.native_code_call1"(%arg0, %arg1) {choice = true, attr1 = 42, attr2 = 24} : (i32, i32) -> (i32)
%1 = "test.native_code_call1"(%arg0, %arg1) {choice = false, attr1 = 42, attr2 = 24} : (i32, i32) -> (i32)
@@ -215,7 +215,7 @@ func.func @symbolBinding(%arg0: i32) -> i32 {
// An op with one use is matched.
// CHECK: %0 = "test.symbol_binding_b"(%arg0)
// CHECK: %1 = "test.symbol_binding_c"(%0)
// CHECK: %2 = "test.symbol_binding_d"(%0, %1) {attr = 42 : i64}
// CHECK: %2 = "test.symbol_binding_d"(%0, %1) <{attr = 42 : i64}>
%0 = "test.symbol_binding_a"(%arg0) {attr = 42} : (i32) -> (i32)
// An op without any use is not matched.
@@ -239,21 +239,21 @@ func.func @symbolBindingNoResult(%arg0: i32) {
// CHECK-LABEL: succeedMatchOpAttr
func.func @succeedMatchOpAttr() -> i32 {
// CHECK: "test.match_op_attribute2"() {default_valued_attr = 3 : i32, more_attr = 4 : i32, optional_attr = 2 : i32, required_attr = 1 : i32}
// CHECK: "test.match_op_attribute2"() <{default_valued_attr = 3 : i32, more_attr = 4 : i32, optional_attr = 2 : i32, required_attr = 1 : i32}>
%0 = "test.match_op_attribute1"() {required_attr = 1: i32, optional_attr = 2: i32, default_valued_attr = 3: i32, more_attr = 4: i32} : () -> (i32)
return %0: i32
}
// CHECK-LABEL: succeedMatchMissingOptionalAttr
func.func @succeedMatchMissingOptionalAttr() -> i32 {
// CHECK: "test.match_op_attribute2"() {default_valued_attr = 3 : i32, more_attr = 4 : i32, required_attr = 1 : i32}
// CHECK: "test.match_op_attribute2"() <{default_valued_attr = 3 : i32, more_attr = 4 : i32, required_attr = 1 : i32}>
%0 = "test.match_op_attribute1"() {required_attr = 1: i32, default_valued_attr = 3: i32, more_attr = 4: i32} : () -> (i32)
return %0: i32
}
// CHECK-LABEL: succeedMatchMissingDefaultValuedAttr
func.func @succeedMatchMissingDefaultValuedAttr() -> i32 {
// CHECK: "test.match_op_attribute2"() {default_valued_attr = 42 : i32, more_attr = 4 : i32, optional_attr = 2 : i32, required_attr = 1 : i32}
// CHECK: "test.match_op_attribute2"() <{default_valued_attr = 42 : i32, more_attr = 4 : i32, optional_attr = 2 : i32, required_attr = 1 : i32}>
%0 = "test.match_op_attribute1"() {required_attr = 1: i32, optional_attr = 2: i32, more_attr = 4: i32} : () -> (i32)
return %0: i32
}
@@ -267,7 +267,7 @@ func.func @failedMatchAdditionalConstraintNotSatisfied() -> i32 {
// CHECK-LABEL: verifyConstantAttr
func.func @verifyConstantAttr(%arg0 : i32) -> i32 {
// CHECK: "test.op_b"(%arg0) {attr = 17 : i32} : (i32) -> i32 loc("a")
// CHECK: "test.op_b"(%arg0) <{attr = 17 : i32}> : (i32) -> i32 loc("a")
%0 = "test.op_c"(%arg0) : (i32) -> i32 loc("a")
return %0 : i32
}
@@ -275,12 +275,12 @@ func.func @verifyConstantAttr(%arg0 : i32) -> i32 {
// CHECK-LABEL: verifyUnitAttr
func.func @verifyUnitAttr() -> (i32, i32) {
// Unit attribute present in the matched op is propagated as attr2.
// CHECK: "test.match_op_attribute4"() {attr1, attr2} : () -> i32
// CHECK: "test.match_op_attribute4"() <{attr1, attr2}> : () -> i32
%0 = "test.match_op_attribute3"() {attr} : () -> i32
// Since the original op doesn't have the unit attribute, the new op
// only has the constant-constructed unit attribute attr1.
// CHECK: "test.match_op_attribute4"() {attr1} : () -> i32
// CHECK: "test.match_op_attribute4"() <{attr1}> : () -> i32
%1 = "test.match_op_attribute3"() : () -> i32
return %0, %1 : i32, i32
}
@@ -291,7 +291,7 @@ func.func @verifyUnitAttr() -> (i32, i32) {
// CHECK-LABEL: testConstOp
func.func @testConstOp() -> (i32) {
// CHECK-NEXT: [[C0:%.+]] = "test.constant"() {value = 1
// CHECK-NEXT: [[C0:%.+]] = "test.constant"() <{value = 1
%0 = "test.constant"() {value = 1 : i32} : () -> i32
// CHECK-NEXT: return [[C0]]
@@ -300,7 +300,7 @@ func.func @testConstOp() -> (i32) {
// CHECK-LABEL: testConstOpUsed
func.func @testConstOpUsed() -> (i32) {
// CHECK-NEXT: [[C0:%.+]] = "test.constant"() {value = 1
// CHECK-NEXT: [[C0:%.+]] = "test.constant"() <{value = 1
%0 = "test.constant"() {value = 1 : i32} : () -> i32
// CHECK-NEXT: [[V0:%.+]] = "test.op_s"([[C0]])
@@ -312,11 +312,11 @@ func.func @testConstOpUsed() -> (i32) {
// CHECK-LABEL: testConstOpReplaced
func.func @testConstOpReplaced() -> (i32) {
// CHECK-NEXT: [[C0:%.+]] = "test.constant"() {value = 1
// CHECK-NEXT: [[C0:%.+]] = "test.constant"() <{value = 1
%0 = "test.constant"() {value = 1 : i32} : () -> i32
%1 = "test.constant"() {value = 2 : i32} : () -> i32
// CHECK: [[V0:%.+]] = "test.op_s"([[C0]]) {value = 2 : i32}
// CHECK: [[V0:%.+]] = "test.op_s"([[C0]]) <{value = 2 : i32}
%2 = "test.op_r"(%0, %1) : (i32, i32) -> i32
// CHECK: [[V0]]
@@ -325,10 +325,10 @@ func.func @testConstOpReplaced() -> (i32) {
// CHECK-LABEL: testConstOpMatchFailure
func.func @testConstOpMatchFailure() -> (i64) {
// CHECK-DAG: [[C0:%.+]] = "test.constant"() {value = 1
// CHECK-DAG: [[C0:%.+]] = "test.constant"() <{value = 1
%0 = "test.constant"() {value = 1 : i64} : () -> i64
// CHECK-DAG: [[C1:%.+]] = "test.constant"() {value = 2
// CHECK-DAG: [[C1:%.+]] = "test.constant"() <{value = 2
%1 = "test.constant"() {value = 2 : i64} : () -> i64
// CHECK: [[V0:%.+]] = "test.op_r"([[C0]], [[C1]])
@@ -340,7 +340,7 @@ func.func @testConstOpMatchFailure() -> (i64) {
// CHECK-LABEL: testConstOpMatchNonConst
func.func @testConstOpMatchNonConst(%arg0 : i32) -> (i32) {
// CHECK-DAG: [[C0:%.+]] = "test.constant"() {value = 1
// CHECK-DAG: [[C0:%.+]] = "test.constant"() <{value = 1
%0 = "test.constant"() {value = 1 : i32} : () -> i32
// CHECK: [[V0:%.+]] = "test.op_r"([[C0]], %arg0)
@@ -358,14 +358,14 @@ func.func @testConstOpMatchNonConst(%arg0 : i32) -> (i32) {
// CHECK-LABEL: verifyI32EnumAttr
func.func @verifyI32EnumAttr() -> i32 {
// CHECK: "test.i32_enum_attr"() {attr = 10 : i32}
// CHECK: "test.i32_enum_attr"() <{attr = 10 : i32}
%0 = "test.i32_enum_attr"() {attr = 5: i32} : () -> i32
return %0 : i32
}
// CHECK-LABEL: verifyI64EnumAttr
func.func @verifyI64EnumAttr() -> i32 {
// CHECK: "test.i64_enum_attr"() {attr = 10 : i64}
// CHECK: "test.i64_enum_attr"() <{attr = 10 : i64}
%0 = "test.i64_enum_attr"() {attr = 5: i64} : () -> i32
return %0 : i32
}
@@ -522,7 +522,7 @@ func.func @generateVariadicOutputOpInNestedPattern() -> (i32) {
// CHECK-LABEL: redundantTest
func.func @redundantTest(%arg0: i32) -> i32 {
%0 = "test.op_m"(%arg0) : (i32) -> i32
// CHECK: "test.op_m"(%arg0) {optional_attr = 314159265 : i32} : (i32) -> i32
// CHECK: "test.op_m"(%arg0) <{optional_attr = 314159265 : i32}> : (i32) -> i32
return %0 : i32
}

View File

@@ -24,6 +24,7 @@ func.func @testCreateFunctions(%arg0 : tensor<10xf32, !test.smpla>, %arg1 : tens
// -----
func.func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{incompatible with return type}}
%bad = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
return
@@ -32,6 +33,7 @@ func.func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
// -----
func.func @testReturnTypeOpInterfaceMismatch(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{operand type mismatch}}
%bad = "test.op_with_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<*xf32>
return
@@ -40,6 +42,7 @@ func.func @testReturnTypeOpInterfaceMismatch(%arg0 : tensor<10xf32>, %arg1 : ten
// -----
func.func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{required first operand and result to match}}
%bad = "test.op_with_refine_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
return

View File

@@ -177,6 +177,7 @@ FormatToken FormatLexer::lexIdentifier(const char *tokStart) {
StringSwitch<FormatToken::Kind>(str)
.Case("attr-dict", FormatToken::kw_attr_dict)
.Case("attr-dict-with-keyword", FormatToken::kw_attr_dict_w_keyword)
.Case("prop-dict", FormatToken::kw_prop_dict)
.Case("custom", FormatToken::kw_custom)
.Case("functional-type", FormatToken::kw_functional_type)
.Case("oilist", FormatToken::kw_oilist)

View File

@@ -60,6 +60,7 @@ public:
keyword_start,
kw_attr_dict,
kw_attr_dict_w_keyword,
kw_prop_dict,
kw_custom,
kw_functional_type,
kw_oilist,
@@ -287,6 +288,7 @@ public:
/// These are the kinds of directives.
enum Kind {
AttrDict,
PropDict,
Custom,
FunctionalType,
OIList,

View File

@@ -37,5 +37,6 @@ OpClass::OpClass(StringRef name, StringRef extraClassDeclaration,
void OpClass::finalize() {
Class::finalize();
declare<VisibilityDeclaration>(Visibility::Public);
declare<ExtraClassDeclaration>(extraClassDeclaration, extraClassDefinition);
declare<ExtraClassDeclaration>(extraClassDeclaration.str(),
extraClassDefinition);
}

File diff suppressed because it is too large Load Diff

View File

@@ -132,6 +132,14 @@ private:
bool withKeyword;
};
/// This class represents the `prop-dict` directive. This directive represents
/// the properties of the operation, expressed as a directionary.
class PropDictDirective
: public DirectiveElementBase<DirectiveElement::PropDict> {
public:
explicit PropDictDirective() = default;
};
/// This class represents the `functional-type` directive. This directive takes
/// two arguments and formats them, respectively, as the inputs and results of a
/// FunctionType.
@@ -294,8 +302,9 @@ struct OperationFormat {
};
OperationFormat(const Operator &op)
{
: useProperties(op.getDialect().usePropertiesForAttributes() &&
!op.getAttributes().empty()),
opCppClassName(op.getCppClassName()) {
operandTypes.resize(op.getNumOperands(), TypeResolution());
resultTypes.resize(op.getNumResults(), TypeResolution());
@@ -351,6 +360,12 @@ struct OperationFormat {
/// A flag indicating if this operation has the SingleBlock trait.
bool hasSingleBlockTrait;
/// Indicate whether attribute are stored in properties.
bool useProperties;
/// The Operation class name
StringRef opCppClassName;
/// A map of buildable types to indices.
llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
@@ -389,8 +404,7 @@ static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) {
/// {0}: The name of the attribute.
/// {1}: The type for the attribute.
const char *const attrParserCode = R"(
if (parser.parseCustomAttributeWithFallback({0}Attr, {1}, "{0}",
result.attributes)) {{
if (parser.parseCustomAttributeWithFallback({0}Attr, {1})) {{
return ::mlir::failure();
}
)";
@@ -400,30 +414,29 @@ const char *const attrParserCode = R"(
/// {0}: The name of the attribute.
/// {1}: The type for the attribute.
const char *const genericAttrParserCode = R"(
if (parser.parseAttribute({0}Attr, {1}, "{0}", result.attributes))
if (parser.parseAttribute({0}Attr, {1}))
return ::mlir::failure();
)";
const char *const optionalAttrParserCode = R"(
{
::mlir::OptionalParseResult parseResult =
parser.parseOptionalAttribute({0}Attr, {1}, "{0}", result.attributes);
if (parseResult.has_value() && failed(*parseResult))
return ::mlir::failure();
}
::mlir::OptionalParseResult parseResult{0}Attr =
parser.parseOptionalAttribute({0}Attr, {1});
if (parseResult{0}Attr.has_value() && failed(*parseResult{0}Attr))
return ::mlir::failure();
if (parseResult{0}Attr.has_value() && succeeded(*parseResult{0}Attr))
)";
/// The code snippet used to generate a parser call for a symbol name attribute.
///
/// {0}: The name of the attribute.
const char *const symbolNameAttrParserCode = R"(
if (parser.parseSymbolName({0}Attr, "{0}", result.attributes))
if (parser.parseSymbolName({0}Attr))
return ::mlir::failure();
)";
const char *const optionalSymbolNameAttrParserCode = R"(
// Parsing an optional symbol name doesn't fail, so no need to check the
// result.
(void)parser.parseOptionalSymbolName({0}Attr, "{0}", result.attributes);
(void)parser.parseOptionalSymbolName({0}Attr);
)";
/// The code snippet used to generate a parser call for an enum attribute.
@@ -434,6 +447,7 @@ const char *const optionalSymbolNameAttrParserCode = R"(
/// {3}: The constant builder call to create an attribute of the enum type.
/// {4}: The set of allowed enum keywords.
/// {5}: The error message on failure when the enum isn't present.
/// {6}: The attribute assignment expression
const char *const enumAttrParserCode = R"(
{
::llvm::StringRef attrStr;
@@ -460,7 +474,7 @@ const char *const enumAttrParserCode = R"(
<< "{0} attribute specification: \"" << attrStr << '"';;
{0}Attr = {3};
result.addAttribute("{0}", {0}Attr);
{6}
}
}
)";
@@ -572,6 +586,7 @@ const char *const inferReturnTypesParserCode = R"(
if (::mlir::failed({0}::inferReturnTypes(parser.getContext(),
result.location, result.operands,
result.attributes.getDictionary(parser.getContext()),
result.getRawProperties(),
result.regions, inferredReturnTypes)))
return ::mlir::failure();
result.addTypes(inferredReturnTypes);
@@ -930,7 +945,9 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
}
/// Generate the parser for a custom directive.
static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
bool useProperties,
StringRef opCppClassName) {
body << " {\n";
// Preprocess the directive variables.
@@ -1003,9 +1020,15 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
const NamedAttribute *var = attr->getVar();
if (var->attr.isOptional() || var->attr.hasDefaultValue())
body << llvm::formatv(" if ({0}Attr)\n ", var->name);
if (useProperties) {
body << formatv(
" result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n",
var->name, opCppClassName);
} else {
body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
var->name);
}
body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
var->name);
} else if (auto *operand = dyn_cast<OperandVariable>(param)) {
const NamedTypeConstraint *var = operand->getVar();
if (var->isOptional()) {
@@ -1041,7 +1064,8 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
/// Generate the parser for a enum attribute.
static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
FmtContext &attrTypeCtx, bool parseAsOptional) {
FmtContext &attrTypeCtx, bool parseAsOptional,
bool useProperties, StringRef opCppClassName) {
Attribute baseAttr = var->attr.getBaseAttr();
const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
@@ -1076,46 +1100,68 @@ static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
});
errorMessageOS << "]\");";
}
std::string attrAssignment;
if (useProperties) {
attrAssignment =
formatv(" "
"result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;",
var->name, opCppClassName);
} else {
attrAssignment =
formatv("result.addAttribute(\"{0}\", {0}Attr);", var->name);
}
body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(),
enumAttr.getStringToSymbolFnName(), attrBuilderStr,
validCaseKeywordsStr, errorMessage);
validCaseKeywordsStr, errorMessage, attrAssignment);
}
// Generate the parser for an attribute.
static void genAttrParser(AttributeVariable *attr, MethodBody &body,
FmtContext &attrTypeCtx, bool parseAsOptional) {
FmtContext &attrTypeCtx, bool parseAsOptional,
bool useProperties, StringRef opCppClassName) {
const NamedAttribute *var = attr->getVar();
// Check to see if we can parse this as an enum attribute.
if (canFormatEnumAttr(var))
return genEnumAttrParser(var, body, attrTypeCtx, parseAsOptional);
return genEnumAttrParser(var, body, attrTypeCtx, parseAsOptional,
useProperties, opCppClassName);
// Check to see if we should parse this as a symbol name attribute.
if (shouldFormatSymbolNameAttr(var)) {
body << formatv(parseAsOptional ? optionalSymbolNameAttrParserCode
: symbolNameAttrParserCode,
var->name);
return;
}
} else {
// If this attribute has a buildable type, use that when parsing the
// attribute.
std::string attrTypeStr;
if (std::optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
llvm::raw_string_ostream os(attrTypeStr);
os << tgfmt(*typeBuilder, &attrTypeCtx);
} else {
attrTypeStr = "::mlir::Type{}";
// If this attribute has a buildable type, use that when parsing the
// attribute.
std::string attrTypeStr;
if (std::optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
llvm::raw_string_ostream os(attrTypeStr);
os << tgfmt(*typeBuilder, &attrTypeCtx);
} else {
attrTypeStr = "::mlir::Type{}";
}
if (parseAsOptional) {
body << formatv(optionalAttrParserCode, var->name, attrTypeStr);
} else {
if (attr->shouldBeQualified() ||
var->attr.getStorageType() == "::mlir::Attribute")
body << formatv(genericAttrParserCode, var->name, attrTypeStr);
else
body << formatv(attrParserCode, var->name, attrTypeStr);
}
}
if (parseAsOptional) {
body << formatv(optionalAttrParserCode, var->name, attrTypeStr);
if (useProperties) {
body << formatv(
" if ({0}Attr) result.getOrAddProperties<{1}::Properties>().{0} = "
"{0}Attr;\n",
var->name, opCppClassName);
} else {
if (attr->shouldBeQualified() ||
var->attr.getStorageType() == "::mlir::Attribute")
body << formatv(genericAttrParserCode, var->name, attrTypeStr);
else
body << formatv(attrParserCode, var->name, attrTypeStr);
body << formatv(
" if ({0}Attr) result.attributes.append(\"{0}\", {0}Attr);\n",
var->name);
}
}
@@ -1170,8 +1216,15 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
if (!thenGroup == optional->isInverted()) {
// Add the anchor unit attribute to the operation state.
body << " result.addAttribute(\"" << anchorAttr->getVar()->name
<< "\", parser.getBuilder().getUnitAttr());\n";
if (useProperties) {
body << formatv(
" result.getOrAddProperties<{1}::Properties>().{0} = "
"parser.getBuilder().getUnitAttr();",
anchorAttr->getVar()->name, opCppClassName);
} else {
body << " result.addAttribute(\"" << anchorAttr->getVar()->name
<< "\", parser.getBuilder().getUnitAttr());\n";
}
}
}
@@ -1190,7 +1243,8 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
// parsing of the rest of the elements.
FormatElement *firstElement = thenElements.front();
if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
genAttrParser(attrVar, body, attrTypeCtx, /*parseAsOptional=*/true);
genAttrParser(attrVar, body, attrTypeCtx, /*parseAsOptional=*/true,
useProperties, opCppClassName);
body << " if (" << attrVar->getVar()->name << "Attr) {\n";
} else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
body << " if (::mlir::succeeded(parser.parseOptional";
@@ -1248,8 +1302,15 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
body << formatv(oilistParserCode, lelementName);
if (AttributeVariable *unitAttrElem =
oilist->getUnitAttrParsingElement(pelement)) {
body << " result.addAttribute(\"" << unitAttrElem->getVar()->name
<< "\", UnitAttr::get(parser.getContext()));\n";
if (useProperties) {
body << formatv(
" result.getOrAddProperties<{1}::Properties>().{0} = "
"parser.getBuilder().getUnitAttr();",
unitAttrElem->getVar()->name, opCppClassName);
} else {
body << " result.addAttribute(\"" << unitAttrElem->getVar()->name
<< "\", UnitAttr::get(parser.getContext()));\n";
}
} else {
for (FormatElement *el : pelement)
genElementParser(el, body, attrTypeCtx);
@@ -1275,7 +1336,8 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
} else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
bool parseAsOptional =
(genCtx == GenContext::Normal && attr->getVar()->attr.isOptional());
genAttrParser(attr, body, attrTypeCtx, parseAsOptional);
genAttrParser(attr, body, attrTypeCtx, parseAsOptional, useProperties,
opCppClassName);
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
@@ -1311,13 +1373,27 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
/// Directives.
} else if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
body << " if (parser.parseOptionalAttrDict"
<< (attrDict->isWithKeyword() ? "WithKeyword" : "")
<< "(result.attributes))\n"
body.indent() << "{\n";
body.indent() << "auto loc = parser.getCurrentLocation();(void)loc;\n"
<< "if (parser.parseOptionalAttrDict"
<< (attrDict->isWithKeyword() ? "WithKeyword" : "")
<< "(result.attributes))\n"
<< " return ::mlir::failure();\n";
if (useProperties) {
body << "if (failed(verifyInherentAttrs(result.name, result.attributes, "
"[&]() {\n"
<< " return parser.emitError(loc) << \"'\" << "
"result.name.getStringRef() << \"' op \";\n"
<< " })))\n"
<< " return ::mlir::failure();\n";
}
body.unindent() << "}\n";
body.unindent();
} else if (auto *attrDict = dyn_cast<PropDictDirective>(element)) {
body << " if (parseProperties(parser, result))\n"
<< " return ::mlir::failure();\n";
} else if (auto *customDir = dyn_cast<CustomDirective>(element)) {
genCustomDirectiveParser(customDir, body);
genCustomDirectiveParser(customDir, body, useProperties, opCppClassName);
} else if (isa<OperandsDirective>(element)) {
body << " ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
<< " if (parser.parseOperandList(allOperands))\n"
@@ -1571,8 +1647,16 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
MethodBody &body) {
if (!allOperands) {
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
body << " result.addAttribute(\"operand_segment_sizes\", "
<< "parser.getBuilder().getDenseI32ArrayAttr({";
if (op.getDialect().usePropertiesForAttributes()) {
body << formatv(" "
"result.getOrAddProperties<{0}::Properties>().operand_"
"segment_sizes = "
"(parser.getBuilder().getDenseI32ArrayAttr({{",
op.getCppClassName());
} else {
body << " result.addAttribute(\"operand_segment_sizes\", "
<< "parser.getBuilder().getDenseI32ArrayAttr({";
}
auto interleaveFn = [&](const NamedTypeConstraint &operand) {
// If the operand is variadic emit the parsed size.
if (operand.isVariableLength())
@@ -1586,18 +1670,36 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
for (const NamedTypeConstraint &operand : op.getOperands()) {
if (!operand.isVariadicOfVariadic())
continue;
body << llvm::formatv(
" result.addAttribute(\"{0}\", "
"parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));\n",
operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
operand.name);
if (op.getDialect().usePropertiesForAttributes()) {
body << llvm::formatv(
" result.getOrAddProperties<{0}::Properties>().{1} = "
"parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n",
op.getCppClassName(),
operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
operand.name);
} else {
body << llvm::formatv(
" result.addAttribute(\"{0}\", "
"parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));"
"\n",
operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
operand.name);
}
}
}
if (!allResultTypes &&
op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
body << " result.addAttribute(\"result_segment_sizes\", "
<< "parser.getBuilder().getDenseI32ArrayAttr({";
if (op.getDialect().usePropertiesForAttributes()) {
body << formatv(
" "
"result.getOrAddProperties<{0}::Properties>().result_segment_sizes = "
"(parser.getBuilder().getDenseI32ArrayAttr({{",
op.getCppClassName());
} else {
body << " result.addAttribute(\"result_segment_sizes\", "
<< "parser.getBuilder().getDenseI32ArrayAttr({";
}
auto interleaveFn = [&](const NamedTypeConstraint &result) {
// If the result is variadic emit the parsed size.
if (result.isVariableLength())
@@ -1641,6 +1743,14 @@ const char *enumAttrBeginPrinterCode = R"(
auto caseValueStr = {1}(caseValue);
)";
/// Generate the printer for the 'prop-dict' directive.
static void genPropDictPrinter(OperationFormat &fmt, Operator &op,
MethodBody &body) {
body << " _odsPrinter << \" \";\n"
<< " printProperties(this->getContext(), _odsPrinter, "
"getProperties());\n";
}
/// Generate the printer for the 'attr-dict' directive.
static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
MethodBody &body, bool withKeyword) {
@@ -1898,7 +2008,7 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
})
.Case<AttributeVariable>([&](AttributeVariable *element) {
Attribute attr = element->getVar()->attr;
body << "(*this)->getAttr(\"" << element->getVar()->name << "\")";
body << op.getGetterName(element->getVar()->name) << "Attr()";
if (attr.isOptional())
return; // done
if (attr.hasDefaultValue()) {
@@ -1906,7 +2016,8 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
// default value.
FmtContext fctx;
fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())");
body << " != "
body << " && " << op.getGetterName(element->getVar()->name)
<< "Attr() != "
<< tgfmt(attr.getConstBuilderTemplate(), &fctx,
attr.getDefaultValue());
return;
@@ -2063,6 +2174,13 @@ void OperationFormat::genElementPrinter(FormatElement *element,
return;
}
// Emit the attribute dictionary.
if (auto *propDict = dyn_cast<PropDictDirective>(element)) {
genPropDictPrinter(*this, op, body);
lastWasPunctuation = false;
return;
}
// Optionally insert a space before the next element. The AttrDict printer
// already adds a space as necessary.
if (shouldEmitSpace || !lastWasPunctuation)
@@ -2300,6 +2418,7 @@ private:
ConstArgument findSeenArg(StringRef name);
/// Parse the various different directives.
FailureOr<FormatElement *> parsePropDictDirective(SMLoc loc, Context context);
FailureOr<FormatElement *> parseAttrDictDirective(SMLoc loc, Context context,
bool withKeyword);
FailureOr<FormatElement *> parseFunctionalTypeDirective(SMLoc loc,
@@ -2329,6 +2448,7 @@ private:
// The following are various bits of format state used for verification
// during parsing.
bool hasAttrDict = false;
bool hasPropDict = false;
bool hasAllRegions = false, hasAllSuccessors = false;
bool canInferResultTypes = false;
llvm::SmallBitVector seenOperandTypes, seenResultTypes;
@@ -2873,6 +2993,8 @@ FailureOr<FormatElement *>
OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
Context ctx) {
switch (kind) {
case FormatToken::kw_prop_dict:
return parsePropDictDirective(loc, ctx);
case FormatToken::kw_attr_dict:
return parseAttrDictDirective(loc, ctx,
/*withKeyword=*/false);
@@ -2925,6 +3047,23 @@ OpFormatParser::parseAttrDictDirective(SMLoc loc, Context context,
return create<AttrDictDirective>(withKeyword);
}
FailureOr<FormatElement *>
OpFormatParser::parsePropDictDirective(SMLoc loc, Context context) {
if (context == TypeDirectiveContext)
return emitError(loc, "'prop-dict' directive can only be used as a "
"top-level directive");
if (context == RefDirectiveContext)
llvm::report_fatal_error("'ref' of 'prop-dict' unsupported");
// Otherwise, this is a top-level context.
if (hasPropDict)
return emitError(loc, "'prop-dict' directive has already been seen");
hasPropDict = true;
return create<PropDictDirective>();
}
LogicalResult OpFormatParser::verifyCustomDirectiveArguments(
SMLoc loc, ArrayRef<FormatElement *> arguments) {
for (FormatElement *argument : arguments) {

View File

@@ -25,7 +25,7 @@ static Operation *createOp(MLIRContext *context, Location loc,
context->allowUnregisteredDialects();
return Operation::create(loc, OperationName(operationName, context),
std::nullopt, std::nullopt, std::nullopt,
std::nullopt, numRegions);
OpaqueProperties(nullptr), std::nullopt, numRegions);
}
namespace {

View File

@@ -38,10 +38,10 @@ TEST(Adaptor, GenericAdaptorsOperandAccess) {
// Using optional instead of plain int here to differentiate absence of
// value from the value 0.
SmallVector<std::optional<int>> v = {0, 4};
OIListSimple::GenericAdaptor<ArrayRef<std::optional<int>>> d(
v, builder.getDictionaryAttr({builder.getNamedAttr(
"operand_segment_sizes",
builder.getDenseI32ArrayAttr({1, 0, 1}))}));
OIListSimple::Properties prop;
prop.operand_segment_sizes = builder.getDenseI32ArrayAttr({1, 0, 1});
OIListSimple::GenericAdaptor<ArrayRef<std::optional<int>>> d(v, {}, prop,
{});
EXPECT_EQ(d.getArg0(), 0);
EXPECT_EQ(d.getArg1(), std::nullopt);
EXPECT_EQ(d.getArg2(), 4);
@@ -51,9 +51,10 @@ TEST(Adaptor, GenericAdaptorsOperandAccess) {
FormatVariadicOfVariadicOperand::FoldAdaptor e({});
{
SmallVector<int> v = {0, 1, 2, 3, 4};
FormatVariadicOfVariadicOperand::GenericAdaptor<ArrayRef<int>> f(
v, builder.getDictionaryAttr({builder.getNamedAttr(
"operand_segments", builder.getDenseI32ArrayAttr({3, 2, 0}))}));
FormatVariadicOfVariadicOperand::Properties prop;
prop.operand_segments = builder.getDenseI32ArrayAttr({3, 2, 0});
FormatVariadicOfVariadicOperand::GenericAdaptor<ArrayRef<int>> f(v, {},
prop, {});
SmallVector<ArrayRef<int>> operand = f.getOperand();
ASSERT_EQ(operand.size(), (std::size_t)3);
EXPECT_THAT(operand[0], ElementsAre(0, 1, 2));

View File

@@ -9,6 +9,7 @@ add_mlir_unittest(MLIRIRTests
PatternMatchTest.cpp
ShapedTypeTest.cpp
TypeTest.cpp
OpPropertiesTest.cpp
DEPENDS
MLIRTestInterfaceIncGen

View File

@@ -0,0 +1,358 @@
//===- TestOpProperties.cpp - Test all properties-related APIs ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/OpDefinition.h"
#include "mlir/Parser/Parser.h"
#include "gtest/gtest.h"
#include <optional>
using namespace mlir;
namespace {
/// Simple structure definining a struct to define "properties" for a given
/// operation. Default values are honored when creating an operation.
struct TestProperties {
int a = -1;
float b = -1.;
std::vector<int64_t> array = {-33};
/// A shared_ptr to a const object is safe: it is equivalent to a value-based
/// member. Here the label will be deallocated when the last operation
/// referring to it is destroyed. However there is no pool-allocation: this is
/// offloaded to the client.
std::shared_ptr<const std::string> label;
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestProperties)
};
/// Convert a DictionaryAttr to a TestProperties struct, optionally emit errors
/// through the provided diagnostic if any. This is used for example during
/// parsing with the generic format.
static LogicalResult
setPropertiesFromAttribute(TestProperties &prop, Attribute attr,
InFlightDiagnostic *diagnostic) {
DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
if (!dict) {
if (diagnostic)
*diagnostic << "expected DictionaryAttr to set TestProperties";
return failure();
}
auto aAttr = dict.getAs<IntegerAttr>("a");
if (!aAttr) {
if (diagnostic)
*diagnostic << "expected IntegerAttr for key `a`";
return failure();
}
auto bAttr = dict.getAs<FloatAttr>("b");
if (!bAttr ||
&bAttr.getValue().getSemantics() != &llvm::APFloatBase::IEEEsingle()) {
if (diagnostic)
*diagnostic << "expected FloatAttr for key `b`";
return failure();
}
auto arrayAttr = dict.getAs<DenseI64ArrayAttr>("array");
if (!arrayAttr) {
if (diagnostic)
*diagnostic << "expected DenseI64ArrayAttr for key `array`";
return failure();
}
auto label = dict.getAs<mlir::StringAttr>("label");
if (!label) {
if (diagnostic)
*diagnostic << "expected StringAttr for key `label`";
return failure();
}
prop.a = aAttr.getValue().getSExtValue();
prop.b = bAttr.getValue().convertToFloat();
prop.array.assign(arrayAttr.asArrayRef().begin(),
arrayAttr.asArrayRef().end());
prop.label = std::make_shared<std::string>(label.getValue());
return success();
}
/// Convert a TestProperties struct to a DictionaryAttr, this is used for
/// example during printing with the generic format.
static Attribute getPropertiesAsAttribute(MLIRContext *ctx,
const TestProperties &prop) {
SmallVector<NamedAttribute> attrs;
Builder b{ctx};
attrs.push_back(b.getNamedAttr("a", b.getI32IntegerAttr(prop.a)));
attrs.push_back(b.getNamedAttr("b", b.getF32FloatAttr(prop.b)));
attrs.push_back(b.getNamedAttr("array", b.getDenseI64ArrayAttr(prop.array)));
attrs.push_back(b.getNamedAttr(
"label", b.getStringAttr(prop.label ? *prop.label : "<nullptr>")));
return b.getDictionaryAttr(attrs);
}
inline llvm::hash_code computeHash(const TestProperties &prop) {
// We hash `b` which is a float using its underlying array of char:
unsigned char const *p = reinterpret_cast<unsigned char const *>(&prop.b);
ArrayRef<unsigned char> bBytes{p, sizeof(prop.b)};
return llvm::hash_combine(
prop.a, llvm::hash_combine_range(bBytes.begin(), bBytes.end()),
llvm::hash_combine_range(prop.array.begin(), prop.array.end()),
StringRef(*prop.label));
}
/// A custom operation for the purpose of showcasing how to use "properties".
class OpWithProperties : public Op<OpWithProperties> {
public:
// Begin boilerplate
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWithProperties)
using Op::Op;
static ArrayRef<StringRef> getAttributeNames() { return {}; }
static StringRef getOperationName() {
return "test_op_properties.op_with_properties";
}
// End boilerplate
// This alias is the only definition needed for enabling "properties" for this
// operation.
using Properties = TestProperties;
static std::optional<mlir::Attribute> getInherentAttr(const Properties &prop,
StringRef name) {
return std::nullopt;
}
static void setInherentAttr(Properties &prop, StringRef name,
mlir::Attribute value) {}
static void populateInherentAttrs(const Properties &prop,
NamedAttrList &attrs) {}
static LogicalResult
verifyInherentAttrs(OperationName opName, NamedAttrList &attrs,
function_ref<InFlightDiagnostic()> getDiag) {
return success();
}
};
// A trivial supporting dialect to register the above operation.
class TestOpPropertiesDialect : public Dialect {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOpPropertiesDialect)
static constexpr StringLiteral getDialectNamespace() {
return StringLiteral("test_op_properties");
}
explicit TestOpPropertiesDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context,
TypeID::get<TestOpPropertiesDialect>()) {
addOperations<OpWithProperties>();
}
};
constexpr StringLiteral mlirSrc = R"mlir(
"test_op_properties.op_with_properties"()
<{a = -42 : i32,
b = -4.200000e+01 : f32,
array = array<i64: 40, 41>,
label = "bar foo"}> : () -> ()
)mlir";
TEST(OpPropertiesTest, Properties) {
MLIRContext context;
context.getOrLoadDialect<TestOpPropertiesDialect>();
ParserConfig config(&context);
// Parse the operation with some properties.
OwningOpRef<Operation *> op = parseSourceString(mlirSrc, config);
ASSERT_TRUE(op.get() != nullptr);
auto opWithProp = dyn_cast<OpWithProperties>(op.get());
ASSERT_TRUE(opWithProp);
{
std::string output;
llvm::raw_string_ostream os(output);
opWithProp.print(os);
ASSERT_STREQ("\"test_op_properties.op_with_properties\"() "
"<{a = -42 : i32, "
"array = array<i64: 40, 41>, "
"b = -4.200000e+01 : f32, "
"label = \"bar foo\"}> : () -> ()\n",
os.str().c_str());
}
// Get a mutable reference to the properties for this operation and modify it
// in place one member at a time.
TestProperties &prop = opWithProp.getProperties();
prop.a = 42;
{
std::string output;
llvm::raw_string_ostream os(output);
opWithProp.print(os);
EXPECT_TRUE(StringRef(os.str()).contains("a = 42"));
EXPECT_TRUE(StringRef(os.str()).contains("b = -4.200000e+01"));
EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 40, 41>"));
EXPECT_TRUE(StringRef(os.str()).contains("label = \"bar foo\""));
}
prop.b = 42.;
{
std::string output;
llvm::raw_string_ostream os(output);
opWithProp.print(os);
EXPECT_TRUE(StringRef(os.str()).contains("a = 42"));
EXPECT_TRUE(StringRef(os.str()).contains("b = 4.200000e+01"));
EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 40, 41>"));
EXPECT_TRUE(StringRef(os.str()).contains("label = \"bar foo\""));
}
prop.array.push_back(42);
{
std::string output;
llvm::raw_string_ostream os(output);
opWithProp.print(os);
EXPECT_TRUE(StringRef(os.str()).contains("a = 42"));
EXPECT_TRUE(StringRef(os.str()).contains("b = 4.200000e+01"));
EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 40, 41, 42>"));
EXPECT_TRUE(StringRef(os.str()).contains("label = \"bar foo\""));
}
prop.label = std::make_shared<std::string>("foo bar");
{
std::string output;
llvm::raw_string_ostream os(output);
opWithProp.print(os);
EXPECT_TRUE(StringRef(os.str()).contains("a = 42"));
EXPECT_TRUE(StringRef(os.str()).contains("b = 4.200000e+01"));
EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 40, 41, 42>"));
EXPECT_TRUE(StringRef(os.str()).contains("label = \"foo bar\""));
}
}
// Test diagnostic emission when using invalid dictionary.
TEST(OpPropertiesTest, FailedProperties) {
MLIRContext context;
context.getOrLoadDialect<TestOpPropertiesDialect>();
std::string diagnosticStr;
context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
diagnosticStr += diag.str();
return success();
});
// Parse the operation with some properties.
ParserConfig config(&context);
// Parse an operation with invalid (incomplete) properties.
OwningOpRef<Operation *> owningOp =
parseSourceString("\"test_op_properties.op_with_properties\"() "
"<{a = -42 : i32}> : () -> ()\n",
config);
ASSERT_EQ(owningOp.get(), nullptr);
EXPECT_STREQ(
"invalid properties {a = -42 : i32} for op "
"test_op_properties.op_with_properties: expected FloatAttr for key `b`",
diagnosticStr.c_str());
diagnosticStr.clear();
owningOp = parseSourceString(mlirSrc, config);
Operation *op = owningOp.get();
ASSERT_TRUE(op != nullptr);
Location loc = op->getLoc();
auto opWithProp = dyn_cast<OpWithProperties>(op);
ASSERT_TRUE(opWithProp);
OperationState state(loc, op->getName());
Builder b{&context};
NamedAttrList attrs;
attrs.push_back(b.getNamedAttr("a", b.getStringAttr("foo")));
state.propertiesAttr = attrs.getDictionary(&context);
{
auto diag = op->emitError("setting properties failed: ");
auto result = state.setProperties(op, &diag);
EXPECT_TRUE(result.failed());
}
EXPECT_STREQ("setting properties failed: expected IntegerAttr for key `a`",
diagnosticStr.c_str());
}
TEST(OpPropertiesTest, DefaultValues) {
MLIRContext context;
context.getOrLoadDialect<TestOpPropertiesDialect>();
OperationState state(UnknownLoc::get(&context),
"test_op_properties.op_with_properties");
Operation *op = Operation::create(state);
ASSERT_TRUE(op != nullptr);
{
std::string output;
llvm::raw_string_ostream os(output);
op->print(os);
EXPECT_TRUE(StringRef(os.str()).contains("a = -1"));
EXPECT_TRUE(StringRef(os.str()).contains("b = -1"));
EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: -33>"));
}
op->erase();
}
TEST(OpPropertiesTest, Cloning) {
MLIRContext context;
context.getOrLoadDialect<TestOpPropertiesDialect>();
ParserConfig config(&context);
// Parse the operation with some properties.
OwningOpRef<Operation *> op = parseSourceString(mlirSrc, config);
ASSERT_TRUE(op.get() != nullptr);
auto opWithProp = dyn_cast<OpWithProperties>(op.get());
ASSERT_TRUE(opWithProp);
Operation *clone = opWithProp->clone();
// Check that op and its clone prints equally
std::string opStr;
std::string cloneStr;
{
llvm::raw_string_ostream os(opStr);
op.get()->print(os);
}
{
llvm::raw_string_ostream os(cloneStr);
clone->print(os);
}
clone->erase();
EXPECT_STREQ(opStr.c_str(), cloneStr.c_str());
}
TEST(OpPropertiesTest, Equivalence) {
MLIRContext context;
context.getOrLoadDialect<TestOpPropertiesDialect>();
ParserConfig config(&context);
// Parse the operation with some properties.
OwningOpRef<Operation *> op = parseSourceString(mlirSrc, config);
ASSERT_TRUE(op.get() != nullptr);
auto opWithProp = dyn_cast<OpWithProperties>(op.get());
ASSERT_TRUE(opWithProp);
llvm::hash_code reference = OperationEquivalence::computeHash(opWithProp);
TestProperties &prop = opWithProp.getProperties();
prop.a = 42;
EXPECT_NE(reference, OperationEquivalence::computeHash(opWithProp));
prop.a = -42;
EXPECT_EQ(reference, OperationEquivalence::computeHash(opWithProp));
prop.b = 42.;
EXPECT_NE(reference, OperationEquivalence::computeHash(opWithProp));
prop.b = -42.;
EXPECT_EQ(reference, OperationEquivalence::computeHash(opWithProp));
prop.array.push_back(42);
EXPECT_NE(reference, OperationEquivalence::computeHash(opWithProp));
prop.array.pop_back();
EXPECT_EQ(reference, OperationEquivalence::computeHash(opWithProp));
}
TEST(OpPropertiesTest, getOrAddProperties) {
MLIRContext context;
context.getOrLoadDialect<TestOpPropertiesDialect>();
OperationState state(UnknownLoc::get(&context),
"test_op_properties.op_with_properties");
// Test `getOrAddProperties` API on OperationState.
TestProperties &prop = state.getOrAddProperties<TestProperties>();
prop.a = 1;
prop.b = 2;
prop.array = {3, 4, 5};
Operation *op = Operation::create(state);
ASSERT_TRUE(op != nullptr);
{
std::string output;
llvm::raw_string_ostream os(output);
op->print(os);
EXPECT_TRUE(StringRef(os.str()).contains("a = 1"));
EXPECT_TRUE(StringRef(os.str()).contains("b = 2"));
EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 3, 4, 5>"));
}
op->erase();
}
} // namespace

View File

@@ -22,9 +22,9 @@ static Operation *createOp(MLIRContext *context,
ArrayRef<Type> resultTypes = std::nullopt,
unsigned int numRegions = 0) {
context->allowUnregisteredDialects();
return Operation::create(UnknownLoc::get(context),
OperationName("foo.bar", context), resultTypes,
operands, std::nullopt, std::nullopt, numRegions);
return Operation::create(
UnknownLoc::get(context), OperationName("foo.bar", context), resultTypes,
operands, std::nullopt, nullptr, std::nullopt, numRegions);
}
namespace {

View File

@@ -13,9 +13,9 @@ using namespace mlir;
static Operation *createOp(MLIRContext *context) {
context->allowUnregisteredDialects();
return Operation::create(UnknownLoc::get(context),
OperationName("foo.bar", context), std::nullopt,
std::nullopt, std::nullopt, std::nullopt, 0);
return Operation::create(
UnknownLoc::get(context), OperationName("foo.bar", context), std::nullopt,
std::nullopt, std::nullopt, /*properties=*/nullptr, std::nullopt, 0);
}
namespace {