[mlir][sparse] unconditionally use IndexType for sparse_tensor.specifier
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D144574
This commit is contained in:
@@ -67,25 +67,18 @@ static void flattenOperands(ValueRange operands,
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds index conversions where needed.
|
||||
static Value toType(OpBuilder &builder, Location loc, Value value, Type tp) {
|
||||
if (value.getType() != tp)
|
||||
return builder.create<arith::IndexCastOp>(loc, tp, value);
|
||||
return value;
|
||||
}
|
||||
|
||||
/// Generates a load with proper index typing.
|
||||
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
|
||||
idx = toType(builder, loc, idx, builder.getIndexType());
|
||||
idx = genCast(builder, loc, idx, builder.getIndexType());
|
||||
return builder.create<memref::LoadOp>(loc, mem, idx);
|
||||
}
|
||||
|
||||
/// Generates a store with proper index typing and (for indices) proper value.
|
||||
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
|
||||
Value idx) {
|
||||
idx = toType(builder, loc, idx, builder.getIndexType());
|
||||
val = toType(builder, loc, val,
|
||||
mem.getType().cast<ShapedType>().getElementType());
|
||||
idx = genCast(builder, loc, idx, builder.getIndexType());
|
||||
val = genCast(builder, loc, val,
|
||||
mem.getType().cast<ShapedType>().getElementType());
|
||||
builder.create<memref::StoreOp>(loc, val, mem, idx);
|
||||
}
|
||||
|
||||
@@ -141,7 +134,7 @@ static void createPushback(OpBuilder &builder, Location loc,
|
||||
|
||||
auto pushBackOp = builder.create<PushBackOp>(
|
||||
loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl), field,
|
||||
toType(builder, loc, value, etp), repeat);
|
||||
genCast(builder, loc, value, etp), repeat);
|
||||
|
||||
desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer());
|
||||
desc.setSpecifierField(builder, loc, specFieldKind, lvl,
|
||||
@@ -338,7 +331,7 @@ static Value genCompressed(OpBuilder &builder, Location loc,
|
||||
msz = builder.create<arith::DivUIOp>(loc, msz, idxStrideC);
|
||||
}
|
||||
Value phim1 = builder.create<arith::SubIOp>(
|
||||
loc, toType(builder, loc, phi, indexType), one);
|
||||
loc, genCast(builder, loc, phi, indexType), one);
|
||||
// Conditional expression.
|
||||
Value lt =
|
||||
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, plo, phi);
|
||||
@@ -350,9 +343,9 @@ static Value genCompressed(OpBuilder &builder, Location loc,
|
||||
builder, loc, desc.getMemRefField(idxIndex),
|
||||
idxStride > 1 ? builder.create<arith::MulIOp>(loc, phim1, idxStrideC)
|
||||
: phim1);
|
||||
Value eq = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
|
||||
toType(builder, loc, crd, indexType),
|
||||
indices[lvl]);
|
||||
Value eq = builder.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, genCast(builder, loc, crd, indexType),
|
||||
indices[lvl]);
|
||||
builder.create<scf::YieldOp>(loc, eq);
|
||||
builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
|
||||
if (lvl > 0)
|
||||
@@ -1226,8 +1219,8 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
|
||||
// Converts MemRefs back to Tensors.
|
||||
Value data = rewriter.create<bufferization::ToTensorOp>(loc, dataBuf);
|
||||
Value indices = rewriter.create<bufferization::ToTensorOp>(loc, idxBuf);
|
||||
Value nnz = toType(rewriter, loc, desc.getValMemSize(rewriter, loc),
|
||||
op.getNnz().getType());
|
||||
Value nnz = genCast(rewriter, loc, desc.getValMemSize(rewriter, loc),
|
||||
op.getNnz().getType());
|
||||
|
||||
rewriter.replaceOp(op, {data, indices, nnz});
|
||||
return success();
|
||||
|
||||
Reference in New Issue
Block a user