[mlir][sparse] simplify ConvertOp rewriting rules (#68350)
Canonicalize complex convertOp into multiple stages, such that it can either be done by a direct conversion or by sorting.
This commit is contained in:
@@ -679,6 +679,50 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: use a new SortCOO operation here instead of reusing convert op.
|
||||
struct SparseSortCOOConverter : public OpConversionPattern<ConvertOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(ConvertOp op, ConvertOpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Direct conversion should have already been lowered.
|
||||
if (!op.isSortCOOConvert())
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *ctx = op.getContext();
|
||||
|
||||
SparseTensorType srcStt = getSparseTensorType(op.getSource());
|
||||
SparseTensorType dstStt = getSparseTensorType(op.getDest());
|
||||
|
||||
// TODO: This should be verification rules for sort_coo operation.
|
||||
assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
|
||||
isUniqueCOOType(srcStt.getRankedTensorType()) &&
|
||||
isUniqueCOOType(dstStt.getRankedTensorType()));
|
||||
|
||||
assert(dstStt.hasSameDimToLvl(srcStt));
|
||||
|
||||
// We don't need a mutable descriptor here as we perform sorting in-place.
|
||||
auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getSource());
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
|
||||
auto crd = desc.getAOSMemRef();
|
||||
auto val = desc.getValMemRef();
|
||||
|
||||
// Otherwise we need another data shuffle and a non-identity map.
|
||||
assert(dstStt.hasSameDimToLvl(srcStt));
|
||||
auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx);
|
||||
|
||||
rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id,
|
||||
rewriter.getIndexAttr(0),
|
||||
SparseTensorSortKind::HybridQuickSort);
|
||||
|
||||
// Since we do in-place sorting, the destinate tensor will have the same set
|
||||
// of memrefs as the source tensor.
|
||||
rewriter.replaceOp(op, adaptor.getSource());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Op, StorageSpecifierKind kind>
|
||||
class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
|
||||
public:
|
||||
@@ -1101,6 +1145,9 @@ public:
|
||||
LogicalResult
|
||||
matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (op.isSortCOOConvert())
|
||||
return failure();
|
||||
|
||||
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
|
||||
SparseTensorEncodingAttr encSrc =
|
||||
getSparseTensorEncoding(op.getSource().getType());
|
||||
@@ -1554,6 +1601,7 @@ void mlir::populateSparseTensorCodegenPatterns(
|
||||
SparseCastConverter, SparseExtractSliceConverter,
|
||||
SparseTensorLoadConverter, SparseExpandConverter,
|
||||
SparseCompressConverter, SparseInsertConverter,
|
||||
SparseSortCOOConverter,
|
||||
SparseSliceGetterOpConverter<ToSliceOffsetOp,
|
||||
StorageSpecifierKind::DimOffset>,
|
||||
SparseSliceGetterOpConverter<ToSliceStrideOp,
|
||||
|
||||
Reference in New Issue
Block a user