[mlir][sparse] add source materizalization callback for sparse tensor codegen type converter.
Required by scf.for to achieve 1:N type conversion (See D136314). Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D136367
This commit is contained in:
@@ -745,6 +745,17 @@ public:
|
||||
mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
|
||||
addConversion([](Type type) { return type; });
|
||||
addConversion(convertSparseTensorType);
|
||||
|
||||
// Required by scf.for 1:N type conversion.
|
||||
addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp,
|
||||
ValueRange inputs,
|
||||
Location loc) -> Optional<Value> {
|
||||
if (!getSparseTensorEncoding(tp))
|
||||
// Not a sparse tensor.
|
||||
return llvm::None;
|
||||
// Sparse compiler knows how to cancel out these casts.
|
||||
return genTuple(builder, loc, tp, inputs);
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Reference in New Issue
Block a user