[mlir][sparse] Implementing sparse=>dense conversion.
Depends On D110882, D110883, D110884 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D110790
This commit is contained in:
@@ -36,7 +36,8 @@ enum Action : uint32_t {
|
||||
kFromFile = 1,
|
||||
kFromCOO = 2,
|
||||
kEmptyCOO = 3,
|
||||
kToCOO = 4
|
||||
kToCOO = 4,
|
||||
kToIter = 5
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -202,7 +203,9 @@ static void sizesFromPtr(ConversionPatternRewriter &rewriter,
|
||||
sizes.push_back(constantIndex(rewriter, op->getLoc(), shape[i]));
|
||||
}
|
||||
|
||||
/// Generates a temporary buffer of the given size and type.
|
||||
/// Generates an uninitialized temporary buffer of the given size and
|
||||
/// type, but returns it as type `memref<? x $tp>` (rather than as type
|
||||
/// `memref<$sz x $tp>`).
|
||||
static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc,
|
||||
unsigned sz, Type tp) {
|
||||
auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp);
|
||||
@@ -210,6 +213,13 @@ static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc,
|
||||
return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{a});
|
||||
}
|
||||
|
||||
/// Generates an uninitialized temporary buffer with room for one value
|
||||
/// of the given type, and returns the `memref<$tp>`.
|
||||
static Value genAllocaScalar(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Type tp) {
|
||||
return rewriter.create<memref::AllocaOp>(loc, MemRefType::get({}, tp));
|
||||
}
|
||||
|
||||
/// Generates a temporary buffer of the given type and given contents.
|
||||
static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> values) {
|
||||
@@ -345,6 +355,39 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
rewriter.create<CallOp>(loc, pTp, fn, params);
|
||||
}
|
||||
|
||||
/// Generates a call to `iter->getNext()`. If there is a next element,
|
||||
/// then it is copied into the out-parameters `ind` and `elemPtr`,
|
||||
/// and the return value is true. If there isn't a next element, then
|
||||
/// the return value is false.
|
||||
static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Value iter, Value ind, Value elemPtr) {
|
||||
Location loc = op->getLoc();
|
||||
Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType();
|
||||
StringRef name;
|
||||
if (elemTp.isF64())
|
||||
name = "getNextF64";
|
||||
else if (elemTp.isF32())
|
||||
name = "getNextF32";
|
||||
else if (elemTp.isInteger(64))
|
||||
name = "getNextI64";
|
||||
else if (elemTp.isInteger(32))
|
||||
name = "getNextI32";
|
||||
else if (elemTp.isInteger(16))
|
||||
name = "getNextI16";
|
||||
else if (elemTp.isInteger(8))
|
||||
name = "getNextI8";
|
||||
else
|
||||
llvm_unreachable("Unknown element type");
|
||||
SmallVector<Value, 3> params;
|
||||
params.push_back(iter);
|
||||
params.push_back(ind);
|
||||
params.push_back(elemPtr);
|
||||
Type i1 = rewriter.getI1Type();
|
||||
auto fn = getFunc(op, name, i1, params, /*emitCInterface=*/true);
|
||||
auto call = rewriter.create<CallOp>(loc, i1, fn, params);
|
||||
return call.getResult(0);
|
||||
}
|
||||
|
||||
/// If the tensor is a sparse constant, generates and returns the pair of
|
||||
/// the constants for the indices and the values.
|
||||
static Optional<std::pair<Value, Value>>
|
||||
@@ -379,6 +422,37 @@ static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
|
||||
return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]);
|
||||
}
|
||||
|
||||
/// Generates code to allocate a tensor of the given type, and zero
|
||||
/// initialize it. This function assumes the TensorType is fully
|
||||
/// specified (i.e., has static rank and sizes).
|
||||
// TODO(D112674): support dynamic sizes.
|
||||
static Value allocDenseTensor(ConversionPatternRewriter &rewriter, Location loc,
|
||||
RankedTensorType tensorTp) {
|
||||
Type elemTp = tensorTp.getElementType();
|
||||
auto memTp = MemRefType::get(tensorTp.getShape(), elemTp);
|
||||
Value mem = rewriter.create<memref::AllocOp>(loc, memTp);
|
||||
Value zero = constantZero(rewriter, loc, elemTp);
|
||||
rewriter.create<linalg::FillOp>(loc, zero, mem).result();
|
||||
return mem;
|
||||
}
|
||||
|
||||
/// Inserts the element returned by genGetNextCall(_, ind, elemPtr) into
|
||||
/// the tensor created by allocDenseTensor(). The `rank` is the rank
|
||||
/// of the `tensor` and the length of `ind`.
|
||||
static void insertScalarIntoDenseTensor(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value elemPtr,
|
||||
Value tensor, unsigned rank,
|
||||
Value ind) {
|
||||
SmallVector<Value, 4> ivs;
|
||||
ivs.reserve(rank);
|
||||
for (unsigned i = 0; i < rank; i++) {
|
||||
Value idx = constantIndex(rewriter, loc, i);
|
||||
ivs.push_back(rewriter.create<memref::LoadOp>(loc, ind, idx));
|
||||
}
|
||||
Value elemV = rewriter.create<memref::LoadOp>(loc, elemPtr);
|
||||
rewriter.create<memref::StoreOp>(loc, elemV, tensor, ivs);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conversion rules.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -509,8 +583,49 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
|
||||
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
|
||||
return success();
|
||||
}
|
||||
if (!encDst || encSrc) {
|
||||
// TODO: sparse => dense
|
||||
if (!encDst && encSrc) {
|
||||
// This is sparse => dense conversion, which is handled as follows:
|
||||
// dst = new Tensor(0);
|
||||
// iter = src->toCOO()->getIterator();
|
||||
// while (elem = iter->getNext()) {
|
||||
// dst[elem.indices] = elem.value;
|
||||
// }
|
||||
Location loc = op->getLoc();
|
||||
RankedTensorType tensorTp = resType.dyn_cast<RankedTensorType>();
|
||||
if (!tensorTp)
|
||||
return failure();
|
||||
unsigned rank = tensorTp.getRank();
|
||||
Value dst = allocDenseTensor(rewriter, loc, tensorTp);
|
||||
Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
|
||||
Value elemPtr = genAllocaScalar(rewriter, loc, tensorTp.getElementType());
|
||||
encDst = SparseTensorEncodingAttr::get(
|
||||
op->getContext(),
|
||||
SmallVector<SparseTensorEncodingAttr::DimLevelType>(
|
||||
rank, SparseTensorEncodingAttr::DimLevelType::Dense),
|
||||
AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
|
||||
SmallVector<Value, 4> sizes;
|
||||
SmallVector<Value, 8> params;
|
||||
// TODO(D112674): support dynamic sizes.
|
||||
sizesFromType(rewriter, sizes, loc, tensorTp);
|
||||
newParams(rewriter, params, op, encDst, kToIter, sizes, src);
|
||||
Value iter = genNewCall(rewriter, op, params);
|
||||
SmallVector<Value> noArgs;
|
||||
SmallVector<Type> noTypes;
|
||||
auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
|
||||
Block *before = rewriter.createBlock(&whileOp.before(), {}, noTypes);
|
||||
rewriter.setInsertionPointToEnd(before);
|
||||
Value cond = genGetNextCall(rewriter, op, iter, ind, elemPtr);
|
||||
rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
|
||||
Block *after = rewriter.createBlock(&whileOp.after(), {}, noTypes);
|
||||
rewriter.setInsertionPointToStart(after);
|
||||
insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, rank, ind);
|
||||
rewriter.create<scf::YieldOp>(loc);
|
||||
rewriter.setInsertionPointAfter(whileOp);
|
||||
rewriter.replaceOpWithNewOp<memref::TensorLoadOp>(op, resType, dst);
|
||||
return success();
|
||||
}
|
||||
if (!encDst && !encSrc) {
|
||||
// dense => dense
|
||||
return failure();
|
||||
}
|
||||
// This is a dense => sparse conversion or a sparse constant in COO =>
|
||||
|
||||
Reference in New Issue
Block a user