[mlir][sparse] lower number of entries op to actual code

works both along runtime path and pure codegen path

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D136389
This commit is contained in:
Aart Bik
2022-10-20 16:01:37 -07:00
parent 4c7218e770
commit 0f3e4d1afa
5 changed files with 100 additions and 12 deletions

View File

@@ -277,6 +277,12 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count,
return forOp;
}
/// Translates field index to memSizes index.
static unsigned getMemSizesIndex(unsigned field) {
assert(2 <= field);
return field - 2;
}
/// Creates a pushback op for given field and updates the fields array
/// accordingly.
static void createPushback(OpBuilder &builder, Location loc,
@@ -286,9 +292,9 @@ static void createPushback(OpBuilder &builder, Location loc,
Type etp = fields[field].getType().cast<ShapedType>().getElementType();
if (value.getType() != etp)
value = builder.create<arith::IndexCastOp>(loc, etp, value);
fields[field] =
builder.create<PushBackOp>(loc, fields[field].getType(), fields[1],
fields[field], value, APInt(64, field - 2));
fields[field] = builder.create<PushBackOp>(
loc, fields[field].getType(), fields[1], fields[field], value,
APInt(64, getMemSizesIndex(field)));
}
/// Generates insertion code.
@@ -739,6 +745,25 @@ public:
}
};
/// Sparse codegen rule for number of entries operator.
class SparseNumberOfEntriesConverter
: public OpConversionPattern<NumberOfEntriesOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Query memSizes for the actually stored values size.
auto tuple = getTuple(adaptor.getTensor());
auto fields = tuple.getInputs();
unsigned lastField = fields.size() - 1;
Value field =
constantIndex(rewriter, op.getLoc(), getMemSizesIndex(lastField));
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, fields[1], field);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
@@ -775,5 +800,6 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
SparseExpandConverter, SparseCompressConverter,
SparseInsertConverter, SparseToPointersConverter,
SparseToIndicesConverter, SparseToValuesConverter,
SparseConvertConverter>(typeConverter, patterns.getContext());
SparseConvertConverter, SparseNumberOfEntriesConverter>(
typeConverter, patterns.getContext());
}