[mlir][sparse] unconditionally use IndexType for sparse_tensor.specifier

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D144574
This commit is contained in:
Peiming Liu
2023-02-22 18:44:00 +00:00
parent 475bbea5be
commit 44ff23d5e4
17 changed files with 247 additions and 308 deletions

View File

@@ -67,25 +67,18 @@ static void flattenOperands(ValueRange operands,
}
}
/// Adds index conversions where needed.
static Value toType(OpBuilder &builder, Location loc, Value value, Type tp) {
if (value.getType() != tp)
return builder.create<arith::IndexCastOp>(loc, tp, value);
return value;
}
/// Generates a load with proper index typing.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
idx = toType(builder, loc, idx, builder.getIndexType());
idx = genCast(builder, loc, idx, builder.getIndexType());
return builder.create<memref::LoadOp>(loc, mem, idx);
}
/// Generates a store with proper index typing and (for indices) proper value.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
Value idx) {
idx = toType(builder, loc, idx, builder.getIndexType());
val = toType(builder, loc, val,
mem.getType().cast<ShapedType>().getElementType());
idx = genCast(builder, loc, idx, builder.getIndexType());
val = genCast(builder, loc, val,
mem.getType().cast<ShapedType>().getElementType());
builder.create<memref::StoreOp>(loc, val, mem, idx);
}
@@ -141,7 +134,7 @@ static void createPushback(OpBuilder &builder, Location loc,
auto pushBackOp = builder.create<PushBackOp>(
loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl), field,
toType(builder, loc, value, etp), repeat);
genCast(builder, loc, value, etp), repeat);
desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer());
desc.setSpecifierField(builder, loc, specFieldKind, lvl,
@@ -338,7 +331,7 @@ static Value genCompressed(OpBuilder &builder, Location loc,
msz = builder.create<arith::DivUIOp>(loc, msz, idxStrideC);
}
Value phim1 = builder.create<arith::SubIOp>(
loc, toType(builder, loc, phi, indexType), one);
loc, genCast(builder, loc, phi, indexType), one);
// Conditional expression.
Value lt =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, plo, phi);
@@ -350,9 +343,9 @@ static Value genCompressed(OpBuilder &builder, Location loc,
builder, loc, desc.getMemRefField(idxIndex),
idxStride > 1 ? builder.create<arith::MulIOp>(loc, phim1, idxStrideC)
: phim1);
Value eq = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
toType(builder, loc, crd, indexType),
indices[lvl]);
Value eq = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, genCast(builder, loc, crd, indexType),
indices[lvl]);
builder.create<scf::YieldOp>(loc, eq);
builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
if (lvl > 0)
@@ -1226,8 +1219,8 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
// Converts MemRefs back to Tensors.
Value data = rewriter.create<bufferization::ToTensorOp>(loc, dataBuf);
Value indices = rewriter.create<bufferization::ToTensorOp>(loc, idxBuf);
Value nnz = toType(rewriter, loc, desc.getValMemSize(rewriter, loc),
op.getNnz().getType());
Value nnz = genCast(rewriter, loc, desc.getValMemSize(rewriter, loc),
op.getNnz().getType());
rewriter.replaceOp(op, {data, indices, nnz});
return success();