[mlir][sparse] add sparse tensor type conversion operation
Introduces a conversion from one (sparse) tensor type to another (sparse) tensor type. See the operation doc for details. Actual codegen for all cases is still TBD. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D107205
This commit is contained in:
@@ -27,16 +27,23 @@ using namespace mlir::sparse_tensor;
|
||||
|
||||
namespace {
|
||||
|
||||
/// Internal encoding of primary storage. Keep this enum consistent
|
||||
/// with the equivalent enum in the sparse runtime support library.
|
||||
enum PrimaryTypeEnum : uint64_t {
|
||||
kF64 = 1,
|
||||
kF32 = 2,
|
||||
kI64 = 3,
|
||||
kI32 = 4,
|
||||
kI16 = 5,
|
||||
kI8 = 6
|
||||
};
|
||||
/// Returns internal type encoding for primary storage. Keep these
|
||||
/// values consistent with the sparse runtime support library.
|
||||
static unsigned getPrimaryTypeEncoding(Type tp) {
|
||||
if (tp.isF64())
|
||||
return 1;
|
||||
if (tp.isF32())
|
||||
return 2;
|
||||
if (tp.isInteger(64))
|
||||
return 3;
|
||||
if (tp.isInteger(32))
|
||||
return 4;
|
||||
if (tp.isInteger(16))
|
||||
return 5;
|
||||
if (tp.isInteger(8))
|
||||
return 6;
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Returns internal type encoding for overhead storage. Keep these
|
||||
/// values consistent with the sparse runtime support library.
|
||||
@@ -170,20 +177,8 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
|
||||
// Secondary and primary types encoding.
|
||||
unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
|
||||
unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
|
||||
unsigned primary;
|
||||
if (eltType.isF64())
|
||||
primary = kF64;
|
||||
else if (eltType.isF32())
|
||||
primary = kF32;
|
||||
else if (eltType.isInteger(64))
|
||||
primary = kI64;
|
||||
else if (eltType.isInteger(32))
|
||||
primary = kI32;
|
||||
else if (eltType.isInteger(16))
|
||||
primary = kI16;
|
||||
else if (eltType.isInteger(8))
|
||||
primary = kI8;
|
||||
else
|
||||
unsigned primary = getPrimaryTypeEncoding(eltType);
|
||||
if (!primary)
|
||||
return failure();
|
||||
params.push_back(
|
||||
rewriter.create<ConstantOp>(loc, rewriter.getI64IntegerAttr(secPtr)));
|
||||
@@ -200,6 +195,17 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse conversion rule for the convert operator.
|
||||
class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(ConvertOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// TODO: implement conversions lowering
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse conversion rule for pointer accesses.
|
||||
class SparseTensorToPointersConverter
|
||||
: public OpConversionPattern<ToPointersOp> {
|
||||
@@ -324,8 +330,8 @@ public:
|
||||
void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
|
||||
SparseTensorNewConverter, SparseTensorToPointersConverter,
|
||||
SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
|
||||
SparseTensorToTensorConverter>(typeConverter,
|
||||
patterns.getContext());
|
||||
SparseTensorNewConverter, SparseTensorConvertConverter,
|
||||
SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
|
||||
SparseTensorToValuesConverter, SparseTensorToTensorConverter>(
|
||||
typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user