[mlir][tensor] Relax input type requirement on tensor.splat (#145893)

`tensor.splat` is currently restricted to only accepting input values
that are of integer, index or float type.

This is much more restrictive than the tensor type itself as well as any
lowerings of it.

This PR therefore removes this restriction by using `AnyType` for the
input value. Whether the type is actually valid or not for a tensor
remains verified through the type equality of the result tensor element
type and the input type.
This commit is contained in:
Markus Böck
2025-06-30 09:49:19 +02:00
committed by GitHub
parent 57f7e14b57
commit 8602204d9f
4 changed files with 29 additions and 11 deletions

View File

@@ -615,6 +615,21 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
// -----
// CHECK-LABEL: func @tensor.splat_other(
// CHECK-SAME: %[[F:.*]]: !test.memref_element)
// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<10x2x4x!test.memref_element>
// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
// CHECK: %[[MAPPED:.*]] = linalg.map
// CHECK: outs(%[[ALLOC_T]] : tensor<10x2x4x!test.memref_element>)
// CHECK: linalg.yield %[[F]]
// CHECK: return %[[MAPPED]] : tensor<10x2x4x!test.memref_element>
func.func @tensor.splat_other(%f: !test.memref_element) -> tensor<10x2x4x!test.memref_element> {
%t = tensor.splat %f : tensor<10x2x4x!test.memref_element>
return %t : tensor<10x2x4x!test.memref_element>
}
// -----
// CHECK-LABEL: func @tensor.concat(
// CHECK-SAME: %[[F:.*]]: tensor<8xf32>)
// CHECK: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]

View File

@@ -466,9 +466,10 @@ func.func @invalid_splat(%v : f32) {
// -----
func.func @invalid_splat(%v : vector<8xf32>) {
// expected-error@+1 {{must be integer/index/float type}}
%w = tensor.splat %v : tensor<8xvector<8xf32>>
// expected-note@+1 {{prior use here}}
func.func @invalid_splat(%v : f32) {
// expected-error@+1 {{expects different type than prior uses: 'i32' vs 'f32'}}
%w = tensor.splat %v : tensor<1xi32>
return
}

View File

@@ -313,13 +313,17 @@ func.func @pad_to_static_size(%arg0: tensor<?x?xf32>, %ub0: index, %ub1: index,
// -----
// CHECK-LABEL: func @test_splat_op
// CHECK-SAME: [[S:%arg[0-9]+]]: f32
func.func @test_splat_op(%s : f32) {
// CHECK: tensor.splat [[S]] : tensor<8xf32>
// CHECK-SAME: %[[S:.*]]: f32
// CHECK-SAME: %[[P:.*]]: !llvm.ptr
func.func @test_splat_op(%s : f32, %p : !llvm.ptr) {
// CHECK: tensor.splat %[[S]] : tensor<8xf32>
%v = tensor.splat %s : tensor<8xf32>
// CHECK: tensor.splat [[S]] : tensor<4xf32>
// CHECK: tensor.splat %[[S]] : tensor<4xf32>
%u = "tensor.splat"(%s) : (f32) -> tensor<4xf32>
// CHECK: tensor.splat %[[P]] : tensor<8x!llvm.ptr>
%w = tensor.splat %p : tensor<8x!llvm.ptr>
return
}