As discussed in https://llvm.discourse.group/t/mlir-support-for-sparse-tensors/2020 this CL is the start of sparse tensor compiler support in MLIR. Starting with a "dense" kernel expressed in the Linalg dialect together with per-dimension sparsity annotations on the tensors, the compiler automatically lowers the kernel to sparse code using the methods described in Fredrik Kjolstad's thesis. Many details are still TBD. For example, the sparse "bufferization" is purely done locally since we don't have a global solution for propagating sparsity yet. Furthermore, code to input and output the sparse tensors is missing. Nevertheless, with some hand modifications, the generated MLIR can be easily converted into runnable code already. Reviewed By: nicolasvasilache, ftynse Differential Revision: https://reviews.llvm.org/D90994
43 lines
1.3 KiB
C++
43 lines
1.3 KiB
C++
//===- TestSparsification.cpp - Test sparsification of tensors ------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
struct TestSparsification
|
|
: public PassWrapper<TestSparsification, FunctionPass> {
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<scf::SCFDialect>();
|
|
}
|
|
void runOnFunction() override {
|
|
auto *ctx = &getContext();
|
|
OwningRewritePatternList patterns;
|
|
linalg::populateSparsificationPatterns(ctx, patterns);
|
|
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
|
|
void registerTestSparsification() {
|
|
PassRegistration<TestSparsification> sparsificationPass(
|
|
"test-sparsification",
|
|
"Test automatic geneneration of sparse tensor code");
|
|
}
|
|
|
|
} // namespace test
|
|
} // namespace mlir
|