[mlir][sparse] fix crash on sparse_tensor.foreach operation on tensors with complex<T> elements.
Reviewed By: aartbik, bixia Differential Revision: https://reviews.llvm.org/D138223
This commit is contained in:
@@ -170,6 +170,35 @@ static void getDynamicSizes(RankedTensorType tp,
|
||||
}
|
||||
}
|
||||
|
||||
static LogicalResult genForeachOnSparseConstant(ForeachOp op,
|
||||
RewriterBase &rewriter,
|
||||
SparseElementsAttr attr) {
|
||||
auto loc = op.getLoc();
|
||||
SmallVector<Value> reduc = op.getInitArgs();
|
||||
|
||||
// Foreach on constant.
|
||||
foreachInSparseConstant(
|
||||
loc, rewriter, attr,
|
||||
[&reduc, &rewriter, op](ArrayRef<Value> coords, Value v) mutable {
|
||||
SmallVector<Value> args;
|
||||
args.append(coords.begin(), coords.end());
|
||||
args.push_back(v);
|
||||
args.append(reduc);
|
||||
// Clones the foreach op to get a copy of the loop body.
|
||||
auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
|
||||
assert(args.size() == cloned.getBody()->getNumArguments());
|
||||
Operation *yield = cloned.getBody()->getTerminator();
|
||||
rewriter.mergeBlockBefore(cloned.getBody(), op, args);
|
||||
// clean up
|
||||
rewriter.eraseOp(cloned);
|
||||
reduc = yield->getOperands();
|
||||
rewriter.eraseOp(yield);
|
||||
});
|
||||
|
||||
rewriter.replaceOp(op, reduc);
|
||||
return success();
|
||||
}
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// The actual sparse tensor rewriting rules.
|
||||
//===---------------------------------------------------------------------===//
|
||||
@@ -752,36 +781,7 @@ public:
|
||||
// rule.
|
||||
if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
|
||||
if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) {
|
||||
// Foreach on constant.
|
||||
DenseElementsAttr indicesAttr = attr.getIndices();
|
||||
DenseElementsAttr valuesAttr = attr.getValues();
|
||||
|
||||
SmallVector<Value> args;
|
||||
for (int i = 0, e = valuesAttr.size(); i < e; i++) {
|
||||
auto valAttr = valuesAttr.getValues<TypedAttr>()[i];
|
||||
for (int j = 0; j < rank; j++) {
|
||||
auto coordAttr = indicesAttr.getValues<IntegerAttr>()[i * rank + j];
|
||||
auto coord = rewriter.create<arith::ConstantIndexOp>(
|
||||
loc, coordAttr.getInt());
|
||||
// Remaps coordinates.
|
||||
args.push_back(coord);
|
||||
}
|
||||
// Remaps value.
|
||||
auto val = rewriter.create<arith::ConstantOp>(loc, valAttr);
|
||||
args.push_back(val);
|
||||
// Remaps iteration args.
|
||||
args.append(reduc);
|
||||
auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
|
||||
Operation *yield = cloned.getBody()->getTerminator();
|
||||
rewriter.mergeBlockBefore(cloned.getBody(), op, args);
|
||||
// clean up
|
||||
args.clear();
|
||||
rewriter.eraseOp(cloned);
|
||||
reduc = yield->getOperands();
|
||||
rewriter.eraseOp(yield);
|
||||
}
|
||||
rewriter.replaceOp(op, reduc);
|
||||
return success();
|
||||
return genForeachOnSparseConstant(op, rewriter, attr);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user