[mlir][sparse] unify sparse_tensor.out rewriting rules (#70518)
This commit is contained in:
@@ -270,13 +270,6 @@ static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
/// Generates a call to release/delete a `SparseTensorCOO`.
|
||||
static void genDelCOOCall(OpBuilder &builder, Location loc, Type elemTp,
|
||||
Value coo) {
|
||||
SmallString<21> name{"delSparseTensorCOO", primaryTypeFunctionSuffix(elemTp)};
|
||||
createFuncCall(builder, loc, name, {}, coo, EmitCInterface::Off);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conversion rules.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -707,37 +700,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse conversion rule for the output operator.
|
||||
class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(OutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
const Location loc = op->getLoc();
|
||||
const auto srcTp = getSparseTensorType(op.getTensor());
|
||||
// Convert to default permuted COO.
|
||||
Value src = adaptor.getOperands()[0];
|
||||
SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
|
||||
Value coo = NewCallParams(rewriter, loc)
|
||||
.genBuffers(srcTp.withoutDimToLvl(), dimSizes)
|
||||
.genNewCall(Action::kToCOO, src);
|
||||
// Then output the tensor to external file with coordinates in the
|
||||
// externally visible lexicographic coordinate order. A sort is
|
||||
// required if the source was not in that order yet (note that the
|
||||
// sort can be dropped altogether if external format does not care
|
||||
// about the order at all, but here we assume it does).
|
||||
const Value sort = constantI1(rewriter, loc, !srcTp.isIdentity());
|
||||
SmallVector<Value, 3> outParams{coo, adaptor.getOperands()[1], sort};
|
||||
const Type elemTp = srcTp.getElementType();
|
||||
SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(elemTp)};
|
||||
createFuncCall(rewriter, loc, name, {}, outParams, EmitCInterface::Off);
|
||||
genDelCOOCall(rewriter, loc, elemTp, coo);
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse conversion rule for the sparse_tensor.pack operator.
|
||||
class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
|
||||
public:
|
||||
@@ -789,6 +751,5 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
|
||||
SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
|
||||
SparseTensorLoadConverter, SparseTensorInsertConverter,
|
||||
SparseTensorExpandConverter, SparseTensorCompressConverter,
|
||||
SparseTensorOutConverter, SparseTensorAssembleConverter>(
|
||||
typeConverter, patterns.getContext());
|
||||
SparseTensorAssembleConverter>(typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user