[mlir][sparse] reuse tensor.insert operation to insert elements into … (#84987)

…a sparse tensor.
This commit is contained in:
Peiming Liu
2024-03-12 16:59:17 -07:00
committed by GitHub
parent 1c3b15e9f5
commit 94e27c265a
26 changed files with 106 additions and 182 deletions

View File

@@ -1014,24 +1014,29 @@ public:
};
/// Sparse codegen rule for the insert operator.
class SparseInsertConverter : public OpConversionPattern<InsertOp> {
class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(InsertOp op, OpAdaptor adaptor,
matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto stt = getSparseTensorType(adaptor.getDest());
if (!stt.hasEncoding())
return failure();
assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
Location loc = op.getLoc();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
auto desc = getDescriptorFromTensorTuple(adaptor.getDest());
TypeRange flatSpTensorTps = desc.getFields().getTypes();
SmallVector<Value> params = llvm::to_vector(desc.getFields());
params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
params.push_back(adaptor.getValue());
SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
params.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
params.push_back(adaptor.getScalar());
SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
params, /*genCall=*/true);
SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
// Replace operation with resulting memrefs.
rewriter.replaceOp(op,
genTuple(rewriter, loc, op.getTensor().getType(), ret));
genTuple(rewriter, loc, op.getDest().getType(), ret));
return success();
}
};