Files
clang-p2996/mlir/test/Examples/transform/Ch4/features.mlir
Oleksandr "Alex" Zinenko 2798b72ae7 [mlir] introduce debug transform dialect extension (#77595)
Introduce a new extension for simple print-debugging of the transform
dialect scripts. The initial version of this extension consists of two
ops that are printing the payload objects associated with transform
dialect values. Similar ops were already available in the test extenion
and several downstream projects, and were extensively used for testing.
2024-01-12 13:24:02 +01:00

124 lines
5.5 KiB
MLIR

// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics
// Matmul as a named operation.
func.func @named(
%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
%bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
-> tensor<512x512xf32> {
// expected-remark @below {{matmul}}
%matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32>
func.return %matmul : tensor<512x512xf32>
}
// Matmul as a generic operation.
func.func @generic(
%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
%bias: tensor<512x512xf32>, %output: tensor<512x512xf32>)
-> tensor<512x512xf32> {
// expected-remark @below {{matmul}}
%matmul = linalg.generic {
iterator_types = ["parallel", "parallel", "reduction"],
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>]
} ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>)
outs(%output: tensor<512x512xf32>) {
^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
%0 = arith.mulf %arg0, %arg1 : f32
%1 = arith.addf %0, %arg2 : f32
linalg.yield %1 : f32
} -> tensor<512x512xf32>
return %matmul : tensor<512x512xf32>
}
// The module containing named sequences must have an attribute allowing them
// to enable verification.
module @transforms attributes { transform.with_named_sequence } {
// Entry point. This takes as the only argument the root operation (typically
// pass root) given to the transform interpreter.
transform.named_sequence @__transform_main(
%root: !transform.any_op {transform.consumed}) {
// Traverses the payload IR associated with the operand handle, invoking
// @match_matmul_elemwise on each of the operations. If the named sequence
// succeeds, i.e., if none of the nested match (transform) operations
// produced a silenceable failure, invokes @print_matmul_elemwise and
// forwards the values yielded as arguments of the new invocation. If the
// named sequence fails with a silenceable failure, silences it (the message
// is forwarded to the debug stream). Definite failures are propagated
// immediately and unconditionally, as usual.
transform.foreach_match in %root
@match_generic_matmul -> @print_generic_matmul
: (!transform.any_op) -> !transform.any_op
transform.yield
}
// This is an action sequence.
transform.named_sequence @print_generic_matmul(
%matmul: !transform.any_op {transform.readonly}) {
transform.debug.emit_remark_at %matmul, "matmul" : !transform.any_op
transform.yield
}
transform.named_sequence @match_generic_matmul(
%candidate: !transform.any_op {transform.readonly}) -> !transform.any_op {
// Match a structured linear algebra operation.
transform.match.structured %candidate : !transform.any_op {
^bb0(%c: !transform.any_op):
// With a rank equal to 3.
%rank = transform.match.structured.rank %c
: (!transform.any_op) -> !transform.param<i64>
%c3 = transform.param.constant 3 : i64 -> !transform.param<i64>
transform.match.param.cmpi eq %rank, %c3 : !transform.param<i64>
// With 2 inputs.
%n_ins = transform.match.structured.num_inputs %c
: (!transform.any_op) -> !transform.param<i64>
%c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
transform.match.param.cmpi eq %n_ins, %c2 : !transform.param<i64>
// With 1 output (note that structured ops in destination passing style
// has as many inits as outputs).
%n_inits = transform.match.structured.num_inits %c
: (!transform.any_op) -> !transform.param<i64>
%c1 = transform.param.constant 1 : i64 -> !transform.param<i64>
transform.match.param.cmpi eq %n_inits, %c1 : !transform.param<i64>
// All inputs and inits are accessed with a projected permutation.
transform.match.structured.input %c[all] {projected_permutation}
: !transform.any_op
transform.match.structured.init %c[0] {projected_permutation}
: !transform.any_op
// The body is a mulf/addf contraction with appropriate dimensions.
transform.match.structured.body %c
{ contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op
%batch, %lhs, %rhs, %reduction =
transform.match.structured.classify_contraction_dims %c
: (!transform.any_op)
-> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>,
!transform.param<i64>)
// There is one of lhs, rhs and reduction dimensions and zero batch
// dimensions.
%n_batch = transform.num_associations %batch
: (!transform.param<i64>) -> !transform.param<i64>
%n_lhs = transform.num_associations %lhs
: (!transform.param<i64>) -> !transform.param<i64>
%n_rhs = transform.num_associations %rhs
: (!transform.param<i64>) -> !transform.param<i64>
%n_reduction = transform.num_associations %reduction
: (!transform.param<i64>) -> !transform.param<i64>
%c0 = transform.param.constant 0 : i64 -> !transform.param<i64>
transform.match.param.cmpi eq %n_batch, %c0 : !transform.param<i64>
transform.match.param.cmpi eq %n_lhs, %c1 : !transform.param<i64>
transform.match.param.cmpi eq %n_rhs, %c1 : !transform.param<i64>
transform.match.param.cmpi eq %n_reduction, %c1 : !transform.param<i64>
}
transform.yield %candidate : !transform.any_op
}
}