[mlir] Allow using non-attribute properties in declarative rewrite patterns (#143071)
This commit adds support for non-attribute properties (such as StringProp and I64Prop) in declarative rewrite patterns. The handling for properties follows the handling for attributes in most cases, including in the generation of static matchers. Constraints that are shared between multiple types are supported by making the constraint matcher a templated function, which is the equivalent to passing ::mlir::Attribute for an arbitrary C++ type.
This commit is contained in:
committed by
GitHub
parent
f4df9f1c6e
commit
5ce5ed4b85
@@ -380,6 +380,11 @@ template. The string can be an arbitrary C++ expression that evaluates into some
|
||||
C++ object expected at the `NativeCodeCall` site (here it would be expecting an
|
||||
array attribute). Typically the string should be a function call.
|
||||
|
||||
In the case of properties, the return value of the `NativeCodeCall` should
|
||||
be in terms of the _interface_ type of a property. For example, the `NativeCodeCall`
|
||||
for a `StringProp` should return a `StringRef`, which will copied into the underlying
|
||||
`std::string`, just as if it were an argument to the operation's builder.
|
||||
|
||||
##### `NativeCodeCall` placeholders
|
||||
|
||||
In `NativeCodeCall`, we can use placeholders like `$_builder`, `$N` and `$N...`.
|
||||
@@ -416,14 +421,20 @@ must be either passed by reference or pointer to the variable used as argument
|
||||
so that the matched value can be returned. In the same example, `$val` will be
|
||||
bound to a variable with `Attribute` type (as `I32Attr`) and the type of the
|
||||
second argument in `Foo()` could be `Attribute&` or `Attribute*`. Names with
|
||||
attribute constraints will be captured as `Attribute`s while everything else
|
||||
will be treated as `Value`s.
|
||||
attribute constraints will be captured as `Attribute`s, names with
|
||||
property constraints (which must have a concrete interface type) will be treated
|
||||
as that type, and everything else will be treated as `Value`s.
|
||||
|
||||
Positional placeholders will be substituted by the `dag` object parameters at
|
||||
the `NativeCodeCall` use site. For example, if we define `SomeCall :
|
||||
NativeCodeCall<"someFn($1, $2, $0)">` and use it like `(SomeCall $in0, $in1,
|
||||
$in2)`, then this will be translated into C++ call `someFn($in1, $in2, $in0)`.
|
||||
|
||||
In the case of properties, the placeholder will be bound to a value of the _interface_
|
||||
type of the property. For example, passing in a `StringProp` as an argument to a `NativeCodeCall` will pass a `StringRef` (as if the getter of the matched
|
||||
operation were called) and not a `std::string`. See
|
||||
`mlir/include/mlir/IR/Properties.td` for details on interface vs. storage type.
|
||||
|
||||
Positional range placeholders will be substituted by multiple `dag` object
|
||||
parameters at the `NativeCodeCall` use site. For example, if we define
|
||||
`SomeCall : NativeCodeCall<"someFn($1...)">` and use it like `(SomeCall $in0,
|
||||
|
||||
@@ -401,6 +401,21 @@ class ConfinedProperty<Property p, Pred pred, string newSummary = "">
|
||||
: ConfinedProp<p, pred, newSummary>,
|
||||
Deprecated<"moved to shorter name ConfinedProp">;
|
||||
|
||||
/// Defines a constant value of type `prop` to be used in pattern matching.
|
||||
/// When used as a constraint, forms a matcher that tests that the property is
|
||||
/// equal to the given value (and matches any other constraints on the property).
|
||||
/// The constant value is given as a string and should be of the _interface_ type
|
||||
/// of the attribute.
|
||||
///
|
||||
/// This requires that the given property's inference type be comparable to the
|
||||
/// given value with `==`, and does require specify a concrete property type.
|
||||
class ConstantProp<Property prop, string val>
|
||||
: ConfinedProp<prop,
|
||||
CPred<"$_self == " # val>,
|
||||
"constant '" # prop.summary # "': " # val> {
|
||||
string value = val;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Primitive property combinators
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -73,12 +73,23 @@ public:
|
||||
// specifies an attribute constraint.
|
||||
bool isAttrMatcher() const;
|
||||
|
||||
// Returns true if this DAG leaf is matching a property. That is, it
|
||||
// specifies a property constraint.
|
||||
bool isPropMatcher() const;
|
||||
|
||||
// Returns true if this DAG leaf is describing a property. That is, it
|
||||
// is a subclass of `Property` in tablegen.
|
||||
bool isPropDefinition() const;
|
||||
|
||||
// Returns true if this DAG leaf is wrapping native code call.
|
||||
bool isNativeCodeCall() const;
|
||||
|
||||
// Returns true if this DAG leaf is specifying a constant attribute.
|
||||
bool isConstantAttr() const;
|
||||
|
||||
// Returns true if this DAG leaf is specifying a constant property.
|
||||
bool isConstantProp() const;
|
||||
|
||||
// Returns true if this DAG leaf is specifying an enum case.
|
||||
bool isEnumCase() const;
|
||||
|
||||
@@ -88,9 +99,19 @@ public:
|
||||
// Returns this DAG leaf as a constraint. Asserts if fails.
|
||||
Constraint getAsConstraint() const;
|
||||
|
||||
// Returns this DAG leaf as a property constraint. Asserts if fails. This
|
||||
// allows access to the interface type.
|
||||
PropConstraint getAsPropConstraint() const;
|
||||
|
||||
// Returns this DAG leaf as a property definition. Asserts if fails.
|
||||
Property getAsProperty() const;
|
||||
|
||||
// Returns this DAG leaf as an constant attribute. Asserts if fails.
|
||||
ConstantAttr getAsConstantAttr() const;
|
||||
|
||||
// Returns this DAG leaf as an constant property. Asserts if fails.
|
||||
ConstantProp getAsConstantProp() const;
|
||||
|
||||
// Returns this DAG leaf as an enum case.
|
||||
// Precondition: isEnumCase()
|
||||
EnumCase getAsEnumCase() const;
|
||||
@@ -279,6 +300,10 @@ public:
|
||||
// the DAG of the operation, `operandIndexOrNumValues` specifies the
|
||||
// operand index, and `variadicSubIndex` must be set to `std::nullopt`.
|
||||
//
|
||||
// * Properties not associated with an operation (e.g. as arguments to
|
||||
// native code) have their corresponding PropConstraint stored in the
|
||||
// `dag` field. This constraint is only used when
|
||||
//
|
||||
// * If a symbol is defined in a `variadic` DAG, `dag` specifies the DAG
|
||||
// of the parent operation, `operandIndexOrNumValues` specifies the
|
||||
// declared operand index of the variadic operand in the parent
|
||||
@@ -364,12 +389,20 @@ public:
|
||||
|
||||
// What kind of entity this symbol represents:
|
||||
// * Attr: op attribute
|
||||
// * Prop: op property
|
||||
// * Operand: op operand
|
||||
// * Result: op result
|
||||
// * Value: a value not attached to an op (e.g., from NativeCodeCall)
|
||||
// * MultipleValues: a pack of values not attached to an op (e.g., from
|
||||
// NativeCodeCall). This kind supports indexing.
|
||||
enum class Kind : uint8_t { Attr, Operand, Result, Value, MultipleValues };
|
||||
enum class Kind : uint8_t {
|
||||
Attr,
|
||||
Prop,
|
||||
Operand,
|
||||
Result,
|
||||
Value,
|
||||
MultipleValues
|
||||
};
|
||||
|
||||
// Creates a SymbolInfo instance. `dagAndConstant` is only used for `Attr`
|
||||
// and `Operand` so should be std::nullopt for `Result` and `Value` kind.
|
||||
@@ -384,6 +417,15 @@ public:
|
||||
static SymbolInfo getAttr() {
|
||||
return SymbolInfo(nullptr, Kind::Attr, std::nullopt);
|
||||
}
|
||||
static SymbolInfo getProp(const Operator *op, int index) {
|
||||
return SymbolInfo(op, Kind::Prop,
|
||||
DagAndConstant(nullptr, index, std::nullopt));
|
||||
}
|
||||
static SymbolInfo getProp(const PropConstraint *constraint) {
|
||||
// -1 for anthe `operandIndexOrNumValues` is a sentinel value.
|
||||
return SymbolInfo(nullptr, Kind::Prop,
|
||||
DagAndConstant(constraint, -1, std::nullopt));
|
||||
}
|
||||
static SymbolInfo
|
||||
getOperand(DagNode node, const Operator *op, int operandIndex,
|
||||
std::optional<int> variadicSubIndex = std::nullopt) {
|
||||
@@ -488,6 +530,10 @@ public:
|
||||
// is already bound.
|
||||
bool bindAttr(StringRef symbol);
|
||||
|
||||
// Registers the given `symbol` as bound to a property that satisfies the
|
||||
// given `constraint`. `constraint` must name a concrete interface type.
|
||||
bool bindProp(StringRef symbol, const PropConstraint &constraint);
|
||||
|
||||
// Returns true if the given `symbol` is bound.
|
||||
bool contains(StringRef symbol) const;
|
||||
|
||||
|
||||
@@ -32,9 +32,9 @@ class Pred;
|
||||
// Wrapper class providing helper methods for accesing property constraint
|
||||
// values.
|
||||
class PropConstraint : public Constraint {
|
||||
public:
|
||||
using Constraint::Constraint;
|
||||
|
||||
public:
|
||||
static bool classof(const Constraint *c) { return c->getKind() == CK_Prop; }
|
||||
|
||||
StringRef getInterfaceType() const;
|
||||
@@ -143,6 +143,10 @@ public:
|
||||
// property constraints, this function is added for future-proofing)
|
||||
Property getBaseProperty() const;
|
||||
|
||||
// Returns true if this property is backed by a TableGen definition and that
|
||||
// definition is a subclass of `className`.
|
||||
bool isSubClassOf(StringRef className) const;
|
||||
|
||||
private:
|
||||
// Elements describing a Property, in general fetched from the record.
|
||||
StringRef summary;
|
||||
@@ -169,6 +173,21 @@ struct NamedProperty {
|
||||
Property prop;
|
||||
};
|
||||
|
||||
// Wrapper class providing helper methods for processing constant property
|
||||
// values defined using the `ConstantProp` subclass of `Property`
|
||||
// in TableGen.
|
||||
class ConstantProp : public Property {
|
||||
public:
|
||||
explicit ConstantProp(const llvm::DefInit *def) : Property(def) {
|
||||
assert(isSubClassOf("ConstantProp"));
|
||||
}
|
||||
|
||||
static bool classof(Property *p) { return p->isSubClassOf("ConstantProp"); }
|
||||
|
||||
// Return the constant value of the property as an expression
|
||||
// that produces an interface-type constant.
|
||||
StringRef getValue() const;
|
||||
};
|
||||
} // namespace tblgen
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -205,10 +205,14 @@ static ::llvm::LogicalResult {0}(
|
||||
|
||||
/// Code for a pattern type or attribute constraint.
|
||||
///
|
||||
/// {3}: "Type type" or "Attribute attr".
|
||||
static const char *const patternAttrOrTypeConstraintCode = R"(
|
||||
/// {0}: name of function
|
||||
/// {1}: Condition template
|
||||
/// {2}: Constraint summary
|
||||
/// {3}: "::mlir::Type type" or "::mlirAttribute attr" or "propType prop".
|
||||
/// Can be "T prop" for generic property constraints.
|
||||
static const char *const patternConstraintCode = R"(
|
||||
static ::llvm::LogicalResult {0}(
|
||||
::mlir::PatternRewriter &rewriter, ::mlir::Operation *op, ::mlir::{3},
|
||||
::mlir::PatternRewriter &rewriter, ::mlir::Operation *op, {3},
|
||||
::llvm::StringRef failureStr) {
|
||||
if (!({1})) {
|
||||
return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {
|
||||
@@ -265,15 +269,31 @@ void StaticVerifierFunctionEmitter::emitPatternConstraints() {
|
||||
FmtContext ctx;
|
||||
ctx.addSubst("_op", "*op").withBuilder("rewriter").withSelf("type");
|
||||
for (auto &it : typeConstraints) {
|
||||
os << formatv(patternAttrOrTypeConstraintCode, it.second,
|
||||
os << formatv(patternConstraintCode, it.second,
|
||||
tgfmt(it.first.getConditionTemplate(), &ctx),
|
||||
escapeString(it.first.getSummary()), "Type type");
|
||||
escapeString(it.first.getSummary()), "::mlir::Type type");
|
||||
}
|
||||
ctx.withSelf("attr");
|
||||
for (auto &it : attrConstraints) {
|
||||
os << formatv(patternAttrOrTypeConstraintCode, it.second,
|
||||
os << formatv(patternConstraintCode, it.second,
|
||||
tgfmt(it.first.getConditionTemplate(), &ctx),
|
||||
escapeString(it.first.getSummary()), "Attribute attr");
|
||||
escapeString(it.first.getSummary()),
|
||||
"::mlir::Attribute attr");
|
||||
}
|
||||
ctx.withSelf("prop");
|
||||
for (auto &it : propConstraints) {
|
||||
PropConstraint propConstraint = cast<PropConstraint>(it.first);
|
||||
StringRef interfaceType = propConstraint.getInterfaceType();
|
||||
// Constraints that are generic over multiple interface types are
|
||||
// templatized under the assumption that they'll be used correctly.
|
||||
if (interfaceType.empty()) {
|
||||
interfaceType = "T";
|
||||
os << "template <typename T>";
|
||||
}
|
||||
os << formatv(patternConstraintCode, it.second,
|
||||
tgfmt(propConstraint.getConditionTemplate(), &ctx),
|
||||
escapeString(propConstraint.getSummary()),
|
||||
Twine(interfaceType) + " prop");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -367,10 +387,15 @@ void StaticVerifierFunctionEmitter::collectOpConstraints(
|
||||
void StaticVerifierFunctionEmitter::collectPatternConstraints(
|
||||
const ArrayRef<DagLeaf> constraints) {
|
||||
for (auto &leaf : constraints) {
|
||||
assert(leaf.isOperandMatcher() || leaf.isAttrMatcher());
|
||||
collectConstraint(
|
||||
leaf.isOperandMatcher() ? typeConstraints : attrConstraints,
|
||||
leaf.isOperandMatcher() ? "type" : "attr", leaf.getAsConstraint());
|
||||
assert(leaf.isOperandMatcher() || leaf.isAttrMatcher() ||
|
||||
leaf.isPropMatcher());
|
||||
Constraint constraint = leaf.getAsConstraint();
|
||||
if (leaf.isOperandMatcher())
|
||||
collectConstraint(typeConstraints, "type", constraint);
|
||||
else if (leaf.isAttrMatcher())
|
||||
collectConstraint(attrConstraints, "attr", constraint);
|
||||
else if (leaf.isPropMatcher())
|
||||
collectConstraint(propConstraints, "prop", constraint);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -51,6 +51,16 @@ bool DagLeaf::isAttrMatcher() const {
|
||||
return isSubClassOf("AttrConstraint");
|
||||
}
|
||||
|
||||
bool DagLeaf::isPropMatcher() const {
|
||||
// Property matchers specify a property constraint.
|
||||
return isSubClassOf("PropConstraint");
|
||||
}
|
||||
|
||||
bool DagLeaf::isPropDefinition() const {
|
||||
// Property matchers specify a property definition.
|
||||
return isSubClassOf("Property");
|
||||
}
|
||||
|
||||
bool DagLeaf::isNativeCodeCall() const {
|
||||
return isSubClassOf("NativeCodeCall");
|
||||
}
|
||||
@@ -59,14 +69,26 @@ bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
|
||||
|
||||
bool DagLeaf::isEnumCase() const { return isSubClassOf("EnumCase"); }
|
||||
|
||||
bool DagLeaf::isConstantProp() const { return isSubClassOf("ConstantProp"); }
|
||||
|
||||
bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
|
||||
|
||||
Constraint DagLeaf::getAsConstraint() const {
|
||||
assert((isOperandMatcher() || isAttrMatcher()) &&
|
||||
"the DAG leaf must be operand or attribute");
|
||||
assert((isOperandMatcher() || isAttrMatcher() || isPropMatcher()) &&
|
||||
"the DAG leaf must be operand, attribute, or property");
|
||||
return Constraint(cast<DefInit>(def)->getDef());
|
||||
}
|
||||
|
||||
PropConstraint DagLeaf::getAsPropConstraint() const {
|
||||
assert(isPropMatcher() && "the DAG leaf must be a property matcher");
|
||||
return PropConstraint(cast<DefInit>(def)->getDef());
|
||||
}
|
||||
|
||||
Property DagLeaf::getAsProperty() const {
|
||||
assert(isPropDefinition() && "the DAG leaf must be a property definition");
|
||||
return Property(cast<DefInit>(def)->getDef());
|
||||
}
|
||||
|
||||
ConstantAttr DagLeaf::getAsConstantAttr() const {
|
||||
assert(isConstantAttr() && "the DAG leaf must be constant attribute");
|
||||
return ConstantAttr(cast<DefInit>(def));
|
||||
@@ -77,6 +99,11 @@ EnumCase DagLeaf::getAsEnumCase() const {
|
||||
return EnumCase(cast<DefInit>(def));
|
||||
}
|
||||
|
||||
ConstantProp DagLeaf::getAsConstantProp() const {
|
||||
assert(isConstantProp() && "the DAG leaf must be a constant property value");
|
||||
return ConstantProp(cast<DefInit>(def));
|
||||
}
|
||||
|
||||
std::string DagLeaf::getConditionTemplate() const {
|
||||
return getAsConstraint().getConditionTemplate();
|
||||
}
|
||||
@@ -232,6 +259,7 @@ SymbolInfoMap::SymbolInfo::SymbolInfo(
|
||||
int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
|
||||
switch (kind) {
|
||||
case Kind::Attr:
|
||||
case Kind::Prop:
|
||||
case Kind::Operand:
|
||||
case Kind::Value:
|
||||
return 1;
|
||||
@@ -258,6 +286,18 @@ std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
|
||||
// TODO(suderman): Use a more exact type when available.
|
||||
return "::mlir::Attribute";
|
||||
}
|
||||
case Kind::Prop: {
|
||||
if (op)
|
||||
return cast<NamedProperty *>(op->getArg(getArgIndex()))
|
||||
->prop.getInterfaceType()
|
||||
.str();
|
||||
assert(dagAndConstant && dagAndConstant->dag &&
|
||||
"generic properties must carry their constraint");
|
||||
return reinterpret_cast<const DagLeaf *>(dagAndConstant->dag)
|
||||
->getAsPropConstraint()
|
||||
.getInterfaceType()
|
||||
.str();
|
||||
}
|
||||
case Kind::Operand: {
|
||||
// Use operand range for captured operands (to support potential variadic
|
||||
// operands).
|
||||
@@ -300,6 +340,12 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
|
||||
LLVM_DEBUG(dbgs() << repl << " (Attr)\n");
|
||||
return std::string(repl);
|
||||
}
|
||||
case Kind::Prop: {
|
||||
assert(index < 0);
|
||||
auto repl = formatv(fmt, name);
|
||||
LLVM_DEBUG(dbgs() << repl << " (Prop)\n");
|
||||
return std::string(repl);
|
||||
}
|
||||
case Kind::Operand: {
|
||||
assert(index < 0);
|
||||
auto *operand = cast<NamedTypeConstraint *>(op->getArg(getArgIndex()));
|
||||
@@ -388,10 +434,11 @@ std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
|
||||
LLVM_DEBUG(dbgs() << "getAllRangeUse for '" << name << "': ");
|
||||
switch (kind) {
|
||||
case Kind::Attr:
|
||||
case Kind::Prop:
|
||||
case Kind::Operand: {
|
||||
assert(index < 0 && "only allowed for symbol bound to result");
|
||||
auto repl = formatv(fmt, name);
|
||||
LLVM_DEBUG(dbgs() << repl << " (Operand/Attr)\n");
|
||||
LLVM_DEBUG(dbgs() << repl << " (Operand/Attr/Prop)\n");
|
||||
return std::string(repl);
|
||||
}
|
||||
case Kind::Result: {
|
||||
@@ -449,9 +496,11 @@ bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
|
||||
PrintFatalError(loc, error);
|
||||
}
|
||||
|
||||
auto symInfo =
|
||||
isa<NamedAttribute *>(op.getArg(argIndex))
|
||||
? SymbolInfo::getAttr(&op, argIndex)
|
||||
Argument arg = op.getArg(argIndex);
|
||||
SymbolInfo symInfo =
|
||||
isa<NamedAttribute *>(arg) ? SymbolInfo::getAttr(&op, argIndex)
|
||||
: isa<NamedProperty *>(arg)
|
||||
? SymbolInfo::getProp(&op, argIndex)
|
||||
: SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
|
||||
|
||||
std::string key = symbol.str();
|
||||
@@ -503,6 +552,13 @@ bool SymbolInfoMap::bindAttr(StringRef symbol) {
|
||||
return symbolInfoMap.count(inserted->first) == 1;
|
||||
}
|
||||
|
||||
bool SymbolInfoMap::bindProp(StringRef symbol,
|
||||
const PropConstraint &constraint) {
|
||||
auto inserted =
|
||||
symbolInfoMap.emplace(symbol.str(), SymbolInfo::getProp(&constraint));
|
||||
return symbolInfoMap.count(inserted->first) == 1;
|
||||
}
|
||||
|
||||
bool SymbolInfoMap::contains(StringRef symbol) const {
|
||||
return find(symbol) != symbolInfoMap.end();
|
||||
}
|
||||
@@ -774,10 +830,23 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
|
||||
if (!treeArgName.empty() && treeArgName != "_") {
|
||||
DagLeaf leaf = tree.getArgAsLeaf(i);
|
||||
|
||||
// In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
|
||||
// In (NativeCodeCall<"Foo($_self, $0, $1, $2, $3)"> I8Attr:$a, I8:$b,
|
||||
// $c, I8Prop:$d),
|
||||
if (leaf.isUnspecified()) {
|
||||
// This is case of $c, a Value without any constraints.
|
||||
verifyBind(infoMap.bindValue(treeArgName), treeArgName);
|
||||
} else if (leaf.isPropMatcher()) {
|
||||
// This is case of $d, a binding to a certain property.
|
||||
auto propConstraint = leaf.getAsPropConstraint();
|
||||
if (propConstraint.getInterfaceType().empty()) {
|
||||
PrintFatalError(&def,
|
||||
formatv("binding symbol '{0}' in NativeCodeCall to "
|
||||
"a property constraint without specifying "
|
||||
"that constraint's type is unsupported",
|
||||
treeArgName));
|
||||
}
|
||||
verifyBind(infoMap.bindProp(treeArgName, propConstraint),
|
||||
treeArgName);
|
||||
} else {
|
||||
auto constraint = leaf.getAsConstraint();
|
||||
bool isAttr = leaf.isAttrMatcher() || leaf.isEnumCase() ||
|
||||
|
||||
@@ -112,3 +112,11 @@ Property Property::getBaseProperty() const {
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool Property::isSubClassOf(StringRef className) const {
|
||||
return def && def->isSubClassOf(className);
|
||||
}
|
||||
|
||||
StringRef ConstantProp::getValue() const {
|
||||
return def->getValueAsString("value");
|
||||
}
|
||||
|
||||
@@ -3478,6 +3478,38 @@ def OpWithPropertyPredicates : TEST_Op<"op_with_property_predicates"> {
|
||||
let assemblyFormat = "attr-dict prop-dict";
|
||||
}
|
||||
|
||||
def TestPropPatternOp1 : TEST_Op<"prop_pattern_op_1"> {
|
||||
let arguments = (ins
|
||||
StringProp:$tag,
|
||||
I64Prop:$val,
|
||||
BoolProp:$cond
|
||||
);
|
||||
let results = (outs I32:$results);
|
||||
let assemblyFormat = "$tag $val $cond attr-dict";
|
||||
}
|
||||
|
||||
def TestPropPatternOp2 : TEST_Op<"prop_pattern_op_2"> {
|
||||
let arguments = (ins
|
||||
I32:$input,
|
||||
StringProp:$tag
|
||||
);
|
||||
let assemblyFormat = "$input $tag attr-dict";
|
||||
}
|
||||
|
||||
def : Pat<
|
||||
(TestPropPatternOp1 $tag, NonNegativeI64Prop:$val, ConstantProp<BoolProp, "false">),
|
||||
(TestPropPatternOp1 $tag, (NativeCodeCall<"$0 + 1"> $val), ConstantProp<BoolProp, "true">)>;
|
||||
|
||||
def : Pat<
|
||||
(TestPropPatternOp2 (TestPropPatternOp1 $tag1, $val, ConstantProp<BoolProp, "true">),
|
||||
PropConstraint<CPred<"!$_self.empty()">, "non-empty string">:$tag2),
|
||||
(TestPropPatternOp2
|
||||
(TestPropPatternOp1 $tag1,
|
||||
(NativeCodeCall<"-($0)"> $val),
|
||||
ConstantProp<BoolProp, "false">),
|
||||
(NativeCodeCall<"$0.str() + \".\" + $1.str()"> $tag1, $tag2))
|
||||
>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Dataflow
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -586,7 +586,7 @@ func.func @testMatchMultiVariadicSubSymbol(%arg0: i32, %arg1: i32, %arg2: i32, %
|
||||
|
||||
// CHECK-LABEL: @testMatchMixedVaradicOptional
|
||||
func.func @testMatchMixedVaradicOptional(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) -> () {
|
||||
// CHECK: "test.mixed_variadic_in6"(%arg0, %arg1, %arg2) <{attr1 = 2 : i32}> : (i32, i32, i32) -> ()
|
||||
// CHECK: "test.mixed_variadic_in6"(%arg0, %arg1, %arg2) <{attr1 = 2 : i32}> : (i32, i32, i32) -> ()
|
||||
"test.mixed_variadic_optional_in7"(%arg0, %arg1, %arg2) {attr1 = 2 : i32, operandSegmentSizes = array<i32: 2, 1>} : (i32, i32, i32) -> ()
|
||||
// CHECK: test.mixed_variadic_optional_in7
|
||||
"test.mixed_variadic_optional_in7"(%arg0, %arg1) {attr1 = 2 : i32, operandSegmentSizes = array<i32: 2, 0>} : (i32, i32) -> ()
|
||||
@@ -594,6 +594,32 @@ func.func @testMatchMixedVaradicOptional(%arg0: i32, %arg1: i32, %arg2: i32, %ar
|
||||
return
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test patterns that operate on properties
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: @testSimplePropertyRewrite
|
||||
func.func @testSimplePropertyRewrite() {
|
||||
// CHECK-NEXT: test.prop_pattern_op_1 "o1" 2 true
|
||||
test.prop_pattern_op_1 "o1" 1 false
|
||||
// Pattern not applied when predicate not met
|
||||
// CHECK-NEXT: test.prop_pattern_op_1 "o2" -1 false
|
||||
test.prop_pattern_op_1 "o2" -1 false
|
||||
// Pattern not applied when constant doesn't match
|
||||
// CHCEK-NEXT: test.prop_pattern_op_1 "o3" 1 true
|
||||
test.prop_pattern_op_1 "o3" 1 true
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @testNestedPropertyRewrite
|
||||
func.func @testNestedPropertyRewrite() {
|
||||
// CHECK: %[[v:.*]] = test.prop_pattern_op_1 "s1" -2 false
|
||||
// CHECK: test.prop_pattern_op_2 %[[v]] "s1.t1"
|
||||
%v = test.prop_pattern_op_1 "s1" 1 false
|
||||
test.prop_pattern_op_2 %v "t1"
|
||||
return
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test that natives calls are only called once during rewrites.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -45,3 +45,35 @@ def test1 : Pat<(AOp (BOp:$x $y), $_), (AOp $x, $y)>;
|
||||
// CHECK: tblgen_values.push_back((*x.getODSResults(0).begin()));
|
||||
// CHECK: tblgen_props.y = ::llvm::dyn_cast_if_present<decltype(tblgen_props.y)>(y);
|
||||
// CHECK: tblgen_AOp_0 = rewriter.create<test::AOp>(odsLoc, tblgen_types, tblgen_values, tblgen_props);
|
||||
|
||||
// Note: These use strings to pick up a non-trivial storage/interface type
|
||||
// difference.
|
||||
def COp : NS_Op<"c_op", []> {
|
||||
let arguments = (ins
|
||||
I32:$x,
|
||||
StringProp:$y
|
||||
);
|
||||
|
||||
let results = (outs I32:$z);
|
||||
}
|
||||
|
||||
def DOp : NS_Op<"d_op", []> {
|
||||
let arguments = (ins
|
||||
StringProp:$y
|
||||
);
|
||||
|
||||
let results = (outs I32:$z);
|
||||
}
|
||||
def test2 : Pat<(COp (DOp:$x $y), $_), (COp $x, $y)>;
|
||||
// CHECK-LABEL: struct test2
|
||||
// CHECK: ::llvm::LogicalResult matchAndRewrite
|
||||
// CHECK-DAG: ::llvm::StringRef y;
|
||||
// CHECK-DAG: test::DOp x;
|
||||
// CHECK-DAG: ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;
|
||||
// CHECK: tblgen_ops.push_back(op0);
|
||||
// CHECK: x = castedOp1;
|
||||
// CHECK: tblgen_prop = castedOp1.getProperties().getY();
|
||||
// CHECK: y = tblgen_prop;
|
||||
// CHECK: tblgen_ops.push_back(op1);
|
||||
// CHECK: test::COp::Properties tblgen_props;
|
||||
// CHECK: tblgen_props.setY(y);
|
||||
|
||||
@@ -45,6 +45,24 @@ def DOp : NS_Op<"d_op", []> {
|
||||
|
||||
def Foo : NativeCodeCall<"foo($_builder, $0)">;
|
||||
|
||||
def NonNegProp : PropConstraint<CPred<"$_self >= 0">, "non-negative integer">;
|
||||
|
||||
def EOp : NS_Op<"e_op", []> {
|
||||
let arguments = (ins
|
||||
I32Prop:$x,
|
||||
I64Prop:$y,
|
||||
AnyInteger:$z
|
||||
);
|
||||
let results = (outs I32:$res);
|
||||
}
|
||||
|
||||
def FOp: NS_Op<"f_op", []> {
|
||||
let arguments = (ins
|
||||
I32Prop:$a,
|
||||
AnyInteger:$b
|
||||
);
|
||||
}
|
||||
|
||||
// Test static matcher for duplicate DagNode
|
||||
// ---
|
||||
|
||||
@@ -52,9 +70,16 @@ def Foo : NativeCodeCall<"foo($_builder, $0)">;
|
||||
// CHECK-NEXT: {{.*::mlir::Type type}}
|
||||
// CHECK: static ::llvm::LogicalResult [[$ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]](
|
||||
// CHECK-NEXT: {{.*::mlir::Attribute attr}}
|
||||
// CHECK: template <typename T>
|
||||
// CHECK-NEXT: static ::llvm::LogicalResult [[$PROP_CONSTRAINT:__mlir_ods_local_prop_constraint.*]](
|
||||
// CHECK-NEXT: {{.*T prop}}
|
||||
// CHECK: static ::llvm::LogicalResult [[$DAG_MATCHER:static_dag_matcher.*]](
|
||||
// CHECK: if(::mlir::failed([[$ATTR_CONSTRAINT]]
|
||||
// CHECK: if(::mlir::failed([[$TYPE_CONSTRAINT]]
|
||||
// CHECK: static ::llvm::LogicalResult [[$DAG_MATCHER2:static_dag_matcher.*]](
|
||||
// CHECK-SAME: int32_t &x
|
||||
// CHECK: if(::mlir::failed([[$PROP_CONSTRAINT]]
|
||||
// CHECK: if(::mlir::failed([[$TYPE_CONSTRAINT]]
|
||||
|
||||
// CHECK: if(::mlir::failed([[$DAG_MATCHER]](rewriter, op1, tblgen_ops
|
||||
def : Pat<(AOp (BOp I32Attr:$attr, I32:$int)),
|
||||
@@ -68,3 +93,11 @@ def : Pat<(COp $_, (BOp I32Attr:$attr, I32:$int)),
|
||||
// CHECK: ::llvm::SmallVector<::mlir::Value, 4> [[$ARR:tblgen_variadic_values_.*]];
|
||||
// CHECK: [[$ARR]].push_back([[$VAR]]);
|
||||
def : Pat<(AOp $x), (DOp (variadic (Foo $x)))>;
|
||||
|
||||
// CHECK: if(::mlir::failed([[$DAG_MATCHER2]]({{.*}} x{{[,)]}}
|
||||
def : Pat<(AOp (EOp NonNegProp:$x, NonNegProp:$_, I32:$z)),
|
||||
(AOp $z)>;
|
||||
|
||||
// CHECK: if(::mlir::failed([[$DAG_MATCHER2]]({{.*}} x{{[,)]}}
|
||||
def : Pat<(FOp $_, (EOp NonNegProp:$x, NonNegProp:$_, I32:$z)),
|
||||
(COp $x, $z)>;
|
||||
|
||||
@@ -125,6 +125,11 @@ private:
|
||||
void emitAttributeMatch(DagNode tree, StringRef castedName, int argIndex,
|
||||
int depth);
|
||||
|
||||
// Emits C++ statements for matching the `argIndex`-th argument of the given
|
||||
// DAG `tree` as a property.
|
||||
void emitPropertyMatch(DagNode tree, StringRef castedName, int argIndex,
|
||||
int depth);
|
||||
|
||||
// Emits C++ for checking a match with a corresponding match failure
|
||||
// diagnostic.
|
||||
void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
|
||||
@@ -338,7 +343,7 @@ private:
|
||||
// for each DagNode.
|
||||
int staticMatcherCounter = 0;
|
||||
|
||||
// The DagLeaf which contains type or attr constraint.
|
||||
// The DagLeaf which contains type, attr, or prop constraint.
|
||||
SetVector<DagLeaf> constraints;
|
||||
|
||||
// Static type/attribute verification function emitter.
|
||||
@@ -487,6 +492,19 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
|
||||
auto leaf = tree.getArgAsLeaf(i);
|
||||
if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
|
||||
os << "::mlir::Attribute " << argName << ";\n";
|
||||
} else if (leaf.isPropMatcher()) {
|
||||
StringRef interfaceType = leaf.getAsPropConstraint().getInterfaceType();
|
||||
if (interfaceType.empty())
|
||||
PrintFatalError(loc, "NativeCodeCall cannot have a property operand "
|
||||
"with unspecified interface type");
|
||||
os << interfaceType << " " << argName;
|
||||
if (leaf.isPropDefinition()) {
|
||||
Property propDef = leaf.getAsProperty();
|
||||
// Ensure properties that aren't zero-arg-constructable still work.
|
||||
if (propDef.hasDefaultValue())
|
||||
os << " = " << propDef.getDefaultValue();
|
||||
}
|
||||
os << ";\n";
|
||||
} else {
|
||||
os << "::mlir::Value " << argName << ";\n";
|
||||
}
|
||||
@@ -539,7 +557,7 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
|
||||
auto constraint = leaf.getAsConstraint();
|
||||
|
||||
std::string self;
|
||||
if (leaf.isAttrMatcher() || leaf.isConstantAttr())
|
||||
if (leaf.isAttrMatcher() || leaf.isConstantAttr() || leaf.isPropMatcher())
|
||||
self = argName;
|
||||
else
|
||||
self = formatv("{0}.getType()", argName);
|
||||
@@ -665,6 +683,8 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
|
||||
++nextOperand;
|
||||
} else if (isa<NamedAttribute *>(opArg)) {
|
||||
emitAttributeMatch(tree, castedName, opArgIdx, depth);
|
||||
} else if (isa<NamedProperty *>(opArg)) {
|
||||
emitPropertyMatch(tree, castedName, opArgIdx, depth);
|
||||
} else {
|
||||
PrintFatalError(loc, "unhandled case when matching op");
|
||||
}
|
||||
@@ -942,6 +962,46 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef castedName,
|
||||
os.unindent() << "}\n";
|
||||
}
|
||||
|
||||
void PatternEmitter::emitPropertyMatch(DagNode tree, StringRef castedName,
|
||||
int argIndex, int depth) {
|
||||
Operator &op = tree.getDialectOp(opMap);
|
||||
auto *namedProp = cast<NamedProperty *>(op.getArg(argIndex));
|
||||
|
||||
os << "{\n";
|
||||
os.indent() << formatv(
|
||||
"[[maybe_unused]] auto tblgen_prop = {0}.getProperties().{1}();\n",
|
||||
castedName, op.getGetterName(namedProp->name));
|
||||
|
||||
auto matcher = tree.getArgAsLeaf(argIndex);
|
||||
if (!matcher.isUnspecified()) {
|
||||
if (!matcher.isPropMatcher()) {
|
||||
PrintFatalError(
|
||||
loc, formatv("the {1}-th argument of op '{0}' should be a property",
|
||||
op.getOperationName(), argIndex + 1));
|
||||
}
|
||||
|
||||
// If a constraint is specified, we need to generate function call to its
|
||||
// static verifier.
|
||||
StringRef verifier = staticMatcherHelper.getVerifierName(matcher);
|
||||
emitStaticVerifierCall(
|
||||
verifier, castedName, "tblgen_prop",
|
||||
formatv("\"op '{0}' property '{1}' failed to satisfy constraint: "
|
||||
"'{2}'\"",
|
||||
op.getOperationName(), namedProp->name,
|
||||
escapeString(matcher.getAsConstraint().getSummary()))
|
||||
.str());
|
||||
}
|
||||
|
||||
// Capture the value
|
||||
auto name = tree.getArgName(argIndex);
|
||||
// `$_` is a special symbol to ignore op argument matching.
|
||||
if (!name.empty() && name != "_") {
|
||||
os << formatv("{0} = tblgen_prop;\n", name);
|
||||
}
|
||||
|
||||
os.unindent() << "}\n";
|
||||
}
|
||||
|
||||
void PatternEmitter::emitMatchCheck(
|
||||
StringRef opName, const FmtObjectBase &matchFmt,
|
||||
const llvm::formatv_object_base &failureFmt) {
|
||||
@@ -1384,6 +1444,10 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
|
||||
std::string val = std::to_string(enumCase.getValue());
|
||||
return handleConstantAttr(Attribute(&enumCase.getDef()), val);
|
||||
}
|
||||
if (leaf.isConstantProp()) {
|
||||
auto constantProp = leaf.getAsConstantProp();
|
||||
return constantProp.getValue().str();
|
||||
}
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
|
||||
auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
|
||||
@@ -1710,7 +1774,7 @@ void PatternEmitter::createSeparateLocalVarsForOpArgs(
|
||||
for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
|
||||
const auto *operand =
|
||||
llvm::dyn_cast_if_present<NamedTypeConstraint *>(resultOp.getArg(argIndex));
|
||||
// We do not need special handling for attributes.
|
||||
// We do not need special handling for attributes or properties.
|
||||
if (!operand)
|
||||
continue;
|
||||
|
||||
@@ -1776,7 +1840,7 @@ void PatternEmitter::supplyValuesForOpArgs(
|
||||
if (auto subTree = node.getArgAsNestedDag(argIndex)) {
|
||||
if (!subTree.isNativeCodeCall())
|
||||
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
|
||||
"for creating attribute");
|
||||
"for creating attributes and properties");
|
||||
os << formatv("/*{0}=*/{1}", opArgName, childNodeNames.lookup(argIndex));
|
||||
} else {
|
||||
auto leaf = node.getArgAsLeaf(argIndex);
|
||||
@@ -1788,6 +1852,11 @@ void PatternEmitter::supplyValuesForOpArgs(
|
||||
PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
|
||||
if (!patArgName.empty())
|
||||
os << "/*" << patArgName << "=*/";
|
||||
} else if (leaf.isConstantProp()) {
|
||||
if (!isa<NamedProperty *>(opArg))
|
||||
PrintFatalError(loc, Twine("expected property ") + Twine(argIndex));
|
||||
if (!patArgName.empty())
|
||||
os << "/*" << patArgName << "=*/";
|
||||
} else {
|
||||
os << "/*" << opArgName << "=*/";
|
||||
}
|
||||
@@ -1820,6 +1889,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
|
||||
" tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
|
||||
"tmpAttr);\n}\n";
|
||||
const char *setterCmd = (useProperties) ? setPropCmd : addAttrCmd;
|
||||
const char *propSetterCmd = "tblgen_props.{0}({1});\n";
|
||||
|
||||
int numVariadic = 0;
|
||||
bool hasOperandSegmentSizes = false;
|
||||
@@ -1845,6 +1915,28 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isa<NamedProperty *>(resultOp.getArg(argIndex))) {
|
||||
// The argument in the op definition.
|
||||
auto opArgName = resultOp.getArgName(argIndex);
|
||||
auto setterName = resultOp.getSetterName(opArgName);
|
||||
if (auto subTree = node.getArgAsNestedDag(argIndex)) {
|
||||
if (!subTree.isNativeCodeCall())
|
||||
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
|
||||
"for creating property");
|
||||
|
||||
os << formatv(propSetterCmd, setterName,
|
||||
childNodeNames.lookup(argIndex));
|
||||
} else {
|
||||
auto leaf = node.getArgAsLeaf(argIndex);
|
||||
// The argument in the result DAG pattern.
|
||||
auto patArgName = node.getArgName(argIndex);
|
||||
// The argument in the result DAG pattern.
|
||||
os << formatv(propSetterCmd, setterName,
|
||||
handleOpArgument(leaf, patArgName));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto *operand =
|
||||
cast<NamedTypeConstraint *>(resultOp.getArg(argIndex));
|
||||
if (operand->isVariadic()) {
|
||||
@@ -1973,6 +2065,12 @@ StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) {
|
||||
assert(constraint && "attribute constraint was not uniqued");
|
||||
return *constraint;
|
||||
}
|
||||
if (leaf.isPropMatcher()) {
|
||||
std::optional<StringRef> constraint =
|
||||
staticVerifierEmitter.getPropConstraintFn(leaf.getAsConstraint());
|
||||
assert(constraint && "prop constraint was not uniqued");
|
||||
return *constraint;
|
||||
}
|
||||
assert(leaf.isOperandMatcher());
|
||||
return staticVerifierEmitter.getTypeConstraintFn(leaf.getAsConstraint());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user