diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h index fa34bbe3c9d9..41a14575ed10 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h @@ -303,9 +303,9 @@ public: } /// Check if the `LevelType` is in the `LevelFormat`. - template + template constexpr bool isa() const { - return getLvlFmt() == fmt; + return (... || (getLvlFmt() == fmt)) || false; } /// Check if the `LevelType` has the properties diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h index 4e2b85d35c1a..24a5640d820e 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h @@ -18,6 +18,18 @@ namespace mlir { namespace sparse_tensor { +/// A simple structure that encodes a range of levels in the sparse tensors that +/// forms a COO segment. +struct COOSegment { + std::pair lvlRange; // [low, high) + bool isSoA; + + bool isSegmentStart(Level l) const { return l == lvlRange.first; } + bool inSegment(Level l) const { + return l >= lvlRange.first && l < lvlRange.second; + } +}; + //===----------------------------------------------------------------------===// /// A wrapper around `RankedTensorType`, which has three goals: /// @@ -330,6 +342,9 @@ public: /// Returns [un]ordered COO type for this sparse tensor type. RankedTensorType getCOOType(bool ordered) const; + /// Returns a list of COO segments in the sparse tensor types. + SmallVector getCOOSegments() const; + private: // These two must be const, to ensure coherence of the memoized fields. const RankedTensorType rtp; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index db359b4b7a5d..53e78d2c28b1 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -74,11 +74,12 @@ void StorageLayout::foreachField( callback) const { const auto lvlTypes = enc.getLvlTypes(); const Level lvlRank = enc.getLvlRank(); - const Level cooStart = SparseTensorType(enc).getCOOStart(); - const Level end = cooStart == lvlRank ? cooStart : cooStart + 1; + SmallVector cooSegs = SparseTensorType(enc).getCOOSegments(); FieldIndex fieldIdx = kDataFieldStartingIdx; + + ArrayRef cooSegsRef = cooSegs; // Per-level storage. - for (Level l = 0; l < end; l++) { + for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) { const auto lt = lvlTypes[l]; if (isWithPosLT(lt)) { if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt))) @@ -88,6 +89,21 @@ void StorageLayout::foreachField( if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt))) return; } + if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) { + if (!cooSegsRef.front().isSoA) { + // AoS COO, all singletons are fused into one memrefs. Skips the entire + // COO segement. + l = cooSegsRef.front().lvlRange.second; + } else { + // SoA COO, each singleton level has one memref. + l++; + } + // Expire handled COO segment. + cooSegsRef = cooSegsRef.drop_front(); + } else { + // Non COO levels. + l++; + } } // The values array. if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel, @@ -796,13 +812,46 @@ bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl, } Level mlir::sparse_tensor::SparseTensorType::getCOOStart() const { - if (hasEncoding() && lvlRank > 1) - for (Level l = 0; l < lvlRank - 1; l++) - if (isCOOType(l, /*isUnique=*/false)) - return l; + SmallVector coo = getCOOSegments(); + if (!coo.empty()) { + assert(coo.size() == 1); + return coo.front().lvlRange.first; + } return lvlRank; } +SmallVector +mlir::sparse_tensor::SparseTensorType::getCOOSegments() const { + SmallVector ret; + if (!hasEncoding() || lvlRank <= 1) + return ret; + + ArrayRef lts = getLvlTypes(); + Level l = 0; + while (l < lvlRank) { + auto lt = lts[l]; + if (lt.isa()) { + auto cur = lts.begin() + l; + auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) { + return !lt.isa(); + }); + unsigned cooLen = std::distance(cur, end); + if (cooLen > 1) { + // To support mixed SoA/AoS COO, we should break the segment when the + // storage scheme changes, for now we faithfully assume that all + // consecutive singleton levels have the same storage format as verified + // STEA. + ret.push_back(COOSegment{std::make_pair(l, l + cooLen), + lts[l + 1].isa()}); + } + l += cooLen; + } else { + l++; + } + } + return ret; +} + RankedTensorType mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const { SmallVector lvlTypes; diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir index a3b26972d66f..c1a976c84fec 100644 --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -48,6 +48,10 @@ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton) }> +#SoACOO = #sparse_tensor.encoding<{ + map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)) +}> + #CooPNo = #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : compressed(nonunique), d0 : singleton(nonordered)) }> @@ -67,6 +71,28 @@ func.func @sparse_nop(%arg0: tensor) -> tensor } +// CHECK-LABEL: func @sparse_nop_aos_coo( +// CHECK-SAME: %[[POS:.*0]]: memref, +// CHECK-SAME: %[[AoS_CRD:.*1]]: memref, +// CHECK-SAME: %[[VAL:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier +// CHECK: return %[[POS]], %[[AoS_CRD]], %[[VAL]], %[[A3]] +func.func @sparse_nop_aos_coo(%arg0: tensor) -> tensor { + return %arg0 : tensor +} + +// CHECK-LABEL: func @sparse_nop_soa_coo( +// CHECK-SAME: %[[POS:.*0]]: memref, +// CHECK-SAME: %[[SoA_CRD_0:.*1]]: memref, +// CHECK-SAME: %[[SoA_CRD_1:.*2]]: memref, +// CHECK-SAME: %[[VAL:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier +// CHECK: return %[[POS]], %[[SoA_CRD_0]], %[[SoA_CRD_1]], %[[VAL]], %[[A3]] +func.func @sparse_nop_soa_coo(%arg0: tensor) -> tensor { + return %arg0 : tensor +} + + // CHECK-LABEL: func @sparse_nop_multi_ret( // CHECK-SAME: %[[A0:.*0]]: memref, // CHECK-SAME: %[[A1:.*1]]: memref,