From 07bf1ddb4eb0abfff20542fd4459bace1f72107f Mon Sep 17 00:00:00 2001 From: Peiming Liu <36770114+PeimingLiu@users.noreply.github.com> Date: Thu, 1 Feb 2024 17:11:33 -0600 Subject: [PATCH] [mlir][sparse] support non-id map for [Dis]assembleOp (#80355) --- .../SparseTensor/IR/SparseTensorDialect.cpp | 2 - .../Transforms/SparseReinterpretMap.cpp | 37 +++++++++++++- .../SparseTensor/sparse_reinterpret_map.mlir | 48 +++++++++++++++++++ 3 files changed, 84 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 6033ebf6897c..27125bc7ed45 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -1016,8 +1016,6 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, return op->emitError("the sparse-tensor must have static shape"); if (!stt.hasEncoding()) return op->emitError("the sparse-tensor must have an encoding attribute"); - if (!stt.isIdentity()) - return op->emitError("the sparse-tensor must have the identity mapping"); // Verifies the trailing COO. Level cooStartLvl = stt.getCOOStart(); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp index a0f7b55ce444..fbe2fc31ab8b 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp @@ -656,6 +656,40 @@ struct TensorInsertDemapper } }; +struct SparseAssembleDemapper : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AssembleOp op, + PatternRewriter &rewriter) const override { + if (!hasAnyNonIdentityOperandsOrResults(op)) + return failure(); + + assert(hasAnySparseResult(op)); + auto stt = getSparseTensorType(op.getResult()); + rewriter.modifyOpInPlace( + op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); }); + rewriter.setInsertionPointAfter(op); + Value out = genRemap(rewriter, stt.getEncoding(), op.getResult()); + rewriter.replaceAllUsesExcept(op, out, out.getDefiningOp()); + return success(); + } +}; + +struct SparseDisassembleDemapper + : public DemapInsRewriter { + using DemapInsRewriter::DemapInsRewriter; + LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor, + PatternRewriter &rewriter) const { + if (!hasAnyNonIdentityOperandsOrResults(op)) + return failure(); + + assert(hasAnySparseOperandOrResult(op)); + rewriter.modifyOpInPlace(op, [&op, &adaptor]() { + op.getTensorMutable().assign(adaptor.getTensor()); + }); + return success(); + } +}; + struct ForeachOpDemapper : public DemapInsRewriter { using DemapInsRewriter::DemapInsRewriter; @@ -758,7 +792,8 @@ void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns, if (scope == ReinterpretMapScope::kAll || scope == ReinterpretMapScope::kExceptGeneric) { patterns.add, - TensorAllocDemapper, TensorInsertDemapper, + TensorAllocDemapper, SparseAssembleDemapper, + SparseDisassembleDemapper, TensorInsertDemapper, ForeachOpDemapper>(patterns.getContext()); } } diff --git a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir index 46f04cca03ed..54de1024323b 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir @@ -80,3 +80,51 @@ func.func @sparse_foreach_reinterpret_map(%6 : tensor<2x4xf64, #BSR>) -> tensor< %9 = sparse_tensor.load %8 hasInserts : tensor<2x4xf64, #BSR> return %9 : tensor<2x4xf64, #BSR> } + + +// ----- + +#BSR = #sparse_tensor.encoding<{ + map = ( i, j ) -> + ( i floordiv 2 : dense, + j floordiv 2 : compressed, + i mod 2 : dense, + j mod 2 : dense + ) +}> +// CHECK-DAG: #[[$remap:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 floordiv 2 : dense, d1 floordiv 2 : compressed, d0 mod 2 : dense, d1 mod 2 : dense) }> +// CHECK-DAG: #[[$demap:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : dense, d1 : compressed, d2 : dense, d3 : dense) }> + +// CHECK-LABEL: func.func @sparse_assemble_reinterpret_map( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor, +// CHECK-SAME: %[[VAL_2:.*]]: tensor) -> tensor<2x4xf64, #[[$remap]]> { +// CHECK: %[[VAL_3:.*]] = sparse_tensor.assemble %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : tensor, tensor, tensor to tensor<1x2x2x2xf64, #[[$demap]]> +// CHECK: %[[VAL_4:.*]] = sparse_tensor.reinterpret_map %[[VAL_3]] : tensor<1x2x2x2xf64, #[[$demap]]> to tensor<2x4xf64, #[[$remap]]> +// CHECK: return %[[VAL_4]] : tensor<2x4xf64, #[[$remap]]> +// CHECK: } +func.func @sparse_assemble_reinterpret_map(%val : tensor, %pos:tensor, %crd:tensor) -> tensor<2x4xf64, #BSR> { + %0 = sparse_tensor.assemble %val, %pos, %crd + : tensor, tensor, tensor to tensor<2x4xf64, #BSR> + return %0 : tensor<2x4xf64, #BSR> +} + +// CHECK-LABEL: func.func @sparse_disassemble_reinterpret_map( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x4xf64, #[[$remap]]>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor, +// CHECK-SAME: %[[VAL_2:.*]]: tensor, +// CHECK-SAME: %[[VAL_3:.*]]: tensor) -> (tensor, tensor, tensor) { +// CHECK: %[[VAL_4:.*]] = sparse_tensor.reinterpret_map %[[VAL_0]] : tensor<2x4xf64, #[[$remap]]> to tensor<1x2x2x2xf64, #[[$demap]]> +// CHECK: %[[VAL_5:.*]], %[[VAL_6:.*]]:2, %[[VAL_7:.*]], %[[VAL_8:.*]]:2 = sparse_tensor.disassemble %[[VAL_4]] : tensor<1x2x2x2xf64, #[[$demap]]> +// CHECK: return +// CHECK: } +func.func @sparse_disassemble_reinterpret_map(%sp : tensor<2x4xf64, #BSR>, + %od : tensor, + %op : tensor, + %oi : tensor) + -> (tensor, tensor, tensor) { + %rd, %rp, %ri, %dl, %pl, %il = sparse_tensor.disassemble %sp : tensor<2x4xf64, #BSR> + outs(%od, %op, %oi : tensor, tensor, tensor) + -> tensor, (tensor, tensor), index, (index, index) + return %rd, %rp, %ri : tensor, tensor, tensor +}