[mlir][sparse] Support packing external data into arbitrary sparse tensor encoding.

We previously only support packing two array (values and coordinates) into COO tensors.
This patch allows packing inputs into arbitrary sparse tensor format.

It also deletes the "implicit" data canonicalization performed inside sparse compiler,
but instead requires users to canonicalize the data before passing it to the sparse compiler.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D150916
This commit is contained in:
Peiming Liu
2023-05-16 22:16:21 +00:00
parent fe69bb6441
commit de56088866
9 changed files with 290 additions and 358 deletions

View File

@@ -1214,192 +1214,94 @@ public:
matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Query memSizes for the actually stored values.
// FIXME: the nse value computed in this way might be wrong when there is
// any "compressed-hi" level.
rewriter.replaceOp(
op, genValMemSize(rewriter, op.getLoc(), adaptor.getTensor()));
return success();
}
};
static void populateCompressedWithHiPosArray(OpBuilder &builder, Location loc,
ArrayRef<unsigned> batchDimSzs,
Value posMemRef, unsigned nse,
PackOp op) {
SmallVector<Value> lbs, ubs, steps;
Value c0 = constantIndex(builder, loc, 0);
Value c1 = constantIndex(builder, loc, 1);
Value c2 = constantIndex(builder, loc, 2);
for (unsigned dimSz : batchDimSzs) {
lbs.push_back(c0);
ubs.push_back(constantIndex(builder, loc, dimSz));
steps.push_back(c1);
}
auto tensorType = op.getValues().getType();
auto memrefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
Value batV = builder.create<bufferization::ToMemrefOp>(loc, memrefType,
op.getValues());
scf::buildLoopNest(
builder, loc, lbs, ubs, steps,
[&ubs, c0, c1, c2, nse, batV, posMemRef](OpBuilder &builder, Location loc,
ValueRange ivs) {
// Linearize index variables
Value crd = linearize(builder, loc, ivs, ubs);
Value len = constantIndex(builder, loc, nse);
Value pLo = builder.create<arith::MulIOp>(loc, crd, len);
SmallVector<Value> indices(ivs.begin(), ivs.end());
auto whileOp = builder.create<scf::WhileOp>(
loc, TypeRange{builder.getIndexType()}, ValueRange{len},
[&indices, c0, c1, batV](OpBuilder &builder, Location loc,
ValueRange vs) {
Value curLen = vs.front();
Value pred = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, curLen, c0);
auto ifOp = builder.create<scf::IfOp>(
loc, TypeRange{builder.getI1Type()}, pred, true);
{
OpBuilder::InsertionGuard guard(builder);
// if len == 0.
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
builder.create<scf::YieldOp>(loc,
constantI1(builder, loc, false));
// Else branch.
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
indices.push_back(
builder.create<arith::SubIOp>(loc, curLen, c1));
Value val = builder.create<memref::LoadOp>(loc, batV, indices);
indices.pop_back();
Value cont = builder.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OEQ, val,
constantZero(builder, loc, val.getType()));
builder.create<scf::YieldOp>(loc, cont);
}
builder.create<scf::ConditionOp>(loc, ifOp.getResults()[0], vs);
},
[c1](OpBuilder &builder, Location loc, ValueRange vs) {
// len --;
Value nxLen = builder.create<arith::SubIOp>(loc, vs.front(), c1);
builder.create<scf::YieldOp>(loc, nxLen);
});
len = whileOp.getResults()[0];
Value pHi = builder.create<arith::AddIOp>(loc, pLo, len);
// Stores position lower bound.
Value idx = builder.create<arith::MulIOp>(loc, crd, c2);
genStore(builder, loc, pLo, posMemRef, idx);
// Stores position upper bound.
idx = builder.create<arith::AddIOp>(loc, idx, c1);
genStore(builder, loc, pHi, posMemRef, idx);
});
}
struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(PackOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const unsigned batchedLvls = op.getNumBatchedLvls();
unsigned nse = op.getValues().getType().getDimSize(batchedLvls);
Location loc = op.getLoc();
const auto stt = getSparseTensorType(op.getResult());
assert(isCOOType(stt.getEncoding(), batchedLvls, true));
unsigned batchedCount = 1;
SmallVector<unsigned> batchDimSzs;
batchDimSzs.reserve(batchedLvls);
for (unsigned i = 0; i < batchedLvls; i++) {
// Should already be guaranteed by verifier.
assert(!ShapedType::isDynamic(stt.getDimShape()[i]));
batchedCount *= stt.getDimShape()[i];
batchDimSzs.push_back(stt.getDimShape()[i]);
}
SmallVector<Value> fields;
Location loc = op.getLoc();
foreachFieldAndTypeInSparseTensor(
stt,
[&rewriter, &fields, &op, &batchDimSzs, nse, batchedCount, stt,
[&rewriter, &fields, &op, &stt,
loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
Level /*lvl*/, DimLevelType dlt) -> bool {
assert(fields.size() == fIdx);
Value field;
switch (fKind) {
case SparseTensorFieldKind::StorageSpec:
field = SparseTensorSpecifier::getInitValue(rewriter, loc, stt);
break;
case SparseTensorFieldKind::PosMemRef: {
// TACO-style COO starts with a PosBuffer
const auto posTp = stt.getPosType();
if (isCompressedDLT(dlt)) {
auto memrefType = MemRefType::get({batchedCount + 1}, posTp);
field = rewriter.create<memref::AllocOp>(loc, memrefType);
Value c0 = constantIndex(rewriter, loc, 0);
genStore(rewriter, loc, c0, field, c0);
for (unsigned i = 1; i <= batchedCount; i++) {
// The postion memref will have values as
// [0, nse, 2 * nse, ..., batchedCount * nse]
Value idx = constantIndex(rewriter, loc, i);
Value val = constantIndex(rewriter, loc, nse * i);
genStore(rewriter, loc, val, field, idx);
}
} else {
assert(isCompressedWithHiDLT(dlt) && !batchDimSzs.empty());
MemRefType posMemTp = MemRefType::get({batchedCount * 2}, posTp);
field = rewriter.create<memref::AllocOp>(loc, posMemTp);
populateCompressedWithHiPosArray(rewriter, loc, batchDimSzs,
field, nse, op);
if (fKind == SparseTensorFieldKind::StorageSpec) {
fields.push_back(
SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
} else {
// Else simply takes the inputs.
Value field = fKind == SparseTensorFieldKind::ValMemRef
? op.getValues()
: op.getLevels()[fIdx];
auto tensorType = field.getType().cast<RankedTensorType>();
auto memrefType = MemRefType::get(tensorType.getShape(),
tensorType.getElementType());
field = rewriter.create<bufferization::ToMemrefOp>(
op->getLoc(), memrefType, field);
if (memrefType.getRank() > 1) {
// Flattens the buffer to rank 1.
auto reassoc = getReassociationForFlattening(memrefType);
field =
rewriter.create<memref::CollapseShapeOp>(loc, field, reassoc);
}
break;
}
case SparseTensorFieldKind::CrdMemRef: {
auto tensorType = op.getCoordinates().getType();
auto memrefType = MemRefType::get(tensorType.getShape(),
tensorType.getElementType());
field = rewriter.create<bufferization::ToMemrefOp>(
op->getLoc(), memrefType, op.getCoordinates());
break;
}
case SparseTensorFieldKind::ValMemRef: {
auto tensorType = op.getValues().getType();
auto memrefType = MemRefType::get(tensorType.getShape(),
tensorType.getElementType());
field = rewriter.create<bufferization::ToMemrefOp>(
op->getLoc(), memrefType, op.getValues());
break;
}
}
assert(field);
if (auto memrefTp = dyn_cast<MemRefType>(field.getType());
memrefTp && memrefTp.getRank() > 1) {
ReassociationIndices reassociation;
for (int i = 0, e = memrefTp.getRank(); i < e; i++)
reassociation.push_back(i);
// Flattens the buffer to rank 1. The value buffer might need be
// collapsed as well due to batching.
field = rewriter.create<memref::CollapseShapeOp>(
loc, field, ArrayRef<ReassociationIndices>(reassociation));
}
if (fType != field.getType())
field = rewriter.create<memref::CastOp>(loc, fType, field);
fields.push_back(field);
// Returns true to continue the iteration.
fields.push_back(field);
}
return true;
});
MutSparseTensorDescriptor desc(stt, fields);
auto noe = linalg::createOrFoldDimOp(rewriter, loc, op.getValues(), 0);
Value c1 = constantIndex(rewriter, loc, 1);
Value c2 = constantIndex(rewriter, loc, 2);
Value posBack = c1; // index to the last value in the postion array
Value memSize = c2; // memory size for current array
// Sets up SparseTensorSpecifier.
for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
// FIXME: dim/lvl confusion!
const auto sh = stt.getDimShape()[lvl];
assert(!ShapedType::isDynamic(sh));
desc.setLvlSize(rewriter, loc, lvl, constantIndex(rewriter, loc, sh));
if (lvl == 0)
desc.setPosMemSize(rewriter, loc, lvl, constantIndex(rewriter, loc, 2));
assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
desc.setCrdMemSize(rewriter, loc, lvl, noe);
// FIXME: dim/lvl confusion!
// Sets up the level size.
auto lvlSize = constantIndex(rewriter, loc, stt.getDimShape()[lvl]);
desc.setLvlSize(rewriter, loc, lvl, lvlSize);
// Sets up the memory size by reading the last value in position array.
DimLevelType dlt = stt.getLvlType(lvl);
// Simply forwards the position index when this is a dense level.
if (isDenseDLT(dlt)) {
memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, posBack);
posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
continue;
}
if (isDLTWithPos(dlt)) {
assert(isCompressedDLT(dlt) || isCompressedWithHiDLT(dlt));
if (isCompressedWithHiDLT(dlt)) {
memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2);
posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
}
desc.setPosMemSize(rewriter, loc, lvl, memSize);
// The last value in position array is the memory size for next level.
memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), posBack);
posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1);
}
assert(isDLTWithCrd(dlt));
desc.setCrdMemSize(rewriter, loc, lvl, memSize);
}
desc.setValMemSize(rewriter, loc, noe);
desc.setValMemSize(rewriter, loc, memSize);
rewriter.replaceOp(op, genTuple(rewriter, loc, desc));
return success();
@@ -1568,10 +1470,8 @@ static LogicalResult genBatchedUnpackOp(UnpackOp op, unsigned nBatched,
struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
using OpConversionPattern::OpConversionPattern;
SparseUnpackOpConverter(TypeConverter &typeConverter, MLIRContext *context,
bool createDeallocs)
: OpConversionPattern(typeConverter, context),
createDeallocs(createDeallocs) {}
SparseUnpackOpConverter(TypeConverter &typeConverter, MLIRContext *context)
: OpConversionPattern(typeConverter, context) {}
LogicalResult
matchAndRewrite(UnpackOp op, OpAdaptor adaptor,
@@ -1582,26 +1482,9 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
assert(isCOOType(srcTp.getEncoding(), nBatched, true) &&
desc.getFields().size() == 4); // specifier + pos + crds + values
(void)srcTp;
auto logicRes = nBatched == 0
? genUnBatchedUnpackOp(op, desc, rewriter)
: genBatchedUnpackOp(op, nBatched, desc, rewriter);
Value posBuf = desc.getPosMemRef(nBatched);
if (createDeallocs) {
// Unpack ends the lifetime of the sparse tensor. While the value array
// and coordinate array are unpacked and returned, the position array
// becomes useless and need to be freed (if user requests).
// FIXME: Depending on whether the tensor being unpacked is created by
// PackOp or not, we may or may not need to free other memref fields of
// the sparse tensor too (PackOp borrows value/coordinate buffer).
rewriter.create<memref::DeallocOp>(op.getLoc(), posBuf);
}
return logicRes;
return nBatched == 0 ? genUnBatchedUnpackOp(op, desc, rewriter)
: genBatchedUnpackOp(op, nBatched, desc, rewriter);
}
private:
const bool createDeallocs;
};
struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
@@ -1755,11 +1638,11 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
void mlir::populateSparseTensorCodegenPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns,
bool createSparseDeallocs, bool enableBufferInitialization) {
patterns.add<SparsePackOpConverter, SparseReturnConverter,
SparseCallConverter, SparseDimOpConverter, SparseCastConverter,
SparseExtractSliceConverter, SparseTensorLoadConverter,
SparseExpandConverter, SparseCompressConverter,
SparseInsertConverter,
patterns.add<SparsePackOpConverter, SparseUnpackOpConverter,
SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
SparseCastConverter, SparseExtractSliceConverter,
SparseTensorLoadConverter, SparseExpandConverter,
SparseCompressConverter, SparseInsertConverter,
SparseSliceGetterOpConverter<ToSliceOffsetOp,
StorageSpecifierKind::DimOffset>,
SparseSliceGetterOpConverter<ToSliceStrideOp,
@@ -1769,7 +1652,7 @@ void mlir::populateSparseTensorCodegenPatterns(
SparseConvertConverter, SparseNewOpConverter,
SparseNumberOfEntriesConverter>(typeConverter,
patterns.getContext());
patterns.add<SparseTensorDeallocConverter, SparseUnpackOpConverter>(
patterns.add<SparseTensorDeallocConverter>(
typeConverter, patterns.getContext(), createSparseDeallocs);
patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
enableBufferInitialization);