[mlir][sparse] Using SparseTensorType in SparsePackOpConverter
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D147465
This commit is contained in:
@@ -1235,29 +1235,28 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
|
||||
matchAndRewrite(PackOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
const auto rtp = getRankedTensorType(op.getResult());
|
||||
assert(isUniqueCOOType(rtp));
|
||||
const auto stt = getSparseTensorType(op.getResult());
|
||||
assert(isUniqueCOOType(stt));
|
||||
|
||||
SmallVector<Value> fields;
|
||||
Location loc = op.getLoc();
|
||||
|
||||
foreachFieldAndTypeInSparseTensor(
|
||||
rtp,
|
||||
[&rewriter, &fields, &op, rtp,
|
||||
stt,
|
||||
[&rewriter, &fields, &op, stt,
|
||||
loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
|
||||
Level /*lvl*/, DimLevelType /*dlt*/) -> bool {
|
||||
assert(fields.size() == fIdx);
|
||||
auto enc = getSparseTensorEncoding(rtp);
|
||||
Value field;
|
||||
switch (fKind) {
|
||||
case SparseTensorFieldKind::StorageSpec:
|
||||
field = SparseTensorSpecifier::getInitValue(rewriter, loc, rtp);
|
||||
field = SparseTensorSpecifier::getInitValue(rewriter, loc, stt);
|
||||
break;
|
||||
case SparseTensorFieldKind::PosMemRef: {
|
||||
// TACO-style COO starts with a PosBuffer
|
||||
// By creating a constant value for it, we avoid the complexity of
|
||||
// memory management.
|
||||
const auto posTp = enc.getPosType();
|
||||
const auto posTp = stt.getPosType();
|
||||
auto tensorType = RankedTensorType::get({2}, posTp);
|
||||
auto memrefType = MemRefType::get(tensorType.getShape(),
|
||||
tensorType.getElementType());
|
||||
@@ -1306,13 +1305,11 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
|
||||
return true;
|
||||
});
|
||||
|
||||
MutSparseTensorDescriptor desc(rtp, fields);
|
||||
MutSparseTensorDescriptor desc(stt, fields);
|
||||
auto noe = linalg::createOrFoldDimOp(rewriter, loc, op.getValues(), 0);
|
||||
// FIXME: should use `SparseTensorType::getLvlRank` in lieu of
|
||||
// `RankedTensorType::getRank`, because the latter introduces dim/lvl
|
||||
// ambiguity.
|
||||
for (Level lvl = 0, lvlRank = rtp.getRank(); lvl < lvlRank; lvl++) {
|
||||
const auto sh = rtp.getShape()[lvl];
|
||||
for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
|
||||
// FIXME: dim/lvl confusion!
|
||||
const auto sh = stt.getDimShape()[lvl];
|
||||
assert(!ShapedType::isDynamic(sh));
|
||||
desc.setLvlSize(rewriter, loc, lvl, constantIndex(rewriter, loc, sh));
|
||||
if (lvl == 0)
|
||||
|
||||
Reference in New Issue
Block a user