[mlir][sparse] add source materizalization callback for sparse tensor codegen type converter.

Required by scf.for to achieve 1:N type conversion (See D136314).

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D136367
This commit is contained in:
Peiming Liu
2022-10-20 16:53:56 +00:00
parent 4c4909703d
commit d12d4857c5

View File

@@ -745,6 +745,17 @@ public:
mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
addConversion([](Type type) { return type; });
addConversion(convertSparseTensorType);
// Required by scf.for 1:N type conversion.
addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp,
ValueRange inputs,
Location loc) -> Optional<Value> {
if (!getSparseTensorEncoding(tp))
// Not a sparse tensor.
return llvm::None;
// Sparse compiler knows how to cancel out these casts.
return genTuple(builder, loc, tp, inputs);
});
}
//===----------------------------------------------------------------------===//