Default vector.contract lowering essentially yields a series of sdot/ddot
operations. However, for some layouts a series of saxpy/daxpy operations,
chained through fma are more efficient. This CL introduces a choice between
the two lowering paths. A default heuristic is to follow.
Some preliminary avx2 performance numbers for matrix-times-vector.
Here, dot performs best for 64x64 A x b and saxpy for 64x64 A^T x b.
```
------------------------------------------------------------
A x b A^T x b
------------------------------------------------------------
GFLOPS sdot (reassoc) saxpy sdot (reassoc) saxpy
------------------------------------------------------------
1x1 0.6 0.9 0.6 0.9
2x2 2.5 3.2 2.4 3.5
4x4 6.4 8.4 4.9 11.8
8x8 11.7 6.1 5.0 29.6
16x16 20.7 10.8 7.3 43.3
32x32 29.3 7.9 6.4 51.8
64x64 38.9 79.3
128x128 32.4 40.7
------------------------------------------------------------
```
Reviewed By: nicolasvasilache, ftynse
Differential Revision: https://reviews.llvm.org/D83012
164 lines
7.8 KiB
MLIR
164 lines
7.8 KiB
MLIR
// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-axpy=1 | FileCheck %s
|
|
|
|
#matvec_accesses = [
|
|
affine_map<(i, j) -> (i, j)>,
|
|
affine_map<(i, j) -> (j)>,
|
|
affine_map<(i, j) -> (i)>
|
|
]
|
|
#matvec_trait = {
|
|
indexing_maps = #matvec_accesses,
|
|
iterator_types = ["parallel", "reduction"]
|
|
}
|
|
|
|
#mattransvec_accesses = [
|
|
affine_map<(i, j) -> (j, i)>,
|
|
affine_map<(i, j) -> (j)>,
|
|
affine_map<(i, j) -> (i)>
|
|
]
|
|
#mattransvec_trait = {
|
|
indexing_maps = #mattransvec_accesses,
|
|
iterator_types = ["parallel", "reduction"]
|
|
}
|
|
|
|
#vecmat_accesses = [
|
|
affine_map<(i, j) -> (j)>,
|
|
affine_map<(i, j) -> (i, j)>,
|
|
affine_map<(i, j) -> (i)>
|
|
]
|
|
#vecmat_trait = {
|
|
indexing_maps = #vecmat_accesses,
|
|
iterator_types = ["parallel", "reduction"]
|
|
}
|
|
|
|
#vecmattrans_accesses = [
|
|
affine_map<(i, j) -> (j)>,
|
|
affine_map<(i, j) -> (j, i)>,
|
|
affine_map<(i, j) -> (i)>
|
|
]
|
|
#vecmattrans_trait = {
|
|
indexing_maps = #vecmattrans_accesses,
|
|
iterator_types = ["parallel", "reduction"]
|
|
}
|
|
|
|
// CHECK-LABEL: func @matvec2x2
|
|
// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
|
|
// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
|
|
// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
|
|
// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<2x2xf32>
|
|
// CHECK: %[[T0:.*]] = load %[[A]][] : memref<vector<2x2xf32>>
|
|
// CHECK: %[[T1:.*]] = load %[[B]][] : memref<vector<2xf32>>
|
|
// CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
|
|
// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0, 0] : vector<2x2xf32>
|
|
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0, 0] : f32 into vector<2x2xf32>
|
|
// CHECK: %[[T5:.*]] = vector.extract %[[T0]][1, 0] : vector<2x2xf32>
|
|
// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T4]] [0, 1] : f32 into vector<2x2xf32>
|
|
// CHECK: %[[T7:.*]] = vector.extract %[[T0]][0, 1] : vector<2x2xf32>
|
|
// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[T6]] [1, 0] : f32 into vector<2x2xf32>
|
|
// CHECK: %[[T9:.*]] = vector.extract %[[T0]][1, 1] : vector<2x2xf32>
|
|
// CHECK: %[[T10:.*]] = vector.insert %[[T9]], %[[T8]] [1, 1] : f32 into vector<2x2xf32>
|
|
// CHECK: %[[T11:.*]] = vector.extract %[[T10]][0] : vector<2x2xf32>
|
|
// CHECK: %[[T12:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
|
|
// CHECK: %[[T13:.*]] = splat %[[T12]] : vector<2xf32>
|
|
// CHECK: %[[T14:.*]] = vector.fma %[[T11]], %[[T13]], %[[T2]] : vector<2xf32>
|
|
// CHECK: %[[T15:.*]] = vector.extract %[[T10]][1] : vector<2x2xf32>
|
|
// CHECK: %[[T16:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
|
|
// CHECK: %[[T17:.*]] = splat %[[T16]] : vector<2xf32>
|
|
// CHECK: %[[T18:.*]] = vector.fma %[[T15]], %[[T17]], %[[T14]] : vector<2xf32>
|
|
// CHECK: store %[[T18]], %[[C]][] : memref<vector<2xf32>>
|
|
func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
|
|
%arg2: memref<vector<2xf32>>) {
|
|
%A = load %arg0[] : memref<vector<2x2xf32>>
|
|
%x = load %arg1[] : memref<vector<2xf32>>
|
|
%b = load %arg2[] : memref<vector<2xf32>>
|
|
%0 = vector.contract #matvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
|
|
store %0, %arg2[] : memref<vector<2xf32>>
|
|
return
|
|
}
|
|
|
|
// CHECK-LABEL: func @mattransvec2x2
|
|
// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
|
|
// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
|
|
// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
|
|
// CHECK: %[[T0:.*]] = load %[[A]][] : memref<vector<2x2xf32>>
|
|
// CHECK: %[[T1:.*]] = load %[[B]][] : memref<vector<2xf32>>
|
|
// CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
|
|
// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32>
|
|
// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
|
|
// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<2xf32>
|
|
// CHECK: %[[T6:.*]] = vector.fma %[[T3]], %[[T5]], %[[T2]] : vector<2xf32>
|
|
// CHECK: %[[T7:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32>
|
|
// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
|
|
// CHECK: %[[T9:.*]] = splat %[[T8]] : vector<2xf32>
|
|
// CHECK: %[[T10:.*]] = vector.fma %[[T7]], %[[T9]], %[[T6]] : vector<2xf32>
|
|
// CHECK: store %[[T10]], %[[C]][] : memref<vector<2xf32>>
|
|
func @mattransvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
|
|
%arg2: memref<vector<2xf32>>) {
|
|
%A = load %arg0[] : memref<vector<2x2xf32>>
|
|
%x = load %arg1[] : memref<vector<2xf32>>
|
|
%b = load %arg2[] : memref<vector<2xf32>>
|
|
%0 = vector.contract #mattransvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
|
|
store %0, %arg2[] : memref<vector<2xf32>>
|
|
return
|
|
}
|
|
|
|
// CHECK-LABEL: func @vecmat2x2
|
|
// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
|
|
// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
|
|
// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
|
|
// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<2x2xf32>
|
|
// CHECK: %[[T0:.*]] = load %[[A]][] : memref<vector<2x2xf32>>
|
|
// CHECK: %[[T1:.*]] = load %[[B]][] : memref<vector<2xf32>>
|
|
// CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
|
|
// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0, 0] : vector<2x2xf32>
|
|
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0, 0] : f32 into vector<2x2xf32>
|
|
// CHECK: %[[T5:.*]] = vector.extract %[[T0]][1, 0] : vector<2x2xf32>
|
|
// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T4]] [0, 1] : f32 into vector<2x2xf32>
|
|
// CHECK: %[[T7:.*]] = vector.extract %[[T0]][0, 1] : vector<2x2xf32>
|
|
// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[T6]] [1, 0] : f32 into vector<2x2xf32>
|
|
// CHECK: %[[T9:.*]] = vector.extract %[[T0]][1, 1] : vector<2x2xf32>
|
|
// CHECK: %[[T10:.*]] = vector.insert %[[T9]], %[[T8]] [1, 1] : f32 into vector<2x2xf32>
|
|
// CHECK: %[[T11:.*]] = vector.extract %[[T10]][0] : vector<2x2xf32>
|
|
// CHECK: %[[T12:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
|
|
// CHECK: %[[T13:.*]] = splat %[[T12]] : vector<2xf32>
|
|
// CHECK: %[[T14:.*]] = vector.fma %[[T11]], %[[T13]], %[[T2]] : vector<2xf32>
|
|
// CHECK: %[[T15:.*]] = vector.extract %[[T10]][1] : vector<2x2xf32>
|
|
// CHECK: %[[T16:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
|
|
// CHECK: %[[T17:.*]] = splat %[[T16]] : vector<2xf32>
|
|
// CHECK: %[[T18:.*]] = vector.fma %[[T15]], %[[T17]], %[[T14]] : vector<2xf32>
|
|
// CHECK: store %[[T18]], %[[C]][] : memref<vector<2xf32>>
|
|
func @vecmat2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
|
|
%arg2: memref<vector<2xf32>>) {
|
|
%A = load %arg0[] : memref<vector<2x2xf32>>
|
|
%x = load %arg1[] : memref<vector<2xf32>>
|
|
%b = load %arg2[] : memref<vector<2xf32>>
|
|
%0 = vector.contract #vecmat_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
|
|
store %0, %arg2[] : memref<vector<2xf32>>
|
|
return
|
|
}
|
|
|
|
// CHECK-LABEL: func @vecmattrans2x2
|
|
// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
|
|
// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
|
|
// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
|
|
// CHECK: %[[T0:.*]] = load %[[A]][] : memref<vector<2x2xf32>>
|
|
// CHECK: %[[T1:.*]] = load %[[B]][] : memref<vector<2xf32>>
|
|
// CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
|
|
// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32>
|
|
// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
|
|
// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<2xf32>
|
|
// CHECK: %[[T6:.*]] = vector.fma %[[T3]], %[[T5]], %[[T2]] : vector<2xf32>
|
|
// CHECK: %[[T7:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32>
|
|
// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
|
|
// CHECK: %[[T9:.*]] = splat %[[T8]] : vector<2xf32>
|
|
// CHECK: %[[T10:.*]] = vector.fma %[[T7]], %[[T9]], %[[T6]] : vector<2xf32>
|
|
// CHECK: store %[[T10]], %[[C]][] : memref<vector<2xf32>>
|
|
func @vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
|
|
%arg2: memref<vector<2xf32>>) {
|
|
%A = load %arg0[] : memref<vector<2x2xf32>>
|
|
%x = load %arg1[] : memref<vector<2xf32>>
|
|
%b = load %arg2[] : memref<vector<2xf32>>
|
|
%0 = vector.contract #vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
|
|
store %0, %arg2[] : memref<vector<2xf32>>
|
|
return
|
|
}
|