[mlir][sparse] lower number of entries op to actual code

works both along runtime path and pure codegen path

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D136389
This commit is contained in:
Aart Bik
2022-10-20 16:01:37 -07:00
parent 4c7218e770
commit 0f3e4d1afa
5 changed files with 100 additions and 12 deletions

View File

@@ -205,6 +205,15 @@ static void newParams(OpBuilder &builder, SmallVector<Value, 8> &params,
params.push_back(ptr);
}
/// Generates a call to obtain the values array.
static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
ValueRange ptr) {
SmallString<15> name{"sparseValues",
primaryTypeFunctionSuffix(tp.getElementType())};
return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On)
.getResult(0);
}
/// Generates a call to release/delete a `SparseTensorCOO`.
static void genDelCOOCall(OpBuilder &builder, Location loc, Type elemTp,
Value coo) {
@@ -903,11 +912,28 @@ public:
LogicalResult
matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resType = op.getType();
Type eltType = resType.cast<ShapedType>().getElementType();
SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)};
replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
EmitCInterface::On);
auto resType = op.getType().cast<ShapedType>();
rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType,
adaptor.getOperands()));
return success();
}
};
/// Sparse conversion rule for number of entries operator.
class SparseNumberOfEntriesConverter
: public OpConversionPattern<NumberOfEntriesOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// Query values array size for the actually stored values size.
Type eltType = op.getTensor().getType().cast<ShapedType>().getElementType();
auto resTp = MemRefType::get({ShapedType::kDynamicSize}, eltType);
Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands());
rewriter.replaceOpWithNewOp<memref::DimOp>(op, values,
constantIndex(rewriter, loc, 0));
return success();
}
};
@@ -1250,9 +1276,10 @@ void mlir::populateSparseTensorConversionPatterns(
SparseTensorConcatConverter, SparseTensorAllocConverter,
SparseTensorDeallocConverter, SparseTensorToPointersConverter,
SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
SparseTensorLoadConverter, SparseTensorInsertConverter,
SparseTensorExpandConverter, SparseTensorCompressConverter,
SparseTensorOutConverter>(typeConverter, patterns.getContext());
SparseNumberOfEntriesConverter, SparseTensorLoadConverter,
SparseTensorInsertConverter, SparseTensorExpandConverter,
SparseTensorCompressConverter, SparseTensorOutConverter>(
typeConverter, patterns.getContext());
patterns.add<SparseTensorConvertConverter>(typeConverter,
patterns.getContext(), options);