[mlir][sparse] unify support of (dis)assemble between direct IR/lib path (#71880)

Note that the (dis)assemble operations still make some simplfying
assumptions (e.g. trailing 2-D COO in AoS format) but now at least both
the direct IR and support library path behave exactly the same.

Generalizing the ops is still TBD.
This commit is contained in:
Aart Bik
2023-11-13 10:05:00 -08:00
committed by GitHub
parent 5fdb70be7b
commit af8428c0d9
9 changed files with 312 additions and 297 deletions

View File

@@ -46,17 +46,6 @@ static std::optional<Type> convertSparseTensorTypes(Type type) {
return std::nullopt;
}
/// Replaces the `op` with a `CallOp` to the `getFunc()` function reference.
static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
StringRef name, TypeRange resultType,
ValueRange operands,
EmitCInterface emitCInterface) {
auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, resultType, operands,
emitCInterface);
return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn,
operands);
}
/// Generates call to lookup a level-size. N.B., this only generates
/// the raw function call, and therefore (intentionally) does not perform
/// any dim<->lvl conversion or other logic.
@@ -264,11 +253,36 @@ private:
};
/// Generates a call to obtain the values array.
static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
ValueRange ptr) {
SmallString<15> name{"sparseValues",
primaryTypeFunctionSuffix(tp.getElementType())};
return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On)
static Value genValuesCall(OpBuilder &builder, Location loc,
SparseTensorType stt, Value ptr) {
auto eltTp = stt.getElementType();
auto resTp = MemRefType::get({ShapedType::kDynamic}, eltTp);
SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltTp)};
return createFuncCall(builder, loc, name, resTp, {ptr}, EmitCInterface::On)
.getResult(0);
}
/// Generates a call to obtain the positions array.
static Value genPositionsCall(OpBuilder &builder, Location loc,
SparseTensorType stt, Value ptr, Level l) {
Type posTp = stt.getPosType();
auto resTp = MemRefType::get({ShapedType::kDynamic}, posTp);
Value lvl = constantIndex(builder, loc, l);
SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
EmitCInterface::On)
.getResult(0);
}
/// Generates a call to obtain the coordindates array.
static Value genCoordinatesCall(OpBuilder &builder, Location loc,
SparseTensorType stt, Value ptr, Level l) {
Type crdTp = stt.getCrdType();
auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
Value lvl = constantIndex(builder, loc, l);
SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(crdTp)};
return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
EmitCInterface::On)
.getResult(0);
}
@@ -391,7 +405,7 @@ public:
SmallVector<Value> dimSizes;
dimSizes.reserve(dimRank);
unsigned operandCtr = 0;
for (Dimension d = 0; d < dimRank; ++d) {
for (Dimension d = 0; d < dimRank; d++) {
dimSizes.push_back(
stt.isDynamicDim(d)
? adaptor.getOperands()[operandCtr++]
@@ -423,7 +437,7 @@ public:
dimSizes.reserve(dimRank);
auto shape = op.getType().getShape();
unsigned operandCtr = 0;
for (Dimension d = 0; d < dimRank; ++d) {
for (Dimension d = 0; d < dimRank; d++) {
dimSizes.push_back(stt.isDynamicDim(d)
? adaptor.getOperands()[operandCtr++]
: constantIndex(rewriter, loc, shape[d]));
@@ -487,12 +501,10 @@ public:
LogicalResult
matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resTp = op.getType();
Type posTp = cast<ShapedType>(resTp).getElementType();
SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
Value lvl = constantIndex(rewriter, op->getLoc(), op.getLevel());
replaceOpWithFuncCall(rewriter, op, name, resTp, {adaptor.getTensor(), lvl},
EmitCInterface::On);
auto stt = getSparseTensorType(op.getTensor());
auto poss = genPositionsCall(rewriter, op.getLoc(), stt,
adaptor.getTensor(), op.getLevel());
rewriter.replaceOp(op, poss);
return success();
}
};
@@ -505,29 +517,14 @@ public:
LogicalResult
matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO: use `SparseTensorType::getCrdType` instead.
Type resType = op.getType();
const Type crdTp = cast<ShapedType>(resType).getElementType();
SmallString<19> name{"sparseCoordinates",
overheadTypeFunctionSuffix(crdTp)};
Location loc = op->getLoc();
Value lvl = constantIndex(rewriter, loc, op.getLevel());
// The function returns a MemRef without a layout.
MemRefType callRetType = get1DMemRefType(crdTp, false);
SmallVector<Value> operands{adaptor.getTensor(), lvl};
auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, callRetType,
operands, EmitCInterface::On);
Value callRet =
rewriter.create<func::CallOp>(loc, callRetType, fn, operands)
.getResult(0);
auto stt = getSparseTensorType(op.getTensor());
auto crds = genCoordinatesCall(rewriter, op.getLoc(), stt,
adaptor.getTensor(), op.getLevel());
// Cast the MemRef type to the type expected by the users, though these
// two types should be compatible at runtime.
if (resType != callRetType)
callRet = rewriter.create<memref::CastOp>(loc, resType, callRet);
rewriter.replaceOp(op, callRet);
if (op.getType() != crds.getType())
crds = rewriter.create<memref::CastOp>(op.getLoc(), op.getType(), crds);
rewriter.replaceOp(op, crds);
return success();
}
};
@@ -539,9 +536,9 @@ public:
LogicalResult
matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resType = cast<ShapedType>(op.getType());
rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType,
adaptor.getOperands()));
auto stt = getSparseTensorType(op.getTensor());
auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
rewriter.replaceOp(op, vals);
return success();
}
};
@@ -554,13 +551,11 @@ public:
LogicalResult
matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// Query values array size for the actually stored values size.
Type eltType = cast<ShapedType>(op.getTensor().getType()).getElementType();
auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType);
Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands());
rewriter.replaceOpWithNewOp<memref::DimOp>(op, values,
constantIndex(rewriter, loc, 0));
auto stt = getSparseTensorType(op.getTensor());
auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
auto zero = constantIndex(rewriter, op.getLoc(), 0);
rewriter.replaceOpWithNewOp<memref::DimOp>(op, vals, zero);
return success();
}
};
@@ -701,7 +696,7 @@ public:
}
};
/// Sparse conversion rule for the sparse_tensor.pack operator.
/// Sparse conversion rule for the sparse_tensor.assemble operator.
class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -710,9 +705,12 @@ public:
ConversionPatternRewriter &rewriter) const override {
const Location loc = op->getLoc();
const auto dstTp = getSparseTensorType(op.getResult());
// AssembleOps always returns a static shaped tensor result.
assert(dstTp.hasStaticDimShape());
SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, dstTp);
// Use a library method to transfer the external buffers from
// clients to the internal SparseTensorStorage. Since we cannot
// assume clients transfer ownership of the buffers, this method
// will copy all data over into a new SparseTensorStorage.
Value dst =
NewCallParams(rewriter, loc)
.genBuffers(dstTp.withoutDimToLvl(), dimSizes)
@@ -724,6 +722,115 @@ public:
}
};
/// Sparse conversion rule for the sparse_tensor.disassemble operator.
class SparseTensorDisassembleConverter
: public OpConversionPattern<DisassembleOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// We simply expose the buffers to the external client. This
// assumes the client only reads the buffers (usually copying it
// to the external data structures, such as numpy arrays).
Location loc = op->getLoc();
auto stt = getSparseTensorType(op.getTensor());
SmallVector<Value> retVal;
SmallVector<Value> retLen;
// Get the values buffer first.
auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
auto valLenTp = op.getValLen().getType();
auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
retVal.push_back(vals);
retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
// Then get the positions and coordinates buffers.
const Level lvlRank = stt.getLvlRank();
Level trailCOOLen = 0;
for (Level l = 0; l < lvlRank; l++) {
if (!stt.isUniqueLvl(l) &&
(stt.isCompressedLvl(l) || stt.isLooseCompressedLvl(l))) {
// A `(loose)compressed_nu` level marks the start of trailing COO
// start level. Since the target coordinate buffer used for trailing
// COO is passed in as AoS scheme and SparseTensorStorage uses a SoA
// scheme, we cannot simply use the internal buffers.
trailCOOLen = lvlRank - l;
break;
}
if (stt.isWithPos(l)) {
auto poss =
genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
retVal.push_back(poss);
retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
}
if (stt.isWithCrd(l)) {
auto crds =
genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
auto crdLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
retVal.push_back(crds);
retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
}
}
// Handle AoS vs. SoA mismatch for COO.
if (trailCOOLen != 0) {
uint64_t cooStartLvl = lvlRank - trailCOOLen;
assert(!stt.isUniqueLvl(cooStartLvl) &&
(stt.isCompressedLvl(cooStartLvl) ||
stt.isLooseCompressedLvl(cooStartLvl)));
// Positions.
auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
cooStartLvl);
auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
retVal.push_back(poss);
retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
// Coordinates, copied over with:
// for (i = 0; i < crdLen; i++)
// buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
auto buf =
genToMemref(rewriter, loc, op.getOutLevels()[retLen.size() - 1]);
auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
cooStartLvl);
auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
cooStartLvl + 1);
auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds0, 0);
auto two = constantIndex(rewriter, loc, 2);
auto bufLen = rewriter.create<arith::MulIOp>(loc, crdLen, two);
Type indexType = rewriter.getIndexType();
auto zero = constantZero(rewriter, loc, indexType);
auto one = constantOne(rewriter, loc, indexType);
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, zero, crdLen, one);
auto idx = forOp.getInductionVar();
rewriter.setInsertionPointToStart(forOp.getBody());
auto c0 = rewriter.create<memref::LoadOp>(loc, crds0, idx);
auto c1 = rewriter.create<memref::LoadOp>(loc, crds1, idx);
SmallVector<Value> args;
args.push_back(idx);
args.push_back(zero);
rewriter.create<memref::StoreOp>(loc, c0, buf, args);
args[1] = one;
rewriter.create<memref::StoreOp>(loc, c1, buf, args);
rewriter.setInsertionPointAfter(forOp);
auto bufLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
retVal.push_back(buf);
retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
}
// Converts MemRefs back to Tensors.
assert(retVal.size() + retLen.size() == op.getNumResults());
for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
auto tensor = rewriter.create<bufferization::ToTensorOp>(loc, retVal[i]);
retVal[i] =
rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
}
// Appends the actual memory length used in each buffer returned.
retVal.append(retLen.begin(), retLen.end());
rewriter.replaceOp(op, retVal);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
@@ -752,5 +859,6 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
SparseTensorLoadConverter, SparseTensorInsertConverter,
SparseTensorExpandConverter, SparseTensorCompressConverter,
SparseTensorAssembleConverter>(typeConverter, patterns.getContext());
SparseTensorAssembleConverter, SparseTensorDisassembleConverter>(
typeConverter, patterns.getContext());
}