Files
clang-p2996/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
Aart Bik 99b3849d89 [mlir][sparse] introduce vectorization pass for sparse loops
This brings back previous SIMD functionality, but in a separate pass.
The idea is to improve this new pass incrementally, going beyond for-loops
to while-loops for co-iteration as welll (masking), while introducing new
abstractions to make the lowering more progressive. The separation of
sparsification and vectorization is a very good first step on this journey.

Also brings back ArmSVE support

Still to be fine-tuned:
  + use of "index" in SIMD loop (viz. a[i] = i)
  + check that all ops really have SIMD support
  + check all forms of reductions
  + chain reduction SIMD values

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D138236
2022-11-21 16:12:12 -08:00

358 lines
14 KiB
C++

//===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
#define GEN_PASS_DEF_SPARSIFICATIONPASS
#define GEN_PASS_DEF_POSTSPARSIFICATIONREWRITE
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
#define GEN_PASS_DEF_SPARSEVECTORIZATION
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace mlir::sparse_tensor;
namespace {
//===----------------------------------------------------------------------===//
// Passes implementation.
//===----------------------------------------------------------------------===//
struct PreSparsificationRewritePass
: public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
PreSparsificationRewritePass() = default;
PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) =
default;
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populatePreSparsificationRewriting(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct SparsificationPass
: public impl::SparsificationPassBase<SparsificationPass> {
SparsificationPass() = default;
SparsificationPass(const SparsificationPass &pass) = default;
SparsificationPass(const SparsificationOptions &options) {
parallelization = options.parallelizationStrategy;
}
void runOnOperation() override {
auto *ctx = &getContext();
// Translate strategy flags to strategy options.
SparsificationOptions options(parallelization);
// Apply sparsification and cleanup rewriting.
RewritePatternSet patterns(ctx);
populateSparsificationPatterns(patterns, options);
scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct PostSparsificationRewritePass
: public impl::PostSparsificationRewriteBase<
PostSparsificationRewritePass> {
PostSparsificationRewritePass() = default;
PostSparsificationRewritePass(const PostSparsificationRewritePass &pass) =
default;
PostSparsificationRewritePass(bool enableRT, bool foreach, bool convert) {
enableRuntimeLibrary = enableRT;
enableForeach = foreach;
enableConvert = convert;
}
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populatePostSparsificationRewriting(patterns, enableRuntimeLibrary,
enableForeach, enableConvert);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct SparseTensorConversionPass
: public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
SparseTensorConversionPass() = default;
SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
SparseTensorConversionPass(const SparseTensorConversionOptions &options) {
sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy);
}
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
SparseTensorTypeToPtrConverter converter;
ConversionTarget target(*ctx);
// Everything in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
// All dynamic rules below accept new function, call, return, and various
// tensor and bufferization operations as legal output of the rewriting
// provided that all sparse tensor types have been fully rewritten.
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType());
});
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
return converter.isSignatureLegal(op.getCalleeType());
});
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
return converter.isLegal(op.getOperandTypes());
});
target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
return converter.isLegal(op.getOperandTypes());
});
target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
return converter.isLegal(op.getSource().getType()) &&
converter.isLegal(op.getDest().getType());
});
target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
[&](tensor::ExpandShapeOp op) {
return converter.isLegal(op.getSrc().getType()) &&
converter.isLegal(op.getResult().getType());
});
target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
[&](tensor::CollapseShapeOp op) {
return converter.isLegal(op.getSrc().getType()) &&
converter.isLegal(op.getResult().getType());
});
target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
[&](bufferization::AllocTensorOp op) {
return converter.isLegal(op.getType());
});
target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
[&](bufferization::DeallocTensorOp op) {
return converter.isLegal(op.getTensor().getType());
});
// The following operations and dialects may be introduced by the
// rewriting rules, and are therefore marked as legal.
target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
linalg::YieldOp, tensor::ExtractOp>();
target.addLegalDialect<
arith::ArithDialect, bufferization::BufferizationDialect,
LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
// Translate strategy flags to strategy options.
SparseTensorConversionOptions options(
sparseToSparseConversionStrategy(sparseToSparse));
// Populate with rules and apply rewriting rules.
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
converter);
populateCallOpTypeConversionPattern(patterns, converter);
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
target);
populateSparseTensorConversionPatterns(converter, patterns, options);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
struct SparseTensorCodegenPass
: public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
SparseTensorCodegenPass() = default;
SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
SparseTensorCodegenPass(bool enableInit) {
enableBufferInitialization = enableInit;
}
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
SparseTensorTypeToBufferConverter converter;
ConversionTarget target(*ctx);
// Most ops in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
target.addLegalOp<SortOp>();
target.addLegalOp<SortCooOp>();
target.addLegalOp<PushBackOp>();
// All dynamic rules below accept new function, call, return, and various
// tensor and bufferization operations as legal output of the rewriting
// provided that all sparse tensor types have been fully rewritten.
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType());
});
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
return converter.isSignatureLegal(op.getCalleeType());
});
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
return converter.isLegal(op.getOperandTypes());
});
target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
[&](bufferization::AllocTensorOp op) {
return converter.isLegal(op.getType());
});
target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
[&](bufferization::DeallocTensorOp op) {
return converter.isLegal(op.getTensor().getType());
});
// The following operations and dialects may be introduced by the
// codegen rules, and are therefore marked as legal.
target.addLegalOp<linalg::FillOp>();
target.addLegalDialect<arith::ArithDialect,
bufferization::BufferizationDialect,
memref::MemRefDialect, scf::SCFDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
// Populate with rules and apply rewriting rules.
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
converter);
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
target);
populateSparseTensorCodegenPatterns(converter, patterns,
enableBufferInitialization);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
struct SparseBufferRewritePass
: public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
SparseBufferRewritePass() = default;
SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
SparseBufferRewritePass(bool enableInit) {
enableBufferInitialization = enableInit;
}
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateSparseBufferRewriting(patterns, enableBufferInitialization);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct SparseVectorizationPass
: public impl::SparseVectorizationBase<SparseVectorizationPass> {
SparseVectorizationPass() = default;
SparseVectorizationPass(const SparseVectorizationPass &pass) = default;
SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) {
vectorLength = vl;
enableVLAVectorization = vla;
enableSIMDIndex32 = sidx32;
}
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateSparseVectorizationPatterns(
patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Strategy flag methods.
//===----------------------------------------------------------------------===//
SparseToSparseConversionStrategy
mlir::sparseToSparseConversionStrategy(int32_t flag) {
switch (flag) {
default:
return SparseToSparseConversionStrategy::kAuto;
case 1:
return SparseToSparseConversionStrategy::kViaCOO;
case 2:
return SparseToSparseConversionStrategy::kDirect;
}
}
//===----------------------------------------------------------------------===//
// Pass creation methods.
//===----------------------------------------------------------------------===//
std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() {
return std::make_unique<PreSparsificationRewritePass>();
}
std::unique_ptr<Pass> mlir::createSparsificationPass() {
return std::make_unique<SparsificationPass>();
}
std::unique_ptr<Pass>
mlir::createSparsificationPass(const SparsificationOptions &options) {
return std::make_unique<SparsificationPass>(options);
}
std::unique_ptr<Pass> mlir::createPostSparsificationRewritePass() {
return std::make_unique<PostSparsificationRewritePass>();
}
std::unique_ptr<Pass>
mlir::createPostSparsificationRewritePass(bool enableRT, bool enableForeach,
bool enableConvert) {
return std::make_unique<PostSparsificationRewritePass>(
enableRT, enableForeach, enableConvert);
}
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
return std::make_unique<SparseTensorConversionPass>();
}
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
const SparseTensorConversionOptions &options) {
return std::make_unique<SparseTensorConversionPass>(options);
}
std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
return std::make_unique<SparseTensorCodegenPass>();
}
std::unique_ptr<Pass>
mlir::createSparseTensorCodegenPass(bool enableBufferInitialization) {
return std::make_unique<SparseTensorCodegenPass>(enableBufferInitialization);
}
std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() {
return std::make_unique<SparseBufferRewritePass>();
}
std::unique_ptr<Pass>
mlir::createSparseBufferRewritePass(bool enableBufferInitialization) {
return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
}
std::unique_ptr<Pass> mlir::createSparseVectorizationPass() {
return std::make_unique<SparseVectorizationPass>();
}
std::unique_ptr<Pass>
mlir::createSparseVectorizationPass(unsigned vectorLength,
bool enableVLAVectorization,
bool enableSIMDIndex32) {
return std::make_unique<SparseVectorizationPass>(
vectorLength, enableVLAVectorization, enableSIMDIndex32);
}