[mlir][sparse] support 'batch' dimensions in sparse_tensor.print (#91411)

This commit is contained in:
Aart Bik
2024-05-07 19:01:36 -07:00
committed by GitHub
parent 584253c4e2
commit c4e5a8a4d3
4 changed files with 130 additions and 34 deletions

View File

@@ -785,45 +785,61 @@ public:
}
private:
// Helper to print contents of a single memref. Note that for the "push_back"
// vectors, this prints the full capacity, not just the size. This is done
// on purpose, so that clients see how much storage has been allocated in
// total. Contents of the extra capacity in the buffer may be uninitialized
// (unless the flag enable-buffer-initialization is set to true).
// Helper to print contents of a single memref. For "push_back" vectors,
// we assume that the previous getters for pos/crd/val have added a
// slice-to-size view to make sure we just print the size and not the
// full capacity.
//
// Generates code to print:
// Generates code to print (1-dim or higher):
// ( a0, a1, ... )
static void printContents(PatternRewriter &rewriter, Location loc,
Value vec) {
auto shape = cast<ShapedType>(vec.getType()).getShape();
SmallVector<Value> idxs;
printContentsLevel(rewriter, loc, vec, 0, shape, idxs);
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
}
// Helper to the helper.
static void printContentsLevel(PatternRewriter &rewriter, Location loc,
Value vec, unsigned i, ArrayRef<int64_t> shape,
SmallVectorImpl<Value> &idxs) {
// Open bracket.
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
// For loop over elements.
// Generate for loop.
auto zero = constantIndex(rewriter, loc, 0);
auto size = rewriter.create<memref::DimOp>(loc, vec, zero);
auto index = constantIndex(rewriter, loc, i);
auto size = rewriter.create<memref::DimOp>(loc, vec, index);
auto step = constantIndex(rewriter, loc, 1);
auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
idxs.push_back(forOp.getInductionVar());
rewriter.setInsertionPointToStart(forOp.getBody());
auto idx = forOp.getInductionVar();
auto val = rewriter.create<memref::LoadOp>(loc, vec, idx);
if (llvm::isa<ComplexType>(val.getType())) {
// Since the vector dialect does not support complex types in any op,
// we split those into (real, imag) pairs here.
Value real = rewriter.create<complex::ReOp>(loc, val);
Value imag = rewriter.create<complex::ImOp>(loc, val);
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
rewriter.create<vector::PrintOp>(loc, real,
vector::PrintPunctuation::Comma);
rewriter.create<vector::PrintOp>(loc, imag,
vector::PrintPunctuation::Close);
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
if (i < shape.size() - 1) {
// Enter deeper loop nest.
printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs);
} else {
rewriter.create<vector::PrintOp>(loc, val,
vector::PrintPunctuation::Comma);
// Actual contents printing.
auto val = rewriter.create<memref::LoadOp>(loc, vec, idxs);
if (llvm::isa<ComplexType>(val.getType())) {
// Since the vector dialect does not support complex types in any op,
// we split those into (real, imag) pairs here.
Value real = rewriter.create<complex::ReOp>(loc, val);
Value imag = rewriter.create<complex::ImOp>(loc, val);
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
rewriter.create<vector::PrintOp>(loc, real,
vector::PrintPunctuation::Comma);
rewriter.create<vector::PrintOp>(loc, imag,
vector::PrintPunctuation::Close);
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
} else {
rewriter.create<vector::PrintOp>(loc, val,
vector::PrintPunctuation::Comma);
}
}
idxs.pop_back();
rewriter.setInsertionPointAfter(forOp);
// Close bracket and end of line.
// Close bracket.
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
}
// Helper method to print run-time lvl/dim sizes.