[mlir][sparse] support type conversion from SoA COO to memrefs. (#82398)

This commit is contained in:
Peiming Liu
2024-02-20 13:19:13 -06:00
committed by GitHub
parent a9b5753220
commit f740366fa6
4 changed files with 99 additions and 9 deletions

View File

@@ -74,11 +74,12 @@ void StorageLayout::foreachField(
callback) const {
const auto lvlTypes = enc.getLvlTypes();
const Level lvlRank = enc.getLvlRank();
const Level cooStart = SparseTensorType(enc).getCOOStart();
const Level end = cooStart == lvlRank ? cooStart : cooStart + 1;
SmallVector<COOSegment> cooSegs = SparseTensorType(enc).getCOOSegments();
FieldIndex fieldIdx = kDataFieldStartingIdx;
ArrayRef cooSegsRef = cooSegs;
// Per-level storage.
for (Level l = 0; l < end; l++) {
for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
const auto lt = lvlTypes[l];
if (isWithPosLT(lt)) {
if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
@@ -88,6 +89,21 @@ void StorageLayout::foreachField(
if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
return;
}
if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
if (!cooSegsRef.front().isSoA) {
// AoS COO, all singletons are fused into one memrefs. Skips the entire
// COO segement.
l = cooSegsRef.front().lvlRange.second;
} else {
// SoA COO, each singleton level has one memref.
l++;
}
// Expire handled COO segment.
cooSegsRef = cooSegsRef.drop_front();
} else {
// Non COO levels.
l++;
}
}
// The values array.
if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
@@ -796,13 +812,46 @@ bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
}
Level mlir::sparse_tensor::SparseTensorType::getCOOStart() const {
if (hasEncoding() && lvlRank > 1)
for (Level l = 0; l < lvlRank - 1; l++)
if (isCOOType(l, /*isUnique=*/false))
return l;
SmallVector<COOSegment> coo = getCOOSegments();
if (!coo.empty()) {
assert(coo.size() == 1);
return coo.front().lvlRange.first;
}
return lvlRank;
}
SmallVector<COOSegment>
mlir::sparse_tensor::SparseTensorType::getCOOSegments() const {
SmallVector<COOSegment> ret;
if (!hasEncoding() || lvlRank <= 1)
return ret;
ArrayRef<LevelType> lts = getLvlTypes();
Level l = 0;
while (l < lvlRank) {
auto lt = lts[l];
if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) {
auto cur = lts.begin() + l;
auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
return !lt.isa<LevelFormat::Singleton>();
});
unsigned cooLen = std::distance(cur, end);
if (cooLen > 1) {
// To support mixed SoA/AoS COO, we should break the segment when the
// storage scheme changes, for now we faithfully assume that all
// consecutive singleton levels have the same storage format as verified
// STEA.
ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
lts[l + 1].isa<LevelPropNonDefault::SoA>()});
}
l += cooLen;
} else {
l++;
}
}
return ret;
}
RankedTensorType
mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
SmallVector<LevelType> lvlTypes;