[mlir][sparse] support 'batch' dimensions in sparse_tensor.print (#91411)
This commit is contained in:
@@ -785,45 +785,61 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
// Helper to print contents of a single memref. Note that for the "push_back"
|
||||
// vectors, this prints the full capacity, not just the size. This is done
|
||||
// on purpose, so that clients see how much storage has been allocated in
|
||||
// total. Contents of the extra capacity in the buffer may be uninitialized
|
||||
// (unless the flag enable-buffer-initialization is set to true).
|
||||
// Helper to print contents of a single memref. For "push_back" vectors,
|
||||
// we assume that the previous getters for pos/crd/val have added a
|
||||
// slice-to-size view to make sure we just print the size and not the
|
||||
// full capacity.
|
||||
//
|
||||
// Generates code to print:
|
||||
// Generates code to print (1-dim or higher):
|
||||
// ( a0, a1, ... )
|
||||
static void printContents(PatternRewriter &rewriter, Location loc,
|
||||
Value vec) {
|
||||
auto shape = cast<ShapedType>(vec.getType()).getShape();
|
||||
SmallVector<Value> idxs;
|
||||
printContentsLevel(rewriter, loc, vec, 0, shape, idxs);
|
||||
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
|
||||
}
|
||||
|
||||
// Helper to the helper.
|
||||
static void printContentsLevel(PatternRewriter &rewriter, Location loc,
|
||||
Value vec, unsigned i, ArrayRef<int64_t> shape,
|
||||
SmallVectorImpl<Value> &idxs) {
|
||||
// Open bracket.
|
||||
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
|
||||
// For loop over elements.
|
||||
// Generate for loop.
|
||||
auto zero = constantIndex(rewriter, loc, 0);
|
||||
auto size = rewriter.create<memref::DimOp>(loc, vec, zero);
|
||||
auto index = constantIndex(rewriter, loc, i);
|
||||
auto size = rewriter.create<memref::DimOp>(loc, vec, index);
|
||||
auto step = constantIndex(rewriter, loc, 1);
|
||||
auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
|
||||
idxs.push_back(forOp.getInductionVar());
|
||||
rewriter.setInsertionPointToStart(forOp.getBody());
|
||||
auto idx = forOp.getInductionVar();
|
||||
auto val = rewriter.create<memref::LoadOp>(loc, vec, idx);
|
||||
if (llvm::isa<ComplexType>(val.getType())) {
|
||||
// Since the vector dialect does not support complex types in any op,
|
||||
// we split those into (real, imag) pairs here.
|
||||
Value real = rewriter.create<complex::ReOp>(loc, val);
|
||||
Value imag = rewriter.create<complex::ImOp>(loc, val);
|
||||
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
|
||||
rewriter.create<vector::PrintOp>(loc, real,
|
||||
vector::PrintPunctuation::Comma);
|
||||
rewriter.create<vector::PrintOp>(loc, imag,
|
||||
vector::PrintPunctuation::Close);
|
||||
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
|
||||
if (i < shape.size() - 1) {
|
||||
// Enter deeper loop nest.
|
||||
printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs);
|
||||
} else {
|
||||
rewriter.create<vector::PrintOp>(loc, val,
|
||||
vector::PrintPunctuation::Comma);
|
||||
// Actual contents printing.
|
||||
auto val = rewriter.create<memref::LoadOp>(loc, vec, idxs);
|
||||
if (llvm::isa<ComplexType>(val.getType())) {
|
||||
// Since the vector dialect does not support complex types in any op,
|
||||
// we split those into (real, imag) pairs here.
|
||||
Value real = rewriter.create<complex::ReOp>(loc, val);
|
||||
Value imag = rewriter.create<complex::ImOp>(loc, val);
|
||||
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
|
||||
rewriter.create<vector::PrintOp>(loc, real,
|
||||
vector::PrintPunctuation::Comma);
|
||||
rewriter.create<vector::PrintOp>(loc, imag,
|
||||
vector::PrintPunctuation::Close);
|
||||
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
|
||||
} else {
|
||||
rewriter.create<vector::PrintOp>(loc, val,
|
||||
vector::PrintPunctuation::Comma);
|
||||
}
|
||||
}
|
||||
idxs.pop_back();
|
||||
rewriter.setInsertionPointAfter(forOp);
|
||||
// Close bracket and end of line.
|
||||
// Close bracket.
|
||||
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
|
||||
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
|
||||
}
|
||||
|
||||
// Helper method to print run-time lvl/dim sizes.
|
||||
|
||||
Reference in New Issue
Block a user