// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py // RUN: mlir-opt %s --linalg-generalize-named-ops \ // RUN: --sparsification --sparse-tensor-codegen \ // RUN: --canonicalize --cse | FileCheck %s #CSR = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(i,j) -> (i,j)> }> // // Computes C = A x B with all matrices sparse (SpMSpM) in CSR. // // CHECK-LABEL: func.func private @_insert_D_C_4_4_f64_0_0( // CHECK-SAME: %[[VAL_0:.*]]: memref<2xindex>, // CHECK-SAME: %[[VAL_1:.*]]: memref<3xindex>, // CHECK-SAME: %[[VAL_2:[^ ]+]]: memref, // CHECK-SAME: %[[VAL_3:.*]]: memref, // CHECK-SAME: %[[VAL_4:.*]]: memref, // CHECK-SAME: %[[VAL_5:[^ ]+]]: index, // CHECK-SAME: %[[VAL_6:.*]]: index, // CHECK-SAME: %[[VAL_7:.*]]: f64) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { // CHECK-DAG: %[[VAL_8:.*]] = arith.constant false // CHECK-DAG: %[[VAL_9:.*]] = arith.constant 1 : index // CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_5]], %[[VAL_9]] : index // CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref // CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref // CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_9]]] : memref<3xindex> // CHECK: %[[VAL_14:.*]] = arith.subi %[[VAL_12]], %[[VAL_9]] : index // CHECK: %[[VAL_15:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_12]] : index // CHECK: %[[VAL_16:.*]] = scf.if %[[VAL_15]] -> (i1) { // CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_14]]] : memref // CHECK: %[[VAL_18:.*]] = arith.cmpi eq, %[[VAL_17]], %[[VAL_6]] : index // CHECK: scf.yield %[[VAL_18]] : i1 // CHECK: } else { // CHECK: memref.store %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref // CHECK: scf.yield %[[VAL_8]] : i1 // CHECK: } // CHECK: %[[VAL_19:.*]] = scf.if %[[VAL_20:.*]] -> (memref) { // CHECK: scf.yield %[[VAL_3]] : memref // CHECK: } else { // CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_13]], %[[VAL_9]] : index // CHECK: memref.store %[[VAL_21]], %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref // CHECK: %[[VAL_22:.*]] = sparse_tensor.push_back %[[VAL_1]], %[[VAL_3]], %[[VAL_6]] {idx = 1 : index} : memref<3xindex>, memref, index // CHECK: scf.yield %[[VAL_22]] : memref // CHECK: } // CHECK: %[[VAL_23:.*]] = sparse_tensor.push_back %[[VAL_1]], %[[VAL_4]], %[[VAL_7]] {idx = 2 : index} : memref<3xindex>, memref, f64 // CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_24:.*]], %[[VAL_23]] : memref<2xindex>, memref<3xindex>, memref, memref, memref // CHECK: } // CHECK-LABEL: func.func @matmul( // CHECK-SAME: %[[VAL_0:.*0]]: memref<2xindex>, // CHECK-SAME: %[[VAL_1:.*1]]: memref<3xindex>, // CHECK-SAME: %[[VAL_2:.*2]]: memref, // CHECK-SAME: %[[VAL_3:.*3]]: memref, // CHECK-SAME: %[[VAL_4:.*4]]: memref, // CHECK-SAME: %[[VAL_5:.*5]]: memref<2xindex>, // CHECK-SAME: %[[VAL_6:.*6]]: memref<3xindex>, // CHECK-SAME: %[[VAL_7:.*7]]: memref, // CHECK-SAME: %[[VAL_8:.*8]]: memref, // CHECK-SAME: %[[VAL_9:.*9]]: memref) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { // CHECK-DAG: %[[VAL_10:.*]] = arith.constant 4 : index // CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f64 // CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0 : index // CHECK-DAG: %[[VAL_13:.*]] = arith.constant 1 : index // CHECK-DAG: %[[VAL_14:.*]] = arith.constant false // CHECK-DAG: %[[VAL_15:.*]] = arith.constant true // CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<2xindex> // CHECK: %[[VAL_17:.*]] = memref.alloc() : memref<3xindex> // CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<16xindex> // CHECK: %[[VAL_19:.*]] = memref.cast %[[VAL_18]] : memref<16xindex> to memref // CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<16xindex> // CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<16xindex> to memref // CHECK: %[[VAL_22:.*]] = memref.alloc() : memref<16xf64> // CHECK: %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<16xf64> to memref // CHECK: linalg.fill ins(%[[VAL_12]] : index) outs(%[[VAL_17]] : memref<3xindex>) // CHECK: memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_12]]] : memref<2xindex> // CHECK: memref.store %[[VAL_10]], %[[VAL_16]]{{\[}}%[[VAL_13]]] : memref<2xindex> // CHECK: %[[VAL_24:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_19]], %[[VAL_12]] {idx = 0 : index} : memref<3xindex>, memref, index // CHECK: %[[VAL_25:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_24]], %[[VAL_12]], %[[VAL_10]] {idx = 0 : index} : memref<3xindex>, memref, index, index // CHECK: %[[VAL_26:.*]] = memref.alloc() : memref<4xf64> // CHECK: %[[VAL_27:.*]] = memref.alloc() : memref<4xi1> // CHECK: %[[VAL_28:.*]] = memref.alloc() : memref<4xindex> // CHECK: %[[VAL_29:.*]] = memref.cast %[[VAL_28]] : memref<4xindex> to memref // CHECK: linalg.fill ins(%[[VAL_11]] : f64) outs(%[[VAL_26]] : memref<4xf64>) // CHECK: linalg.fill ins(%[[VAL_14]] : i1) outs(%[[VAL_27]] : memref<4xi1>) // CHECK: %[[VAL_30:.*]]:5 = scf.for %[[VAL_31:.*]] = %[[VAL_12]] to %[[VAL_10]] step %[[VAL_13]] iter_args(%[[VAL_32:.*]] = %[[VAL_16]], %[[VAL_33:.*]] = %[[VAL_17]], %[[VAL_34:.*]] = %[[VAL_25]], %[[VAL_35:.*]] = %[[VAL_21]], %[[VAL_36:.*]] = %[[VAL_23]]) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { // CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_31]]] : memref // CHECK: %[[VAL_38:.*]] = arith.addi %[[VAL_31]], %[[VAL_13]] : index // CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_38]]] : memref // CHECK: %[[VAL_40:.*]] = scf.for %[[VAL_41:.*]] = %[[VAL_37]] to %[[VAL_39]] step %[[VAL_13]] iter_args(%[[VAL_42:.*]] = %[[VAL_12]]) -> (index) { // CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_41]]] : memref // CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_41]]] : memref // CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_43]]] : memref // CHECK: %[[VAL_46:.*]] = arith.addi %[[VAL_43]], %[[VAL_13]] : index // CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_46]]] : memref // CHECK: %[[VAL_48:.*]] = scf.for %[[VAL_49:.*]] = %[[VAL_45]] to %[[VAL_47]] step %[[VAL_13]] iter_args(%[[VAL_50:.*]] = %[[VAL_42]]) -> (index) { // CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_49]]] : memref // CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_51]]] : memref<4xf64> // CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_49]]] : memref // CHECK: %[[VAL_54:.*]] = arith.mulf %[[VAL_44]], %[[VAL_53]] : f64 // CHECK: %[[VAL_55:.*]] = arith.addf %[[VAL_52]], %[[VAL_54]] : f64 // CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_51]]] : memref<4xi1> // CHECK: %[[VAL_57:.*]] = arith.cmpi eq, %[[VAL_56]], %[[VAL_14]] : i1 // CHECK: %[[VAL_58:.*]] = scf.if %[[VAL_57]] -> (index) { // CHECK: memref.store %[[VAL_15]], %[[VAL_27]]{{\[}}%[[VAL_51]]] : memref<4xi1> // CHECK: memref.store %[[VAL_51]], %[[VAL_28]]{{\[}}%[[VAL_50]]] : memref<4xindex> // CHECK: %[[VAL_59:.*]] = arith.addi %[[VAL_50]], %[[VAL_13]] : index // CHECK: scf.yield %[[VAL_59]] : index // CHECK: } else { // CHECK: scf.yield %[[VAL_50]] : index // CHECK: } // CHECK: memref.store %[[VAL_55]], %[[VAL_26]]{{\[}}%[[VAL_51]]] : memref<4xf64> // CHECK: scf.yield %[[VAL_60:.*]] : index // CHECK: } {"Emitted from" = "linalg.generic"} // CHECK: sparse_tensor.sort %[[VAL_62:.*]], %[[VAL_29]] : memref // CHECK: %[[VAL_63:.*]]:5 = scf.for %[[VAL_64:.*]] = %[[VAL_12]] to %[[VAL_62]] step %[[VAL_13]] iter_args(%[[VAL_65:.*]] = %[[VAL_32]], %[[VAL_66:.*]] = %[[VAL_33]], %[[VAL_67:.*]] = %[[VAL_34]], %[[VAL_68:.*]] = %[[VAL_35]], %[[VAL_69:.*]] = %[[VAL_36]]) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { // CHECK: %[[VAL_70:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_64]]] : memref<4xindex> // CHECK: %[[VAL_71:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_70]]] : memref<4xf64> // CHECK: %[[VAL_72:.*]]:5 = func.call @_insert_D_C_4_4_f64_0_0(%[[VAL_65]], %[[VAL_66]], %[[VAL_67]], %[[VAL_68]], %[[VAL_69]], %[[VAL_31]], %[[VAL_70]], %[[VAL_71]]) : (memref<2xindex>, memref<3xindex>, memref, memref, memref, index, index, f64) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) // CHECK: memref.store %[[VAL_11]], %[[VAL_26]]{{\[}}%[[VAL_70]]] : memref<4xf64> // CHECK: memref.store %[[VAL_14]], %[[VAL_27]]{{\[}}%[[VAL_70]]] : memref<4xi1> // CHECK: scf.yield %[[VAL_72]]#0, %[[VAL_72]]#1, %[[VAL_72]]#2, %[[VAL_72]]#3, %[[VAL_72]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref // CHECK: } // CHECK: scf.yield %[[VAL_73:.*]]#0, %[[VAL_73]]#1, %[[VAL_73]]#2, %[[VAL_73]]#3, %[[VAL_73]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref // CHECK: } {"Emitted from" = "linalg.generic"} // CHECK: memref.dealloc %[[VAL_26]] : memref<4xf64> // CHECK: memref.dealloc %[[VAL_27]] : memref<4xi1> // CHECK: memref.dealloc %[[VAL_28]] : memref<4xindex> // CHECK: %[[VAL_74:.*]] = memref.load %[[VAL_75:.*]]#1{{\[}}%[[VAL_12]]] : memref<3xindex> // CHECK: %[[VAL_76:.*]] = memref.load %[[VAL_75]]#2{{\[}}%[[VAL_12]]] : memref // CHECK: %[[VAL_77:.*]] = scf.for %[[VAL_78:.*]] = %[[VAL_13]] to %[[VAL_74]] step %[[VAL_13]] iter_args(%[[VAL_79:.*]] = %[[VAL_76]]) -> (index) { // CHECK: %[[VAL_80:.*]] = memref.load %[[VAL_75]]#2{{\[}}%[[VAL_78]]] : memref // CHECK: %[[VAL_81:.*]] = arith.cmpi eq, %[[VAL_80]], %[[VAL_12]] : index // CHECK: %[[VAL_82:.*]] = arith.select %[[VAL_81]], %[[VAL_79]], %[[VAL_80]] : index // CHECK: scf.if %[[VAL_81]] { // CHECK: memref.store %[[VAL_79]], %[[VAL_75]]#2{{\[}}%[[VAL_78]]] : memref // CHECK: } // CHECK: scf.yield %[[VAL_82]] : index // CHECK: } // CHECK: return %[[VAL_75]]#0, %[[VAL_75]]#1, %[[VAL_75]]#2, %[[VAL_75]]#3, %[[VAL_75]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref // CHECK: } func.func @matmul(%A: tensor<4x8xf64, #CSR>, %B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> { %C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR> %D = linalg.matmul ins(%A, %B: tensor<4x8xf64, #CSR>, tensor<8x4xf64, #CSR>) outs(%C: tensor<4x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> return %D: tensor<4x4xf64, #CSR> }