Use the regular dialect conversion driver instead of the 1:N dialect conversion driver. The 1:N dialect conversion driver will be removed soon.
523 lines
20 KiB
C++
523 lines
20 KiB
C++
//===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
|
|
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
|
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
|
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_SPARSEASSEMBLER
|
|
#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
|
|
#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
|
|
#define GEN_PASS_DEF_SPARSIFICATIONPASS
|
|
#define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
|
|
#define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
|
|
#define GEN_PASS_DEF_LOWERFOREACHTOSCF
|
|
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
|
|
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
|
|
#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
|
|
#define GEN_PASS_DEF_SPARSEVECTORIZATION
|
|
#define GEN_PASS_DEF_SPARSEGPUCODEGEN
|
|
#define GEN_PASS_DEF_STAGESPARSEOPERATIONS
|
|
#define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
|
|
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::sparse_tensor;
|
|
|
|
namespace {
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Passes implementation.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> {
|
|
SparseAssembler() = default;
|
|
SparseAssembler(const SparseAssembler &pass) = default;
|
|
SparseAssembler(bool dO) { directOut = dO; }
|
|
|
|
void runOnOperation() override {
|
|
auto *ctx = &getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
populateSparseAssembler(patterns, directOut);
|
|
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct SparseReinterpretMap
|
|
: public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
|
|
SparseReinterpretMap() = default;
|
|
SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
|
|
SparseReinterpretMap(const SparseReinterpretMapOptions &options) {
|
|
scope = options.scope;
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
auto *ctx = &getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
populateSparseReinterpretMap(patterns, scope);
|
|
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct PreSparsificationRewritePass
|
|
: public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
|
|
PreSparsificationRewritePass() = default;
|
|
PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) =
|
|
default;
|
|
|
|
void runOnOperation() override {
|
|
auto *ctx = &getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
populatePreSparsificationRewriting(patterns);
|
|
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct SparsificationPass
|
|
: public impl::SparsificationPassBase<SparsificationPass> {
|
|
SparsificationPass() = default;
|
|
SparsificationPass(const SparsificationPass &pass) = default;
|
|
SparsificationPass(const SparsificationOptions &options) {
|
|
parallelization = options.parallelizationStrategy;
|
|
sparseEmitStrategy = options.sparseEmitStrategy;
|
|
enableRuntimeLibrary = options.enableRuntimeLibrary;
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
auto *ctx = &getContext();
|
|
// Translate strategy flags to strategy options.
|
|
SparsificationOptions options(parallelization, sparseEmitStrategy,
|
|
enableRuntimeLibrary);
|
|
// Apply sparsification and cleanup rewriting.
|
|
RewritePatternSet patterns(ctx);
|
|
populateSparsificationPatterns(patterns, options);
|
|
scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
|
|
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct StageSparseOperationsPass
|
|
: public impl::StageSparseOperationsBase<StageSparseOperationsPass> {
|
|
StageSparseOperationsPass() = default;
|
|
StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default;
|
|
void runOnOperation() override {
|
|
auto *ctx = &getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
populateStageSparseOperationsPatterns(patterns);
|
|
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct LowerSparseOpsToForeachPass
|
|
: public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> {
|
|
LowerSparseOpsToForeachPass() = default;
|
|
LowerSparseOpsToForeachPass(const LowerSparseOpsToForeachPass &pass) =
|
|
default;
|
|
LowerSparseOpsToForeachPass(bool enableRT, bool convert) {
|
|
enableRuntimeLibrary = enableRT;
|
|
enableConvert = convert;
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
auto *ctx = &getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
populateLowerSparseOpsToForeachPatterns(patterns, enableRuntimeLibrary,
|
|
enableConvert);
|
|
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct LowerForeachToSCFPass
|
|
: public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> {
|
|
LowerForeachToSCFPass() = default;
|
|
LowerForeachToSCFPass(const LowerForeachToSCFPass &pass) = default;
|
|
|
|
void runOnOperation() override {
|
|
auto *ctx = &getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
populateLowerForeachToSCFPatterns(patterns);
|
|
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct LowerSparseIterationToSCFPass
|
|
: public impl::LowerSparseIterationToSCFBase<
|
|
LowerSparseIterationToSCFPass> {
|
|
LowerSparseIterationToSCFPass() = default;
|
|
LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) =
|
|
default;
|
|
|
|
void runOnOperation() override {
|
|
auto *ctx = &getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
SparseIterationTypeConverter converter;
|
|
ConversionTarget target(*ctx);
|
|
|
|
// The actual conversion.
|
|
target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
|
|
memref::MemRefDialect, scf::SCFDialect,
|
|
sparse_tensor::SparseTensorDialect>();
|
|
target.addIllegalOp<CoIterateOp, ExtractIterSpaceOp, ExtractValOp,
|
|
IterateOp>();
|
|
target.addLegalOp<UnrealizedConversionCastOp>();
|
|
populateLowerSparseIterationToSCFPatterns(converter, patterns);
|
|
|
|
if (failed(applyPartialConversion(getOperation(), target,
|
|
std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
|
|
struct SparseTensorConversionPass
|
|
: public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
|
|
SparseTensorConversionPass() = default;
|
|
SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
|
|
|
|
void runOnOperation() override {
|
|
auto *ctx = &getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
SparseTensorTypeToPtrConverter converter;
|
|
ConversionTarget target(*ctx);
|
|
// Everything in the sparse dialect must go!
|
|
target.addIllegalDialect<SparseTensorDialect>();
|
|
// All dynamic rules below accept new function, call, return, and various
|
|
// tensor and bufferization operations as legal output of the rewriting
|
|
// provided that all sparse tensor types have been fully rewritten.
|
|
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
|
|
return converter.isSignatureLegal(op.getFunctionType());
|
|
});
|
|
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
|
|
return converter.isSignatureLegal(op.getCalleeType());
|
|
});
|
|
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
|
|
return converter.isLegal(op.getOperandTypes());
|
|
});
|
|
target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
|
|
return converter.isLegal(op.getOperandTypes());
|
|
});
|
|
target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
|
|
return converter.isLegal(op.getSource().getType()) &&
|
|
converter.isLegal(op.getDest().getType());
|
|
});
|
|
target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
|
|
[&](tensor::ExpandShapeOp op) {
|
|
return converter.isLegal(op.getSrc().getType()) &&
|
|
converter.isLegal(op.getResult().getType());
|
|
});
|
|
target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
|
|
[&](tensor::CollapseShapeOp op) {
|
|
return converter.isLegal(op.getSrc().getType()) &&
|
|
converter.isLegal(op.getResult().getType());
|
|
});
|
|
target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
|
|
[&](bufferization::AllocTensorOp op) {
|
|
return converter.isLegal(op.getType());
|
|
});
|
|
target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
|
|
[&](bufferization::DeallocTensorOp op) {
|
|
return converter.isLegal(op.getTensor().getType());
|
|
});
|
|
// The following operations and dialects may be introduced by the
|
|
// rewriting rules, and are therefore marked as legal.
|
|
target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
|
|
linalg::YieldOp, tensor::ExtractOp,
|
|
tensor::FromElementsOp>();
|
|
target.addLegalDialect<
|
|
arith::ArithDialect, bufferization::BufferizationDialect,
|
|
LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
|
|
|
|
// Populate with rules and apply rewriting rules.
|
|
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
|
|
converter);
|
|
populateCallOpTypeConversionPattern(patterns, converter);
|
|
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
|
|
target);
|
|
populateSparseTensorConversionPatterns(converter, patterns);
|
|
if (failed(applyPartialConversion(getOperation(), target,
|
|
std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
|
|
struct SparseTensorCodegenPass
|
|
: public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
|
|
SparseTensorCodegenPass() = default;
|
|
SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
|
|
SparseTensorCodegenPass(bool createDeallocs, bool enableInit) {
|
|
createSparseDeallocs = createDeallocs;
|
|
enableBufferInitialization = enableInit;
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
auto *ctx = &getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
SparseTensorTypeToBufferConverter converter;
|
|
ConversionTarget target(*ctx);
|
|
// Most ops in the sparse dialect must go!
|
|
target.addIllegalDialect<SparseTensorDialect>();
|
|
target.addLegalOp<SortOp>();
|
|
target.addLegalOp<PushBackOp>();
|
|
// Storage specifier outlives sparse tensor pipeline.
|
|
target.addLegalOp<GetStorageSpecifierOp>();
|
|
target.addLegalOp<SetStorageSpecifierOp>();
|
|
target.addLegalOp<StorageSpecifierInitOp>();
|
|
// Note that tensor::FromElementsOp might be yield after lowering unpack.
|
|
target.addLegalOp<tensor::FromElementsOp>();
|
|
// All dynamic rules below accept new function, call, return, and
|
|
// various tensor and bufferization operations as legal output of the
|
|
// rewriting provided that all sparse tensor types have been fully
|
|
// rewritten.
|
|
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
|
|
return converter.isSignatureLegal(op.getFunctionType());
|
|
});
|
|
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
|
|
return converter.isSignatureLegal(op.getCalleeType());
|
|
});
|
|
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
|
|
return converter.isLegal(op.getOperandTypes());
|
|
});
|
|
target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
|
|
[&](bufferization::AllocTensorOp op) {
|
|
return converter.isLegal(op.getType());
|
|
});
|
|
target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
|
|
[&](bufferization::DeallocTensorOp op) {
|
|
return converter.isLegal(op.getTensor().getType());
|
|
});
|
|
// The following operations and dialects may be introduced by the
|
|
// codegen rules, and are therefore marked as legal.
|
|
target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
|
|
target.addLegalDialect<
|
|
arith::ArithDialect, bufferization::BufferizationDialect,
|
|
complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
|
|
target.addLegalOp<UnrealizedConversionCastOp>();
|
|
// Populate with rules and apply rewriting rules.
|
|
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
|
|
converter);
|
|
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
|
|
target);
|
|
populateSparseTensorCodegenPatterns(
|
|
converter, patterns, createSparseDeallocs, enableBufferInitialization);
|
|
if (failed(applyPartialConversion(getOperation(), target,
|
|
std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
|
|
struct SparseBufferRewritePass
|
|
: public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
|
|
SparseBufferRewritePass() = default;
|
|
SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
|
|
SparseBufferRewritePass(bool enableInit) {
|
|
enableBufferInitialization = enableInit;
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
auto *ctx = &getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
populateSparseBufferRewriting(patterns, enableBufferInitialization);
|
|
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct SparseVectorizationPass
|
|
: public impl::SparseVectorizationBase<SparseVectorizationPass> {
|
|
SparseVectorizationPass() = default;
|
|
SparseVectorizationPass(const SparseVectorizationPass &pass) = default;
|
|
SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) {
|
|
vectorLength = vl;
|
|
enableVLAVectorization = vla;
|
|
enableSIMDIndex32 = sidx32;
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
if (vectorLength == 0)
|
|
return signalPassFailure();
|
|
auto *ctx = &getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
populateSparseVectorizationPatterns(
|
|
patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
|
|
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
|
|
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct SparseGPUCodegenPass
|
|
: public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> {
|
|
SparseGPUCodegenPass() = default;
|
|
SparseGPUCodegenPass(const SparseGPUCodegenPass &pass) = default;
|
|
SparseGPUCodegenPass(unsigned nT, bool enableRT) {
|
|
numThreads = nT;
|
|
enableRuntimeLibrary = enableRT;
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
auto *ctx = &getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
if (numThreads == 0)
|
|
populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary);
|
|
else
|
|
populateSparseGPUCodegenPatterns(patterns, numThreads);
|
|
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct StorageSpecifierToLLVMPass
|
|
: public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
|
|
StorageSpecifierToLLVMPass() = default;
|
|
|
|
void runOnOperation() override {
|
|
auto *ctx = &getContext();
|
|
ConversionTarget target(*ctx);
|
|
RewritePatternSet patterns(ctx);
|
|
StorageSpecifierToLLVMTypeConverter converter;
|
|
|
|
// All ops in the sparse dialect must go!
|
|
target.addIllegalDialect<SparseTensorDialect>();
|
|
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
|
|
return converter.isSignatureLegal(op.getFunctionType());
|
|
});
|
|
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
|
|
return converter.isSignatureLegal(op.getCalleeType());
|
|
});
|
|
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
|
|
return converter.isLegal(op.getOperandTypes());
|
|
});
|
|
target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
|
|
|
|
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
|
|
converter);
|
|
populateCallOpTypeConversionPattern(patterns, converter);
|
|
populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
|
|
populateReturnOpTypeConversionPattern(patterns, converter);
|
|
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
|
|
target);
|
|
populateStorageSpecifierToLLVMPatterns(converter, patterns);
|
|
if (failed(applyPartialConversion(getOperation(), target,
|
|
std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Pass creation methods.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
std::unique_ptr<Pass> mlir::createSparseAssembler() {
|
|
return std::make_unique<SparseAssembler>();
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() {
|
|
return std::make_unique<SparseReinterpretMap>();
|
|
}
|
|
|
|
std::unique_ptr<Pass>
|
|
mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) {
|
|
SparseReinterpretMapOptions options;
|
|
options.scope = scope;
|
|
return std::make_unique<SparseReinterpretMap>(options);
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() {
|
|
return std::make_unique<PreSparsificationRewritePass>();
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createSparsificationPass() {
|
|
return std::make_unique<SparsificationPass>();
|
|
}
|
|
|
|
std::unique_ptr<Pass>
|
|
mlir::createSparsificationPass(const SparsificationOptions &options) {
|
|
return std::make_unique<SparsificationPass>(options);
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() {
|
|
return std::make_unique<StageSparseOperationsPass>();
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createLowerSparseOpsToForeachPass() {
|
|
return std::make_unique<LowerSparseOpsToForeachPass>();
|
|
}
|
|
|
|
std::unique_ptr<Pass>
|
|
mlir::createLowerSparseOpsToForeachPass(bool enableRT, bool enableConvert) {
|
|
return std::make_unique<LowerSparseOpsToForeachPass>(enableRT, enableConvert);
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
|
|
return std::make_unique<LowerForeachToSCFPass>();
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createLowerSparseIterationToSCFPass() {
|
|
return std::make_unique<LowerSparseIterationToSCFPass>();
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
|
|
return std::make_unique<SparseTensorConversionPass>();
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
|
|
return std::make_unique<SparseTensorCodegenPass>();
|
|
}
|
|
|
|
std::unique_ptr<Pass>
|
|
mlir::createSparseTensorCodegenPass(bool createSparseDeallocs,
|
|
bool enableBufferInitialization) {
|
|
return std::make_unique<SparseTensorCodegenPass>(createSparseDeallocs,
|
|
enableBufferInitialization);
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() {
|
|
return std::make_unique<SparseBufferRewritePass>();
|
|
}
|
|
|
|
std::unique_ptr<Pass>
|
|
mlir::createSparseBufferRewritePass(bool enableBufferInitialization) {
|
|
return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createSparseVectorizationPass() {
|
|
return std::make_unique<SparseVectorizationPass>();
|
|
}
|
|
|
|
std::unique_ptr<Pass>
|
|
mlir::createSparseVectorizationPass(unsigned vectorLength,
|
|
bool enableVLAVectorization,
|
|
bool enableSIMDIndex32) {
|
|
return std::make_unique<SparseVectorizationPass>(
|
|
vectorLength, enableVLAVectorization, enableSIMDIndex32);
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass() {
|
|
return std::make_unique<SparseGPUCodegenPass>();
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass(unsigned numThreads,
|
|
bool enableRT) {
|
|
return std::make_unique<SparseGPUCodegenPass>(numThreads, enableRT);
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createStorageSpecifierToLLVMPass() {
|
|
return std::make_unique<StorageSpecifierToLLVMPass>();
|
|
}
|