This patch enables continuous tiling of a target structured op using diminishing tile sizes. In cases where the tensor dimensions are not exactly divisible by the tile size, we are left with leftover tensor chunks that are irregularly tiled. This approach enables tiling of the leftover chunk with a smaller tile size and repeats this process recursively using exponentially diminishing tile sizes. This eventually generates a chain of loops that apply tiling using diminishing tile sizes. Adds `continuous_tile_sizes` op to the transform dialect. This op, when given a tile size and a dimension, computes a series of diminishing tile sizes that can be used to tile the target along the given dimension. Additionally, this op also generates a series of chunk sizes that the corresponding tile sizes should be applied to along the given dimension. Adds `multiway` attribute to `transform.structured.split` that enables multiway splitting of a single target op along the given dimension, as specified in a list enumerating the chunk sizes.
343 lines
15 KiB
MLIR
343 lines
15 KiB
MLIR
// RUN: mlir-opt %s --transform-interpreter --split-input-file -verify-diagnostics | FileCheck %s
|
|
|
|
module attributes {transform.with_named_sequence} {
|
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
|
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
|
%1:2 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op
|
|
transform.yield
|
|
}
|
|
}
|
|
|
|
func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
|
|
|
|
// CHECK: #[[$ADD_42_MAP:.+]] = affine_map<(d0) -> (d0 + 42)>
|
|
|
|
// CHECK-LABEL: @one_d_static
|
|
// CHECK-SAME: %[[IN:.+]]: tensor<100xf32>, %[[OUT:.+]]: tensor<100xf32>
|
|
func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
|
|
// CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN]][0] [42] [1] : tensor<100xf32> to tensor<42xf32>
|
|
// CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT]][0] [42] [1] : tensor<100xf32> to tensor<42xf32>
|
|
// CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic
|
|
// CHECK: ins(%[[IN_SLICE_LOW]]
|
|
// CHECK: outs(%[[OUT_SLICE_LOW]]
|
|
// CHECK: linalg.index 0
|
|
// CHECK: func.call @elem
|
|
// CHECK: %[[RES_PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [42] [1]
|
|
//
|
|
// CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN]][42] [58] [1] : tensor<100xf32> to tensor<58xf32>
|
|
// CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[RES_PARTIAL]][42] [58] [1] : tensor<100xf32> to tensor<58xf32>
|
|
// CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic
|
|
// CHECK: ins(%[[IN_SLICE_HIGH]]
|
|
// CHECK: outs(%[[OUT_SLICE_HIGH]]
|
|
// CHECK: %[[IDX:.+]] = linalg.index 0
|
|
// CHECK: affine.apply #[[$ADD_42_MAP]](%[[IDX]])
|
|
// CHECK: func.call @elem
|
|
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[RES_PARTIAL]][42] [58] [1]
|
|
%0 = linalg.generic {
|
|
indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
|
|
iterator_types = ["parallel"]
|
|
}
|
|
ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
|
|
^bb0(%0: f32, %1: f32):
|
|
%i = linalg.index 0 : index
|
|
%call_res = func.call @elem(%0, %i, %i) : (f32, index, index) -> f32
|
|
linalg.yield %call_res : f32
|
|
} -> tensor<100xf32>
|
|
|
|
// CHECK: return %[[RES]]
|
|
return %0 : tensor<100xf32>
|
|
}
|
|
|
|
// -----
|
|
|
|
module attributes {transform.with_named_sequence} {
|
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
|
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
|
%1:2 = transform.structured.split %0 after 42 { dimension = 0 } : !transform.any_op
|
|
transform.yield
|
|
}
|
|
}
|
|
|
|
func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
|
|
|
|
// CHECK-LABEL: @one_d_static_overflow
|
|
// CHECK-SAME: %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32>
|
|
func.func @one_d_static_overflow(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> {
|
|
// Folding is sufficiently powerful to detect the static overflow and avoid
|
|
// the splitting altogether.
|
|
// CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic
|
|
// CHECK: ins(%[[IN]]
|
|
// CHECK: outs(%[[OUT]]
|
|
// CHECK: linalg.index 0
|
|
// CHECK: func.call @elem
|
|
%0 = linalg.generic {
|
|
indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
|
|
iterator_types = ["parallel"]
|
|
}
|
|
ins(%arg0: tensor<10xf32>) outs(%arg1: tensor<10xf32>) {
|
|
^bb0(%0: f32, %1: f32):
|
|
%i = linalg.index 0 : index
|
|
%call_res = func.call @elem(%0, %i, %i) : (f32, index, index) -> f32
|
|
linalg.yield %call_res : f32
|
|
} -> tensor<10xf32>
|
|
return %0 : tensor<10xf32>
|
|
}
|
|
|
|
// -----
|
|
|
|
module attributes {transform.with_named_sequence} {
|
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
|
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
|
%1 = transform.structured.match ops{["func.call"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
|
transform.structured.split %0 after %1 { dimension = 0 } : !transform.any_op, !transform.any_op
|
|
transform.yield
|
|
}
|
|
}
|
|
|
|
func.func private @get_size() -> index
|
|
|
|
// CHECK: #[[$MAP_MIN_100:.+]] = affine_map<()[s0] -> (s0, 100)>
|
|
// CHECK: #[[$MAP_S_MINUS_100:.+]] = affine_map<()[s0] -> (-s0 + 100)>
|
|
|
|
// CHECK-LABEL: @dynamic
|
|
func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
|
|
// CHECK: %[[SPLIT:.+]] = call @get_size
|
|
// CHECK: %[[SPLIT_LOW:.+]] = affine.min #[[$MAP_MIN_100]]()[%[[SPLIT]]
|
|
// CHECK: %[[SPLIT_HIGH_1:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]]
|
|
// CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor<?xf32>
|
|
// CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor<?xf32>
|
|
// CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic
|
|
// CHECK: ins(%[[IN_SLICE_LOW]]
|
|
// CHECK: outs(%[[OUT_SLICE_LOW]]
|
|
// CHECK: %[[PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [%[[SPLIT_LOW]]] [1]
|
|
//
|
|
// CHECK: %[[SPLIT_HIGH_2:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]]
|
|
// CHECK: %[[SPLIT_HIGH_3:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]]
|
|
// CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_2]]] [1] : tensor<100xf32> to tensor<?xf32>
|
|
// CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[PARTIAL:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_3]]] [1] : tensor<100xf32> to tensor<?xf32>
|
|
// CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic
|
|
// CHECK: ins(%[[IN_SLICE_HIGH]]
|
|
// CHECK: outs(%[[OUT_SLICE_HIGH]]
|
|
// CHECK: %[[SPLIT_HIGH_4:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]]
|
|
// CHECK: tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[PARTIAL]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_4]]] [1]
|
|
%0 = func.call @get_size() : () -> index
|
|
%1 = linalg.generic {
|
|
indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
|
|
iterator_types = ["parallel"]
|
|
}
|
|
ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
|
|
^bb0(%3: f32, %4: f32):
|
|
%5 = arith.addf %3, %4 : f32
|
|
linalg.yield %5 : f32
|
|
} -> tensor<100xf32>
|
|
return %1 : tensor<100xf32>
|
|
}
|
|
|
|
// -----
|
|
|
|
module attributes {transform.with_named_sequence} {
|
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
|
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
|
%1:2 = transform.structured.split %0 after 4 { dimension = 0 } : !transform.any_op
|
|
%2:2 = transform.structured.split %1#1 after 16 { dimension = 1 } : !transform.any_op
|
|
transform.yield
|
|
}
|
|
}
|
|
|
|
func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
|
|
|
|
// CHECK-LABEL: @two_d
|
|
func.func @two_d(%arg0: tensor<10x34xf32>,
|
|
%arg1: tensor<10x34xf32>) -> tensor<10x34xf32> {
|
|
// Check the overall structure: split along the dimension 0, and then split
|
|
// the second half only along the dimension 1.
|
|
// CHECK: %[[IN_1:.+]] = tensor.extract_slice %[[IN:.+]][0, 0]
|
|
// CHECK: %[[OUT_1:.+]] = tensor.extract_slice %[[OUT:.+]][0, 0]
|
|
// CHECK: %[[RES_1:.+]] = linalg.generic
|
|
// CHECK-SAME: ins(%[[IN_1]] : tensor<4x34xf32>)
|
|
// CHECK-SAME: outs(%[[OUT_1]] : tensor<4x34xf32>)
|
|
// CHECK: %[[PARTIAL_1:.+]] = tensor.insert_slice %[[RES_1]] into %[[OUT]]
|
|
//
|
|
// CHECK: %[[IN_2:.+]] = tensor.extract_slice %[[IN]]
|
|
// CHECK: %[[OUT_2:.+]] = tensor.extract_slice %[[PARTIAL_1]]
|
|
// Note that `extract_slice` taking a slice from another `extract_slice` result
|
|
// is folded to use the operand of the first `extract_slice`.
|
|
// CHECK: %[[IN_21:.+]] = tensor.extract_slice %[[IN_2]]
|
|
// CHECK: %[[OUT_21:.+]] = tensor.extract_slice %[[OUT_2]]
|
|
// CHECK: %[[RES_21:.+]] = linalg.generic
|
|
// CHECK-SAME: ins(%[[IN_21]] : tensor<6x16xf32>)
|
|
// CHECK-SAME: outs(%[[OUT_21]] : tensor<6x16xf32>)
|
|
// CHECK: %[[PARTIAL_21:.+]] = tensor.insert_slice %[[RES_21]] into %[[OUT_2]]
|
|
//
|
|
// CHECK: %[[IN_22:.+]] = tensor.extract_slice %[[IN_2]]
|
|
// CHECK: %[[OUT_22:.+]] = tensor.extract_slice %[[PARTIAL_21]]
|
|
// CHECK: %[[RES_22:.+]] = linalg.generic
|
|
// CHECK-SAME: ins(%[[IN_22]] : tensor<6x18xf32>)
|
|
// CHECK-SAME: outs(%[[OUT_22]] : tensor<6x18xf32>)
|
|
// CHECK: %[[PARTIAL_22:.+]] = tensor.insert_slice %[[RES_22]] into %[[PARTIAL_21]]
|
|
// CHECK: %[[PARTIAL_2:.+]] = tensor.insert_slice %[[PARTIAL_22]] into %[[PARTIAL_1]]
|
|
%0 = linalg.generic {
|
|
indexing_maps = [affine_map<(i, j) -> (i, j)>,
|
|
affine_map<(i, j) -> (i, j)>],
|
|
iterator_types = ["parallel", "parallel"]
|
|
}
|
|
ins(%arg0: tensor<10x34xf32>)
|
|
outs(%arg1: tensor<10x34xf32>) {
|
|
^bb0(%0: f32, %1: f32):
|
|
%i = linalg.index 0 : index
|
|
%j = linalg.index 1 : index
|
|
%call_res = func.call @elem(%0, %i, %j) : (f32, index, index) -> f32
|
|
linalg.yield %call_res : f32
|
|
} -> tensor<10x34xf32>
|
|
return %0 : tensor<10x34xf32>
|
|
}
|
|
|
|
// -----
|
|
|
|
module attributes {transform.with_named_sequence} {
|
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) {
|
|
// expected-error @below {{expects either a dynamic or a static split point to be provided}}
|
|
%0:2 = "transform.structured.split"(%arg1) { dimension = 1, static_chunk_sizes = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
|
transform.yield
|
|
}
|
|
}
|
|
|
|
// -----
|
|
|
|
module attributes {transform.with_named_sequence} {
|
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
|
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
|
%1 = transform.structured.match ops{["func.call"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
|
// expected-error @below {{expected dynamic split point handle to point to a single-result index-typed op}}
|
|
transform.structured.split %0 after %1 { dimension = 0 } : !transform.any_op, !transform.any_op
|
|
transform.yield
|
|
}
|
|
}
|
|
|
|
func.func private @get_size() -> i64
|
|
|
|
func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
|
|
// expected-note @below {{dynamic split point}}
|
|
%0 = func.call @get_size() : () -> i64
|
|
%1 = linalg.generic {
|
|
indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
|
|
iterator_types = ["parallel"]
|
|
}
|
|
ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
|
|
^bb0(%3: f32, %4: f32):
|
|
linalg.yield %3 : f32
|
|
} -> tensor<100xf32>
|
|
return %1 : tensor<100xf32>
|
|
}
|
|
|
|
// -----
|
|
|
|
module attributes {transform.with_named_sequence} {
|
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
|
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
|
%1 = transform.structured.match ops{["func.call"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
|
// expected-error @below {{expected the dynamic split point handle to point to as many operations (0) as the target handle (1)}}
|
|
transform.structured.split %0 after %1 { dimension = 0 } : !transform.any_op, !transform.any_op
|
|
transform.yield
|
|
}
|
|
}
|
|
|
|
func.func private @get_size() -> i64
|
|
|
|
func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
|
|
%1 = linalg.generic {
|
|
indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
|
|
iterator_types = ["parallel"]
|
|
}
|
|
ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
|
|
^bb0(%3: f32, %4: f32):
|
|
linalg.yield %3 : f32
|
|
} -> tensor<100xf32>
|
|
return %1 : tensor<100xf32>
|
|
}
|
|
|
|
// -----
|
|
|
|
module attributes {transform.with_named_sequence} {
|
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
|
%0 = transform.structured.match ops{["func.return"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
|
// expected-error @below {{only applies to structured ops}}
|
|
transform.structured.split %0 after 16 { dimension = 1 } : !transform.any_op
|
|
transform.yield
|
|
}
|
|
}
|
|
|
|
func.func @noop(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
|
|
// expected-note @below {{target op}}
|
|
return %arg0 : tensor<100xf32>
|
|
}
|
|
|
|
// -----
|
|
|
|
module attributes {transform.with_named_sequence} {
|
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
|
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
|
// expected-error @below {{dimension 1 does not exist in target op}}
|
|
transform.structured.split %0 after 16 { dimension = 1 } : !transform.any_op
|
|
transform.yield
|
|
}
|
|
}
|
|
|
|
func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
|
|
// expected-note @below {{target op}}
|
|
%0 = linalg.generic {
|
|
indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
|
|
iterator_types = ["parallel"]
|
|
}
|
|
ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
|
|
^bb0(%0: f32, %1: f32):
|
|
linalg.yield %0 : f32
|
|
} -> tensor<100xf32>
|
|
return %0 : tensor<100xf32>
|
|
}
|
|
|
|
// -----
|
|
|
|
module attributes {transform.with_named_sequence} {
|
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
|
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
|
// expected-error @below {{splitting does not produce the second part for a subset of targets}}
|
|
// expected-note @below {{expected splitting to produce the second part of all or none of the targets}}
|
|
%1:2 = transform.structured.split %0 after 142 { dimension = 0 } : !transform.any_op
|
|
transform.yield
|
|
}
|
|
}
|
|
|
|
func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
|
|
|
|
func.func @split_one_but_not_other(
|
|
%arg0: tensor<100xf32>, %arg1: tensor<100xf32>,
|
|
%arg2: tensor<200xf32>, %arg3: tensor<200xf32>)
|
|
-> (tensor<100xf32>, tensor<200xf32>) {
|
|
// expected-note @below {{first target with no second part}}
|
|
%0 = linalg.generic {
|
|
indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
|
|
iterator_types = ["parallel"]
|
|
}
|
|
ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
|
|
^bb0(%arg4: f32, %arg5: f32):
|
|
%i = linalg.index 0 : index
|
|
%call_res = func.call @elem(%arg4, %i, %i) : (f32, index, index) -> f32
|
|
linalg.yield %call_res : f32
|
|
} -> tensor<100xf32>
|
|
|
|
%1 = linalg.generic {
|
|
indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
|
|
iterator_types = ["parallel"]
|
|
}
|
|
ins(%arg2: tensor<200xf32>) outs(%arg3: tensor<200xf32>) {
|
|
^bb0(%arg4: f32, %arg5: f32):
|
|
%i = linalg.index 0 : index
|
|
%call_res = func.call @elem(%arg4, %i, %i) : (f32, index, index) -> f32
|
|
linalg.yield %call_res : f32
|
|
} -> tensor<200xf32>
|
|
|
|
return %0, %1 : tensor<100xf32>, tensor<200xf32>
|
|
}
|
|
|