diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 02e62930a742..d58ee84bee63 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2920,8 +2920,8 @@ def Vector_SplatOp : Vector_Op<"splat", [ ]> { let summary = "vector splat or broadcast operation"; let description = [{ - Broadcast the operand to all elements of the result vector. The operand is - required to be of integer/index/float type. + Broadcast the operand to all elements of the result vector. The type of the + operand must match the element type of the vector type. Example: @@ -2931,8 +2931,7 @@ def Vector_SplatOp : Vector_Op<"splat", [ ``` }]; - let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat], - "integer/index/float type">:$input); + let arguments = (ins AnyType:$input); let results = (outs AnyVectorOfAnyRank:$aggregate); let builders = [ diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index ec7cee7b2c64..4935ec8ba8e6 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1975,6 +1975,15 @@ func.func @flat_transpose_scalable(%arg0: vector<[16]xf32>) -> vector<[16]xf32> // ----- +// expected-note @+1 {{prior use here}} +func.func @vector_splat_type_mismatch(%a: f32) { + // expected-error @+1 {{expects different type than prior uses: 'i32' vs 'f32'}} + %0 = vector.splat %a : vector<1xi32> + return +} + +// ----- + //===----------------------------------------------------------------------===// // vector.load //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index c59f7bd00190..0121bcdbbba4 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -149,7 +149,7 @@ func.func @vector_transfer_ops_tensor(%arg0: tensor, } // CHECK-LABEL: @vector_broadcast -func.func @vector_broadcast(%a: f32, %b: vector, %c: vector<16xf32>, %d: vector<1x16xf32>, %e: vector<8x1xf32>) -> vector<8x16xf32> { +func.func @vector_broadcast(%a: f32, %b: vector, %c: vector<16xf32>, %d: vector<1x16xf32>, %e: vector<8x1xf32>, %f: vector<8x1x!llvm.ptr<1>>) { // CHECK: vector.broadcast %{{.*}} : f32 to vector %0 = vector.broadcast %a : f32 to vector // CHECK: vector.broadcast %{{.*}} : vector to vector<4xf32> @@ -162,7 +162,9 @@ func.func @vector_broadcast(%a: f32, %b: vector, %c: vector<16xf32>, %d: ve %4 = vector.broadcast %d : vector<1x16xf32> to vector<8x16xf32> // CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1xf32> to vector<8x16xf32> %5 = vector.broadcast %e : vector<8x1xf32> to vector<8x16xf32> - return %4 : vector<8x16xf32> + // CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1x!llvm.ptr<1>> to vector<8x16x!llvm.ptr<1>> + %6 = vector.broadcast %f : vector<8x1x!llvm.ptr<1>> to vector<8x16x!llvm.ptr<1>> + return } // CHECK-LABEL: @shuffle0D @@ -959,13 +961,16 @@ func.func @vector_scan(%0: vector<4x8x16x32xf32>) -> vector<4x8x16x32xf32> { } // CHECK-LABEL: func @test_splat_op -// CHECK-SAME: [[S:%arg[0-9]+]]: f32 -func.func @test_splat_op(%s : f32) { - // CHECK: vector.splat [[S]] : vector<8xf32> +// CHECK-SAME: %[[s:.*]]: f32, %[[s2:.*]]: !llvm.ptr<1> +func.func @test_splat_op(%s : f32, %s2 : !llvm.ptr<1>) { + // CHECK: vector.splat %[[s]] : vector<8xf32> %v = vector.splat %s : vector<8xf32> - // CHECK: vector.splat [[S]] : vector<4xf32> + // CHECK: vector.splat %[[s]] : vector<4xf32> %u = "vector.splat"(%s) : (f32) -> vector<4xf32> + + // CHECK: vector.splat %[[s2]] : vector<16x!llvm.ptr<1>> + %w = vector.splat %s2 : vector<16x!llvm.ptr<1>> return }