[mlir][sparse] add sparse tensor type conversion operation

Introduces a conversion from one (sparse) tensor type to another
(sparse) tensor type. See the operation doc for details. Actual
codegen for all cases is still TBD.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D107205
This commit is contained in:
Aart Bik
2021-07-30 17:52:39 -07:00
parent 7f55557765
commit 697ea09d47
5 changed files with 127 additions and 29 deletions

View File

@@ -27,16 +27,23 @@ using namespace mlir::sparse_tensor;
namespace {
/// Internal encoding of primary storage. Keep this enum consistent
/// with the equivalent enum in the sparse runtime support library.
enum PrimaryTypeEnum : uint64_t {
kF64 = 1,
kF32 = 2,
kI64 = 3,
kI32 = 4,
kI16 = 5,
kI8 = 6
};
/// Returns internal type encoding for primary storage. Keep these
/// values consistent with the sparse runtime support library.
static unsigned getPrimaryTypeEncoding(Type tp) {
if (tp.isF64())
return 1;
if (tp.isF32())
return 2;
if (tp.isInteger(64))
return 3;
if (tp.isInteger(32))
return 4;
if (tp.isInteger(16))
return 5;
if (tp.isInteger(8))
return 6;
return 0;
}
/// Returns internal type encoding for overhead storage. Keep these
/// values consistent with the sparse runtime support library.
@@ -170,20 +177,8 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
// Secondary and primary types encoding.
unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
unsigned primary;
if (eltType.isF64())
primary = kF64;
else if (eltType.isF32())
primary = kF32;
else if (eltType.isInteger(64))
primary = kI64;
else if (eltType.isInteger(32))
primary = kI32;
else if (eltType.isInteger(16))
primary = kI16;
else if (eltType.isInteger(8))
primary = kI8;
else
unsigned primary = getPrimaryTypeEncoding(eltType);
if (!primary)
return failure();
params.push_back(
rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr)));
@@ -200,6 +195,17 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
}
};
/// Sparse conversion rule for the convert operator.
class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ConvertOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// TODO: implement conversions lowering
return failure();
}
};
/// Sparse conversion rule for pointer accesses.
class SparseTensorToPointersConverter
: public OpConversionPattern<ToPointersOp> {
@@ -324,8 +330,8 @@ public:
void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
SparseTensorNewConverter, SparseTensorToPointersConverter,
SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
SparseTensorToTensorConverter>(typeConverter,
patterns.getContext());
SparseTensorNewConverter, SparseTensorConvertConverter,
SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
SparseTensorToValuesConverter, SparseTensorToTensorConverter>(
typeConverter, patterns.getContext());
}