[mlir][sparse] lower number of entries op to actual code
works both along runtime path and pure codegen path Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D136389
This commit is contained in:
@@ -205,6 +205,15 @@ static void newParams(OpBuilder &builder, SmallVector<Value, 8> ¶ms,
|
||||
params.push_back(ptr);
|
||||
}
|
||||
|
||||
/// Generates a call to obtain the values array.
|
||||
static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
|
||||
ValueRange ptr) {
|
||||
SmallString<15> name{"sparseValues",
|
||||
primaryTypeFunctionSuffix(tp.getElementType())};
|
||||
return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On)
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
/// Generates a call to release/delete a `SparseTensorCOO`.
|
||||
static void genDelCOOCall(OpBuilder &builder, Location loc, Type elemTp,
|
||||
Value coo) {
|
||||
@@ -903,11 +912,28 @@ public:
|
||||
LogicalResult
|
||||
matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type resType = op.getType();
|
||||
Type eltType = resType.cast<ShapedType>().getElementType();
|
||||
SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)};
|
||||
replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
|
||||
EmitCInterface::On);
|
||||
auto resType = op.getType().cast<ShapedType>();
|
||||
rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType,
|
||||
adaptor.getOperands()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse conversion rule for number of entries operator.
|
||||
class SparseNumberOfEntriesConverter
|
||||
: public OpConversionPattern<NumberOfEntriesOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
// Query values array size for the actually stored values size.
|
||||
Type eltType = op.getTensor().getType().cast<ShapedType>().getElementType();
|
||||
auto resTp = MemRefType::get({ShapedType::kDynamicSize}, eltType);
|
||||
Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands());
|
||||
rewriter.replaceOpWithNewOp<memref::DimOp>(op, values,
|
||||
constantIndex(rewriter, loc, 0));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -1250,9 +1276,10 @@ void mlir::populateSparseTensorConversionPatterns(
|
||||
SparseTensorConcatConverter, SparseTensorAllocConverter,
|
||||
SparseTensorDeallocConverter, SparseTensorToPointersConverter,
|
||||
SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
|
||||
SparseTensorLoadConverter, SparseTensorInsertConverter,
|
||||
SparseTensorExpandConverter, SparseTensorCompressConverter,
|
||||
SparseTensorOutConverter>(typeConverter, patterns.getContext());
|
||||
SparseNumberOfEntriesConverter, SparseTensorLoadConverter,
|
||||
SparseTensorInsertConverter, SparseTensorExpandConverter,
|
||||
SparseTensorCompressConverter, SparseTensorOutConverter>(
|
||||
typeConverter, patterns.getContext());
|
||||
|
||||
patterns.add<SparseTensorConvertConverter>(typeConverter,
|
||||
patterns.getContext(), options);
|
||||
|
||||
Reference in New Issue
Block a user