[mlir][sparse] implement sparse tensor init operation
Next step towards supporting sparse tensors outputs. Also some minor refactoring of enum constants as well as replacing tensor arguments with proper buffer arguments (latter is required for more general sizes arguments for the sparse_tensor.init operation, as well as more general spares_tensor.convert operations later) Reviewed By: wrengr Differential Revision: https://reviews.llvm.org/D111771
This commit is contained in:
@@ -29,6 +29,16 @@ using namespace mlir::sparse_tensor;
|
||||
|
||||
namespace {
|
||||
|
||||
/// New tensor storage action. Keep these values consistent with
|
||||
/// the sparse runtime support library.
|
||||
enum Action : uint32_t {
|
||||
kEmpty = 0,
|
||||
kFromFile = 1,
|
||||
kFromCOO = 2,
|
||||
kEmptyCOO = 3,
|
||||
kToCOO = 4
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Helper methods.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -105,18 +115,10 @@ inline static Value constantI32(ConversionPatternRewriter &rewriter,
|
||||
return rewriter.create<arith::ConstantIntOp>(loc, i, 32);
|
||||
}
|
||||
|
||||
/// Returns integers of given width and values as a constant tensor.
|
||||
/// We cast the static shape into a dynamic shape to ensure that the
|
||||
/// method signature remains uniform across different tensor dimensions.
|
||||
static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width,
|
||||
Location loc, ArrayRef<APInt> values) {
|
||||
Type etp = rewriter.getIntegerType(width);
|
||||
unsigned sz = values.size();
|
||||
RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
|
||||
RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp);
|
||||
auto elts = rewriter.create<arith::ConstantOp>(
|
||||
loc, DenseElementsAttr::get(tt1, values));
|
||||
return rewriter.create<tensor::CastOp>(loc, tt2, elts);
|
||||
/// Generates a constant of `i8` type.
|
||||
inline static Value constantI8(ConversionPatternRewriter &rewriter,
|
||||
Location loc, int8_t i) {
|
||||
return rewriter.create<arith::ConstantIntOp>(loc, i, 8);
|
||||
}
|
||||
|
||||
/// Returns a function reference (first hit also inserts into module). Sets
|
||||
@@ -142,43 +144,70 @@ static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Generates a temporary buffer of the given size and type.
|
||||
static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc,
|
||||
unsigned sz, Type tp) {
|
||||
auto memTp = MemRefType::get({ShapedType::kDynamicSize}, tp);
|
||||
Value a = constantIndex(rewriter, loc, sz);
|
||||
return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{a});
|
||||
}
|
||||
|
||||
/// Fills a temporary buffer of the given type with arguments.
|
||||
static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> values) {
|
||||
unsigned sz = values.size();
|
||||
assert(sz >= 1);
|
||||
Value buffer = genAlloca(rewriter, loc, sz, values[0].getType());
|
||||
for (unsigned i = 0; i < sz; i++) {
|
||||
Value idx = constantIndex(rewriter, loc, i);
|
||||
rewriter.create<memref::StoreOp>(loc, values[i], buffer, idx);
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
/// Generates a call into the "swiss army knife" method of the sparse runtime
|
||||
/// support library for materializing sparse tensors into the computation. The
|
||||
/// method returns the call value and assigns the permutation to 'perm'.
|
||||
static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
SparseTensorEncodingAttr &enc, uint32_t action,
|
||||
Value &perm, Value ptr = Value()) {
|
||||
Value &perm, ValueRange szs, Value ptr = Value()) {
|
||||
Location loc = op->getLoc();
|
||||
ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
|
||||
SmallVector<Value, 8> params;
|
||||
// Sparsity annotations in tensor constant form.
|
||||
SmallVector<APInt, 4> attrs;
|
||||
unsigned sz = enc.getDimLevelType().size();
|
||||
SmallVector<Value, 4> attrs;
|
||||
ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType();
|
||||
unsigned sz = dlt.size();
|
||||
for (unsigned i = 0; i < sz; i++)
|
||||
attrs.push_back(
|
||||
APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i])));
|
||||
params.push_back(getTensor(rewriter, 8, loc, attrs));
|
||||
attrs.push_back(constantI8(rewriter, loc, getDimLevelTypeEncoding(dlt[i])));
|
||||
params.push_back(genBuffer(rewriter, loc, attrs));
|
||||
// Dimension sizes array of the enveloping *dense* tensor. Useful for either
|
||||
// verification of external data, or for construction of internal data.
|
||||
auto shape = resType.getShape();
|
||||
SmallVector<APInt, 4> sizes;
|
||||
for (unsigned i = 0; i < sz; i++) {
|
||||
uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
|
||||
sizes.push_back(APInt(64, s));
|
||||
SmallVector<Value, 4> sizes;
|
||||
if (szs.size() > 0) {
|
||||
for (Value s : szs)
|
||||
sizes.push_back(
|
||||
rewriter.create<arith::IndexCastOp>(loc, s, rewriter.getI64Type()));
|
||||
} else {
|
||||
for (unsigned i = 0; i < sz; i++) {
|
||||
uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
|
||||
sizes.push_back(constantI64(rewriter, loc, s));
|
||||
}
|
||||
}
|
||||
params.push_back(getTensor(rewriter, 64, loc, sizes));
|
||||
params.push_back(genBuffer(rewriter, loc, sizes));
|
||||
// Dimension order permutation array. This is the "identity" permutation by
|
||||
// default, or otherwise the "reverse" permutation of a given ordering, so
|
||||
// that indices can be mapped quickly to the right position.
|
||||
SmallVector<APInt, 4> rev(sz);
|
||||
SmallVector<Value, 4> rev(sz);
|
||||
if (AffineMap p = enc.getDimOrdering()) {
|
||||
for (unsigned i = 0; i < sz; i++)
|
||||
rev[p.getDimPosition(i)] = APInt(64, i);
|
||||
rev[p.getDimPosition(i)] = constantI64(rewriter, loc, i);
|
||||
} else {
|
||||
for (unsigned i = 0; i < sz; i++)
|
||||
rev[i] = APInt(64, i);
|
||||
rev[i] = constantI64(rewriter, loc, i);
|
||||
}
|
||||
perm = getTensor(rewriter, 64, loc, rev);
|
||||
perm = genBuffer(rewriter, loc, rev);
|
||||
params.push_back(perm);
|
||||
// Secondary and primary types encoding.
|
||||
unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
|
||||
@@ -309,18 +338,6 @@ static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
|
||||
return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]);
|
||||
}
|
||||
|
||||
/// Generates code to stack-allocate a `memref<?xindex>` where the `?`
|
||||
/// is the given `rank`. This array is intended to serve as a reusable
|
||||
/// buffer for storing the indices of a single tensor element, to avoid
|
||||
/// allocation in the body of loops.
|
||||
static Value allocaIndices(ConversionPatternRewriter &rewriter, Location loc,
|
||||
int64_t rank) {
|
||||
auto indexTp = rewriter.getIndexType();
|
||||
auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp);
|
||||
Value arg = constantIndex(rewriter, loc, rank);
|
||||
return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conversion rules.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -378,8 +395,25 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
|
||||
if (!enc)
|
||||
return failure();
|
||||
Value perm;
|
||||
rewriter.replaceOp(op, genNewCall(rewriter, op, enc, kFromFile, perm, {},
|
||||
adaptor.getOperands()[0]));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse conversion rule for the init operator.
|
||||
class SparseTensorInitConverter : public OpConversionPattern<InitOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(InitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type resType = op.getType();
|
||||
auto enc = getSparseTensorEncoding(resType);
|
||||
if (!enc)
|
||||
return failure();
|
||||
Value perm;
|
||||
rewriter.replaceOp(
|
||||
op, genNewCall(rewriter, op, enc, 0, perm, adaptor.getOperands()[0]));
|
||||
op, genNewCall(rewriter, op, enc, kEmpty, perm, adaptor.getOperands()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -402,8 +436,9 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
|
||||
// yield the fastest conversion but avoids the need for a full
|
||||
// O(N^2) conversion matrix.
|
||||
Value perm;
|
||||
Value coo = genNewCall(rewriter, op, encDst, 3, perm, src);
|
||||
rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, coo));
|
||||
Value coo = genNewCall(rewriter, op, encDst, kToCOO, perm, {}, src);
|
||||
rewriter.replaceOp(
|
||||
op, genNewCall(rewriter, op, encDst, kFromCOO, perm, {}, coo));
|
||||
return success();
|
||||
}
|
||||
if (!encDst || encSrc) {
|
||||
@@ -439,8 +474,9 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
|
||||
Location loc = op->getLoc();
|
||||
ShapedType shape = resType.cast<ShapedType>();
|
||||
Value perm;
|
||||
Value ptr = genNewCall(rewriter, op, encDst, 2, perm);
|
||||
Value ind = allocaIndices(rewriter, loc, shape.getRank());
|
||||
Value ptr = genNewCall(rewriter, op, encDst, kEmptyCOO, perm, {});
|
||||
Value ind =
|
||||
genAlloca(rewriter, loc, shape.getRank(), rewriter.getIndexType());
|
||||
SmallVector<Value> lo;
|
||||
SmallVector<Value> hi;
|
||||
SmallVector<Value> st;
|
||||
@@ -478,7 +514,8 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
|
||||
genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm);
|
||||
return {};
|
||||
});
|
||||
rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, ptr));
|
||||
rewriter.replaceOp(
|
||||
op, genNewCall(rewriter, op, encDst, kFromCOO, perm, {}, ptr));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -637,9 +674,9 @@ public:
|
||||
void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
|
||||
SparseTensorNewConverter, SparseTensorConvertConverter,
|
||||
SparseTensorReleaseConverter, SparseTensorToPointersConverter,
|
||||
SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
|
||||
SparseTensorToTensorConverter>(typeConverter,
|
||||
patterns.getContext());
|
||||
SparseTensorNewConverter, SparseTensorInitConverter,
|
||||
SparseTensorConvertConverter, SparseTensorReleaseConverter,
|
||||
SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
|
||||
SparseTensorToValuesConverter, SparseTensorToTensorConverter>(
|
||||
typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user