[mlir][sparse] Add codegen for expand op.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D133454
This commit is contained in:
@@ -19,6 +19,7 @@
|
||||
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
||||
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
|
||||
@@ -474,6 +475,58 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse codegen rule for the expand op.
|
||||
class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
ShapedType srcType = op.getTensor().getType().cast<ShapedType>();
|
||||
Type eltType = srcType.getElementType();
|
||||
Type boolType = rewriter.getIntegerType(1);
|
||||
Type idxType = rewriter.getIndexType();
|
||||
// All initialization should be done on entry of the loop nest.
|
||||
rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
|
||||
// Determine the size for access expansion (always the innermost stored
|
||||
// dimension size, translated back to original dimension). Note that we
|
||||
// recursively rewrite the new DimOp on the **original** tensor.
|
||||
auto enc = getSparseTensorEncoding(srcType);
|
||||
unsigned innerDim = srcType.getRank() - 1;
|
||||
if (AffineMap p = enc.getDimOrdering())
|
||||
innerDim = p.getDimPosition(innerDim);
|
||||
Value sz = rewriter.create<tensor::DimOp>(loc, op.getTensor(), innerDim);
|
||||
// Generate a memref for `sz` elements of type `t`.
|
||||
auto genAlloc = [&](Type t) {
|
||||
auto memTp = MemRefType::get({ShapedType::kDynamicSize}, t);
|
||||
return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
|
||||
};
|
||||
// Allocate temporary buffers for values, filled-switch, and indices.
|
||||
// We do not use stack buffers for this, since the expanded size may
|
||||
// be rather large (as it envelops a single expanded dense dimension).
|
||||
Value values = genAlloc(eltType);
|
||||
Value filled = genAlloc(boolType);
|
||||
Value indices = genAlloc(idxType);
|
||||
Value zero = constantZero(rewriter, loc, idxType);
|
||||
// Reset the values/filled-switch to all-zero/false. Note that this
|
||||
// introduces an O(N) operation into the computation, but this reset
|
||||
// operation is amortized over the innermost loops for the access
|
||||
// pattern expansion. As noted in the operation doc, we would like
|
||||
// to amortize this setup cost even between kernels.
|
||||
rewriter.create<linalg::FillOp>(
|
||||
loc, ValueRange{constantZero(rewriter, loc, eltType)},
|
||||
ValueRange{values});
|
||||
rewriter.create<linalg::FillOp>(
|
||||
loc, ValueRange{constantZero(rewriter, loc, boolType)},
|
||||
ValueRange{filled});
|
||||
// Replace expansion op with these buffers and initial index.
|
||||
assert(op.getNumResults() == 4);
|
||||
rewriter.replaceOp(op, {values, filled, indices, zero});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse codegen rule for pointer accesses.
|
||||
class SparseToPointersConverter
|
||||
: public SparseGetterOpConverter<ToPointersOp, SparseToPointersConverter> {
|
||||
@@ -533,8 +586,9 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
|
||||
void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
|
||||
SparseCastConverter, SparseTensorAllocConverter,
|
||||
SparseTensorDeallocConverter, SparseToPointersConverter,
|
||||
SparseToIndicesConverter, SparseToValuesConverter,
|
||||
SparseTensorLoadConverter>(typeConverter, patterns.getContext());
|
||||
SparseCastConverter, SparseExpandConverter,
|
||||
SparseTensorAllocConverter, SparseTensorDeallocConverter,
|
||||
SparseToPointersConverter, SparseToIndicesConverter,
|
||||
SparseToValuesConverter, SparseTensorLoadConverter>(
|
||||
typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user