[mlir][sparse] fix crash on sparse_tensor.foreach operation on tensors with complex<T> elements.

Reviewed By: aartbik, bixia

Differential Revision: https://reviews.llvm.org/D138223
This commit is contained in:
Peiming Liu
2022-11-17 17:49:23 +00:00
parent 48dbf35302
commit 8d615a23ef
4 changed files with 98 additions and 38 deletions

View File

@@ -170,6 +170,35 @@ static void getDynamicSizes(RankedTensorType tp,
}
}
static LogicalResult genForeachOnSparseConstant(ForeachOp op,
RewriterBase &rewriter,
SparseElementsAttr attr) {
auto loc = op.getLoc();
SmallVector<Value> reduc = op.getInitArgs();
// Foreach on constant.
foreachInSparseConstant(
loc, rewriter, attr,
[&reduc, &rewriter, op](ArrayRef<Value> coords, Value v) mutable {
SmallVector<Value> args;
args.append(coords.begin(), coords.end());
args.push_back(v);
args.append(reduc);
// Clones the foreach op to get a copy of the loop body.
auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
assert(args.size() == cloned.getBody()->getNumArguments());
Operation *yield = cloned.getBody()->getTerminator();
rewriter.mergeBlockBefore(cloned.getBody(), op, args);
// clean up
rewriter.eraseOp(cloned);
reduc = yield->getOperands();
rewriter.eraseOp(yield);
});
rewriter.replaceOp(op, reduc);
return success();
}
//===---------------------------------------------------------------------===//
// The actual sparse tensor rewriting rules.
//===---------------------------------------------------------------------===//
@@ -752,36 +781,7 @@ public:
// rule.
if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) {
// Foreach on constant.
DenseElementsAttr indicesAttr = attr.getIndices();
DenseElementsAttr valuesAttr = attr.getValues();
SmallVector<Value> args;
for (int i = 0, e = valuesAttr.size(); i < e; i++) {
auto valAttr = valuesAttr.getValues<TypedAttr>()[i];
for (int j = 0; j < rank; j++) {
auto coordAttr = indicesAttr.getValues<IntegerAttr>()[i * rank + j];
auto coord = rewriter.create<arith::ConstantIndexOp>(
loc, coordAttr.getInt());
// Remaps coordinates.
args.push_back(coord);
}
// Remaps value.
auto val = rewriter.create<arith::ConstantOp>(loc, valAttr);
args.push_back(val);
// Remaps iteration args.
args.append(reduc);
auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
Operation *yield = cloned.getBody()->getTerminator();
rewriter.mergeBlockBefore(cloned.getBody(), op, args);
// clean up
args.clear();
rewriter.eraseOp(cloned);
reduc = yield->getOperands();
rewriter.eraseOp(yield);
}
rewriter.replaceOp(op, reduc);
return success();
return genForeachOnSparseConstant(op, rewriter, attr);
}
}