[mlir][sparse] support sparsifying batch levels (#83898)
This commit is contained in:
@@ -429,11 +429,18 @@ static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
|
||||
}
|
||||
|
||||
/// Creates the reassociation array.
|
||||
static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
|
||||
ReassociationIndices reassociation;
|
||||
for (int i = 0, e = srcTp.getRank(); i < e; i++)
|
||||
reassociation.push_back(i);
|
||||
return reassociation;
|
||||
static SmallVector<ReassociationIndices>
|
||||
getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls) {
|
||||
SmallVector<ReassociationIndices> ret(batchLvls + 1, {});
|
||||
// Create reassociation in the form:
|
||||
// {0}, {1}, ..., {batchLvl - 1}, {batchLvl, ..., rank}
|
||||
for (unsigned i = 0; i < batchLvls; i++)
|
||||
ret[i].push_back(i);
|
||||
|
||||
for (int i = batchLvls, e = srcTp.getRank(); i < e; i++)
|
||||
ret.back().push_back(i);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -1287,9 +1294,10 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
|
||||
: op.getLevels()[fIdx];
|
||||
// TODO: handle batch.
|
||||
TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
|
||||
if (mem.getType().getRank() > 1) {
|
||||
// Flattens the buffer to rank 1.
|
||||
auto reassoc = getReassociationForFlattening(mem.getType());
|
||||
if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
|
||||
// Flattens the buffer to batchLvlRank.
|
||||
auto reassoc = getReassociationForFlattening(
|
||||
mem.getType(), stt.getBatchLvlRank());
|
||||
mem = rewriter.create<memref::CastOp>(
|
||||
loc, fType,
|
||||
rewriter.create<memref::CollapseShapeOp>(loc, mem, reassoc));
|
||||
@@ -1325,11 +1333,17 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
|
||||
// Sets up the memory size by reading the last value in position array.
|
||||
LevelType lt = stt.getLvlType(lvl);
|
||||
// Simply forwards the position index when this is a dense level.
|
||||
if (isDenseLT(lt)) {
|
||||
if (lt.isa<LevelFormat::Dense>()) {
|
||||
memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize);
|
||||
posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
|
||||
continue;
|
||||
}
|
||||
if (lt.isa<LevelFormat::Batch>()) {
|
||||
// Skips batch levels as it is not linearized.
|
||||
// FIXME: this assumes that every batch has the same number of nse, need
|
||||
// to be generalized to handle varied-size batches.
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isWithPosLT(lt)) {
|
||||
assert(isCompressedLT(lt) || isLooseCompressedLT(lt));
|
||||
@@ -1343,7 +1357,12 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
|
||||
}
|
||||
desc.setPosMemSize(rewriter, loc, lvl, memSize);
|
||||
// The last value in position array is the memory size for next level.
|
||||
memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), posBack);
|
||||
// FIXME: this assumes that every batch has the same number of nse, need
|
||||
// to be generalized to handle varied-size batches.
|
||||
SmallVector<Value> batched(stt.getBatchLvlRank(),
|
||||
constantIndex(rewriter, loc, 0));
|
||||
batched.push_back(posBack);
|
||||
memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), batched);
|
||||
posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1);
|
||||
}
|
||||
assert(isWithCrdLT(lt) && lvl <= trailCOOStart);
|
||||
@@ -1413,8 +1432,9 @@ struct SparseDisassembleOpConverter
|
||||
retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp));
|
||||
}
|
||||
Value flatOut = dst;
|
||||
if (dst.getType().getRank() != 1) {
|
||||
auto reassoc = getReassociationForFlattening(dst.getType());
|
||||
if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
|
||||
auto reassoc =
|
||||
getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank());
|
||||
flatOut = rewriter.create<memref::CollapseShapeOp>(loc, dst, reassoc);
|
||||
}
|
||||
Value dstMem = genSliceToSize(rewriter, loc, flatOut, sz);
|
||||
|
||||
Reference in New Issue
Block a user