[mlir][sparse] refine insertion code
builds SSA cycle for compress insertion loop adds casting on index mismatch during push_back Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D136186
This commit is contained in:
@@ -265,11 +265,14 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
|
||||
}
|
||||
|
||||
/// Creates a straightforward counting for-loop.
|
||||
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count) {
|
||||
static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count,
|
||||
SmallVectorImpl<Value> &fields) {
|
||||
Type indexType = builder.getIndexType();
|
||||
Value zero = constantZero(builder, loc, indexType);
|
||||
Value one = constantOne(builder, loc, indexType);
|
||||
scf::ForOp forOp = builder.create<scf::ForOp>(loc, zero, count, one);
|
||||
scf::ForOp forOp = builder.create<scf::ForOp>(loc, zero, count, one, fields);
|
||||
for (unsigned i = 0, e = fields.size(); i < e; i++)
|
||||
fields[i] = forOp.getRegionIterArg(i);
|
||||
builder.setInsertionPointToStart(forOp.getBody());
|
||||
return forOp;
|
||||
}
|
||||
@@ -280,6 +283,9 @@ static void createPushback(OpBuilder &builder, Location loc,
|
||||
SmallVectorImpl<Value> &fields, unsigned field,
|
||||
Value value) {
|
||||
assert(field < fields.size());
|
||||
Type etp = fields[field].getType().cast<ShapedType>().getElementType();
|
||||
if (value.getType() != etp)
|
||||
value = builder.create<arith::IndexCastOp>(loc, etp, value);
|
||||
fields[field] =
|
||||
builder.create<PushBackOp>(loc, fields[field].getType(), fields[1],
|
||||
fields[field], value, APInt(64, field));
|
||||
@@ -298,11 +304,8 @@ static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
|
||||
if (rank != 1 || !isCompressedDim(rtp, 0) || !isUniqueDim(rtp, 0) ||
|
||||
!isOrderedDim(rtp, 0))
|
||||
return; // TODO: add codegen
|
||||
// push_back memSizes pointers-0 0
|
||||
// push_back memSizes indices-0 index
|
||||
// push_back memSizes values value
|
||||
Value zero = constantIndex(builder, loc, 0);
|
||||
createPushback(builder, loc, fields, 2, zero);
|
||||
createPushback(builder, loc, fields, 3, indices[0]);
|
||||
createPushback(builder, loc, fields, 4, value);
|
||||
}
|
||||
@@ -316,9 +319,12 @@ static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
|
||||
if (rtp.getShape().size() != 1 || !isCompressedDim(rtp, 0) ||
|
||||
!isUniqueDim(rtp, 0) || !isOrderedDim(rtp, 0))
|
||||
return; // TODO: add codegen
|
||||
// push_back memSizes pointers-0 0
|
||||
// push_back memSizes pointers-0 memSizes[2]
|
||||
Value zero = constantIndex(builder, loc, 0);
|
||||
Value two = constantIndex(builder, loc, 2);
|
||||
Value size = builder.create<memref::LoadOp>(loc, fields[1], two);
|
||||
createPushback(builder, loc, fields, 2, zero);
|
||||
createPushback(builder, loc, fields, 2, size);
|
||||
}
|
||||
|
||||
@@ -460,6 +466,7 @@ public:
|
||||
Location loc = op.getLoc();
|
||||
SmallVector<Value, 8> fields;
|
||||
createAllocFields(rewriter, loc, resType, adaptor.getOperands(), fields);
|
||||
// Replace operation with resulting memrefs.
|
||||
rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
|
||||
return success();
|
||||
}
|
||||
@@ -504,6 +511,7 @@ public:
|
||||
// Generate optional insertion finalization code.
|
||||
if (op.getHasInserts())
|
||||
genEndInsert(rewriter, op.getLoc(), srcType, fields);
|
||||
// Replace operation with resulting memrefs.
|
||||
rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), srcType, fields));
|
||||
return success();
|
||||
}
|
||||
@@ -591,23 +599,26 @@ public:
|
||||
// sparsity of the expanded access pattern.
|
||||
//
|
||||
// Generate
|
||||
// for (i = 0; i < count; i++) {
|
||||
// out_memrefs = for (i = 0; i < count; i++)(in_memrefs) {
|
||||
// index = added[i];
|
||||
// value = values[index];
|
||||
// insert({prev_indices, index}, value);
|
||||
// new_memrefs = insert(in_memrefs, {prev_indices, index}, value);
|
||||
// values[index] = 0;
|
||||
// filled[index] = false;
|
||||
// yield new_memrefs
|
||||
// }
|
||||
Value i = createFor(rewriter, loc, count).getInductionVar();
|
||||
scf::ForOp loop = createFor(rewriter, loc, count, fields);
|
||||
Value i = loop.getInductionVar();
|
||||
Value index = rewriter.create<memref::LoadOp>(loc, added, i);
|
||||
Value value = rewriter.create<memref::LoadOp>(loc, values, index);
|
||||
indices.push_back(index);
|
||||
// TODO: generate yield cycle
|
||||
genInsert(rewriter, loc, dstType, fields, indices, value);
|
||||
rewriter.create<memref::StoreOp>(loc, constantZero(rewriter, loc, eltType),
|
||||
values, index);
|
||||
rewriter.create<memref::StoreOp>(loc, constantI1(rewriter, loc, false),
|
||||
filled, index);
|
||||
rewriter.create<scf::YieldOp>(loc, fields);
|
||||
// Deallocate the buffers on exit of the full loop nest.
|
||||
Operation *parent = op;
|
||||
for (; isa<scf::ForOp>(parent->getParentOp()) ||
|
||||
@@ -620,7 +631,9 @@ public:
|
||||
rewriter.create<memref::DeallocOp>(loc, values);
|
||||
rewriter.create<memref::DeallocOp>(loc, filled);
|
||||
rewriter.create<memref::DeallocOp>(loc, added);
|
||||
rewriter.replaceOp(op, genTuple(rewriter, loc, dstType, fields));
|
||||
// Replace operation with resulting memrefs.
|
||||
rewriter.replaceOp(op,
|
||||
genTuple(rewriter, loc, dstType, loop->getResults()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -641,6 +654,7 @@ public:
|
||||
// Generate insertion.
|
||||
Value value = adaptor.getValue();
|
||||
genInsert(rewriter, op->getLoc(), dstType, fields, indices, value);
|
||||
// Replace operation with resulting memrefs.
|
||||
rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), dstType, fields));
|
||||
return success();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user