[mlir][sparse] Add codegen for expand op.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D133454
This commit is contained in:
bixia1
2022-09-07 14:34:04 -07:00
parent 0f2f1c2be1
commit 8a583bd53d
4 changed files with 112 additions and 5 deletions

View File

@@ -19,6 +19,7 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
@@ -474,6 +475,58 @@ public:
}
};
/// Sparse codegen rule for the expand op.
class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
ShapedType srcType = op.getTensor().getType().cast<ShapedType>();
Type eltType = srcType.getElementType();
Type boolType = rewriter.getIntegerType(1);
Type idxType = rewriter.getIndexType();
// All initialization should be done on entry of the loop nest.
rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
// Determine the size for access expansion (always the innermost stored
// dimension size, translated back to original dimension). Note that we
// recursively rewrite the new DimOp on the **original** tensor.
auto enc = getSparseTensorEncoding(srcType);
unsigned innerDim = srcType.getRank() - 1;
if (AffineMap p = enc.getDimOrdering())
innerDim = p.getDimPosition(innerDim);
Value sz = rewriter.create<tensor::DimOp>(loc, op.getTensor(), innerDim);
// Generate a memref for `sz` elements of type `t`.
auto genAlloc = [&](Type t) {
auto memTp = MemRefType::get({ShapedType::kDynamicSize}, t);
return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
};
// Allocate temporary buffers for values, filled-switch, and indices.
// We do not use stack buffers for this, since the expanded size may
// be rather large (as it envelops a single expanded dense dimension).
Value values = genAlloc(eltType);
Value filled = genAlloc(boolType);
Value indices = genAlloc(idxType);
Value zero = constantZero(rewriter, loc, idxType);
// Reset the values/filled-switch to all-zero/false. Note that this
// introduces an O(N) operation into the computation, but this reset
// operation is amortized over the innermost loops for the access
// pattern expansion. As noted in the operation doc, we would like
// to amortize this setup cost even between kernels.
rewriter.create<linalg::FillOp>(
loc, ValueRange{constantZero(rewriter, loc, eltType)},
ValueRange{values});
rewriter.create<linalg::FillOp>(
loc, ValueRange{constantZero(rewriter, loc, boolType)},
ValueRange{filled});
// Replace expansion op with these buffers and initial index.
assert(op.getNumResults() == 4);
rewriter.replaceOp(op, {values, filled, indices, zero});
return success();
}
};
/// Sparse codegen rule for pointer accesses.
class SparseToPointersConverter
: public SparseGetterOpConverter<ToPointersOp, SparseToPointersConverter> {
@@ -533,8 +586,9 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
SparseCastConverter, SparseTensorAllocConverter,
SparseTensorDeallocConverter, SparseToPointersConverter,
SparseToIndicesConverter, SparseToValuesConverter,
SparseTensorLoadConverter>(typeConverter, patterns.getContext());
SparseCastConverter, SparseExpandConverter,
SparseTensorAllocConverter, SparseTensorDeallocConverter,
SparseToPointersConverter, SparseToIndicesConverter,
SparseToValuesConverter, SparseTensorLoadConverter>(
typeConverter, patterns.getContext());
}