[mlir][sparse] reuse tensor.insert operation to insert elements into … (#84987)
…a sparse tensor.
This commit is contained in:
@@ -1014,24 +1014,29 @@ public:
|
||||
};
|
||||
|
||||
/// Sparse codegen rule for the insert operator.
|
||||
class SparseInsertConverter : public OpConversionPattern<InsertOp> {
|
||||
class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(InsertOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto stt = getSparseTensorType(adaptor.getDest());
|
||||
if (!stt.hasEncoding())
|
||||
return failure();
|
||||
assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
|
||||
|
||||
Location loc = op.getLoc();
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getDest());
|
||||
TypeRange flatSpTensorTps = desc.getFields().getTypes();
|
||||
SmallVector<Value> params = llvm::to_vector(desc.getFields());
|
||||
params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
|
||||
params.push_back(adaptor.getValue());
|
||||
SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
|
||||
params.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
|
||||
params.push_back(adaptor.getScalar());
|
||||
SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
|
||||
params, /*genCall=*/true);
|
||||
SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
|
||||
// Replace operation with resulting memrefs.
|
||||
rewriter.replaceOp(op,
|
||||
genTuple(rewriter, loc, op.getTensor().getType(), ret));
|
||||
genTuple(rewriter, loc, op.getDest().getType(), ret));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user