[mlir][sparse] Implementing sparse=>dense conversion.

Depends On D110882, D110883, D110884

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D110790
This commit is contained in:
wren romano
2021-10-28 13:59:01 -07:00
parent a94b721d26
commit 28882b6575
6 changed files with 501 additions and 21 deletions

View File

@@ -36,7 +36,8 @@ enum Action : uint32_t {
kFromFile = 1,
kFromCOO = 2,
kEmptyCOO = 3,
kToCOO = 4
kToCOO = 4,
kToIter = 5
};
//===----------------------------------------------------------------------===//
@@ -202,7 +203,9 @@ static void sizesFromPtr(ConversionPatternRewriter &rewriter,
sizes.push_back(constantIndex(rewriter, op->getLoc(), shape[i]));
}
/// Generates a temporary buffer of the given size and type.
/// Generates an uninitialized temporary buffer of the given size and
/// type, but returns it as type `memref<? x $tp>` (rather than as type
/// `memref<$sz x $tp>`).
static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc,
unsigned sz, Type tp) {
auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp);
@@ -210,6 +213,13 @@ static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc,
return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{a});
}
/// Generates an uninitialized temporary buffer with room for one value
/// of the given type, and returns the `memref<$tp>`.
static Value genAllocaScalar(ConversionPatternRewriter &rewriter, Location loc,
Type tp) {
return rewriter.create<memref::AllocaOp>(loc, MemRefType::get({}, tp));
}
/// Generates a temporary buffer of the given type and given contents.
static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> values) {
@@ -345,6 +355,39 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
rewriter.create<CallOp>(loc, pTp, fn, params);
}
/// Generates a call to `iter->getNext()`. If there is a next element,
/// then it is copied into the out-parameters `ind` and `elemPtr`,
/// and the return value is true. If there isn't a next element, then
/// the return value is false.
static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op,
Value iter, Value ind, Value elemPtr) {
Location loc = op->getLoc();
Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType();
StringRef name;
if (elemTp.isF64())
name = "getNextF64";
else if (elemTp.isF32())
name = "getNextF32";
else if (elemTp.isInteger(64))
name = "getNextI64";
else if (elemTp.isInteger(32))
name = "getNextI32";
else if (elemTp.isInteger(16))
name = "getNextI16";
else if (elemTp.isInteger(8))
name = "getNextI8";
else
llvm_unreachable("Unknown element type");
SmallVector<Value, 3> params;
params.push_back(iter);
params.push_back(ind);
params.push_back(elemPtr);
Type i1 = rewriter.getI1Type();
auto fn = getFunc(op, name, i1, params, /*emitCInterface=*/true);
auto call = rewriter.create<CallOp>(loc, i1, fn, params);
return call.getResult(0);
}
/// If the tensor is a sparse constant, generates and returns the pair of
/// the constants for the indices and the values.
static Optional<std::pair<Value, Value>>
@@ -379,6 +422,37 @@ static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]);
}
/// Generates code to allocate a tensor of the given type, and zero
/// initialize it. This function assumes the TensorType is fully
/// specified (i.e., has static rank and sizes).
// TODO(D112674): support dynamic sizes.
static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc,
RankedTensorType tensorTp) {
Type elemTp = tensorTp.getElementType();
auto memTp = MemRefType::get(tensorTp.getShape(), elemTp);
Value mem = rewriter.create<memref::AllocOp>(loc, memTp);
Value zero = constantZero(rewriter, loc, elemTp);
rewriter.create<linalg::FillOp>(loc, zero, mem).result();
return mem;
}
/// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into
/// the tensor created by allocDenseTensor(). The `rank` is the rank
/// of the `tensor` and the length of `ind`.
static void insertScalarIntoDenseTensor(ConversionPatternRewriter &rewriter,
Location loc, Value elemPtr,
Value tensor, unsigned rank,
Value ind) {
SmallVector<Value, 4> ivs;
ivs.reserve(rank);
for (unsigned i = 0; i < rank; i++) {
Value idx = constantIndex(rewriter, loc, i);
ivs.push_back(rewriter.create<memref::LoadOp>(loc, ind, idx));
}
Value elemV = rewriter.create<memref::LoadOp>(loc, elemPtr);
rewriter.create<memref::StoreOp>(loc, elemV, tensor, ivs);
}
//===----------------------------------------------------------------------===//
// Conversion rules.
//===----------------------------------------------------------------------===//
@@ -509,8 +583,49 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
return success();
}
if (!encDst || encSrc) {
// TODO: sparse => dense
if (!encDst && encSrc) {
// This is sparse => dense conversion, which is handled as follows:
// dst = new Tensor(0);
// iter = src->toCOO()->getIterator();
// while (elem = iter->getNext()) {
// dst[elem.indices] = elem.value;
// }
Location loc = op->getLoc();
RankedTensorType tensorTp = resType.dyn_cast<RankedTensorType>();
if (!tensorTp)
return failure();
unsigned rank = tensorTp.getRank();
Value dst = allocDenseTensor(rewriter, loc, tensorTp);
Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
Value elemPtr = genAllocaScalar(rewriter, loc, tensorTp.getElementType());
encDst = SparseTensorEncodingAttr::get(
op->getContext(),
SmallVector<SparseTensorEncodingAttr::DimLevelType>(
rank, SparseTensorEncodingAttr::DimLevelType::Dense),
AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
SmallVector<Value, 4> sizes;
SmallVector<Value, 8> params;
// TODO(D112674): support dynamic sizes.
sizesFromType(rewriter, sizes, loc, tensorTp);
newParams(rewriter, params, op, encDst, kToIter, sizes, src);
Value iter = genNewCall(rewriter, op, params);
SmallVector<Value> noArgs;
SmallVector<Type> noTypes;
auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
Block *before = rewriter.createBlock(&whileOp.before(), {}, noTypes);
rewriter.setInsertionPointToEnd(before);
Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr);
rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
Block *after = rewriter.createBlock(&whileOp.after(), {}, noTypes);
rewriter.setInsertionPointToStart(after);
insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind);
rewriter.create<scf::YieldOp>(loc);
rewriter.setInsertionPointAfter(whileOp);
rewriter.replaceOpWithNewOp<memref::TensorLoadOp>(op, resType, dst);
return success();
}
if (!encDst && !encSrc) {
// dense => dense
return failure();
}
// This is a dense => sparse conversion or a sparse constant in COO =>