From 8602204d9fc483c7c58fa4e4d422d9bffb4e4e95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20B=C3=B6ck?= Date: Mon, 30 Jun 2025 09:49:19 +0200 Subject: [PATCH] [mlir][tensor] Relax input type requirement on `tensor.splat` (#145893) `tensor.splat` is currently restricted to only accepting input values that are of integer, index or float type. This is much more restrictive than the tensor type itself as well as any lowerings of it. This PR therefore removes this restriction by using `AnyType` for the input value. Whether the type is actually valid or not for a tensor remains verified through the type equality of the result tensor element type and the input type. --- mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 6 ++---- mlir/test/Dialect/Tensor/bufferize.mlir | 15 +++++++++++++++ mlir/test/Dialect/Tensor/invalid.mlir | 7 ++++--- mlir/test/Dialect/Tensor/ops.mlir | 12 ++++++++---- 4 files changed, 29 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 47962f75558e..7d396e5c64c2 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1771,8 +1771,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [ ]> { let summary = "tensor splat or broadcast operation"; let description = [{ - Broadcast the operand to all elements of the result tensor. The operand is - required to be of integer/index/float type. + Broadcast the operand to all elements of the result tensor. An additional argument of type `index` must be provided for each dynamic dimension present in the result type. @@ -1795,8 +1794,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [ ``` }]; - let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat], - "integer/index/float type">:$input, + let arguments = (ins AnyType:$input, Variadic:$dynamicSizes); let results = (outs AnyRankedTensor:$aggregate); diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir index c0adc8a49bf7..296ca02564e3 100644 --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -615,6 +615,21 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> { // ----- +// CHECK-LABEL: func @tensor.splat_other( +// CHECK-SAME: %[[F:.*]]: !test.memref_element) +// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<10x2x4x!test.memref_element> +// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]] +// CHECK: %[[MAPPED:.*]] = linalg.map +// CHECK: outs(%[[ALLOC_T]] : tensor<10x2x4x!test.memref_element>) +// CHECK: linalg.yield %[[F]] +// CHECK: return %[[MAPPED]] : tensor<10x2x4x!test.memref_element> +func.func @tensor.splat_other(%f: !test.memref_element) -> tensor<10x2x4x!test.memref_element> { + %t = tensor.splat %f : tensor<10x2x4x!test.memref_element> + return %t : tensor<10x2x4x!test.memref_element> +} + +// ----- + // CHECK-LABEL: func @tensor.concat( // CHECK-SAME: %[[F:.*]]: tensor<8xf32>) // CHECK: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]] diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index f35d52e70008..665657a67dc6 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -466,9 +466,10 @@ func.func @invalid_splat(%v : f32) { // ----- -func.func @invalid_splat(%v : vector<8xf32>) { - // expected-error@+1 {{must be integer/index/float type}} - %w = tensor.splat %v : tensor<8xvector<8xf32>> +// expected-note@+1 {{prior use here}} +func.func @invalid_splat(%v : f32) { + // expected-error@+1 {{expects different type than prior uses: 'i32' vs 'f32'}} + %w = tensor.splat %v : tensor<1xi32> return } diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir index 930986211cb6..681a934ba069 100644 --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -313,13 +313,17 @@ func.func @pad_to_static_size(%arg0: tensor, %ub0: index, %ub1: index, // ----- // CHECK-LABEL: func @test_splat_op -// CHECK-SAME: [[S:%arg[0-9]+]]: f32 -func.func @test_splat_op(%s : f32) { - // CHECK: tensor.splat [[S]] : tensor<8xf32> +// CHECK-SAME: %[[S:.*]]: f32 +// CHECK-SAME: %[[P:.*]]: !llvm.ptr +func.func @test_splat_op(%s : f32, %p : !llvm.ptr) { + // CHECK: tensor.splat %[[S]] : tensor<8xf32> %v = tensor.splat %s : tensor<8xf32> - // CHECK: tensor.splat [[S]] : tensor<4xf32> + // CHECK: tensor.splat %[[S]] : tensor<4xf32> %u = "tensor.splat"(%s) : (f32) -> tensor<4xf32> + + // CHECK: tensor.splat %[[P]] : tensor<8x!llvm.ptr> + %w = tensor.splat %p : tensor<8x!llvm.ptr> return }