This patch implements the lowering for vector deinterleave for vector of n-dimensions. Process involves unrolling the n-d vector to a series of one-dimensional vectors. The deinterleave operation is then used on these vectors. From: ``` %0, %1 = vector.deinterleave %a : vector<2x8xi8> -> vector<2x4xi8> ``` To: ``` %cst = arith.constant dense<0> : vector<2x4xi32> %0 = vector.extract %arg0[0] : vector<8xi32> from vector<2x8xi32> %res1, %res2 = vector.deinterleave %0 : vector<8xi32> -> vector<4xi32> %1 = vector.insert %res1, %cst [0] : vector<4xi32> into vector<2x4xi32> %2 = vector.insert %res2, %cst [0] : vector<4xi32> into vector<2x4xi32> %3 = vector.extract %arg0[1] : vector<8xi32> from vector<2x8xi32> %res1_0, %res2_1 = vector.deinterleave %3 : vector<8xi32> -> vector<4xi32> %4 = vector.insert %res1_0, %1 [1] : vector<4xi32> into vector<2x4xi32> %5 = vector.insert %res2_1, %2 [1] : vector<4xi32> into vector<2x4xi32> ...etc. ```
69 lines
3.9 KiB
MLIR
69 lines
3.9 KiB
MLIR
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
|
|
|
|
// CHECK-LABEL: @vector_deinterleave_2d
|
|
// CHECK-SAME: %[[SRC:.*]]: vector<2x8xi32>) -> (vector<2x4xi32>, vector<2x4xi32>)
|
|
func.func @vector_deinterleave_2d(%a: vector<2x8xi32>) -> (vector<2x4xi32>, vector<2x4xi32>) {
|
|
// CHECK: %[[CST:.*]] = arith.constant dense<0>
|
|
// CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0]
|
|
// CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]]
|
|
// CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %[[CST]] [0]
|
|
// CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %[[CST]] [0]
|
|
// CHECK: %[[SRC_1:.*]] = vector.extract %[[SRC]][1]
|
|
// CHECK: %[[UNZIP_2:.*]], %[[UNZIP_3:.*]] = vector.deinterleave %[[SRC_1]]
|
|
// CHECK: %[[RES_2:.*]] = vector.insert %[[UNZIP_2]], %[[RES_0]] [1]
|
|
// CHECK: %[[RES_3:.*]] = vector.insert %[[UNZIP_3]], %[[RES_1]] [1]
|
|
// CHECK-NEXT: return %[[RES_2]], %[[RES_3]] : vector<2x4xi32>, vector<2x4xi32>
|
|
%0, %1 = vector.deinterleave %a : vector<2x8xi32> -> vector<2x4xi32>
|
|
return %0, %1 : vector<2x4xi32>, vector<2x4xi32>
|
|
}
|
|
|
|
// CHECK-LABEL: @vector_deinterleave_2d_scalable
|
|
// CHECK-SAME: %[[SRC:.*]]: vector<2x[8]xi32>) -> (vector<2x[4]xi32>, vector<2x[4]xi32>)
|
|
func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xi32>) -> (vector<2x[4]xi32>, vector<2x[4]xi32>) {
|
|
// CHECK: %[[CST:.*]] = arith.constant dense<0>
|
|
// CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0]
|
|
// CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]]
|
|
// CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %[[CST]] [0]
|
|
// CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %[[CST]] [0]
|
|
// CHECK: %[[SRC_1:.*]] = vector.extract %[[SRC]][1]
|
|
// CHECK: %[[UNZIP_2:.*]], %[[UNZIP_3:.*]] = vector.deinterleave %[[SRC_1]]
|
|
// CHECK: %[[RES_2:.*]] = vector.insert %[[UNZIP_2]], %[[RES_0]] [1]
|
|
// CHECK: %[[RES_3:.*]] = vector.insert %[[UNZIP_3]], %[[RES_1]] [1]
|
|
// CHECK-NEXT: return %[[RES_2]], %[[RES_3]] : vector<2x[4]xi32>, vector<2x[4]xi32>
|
|
%0, %1 = vector.deinterleave %a : vector<2x[8]xi32> -> vector<2x[4]xi32>
|
|
return %0, %1 : vector<2x[4]xi32>, vector<2x[4]xi32>
|
|
}
|
|
|
|
// CHECK-LABEL: @vector_deinterleave_4d
|
|
// CHECK-SAME: %[[SRC:.*]]: vector<1x2x3x8xi64>) -> (vector<1x2x3x4xi64>, vector<1x2x3x4xi64>)
|
|
func.func @vector_deinterleave_4d(%a: vector<1x2x3x8xi64>) -> (vector<1x2x3x4xi64>, vector<1x2x3x4xi64>) {
|
|
// CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0, 0, 0] : vector<8xi64> from vector<1x2x3x8xi64>
|
|
// CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]] : vector<8xi64> -> vector<4xi64>
|
|
// CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %{{.*}} [0, 0, 0] : vector<4xi64> into vector<1x2x3x4xi64>
|
|
// CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %{{.*}} [0, 0, 0] : vector<4xi64> into vector<1x2x3x4xi64>
|
|
// CHECK-COUNT-5: vector.deinterleave %{{.*}} : vector<8xi64> -> vector<4xi64>
|
|
%0, %1 = vector.deinterleave %a : vector<1x2x3x8xi64> -> vector<1x2x3x4xi64>
|
|
return %0, %1 : vector<1x2x3x4xi64>, vector<1x2x3x4xi64>
|
|
}
|
|
|
|
// CHECK-LABEL: @vector_deinterleave_nd_with_scalable_dim
|
|
func.func @vector_deinterleave_nd_with_scalable_dim(
|
|
%a: vector<1x3x[2]x2x3x8xf16>) -> (vector<1x3x[2]x2x3x4xf16>, vector<1x3x[2]x2x3x4xf16>) {
|
|
// The scalable dim blocks unrolling so only the first two dims are unrolled.
|
|
// CHECK-COUNT-3: vector.deinterleave %{{.*}} : vector<[2]x2x3x8xf16>
|
|
%0, %1 = vector.deinterleave %a: vector<1x3x[2]x2x3x8xf16> -> vector<1x3x[2]x2x3x4xf16>
|
|
return %0, %1 : vector<1x3x[2]x2x3x4xf16>, vector<1x3x[2]x2x3x4xf16>
|
|
}
|
|
|
|
module attributes {transform.with_named_sequence} {
|
|
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
|
|
%f = transform.structured.match ops{["func.func"]} in %module_op
|
|
: (!transform.any_op) -> !transform.any_op
|
|
|
|
transform.apply_patterns to %f {
|
|
transform.apply_patterns.vector.lower_interleave
|
|
} : !transform.any_op
|
|
transform.yield
|
|
}
|
|
}
|