[mlir][sparse] support type conversion from SoA COO to memrefs. (#82398)
This commit is contained in:
@@ -74,11 +74,12 @@ void StorageLayout::foreachField(
|
||||
callback) const {
|
||||
const auto lvlTypes = enc.getLvlTypes();
|
||||
const Level lvlRank = enc.getLvlRank();
|
||||
const Level cooStart = SparseTensorType(enc).getCOOStart();
|
||||
const Level end = cooStart == lvlRank ? cooStart : cooStart + 1;
|
||||
SmallVector<COOSegment> cooSegs = SparseTensorType(enc).getCOOSegments();
|
||||
FieldIndex fieldIdx = kDataFieldStartingIdx;
|
||||
|
||||
ArrayRef cooSegsRef = cooSegs;
|
||||
// Per-level storage.
|
||||
for (Level l = 0; l < end; l++) {
|
||||
for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
|
||||
const auto lt = lvlTypes[l];
|
||||
if (isWithPosLT(lt)) {
|
||||
if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
|
||||
@@ -88,6 +89,21 @@ void StorageLayout::foreachField(
|
||||
if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
|
||||
return;
|
||||
}
|
||||
if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
|
||||
if (!cooSegsRef.front().isSoA) {
|
||||
// AoS COO, all singletons are fused into one memrefs. Skips the entire
|
||||
// COO segement.
|
||||
l = cooSegsRef.front().lvlRange.second;
|
||||
} else {
|
||||
// SoA COO, each singleton level has one memref.
|
||||
l++;
|
||||
}
|
||||
// Expire handled COO segment.
|
||||
cooSegsRef = cooSegsRef.drop_front();
|
||||
} else {
|
||||
// Non COO levels.
|
||||
l++;
|
||||
}
|
||||
}
|
||||
// The values array.
|
||||
if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
|
||||
@@ -796,13 +812,46 @@ bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
|
||||
}
|
||||
|
||||
Level mlir::sparse_tensor::SparseTensorType::getCOOStart() const {
|
||||
if (hasEncoding() && lvlRank > 1)
|
||||
for (Level l = 0; l < lvlRank - 1; l++)
|
||||
if (isCOOType(l, /*isUnique=*/false))
|
||||
return l;
|
||||
SmallVector<COOSegment> coo = getCOOSegments();
|
||||
if (!coo.empty()) {
|
||||
assert(coo.size() == 1);
|
||||
return coo.front().lvlRange.first;
|
||||
}
|
||||
return lvlRank;
|
||||
}
|
||||
|
||||
SmallVector<COOSegment>
|
||||
mlir::sparse_tensor::SparseTensorType::getCOOSegments() const {
|
||||
SmallVector<COOSegment> ret;
|
||||
if (!hasEncoding() || lvlRank <= 1)
|
||||
return ret;
|
||||
|
||||
ArrayRef<LevelType> lts = getLvlTypes();
|
||||
Level l = 0;
|
||||
while (l < lvlRank) {
|
||||
auto lt = lts[l];
|
||||
if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) {
|
||||
auto cur = lts.begin() + l;
|
||||
auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
|
||||
return !lt.isa<LevelFormat::Singleton>();
|
||||
});
|
||||
unsigned cooLen = std::distance(cur, end);
|
||||
if (cooLen > 1) {
|
||||
// To support mixed SoA/AoS COO, we should break the segment when the
|
||||
// storage scheme changes, for now we faithfully assume that all
|
||||
// consecutive singleton levels have the same storage format as verified
|
||||
// STEA.
|
||||
ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
|
||||
lts[l + 1].isa<LevelPropNonDefault::SoA>()});
|
||||
}
|
||||
l += cooLen;
|
||||
} else {
|
||||
l++;
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
RankedTensorType
|
||||
mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
|
||||
SmallVector<LevelType> lvlTypes;
|
||||
|
||||
Reference in New Issue
Block a user