[mlir][sparse] adding SparseTensorType::get{Pointer,Index}Type methods

Depends On D143800

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D143946
This commit is contained in:
wren romano
2023-02-15 13:31:05 -08:00
parent 8bd0e9481c
commit ae7942e296
5 changed files with 26 additions and 13 deletions

View File

@@ -160,7 +160,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
// Append linear x pointers, initialized to zero. Since each compressed
// dimension initially already has a single zero entry, this maintains
// the desired "linear + 1" length property at all times.
Type ptrType = stt.getEncoding().getPointerType();
Type ptrType = stt.getPointerType();
Value ptrZero = constantZero(builder, loc, ptrType);
createPushback(builder, loc, desc, SparseTensorFieldKind::PtrMemRef, l,
ptrZero, linear);
@@ -279,8 +279,7 @@ static void createAllocFields(OpBuilder &builder, Location loc,
// to all zeros, sets the dimSizes to known values and gives all pointer
// fields an initial zero entry, so that it is easier to maintain the
// "linear + 1" length property.
Value ptrZero =
constantZero(builder, loc, stt.getEncoding().getPointerType());
Value ptrZero = constantZero(builder, loc, stt.getPointerType());
for (Level lvlRank = stt.getLvlRank(), l = 0; l < lvlRank; l++) {
// Fills dim sizes array.
// FIXME: this method seems to set *level* sizes, but the name is confusing
@@ -546,7 +545,7 @@ static void genEndInsert(OpBuilder &builder, Location loc,
// times?
//
if (l > 0) {
Type ptrType = stt.getEncoding().getPointerType();
Type ptrType = stt.getPointerType();
Value ptrMemRef = desc.getPtrMemRef(l);
Value hi = desc.getPtrMemSize(builder, loc, l);
Value zero = constantIndex(builder, loc, 0);