The only benefit of FunctionPass is that it filters out function declarations. This isn't enough to justify carrying it around, as we can simplify filter out declarations when necessary within the pass. We can also explore with better scheduling primitives to filter out declarations at the pipeline level in the future. The definition of FunctionPass is left intact for now to allow time for downstream users to migrate. Differential Revision: https://reviews.llvm.org/D117182
236 lines
9.3 KiB
C++
236 lines
9.3 KiB
C++
//===- Bufferize.cpp - Bufferization for `tensor` dialect ops -------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements bufferization of `tensor` dialect ops
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
|
#include "PassDetail.h"
|
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/SCF.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
|
|
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
struct BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto resultType = getTypeConverter()->convertType(op.getType());
|
|
rewriter.replaceOpWithNewOp<memref::CastOp>(op, resultType,
|
|
adaptor.getOperands()[0]);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(),
|
|
adaptor.index());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.tensor(),
|
|
adaptor.indices());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct BufferizeFromElementsOp
|
|
: public OpConversionPattern<tensor::FromElementsOp> {
|
|
public:
|
|
using OpConversionPattern::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(tensor::FromElementsOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
auto tensorType = op.getType().cast<RankedTensorType>();
|
|
auto shape = tensorType.getShape();
|
|
|
|
// Allocate a buffer for the result.
|
|
auto resultType =
|
|
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
|
Value buffer = rewriter.create<memref::AllocOp>(loc, resultType);
|
|
|
|
// Case: tensor<0xelem_type>.
|
|
if (op.elements().empty()) {
|
|
rewriter.replaceOp(op, {buffer});
|
|
return success();
|
|
}
|
|
|
|
// Case: tensor<elem_type>.
|
|
if (shape.empty()) {
|
|
rewriter.create<memref::StoreOp>(loc, op.elements().front(), buffer);
|
|
rewriter.replaceOp(op, {buffer});
|
|
return success();
|
|
}
|
|
|
|
// Create constants for the range of possible indices [0, max{shape_i}).
|
|
auto maxDim = *std::max_element(shape.begin(), shape.end());
|
|
SmallVector<Value, 2> constants;
|
|
constants.reserve(maxDim);
|
|
for (int i = 0; i < maxDim; ++i)
|
|
constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
|
|
|
|
// Traverse all `elements` and create `memref.store` ops.
|
|
ImplicitLocOpBuilder b(loc, rewriter);
|
|
auto elementIt = adaptor.elements().begin();
|
|
SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
|
|
createStores(/*dim=*/0, buffer, shape, constants, elementIt, indices, b);
|
|
|
|
rewriter.replaceOp(op, {buffer});
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
// Implements backtracking to traverse indices of the output buffer while
|
|
// iterating over op.elements().
|
|
void createStores(int dim, Value buffer, ArrayRef<int64_t> shape,
|
|
ArrayRef<Value> constants, ValueRange::iterator &elementIt,
|
|
SmallVectorImpl<Value> &indices,
|
|
ImplicitLocOpBuilder b) const {
|
|
if (dim == static_cast<int>(shape.size()) - 1) {
|
|
for (int i = 0; i < shape.back(); ++i) {
|
|
indices.back() = constants[i];
|
|
b.create<memref::StoreOp>(*elementIt, buffer, indices);
|
|
++elementIt;
|
|
}
|
|
return;
|
|
}
|
|
for (int i = 0; i < shape[dim]; ++i) {
|
|
indices[dim] = constants[i];
|
|
createStores(dim + 1, buffer, shape, constants, elementIt, indices, b);
|
|
}
|
|
}
|
|
};
|
|
|
|
struct BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(tensor::GenerateOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
// Allocate memory.
|
|
Location loc = op.getLoc();
|
|
RankedTensorType tensorType = op.getType().cast<RankedTensorType>();
|
|
MemRefType memrefType =
|
|
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
|
Value result = rewriter.create<memref::AllocOp>(loc, memrefType,
|
|
adaptor.dynamicExtents());
|
|
|
|
// Collect loop bounds.
|
|
int64_t rank = tensorType.getRank();
|
|
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
|
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
|
SmallVector<Value, 4> lowerBounds(rank, zero);
|
|
SmallVector<Value, 4> steps(rank, one);
|
|
SmallVector<Value, 4> upperBounds;
|
|
int nextDynamicIndex = 0;
|
|
for (int i = 0; i < rank; i++) {
|
|
Value upperBound = tensorType.isDynamicDim(i)
|
|
? adaptor.dynamicExtents()[nextDynamicIndex++]
|
|
: rewriter.create<arith::ConstantIndexOp>(
|
|
loc, memrefType.getDimSize(i));
|
|
upperBounds.push_back(upperBound);
|
|
}
|
|
|
|
// Generate tensor elements with a parallel loop that stores into
|
|
// each element of the resulting memref.
|
|
//
|
|
// This is a bit tricky. We cannot simply clone the ops because when an op
|
|
// is cloned, it must be legalized. However, we want to allow arbitrary ops
|
|
// in the body that we don't necessarily have legalization patterns for as
|
|
// part of this dialect conversion invocation.
|
|
//
|
|
// To accomplish this, we use mergeBlockBefore to "move" this op's body
|
|
// into the scf.parallel's body.
|
|
auto parallel =
|
|
rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
|
|
Block *parallelBody = parallel.getBody();
|
|
rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(),
|
|
parallelBody->getArguments());
|
|
// Replace the inlined yield op with a store op. The scf.parallel's builder
|
|
// already populated an scf.yield at the end, so we don't need to worry
|
|
// about creating that.
|
|
Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
|
|
rewriter.setInsertionPointAfter(elementYield);
|
|
rewriter.replaceOpWithNewOp<memref::StoreOp>(
|
|
elementYield, elementYield->getOperands()[0], result,
|
|
parallelBody->getArguments());
|
|
|
|
rewriter.replaceOp(op, {result});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct BufferizeRankOp : public OpConversionPattern<tensor::RankOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(tensor::RankOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOpWithNewOp<memref::RankOp>(op, op.getType(),
|
|
adaptor.tensor());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
|
|
void runOnOperation() override {
|
|
auto *context = &getContext();
|
|
bufferization::BufferizeTypeConverter typeConverter;
|
|
|
|
ConversionTarget target(*context);
|
|
target.addLegalDialect<scf::SCFDialect, memref::MemRefDialect>();
|
|
target.addDynamicallyLegalDialect<arith::ArithmeticDialect,
|
|
StandardOpsDialect>(
|
|
[&](Operation *op) { return typeConverter.isLegal(op); });
|
|
target.addLegalOp<CallOp, ReturnOp>();
|
|
target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
|
|
tensor::FromElementsOp, tensor::GenerateOp>();
|
|
bufferization::populateBufferizeMaterializationLegality(target);
|
|
|
|
RewritePatternSet patterns(context);
|
|
populateTensorBufferizePatterns(typeConverter, patterns);
|
|
if (failed(applyPartialConversion(getOperation(), target,
|
|
std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::populateTensorBufferizePatterns(
|
|
bufferization::BufferizeTypeConverter &typeConverter,
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
|
|
BufferizeFromElementsOp, BufferizeGenerateOp, BufferizeRankOp>(
|
|
typeConverter, patterns.getContext());
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
|
|
return std::make_unique<TensorBufferizePass>();
|
|
}
|