[mlir][sparse] prepare runtime support lib for multiple dim level types

We are moving from just dense/compressed to more general dim level
types, so we need more than just an "i1" array for annotations.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D102520
This commit is contained in:
Aart Bik
2021-05-14 19:11:39 -07:00
parent fcd12fed41
commit 56fd4c1cf8
3 changed files with 125 additions and 22 deletions

View File

@@ -41,6 +41,19 @@ static unsigned getOverheadTypeEncoding(unsigned width) {
}
}
/// Returns internal dimension level type encoding.
static unsigned
getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
switch (dlt) {
case SparseTensorEncodingAttr::DimLevelType::Dense:
return 0;
case SparseTensorEncodingAttr::DimLevelType::Compressed:
return 1;
case SparseTensorEncodingAttr::DimLevelType::Singleton:
return 2;
}
}
/// Returns function reference (first hit also inserts into module).
static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
ValueRange operands) {
@@ -107,12 +120,12 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
// Sparsity annotations in tensor constant form. Note that we cast
// the static shape into a dynamic shape to ensure that the method
// signature remains uniform accross different tensor dimensions.
SmallVector<bool, 4> attrs;
SmallVector<APInt, 4> attrs;
unsigned sz = enc.getDimLevelType().size();
for (unsigned i = 0; i < sz; i++)
attrs.push_back(enc.getDimLevelType()[i] ==
SparseTensorEncodingAttr::DimLevelType::Compressed);
Type etp = rewriter.getIntegerType(1);
attrs.push_back(
APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i])));
Type etp = rewriter.getIntegerType(8);
RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
RankedTensorType tt2 =
RankedTensorType::get({ShapedType::kDynamicSize}, etp);