[mlir][sparse] refine insertion code

builds SSA cycle for compress insertion loop
adds casting on index mismatch during push_back

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D136186
This commit is contained in:
Aart Bik
2022-10-18 10:35:00 -07:00
parent 3e8eff3747
commit d22df0ebba
2 changed files with 97 additions and 14 deletions

View File

@@ -265,11 +265,14 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
}
/// Creates a straightforward counting for-loop.
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count) {
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count,
SmallVectorImpl<Value> &fields) {
Type indexType = builder.getIndexType();
Value zero = constantZero(builder, loc, indexType);
Value one = constantOne(builder, loc, indexType);
scf::ForOp forOp = builder.create<scf::ForOp>(loc, zero, count, one);
scf::ForOp forOp = builder.create<scf::ForOp>(loc, zero, count, one, fields);
for (unsigned i = 0, e = fields.size(); i < e; i++)
fields[i] = forOp.getRegionIterArg(i);
builder.setInsertionPointToStart(forOp.getBody());
return forOp;
}
@@ -280,6 +283,9 @@ static void createPushback(OpBuilder &builder, Location loc,
SmallVectorImpl<Value> &fields, unsigned field,
Value value) {
assert(field < fields.size());
Type etp = fields[field].getType().cast<ShapedType>().getElementType();
if (value.getType() != etp)
value = builder.create<arith::IndexCastOp>(loc, etp, value);
fields[field] =
builder.create<PushBackOp>(loc, fields[field].getType(), fields[1],
fields[field], value, APInt(64, field));
@@ -298,11 +304,8 @@ static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
if (rank != 1 || !isCompressedDim(rtp, 0) || !isUniqueDim(rtp, 0) ||
!isOrderedDim(rtp, 0))
return; // TODO: add codegen
// push_back memSizes pointers-0 0
// push_back memSizes indices-0 index
// push_back memSizes values value
Value zero = constantIndex(builder, loc, 0);
createPushback(builder, loc, fields, 2, zero);
createPushback(builder, loc, fields, 3, indices[0]);
createPushback(builder, loc, fields, 4, value);
}
@@ -316,9 +319,12 @@ static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
if (rtp.getShape().size() != 1 || !isCompressedDim(rtp, 0) ||
!isUniqueDim(rtp, 0) || !isOrderedDim(rtp, 0))
return; // TODO: add codegen
// push_back memSizes pointers-0 0
// push_back memSizes pointers-0 memSizes[2]
Value zero = constantIndex(builder, loc, 0);
Value two = constantIndex(builder, loc, 2);
Value size = builder.create<memref::LoadOp>(loc, fields[1], two);
createPushback(builder, loc, fields, 2, zero);
createPushback(builder, loc, fields, 2, size);
}
@@ -460,6 +466,7 @@ public:
Location loc = op.getLoc();
SmallVector<Value, 8> fields;
createAllocFields(rewriter, loc, resType, adaptor.getOperands(), fields);
// Replace operation with resulting memrefs.
rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
return success();
}
@@ -504,6 +511,7 @@ public:
// Generate optional insertion finalization code.
if (op.getHasInserts())
genEndInsert(rewriter, op.getLoc(), srcType, fields);
// Replace operation with resulting memrefs.
rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), srcType, fields));
return success();
}
@@ -591,23 +599,26 @@ public:
// sparsity of the expanded access pattern.
//
// Generate
// for (i = 0; i < count; i++) {
// out_memrefs = for (i = 0; i < count; i++)(in_memrefs) {
// index = added[i];
// value = values[index];
// insert({prev_indices, index}, value);
// new_memrefs = insert(in_memrefs, {prev_indices, index}, value);
// values[index] = 0;
// filled[index] = false;
// yield new_memrefs
// }
Value i = createFor(rewriter, loc, count).getInductionVar();
scf::ForOp loop = createFor(rewriter, loc, count, fields);
Value i = loop.getInductionVar();
Value index = rewriter.create<memref::LoadOp>(loc, added, i);
Value value = rewriter.create<memref::LoadOp>(loc, values, index);
indices.push_back(index);
// TODO: generate yield cycle
genInsert(rewriter, loc, dstType, fields, indices, value);
rewriter.create<memref::StoreOp>(loc, constantZero(rewriter, loc, eltType),
values, index);
rewriter.create<memref::StoreOp>(loc, constantI1(rewriter, loc, false),
filled, index);
rewriter.create<scf::YieldOp>(loc, fields);
// Deallocate the buffers on exit of the full loop nest.
Operation *parent = op;
for (; isa<scf::ForOp>(parent->getParentOp()) ||
@@ -620,7 +631,9 @@ public:
rewriter.create<memref::DeallocOp>(loc, values);
rewriter.create<memref::DeallocOp>(loc, filled);
rewriter.create<memref::DeallocOp>(loc, added);
rewriter.replaceOp(op, genTuple(rewriter, loc, dstType, fields));
// Replace operation with resulting memrefs.
rewriter.replaceOp(op,
genTuple(rewriter, loc, dstType, loop->getResults()));
return success();
}
};
@@ -641,6 +654,7 @@ public:
// Generate insertion.
Value value = adaptor.getValue();
genInsert(rewriter, op->getLoc(), dstType, fields, indices, value);
// Replace operation with resulting memrefs.
rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), dstType, fields));
return success();
}