Files
clang-p2996/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
Oleksandr "Alex" Zinenko e4384149b5 [mlir] use transform-interpreter in test passes (#70040)
Update most test passes to use the transform-interpreter pass instead of
the test-transform-dialect-interpreter-pass. The new "main" interpreter
pass has a named entry point instead of looking up the top-level op with
`PossibleTopLevelOpTrait`, which is arguably a more understandable
interface. The change is mechanical, rewriting an unnamed sequence into
a named one and wrapping the transform IR in to a module when necessary.

Add an option to the transform-interpreter pass to target a tagged
payload op instead of the root anchor op, which is also useful for repro
generation.

Only the test in the transform dialect proper and the examples have not
been updated yet. These will be updated separately after a more careful
consideration of testing coverage of the transform interpreter logic.
2023-10-24 16:12:34 +02:00

163 lines
8.2 KiB
MLIR

// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
// CHECK-LABEL: func @nop_shape_cast
// CHECK-SAME: %[[A:.*]]: vector<16xf32>
// CHECK: return %[[A]] : vector<16xf32>
func.func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32>
return %0 : vector<16xf32>
}
// CHECK-LABEL: func @cancel_shape_cast
// CHECK-SAME: %[[A:.*]]: vector<16xf32>
// CHECK: return %[[A]] : vector<16xf32>
func.func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
%1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32>
return %1 : vector<16xf32>
}
// Shape up and downcasts for 2-D vectors, for supporting conversion to
// llvm.matrix operations
// CHECK-LABEL: func @shape_casts
func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) {
// CHECK-DAG: %[[cst22:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
// CHECK-DAG: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
// CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2xf32> from vector<2x2xf32>
//
// CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[cst]]
// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
//
// CHECK: %[[ex1:.*]] = vector.extract %{{.*}}[1] : vector<2xf32> from vector<2x2xf32>
//
// CHECK: %[[in2:.*]] = vector.insert_strided_slice %[[ex1]], %[[in0]]
// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
//
%0 = vector.shape_cast %a : vector<2x2xf32> to vector<4xf32>
// CHECK: %[[add:.*]] = arith.addf %[[in2]], %[[in2]] : vector<4xf32>
%r0 = arith.addf %0, %0: vector<4xf32>
//
// CHECK: %[[ss0:.*]] = vector.extract_strided_slice %[[add]]
// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} :
// CHECK-SAME: vector<4xf32> to vector<2xf32>
//
// CHECK: %[[res0:.*]] = vector.insert %[[ss0]], %[[cst22]] [0] :
// CHECK-SAME: vector<2xf32> into vector<2x2xf32>
//
// CHECK: %[[s2:.*]] = vector.extract_strided_slice %[[add]]
// CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} :
// CHECK-SAME: vector<4xf32> to vector<2xf32>
//
// CHECK: %[[res1:.*]] = vector.insert %[[s2]], %[[res0]] [1] :
// CHECK-SAME: vector<2xf32> into vector<2x2xf32>
//
%1 = vector.shape_cast %r0 : vector<4xf32> to vector<2x2xf32>
// CHECK: return %[[add]], %[[res1]] : vector<4xf32>, vector<2x2xf32>
return %r0, %1 : vector<4xf32>, vector<2x2xf32>
}
// CHECK-LABEL: func @shape_cast_2d2d
// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<3x2xf32>
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : f32 into vector<2x3xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : f32 from vector<3x2xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<3x2xf32>
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32>
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : f32 from vector<3x2xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<3x2xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32>
// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : f32 from vector<3x2xf32>
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32>
// CHECK: return %[[T11]] : vector<2x3xf32>
func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
%s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32>
return %s : vector<2x3xf32>
}
// CHECK-LABEL: func @shape_cast_3d1d
// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<6xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0, 0] : f32 from vector<1x3x2xf32>
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0] : f32 into vector<6xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 0, 1] : f32 from vector<1x3x2xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : f32 into vector<6xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1, 0] : f32 from vector<1x3x2xf32>
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : f32 into vector<6xf32>
// CHECK: %[[T6:.*]] = vector.extract %[[A]][0, 1, 1] : f32 from vector<1x3x2xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [3] : f32 into vector<6xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2, 0] : f32 from vector<1x3x2xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [4] : f32 into vector<6xf32>
// CHECK: %[[T10:.*]] = vector.extract %[[A]][0, 2, 1] : f32 from vector<1x3x2xf32>
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [5] : f32 into vector<6xf32>
// CHECK: return %[[T11]] : vector<6xf32>
func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
%s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32>
return %s : vector<6xf32>
}
// CHECK-LABEL: func @shape_cast_1d3d
// CHECK-SAME: %[[A:.*]]: vector<6xf32>
// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x1x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<6xf32>
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0, 0] : f32 into vector<2x1x3xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[A]][1] : f32 from vector<6xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 0, 1] : f32 into vector<2x1x3xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][2] : f32 from vector<6xf32>
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 0, 2] : f32 into vector<2x1x3xf32>
// CHECK: %[[T6:.*]] = vector.extract %[[A]][3] : f32 from vector<6xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0, 0] : f32 into vector<2x1x3xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[A]][4] : f32 from vector<6xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 0, 1] : f32 into vector<2x1x3xf32>
// CHECK: %[[T10:.*]] = vector.extract %[[A]][5] : f32 from vector<6xf32>
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 0, 2] : f32 into vector<2x1x3xf32>
// CHECK: return %[[T11]] : vector<2x1x3xf32>
func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
%s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32>
return %s : vector<2x1x3xf32>
}
// CHECK-LABEL: func.func @shape_cast_0d1d(
// CHECK-SAME: %[[VAL_0:.*]]: vector<f32>) -> vector<1xf32> {
// CHECK: %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
// CHECK: %[[VAL_2:.*]] = vector.extractelement %[[VAL_0]][] : vector<f32>
// CHECK: %[[VAL_3:.*]] = vector.insert %[[VAL_2]], %[[VAL_1]] [0] : f32 into vector<1xf32>
// CHECK: return %[[VAL_3]] : vector<1xf32>
// CHECK: }
func.func @shape_cast_0d1d(%arg0 : vector<f32>) -> vector<1xf32> {
%s = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
return %s : vector<1xf32>
}
// CHECK-LABEL: func.func @shape_cast_1d0d(
// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>) -> vector<f32> {
// CHECK: %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
// CHECK: %[[VAL_2:.*]] = vector.extract %[[VAL_0]][0] : f32 from vector<1xf32>
// CHECK: %[[VAL_3:.*]] = vector.insertelement %[[VAL_2]], %[[VAL_1]][] : vector<f32>
// CHECK: return %[[VAL_3]] : vector<f32>
// CHECK: }
func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
%s = vector.shape_cast %arg0 : vector<1xf32> to vector<f32>
return %s : vector<f32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %f {
transform.apply_patterns.vector.lower_shape_cast
} : !transform.any_op
transform.yield
}
}