[mlir][sparse] support sparse tensor element type conversion in codegen path

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D144578
This commit is contained in:
Peiming Liu
2023-02-22 19:04:02 +00:00
parent 230e61658b
commit 85dbb3fc4b
8 changed files with 265 additions and 22 deletions

View File

@@ -1030,11 +1030,73 @@ public:
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
SparseTensorEncodingAttr encSrc =
getSparseTensorEncoding(op.getSource().getType());
if (encDst != encSrc) {
// This should be handled by rewriting before codegen.
// Different encoding (except for different bitwidth) should be handled by
// rewriting.
if (encDst.withoutBitWidths() != encSrc.withoutBitWidths()) {
return failure();
}
rewriter.replaceOp(op, adaptor.getSource());
Type retElemTp = op.getResult().getType().getElementType();
Type srcElemTp = op.getSource().getType().getElementType();
// Fold the trivial cases.
if (retElemTp == srcElemTp && encDst == encSrc) {
rewriter.replaceOp(op, adaptor.getSource());
return success();
}
//
// Do element-wise type conversion without using InsertOp.
//
// for each memref in srcTensor:
// dst = memref.alloc
// if srcMemRefType != dstMemRefType:
// for every dst[i] = cast(src[i])
// else:
// dst = memref.copy(src)
Location loc = op.getLoc();
auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource());
SmallVector<Value> fields;
foreachFieldAndTypeInSparseTensor(
SparseTensorType(op.getResult().getType().cast<RankedTensorType>()),
[&rewriter, &fields, srcDesc,
loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl,
DimLevelType /*dlt*/) -> bool {
// Simply reuses the storage specifier as it is an SSA value.
if (fKind == SparseTensorFieldKind::StorageSpec) {
fields.push_back(srcDesc.getSpecifier());
} else {
// Allocates new memrefs
Value srcMem = srcDesc.getMemRefField(fIdx);
// TODO: We can instead use the actual memSize in specifier, that
// would require a subViewOp to avoid overflow when copying
// values.
Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
auto dstMem = rewriter.create<memref::AllocOp>(
loc, fTp.cast<MemRefType>(), sz);
if (fTp != srcMem.getType()) {
// Converts elements type.
scf::buildLoopNest(
rewriter, loc, constantIndex(rewriter, loc, 0), sz,
constantIndex(rewriter, loc, 1),
[srcMem, &dstMem](OpBuilder &builder, Location loc,
ValueRange ivs) {
Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs);
Value casted = genCast(builder, loc, v,
dstMem.getType().getElementType());
builder.create<memref::StoreOp>(loc, casted, dstMem, ivs);
});
} else {
// TODO: We can even reuse the same memref for the new tensor,
// but that requires a `ref-counting` based memory management
// for shared memrefs between multiple sparse tensors.
rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
}
fields.push_back(dstMem);
}
return true;
});
rewriter.replaceOp(
op, genTuple(rewriter, loc, op.getResult().getType(), fields));
return success();
}
};