Files
clang-p2996/mlir/test/Integration/Dialect/Transform/match_reduction.mlir
Oleksandr "Alex" Zinenko 2798b72ae7 [mlir] introduce debug transform dialect extension (#77595)
Introduce a new extension for simple print-debugging of the transform
dialect scripts. The initial version of this extension consists of two
ops that are printing the payload objects associated with transform
dialect values. Similar ops were already available in the test extenion
and several downstream projects, and were extensively used for testing.
2024-01-12 13:24:02 +01:00

320 lines
14 KiB
MLIR

// RUN: mlir-opt %s --transform-interpreter --verify-diagnostics
module attributes { transform.with_named_sequence } {
transform.named_sequence @_reduce_leading_trailing(%entry: !transform.any_op {transform.readonly})
-> (!transform.any_op) {
%c1 = transform.param.constant 1 : i64 -> !transform.param<i64>
transform.match.structured %entry : !transform.any_op {
^bb0(%struct: !transform.any_op):
transform.match.structured.dim %struct[all] {parallel} : !transform.any_op
transform.match.structured.input %struct[all] {projected_permutation} : !transform.any_op
transform.match.structured.init %struct[all] {permutation} : !transform.any_op
%ni = transform.match.structured.num_inits %struct : (!transform.any_op) -> !transform.param<i64>
transform.match.param.cmpi eq %ni, %c1 : !transform.param<i64>
}
transform.yield %entry : !transform.any_op
}
transform.named_sequence @fill_reduce_leading_trailing(%entry: !transform.any_op {transform.readonly})
-> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op,
!transform.param<i64>, !transform.param<i64>, !transform.param<i64>) {
%c1 = transform.param.constant 1 : i64 -> !transform.param<i64>
%c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
%c4 = transform.param.constant 4 : i64 -> !transform.param<i64>
%rk, %dms, %bw, %operand_o, %init_v, %trailing_o = transform.match.structured failures(propagate) %entry
: (!transform.any_op) -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>,
!transform.any_op, !transform.any_value, !transform.any_op) {
^bb0(%struct: !transform.any_op):
%rank = transform.match.structured.rank %struct : (!transform.any_op) -> !transform.param<i64>
transform.match.param.cmpi ge %rank, %c2 : !transform.param<i64>
transform.match.param.cmpi le %rank, %c4 : !transform.param<i64>
transform.match.structured.dim %struct[-1] {reduction} : !transform.any_op
transform.match.structured.dim %struct[except(-1)] {parallel} : !transform.any_op
%dims = transform.match.structured.dim %struct[all] : (!transform.any_op) -> !transform.param<i64>
%n_inputs = transform.match.structured.num_inputs %struct : (!transform.any_op) -> !transform.param<i64>
%n_outputs = transform.match.structured.num_inits %struct : (!transform.any_op) -> !transform.param<i64>
transform.match.param.cmpi eq %n_inputs, %c1 : !transform.param<i64>
transform.match.param.cmpi eq %n_outputs, %c1 : !transform.param<i64>
transform.match.structured.input %struct[0] {projected_permutation} : !transform.any_op
transform.match.structured.init %struct[0] {projected_permutation} : !transform.any_op
%init = transform.match.structured.init %struct[0] : (!transform.any_op) -> !transform.any_value
// This danse is necessary to create an empty handle if there is no single
// user without failing the entire match
%trailing_optional = transform.sequence %struct : (!transform.any_op) -> !transform.any_op failures(suppress) {
^bb0(%struct_inner: !transform.any_op):
%result = transform.match.structured failures(propagate) %struct_inner : (!transform.any_op) -> !transform.any_op {
^bb0(%struct_inner_inner: !transform.any_op):
%result_inner = transform.match.structured.result %struct_inner_inner[0] {single} : (!transform.any_op) -> !transform.any_op
%trailing = transform.include @_reduce_leading_trailing failures(propagate) (%result_inner) : (!transform.any_op) -> !transform.any_op
transform.match.structured.yield %trailing : !transform.any_op
}
transform.yield %result: !transform.any_op
}
// Suppress errors as a way to implement optionality. We cannot suppress them in
// the include because it keeps matching after "get_defining_op" fails, which
// breaks the single-op precondition of the following ops. We don't want to
// propagate that failure though.
//
// Additionally, we cannot put the sequence inside the call because its first
// operand must be an operation handle (the verifier asserts!) and there is
// no such handle available there.
//
// TODO: extend the structured matching to gracefully handle empty handles
// or provide the suppress-errors-but-stop failure mode for includes to
// implement optionality.
%operand_optional = transform.sequence %struct : (!transform.any_op) -> !transform.any_op failures(suppress) {
^bb0(%struct_inner: !transform.any_op):
%operand3 = transform.match.structured failures(propagate) %struct_inner : (!transform.any_op) -> !transform.any_op {
^bb1(%struct_inner_inner: !transform.any_op):
%operand = transform.match.structured.input %struct_inner_inner[0] : (!transform.any_op) -> !transform.any_op
%operand2 = transform.include @_reduce_leading_trailing failures(propagate) (%operand) : (!transform.any_op) -> !transform.any_op
transform.match.structured.yield %operand2 : !transform.any_op
}
transform.yield %operand3 : !transform.any_op
}
%bitwidth = transform.match.structured.elemental_bitwidth %init : (!transform.any_value) -> !transform.param<i64>
transform.match.structured.body %struct { reduction_position = 0 } : !transform.any_op
transform.match.structured.yield %rank, %dims, %bitwidth, %operand_optional, %init, %trailing_optional
: !transform.param<i64>, !transform.param<i64>, !transform.param<i64>,
!transform.any_op, !transform.any_value, !transform.any_op
}
%init_o = transform.get_defining_op %init_v : (!transform.any_value) -> !transform.any_op
transform.match.operation_name %init_o ["linalg.fill"] : !transform.any_op
transform.yield %operand_o, %init_o, %entry, %trailing_o, %rk, %dms, %bw
: !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op,
!transform.param<i64>, !transform.param<i64>, !transform.param<i64>
}
transform.named_sequence @print_reduce_leading_trailing(
%leading: !transform.any_op {transform.readonly},
%fill: !transform.any_op {transform.readonly},
%reduction: !transform.any_op {transform.readonly},
%trailing: !transform.any_op {transform.readonly},
%rank: !transform.param<i64> {transform.readonly},
%dims: !transform.param<i64> {transform.readonly},
%bitwidth: !transform.param<i64> {transform.readonly}) {
transform.debug.emit_remark_at %leading, "leading" : !transform.any_op
transform.debug.emit_remark_at %fill, "fill" : !transform.any_op
transform.debug.emit_remark_at %reduction, "reduction" : !transform.any_op
transform.debug.emit_remark_at %trailing, "trailing" : !transform.any_op
transform.debug.emit_param_as_remark %rank, "rank" at %reduction : !transform.param<i64>, !transform.any_op
transform.debug.emit_param_as_remark %dims, "dimensions" at %reduction : !transform.param<i64>, !transform.any_op
transform.debug.emit_param_as_remark %bitwidth, "bitwidth" at %reduction : !transform.param<i64>, !transform.any_op
transform.yield
}
transform.named_sequence @__transform_main(%root: !transform.any_op {transform.consumed}) {
transform.foreach_match in %root
@fill_reduce_leading_trailing -> @print_reduce_leading_trailing
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}
!in_tensor_t = tensor<8x64xf32>
!out_tensor_t = tensor<8xf32>
func.func @eltwise_reduce(%arg : !in_tensor_t) -> (!out_tensor_t) {
%cst = arith.constant -0.000000e+00 : f32
%0 = tensor.empty() : !out_tensor_t
// expected-remark @below {{fill}}
%1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) -> !out_tensor_t
%2 = tensor.empty() : !in_tensor_t
// expected-remark @below {{leading}}
%3 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg : !in_tensor_t) outs(%2 : !in_tensor_t) {
^bb0(%arg3: f32, %arg4: f32):
%4 = arith.addf %arg3, %arg3 : f32
%5 = arith.addf %4, %4 : f32
linalg.yield %5 : f32
} -> !in_tensor_t
// expected-remark @below {{reduction}}
// expected-remark @below {{rank 2}}
// expected-remark @below {{dimensions 8 : i64, 64 : i64}}
// expected-remark @below {{bitwidth 32 : i64}}
%6 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%3 : !in_tensor_t) outs(%1 : !out_tensor_t) {
^bb0(%arg3: f32, %arg4: f32):
%4 = arith.addf %arg3, %arg4 : f32
linalg.yield %4 : f32
} -> !out_tensor_t
return %6 : !out_tensor_t
}
func.func @reduce_eltwise(%arg : !in_tensor_t) -> (!out_tensor_t) {
%cst = arith.constant -0.000000e+00 : f32
%0 = tensor.empty() : !out_tensor_t
// expected-remark @below {{fill}}
%1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) -> !out_tensor_t
// expected-remark @below {{reduction}}
// expected-remark @below {{rank 2}}
// expected-remark @below {{dimensions 8 : i64, 64 : i64}}
// expected-remark @below {{bitwidth 32 : i64}}
%5 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%arg : !in_tensor_t) outs(%1 : !out_tensor_t) {
^bb0(%arg3: f32, %arg4: f32):
%4 = arith.addf %arg3, %arg4 : f32
linalg.yield %4 : f32
} -> !out_tensor_t
%6 = tensor.empty() : !out_tensor_t
// expected-remark @below {{trailing}}
%7 = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%5 : !out_tensor_t) outs(%6 : !out_tensor_t) {
^bb0(%arg3: f32, %arg4: f32):
%4 = math.sqrt %arg3 : f32
linalg.yield %4 : f32
} -> !out_tensor_t
return %7 : !out_tensor_t
}
func.func @eltwise_reduce_eltwise(%arg : !in_tensor_t) -> (!out_tensor_t) {
%cst = arith.constant -0.000000e+00 : f32
%0 = tensor.empty() : !out_tensor_t
// expected-remark @below {{fill}}
%1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) -> !out_tensor_t
%2 = tensor.empty() : !in_tensor_t
// expected-remark @below {{leading}}
%3 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg : !in_tensor_t) outs(%2 : !in_tensor_t) {
^bb0(%arg3: f32, %arg4: f32):
%4 = arith.addf %arg3, %arg3 : f32
%5 = arith.addf %4, %4 : f32
linalg.yield %5 : f32
} -> !in_tensor_t
// expected-remark @below {{reduction}}
// expected-remark @below {{rank 2}}
// expected-remark @below {{dimensions 8 : i64, 64 : i64}}
// expected-remark @below {{bitwidth 32 : i64}}
%6 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%3 : !in_tensor_t) outs(%1 : !out_tensor_t) {
^bb0(%arg3: f32, %arg4: f32):
%4 = arith.addf %arg3, %arg4 : f32
linalg.yield %4 : f32
} -> !out_tensor_t
%7 = tensor.empty() : !out_tensor_t
// expected-remark @below {{trailing}}
%8 = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%6 : !out_tensor_t) outs(%7 : !out_tensor_t) {
^bb0(%arg3: f32, %arg4: f32):
%4 = math.sqrt %arg3 : f32
linalg.yield %4 : f32
} -> !out_tensor_t
return %8 : !out_tensor_t
}
func.func @eltwise_reduce_eltwise_swapped(%arg : !in_tensor_t) -> (!out_tensor_t) {
%cst = arith.constant -0.000000e+00 : f32
%2 = tensor.empty() : !in_tensor_t
// expected-remark @below {{leading}}
%3 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg : !in_tensor_t) outs(%2 : !in_tensor_t) {
^bb0(%arg3: f32, %arg4: f32):
%4 = arith.addf %arg3, %arg3 : f32
%5 = arith.addf %4, %4 : f32
linalg.yield %5 : f32
} -> !in_tensor_t
%0 = tensor.empty() : !out_tensor_t
// expected-remark @below {{fill}}
%1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) -> !out_tensor_t
// expected-remark @below {{reduction}}
// expected-remark @below {{rank 2}}
// expected-remark @below {{dimensions 8 : i64, 64 : i64}}
// expected-remark @below {{bitwidth 32 : i64}}
%6 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%3 : !in_tensor_t) outs(%1 : !out_tensor_t) {
^bb0(%arg3: f32, %arg4: f32):
%4 = arith.addf %arg3, %arg4 : f32
linalg.yield %4 : f32
} -> !out_tensor_t
%7 = tensor.empty() : !out_tensor_t
// expected-remark @below {{trailing}}
%8 = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%6 : !out_tensor_t) outs(%7 : !out_tensor_t) {
^bb0(%arg3: f32, %arg4: f32):
%4 = math.sqrt %arg3 : f32
linalg.yield %4 : f32
} -> !out_tensor_t
return %8 : !out_tensor_t
}
func.func @reduction_with_extra_op_in_func(%arg0: tensor<8x479xf32>, %arg1: tensor<32x32xf32>) -> (tensor<8xf32>, tensor<32xf32>) {
%cst = arith.constant 0.0 : f32
%empty = tensor.empty() : tensor<8xf32>
// expected-remark @below {{fill}}
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<8xf32>) -> tensor<8xf32>
// expected-remark @below {{reduction}}
// expected-remark @below {{rank 2}}
// expected-remark @below {{dimensions 8 : i64, 479 : i64}}
// expected-remark @below {{bitwidth 32 : i64}}
%result = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%arg0 : tensor<8x479xf32>)
outs(%fill : tensor<8xf32>) {
^bb0(%in: f32, %out: f32):
%6 = arith.addf %in, %out : f32
linalg.yield %6 : f32
} -> tensor<8xf32>
%empty2 = tensor.empty() : tensor<32xf32>
%fill2 = linalg.fill ins(%cst : f32) outs(%empty2 : tensor<32xf32>) -> tensor<32xf32>
return %result, %fill2 : tensor<8xf32>, tensor<32xf32>
}