[mlir][sparse] simplify ConvertOp rewriting rules (#68350)

Canonicalize complex convertOp into multiple stages, such that it can
either be done by a direct conversion or by sorting.
This commit is contained in:
Peiming Liu
2023-10-11 09:34:11 -07:00
committed by GitHub
parent 12b87f6ef7
commit dda3dc5e38
7 changed files with 277 additions and 255 deletions

View File

@@ -679,6 +679,50 @@ public:
}
};
// TODO: use a new SortCOO operation here instead of reusing convert op.
struct SparseSortCOOConverter : public OpConversionPattern<ConvertOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ConvertOp op, ConvertOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Direct conversion should have already been lowered.
if (!op.isSortCOOConvert())
return failure();
Location loc = op.getLoc();
MLIRContext *ctx = op.getContext();
SparseTensorType srcStt = getSparseTensorType(op.getSource());
SparseTensorType dstStt = getSparseTensorType(op.getDest());
// TODO: This should be verification rules for sort_coo operation.
assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
isUniqueCOOType(srcStt.getRankedTensorType()) &&
isUniqueCOOType(dstStt.getRankedTensorType()));
assert(dstStt.hasSameDimToLvl(srcStt));
// We don't need a mutable descriptor here as we perform sorting in-place.
auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getSource());
auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
auto crd = desc.getAOSMemRef();
auto val = desc.getValMemRef();
// Otherwise we need another data shuffle and a non-identity map.
assert(dstStt.hasSameDimToLvl(srcStt));
auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx);
rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id,
rewriter.getIndexAttr(0),
SparseTensorSortKind::HybridQuickSort);
// Since we do in-place sorting, the destinate tensor will have the same set
// of memrefs as the source tensor.
rewriter.replaceOp(op, adaptor.getSource());
return success();
}
};
template <typename Op, StorageSpecifierKind kind>
class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
public:
@@ -1101,6 +1145,9 @@ public:
LogicalResult
matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.isSortCOOConvert())
return failure();
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
SparseTensorEncodingAttr encSrc =
getSparseTensorEncoding(op.getSource().getType());
@@ -1554,6 +1601,7 @@ void mlir::populateSparseTensorCodegenPatterns(
SparseCastConverter, SparseExtractSliceConverter,
SparseTensorLoadConverter, SparseExpandConverter,
SparseCompressConverter, SparseInsertConverter,
SparseSortCOOConverter,
SparseSliceGetterOpConverter<ToSliceOffsetOp,
StorageSpecifierKind::DimOffset>,
SparseSliceGetterOpConverter<ToSliceStrideOp,