[mlir][sparse] Add codegen rule for the push_back operator.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D134372
This commit is contained in:
bixia1
2022-09-21 09:26:09 -07:00
parent 47afaf2eb0
commit 4132bce9e5
3 changed files with 120 additions and 2 deletions

View File

@@ -613,6 +613,61 @@ public:
}
};
/// Sparse codegen rule for the push_back operator.
class SparsePushBackConverter : public OpConversionPattern<PushBackOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(PushBackOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Lower push_back(buffer, value) to:
// if (size(buffer) >= capacity(buffer))
// new_capacity = capacity(buffer)*2
// new_buffer = realloc(buffer, new_capacity)
// buffer = new_buffer
// store(buffer, value)
// size(buffer)++
Location loc = op->getLoc();
Value c0 = constantIndex(rewriter, loc, 0);
Value buffer = adaptor.getInBuffer();
Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue());
Value bufferSizes = adaptor.getBufferSizes();
Value size = rewriter.create<memref::LoadOp>(loc, bufferSizes, idx);
Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge,
size, capacity);
Value value = adaptor.getValue();
auto bufferType =
MemRefType::get({ShapedType::kDynamicSize}, value.getType());
scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
/*else=*/true);
// True branch.
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
Value c2 = constantIndex(rewriter, loc, 2);
capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
Value newBuffer =
rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
rewriter.create<scf::YieldOp>(loc, newBuffer);
// False branch.
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
rewriter.create<scf::YieldOp>(loc, buffer);
// Add the value to the end of the buffer.
rewriter.setInsertionPointAfter(ifOp);
buffer = ifOp.getResult(0);
rewriter.create<memref::StoreOp>(loc, value, buffer, size);
// Increment the size of the buffer by 1.
Value c1 = constantIndex(rewriter, loc, 1);
size = rewriter.create<arith::AddIOp>(loc, size, c1);
rewriter.create<memref::StoreOp>(loc, size, bufferSizes, idx);
rewriter.replaceOp(op, buffer);
return success();
}
};
/// Base class for getter-like operations, e.g., to_indices, to_pointers.
template <typename SourceOp, typename Base>
class SparseGetterOpConverter : public OpConversionPattern<SourceOp> {
@@ -697,6 +752,7 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
SparseCastConverter, SparseTensorAllocConverter,
SparseTensorDeallocConverter, SparseTensorLoadConverter,
SparseExpandConverter, SparseCompressConverter,
SparseToPointersConverter, SparseToIndicesConverter,
SparseToValuesConverter>(typeConverter, patterns.getContext());
SparsePushBackConverter, SparseToPointersConverter,
SparseToIndicesConverter, SparseToValuesConverter>(
typeConverter, patterns.getContext());
}