Files
clang-p2996/mlir/test/Integration/Dialect/Transform/match_batch_matmul.mlir
Oleksandr "Alex" Zinenko 2798b72ae7 [mlir] introduce debug transform dialect extension (#77595)
Introduce a new extension for simple print-debugging of the transform
dialect scripts. The initial version of this extension consists of two
ops that are printing the payload objects associated with transform
dialect values. Similar ops were already available in the test extenion
and several downstream projects, and were extensively used for testing.
2024-01-12 13:24:02 +01:00

67 lines
3.8 KiB
MLIR

// RUN: mlir-opt %s --transform-preload-library='transform-library-paths=%p/match_matmul_common.mlir' --transform-interpreter --verify-diagnostics
module attributes { transform.with_named_sequence } {
transform.named_sequence @_match_matmul_like(
%entry: !transform.any_op {transform.readonly},
%rank: !transform.param<i64> {transform.readonly})
-> (!transform.any_op, !transform.any_op, !transform.param<i64>,
!transform.type, !transform.type, !transform.type,
!transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>)
transform.named_sequence @match_bmm(%entry: !transform.any_op {transform.readonly})
-> (!transform.any_op, !transform.any_op, !transform.param<i64>,
!transform.type, !transform.type, !transform.type, !transform.param<i64>) {
transform.match.operation_name %entry ["linalg.batch_matmul", "linalg.generic"] : !transform.any_op
%c3 = transform.param.constant 4 : i64 -> !transform.param<i64>
%fill, %bmm, %dims, %lhs_type, %rhs_type, %res_type, %batch, %m, %n, %k =
transform.include @_match_matmul_like failures(propagate) (%entry, %c3)
: (!transform.any_op, !transform.param<i64>)
-> (!transform.any_op, !transform.any_op, !transform.param<i64>,
!transform.type, !transform.type, !transform.type,
!transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>)
transform.yield %fill, %bmm, %dims, %lhs_type, %rhs_type, %res_type, %batch
: !transform.any_op, !transform.any_op, !transform.param<i64>, !transform.type, !transform.type, !transform.type, !transform.param<i64>
}
transform.named_sequence @print_bmm(
%fill: !transform.any_op {transform.readonly},
%bmm: !transform.any_op {transform.readonly},
%dims: !transform.param<i64> {transform.readonly},
%lhs_type: !transform.type {transform.readonly},
%rhs_type: !transform.type {transform.readonly},
%res_type: !transform.type {transform.readonly},
%batch: !transform.param<i64> {transform.readonly}) {
transform.debug.emit_remark_at %fill, "fill" : !transform.any_op
transform.debug.emit_remark_at %bmm, "batch matmul" : !transform.any_op
transform.debug.emit_param_as_remark %dims, "dimensions" at %bmm : !transform.param<i64>, !transform.any_op
transform.debug.emit_param_as_remark %lhs_type, "LHS type" at %bmm : !transform.type, !transform.any_op
transform.debug.emit_param_as_remark %rhs_type, "RHS type" at %bmm : !transform.type, !transform.any_op
transform.debug.emit_param_as_remark %res_type, "result type" at %bmm : !transform.type, !transform.any_op
transform.debug.emit_param_as_remark %batch, "batch dimension" at %bmm : !transform.param<i64>, !transform.any_op
transform.yield
}
transform.named_sequence @__transform_main(%root: !transform.any_op {transform.consumed}) {
transform.foreach_match in %root
@match_bmm -> @print_bmm
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}
func.func @bmm_simple(%lhs: tensor<40x10x20xf16>, %rhs: tensor<40x20x15xf32>) -> tensor<40x10x15xf64>{
%cst = arith.constant 0.0 : f64
%empty = tensor.empty() : tensor<40x10x15xf64>
// expected-remark @below {{fill}}
%fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<40x10x15xf64>) -> tensor<40x10x15xf64>
// expected-remark @below {{batch matmul}}
// expected-remark @below {{dimensions 40 : i64, 10 : i64, 15 : i64, 20 : i64}}
// expected-remark @below {{LHS type f16}}
// expected-remark @below {{RHS type f32}}
// expected-remark @below {{result type f64}}
// expected-remark @below {{batch dimension 0}}
%result = linalg.batch_matmul ins(%lhs, %rhs: tensor<40x10x20xf16>, tensor<40x20x15xf32>) outs(%fill: tensor<40x10x15xf64>) -> tensor<40x10x15xf64>
return %result : tensor<40x10x15xf64>
}