[mlir][sparse] complete migration to sparse tensor type

A very elaborate, but also very fun revision because all
puzzle pieces are finally "falling in place".

1. replaces lingalg annotations + flags with proper sparse tensor types
2. add rigorous verification on sparse tensor type and sparse primitives
3. removes glue and clutter on opaque pointers in favor of sparse tensor types
4. migrates all tests to use sparse tensor types

NOTE: next CL will remove *all* obsoleted sparse code in Linalg

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D102095
This commit is contained in:
Aart Bik
2021-05-10 10:34:21 -07:00
parent b1c3c2e4fc
commit 96a23911f6
25 changed files with 1478 additions and 2122 deletions

View File

@@ -1,4 +1,4 @@
//===- SparseTensorLowering.cpp - Sparse tensor primitives lowering -------===//
//===- SparseTensorLowering.cpp - Sparse tensor primitives conversion -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,9 +6,9 @@
//
//===----------------------------------------------------------------------===//
//
// Lower sparse tensor primitives to calls into a runtime support library.
// Note that this is a current implementation choice to keep the lowering
// simple. In principle, these primitives could also be lowered to actual
// Convert sparse tensor primitives to calls into a runtime support library.
// Note that this is a current implementation choice to keep the conversion
// simple. In principle, these primitives could also be converted to actual
// elaborate IR code that implements the primitives on the selected sparse
// tensor storage schemes.
//
@@ -22,9 +22,24 @@
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
namespace {
/// Returns internal type encoding for overhead storage.
static unsigned getOverheadTypeEncoding(unsigned width) {
switch (width) {
default:
return 1;
case 32:
return 2;
case 16:
return 3;
case 8:
return 4;
}
}
/// Returns function reference (first hit also inserts into module).
static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
ValueRange operands) {
@@ -41,14 +56,14 @@ static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
return SymbolRefAttr::get(context, name);
}
/// Sparse conversion rule to remove opaque pointer cast.
class SparseTensorFromPointerConverter
: public OpConversionPattern<sparse_tensor::FromPointerOp> {
/// Sparse conversion rule for returns.
class SparseReturnConverter : public OpConversionPattern<ReturnOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(sparse_tensor::FromPointerOp op, ArrayRef<Value> operands,
matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(op, operands[0]);
rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
return success();
}
};
@@ -71,18 +86,77 @@ public:
}
};
/// Sparse conversion rule for the new operator.
class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(NewOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type resType = op.getType();
Type eltType = resType.cast<ShapedType>().getElementType();
MLIRContext *context = op->getContext();
SmallVector<Value, 5> params;
// Sparse encoding.
auto enc = getSparseTensorEncoding(resType);
if (!enc)
return failure();
// User pointer.
params.push_back(operands[0]);
// Sparsity annotations.
SmallVector<bool, 4> attrs;
unsigned sz = enc.getDimLevelType().size();
for (unsigned i = 0; i < sz; i++)
attrs.push_back(enc.getDimLevelType()[i] ==
SparseTensorEncodingAttr::DimLevelType::Compressed);
auto elts = DenseElementsAttr::get(
RankedTensorType::get({sz}, rewriter.getIntegerType(1)), attrs);
params.push_back(rewriter.create<ConstantOp>(loc, elts));
// Seconary and primary types encoding.
unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
unsigned primary;
if (eltType.isF64())
primary = 1;
else if (eltType.isF32())
primary = 2;
else if (eltType.isInteger(32))
primary = 3;
else if (eltType.isInteger(16))
primary = 4;
else if (eltType.isInteger(8))
primary = 5;
else
return failure();
params.push_back(
rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr)));
params.push_back(
rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secInd)));
params.push_back(
rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(primary)));
// Generate the call to create new tensor.
Type ptrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
StringRef name = "newSparseTensor";
rewriter.replaceOpWithNewOp<CallOp>(
op, ptrType, getFunc(op, name, ptrType, params), params);
return success();
}
};
/// Sparse conversion rule for pointer accesses.
class SparseTensorToPointersConverter
: public OpConversionPattern<sparse_tensor::ToPointersOp> {
: public OpConversionPattern<ToPointersOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(sparse_tensor::ToPointersOp op, ArrayRef<Value> operands,
matchAndRewrite(ToPointersOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Type resType = op.getType();
Type eltType = resType.cast<ShapedType>().getElementType();
StringRef name;
if (eltType.isIndex() || eltType.isInteger(64))
if (eltType.isIndex())
name = "sparsePointers";
else if (eltType.isInteger(64))
name = "sparsePointers64";
else if (eltType.isInteger(32))
name = "sparsePointers32";
@@ -99,17 +173,18 @@ public:
};
/// Sparse conversion rule for index accesses.
class SparseTensorToIndicesConverter
: public OpConversionPattern<sparse_tensor::ToIndicesOp> {
class SparseTensorToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(sparse_tensor::ToIndicesOp op, ArrayRef<Value> operands,
matchAndRewrite(ToIndicesOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Type resType = op.getType();
Type eltType = resType.cast<ShapedType>().getElementType();
StringRef name;
if (eltType.isIndex() || eltType.isInteger(64))
if (eltType.isIndex())
name = "sparseIndices";
else if (eltType.isInteger(64))
name = "sparseIndices64";
else if (eltType.isInteger(32))
name = "sparseIndices32";
@@ -126,12 +201,11 @@ public:
};
/// Sparse conversion rule for value accesses.
class SparseTensorToValuesConverter
: public OpConversionPattern<sparse_tensor::ToValuesOp> {
class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(sparse_tensor::ToValuesOp op, ArrayRef<Value> operands,
matchAndRewrite(ToValuesOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Type resType = op.getType();
Type eltType = resType.cast<ShapedType>().getElementType();
@@ -158,8 +232,10 @@ public:
/// Populates the given patterns list with conversion rules required for
/// the sparsification of linear algebra operations.
void mlir::populateSparseTensorConversionPatterns(RewritePatternSet &patterns) {
patterns.add<SparseTensorFromPointerConverter, SparseTensorToDimSizeConverter,
SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
SparseTensorToValuesConverter>(patterns.getContext());
void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
SparseTensorNewConverter, SparseTensorToPointersConverter,
SparseTensorToIndicesConverter, SparseTensorToValuesConverter>(
typeConverter, patterns.getContext());
}