[mlir][sparse] lower number of entries op to actual code
works both along runtime path and pure codegen path Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D136389
This commit is contained in:
@@ -277,6 +277,12 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count,
|
||||
return forOp;
|
||||
}
|
||||
|
||||
/// Translates field index to memSizes index.
|
||||
static unsigned getMemSizesIndex(unsigned field) {
|
||||
assert(2 <= field);
|
||||
return field - 2;
|
||||
}
|
||||
|
||||
/// Creates a pushback op for given field and updates the fields array
|
||||
/// accordingly.
|
||||
static void createPushback(OpBuilder &builder, Location loc,
|
||||
@@ -286,9 +292,9 @@ static void createPushback(OpBuilder &builder, Location loc,
|
||||
Type etp = fields[field].getType().cast<ShapedType>().getElementType();
|
||||
if (value.getType() != etp)
|
||||
value = builder.create<arith::IndexCastOp>(loc, etp, value);
|
||||
fields[field] =
|
||||
builder.create<PushBackOp>(loc, fields[field].getType(), fields[1],
|
||||
fields[field], value, APInt(64, field - 2));
|
||||
fields[field] = builder.create<PushBackOp>(
|
||||
loc, fields[field].getType(), fields[1], fields[field], value,
|
||||
APInt(64, getMemSizesIndex(field)));
|
||||
}
|
||||
|
||||
/// Generates insertion code.
|
||||
@@ -739,6 +745,25 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse codegen rule for number of entries operator.
|
||||
class SparseNumberOfEntriesConverter
|
||||
: public OpConversionPattern<NumberOfEntriesOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Query memSizes for the actually stored values size.
|
||||
auto tuple = getTuple(adaptor.getTensor());
|
||||
auto fields = tuple.getInputs();
|
||||
unsigned lastField = fields.size() - 1;
|
||||
Value field =
|
||||
constantIndex(rewriter, op.getLoc(), getMemSizesIndex(lastField));
|
||||
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, fields[1], field);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -775,5 +800,6 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
|
||||
SparseExpandConverter, SparseCompressConverter,
|
||||
SparseInsertConverter, SparseToPointersConverter,
|
||||
SparseToIndicesConverter, SparseToValuesConverter,
|
||||
SparseConvertConverter>(typeConverter, patterns.getContext());
|
||||
SparseConvertConverter, SparseNumberOfEntriesConverter>(
|
||||
typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user