[mlir][tensor] Relax input type requirement on tensor.splat (#145893)
`tensor.splat` is currently restricted to only accepting input values that are of integer, index or float type. This is much more restrictive than the tensor type itself as well as any lowerings of it. This PR therefore removes this restriction by using `AnyType` for the input value. Whether the type is actually valid or not for a tensor remains verified through the type equality of the result tensor element type and the input type.
This commit is contained in:
@@ -615,6 +615,21 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @tensor.splat_other(
|
||||
// CHECK-SAME: %[[F:.*]]: !test.memref_element)
|
||||
// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<10x2x4x!test.memref_element>
|
||||
// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
|
||||
// CHECK: %[[MAPPED:.*]] = linalg.map
|
||||
// CHECK: outs(%[[ALLOC_T]] : tensor<10x2x4x!test.memref_element>)
|
||||
// CHECK: linalg.yield %[[F]]
|
||||
// CHECK: return %[[MAPPED]] : tensor<10x2x4x!test.memref_element>
|
||||
func.func @tensor.splat_other(%f: !test.memref_element) -> tensor<10x2x4x!test.memref_element> {
|
||||
%t = tensor.splat %f : tensor<10x2x4x!test.memref_element>
|
||||
return %t : tensor<10x2x4x!test.memref_element>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @tensor.concat(
|
||||
// CHECK-SAME: %[[F:.*]]: tensor<8xf32>)
|
||||
// CHECK: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
|
||||
|
||||
@@ -466,9 +466,10 @@ func.func @invalid_splat(%v : f32) {
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_splat(%v : vector<8xf32>) {
|
||||
// expected-error@+1 {{must be integer/index/float type}}
|
||||
%w = tensor.splat %v : tensor<8xvector<8xf32>>
|
||||
// expected-note@+1 {{prior use here}}
|
||||
func.func @invalid_splat(%v : f32) {
|
||||
// expected-error@+1 {{expects different type than prior uses: 'i32' vs 'f32'}}
|
||||
%w = tensor.splat %v : tensor<1xi32>
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -313,13 +313,17 @@ func.func @pad_to_static_size(%arg0: tensor<?x?xf32>, %ub0: index, %ub1: index,
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @test_splat_op
|
||||
// CHECK-SAME: [[S:%arg[0-9]+]]: f32
|
||||
func.func @test_splat_op(%s : f32) {
|
||||
// CHECK: tensor.splat [[S]] : tensor<8xf32>
|
||||
// CHECK-SAME: %[[S:.*]]: f32
|
||||
// CHECK-SAME: %[[P:.*]]: !llvm.ptr
|
||||
func.func @test_splat_op(%s : f32, %p : !llvm.ptr) {
|
||||
// CHECK: tensor.splat %[[S]] : tensor<8xf32>
|
||||
%v = tensor.splat %s : tensor<8xf32>
|
||||
|
||||
// CHECK: tensor.splat [[S]] : tensor<4xf32>
|
||||
// CHECK: tensor.splat %[[S]] : tensor<4xf32>
|
||||
%u = "tensor.splat"(%s) : (f32) -> tensor<4xf32>
|
||||
|
||||
// CHECK: tensor.splat %[[P]] : tensor<8x!llvm.ptr>
|
||||
%w = tensor.splat %p : tensor<8x!llvm.ptr>
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user