[mlir][sparse] Add codegen rule for the push_back operator.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D134372
This commit is contained in:
@@ -613,6 +613,61 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse codegen rule for the push_back operator.
|
||||
class SparsePushBackConverter : public OpConversionPattern<PushBackOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(PushBackOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Lower push_back(buffer, value) to:
|
||||
// if (size(buffer) >= capacity(buffer))
|
||||
// new_capacity = capacity(buffer)*2
|
||||
// new_buffer = realloc(buffer, new_capacity)
|
||||
// buffer = new_buffer
|
||||
// store(buffer, value)
|
||||
// size(buffer)++
|
||||
Location loc = op->getLoc();
|
||||
Value c0 = constantIndex(rewriter, loc, 0);
|
||||
Value buffer = adaptor.getInBuffer();
|
||||
Value capacity = rewriter.create<memref::DimOp>(loc, buffer, c0);
|
||||
Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue());
|
||||
Value bufferSizes = adaptor.getBufferSizes();
|
||||
Value size = rewriter.create<memref::LoadOp>(loc, bufferSizes, idx);
|
||||
Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge,
|
||||
size, capacity);
|
||||
Value value = adaptor.getValue();
|
||||
auto bufferType =
|
||||
MemRefType::get({ShapedType::kDynamicSize}, value.getType());
|
||||
scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, bufferType, cond,
|
||||
/*else=*/true);
|
||||
// True branch.
|
||||
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
|
||||
Value c2 = constantIndex(rewriter, loc, 2);
|
||||
capacity = rewriter.create<arith::MulIOp>(loc, capacity, c2);
|
||||
Value newBuffer =
|
||||
rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
|
||||
rewriter.create<scf::YieldOp>(loc, newBuffer);
|
||||
|
||||
// False branch.
|
||||
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
|
||||
rewriter.create<scf::YieldOp>(loc, buffer);
|
||||
|
||||
// Add the value to the end of the buffer.
|
||||
rewriter.setInsertionPointAfter(ifOp);
|
||||
buffer = ifOp.getResult(0);
|
||||
rewriter.create<memref::StoreOp>(loc, value, buffer, size);
|
||||
|
||||
// Increment the size of the buffer by 1.
|
||||
Value c1 = constantIndex(rewriter, loc, 1);
|
||||
size = rewriter.create<arith::AddIOp>(loc, size, c1);
|
||||
rewriter.create<memref::StoreOp>(loc, size, bufferSizes, idx);
|
||||
|
||||
rewriter.replaceOp(op, buffer);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Base class for getter-like operations, e.g., to_indices, to_pointers.
|
||||
template <typename SourceOp, typename Base>
|
||||
class SparseGetterOpConverter : public OpConversionPattern<SourceOp> {
|
||||
@@ -697,6 +752,7 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
|
||||
SparseCastConverter, SparseTensorAllocConverter,
|
||||
SparseTensorDeallocConverter, SparseTensorLoadConverter,
|
||||
SparseExpandConverter, SparseCompressConverter,
|
||||
SparseToPointersConverter, SparseToIndicesConverter,
|
||||
SparseToValuesConverter>(typeConverter, patterns.getContext());
|
||||
SparsePushBackConverter, SparseToPointersConverter,
|
||||
SparseToIndicesConverter, SparseToValuesConverter>(
|
||||
typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user