Files
clang-p2996/mlir/test/Dialect/SparseTensor/sparse_storage.mlir
wren romano a0615d020a [mlir][sparse] Renaming the STEA field dimLevelType to lvlTypes
This commit is part of the migration of towards the new STEA syntax/design.  In particular, this commit includes the following changes:
* Renaming compiler-internal functions/methods:
  * `SparseTensorEncodingAttr::{getDimLevelType => getLvlTypes}`
  * `Merger::{getDimLevelType => getLvlType}` (for consistency)
  * `sparse_tensor::{getDimLevelType => buildLevelType}` (to help reduce confusion vs actual getter methods)
* Renaming external facets to match:
  * the STEA parser and printer
  * the C and Python bindings
  * PyTACO

However, the actual renaming of the `DimLevelType` itself (along with all the "dlt" names) will be handled in a separate commit.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D150330
2023-05-17 14:24:09 -07:00

79 lines
3.2 KiB
MLIR

// RUN: mlir-opt %s -sparsification= | FileCheck %s
#SparseVector64 = #sparse_tensor.encoding<{
lvlTypes = [ "compressed" ],
posWidth = 64,
crdWidth = 64
}>
#SparseVector32 = #sparse_tensor.encoding<{
lvlTypes = [ "compressed" ],
posWidth = 32,
crdWidth = 32
}>
#trait_mul = {
indexing_maps = [
affine_map<(i) -> (i)>, // a
affine_map<(i) -> (i)>, // b
affine_map<(i) -> (i)> // x (out)
],
iterator_types = ["parallel"],
doc = "x(i) = a(i) * b(i)"
}
// CHECK-LABEL: func @mul64(
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[P0:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<?xi64>
// CHECK: %[[B0:.*]] = arith.index_cast %[[P0]] : i64 to index
// CHECK: %[[P1:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<?xi64>
// CHECK: %[[B1:.*]] = arith.index_cast %[[P1]] : i64 to index
// CHECK: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
// CHECK: %[[IND0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xi64>
// CHECK: %[[INDC:.*]] = arith.index_cast %[[IND0]] : i64 to index
// CHECK: %[[VAL0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf64>
// CHECK: %[[VAL1:.*]] = memref.load %{{.*}}[%[[INDC]]] : memref<32xf64>
// CHECK: %[[MUL:.*]] = arith.mulf %[[VAL0]], %[[VAL1]] : f64
// CHECK: store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
// CHECK: }
func.func @mul64(%arga: tensor<32xf64, #SparseVector64>, %argb: tensor<32xf64>, %argx: tensor<32xf64>) -> tensor<32xf64> {
%0 = linalg.generic #trait_mul
ins(%arga, %argb: tensor<32xf64, #SparseVector64>, tensor<32xf64>)
outs(%argx: tensor<32xf64>) {
^bb(%a: f64, %b: f64, %x: f64):
%0 = arith.mulf %a, %b : f64
linalg.yield %0 : f64
} -> tensor<32xf64>
return %0 : tensor<32xf64>
}
// CHECK-LABEL: func @mul32(
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[P0:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<?xi32>
// CHECK: %[[Z0:.*]] = arith.extui %[[P0]] : i32 to i64
// CHECK: %[[B0:.*]] = arith.index_cast %[[Z0]] : i64 to index
// CHECK: %[[P1:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<?xi32>
// CHECK: %[[Z1:.*]] = arith.extui %[[P1]] : i32 to i64
// CHECK: %[[B1:.*]] = arith.index_cast %[[Z1]] : i64 to index
// CHECK: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
// CHECK: %[[IND0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xi32>
// CHECK: %[[ZEXT:.*]] = arith.extui %[[IND0]] : i32 to i64
// CHECK: %[[INDC:.*]] = arith.index_cast %[[ZEXT]] : i64 to index
// CHECK: %[[VAL0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf64>
// CHECK: %[[VAL1:.*]] = memref.load %{{.*}}[%[[INDC]]] : memref<32xf64>
// CHECK: %[[MUL:.*]] = arith.mulf %[[VAL0]], %[[VAL1]] : f64
// CHECK: store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
// CHECK: }
func.func @mul32(%arga: tensor<32xf64, #SparseVector32>, %argb: tensor<32xf64>, %argx: tensor<32xf64>) -> tensor<32xf64> {
%0 = linalg.generic #trait_mul
ins(%arga, %argb: tensor<32xf64, #SparseVector32>, tensor<32xf64>)
outs(%argx: tensor<32xf64>) {
^bb(%a: f64, %b: f64, %x: f64):
%0 = arith.mulf %a, %b : f64
linalg.yield %0 : f64
} -> tensor<32xf64>
return %0 : tensor<32xf64>
}