[mlir][sparse] support sparse tensor element type conversion in codegen path
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D144578
This commit is contained in:
@@ -1030,11 +1030,73 @@ public:
|
||||
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
|
||||
SparseTensorEncodingAttr encSrc =
|
||||
getSparseTensorEncoding(op.getSource().getType());
|
||||
if (encDst != encSrc) {
|
||||
// This should be handled by rewriting before codegen.
|
||||
// Different encoding (except for different bitwidth) should be handled by
|
||||
// rewriting.
|
||||
if (encDst.withoutBitWidths() != encSrc.withoutBitWidths()) {
|
||||
return failure();
|
||||
}
|
||||
rewriter.replaceOp(op, adaptor.getSource());
|
||||
|
||||
Type retElemTp = op.getResult().getType().getElementType();
|
||||
Type srcElemTp = op.getSource().getType().getElementType();
|
||||
// Fold the trivial cases.
|
||||
if (retElemTp == srcElemTp && encDst == encSrc) {
|
||||
rewriter.replaceOp(op, adaptor.getSource());
|
||||
return success();
|
||||
}
|
||||
//
|
||||
// Do element-wise type conversion without using InsertOp.
|
||||
//
|
||||
// for each memref in srcTensor:
|
||||
// dst = memref.alloc
|
||||
// if srcMemRefType != dstMemRefType:
|
||||
// for every dst[i] = cast(src[i])
|
||||
// else:
|
||||
// dst = memref.copy(src)
|
||||
Location loc = op.getLoc();
|
||||
auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource());
|
||||
SmallVector<Value> fields;
|
||||
foreachFieldAndTypeInSparseTensor(
|
||||
SparseTensorType(op.getResult().getType().cast<RankedTensorType>()),
|
||||
[&rewriter, &fields, srcDesc,
|
||||
loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl,
|
||||
DimLevelType /*dlt*/) -> bool {
|
||||
// Simply reuses the storage specifier as it is an SSA value.
|
||||
if (fKind == SparseTensorFieldKind::StorageSpec) {
|
||||
fields.push_back(srcDesc.getSpecifier());
|
||||
} else {
|
||||
// Allocates new memrefs
|
||||
Value srcMem = srcDesc.getMemRefField(fIdx);
|
||||
// TODO: We can instead use the actual memSize in specifier, that
|
||||
// would require a subViewOp to avoid overflow when copying
|
||||
// values.
|
||||
Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
|
||||
auto dstMem = rewriter.create<memref::AllocOp>(
|
||||
loc, fTp.cast<MemRefType>(), sz);
|
||||
if (fTp != srcMem.getType()) {
|
||||
// Converts elements type.
|
||||
scf::buildLoopNest(
|
||||
rewriter, loc, constantIndex(rewriter, loc, 0), sz,
|
||||
constantIndex(rewriter, loc, 1),
|
||||
[srcMem, &dstMem](OpBuilder &builder, Location loc,
|
||||
ValueRange ivs) {
|
||||
Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs);
|
||||
Value casted = genCast(builder, loc, v,
|
||||
dstMem.getType().getElementType());
|
||||
builder.create<memref::StoreOp>(loc, casted, dstMem, ivs);
|
||||
});
|
||||
} else {
|
||||
// TODO: We can even reuse the same memref for the new tensor,
|
||||
// but that requires a `ref-counting` based memory management
|
||||
// for shared memrefs between multiple sparse tensors.
|
||||
rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
|
||||
}
|
||||
fields.push_back(dstMem);
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
rewriter.replaceOp(
|
||||
op, genTuple(rewriter, loc, op.getResult().getType(), fields));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user