[mlir][sparse] Introducing options for the SparseTensorConversion pass

This is work towards: https://github.com/llvm/llvm-project/issues/51652

This differential sets up the options and threads them through everywhere, but doesn't actually use them yet.  The differential that finally makes use of them is D122061, which is the final differential in the chain that fixes bug 51652.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D122054
This commit is contained in:
wren romano
2022-03-18 19:10:40 -07:00
parent 110295ebb7
commit c7e24db412
6 changed files with 124 additions and 15 deletions

View File

@@ -453,7 +453,18 @@ class SparseTensorInitConverter : public OpConversionPattern<InitOp> {
/// Sparse conversion rule for the convert operator.
class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
/// Options to control sparse code generation.
SparseTensorConversionOptions options;
public:
using OpConversionPattern::OpConversionPattern;
SparseTensorConvertConverter(MLIRContext *context,
SparseTensorConversionOptions o)
: OpConversionPattern<ConvertOp>(context), options(o) {}
SparseTensorConvertConverter(TypeConverter &typeConv, MLIRContext *context,
SparseTensorConversionOptions o)
: OpConversionPattern<ConvertOp>(typeConv, context), options(o) {}
LogicalResult
matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -825,14 +836,17 @@ public:
/// Populates the given patterns list with conversion rules required for
/// the sparsification of linear algebra operations.
void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
void mlir::populateSparseTensorConversionPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns,
const SparseTensorConversionOptions &options) {
patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
SparseCastConverter, SparseTensorNewConverter,
SparseTensorInitConverter, SparseTensorConvertConverter,
SparseTensorReleaseConverter, SparseTensorToPointersConverter,
SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
SparseTensorLoadConverter, SparseTensorLexInsertConverter,
SparseTensorExpandConverter, SparseTensorCompressConverter,
SparseTensorOutConverter>(typeConverter, patterns.getContext());
SparseTensorInitConverter, SparseTensorReleaseConverter,
SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
SparseTensorToValuesConverter, SparseTensorLoadConverter,
SparseTensorLexInsertConverter, SparseTensorExpandConverter,
SparseTensorCompressConverter, SparseTensorOutConverter>(
typeConverter, patterns.getContext());
patterns.add<SparseTensorConvertConverter>(typeConverter,
patterns.getContext(), options);
}