This revision connects the generated sparse code with an actual sparse storage scheme, which can be initialized from a test file. Lacking a first-class citizen SparseTensor type (with buffer), the storage is hidden behind an opaque pointer with some "glue" to bring the pointer back to tensor land. Rather than generating sparse setup code for each different annotated tensor (viz. the "pack" methods in TACO), a single "one-size-fits-all" implementation has been added to the runtime support library. Many details and abstractions need to be refined in the future, but this revision allows full end-to-end integration testing and performance benchmarking (with on one end, an annotated Lingalg op and, on the other end, a JIT/AOT executable). Reviewed By: nicolasvasilache, bixia Differential Revision: https://reviews.llvm.org/D95847
139 lines
4.9 KiB
C++
139 lines
4.9 KiB
C++
//===- SparseLowering.cpp - Lowers sparse primitives to library calls. ---===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
/// Returns function reference (first hit also inserts into module).
|
|
static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
|
|
ValueRange operands) {
|
|
MLIRContext *context = op->getContext();
|
|
auto module = op->getParentOfType<ModuleOp>();
|
|
auto func = module.lookupSymbol<FuncOp>(name);
|
|
if (!func) {
|
|
OpBuilder moduleBuilder(module.getBodyRegion());
|
|
moduleBuilder
|
|
.create<FuncOp>(op->getLoc(), name,
|
|
FunctionType::get(context, operands.getTypes(), result))
|
|
.setPrivate();
|
|
}
|
|
return SymbolRefAttr::get(context, name);
|
|
}
|
|
|
|
/// Sparse conversion rule to remove opaque pointer cast.
|
|
class TensorFromPointerConverter
|
|
: public OpConversionPattern<linalg::SparseTensorFromPointerOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(linalg::SparseTensorFromPointerOp op,
|
|
ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOp(op, operands[0]);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Sparse conversion rule for dimension accesses.
|
|
class TensorToDimSizeConverter : public OpConversionPattern<DimOp> {
|
|
public:
|
|
using OpConversionPattern::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(DimOp op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (!operands[0].getType().isa<LLVM::LLVMPointerType>())
|
|
return failure();
|
|
Type resType = op.getType();
|
|
StringRef name = "sparseDimSize";
|
|
rewriter.replaceOpWithNewOp<CallOp>(
|
|
op, resType, getFunc(op, name, resType, operands), operands);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Sparse conversion rule for pointer accesses.
|
|
class TensorToPointersConverter
|
|
: public OpConversionPattern<linalg::SparseTensorToPointersMemRefOp> {
|
|
public:
|
|
using OpConversionPattern::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(linalg::SparseTensorToPointersMemRefOp op,
|
|
ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type resType = op.getType();
|
|
Type eltType = resType.cast<ShapedType>().getElementType();
|
|
StringRef name;
|
|
if (eltType.isIndex() || eltType.isInteger(64))
|
|
name = "sparsePtrsI64";
|
|
else
|
|
return failure();
|
|
rewriter.replaceOpWithNewOp<CallOp>(
|
|
op, resType, getFunc(op, name, resType, operands), operands);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Sparse conversion rule for index accesses.
|
|
class TensorToIndicesConverter
|
|
: public OpConversionPattern<linalg::SparseTensorToIndicesMemRefOp> {
|
|
public:
|
|
using OpConversionPattern::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(linalg::SparseTensorToIndicesMemRefOp op,
|
|
ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type resType = op.getType();
|
|
Type eltType = resType.cast<ShapedType>().getElementType();
|
|
StringRef name;
|
|
if (eltType.isIndex() || eltType.isInteger(64))
|
|
name = "sparseIndxsI64";
|
|
else
|
|
return failure();
|
|
rewriter.replaceOpWithNewOp<CallOp>(
|
|
op, resType, getFunc(op, name, resType, operands), operands);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Sparse conversion rule for value accesses.
|
|
class TensorToValuesConverter
|
|
: public OpConversionPattern<linalg::SparseTensorToValuesMemRefOp> {
|
|
public:
|
|
using OpConversionPattern::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(linalg::SparseTensorToValuesMemRefOp op,
|
|
ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type resType = op.getType();
|
|
Type eltType = resType.cast<ShapedType>().getElementType();
|
|
StringRef name;
|
|
if (eltType.isF64())
|
|
name = "sparseValsF64";
|
|
else
|
|
return failure();
|
|
rewriter.replaceOpWithNewOp<CallOp>(
|
|
op, resType, getFunc(op, name, resType, operands), operands);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
/// Populates the given patterns list with conversion rules required for
|
|
/// the sparsification of linear algebra operations.
|
|
void linalg::populateSparsificationConversionPatterns(
|
|
MLIRContext *context, OwningRewritePatternList &patterns) {
|
|
patterns.insert<TensorFromPointerConverter, TensorToDimSizeConverter,
|
|
TensorToPointersConverter, TensorToIndicesConverter,
|
|
TensorToValuesConverter>(context);
|
|
}
|