[mlir][sparse] support 'batch' dimensions in sparse_tensor.print (#91411)

This commit is contained in:
Aart Bik
2024-05-07 19:01:36 -07:00
committed by GitHub
parent 584253c4e2
commit c4e5a8a4d3
4 changed files with 130 additions and 34 deletions

View File

@@ -417,11 +417,17 @@ static void genEndInsert(OpBuilder &builder, Location loc,
/// Generates a subview into the sizes.
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
Value sz) {
auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType();
auto memTp = llvm::cast<MemRefType>(mem.getType());
// For higher-dimensional memrefs, we assume that the innermost
// dimension is always of the right size.
// TODO: generate complex truncating view here too?
if (memTp.getRank() > 1)
return mem;
// Truncate linear memrefs to given size.
return builder
.create<memref::SubViewOp>(
loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem,
ValueRange{}, ValueRange{sz}, ValueRange{},
loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()),
mem, ValueRange{}, ValueRange{sz}, ValueRange{},
ArrayRef<int64_t>{0}, // static offset
ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
ArrayRef<int64_t>{1}) // static stride