[mlir][sparse] Refactoring: remove dependence on tuple type when lowering sparse tensors.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D133390
This commit is contained in:
Peiming Liu
2022-09-07 00:49:44 +00:00
parent 300155911a
commit edca72f5bc
12 changed files with 317 additions and 876 deletions

View File

@@ -24,7 +24,6 @@ namespace mlir {
#define GEN_PASS_DEF_SPARSIFICATIONPASS
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
#define GEN_PASS_DEF_SPARSETENSORSTORAGEEXPANSION
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
} // namespace mlir
@@ -154,9 +153,8 @@ struct SparseTensorCodegenPass
RewritePatternSet patterns(ctx);
SparseTensorTypeToBufferConverter converter;
ConversionTarget target(*ctx);
// Almost everything in the sparse dialect must go!
// Everything in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
target.addLegalOp<StorageGetOp, StorageSetOp, StorageOp>();
// All dynamic rules below accept new function, call, return, and various
// tensor and bufferization operations as legal output of the rewriting
// provided that all sparse tensor types have been fully rewritten.
@@ -181,53 +179,13 @@ struct SparseTensorCodegenPass
target.addLegalDialect<arith::ArithmeticDialect,
bufferization::BufferizationDialect,
memref::MemRefDialect, scf::SCFDialect>();
// Populate with rules and apply rewriting rules.
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
converter);
populateCallOpTypeConversionPattern(patterns, converter);
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
target);
populateSparseTensorCodegenPatterns(converter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
struct SparseTensorStorageExpansionPass
: public impl::SparseTensorStorageExpansionBase<
SparseTensorStorageExpansionPass> {
SparseTensorStorageExpansionPass() = default;
SparseTensorStorageExpansionPass(
const SparseTensorStorageExpansionPass &pass) = default;
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
SparseTensorStorageTupleExpander converter;
ConversionTarget target(*ctx);
// Now, everything in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
// All dynamic rules below accept new function, call, return.
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType());
});
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
return converter.isSignatureLegal(op.getCalleeType());
});
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
return converter.isLegal(op.getOperandTypes());
});
// We generate UnrealizedConversionCastOp to intermix tuples and a
// list of types.
target.addLegalOp<UnrealizedConversionCastOp>();
// Populate with rules and apply rewriting rules.
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
converter);
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
target);
populateSparseTensorStorageExpansionPatterns(converter, patterns);
populateSparseTensorCodegenPatterns(converter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
@@ -277,7 +235,3 @@ std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
return std::make_unique<SparseTensorCodegenPass>();
}
std::unique_ptr<Pass> mlir::createSparseTensorStorageExpansionPass() {
return std::make_unique<SparseTensorStorageExpansionPass>();
}