[mlir][sparse] Improve the implementation of sparse_tensor.new for the codegen path.
Rewrite a NewOp into a NewOp of a sorted COO tensor and a ConvertOp for converting the sorted COO tensor to the destination tensor type. Codegen a NewOp of a sorted COO tensor to use the new bulk reader API and sort the elements only when the input is not sorted. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D144504
This commit is contained in:
@@ -1289,6 +1289,137 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(NewOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
const auto dstTp = getSparseTensorType(op.getResult());
|
||||
const auto encDst = dstTp.getEncoding();
|
||||
// Creating COO with NewOp is handled by direct IR codegen. All other cases
|
||||
// are handled by rewriting.
|
||||
if (!dstTp.hasEncoding() || getCOOStart(encDst) != 0)
|
||||
return failure();
|
||||
|
||||
// Implement the NewOp(filename) as follows:
|
||||
// reader = getSparseTensorReader(filename)
|
||||
// nse = getSparseTensorNNZ()
|
||||
// tmp = bufferization.alloc_tensor an ordered COO with
|
||||
// dst dim ordering, size_hint = nse
|
||||
// indices = to_indices_buffer(tmp)
|
||||
// values = to_values(tmp)
|
||||
// isSorted = getSparseTensorReaderRead(indices, values, dimOrdering)
|
||||
// if (!isSorted) sort_coo(nse, indices, values)
|
||||
// update storage specifier
|
||||
// dst = sparse_tensor.ConvertOp tmp
|
||||
|
||||
// Create a sparse tensor reader.
|
||||
Value fileName = op.getSource();
|
||||
Type opaqueTp = getOpaquePointerType(rewriter);
|
||||
Value reader = createFuncCall(rewriter, loc, "createSparseTensorReader",
|
||||
{opaqueTp}, {fileName}, EmitCInterface::Off)
|
||||
.getResult(0);
|
||||
|
||||
Type indexTp = rewriter.getIndexType();
|
||||
const Dimension dimRank = dstTp.getDimRank();
|
||||
|
||||
// If the result tensor has dynamic dimensions, get the dynamic sizes from
|
||||
// the sparse tensor reader.
|
||||
SmallVector<Value> dynSizes;
|
||||
if (dstTp.hasDynamicDimShape()) {
|
||||
Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp);
|
||||
createFuncCall(rewriter, loc, "copySparseTensorReaderDimSizes", {},
|
||||
{reader, dimSizes}, EmitCInterface::On)
|
||||
.getResult(0);
|
||||
ArrayRef<int64_t> dstShape = dstTp.getRankedTensorType().getShape();
|
||||
for (auto &d : llvm::enumerate(dstShape)) {
|
||||
if (d.value() == ShapedType::kDynamic) {
|
||||
dynSizes.push_back(rewriter.create<memref::LoadOp>(
|
||||
loc, dimSizes, constantIndex(rewriter, loc, d.index())));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNNZ",
|
||||
{indexTp}, {reader}, EmitCInterface::Off)
|
||||
.getResult(0);
|
||||
// Construct allocation for each field.
|
||||
SmallVector<Value> fields;
|
||||
createAllocFields(rewriter, loc, dstTp, dynSizes, /*enableInit=*/false,
|
||||
fields, nse);
|
||||
MutSparseTensorDescriptor desc(dstTp, fields);
|
||||
|
||||
// Read the COO tensor data.
|
||||
Type eltTp = dstTp.getElementType();
|
||||
Type indBufEleTp = getIndexOverheadType(rewriter, encDst);
|
||||
SmallString<32> getReadFuncName{"getSparseTensorReaderRead",
|
||||
overheadTypeFunctionSuffix(indBufEleTp),
|
||||
primaryTypeFunctionSuffix(eltTp)};
|
||||
|
||||
Value xs = desc.getAOSMemRef();
|
||||
Value ys = desc.getValMemRef();
|
||||
SmallVector<Value> dim2lvlValues(dimRank, Value());
|
||||
if (auto dimOrder = encDst.getDimOrdering()) {
|
||||
assert(dimOrder.isPermutation() && "Got non-permutation");
|
||||
for (uint64_t l = 0; l < dimRank; l++) {
|
||||
uint64_t d = dimOrder.getDimPosition(l);
|
||||
dim2lvlValues[d] = constantIndex(rewriter, loc, l);
|
||||
}
|
||||
} else {
|
||||
for (uint64_t l = 0; l < dimRank; l++)
|
||||
dim2lvlValues[l] = constantIndex(rewriter, loc, l);
|
||||
}
|
||||
Value dim2lvl = allocaBuffer(rewriter, loc, dim2lvlValues);
|
||||
|
||||
Value f = constantI1(rewriter, loc, false);
|
||||
Value isSorted =
|
||||
createFuncCall(rewriter, loc, getReadFuncName, {f.getType()},
|
||||
{reader, dim2lvl, xs, ys}, EmitCInterface::On)
|
||||
.getResult(0);
|
||||
|
||||
// If the destination tensor is a sorted COO, we need to sort the COO tensor
|
||||
// data if the input elements aren't sorted yet.
|
||||
if (encDst.isOrderedLvl(dimRank - 1)) {
|
||||
Value notSorted = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, isSorted, f);
|
||||
scf::IfOp ifOp =
|
||||
rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false);
|
||||
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
|
||||
rewriter.create<SortCooOp>(
|
||||
loc, nse, xs, ValueRange{ys}, rewriter.getIndexAttr(dimRank),
|
||||
rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
|
||||
rewriter.setInsertionPointAfter(ifOp);
|
||||
}
|
||||
|
||||
// Set PtrMemRef0[1] = nse.
|
||||
Value c1 = constantIndex(rewriter, loc, 1);
|
||||
Value ptrMemref0 = desc.getPtrMemRef(0);
|
||||
Type ptrEleTy = getMemRefType(ptrMemref0).getElementType();
|
||||
Value ptrNse =
|
||||
ptrEleTy == nse.getType()
|
||||
? nse
|
||||
: rewriter.create<arith::IndexCastOp>(loc, ptrEleTy, nse);
|
||||
rewriter.create<memref::StoreOp>(loc, ptrNse, ptrMemref0, c1);
|
||||
|
||||
// Update storage specifier.
|
||||
Value idxSize = rewriter.create<arith::MulIOp>(
|
||||
loc, nse, constantIndex(rewriter, loc, dimRank));
|
||||
desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::IdxMemSize, 0,
|
||||
idxSize);
|
||||
desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::ValMemSize,
|
||||
std::nullopt, nse);
|
||||
|
||||
// Release the sparse tensor reader.
|
||||
createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
|
||||
EmitCInterface::Off);
|
||||
|
||||
// Replace operation with resulting memrefs.
|
||||
rewriter.replaceOp(op, genTuple(rewriter, loc, dstTp, fields));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -1308,8 +1439,8 @@ void mlir::populateSparseTensorCodegenPatterns(
|
||||
SparseInsertConverter, SparseToPointersConverter,
|
||||
SparseToIndicesConverter, SparseToIndicesBufferConverter,
|
||||
SparseToValuesConverter, SparseConvertConverter,
|
||||
SparseNumberOfEntriesConverter>(typeConverter,
|
||||
patterns.getContext());
|
||||
SparseNewOpConverter, SparseNumberOfEntriesConverter>(
|
||||
typeConverter, patterns.getContext());
|
||||
patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
|
||||
enableBufferInitialization);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user