[mlir][sparse] unify sparse_tensor.out rewriting rules (#70518)

This commit is contained in:
Peiming Liu
2023-10-27 16:46:58 -07:00
committed by GitHub
parent 01828c4323
commit 7d608ee2bb
4 changed files with 23 additions and 80 deletions

View File

@@ -270,13 +270,6 @@ static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
.getResult(0);
}
/// Generates a call to release/delete a `SparseTensorCOO`.
static void genDelCOOCall(OpBuilder &builder, Location loc, Type elemTp,
Value coo) {
SmallString<21> name{"delSparseTensorCOO", primaryTypeFunctionSuffix(elemTp)};
createFuncCall(builder, loc, name, {}, coo, EmitCInterface::Off);
}
//===----------------------------------------------------------------------===//
// Conversion rules.
//===----------------------------------------------------------------------===//
@@ -707,37 +700,6 @@ public:
}
};
/// Sparse conversion rule for the output operator.
class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(OutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const Location loc = op->getLoc();
const auto srcTp = getSparseTensorType(op.getTensor());
// Convert to default permuted COO.
Value src = adaptor.getOperands()[0];
SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
Value coo = NewCallParams(rewriter, loc)
.genBuffers(srcTp.withoutDimToLvl(), dimSizes)
.genNewCall(Action::kToCOO, src);
// Then output the tensor to external file with coordinates in the
// externally visible lexicographic coordinate order. A sort is
// required if the source was not in that order yet (note that the
// sort can be dropped altogether if external format does not care
// about the order at all, but here we assume it does).
const Value sort = constantI1(rewriter, loc, !srcTp.isIdentity());
SmallVector<Value, 3> outParams{coo, adaptor.getOperands()[1], sort};
const Type elemTp = srcTp.getElementType();
SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(elemTp)};
createFuncCall(rewriter, loc, name, {}, outParams, EmitCInterface::Off);
genDelCOOCall(rewriter, loc, elemTp, coo);
rewriter.eraseOp(op);
return success();
}
};
/// Sparse conversion rule for the sparse_tensor.pack operator.
class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
public:
@@ -789,6 +751,5 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
SparseTensorLoadConverter, SparseTensorInsertConverter,
SparseTensorExpandConverter, SparseTensorCompressConverter,
SparseTensorOutConverter, SparseTensorAssembleConverter>(
typeConverter, patterns.getContext());
SparseTensorAssembleConverter>(typeConverter, patterns.getContext());
}