[mlir][sparse] support sparsifying batch levels (#83898)

This commit is contained in:
Peiming Liu
2024-03-04 14:39:06 -08:00
committed by GitHub
parent 8cc8fdaf5c
commit 52b69aa32f
14 changed files with 224 additions and 83 deletions

View File

@@ -429,11 +429,18 @@ static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
}
/// Creates the reassociation array.
static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
ReassociationIndices reassociation;
for (int i = 0, e = srcTp.getRank(); i < e; i++)
reassociation.push_back(i);
return reassociation;
static SmallVector<ReassociationIndices>
getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls) {
SmallVector<ReassociationIndices> ret(batchLvls + 1, {});
// Create reassociation in the form:
// {0}, {1}, ..., {batchLvl - 1}, {batchLvl, ..., rank}
for (unsigned i = 0; i < batchLvls; i++)
ret[i].push_back(i);
for (int i = batchLvls, e = srcTp.getRank(); i < e; i++)
ret.back().push_back(i);
return ret;
}
//===----------------------------------------------------------------------===//
@@ -1287,9 +1294,10 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
: op.getLevels()[fIdx];
// TODO: handle batch.
TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
if (mem.getType().getRank() > 1) {
// Flattens the buffer to rank 1.
auto reassoc = getReassociationForFlattening(mem.getType());
if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
// Flattens the buffer to batchLvlRank.
auto reassoc = getReassociationForFlattening(
mem.getType(), stt.getBatchLvlRank());
mem = rewriter.create<memref::CastOp>(
loc, fType,
rewriter.create<memref::CollapseShapeOp>(loc, mem, reassoc));
@@ -1325,11 +1333,17 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
// Sets up the memory size by reading the last value in position array.
LevelType lt = stt.getLvlType(lvl);
// Simply forwards the position index when this is a dense level.
if (isDenseLT(lt)) {
if (lt.isa<LevelFormat::Dense>()) {
memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize);
posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
continue;
}
if (lt.isa<LevelFormat::Batch>()) {
// Skips batch levels as it is not linearized.
// FIXME: this assumes that every batch has the same number of nse, need
// to be generalized to handle varied-size batches.
continue;
}
if (isWithPosLT(lt)) {
assert(isCompressedLT(lt) || isLooseCompressedLT(lt));
@@ -1343,7 +1357,12 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
}
desc.setPosMemSize(rewriter, loc, lvl, memSize);
// The last value in position array is the memory size for next level.
memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), posBack);
// FIXME: this assumes that every batch has the same number of nse, need
// to be generalized to handle varied-size batches.
SmallVector<Value> batched(stt.getBatchLvlRank(),
constantIndex(rewriter, loc, 0));
batched.push_back(posBack);
memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), batched);
posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1);
}
assert(isWithCrdLT(lt) && lvl <= trailCOOStart);
@@ -1413,8 +1432,9 @@ struct SparseDisassembleOpConverter
retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp));
}
Value flatOut = dst;
if (dst.getType().getRank() != 1) {
auto reassoc = getReassociationForFlattening(dst.getType());
if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
auto reassoc =
getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank());
flatOut = rewriter.create<memref::CollapseShapeOp>(loc, dst, reassoc);
}
Value dstMem = genSliceToSize(rewriter, loc, flatOut, sz);