[mlir][sparse] Improve the implementation of sparse_tensor.new for the codegen path.

Rewrite a NewOp into a NewOp of a sorted COO tensor and a ConvertOp for
converting the sorted COO tensor to the destination tensor type.

Codegen a NewOp of a sorted COO tensor to use the new bulk reader API and sort
the elements only when the input is not sorted.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D144504
This commit is contained in:
bixia1
2023-02-28 12:52:16 -08:00
parent c888a0ce88
commit 2c81d43241
4 changed files with 252 additions and 129 deletions

View File

@@ -1289,6 +1289,137 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
}
};
struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(NewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
const auto dstTp = getSparseTensorType(op.getResult());
const auto encDst = dstTp.getEncoding();
// Creating COO with NewOp is handled by direct IR codegen. All other cases
// are handled by rewriting.
if (!dstTp.hasEncoding() || getCOOStart(encDst) != 0)
return failure();
// Implement the NewOp(filename) as follows:
// reader = getSparseTensorReader(filename)
// nse = getSparseTensorNNZ()
// tmp = bufferization.alloc_tensor an ordered COO with
// dst dim ordering, size_hint = nse
// indices = to_indices_buffer(tmp)
// values = to_values(tmp)
// isSorted = getSparseTensorReaderRead(indices, values, dimOrdering)
// if (!isSorted) sort_coo(nse, indices, values)
// update storage specifier
// dst = sparse_tensor.ConvertOp tmp
// Create a sparse tensor reader.
Value fileName = op.getSource();
Type opaqueTp = getOpaquePointerType(rewriter);
Value reader = createFuncCall(rewriter, loc, "createSparseTensorReader",
{opaqueTp}, {fileName}, EmitCInterface::Off)
.getResult(0);
Type indexTp = rewriter.getIndexType();
const Dimension dimRank = dstTp.getDimRank();
// If the result tensor has dynamic dimensions, get the dynamic sizes from
// the sparse tensor reader.
SmallVector<Value> dynSizes;
if (dstTp.hasDynamicDimShape()) {
Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp);
createFuncCall(rewriter, loc, "copySparseTensorReaderDimSizes", {},
{reader, dimSizes}, EmitCInterface::On)
.getResult(0);
ArrayRef<int64_t> dstShape = dstTp.getRankedTensorType().getShape();
for (auto &d : llvm::enumerate(dstShape)) {
if (d.value() == ShapedType::kDynamic) {
dynSizes.push_back(rewriter.create<memref::LoadOp>(
loc, dimSizes, constantIndex(rewriter, loc, d.index())));
}
}
}
Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNNZ",
{indexTp}, {reader}, EmitCInterface::Off)
.getResult(0);
// Construct allocation for each field.
SmallVector<Value> fields;
createAllocFields(rewriter, loc, dstTp, dynSizes, /*enableInit=*/false,
fields, nse);
MutSparseTensorDescriptor desc(dstTp, fields);
// Read the COO tensor data.
Type eltTp = dstTp.getElementType();
Type indBufEleTp = getIndexOverheadType(rewriter, encDst);
SmallString<32> getReadFuncName{"getSparseTensorReaderRead",
overheadTypeFunctionSuffix(indBufEleTp),
primaryTypeFunctionSuffix(eltTp)};
Value xs = desc.getAOSMemRef();
Value ys = desc.getValMemRef();
SmallVector<Value> dim2lvlValues(dimRank, Value());
if (auto dimOrder = encDst.getDimOrdering()) {
assert(dimOrder.isPermutation() && "Got non-permutation");
for (uint64_t l = 0; l < dimRank; l++) {
uint64_t d = dimOrder.getDimPosition(l);
dim2lvlValues[d] = constantIndex(rewriter, loc, l);
}
} else {
for (uint64_t l = 0; l < dimRank; l++)
dim2lvlValues[l] = constantIndex(rewriter, loc, l);
}
Value dim2lvl = allocaBuffer(rewriter, loc, dim2lvlValues);
Value f = constantI1(rewriter, loc, false);
Value isSorted =
createFuncCall(rewriter, loc, getReadFuncName, {f.getType()},
{reader, dim2lvl, xs, ys}, EmitCInterface::On)
.getResult(0);
// If the destination tensor is a sorted COO, we need to sort the COO tensor
// data if the input elements aren't sorted yet.
if (encDst.isOrderedLvl(dimRank - 1)) {
Value notSorted = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, isSorted, f);
scf::IfOp ifOp =
rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false);
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
rewriter.create<SortCooOp>(
loc, nse, xs, ValueRange{ys}, rewriter.getIndexAttr(dimRank),
rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
rewriter.setInsertionPointAfter(ifOp);
}
// Set PtrMemRef0[1] = nse.
Value c1 = constantIndex(rewriter, loc, 1);
Value ptrMemref0 = desc.getPtrMemRef(0);
Type ptrEleTy = getMemRefType(ptrMemref0).getElementType();
Value ptrNse =
ptrEleTy == nse.getType()
? nse
: rewriter.create<arith::IndexCastOp>(loc, ptrEleTy, nse);
rewriter.create<memref::StoreOp>(loc, ptrNse, ptrMemref0, c1);
// Update storage specifier.
Value idxSize = rewriter.create<arith::MulIOp>(
loc, nse, constantIndex(rewriter, loc, dimRank));
desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::IdxMemSize, 0,
idxSize);
desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::ValMemSize,
std::nullopt, nse);
// Release the sparse tensor reader.
createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
EmitCInterface::Off);
// Replace operation with resulting memrefs.
rewriter.replaceOp(op, genTuple(rewriter, loc, dstTp, fields));
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
@@ -1308,8 +1439,8 @@ void mlir::populateSparseTensorCodegenPatterns(
SparseInsertConverter, SparseToPointersConverter,
SparseToIndicesConverter, SparseToIndicesBufferConverter,
SparseToValuesConverter, SparseConvertConverter,
SparseNumberOfEntriesConverter>(typeConverter,
patterns.getContext());
SparseNewOpConverter, SparseNumberOfEntriesConverter>(
typeConverter, patterns.getContext());
patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
enableBufferInitialization);
}