Files
clang-p2996/mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp
Aart Bik 0b1764a3d7 [mlir][sparse] sparse tensor storage implementation
This revision connects the generated sparse code with an actual
sparse storage scheme, which can be initialized from a test file.
Lacking a first-class citizen SparseTensor type (with buffer),
the storage is hidden behind an opaque pointer with some "glue"
to bring the pointer back to tensor land. Rather than generating
sparse setup code for each different annotated tensor (viz. the
"pack" methods in TACO), a single "one-size-fits-all" implementation
has been added to the runtime support library.  Many details and
abstractions need to be refined in the future, but this revision
allows full end-to-end integration testing and performance
benchmarking (with on one end, an annotated Lingalg
op and, on the other end, a JIT/AOT executable).

Reviewed By: nicolasvasilache, bixia

Differential Revision: https://reviews.llvm.org/D95847
2021-02-10 11:57:24 -08:00

139 lines
4.9 KiB
C++

//===- SparseLowering.cpp - Lowers sparse primitives to library calls. ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
using namespace mlir;
namespace {
/// Returns function reference (first hit also inserts into module).
static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
ValueRange operands) {
MLIRContext *context = op->getContext();
auto module = op->getParentOfType<ModuleOp>();
auto func = module.lookupSymbol<FuncOp>(name);
if (!func) {
OpBuilder moduleBuilder(module.getBodyRegion());
moduleBuilder
.create<FuncOp>(op->getLoc(), name,
FunctionType::get(context, operands.getTypes(), result))
.setPrivate();
}
return SymbolRefAttr::get(context, name);
}
/// Sparse conversion rule to remove opaque pointer cast.
class TensorFromPointerConverter
: public OpConversionPattern<linalg::SparseTensorFromPointerOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(linalg::SparseTensorFromPointerOp op,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(op, operands[0]);
return success();
}
};
/// Sparse conversion rule for dimension accesses.
class TensorToDimSizeConverter : public OpConversionPattern<DimOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(DimOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!operands[0].getType().isa<LLVM::LLVMPointerType>())
return failure();
Type resType = op.getType();
StringRef name = "sparseDimSize";
rewriter.replaceOpWithNewOp<CallOp>(
op, resType, getFunc(op, name, resType, operands), operands);
return success();
}
};
/// Sparse conversion rule for pointer accesses.
class TensorToPointersConverter
: public OpConversionPattern<linalg::SparseTensorToPointersMemRefOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(linalg::SparseTensorToPointersMemRefOp op,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Type resType = op.getType();
Type eltType = resType.cast<ShapedType>().getElementType();
StringRef name;
if (eltType.isIndex() || eltType.isInteger(64))
name = "sparsePtrsI64";
else
return failure();
rewriter.replaceOpWithNewOp<CallOp>(
op, resType, getFunc(op, name, resType, operands), operands);
return success();
}
};
/// Sparse conversion rule for index accesses.
class TensorToIndicesConverter
: public OpConversionPattern<linalg::SparseTensorToIndicesMemRefOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(linalg::SparseTensorToIndicesMemRefOp op,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Type resType = op.getType();
Type eltType = resType.cast<ShapedType>().getElementType();
StringRef name;
if (eltType.isIndex() || eltType.isInteger(64))
name = "sparseIndxsI64";
else
return failure();
rewriter.replaceOpWithNewOp<CallOp>(
op, resType, getFunc(op, name, resType, operands), operands);
return success();
}
};
/// Sparse conversion rule for value accesses.
class TensorToValuesConverter
: public OpConversionPattern<linalg::SparseTensorToValuesMemRefOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(linalg::SparseTensorToValuesMemRefOp op,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Type resType = op.getType();
Type eltType = resType.cast<ShapedType>().getElementType();
StringRef name;
if (eltType.isF64())
name = "sparseValsF64";
else
return failure();
rewriter.replaceOpWithNewOp<CallOp>(
op, resType, getFunc(op, name, resType, operands), operands);
return success();
}
};
} // namespace
/// Populates the given patterns list with conversion rules required for
/// the sparsification of linear algebra operations.
void linalg::populateSparsificationConversionPatterns(
MLIRContext *context, OwningRewritePatternList &patterns) {
patterns.insert<TensorFromPointerConverter, TensorToDimSizeConverter,
TensorToPointersConverter, TensorToIndicesConverter,
TensorToValuesConverter>(context);
}