[mlir] add a chapter on matchers to the transform dialect tutorial (#76725)
These operations has been available for a while, but were not described in the tutorial. Add a new chapter on using and defining match operations.
This commit is contained in:
committed by
GitHub
parent
633d9184f5
commit
4cb2ef4fe3
581
mlir/docs/Tutorials/transform/Ch4.md
Normal file
581
mlir/docs/Tutorials/transform/Ch4.md
Normal file
@@ -0,0 +1,581 @@
|
||||
# Chapter 4: Matching Payload with Transform Operations
|
||||
|
||||
**Check the continuously-tested version of MLIR files under
|
||||
[mlir/test/Examples/transform/Ch4](https://github.com/llvm/llvm-project/tree/main/mlir/test/Examples/transform/Ch4).**
|
||||
|
||||
Up until now, we were applying transform dialect scripts under the assumption
|
||||
that specific payload operations are identified by the caller when the transform
|
||||
dialect interpreter is invoked. This may be seen as contrary to the idea of
|
||||
driving transformations from a dialect since the transformation targets must be
|
||||
identified through mechanisms external to the transform dialect interpreter, for
|
||||
example, when invoking the interpreter programmatically in C++ or through pass
|
||||
arguments as seen in previous chapters. It also adds practical overhead due to
|
||||
increased interaction with the interpreter in C++, and cognitive overhead of
|
||||
manipulating two interfaces at once. To remedy this, Transform dialect proposes
|
||||
a subset of operations for _matching_ payload operations that need to be
|
||||
transformed.
|
||||
|
||||
_Match_ operations are simply transform operations with some additional
|
||||
guarantees. In particular, they are not expected to modify the payload IR and
|
||||
are expected to fail if their operands (typically payload operation handles) are
|
||||
not associated with payload IR objects having desired properties, such as
|
||||
operation names or kinds of arguments. Using simple combinator operations, it
|
||||
becomes possible to set up a higher-level match and rewrite infrastructure
|
||||
directly within the transform dialect.
|
||||
|
||||
|
||||
## Simple match
|
||||
|
||||
Let us reconsider the “fully connected layer” example from [Chapter
|
||||
1](Ch1.md#chaining-transformations-with-handles), reproduced below for
|
||||
convenience.
|
||||
|
||||
|
||||
```mlir
|
||||
// Original function to optimize.
|
||||
func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
|
||||
%bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
|
||||
-> tensor<512x512xf32> {
|
||||
// Matrix-matrix multiplication.
|
||||
%matmul = linalg.matmul
|
||||
ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
|
||||
outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
|
||||
|
||||
// Elementwise addition.
|
||||
%biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
|
||||
ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
|
||||
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
|
||||
|
||||
// Elementwise max with 0 (ReLU).
|
||||
%c0f = arith.constant 0.0 : f32
|
||||
%relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
|
||||
ins(%biased, %c0f : tensor<512x512xf32>, f32)
|
||||
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
|
||||
func.return %relued : tensor<512x512xf32>
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
|
||||
In Chapter 1, we were calling the test transform interpreter pass with
|
||||
additional arguments, `bind-first-extra-to-ops=linalg.matmul
|
||||
bind-second-extra-to-ops=linalg.elemwise_binary`, to provide initial
|
||||
associations for operation handles. Instead, we can use match operations to
|
||||
discover relevant operations in the payload IR. Match operations can be combined
|
||||
with “regular” transform operations using, e.g., the
|
||||
`transform.collect_matching` combinator operation that leverages the concept of
|
||||
named sequences to organize matchers.
|
||||
|
||||
|
||||
```mlir
|
||||
// The module containing named sequences must have an attribute allowing them
|
||||
// to enable verification.
|
||||
module @transforms attributes { transform.with_named_sequence } {
|
||||
// Entry point. This takes as the only argument the root operation (typically
|
||||
// pass root) given to the transform interpreter.
|
||||
transform.named_sequence @__transform_main(
|
||||
%root: !transform.any_op {transform.readonly}) {
|
||||
// Collect operations that match the criteria specified in named sequence.
|
||||
// If the named sequence fails with a silenceable failure, silences it (the
|
||||
// message is forwarded to the debug stream). If the named sequence
|
||||
// succeeds, appends its results to the results of this operation.
|
||||
%elemwise = transform.collect_matching @match_elemwise in %root
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
%matmul = transform.collect_matching @match_matmul in %root
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
transform.include @print_elemwise failures(propagate) (%elemwise)
|
||||
: (!transform.any_op) -> ()
|
||||
transform.include @print_matmul failures(propagate) (%matmul)
|
||||
: (!transform.any_op) -> ()
|
||||
|
||||
transform.yield
|
||||
}
|
||||
|
||||
// This is a matcher sequence. It is given an operation to match and the
|
||||
// match is considered successful unless any nested operation produces a
|
||||
// failure. The values yielded by this operation will be forwarded to the
|
||||
// rewriter sequence on success.
|
||||
transform.named_sequence @match_elemwise(
|
||||
%entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
|
||||
transform.match.operation_name %entry ["linalg.elemwise_binary"]
|
||||
: !transform.any_op
|
||||
transform.yield %entry : !transform.any_op
|
||||
}
|
||||
transform.named_sequence @match_matmul(
|
||||
%entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
|
||||
transform.match.operation_name %entry ["linalg.matmul"] : !transform.any_op
|
||||
transform.yield %entry : !transform.any_op
|
||||
}
|
||||
|
||||
// This is a rewriter sequence.
|
||||
transform.named_sequence @print_elemwise(
|
||||
%elemwise_binary: !transform.any_op {transform.readonly}) {
|
||||
transform.test_print_remark_at_operand
|
||||
%elemwise_binary, "elementwise binary" : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
transform.named_sequence @print_matmul(
|
||||
%matmul: !transform.any_op {transform.readonly}) {
|
||||
transform.test_print_remark_at_operand %matmul, "matmul" : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
|
||||
This script can be executed using the non-test interpreter pass running on the
|
||||
root operation of the translation unit without additional flags: `mlir-opt
|
||||
--transform-interpreter`. It will emit corresponding remarks at
|
||||
`linalg.elemwise_binary` and `linalg.matmul` operations. In debug builds, the
|
||||
infrastructure provides a convenient method to understand the matching process
|
||||
by passing `-debug-only=transform-matcher` to `mlir-opt` or a derived tool. It
|
||||
will print the silenceable failure messages produced by the match operations
|
||||
into the debug stream, for example:
|
||||
|
||||
|
||||
```
|
||||
<...>
|
||||
[transform-matcher] matching %0 = linalg.matmul ins(%arg0, %arg1 : tensor<512x512xf32>, tensor<512x512xf32>) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32> @0x5622eee08410
|
||||
[transform-matcher] matcher match_elemwise failed: wrong operation name
|
||||
<...>
|
||||
```
|
||||
|
||||
|
||||
This is now sufficient to run the rest of the transform script from Chapter 1,
|
||||
substituting `%arg1` with `%matmul` and `%arg2` with `%elemwise`.
|
||||
|
||||
|
||||
## Matching Chains of Operations
|
||||
|
||||
The matcher above remains naive as it matches _all_ operations of the certain
|
||||
kind under the payload root. These operations may or may not be related, and
|
||||
may, for example, belong to different functions. Even if they are in a single
|
||||
function, if there are multiple groups of such operations, we wouldn’t be able
|
||||
to differentiate them with this approach. In reality, we want to match a
|
||||
specific group of operations where a `matmul` operation produces a result that
|
||||
is used by an elementwise operation, which in turn feeds another elementwise
|
||||
operation in a similar way.
|
||||
|
||||
This can be achieved using the following matcher sequence.
|
||||
|
||||
|
||||
```mlir
|
||||
// This is also a matcher sequence. It is similarly given an operation to
|
||||
// match and nested operations must succeed in order for a match to be deemed
|
||||
// successful. It starts matching from the last operation in the use-def chain
|
||||
// and goes back because each operand (use) has exactly one definition.
|
||||
transform.named_sequence @match_matmul_elemwise(
|
||||
%last: !transform.any_op {transform.readonly})
|
||||
-> (!transform.any_op, !transform.any_op, !transform.any_op) {
|
||||
// The last operation must be an elementwise binary.
|
||||
transform.match.operation_name %last ["linalg.elemwise_binary"]
|
||||
: !transform.any_op
|
||||
// Its first operand must be defined by another operation, to which we
|
||||
// will get a handle here. We are guaranteed that the first operand exists
|
||||
// because we know the operation is binary, but even in absence of such a
|
||||
// guarantee, this operation would have produced a silenceable failure when
|
||||
// `%last` does not have enough operands.
|
||||
%middle = transform.get_producer_of_operand %last[0]
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
// The defining operation must itself be an elementwise binary.
|
||||
transform.match.operation_name %middle ["linalg.elemwise_binary"]
|
||||
: !transform.any_op
|
||||
// And the first operand of that operation must be defined by yet another
|
||||
// operation.
|
||||
%matmul = transform.get_producer_of_operand %middle[0]
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
// And that operation is a matmul.
|
||||
transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op
|
||||
// We will yield the handles to the matmul and the two elementwise
|
||||
// operations separately.
|
||||
transform.yield %matmul, %middle, %last
|
||||
: !transform.any_op, !transform.any_op, !transform.any_op
|
||||
}
|
||||
```
|
||||
|
||||
This matcher is applicable in presence of other `elemwise` and `matmul`
|
||||
operations and will return the triple of _related_ operations rather than
|
||||
operations in the order in which they are found. It can be exercised similarly
|
||||
to the previous incarnation, as follows.
|
||||
|
||||
```mlir
|
||||
// Alternative entry point.
|
||||
transform.named_sequence @__transform_main(
|
||||
%root: !transform.any_op {transform.readonly}) {
|
||||
// Collect groups of operations that match the criteria specified in the
|
||||
// named sequence.
|
||||
%matmul, %el1, %el2 = transform.collect_matching @match_matmul_elemwise in %root
|
||||
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
|
||||
%elemwise = transform.merge_handles %el1, %el2 : !transform.any_op
|
||||
|
||||
transform.include @print_elemwise failures(propagate) (%elemwise)
|
||||
: (!transform.any_op) -> ()
|
||||
transform.include @print_matmul failures(propagate) (%matmul)
|
||||
: (!transform.any_op) -> ()
|
||||
|
||||
transform.yield
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## Defining Match Operations
|
||||
|
||||
The matcher of a chain of operations is correct in presence of other operations,
|
||||
but is still insufficiently robust for many cases of interest. In particular,
|
||||
using `transform.get_producer_of_operand %last[0]` requires that the _first_
|
||||
operand of elementwise operations is produced by another operation. The same
|
||||
transformation strategy may however apply regardless of the operand position:
|
||||
many binary operations are associative. Let us use this opportunity to introduce
|
||||
a new match operation. Specifically, we would like this operation to succeed if
|
||||
_any_ of the operands satisfies certain conditions that can be expressed as
|
||||
other match operations. We also want it to return some of the state and the
|
||||
position of the matched operand in the operand list.
|
||||
|
||||
Match operations are defined similarly to other transform operations, with the
|
||||
only difference of additionally implementing the `MatchOpInterface`. Note that
|
||||
this interface has _no additional methods_ (though it may add some eventually)
|
||||
and is only used as a verification contract that the operation is intended for
|
||||
matching and will not attempt to transform the payload. The minimal definition
|
||||
of our operation is as follows.
|
||||
|
||||
|
||||
```tablegen
|
||||
// Define the new operation. By convention, prefix its name with `match`
|
||||
// followed by the name of the dialect extension.
|
||||
def HasOperandSatisfyingOp : TransformDialectOp<"match.my.has_operand_satisfying",
|
||||
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
// Indicate that the operation implements MatchOpInterface in addition to
|
||||
// the TransformOpInterface. This interface is only used as a tag at this
|
||||
// point and has no methods that are mandatory to implement.
|
||||
MatchOpInterface,
|
||||
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> {
|
||||
let summary = "Succeed if any of the operands matches all nested criteria";
|
||||
let arguments = (ins TransformHandleTypeInterface:$op);
|
||||
let results = (outs TransformParamTypeInterface:$position,
|
||||
Variadic<Transform_AnyHandleOrParamType>:$results);
|
||||
|
||||
// Match operations can be arbitrarily complex, e.g., containing regions.
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
let hasVerifier = 1;
|
||||
let assemblyFormat = [{
|
||||
$op `:` functional-type($op, results) attr-dict-with-keyword $body
|
||||
}];
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
It takes as argument the handle associated with the payload operations whose
|
||||
operands it will match, has an associated single-block region containing the
|
||||
match criteria, and returns the position of the matched operand as well as any
|
||||
other transform value yielded from the body on the successful match.
|
||||
|
||||
The matching logic is implemented in the `apply` method of the
|
||||
`TransformOpInterface` and is easily composable with other transform operations.
|
||||
All facilities for managing the interpreter state and recursively entering the
|
||||
blocks are available in the same way as they are for “regular” transform
|
||||
operations. Match operations are expected to return a silenceable failure to
|
||||
indicate failure to match, and to immediately propagate definite failures. If
|
||||
they have nested operations, they are expected to handle and, in most cases,
|
||||
silence the silenceable failures produced when applying those operations. For
|
||||
our operation, the matching is essentially a loop iterating over all operands of
|
||||
the (single) payload operation and applying nested transform ops until they all
|
||||
succeed for one of the operands.
|
||||
|
||||
|
||||
```cpp
|
||||
// Matcher ops implement `apply` similarly to other transform ops. They are not
|
||||
// expected to modify payload, but use the tri-state result to signal failure or
|
||||
// success to match, as well as potential irrecoverable errors.
|
||||
mlir::DiagnosedSilenceableFailure
|
||||
mlir::transform::HasOperandSatisfyingOp::apply(
|
||||
mlir::transform::TransformRewriter &rewriter,
|
||||
mlir::transform::TransformResults &results,
|
||||
mlir::transform::TransformState &state) {
|
||||
// For simplicity, only handle a single payload op. Actual implementations
|
||||
// can use `SingleOpMatcher` trait to simplify implementation and document
|
||||
// this expectation.
|
||||
auto payloadOps = state.getPayloadOps(getOp());
|
||||
if (!llvm::hasSingleElement(payloadOps))
|
||||
return emitSilenceableError() << "expected single payload";
|
||||
|
||||
// Iterate over all operands of the payload op to see if they can be matched
|
||||
// using the body of this op.
|
||||
Operation *payload = *payloadOps.begin();
|
||||
for (OpOperand &operand : payload->getOpOperands()) {
|
||||
// Create a scope for transform values defined in the body. This corresponds
|
||||
// to the syntactic scope of the region attached to this op. Any values
|
||||
// associated with payloads from now on will be automatically dissociated
|
||||
// when this object is destroyed, i.e. at the end of the iteration.
|
||||
// Associate the block argument handle with the operand.
|
||||
auto matchScope = state.make_region_scope(getBody());
|
||||
if (failed(state.mapBlockArgument(getBody().getArgument(0),
|
||||
{operand.get()}))) {
|
||||
return DiagnosedSilenceableFailure::definiteFailure();
|
||||
}
|
||||
|
||||
// Iterate over all nested matchers with the current mapping and see if they
|
||||
// succeed.
|
||||
bool matchSucceeded = true;
|
||||
for (Operation &matcher : getBody().front().without_terminator()) {
|
||||
// Matcher ops are applied similarly to any other transform op.
|
||||
DiagnosedSilenceableFailure diag =
|
||||
state.applyTransform(cast<TransformOpInterface>(matcher));
|
||||
|
||||
// Definite failures are immediately propagated as they are irrecoverable.
|
||||
if (diag.isDefiniteFailure())
|
||||
return diag;
|
||||
|
||||
// On success, keep checking the remaining conditions.
|
||||
if (diag.succeeded())
|
||||
continue;
|
||||
|
||||
// Report failure-to-match for debugging purposes and stop matching this
|
||||
// operand.
|
||||
assert(diag.isSilenceableFailure());
|
||||
DEBUG_MATCHER(DBGS_MATCHER()
|
||||
<< "failed to match operand #" << operand.getOperandNumber()
|
||||
<< ": " << diag.getMessage());
|
||||
(void)diag.silence();
|
||||
matchSucceeded = false;
|
||||
break;
|
||||
}
|
||||
// If failed to match this operand, try other operands.
|
||||
if (!matchSucceeded)
|
||||
continue;
|
||||
|
||||
// If we reached this point, the matching succeeded for the current operand.
|
||||
// Remap the values associated with terminator operands to be associated
|
||||
// with op results, and also map the parameter result to the operand's
|
||||
// position. Note that it is safe to do here despite the end of the scope
|
||||
// as `results` are integrated into `state` by the interpreter after `apply`
|
||||
// returns rather than immediately.
|
||||
SmallVector<SmallVector<MappedValue>> yieldedMappings;
|
||||
transform::detail::prepareValueMappings(
|
||||
yieldedMappings, getBody().front().getTerminator()->getOperands(),
|
||||
state);
|
||||
results.setParams(getPosition().cast<OpResult>(),
|
||||
{rewriter.getI32IntegerAttr(operand.getOperandNumber())});
|
||||
for (auto &&[result, mapping] : llvm::zip(getResults(), yieldedMappings))
|
||||
results.setMappedValues(result, mapping);
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
// If we reached this point, none of the operands succeeded the match.
|
||||
return emitSilenceableError()
|
||||
<< "none of the operands satisfied the conditions";
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
|
||||
By convention, operations implementing `MatchOpInterface` must not modify
|
||||
payload IR and must therefore specify that they only read operand handles and
|
||||
payload as their effects.
|
||||
|
||||
|
||||
```cpp
|
||||
void transform::CollectMatchingOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
onlyReadsHandle(getRoot(), effects);
|
||||
producesHandle(getResults(), effects);
|
||||
onlyReadsPayload(effects);
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
This operation can now be included in a transform dialect extension, loaded and
|
||||
used in our matcher. Specifically, we will use it to indicate that either of the
|
||||
operands of the “max” elementwise operation in our example can be produced by
|
||||
the previous elementwise operation. The previous operation will still require
|
||||
the matmul to produce the first operand for simplicity. The updated matcher
|
||||
sequence looks as follows.
|
||||
|
||||
|
||||
```mlir
|
||||
transform.named_sequence @match_matmul_elemwise(
|
||||
%last: !transform.any_op {transform.readonly})
|
||||
-> (!transform.any_op, !transform.any_op, !transform.any_op,
|
||||
!transform.param<i32>) {
|
||||
// The last operation must be an elementwise binary.
|
||||
transform.match.operation_name %last ["linalg.elemwise_binary"]
|
||||
: !transform.any_op
|
||||
|
||||
// One of its operands must be defined by another operation, to which we
|
||||
// will get a handle here. This is achieved thanks to a newly defined
|
||||
// operation that tries to match operands one by one using the match
|
||||
// operations nested in its region.
|
||||
%pos, %middle = transform.match.my.has_operand_satisfying %last
|
||||
: (!transform.any_op) -> (!transform.param<i32>, !transform.any_op) {
|
||||
^bb0(%operand: !transform.any_value):
|
||||
// The operand must be defined by an operation.
|
||||
%def = transform.get_defining_op %operand
|
||||
: (!transform.any_value) -> !transform.any_op
|
||||
// The defining operation must itself be an elementwise binary.
|
||||
transform.match.operation_name %def ["linalg.elemwise_binary"]
|
||||
: !transform.any_op
|
||||
transform.yield %def : !transform.any_op
|
||||
}
|
||||
|
||||
// And the first operand of that operation must be defined by yet another
|
||||
// operation.
|
||||
%matmul = transform.get_producer_of_operand %middle[0]
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
// And that operation is a matmul.
|
||||
transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op
|
||||
// We will yield the handles to the matmul and the two elementwise
|
||||
// operations separately.
|
||||
transform.yield %matmul, %middle, %last, %pos
|
||||
: !transform.any_op, !transform.any_op, !transform.any_op,
|
||||
!transform.param<i32>
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
This achieves the desired effect and matches both `max(add(matmul(...), bias),
|
||||
0)` and `max(0, add(matmul(...), bias))` in the same values. The `%pos` value is
|
||||
a transform dialect _parameter_, which is used to store lists of entities known
|
||||
to be constant throughout the transform application. Most often, parameters are
|
||||
numeric values, but they can generally be any MLIR attributes.
|
||||
|
||||
In order to demonstrate that groups of operations are matched independently of
|
||||
each other, let us use the `transform.foreach_match` operation that allows one
|
||||
to implement a simple high-level pattern rewriting approach within the transform
|
||||
dialect (for advanced or lower-level pattern rewriting, consider PDL(L) or C++
|
||||
rewriting APIs). It maps a matcher named sequence to an action named sequence,
|
||||
and the latter gets invoked whenever the former succeeds.
|
||||
|
||||
|
||||
```mlir
|
||||
// Traverses the payload IR associated with the operand handle, invoking
|
||||
// @match_matmul_elemwise on each of the operations. If the named sequence
|
||||
// succeeds, i.e., if none of the nested match (transform) operations
|
||||
// produced a silenceable failure, invokes @print_matmul_elemwise and
|
||||
// forwards the values yielded as arguments of the new invocation. If the
|
||||
// named sequence fails with a silenceable failure, silences it (the message
|
||||
// is forwarded to the debug stream). Definite failures are propagated
|
||||
// immediately and unconditionally, as usual.
|
||||
transform.foreach_match in %root
|
||||
@match_matmul_elemwise -> @print_matmul_elemwise
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
```
|
||||
|
||||
|
||||
The `@print_matmul_elemwise` named sequence, available in `multiple.mlir`, will
|
||||
use the parameter with the position of the operand to differentiate the two
|
||||
groups.
|
||||
|
||||
|
||||
## Matchers for Inferred Features
|
||||
|
||||
The matcher sequences described above, although useful to drive transformations
|
||||
from within the transform dialect interpreter, are rather basic since they
|
||||
mostly rely on operation names and use-def chains. Alternative implementations
|
||||
using APIs or various declarative rewrite rules are barely less expressive and
|
||||
sometimes more concise. The real power of transform dialect matcher ops lies in
|
||||
the possibility to define matchers of _inferred properties_ of payloads, i.e.,
|
||||
properties that are not directly accessible as an attribute of an operation or
|
||||
any straightforward relation between IR components.
|
||||
|
||||
The utility of such matchers can be easily demonstrated by slightly modifying
|
||||
our original example. If matrix multiplication is expressed as a special case of
|
||||
tensor contraction using `linalg.generic` instead of `linalg.matmul`, the
|
||||
operation name-based matcher no longer applies. Yet such a representation is
|
||||
very common and can appear both in the original input and during the course of
|
||||
transformation, e.g., where a higher-dimensional contraction is decomposed into
|
||||
loops around a matrix multiplication.
|
||||
|
||||
In order to be a (potentially transposed) matrix multiplication, the
|
||||
`linalg.generic` operation must have the following features:
|
||||
|
||||
|
||||
|
||||
* Total rank of 3.
|
||||
* Two inputs accessed as projected permutation of iteration dimensions.
|
||||
* One output accessed as projected permutation of iteration dimensions.
|
||||
* Iteration dimensions can be subdivided into LHS parallel, RHS parallel and reduction dimensions.
|
||||
* The body block consists of a multiplication and an addition.
|
||||
|
||||
Most of these features can be derived from the properties of the operation,
|
||||
e.g., the total rank corresponds to the number of entries in the `iterators`
|
||||
attribute, but almost none of them are immediately accessible in the IR or in
|
||||
any declarative form, which is usually limited to checking the presence or the
|
||||
exact match of an attribute or a type. The transform dialect allows these
|
||||
features to be implemented in the `apply` method of a matcher op and reused
|
||||
across multiple matching cases. For structured linear algebra payload
|
||||
operations, many such match operations are readily available in the `structured`
|
||||
extension. They are sufficient to implement a matrix multiplication matcher
|
||||
using the features listed above almost verbatim.
|
||||
|
||||
|
||||
```mlir
|
||||
transform.named_sequence @match_generic_matmul(
|
||||
%candidate: !transform.any_op {transform.readonly}) -> !transform.any_op {
|
||||
// Match a structured linear algebra operation.
|
||||
transform.match.structured %candidate : !transform.any_op {
|
||||
^bb0(%c: !transform.any_op):
|
||||
// With a rank equal to 3.
|
||||
%rank = transform.match.structured.rank %c
|
||||
: (!transform.any_op) -> !transform.param<i64>
|
||||
%c3 = transform.param.constant 3 : i64 -> !transform.param<i64>
|
||||
transform.match.param.cmpi eq %rank, %c3 : !transform.param<i64>
|
||||
|
||||
// With 2 inputs.
|
||||
%n_ins = transform.match.structured.num_inputs %c
|
||||
: (!transform.any_op) -> !transform.param<i64>
|
||||
%c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
|
||||
transform.match.param.cmpi eq %n_ins, %c2 : !transform.param<i64>
|
||||
|
||||
// With 1 output (note that structured ops in destination passing style
|
||||
// has as many inits as outputs).
|
||||
%n_inits = transform.match.structured.num_inits %c
|
||||
: (!transform.any_op) -> !transform.param<i64>
|
||||
%c1 = transform.param.constant 1 : i64 -> !transform.param<i64>
|
||||
transform.match.param.cmpi eq %n_inits, %c1 : !transform.param<i64>
|
||||
|
||||
// All inputs and inits are accessed with a projected permutation.
|
||||
transform.match.structured.input %c[all] {projected_permutation}
|
||||
: !transform.any_op
|
||||
transform.match.structured.init %c[0] {projected_permutation}
|
||||
: !transform.any_op
|
||||
|
||||
// The body is a mulf/addf contraction with appropriate dimensions.
|
||||
transform.match.structured.body %c
|
||||
{ contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op
|
||||
%batch, %lhs, %rhs, %reduction =
|
||||
transform.match.structured.classify_contraction_dims %c
|
||||
: (!transform.any_op)
|
||||
-> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>,
|
||||
!transform.param<i64>)
|
||||
|
||||
|
||||
// There is one of lhs, rhs and reduction dimensions and zero batch
|
||||
// dimensions.
|
||||
%n_batch = transform.num_associations %batch
|
||||
: (!transform.param<i64>) -> !transform.param<i64>
|
||||
%n_lhs = transform.num_associations %lhs
|
||||
: (!transform.param<i64>) -> !transform.param<i64>
|
||||
%n_rhs = transform.num_associations %rhs
|
||||
: (!transform.param<i64>) -> !transform.param<i64>
|
||||
%n_reduction = transform.num_associations %reduction
|
||||
: (!transform.param<i64>) -> !transform.param<i64>
|
||||
%c0 = transform.param.constant 0 : i64 -> !transform.param<i64>
|
||||
transform.match.param.cmpi eq %n_batch, %c0 : !transform.param<i64>
|
||||
transform.match.param.cmpi eq %n_lhs, %c1 : !transform.param<i64>
|
||||
transform.match.param.cmpi eq %n_rhs, %c1 : !transform.param<i64>
|
||||
transform.match.param.cmpi eq %n_reduction, %c1 : !transform.param<i64>
|
||||
}
|
||||
transform.yield %candidate : !transform.any_op
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
While this example leverages the contraction-specific matchers that have a
|
||||
rather non-trivial C++ implementation, the transform dialect is sufficiently
|
||||
flexible to implement this reasoning directly if desired. One could, for
|
||||
example, obtain the access map of each input as a parameter and extract the
|
||||
accessed dimensions as other parameters that can be compared with each other to
|
||||
ensure the subscripts are `m,k` for LHS, `k,n` for RHS and `m,n` for the
|
||||
init/result given the `m,n,k` notation for loops.
|
||||
|
||||
@@ -26,6 +26,7 @@ The tutorial is divided into the following chapters.
|
||||
- [Chapter #1](Ch1.md): Combining Existing Transformations
|
||||
- [Chapter #2](Ch2.md): Adding a Simple New Transformation Operation
|
||||
- [Chapter #3](Ch3.md): More than Simple Transform Operations
|
||||
- [Chapter #4](Ch4.md): Matching Payload with Transform Operations
|
||||
- [Chapter H](ChH.md): Reproducing Halide Schedule
|
||||
|
||||
The code corresponding to this tutorial is located under
|
||||
|
||||
@@ -2,3 +2,4 @@ add_custom_target(TransformExample)
|
||||
|
||||
add_subdirectory(Ch2)
|
||||
add_subdirectory(Ch3)
|
||||
add_subdirectory(Ch4)
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This is the top-level file for the Transform dialect tutorial chapter 2.
|
||||
// This is the top-level file for the Transform dialect tutorial chapter 3.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
21
mlir/examples/transform/Ch4/CMakeLists.txt
Normal file
21
mlir/examples/transform/Ch4/CMakeLists.txt
Normal file
@@ -0,0 +1,21 @@
|
||||
# For a better top-level template to copy, see examples/standalone.
|
||||
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR})
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||
|
||||
add_subdirectory(include)
|
||||
add_subdirectory(lib)
|
||||
|
||||
add_dependencies(TransformExample transform-opt-ch4)
|
||||
add_llvm_example(transform-opt-ch4
|
||||
transform-opt/transform-opt.cpp)
|
||||
|
||||
target_link_libraries(transform-opt-ch4
|
||||
PRIVATE
|
||||
MLIRIR
|
||||
MLIRMlirOptMain
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRTransformDialectTransforms
|
||||
MyExtensionCh4
|
||||
)
|
||||
14
mlir/examples/transform/Ch4/include/CMakeLists.txt
Normal file
14
mlir/examples/transform/Ch4/include/CMakeLists.txt
Normal file
@@ -0,0 +1,14 @@
|
||||
# Tell Tablegen to use MyExtension.td as input.
|
||||
set(LLVM_TARGET_DEFINITIONS MyExtension.td)
|
||||
|
||||
# Ask Tablegen to generate op declarations and definitions from ODS.
|
||||
mlir_tablegen(MyExtension.h.inc -gen-op-decls)
|
||||
mlir_tablegen(MyExtension.cpp.inc -gen-op-defs)
|
||||
|
||||
# Add a CMakeTarget we can depend on to ensure the generation happens before the
|
||||
# compilation.
|
||||
add_public_tablegen_target(MyExtensionCh4IncGen)
|
||||
|
||||
# Don't forget to generate the documentation, this will produce a
|
||||
# MyExtensionCh4.md under Tutorials/transform
|
||||
add_mlir_doc(MyExtension MyExtensionCh4 Tutorials/transform/ -gen-op-doc)
|
||||
30
mlir/examples/transform/Ch4/include/MyExtension.h
Normal file
30
mlir/examples/transform/Ch4/include/MyExtension.h
Normal file
@@ -0,0 +1,30 @@
|
||||
//===-- MyExtension.h - Transform dialect tutorial --------------*- c++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines Transform dialect extension operations used in the
|
||||
// Chapter 4 of the Transform dialect tutorial.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Bytecode/BytecodeOpInterface.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformOps.h"
|
||||
|
||||
namespace mlir {
|
||||
class CallOpInterface;
|
||||
namespace func {
|
||||
class CallOp;
|
||||
} // namespace func
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "MyExtension.h.inc"
|
||||
|
||||
// Registers our Transform dialect extension.
|
||||
void registerMyExtension(::mlir::DialectRegistry ®istry);
|
||||
46
mlir/examples/transform/Ch4/include/MyExtension.td
Normal file
46
mlir/examples/transform/Ch4/include/MyExtension.td
Normal file
@@ -0,0 +1,46 @@
|
||||
//===-- MyExtension.td - Transform dialect tutorial --------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines Transform dialect extension operations used in the
|
||||
// Chapter 4 of the Transform dialect tutorial.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MY_EXTENSION
|
||||
#define MY_EXTENSION
|
||||
|
||||
include "mlir/Dialect/Transform/IR/MatchInterfaces.td"
|
||||
include "mlir/Dialect/Transform/IR/TransformDialect.td"
|
||||
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
// Define the new operation. By convention, prefix its name with `match`
|
||||
// followed by the name of the dialect extension.
|
||||
def HasOperandSatisfyingOp : TransformDialectOp<"match.my.has_operand_satisfying",
|
||||
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
// Indicate that the operation implements MatchOpInterface in addition to
|
||||
// the TransformOpInterface. This interface is only used as a tag at this
|
||||
// point and has no methods that are mandatory to implement.
|
||||
MatchOpInterface,
|
||||
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> {
|
||||
let summary = "Succeed if any of the operands matches all nested criteria";
|
||||
let arguments = (ins TransformHandleTypeInterface:$op);
|
||||
let results = (outs TransformParamTypeInterface:$position,
|
||||
Variadic<Transform_AnyHandleOrParamType>:$results);
|
||||
|
||||
// Match operations can be arbitrarily complex, e.g., containing regions.
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
let hasVerifier = 1;
|
||||
let assemblyFormat = [{
|
||||
$op `:` functional-type($op, results) attr-dict-with-keyword $body
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // MY_EXTENSION
|
||||
20
mlir/examples/transform/Ch4/lib/CMakeLists.txt
Normal file
20
mlir/examples/transform/Ch4/lib/CMakeLists.txt
Normal file
@@ -0,0 +1,20 @@
|
||||
# Outside examples, this should be `add_mlir_library`.
|
||||
add_mlir_example_library(
|
||||
# Library called MyExtension.
|
||||
MyExtensionCh4
|
||||
|
||||
# Built from the following source files.
|
||||
MyExtension.cpp
|
||||
|
||||
# Make includes visible without top-level path.
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/examples/transform/Ch4/include
|
||||
|
||||
# Make sure ODS declaration and definitions are generated before compiling this.
|
||||
DEPENDS
|
||||
MyExtensionCh4IncGen
|
||||
|
||||
# Link in the transform dialect, an all generated dialects.
|
||||
LINK_LIBS PRIVATE
|
||||
MLIRTransformDialect
|
||||
)
|
||||
207
mlir/examples/transform/Ch4/lib/MyExtension.cpp
Normal file
207
mlir/examples/transform/Ch4/lib/MyExtension.cpp
Normal file
@@ -0,0 +1,207 @@
|
||||
//===-- MyExtension.cpp - Transform dialect tutorial ----------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines Transform dialect extension operations used in the
|
||||
// Chapter 4 of the Transform dialect tutorial.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "MyExtension.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE_MATCHER "transform-matcher"
|
||||
#define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
|
||||
#define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "MyExtension.cpp.inc"
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// MyExtension
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
// Define a new transform dialect extension. This uses the CRTP idiom to
|
||||
// identify extensions.
|
||||
class MyExtension
|
||||
: public ::mlir::transform::TransformDialectExtension<MyExtension> {
|
||||
public:
|
||||
// The extension must derive the base constructor.
|
||||
using Base::Base;
|
||||
|
||||
// This function initializes the extension, similarly to `initialize` in
|
||||
// dialect definitions. List individual operations and dependent dialects
|
||||
// here.
|
||||
void init();
|
||||
};
|
||||
|
||||
void MyExtension::init() {
|
||||
// Register the additional match operations with the dialect similarly to
|
||||
// other transform operations. List all operations generated from ODS. This
|
||||
// call will perform additional checks that the operations implement the
|
||||
// transform and memory effect interfaces required by the dialect interpreter
|
||||
// and assert if they do not.
|
||||
registerTransformOps<
|
||||
#define GET_OP_LIST
|
||||
#include "MyExtension.cpp.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// HasOperandSatisfyingOp
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
/// Returns `true` if both types implement one of the interfaces provided as
|
||||
/// template parameters.
|
||||
template <typename... Tys>
|
||||
static bool implementSameInterface(mlir::Type t1, mlir::Type t2) {
|
||||
return ((llvm::isa<Tys>(t1) && llvm::isa<Tys>(t2)) || ... || false);
|
||||
}
|
||||
|
||||
/// Returns `true` if both types implement one of the transform dialect
|
||||
/// interfaces.
|
||||
static bool implementSameTransformInterface(mlir::Type t1, mlir::Type t2) {
|
||||
return implementSameInterface<
|
||||
mlir::transform::TransformHandleTypeInterface,
|
||||
mlir::transform::TransformParamTypeInterface,
|
||||
mlir::transform::TransformValueHandleTypeInterface>(t1, t2);
|
||||
}
|
||||
|
||||
// Matcher ops implement `apply` similarly to other transform ops. They are not
|
||||
// expected to modify payload, but use the tri-state result to signal failure or
|
||||
// success to match, as well as potential irrecoverable errors.
|
||||
mlir::DiagnosedSilenceableFailure
|
||||
mlir::transform::HasOperandSatisfyingOp::apply(
|
||||
mlir::transform::TransformRewriter &rewriter,
|
||||
mlir::transform::TransformResults &results,
|
||||
mlir::transform::TransformState &state) {
|
||||
// For simplicity, only handle a single payload op. Actual implementations
|
||||
// can use `SingleOpMatcher` trait to simplify implementation and document
|
||||
// this expectation.
|
||||
auto payloadOps = state.getPayloadOps(getOp());
|
||||
if (!llvm::hasSingleElement(payloadOps))
|
||||
return emitSilenceableError() << "expected single payload";
|
||||
|
||||
// Iterate over all operands of the payload op to see if they can be matched
|
||||
// using the body of this op.
|
||||
Operation *payload = *payloadOps.begin();
|
||||
for (OpOperand &operand : payload->getOpOperands()) {
|
||||
// Create a scope for transform values defined in the body. This corresponds
|
||||
// to the syntactic scope of the region attached to this op. Any values
|
||||
// associated with payloads from now on will be automatically dissociated
|
||||
// when this object is destroyed, i.e. at the end of the iteration.
|
||||
// Associate the block argument handle with the operand.
|
||||
auto matchScope = state.make_region_scope(getBody());
|
||||
if (failed(state.mapBlockArgument(getBody().getArgument(0),
|
||||
{operand.get()}))) {
|
||||
return DiagnosedSilenceableFailure::definiteFailure();
|
||||
}
|
||||
|
||||
// Iterate over all nested matchers with the current mapping and see if they
|
||||
// succeed.
|
||||
bool matchSucceeded = true;
|
||||
for (Operation &matcher : getBody().front().without_terminator()) {
|
||||
// Matcher ops are applied similarly to any other transform op.
|
||||
DiagnosedSilenceableFailure diag =
|
||||
state.applyTransform(cast<TransformOpInterface>(matcher));
|
||||
|
||||
// Definite failures are immediately propagated as they are irrecoverable.
|
||||
if (diag.isDefiniteFailure())
|
||||
return diag;
|
||||
|
||||
// On success, keep checking the remaining conditions.
|
||||
if (diag.succeeded())
|
||||
continue;
|
||||
|
||||
// Report failure-to-match for debugging purposes and stop matching this
|
||||
// operand.
|
||||
assert(diag.isSilenceableFailure());
|
||||
DEBUG_MATCHER(DBGS_MATCHER()
|
||||
<< "failed to match operand #" << operand.getOperandNumber()
|
||||
<< ": " << diag.getMessage());
|
||||
(void)diag.silence();
|
||||
matchSucceeded = false;
|
||||
break;
|
||||
}
|
||||
// If failed to match this operand, try other operands.
|
||||
if (!matchSucceeded)
|
||||
continue;
|
||||
|
||||
// If we reached this point, the matching succeeded for the current operand.
|
||||
// Remap the values associated with terminator operands to be associated
|
||||
// with op results, and also map the parameter result to the operand's
|
||||
// position. Note that it is safe to do here despite the end of the scope
|
||||
// as `results` are integrated into `state` by the interpreter after `apply`
|
||||
// returns rather than immediately.
|
||||
SmallVector<SmallVector<MappedValue>> yieldedMappings;
|
||||
transform::detail::prepareValueMappings(
|
||||
yieldedMappings, getBody().front().getTerminator()->getOperands(),
|
||||
state);
|
||||
results.setParams(getPosition().cast<OpResult>(),
|
||||
{rewriter.getI32IntegerAttr(operand.getOperandNumber())});
|
||||
for (auto &&[result, mapping] : llvm::zip(getResults(), yieldedMappings))
|
||||
results.setMappedValues(result, mapping);
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
// If we reached this point, none of the operands succeeded the match.
|
||||
return emitSilenceableError()
|
||||
<< "none of the operands satisfied the conditions";
|
||||
}
|
||||
|
||||
// By convention, operations implementing MatchOpInterface must not modify
|
||||
// payload IR and must therefore specify that they only read operand handles and
|
||||
// payload as their effects.
|
||||
void mlir::transform::HasOperandSatisfyingOp::getEffects(
|
||||
llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance> &effects) {
|
||||
onlyReadsPayload(effects);
|
||||
onlyReadsHandle(getOp(), effects);
|
||||
producesHandle(getPosition(), effects);
|
||||
producesHandle(getResults(), effects);
|
||||
}
|
||||
|
||||
// Verify well-formedness of the operation and emit diagnostics if it is
|
||||
// ill-formed.
|
||||
mlir::LogicalResult mlir::transform::HasOperandSatisfyingOp::verify() {
|
||||
mlir::Block &bodyBlock = getBody().front();
|
||||
if (bodyBlock.getNumArguments() != 1 ||
|
||||
!isa<TransformValueHandleTypeInterface>(
|
||||
bodyBlock.getArgument(0).getType())) {
|
||||
return emitOpError()
|
||||
<< "expects the body to have one value handle argument";
|
||||
}
|
||||
if (bodyBlock.getTerminator()->getNumOperands() != getNumResults() - 1) {
|
||||
return emitOpError() << "expects the body to yield "
|
||||
<< (getNumResults() - 1) << " values, got "
|
||||
<< bodyBlock.getTerminator()->getNumOperands();
|
||||
}
|
||||
for (auto &&[i, operand, result] :
|
||||
llvm::enumerate(bodyBlock.getTerminator()->getOperands().getTypes(),
|
||||
getResults().getTypes())) {
|
||||
if (implementSameTransformInterface(operand, result))
|
||||
continue;
|
||||
return emitOpError() << "expects terminator operand #" << i
|
||||
<< " and result #" << (i + 1)
|
||||
<< " to implement the same transform interface";
|
||||
}
|
||||
|
||||
for (Operation &op : bodyBlock.without_terminator()) {
|
||||
if (!isa<TransformOpInterface>(op) || !isa<MatchOpInterface>(op)) {
|
||||
InFlightDiagnostic diag = emitOpError()
|
||||
<< "expects body to contain match ops";
|
||||
diag.attachNote(op.getLoc()) << "non-match operation";
|
||||
return diag;
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void registerMyExtension(::mlir::DialectRegistry ®istry) {
|
||||
registry.addExtensions<MyExtension>();
|
||||
}
|
||||
55
mlir/examples/transform/Ch4/transform-opt/transform-opt.cpp
Normal file
55
mlir/examples/transform/Ch4/transform-opt/transform-opt.cpp
Normal file
@@ -0,0 +1,55 @@
|
||||
//===-- transform-opt.cpp - Transform dialect tutorial entry point --------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This is the top-level file for the Transform dialect tutorial chapter 4.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "MyExtension.h"
|
||||
|
||||
#include "mlir/Dialect/Transform/Transforms/Passes.h"
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/InitAllDialects.h"
|
||||
#include "mlir/InitAllExtensions.h"
|
||||
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include <cstdlib>
|
||||
|
||||
namespace test {
|
||||
void registerTestTransformDialectExtension(mlir::DialectRegistry &);
|
||||
} // namespace test
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
// Register all "core" dialects and our transform dialect extension.
|
||||
mlir::DialectRegistry registry;
|
||||
mlir::registerAllDialects(registry);
|
||||
mlir::registerAllExtensions(registry);
|
||||
registerMyExtension(registry);
|
||||
|
||||
// Register a handful of cleanup passes that we can run to make the output IR
|
||||
// look nicer.
|
||||
mlir::registerCanonicalizerPass();
|
||||
mlir::registerCSEPass();
|
||||
mlir::registerSymbolDCEPass();
|
||||
mlir::transform::registerInterpreterPass();
|
||||
|
||||
// Register the test passes.
|
||||
#ifdef MLIR_INCLUDE_TESTS
|
||||
test::registerTestTransformDialectExtension(registry);
|
||||
#else
|
||||
llvm::errs() << "warning: MLIR built without test extension, interpreter "
|
||||
"testing will not be available\n";
|
||||
#endif // MLIR_INCLUDE_TESTS
|
||||
|
||||
// Delegate to the MLIR utility for parsing and pass management.
|
||||
return mlir::MlirOptMain(argc, argv, "transform-opt-ch4", registry)
|
||||
.succeeded()
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
@@ -166,6 +166,8 @@ if(LLVM_BUILD_EXAMPLES)
|
||||
list(APPEND MLIR_TEST_DEPENDS
|
||||
transform-opt-ch2
|
||||
transform-opt-ch3
|
||||
transform-opt-ch4
|
||||
mlir-minimal-opt
|
||||
)
|
||||
if(MLIR_ENABLE_EXECUTION_ENGINE)
|
||||
list(APPEND MLIR_TEST_DEPENDS
|
||||
|
||||
123
mlir/test/Examples/transform/Ch4/features.mlir
Normal file
123
mlir/test/Examples/transform/Ch4/features.mlir
Normal file
@@ -0,0 +1,123 @@
|
||||
// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics
|
||||
|
||||
// Matmul as a named operation.
|
||||
func.func @named(
|
||||
%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
|
||||
%bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
|
||||
-> tensor<512x512xf32> {
|
||||
// expected-remark @below {{matmul}}
|
||||
%matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
|
||||
outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
|
||||
func.return %matmul : tensor<512x512xf32>
|
||||
}
|
||||
|
||||
// Matmul as a generic operation.
|
||||
func.func @generic(
|
||||
%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
|
||||
%bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
|
||||
-> tensor<512x512xf32> {
|
||||
// expected-remark @below {{matmul}}
|
||||
%matmul = linalg.generic {
|
||||
iterator_types = ["parallel", "parallel", "reduction"],
|
||||
indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d2, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>]
|
||||
} ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
|
||||
outs(%output: tensor<512x512xf32>) {
|
||||
^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
|
||||
%0 = arith.mulf %arg0, %arg1 : f32
|
||||
%1 = arith.addf %0, %arg2 : f32
|
||||
linalg.yield %1 : f32
|
||||
} -> tensor<512x512xf32>
|
||||
return %matmul : tensor<512x512xf32>
|
||||
}
|
||||
|
||||
// The module containing named sequences must have an attribute allowing them
|
||||
// to enable verification.
|
||||
module @transforms attributes { transform.with_named_sequence } {
|
||||
// Entry point. This takes as the only argument the root operation (typically
|
||||
// pass root) given to the transform interpreter.
|
||||
transform.named_sequence @__transform_main(
|
||||
%root: !transform.any_op {transform.consumed}) {
|
||||
|
||||
// Traverses the payload IR associated with the operand handle, invoking
|
||||
// @match_matmul_elemwise on each of the operations. If the named sequence
|
||||
// succeeds, i.e., if none of the nested match (transform) operations
|
||||
// produced a silenceable failure, invokes @print_matmul_elemwise and
|
||||
// forwards the values yielded as arguments of the new invocation. If the
|
||||
// named sequence fails with a silenceable failure, silences it (the message
|
||||
// is forwarded to the debug stream). Definite failures are propagated
|
||||
// immediately and unconditionally, as usual.
|
||||
transform.foreach_match in %root
|
||||
@match_generic_matmul -> @print_generic_matmul
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
|
||||
transform.yield
|
||||
}
|
||||
|
||||
// This is an action sequence.
|
||||
transform.named_sequence @print_generic_matmul(
|
||||
%matmul: !transform.any_op {transform.readonly}) {
|
||||
transform.test_print_remark_at_operand %matmul, "matmul" : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
|
||||
transform.named_sequence @match_generic_matmul(
|
||||
%candidate: !transform.any_op {transform.readonly}) -> !transform.any_op {
|
||||
// Match a structured linear algebra operation.
|
||||
transform.match.structured %candidate : !transform.any_op {
|
||||
^bb0(%c: !transform.any_op):
|
||||
// With a rank equal to 3.
|
||||
%rank = transform.match.structured.rank %c
|
||||
: (!transform.any_op) -> !transform.param<i64>
|
||||
%c3 = transform.param.constant 3 : i64 -> !transform.param<i64>
|
||||
transform.match.param.cmpi eq %rank, %c3 : !transform.param<i64>
|
||||
|
||||
// With 2 inputs.
|
||||
%n_ins = transform.match.structured.num_inputs %c
|
||||
: (!transform.any_op) -> !transform.param<i64>
|
||||
%c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
|
||||
transform.match.param.cmpi eq %n_ins, %c2 : !transform.param<i64>
|
||||
|
||||
// With 1 output (note that structured ops in destination passing style
|
||||
// has as many inits as outputs).
|
||||
%n_inits = transform.match.structured.num_inits %c
|
||||
: (!transform.any_op) -> !transform.param<i64>
|
||||
%c1 = transform.param.constant 1 : i64 -> !transform.param<i64>
|
||||
transform.match.param.cmpi eq %n_inits, %c1 : !transform.param<i64>
|
||||
|
||||
// All inputs and inits are accessed with a projected permutation.
|
||||
transform.match.structured.input %c[all] {projected_permutation}
|
||||
: !transform.any_op
|
||||
transform.match.structured.init %c[0] {projected_permutation}
|
||||
: !transform.any_op
|
||||
|
||||
// The body is a mulf/addf contraction with appropriate dimensions.
|
||||
transform.match.structured.body %c
|
||||
{ contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op
|
||||
%batch, %lhs, %rhs, %reduction =
|
||||
transform.match.structured.classify_contraction_dims %c
|
||||
: (!transform.any_op)
|
||||
-> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>,
|
||||
!transform.param<i64>)
|
||||
|
||||
// There is one of lhs, rhs and reduction dimensions and zero batch
|
||||
// dimensions.
|
||||
%n_batch = transform.num_associations %batch
|
||||
: (!transform.param<i64>) -> !transform.param<i64>
|
||||
%n_lhs = transform.num_associations %lhs
|
||||
: (!transform.param<i64>) -> !transform.param<i64>
|
||||
%n_rhs = transform.num_associations %rhs
|
||||
: (!transform.param<i64>) -> !transform.param<i64>
|
||||
%n_reduction = transform.num_associations %reduction
|
||||
: (!transform.param<i64>) -> !transform.param<i64>
|
||||
%c0 = transform.param.constant 0 : i64 -> !transform.param<i64>
|
||||
transform.match.param.cmpi eq %n_batch, %c0 : !transform.param<i64>
|
||||
transform.match.param.cmpi eq %n_lhs, %c1 : !transform.param<i64>
|
||||
transform.match.param.cmpi eq %n_rhs, %c1 : !transform.param<i64>
|
||||
transform.match.param.cmpi eq %n_reduction, %c1 : !transform.param<i64>
|
||||
}
|
||||
transform.yield %candidate : !transform.any_op
|
||||
}
|
||||
}
|
||||
131
mlir/test/Examples/transform/Ch4/multiple.mlir
Normal file
131
mlir/test/Examples/transform/Ch4/multiple.mlir
Normal file
@@ -0,0 +1,131 @@
|
||||
// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics
|
||||
|
||||
// Matmul+ReLU.
|
||||
func.func @fc_relu_operands_00(
|
||||
%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
|
||||
%bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
|
||||
-> tensor<512x512xf32> {
|
||||
// Matrix-matrix multiplication.
|
||||
// expected-remark @below {{matmul # 0}}
|
||||
%matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
|
||||
outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
|
||||
|
||||
// Elementwise addition.
|
||||
// expected-remark @below {{add # 0}}
|
||||
%biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
|
||||
ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
|
||||
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
|
||||
|
||||
// Elementwise max with 0 (ReLU).
|
||||
%c0f = arith.constant 0.0 : f32
|
||||
// expected-remark @below {{max # 0}}
|
||||
%relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
|
||||
ins(%biased, %c0f : tensor<512x512xf32>, f32)
|
||||
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
|
||||
func.return %relued : tensor<512x512xf32>
|
||||
}
|
||||
|
||||
// Matmul+ReLU with swapped operands.
|
||||
func.func @fc_relu_operands_01(
|
||||
%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
|
||||
%bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
|
||||
-> tensor<512x512xf32> {
|
||||
// Matrix-matrix multiplication.
|
||||
// expected-remark @below {{matmul # 1}}
|
||||
%matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
|
||||
outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
|
||||
|
||||
// Elementwise addition.
|
||||
// expected-remark @below {{add # 1}}
|
||||
%biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
|
||||
ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
|
||||
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
|
||||
|
||||
// Elementwise max with 0 (ReLU).
|
||||
%c0f = arith.constant 0.0 : f32
|
||||
// expected-remark @below {{max # 1}}
|
||||
%relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
|
||||
ins(%c0f, %biased : f32, tensor<512x512xf32>)
|
||||
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
|
||||
func.return %relued : tensor<512x512xf32>
|
||||
}
|
||||
|
||||
// The module containing named sequences must have an attribute allowing them
|
||||
// to enable verification.
|
||||
module @transforms attributes { transform.with_named_sequence } {
|
||||
// Entry point. This takes as the only argument the root operation (typically
|
||||
// pass root) given to the transform interpreter.
|
||||
transform.named_sequence @__transform_main(
|
||||
%root: !transform.any_op {transform.consumed}) {
|
||||
|
||||
// Traverses the payload IR associated with the operand handle, invoking
|
||||
// @match_matmul_elemwise on each of the operations. If the named sequence
|
||||
// succeeds, i.e., if none of the nested match (transform) operations
|
||||
// produced a silenceable failure, invokes @print_matmul_elemwise and
|
||||
// forwards the values yielded as arguments of the new invocation. If the
|
||||
// named sequence fails with a silenceable failure, silences it (the message
|
||||
// is forwarded to the debug stream). Definite failures are propagated
|
||||
// immediately and unconditionally, as usual.
|
||||
transform.foreach_match in %root
|
||||
@match_matmul_elemwise -> @print_matmul_elemwise
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
|
||||
transform.yield
|
||||
}
|
||||
|
||||
// This is an action sequence.
|
||||
transform.named_sequence @print_matmul_elemwise(
|
||||
%matmul: !transform.any_op {transform.readonly},
|
||||
%add: !transform.any_op {transform.readonly},
|
||||
%max: !transform.any_op {transform.readonly},
|
||||
%pos: !transform.param<i32> {transform.readonly}) {
|
||||
transform.test_print_param %pos, "matmul #" at %matmul
|
||||
: !transform.param<i32>, !transform.any_op
|
||||
transform.test_print_param %pos, "add #" at %add
|
||||
: !transform.param<i32>, !transform.any_op
|
||||
transform.test_print_param %pos, "max #" at %max
|
||||
: !transform.param<i32>, !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
|
||||
// This is also a matcher sequence. It is similarly given an operation to
|
||||
// match and nested operations must succeed in order for a match to be deemed
|
||||
// successful. It starts matching from the last operation in the use-def chain
|
||||
// and goes back because each operand (use) has exactly one definition.
|
||||
transform.named_sequence @match_matmul_elemwise(
|
||||
%last: !transform.any_op {transform.readonly})
|
||||
-> (!transform.any_op, !transform.any_op, !transform.any_op,
|
||||
!transform.param<i32>) {
|
||||
// The last operation must be an elementwise binary.
|
||||
transform.match.operation_name %last ["linalg.elemwise_binary"]
|
||||
: !transform.any_op
|
||||
|
||||
// One of its operands must be defined by another operation, to which we
|
||||
// will get a handle here. This is achieved thanks to a newly defined
|
||||
// operation that tries to match operands one by one using the match
|
||||
// operations nested in its region.
|
||||
%pos, %middle = transform.match.my.has_operand_satisfying %last
|
||||
: (!transform.any_op) -> (!transform.param<i32>, !transform.any_op) {
|
||||
^bb0(%operand: !transform.any_value):
|
||||
// The operand must be defined by an operation.
|
||||
%def = transform.get_defining_op %operand
|
||||
: (!transform.any_value) -> !transform.any_op
|
||||
// The defining operation must itself be an elementwise binary.
|
||||
transform.match.operation_name %def ["linalg.elemwise_binary"]
|
||||
: !transform.any_op
|
||||
transform.yield %def : !transform.any_op
|
||||
}
|
||||
|
||||
// And the first operand of that operation must be defined by yet another
|
||||
// operation.
|
||||
%matmul = transform.get_producer_of_operand %middle[0]
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
// And that operation is a matmul.
|
||||
transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op
|
||||
// We will yield the handles to the matmul and the two elementwise
|
||||
// operations separately.
|
||||
transform.yield %matmul, %middle, %last, %pos
|
||||
: !transform.any_op, !transform.any_op, !transform.any_op,
|
||||
!transform.param<i32>
|
||||
}
|
||||
}
|
||||
139
mlir/test/Examples/transform/Ch4/sequence.mlir
Normal file
139
mlir/test/Examples/transform/Ch4/sequence.mlir
Normal file
@@ -0,0 +1,139 @@
|
||||
// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics
|
||||
//
|
||||
// RUN: transform-opt-ch4 %s \
|
||||
// RUN: --transform-interpreter='entry-point=__transform_main_v2' \
|
||||
// RUN: --verify-diagnostics
|
||||
|
||||
// ****************************** IMPORTANT NOTE ******************************
|
||||
//
|
||||
// If you are changing this file, you may also need to change
|
||||
// mlir/docs/Tutorials/Transform accordingly.
|
||||
//
|
||||
// ****************************************************************************
|
||||
|
||||
// Original function to optimize.
|
||||
func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
|
||||
%bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
|
||||
-> tensor<512x512xf32> {
|
||||
// Matrix-matrix multiplication.
|
||||
// expected-remark @below {{matmul}}
|
||||
%matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
|
||||
outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
|
||||
|
||||
// Elementwise addition.
|
||||
// expected-remark @below {{elementwise binary}}
|
||||
%biased = linalg.elemwise_binary { fun = #linalg.binary_fn<add> }
|
||||
ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>)
|
||||
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
|
||||
|
||||
// Elementwise max with 0 (ReLU).
|
||||
%c0f = arith.constant 0.0 : f32
|
||||
// expected-remark @below {{elementwise binary}}
|
||||
%relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> }
|
||||
ins(%biased, %c0f : tensor<512x512xf32>, f32)
|
||||
outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32>
|
||||
func.return %relued : tensor<512x512xf32>
|
||||
}
|
||||
|
||||
// The module containing named sequences must have an attribute allowing them
|
||||
// to enable verification.
|
||||
module @transforms attributes { transform.with_named_sequence } {
|
||||
// Entry point. This takes as the only argument the root operation (typically
|
||||
// pass root) given to the transform interpreter.
|
||||
transform.named_sequence @__transform_main(
|
||||
%root: !transform.any_op {transform.readonly}) {
|
||||
// Collect operations that match the criteria specified in the named
|
||||
// sequence. If the named sequence fails with a silenceable failure,
|
||||
// silences it (the message is forwarded to the debug stream). If the named
|
||||
// sequence succeeds, appends its results to the results of this operation.
|
||||
%elemwise = transform.collect_matching @match_elemwise in %root
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
%matmul = transform.collect_matching @match_matmul in %root
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
|
||||
transform.include @print_elemwise failures(propagate) (%elemwise)
|
||||
: (!transform.any_op) -> ()
|
||||
transform.include @print_matmul failures(propagate) (%matmul)
|
||||
: (!transform.any_op) -> ()
|
||||
|
||||
transform.yield
|
||||
}
|
||||
|
||||
// Alternative entry point.
|
||||
transform.named_sequence @__transform_main_v2(
|
||||
%root: !transform.any_op {transform.readonly}) {
|
||||
// Collect groups of operations that match the criteria specified in the
|
||||
// named sequence.
|
||||
%matmul, %el1, %el2 = transform.collect_matching @match_matmul_elemwise in %root
|
||||
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
|
||||
%elemwise = transform.merge_handles %el1, %el2 : !transform.any_op
|
||||
|
||||
transform.include @print_elemwise failures(propagate) (%elemwise)
|
||||
: (!transform.any_op) -> ()
|
||||
transform.include @print_matmul failures(propagate) (%matmul)
|
||||
: (!transform.any_op) -> ()
|
||||
|
||||
transform.yield
|
||||
}
|
||||
|
||||
// This is a matcher sequence. It is given an operation to match and the
|
||||
// match is considered successful unless any nested operation produces a
|
||||
// failure. The values yielded by this operation will be forwarded to the
|
||||
// rewriter sequence on success.
|
||||
transform.named_sequence @match_elemwise(
|
||||
%entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
|
||||
transform.match.operation_name %entry ["linalg.elemwise_binary"]
|
||||
: !transform.any_op
|
||||
transform.yield %entry : !transform.any_op
|
||||
}
|
||||
transform.named_sequence @match_matmul(
|
||||
%entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
|
||||
transform.match.operation_name %entry ["linalg.matmul"] : !transform.any_op
|
||||
transform.yield %entry : !transform.any_op
|
||||
}
|
||||
|
||||
// This is an action sequence.
|
||||
transform.named_sequence @print_elemwise(
|
||||
%elemwise_binary: !transform.any_op {transform.readonly}) {
|
||||
transform.test_print_remark_at_operand
|
||||
%elemwise_binary, "elementwise binary" : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
transform.named_sequence @print_matmul(
|
||||
%matmul: !transform.any_op {transform.readonly}) {
|
||||
transform.test_print_remark_at_operand %matmul, "matmul" : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
|
||||
// This is also a matcher sequence. It is similarly given an operation to
|
||||
// match and nested operations must succeed in order for a match to be deemed
|
||||
// successful. It starts matching from the last operation in the use-def chain
|
||||
// and goes back because each operand (use) has exactly one definition.
|
||||
transform.named_sequence @match_matmul_elemwise(
|
||||
%last: !transform.any_op {transform.readonly})
|
||||
-> (!transform.any_op, !transform.any_op, !transform.any_op) {
|
||||
// The last operation must be an elementwise binary.
|
||||
transform.match.operation_name %last ["linalg.elemwise_binary"]
|
||||
: !transform.any_op
|
||||
// Its first operand must be defined by another operation, to which we
|
||||
// will get a handle here. We are guaranteed that the first operand exists
|
||||
// because we know the operation is binary, but even in absence of such a
|
||||
// guarantee, this operation would have produced a silenceable failure when
|
||||
// `%last` does not have enough operands.
|
||||
%middle = transform.get_producer_of_operand %last[0]
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
// The defining operation must itself be an elementwise binary.
|
||||
transform.match.operation_name %middle ["linalg.elemwise_binary"]
|
||||
: !transform.any_op
|
||||
// And the first operand of that operation must be defined by yet another
|
||||
// operation.
|
||||
%matmul = transform.get_producer_of_operand %middle[0]
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
// And that operation is a matmul.
|
||||
transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op
|
||||
// We will yield the handles to the matmul and the two elementwise
|
||||
// operations separately.
|
||||
transform.yield %matmul, %middle, %last
|
||||
: !transform.any_op, !transform.any_op, !transform.any_op
|
||||
}
|
||||
}
|
||||
@@ -154,8 +154,9 @@ tools.extend(
|
||||
ToolSubst("toyc-ch5", unresolved="ignore"),
|
||||
ToolSubst("toyc-ch6", unresolved="ignore"),
|
||||
ToolSubst("toyc-ch7", unresolved="ignore"),
|
||||
ToolSubst('transform-opt-ch2', unresolved='ignore'),
|
||||
ToolSubst('transform-opt-ch3', unresolved='ignore'),
|
||||
ToolSubst("transform-opt-ch2", unresolved="ignore"),
|
||||
ToolSubst("transform-opt-ch3", unresolved="ignore"),
|
||||
ToolSubst("transform-opt-ch4", unresolved="ignore"),
|
||||
ToolSubst("%mlir_lib_dir", config.mlir_lib_dir, unresolved="ignore"),
|
||||
ToolSubst("%mlir_src_dir", config.mlir_src_root, unresolved="ignore"),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user