[mlir][sparse] Refactoring: remove dependence on tuple type when lowering sparse tensors.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D133390
This commit is contained in:
@@ -24,7 +24,6 @@ namespace mlir {
|
||||
#define GEN_PASS_DEF_SPARSIFICATIONPASS
|
||||
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
|
||||
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
|
||||
#define GEN_PASS_DEF_SPARSETENSORSTORAGEEXPANSION
|
||||
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
|
||||
} // namespace mlir
|
||||
|
||||
@@ -154,9 +153,8 @@ struct SparseTensorCodegenPass
|
||||
RewritePatternSet patterns(ctx);
|
||||
SparseTensorTypeToBufferConverter converter;
|
||||
ConversionTarget target(*ctx);
|
||||
// Almost everything in the sparse dialect must go!
|
||||
// Everything in the sparse dialect must go!
|
||||
target.addIllegalDialect<SparseTensorDialect>();
|
||||
target.addLegalOp<StorageGetOp, StorageSetOp, StorageOp>();
|
||||
// All dynamic rules below accept new function, call, return, and various
|
||||
// tensor and bufferization operations as legal output of the rewriting
|
||||
// provided that all sparse tensor types have been fully rewritten.
|
||||
@@ -181,53 +179,13 @@ struct SparseTensorCodegenPass
|
||||
target.addLegalDialect<arith::ArithmeticDialect,
|
||||
bufferization::BufferizationDialect,
|
||||
memref::MemRefDialect, scf::SCFDialect>();
|
||||
// Populate with rules and apply rewriting rules.
|
||||
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
|
||||
converter);
|
||||
populateCallOpTypeConversionPattern(patterns, converter);
|
||||
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
|
||||
target);
|
||||
populateSparseTensorCodegenPatterns(converter, patterns);
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
struct SparseTensorStorageExpansionPass
|
||||
: public impl::SparseTensorStorageExpansionBase<
|
||||
SparseTensorStorageExpansionPass> {
|
||||
|
||||
SparseTensorStorageExpansionPass() = default;
|
||||
SparseTensorStorageExpansionPass(
|
||||
const SparseTensorStorageExpansionPass &pass) = default;
|
||||
|
||||
void runOnOperation() override {
|
||||
auto *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
SparseTensorStorageTupleExpander converter;
|
||||
ConversionTarget target(*ctx);
|
||||
// Now, everything in the sparse dialect must go!
|
||||
target.addIllegalDialect<SparseTensorDialect>();
|
||||
// All dynamic rules below accept new function, call, return.
|
||||
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
|
||||
return converter.isSignatureLegal(op.getFunctionType());
|
||||
});
|
||||
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
|
||||
return converter.isSignatureLegal(op.getCalleeType());
|
||||
});
|
||||
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
|
||||
return converter.isLegal(op.getOperandTypes());
|
||||
});
|
||||
// We generate UnrealizedConversionCastOp to intermix tuples and a
|
||||
// list of types.
|
||||
target.addLegalOp<UnrealizedConversionCastOp>();
|
||||
// Populate with rules and apply rewriting rules.
|
||||
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
|
||||
converter);
|
||||
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
|
||||
target);
|
||||
populateSparseTensorStorageExpansionPatterns(converter, patterns);
|
||||
populateSparseTensorCodegenPatterns(converter, patterns);
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
signalPassFailure();
|
||||
@@ -277,7 +235,3 @@ std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
|
||||
std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
|
||||
return std::make_unique<SparseTensorCodegenPass>();
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createSparseTensorStorageExpansionPass() {
|
||||
return std::make_unique<SparseTensorStorageExpansionPass>();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user