Files
clang-p2996/mlir/test/lib/Transforms/TestSparsification.cpp
Aart Bik 0b1764a3d7 [mlir][sparse] sparse tensor storage implementation
This revision connects the generated sparse code with an actual
sparse storage scheme, which can be initialized from a test file.
Lacking a first-class citizen SparseTensor type (with buffer),
the storage is hidden behind an opaque pointer with some "glue"
to bring the pointer back to tensor land. Rather than generating
sparse setup code for each different annotated tensor (viz. the
"pack" methods in TACO), a single "one-size-fits-all" implementation
has been added to the runtime support library.  Many details and
abstractions need to be refined in the future, but this revision
allows full end-to-end integration testing and performance
benchmarking (with on one end, an annotated Lingalg
op and, on the other end, a JIT/AOT executable).

Reviewed By: nicolasvasilache, bixia

Differential Revision: https://reviews.llvm.org/D95847
2021-02-10 11:57:24 -08:00

143 lines
5.0 KiB
C++

//===- TestSparsification.cpp - Test sparsification of tensors ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
namespace {
struct TestSparsification
: public PassWrapper<TestSparsification, FunctionPass> {
TestSparsification() = default;
TestSparsification(const TestSparsification &pass) {}
Option<int32_t> parallelization{
*this, "parallelization-strategy",
llvm::cl::desc("Set the parallelization strategy"), llvm::cl::init(0)};
Option<int32_t> vectorization{
*this, "vectorization-strategy",
llvm::cl::desc("Set the vectorization strategy"), llvm::cl::init(0)};
Option<int32_t> vectorLength{
*this, "vl", llvm::cl::desc("Set the vector length"), llvm::cl::init(1)};
Option<int32_t> ptrType{*this, "ptr-type",
llvm::cl::desc("Set the pointer type"),
llvm::cl::init(0)};
Option<int32_t> indType{*this, "ind-type",
llvm::cl::desc("Set the index type"),
llvm::cl::init(0)};
Option<bool> fastOutput{*this, "fast-output",
llvm::cl::desc("Allows fast output buffers"),
llvm::cl::init(false)};
Option<bool> lower{*this, "lower", llvm::cl::desc("Lower sparse primitives"),
llvm::cl::init(false)};
/// Registers all dialects required by testing.
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<scf::SCFDialect, vector::VectorDialect, LLVM::LLVMDialect>();
}
/// Returns parallelization strategy given on command line.
linalg::SparseParallelizationStrategy parallelOption() {
switch (parallelization) {
default:
return linalg::SparseParallelizationStrategy::kNone;
case 1:
return linalg::SparseParallelizationStrategy::kDenseOuterLoop;
case 2:
return linalg::SparseParallelizationStrategy::kAnyStorageOuterLoop;
case 3:
return linalg::SparseParallelizationStrategy::kDenseAnyLoop;
case 4:
return linalg::SparseParallelizationStrategy::kAnyStorageAnyLoop;
}
}
/// Returns vectorization strategy given on command line.
linalg::SparseVectorizationStrategy vectorOption() {
switch (vectorization) {
default:
return linalg::SparseVectorizationStrategy::kNone;
case 1:
return linalg::SparseVectorizationStrategy::kDenseInnerLoop;
case 2:
return linalg::SparseVectorizationStrategy::kAnyStorageInnerLoop;
}
}
/// Returns the requested integer type.
linalg::SparseIntType typeOption(int32_t option) {
switch (option) {
default:
return linalg::SparseIntType::kNative;
case 1:
return linalg::SparseIntType::kI64;
case 2:
return linalg::SparseIntType::kI32;
case 3:
return linalg::SparseIntType::kI16;
case 4:
return linalg::SparseIntType::kI8;
}
}
/// Runs the test on a function.
void runOnFunction() override {
auto *ctx = &getContext();
OwningRewritePatternList patterns;
// Translate strategy flags to strategy options.
linalg::SparsificationOptions options(parallelOption(), vectorOption(),
vectorLength, typeOption(ptrType),
typeOption(indType), fastOutput);
// Apply rewriting.
linalg::populateSparsificationPatterns(ctx, patterns, options);
vector::populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
// Lower sparse primitives to calls into runtime support library.
if (lower) {
OwningRewritePatternList conversionPatterns;
ConversionTarget target(*ctx);
target.addIllegalOp<linalg::SparseTensorFromPointerOp,
linalg::SparseTensorToPointersMemRefOp,
linalg::SparseTensorToIndicesMemRefOp,
linalg::SparseTensorToValuesMemRefOp>();
target.addLegalOp<CallOp>();
linalg::populateSparsificationConversionPatterns(ctx, conversionPatterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(conversionPatterns))))
signalPassFailure();
}
}
};
} // end anonymous namespace
namespace mlir {
namespace test {
void registerTestSparsification() {
PassRegistration<TestSparsification> sparsificationPass(
"test-sparsification",
"Test automatic generation of sparse tensor code");
}
} // namespace test
} // namespace mlir