Files
clang-p2996/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
River Riddle 9a277af2d4 [mlir][Pass] Add support for generating pass utilities via tablegen
This revision adds support for generating utilities for passes such as options/statistics/etc. that can be inferred from the tablegen definition. This removes additional boilerplate from the pass, and also makes it easier to remove the reliance on the pass registry to provide certain things(e.g. the pass argument).

Differential Revision: https://reviews.llvm.org/D76659
2020-04-01 02:10:46 -07:00

182 lines
6.9 KiB
C++

//===- LegalizeStandardForSPIRV.cpp - Legalize ops for SPIR-V lowering ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This transformation pass legalizes operations before the conversion to SPIR-V
// dialect to handle ops that cannot be lowered directly.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
namespace {
/// Merges subview operation with load operation.
class LoadOpOfSubViewFolder final : public OpRewritePattern<LoadOp> {
public:
using OpRewritePattern<LoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(LoadOp loadOp,
PatternRewriter &rewriter) const override;
};
/// Merges subview operation with store operation.
class StoreOpOfSubViewFolder final : public OpRewritePattern<StoreOp> {
public:
using OpRewritePattern<StoreOp>::OpRewritePattern;
LogicalResult matchAndRewrite(StoreOp storeOp,
PatternRewriter &rewriter) const override;
};
} // namespace
//===----------------------------------------------------------------------===//
// Utility functions for op legalization.
//===----------------------------------------------------------------------===//
/// Given the 'indices' of an load/store operation where the memref is a result
/// of a subview op, returns the indices w.r.t to the source memref of the
/// subview op. For example
///
/// %0 = ... : memref<12x42xf32>
/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
/// memref<4x4xf32, offset=?, strides=[?, ?]>
/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
///
/// could be folded into
///
/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
/// memref<12x42xf32>
static LogicalResult
resolveSourceIndices(Location loc, PatternRewriter &rewriter,
SubViewOp subViewOp, ValueRange indices,
SmallVectorImpl<Value> &sourceIndices) {
// TODO: Aborting when the offsets are static. There might be a way to fold
// the subview op with load even if the offsets have been canonicalized
// away.
if (subViewOp.getNumOffsets() == 0)
return failure();
ValueRange opOffsets = subViewOp.offsets();
SmallVector<Value, 2> opStrides;
if (subViewOp.getNumStrides()) {
// If the strides are dynamic, get the stride operands.
opStrides = llvm::to_vector<2>(subViewOp.strides());
} else {
// When static, the stride operands can be retrieved by taking the strides
// of the result of the subview op, and dividing the strides of the base
// memref.
SmallVector<int64_t, 2> staticStrides;
if (failed(subViewOp.getStaticStrides(staticStrides))) {
return failure();
}
opStrides.reserve(opOffsets.size());
for (auto stride : staticStrides) {
auto constValAttr = rewriter.getIntegerAttr(
IndexType::get(rewriter.getContext()), stride);
opStrides.emplace_back(rewriter.create<ConstantOp>(loc, constValAttr));
}
}
assert(opOffsets.size() == opStrides.size());
// New indices for the load are the current indices * subview_stride +
// subview_offset.
assert(indices.size() == opStrides.size());
sourceIndices.resize(indices.size());
for (auto index : llvm::enumerate(indices)) {
auto offset = opOffsets[index.index()];
auto stride = opStrides[index.index()];
auto mul = rewriter.create<MulIOp>(loc, index.value(), stride);
sourceIndices[index.index()] =
rewriter.create<AddIOp>(loc, offset, mul).getResult();
}
return success();
}
//===----------------------------------------------------------------------===//
// Folding SubViewOp and LoadOp.
//===----------------------------------------------------------------------===//
LogicalResult
LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp,
PatternRewriter &rewriter) const {
auto subViewOp = dyn_cast_or_null<SubViewOp>(loadOp.memref().getDefiningOp());
if (!subViewOp) {
return failure();
}
SmallVector<Value, 4> sourceIndices;
if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
loadOp.indices(), sourceIndices)))
return failure();
rewriter.replaceOpWithNewOp<LoadOp>(loadOp, subViewOp.source(),
sourceIndices);
return success();
}
//===----------------------------------------------------------------------===//
// Folding SubViewOp and StoreOp.
//===----------------------------------------------------------------------===//
LogicalResult
StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp,
PatternRewriter &rewriter) const {
auto subViewOp =
dyn_cast_or_null<SubViewOp>(storeOp.memref().getDefiningOp());
if (!subViewOp) {
return failure();
}
SmallVector<Value, 4> sourceIndices;
if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
storeOp.indices(), sourceIndices)))
return failure();
rewriter.replaceOpWithNewOp<StoreOp>(storeOp, storeOp.value(),
subViewOp.source(), sourceIndices);
return success();
}
//===----------------------------------------------------------------------===//
// Hook for adding patterns.
//===----------------------------------------------------------------------===//
void mlir::populateStdLegalizationPatternsForSPIRVLowering(
MLIRContext *context, OwningRewritePatternList &patterns) {
patterns.insert<LoadOpOfSubViewFolder, StoreOpOfSubViewFolder>(context);
}
//===----------------------------------------------------------------------===//
// Pass for testing just the legalization patterns.
//===----------------------------------------------------------------------===//
namespace {
struct SPIRVLegalization final : public OperationPass<SPIRVLegalization> {
/// Include the generated pass utilities.
#define GEN_PASS_LegalizeStandardForSPIRV
#include "mlir/Conversion/Passes.h.inc"
void runOnOperation() override;
};
} // namespace
void SPIRVLegalization::runOnOperation() {
OwningRewritePatternList patterns;
auto *context = &getContext();
populateStdLegalizationPatternsForSPIRVLowering(context, patterns);
applyPatternsGreedily(getOperation()->getRegions(), patterns);
}
std::unique_ptr<Pass> mlir::createLegalizeStdOpsForSPIRVLoweringPass() {
return std::make_unique<SPIRVLegalization>();
}