[mlir][Transforms] GreedyPatternRewriteDriver debugging: Detect faulty patterns
Compute operation finger prints to detect incorrect API usage in RewritePatterns. Does not work for dialect conversion patterns. Detect patterns that: * Returned `failure` but changed the IR. * Returned `success` but did not change the IR. * Inserted/removed/modified ops, bypassing the rewriter. Not all cases are detected. These new checks are quite expensive, so they are only enabled with `-DMLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON`. Failures manifest as fatal errors (`llvm::report_fatal_error`) or crashes (accessing deallocated memory). To get better debugging information, run `mlir-opt -debug` (to see which pattern is broken) with ASAN (to see where memory was deallocated). Differential Revision: https://reviews.llvm.org/D144552
This commit is contained in:
@@ -141,6 +141,10 @@ set(MLIR_INSTALL_AGGREGATE_OBJECTS 1 CACHE BOOL
|
||||
|
||||
set(MLIR_BUILD_MLIR_C_DYLIB 0 CACHE BOOL "Builds libMLIR-C shared library.")
|
||||
|
||||
configure_file(
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Config/mlir-config.h.cmake
|
||||
${MLIR_INCLUDE_DIR}/mlir/Config/mlir-config.h)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Python Bindings Configuration
|
||||
# Requires:
|
||||
|
||||
22
mlir/include/mlir/Config/mlir-config.h.cmake
Normal file
22
mlir/include/mlir/Config/mlir-config.h.cmake
Normal file
@@ -0,0 +1,22 @@
|
||||
//===- mlir-config.h - MLIR configuration ------------------------*- C -*-===*//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/* This file enumerates variables from the MLIR configuration so that they
|
||||
can be in exported headers and won't override package specific directives.
|
||||
This is a C header that can be included in the mlir-c headers. */
|
||||
|
||||
#ifndef MLIR_CONFIG_H
|
||||
#define MLIR_CONFIG_H
|
||||
|
||||
/* Enable expensive checks to detect invalid pattern API usage. Failed checks
|
||||
manifest as fatal errors or invalid memory accesses (e.g., accessing
|
||||
deallocated memory) that cause a crash. Running with ASAN is recommended for
|
||||
easier debugging. */
|
||||
#cmakedefine01 MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
|
||||
#endif
|
||||
@@ -429,6 +429,38 @@ public:
|
||||
static bool classof(const OpBuilder::Listener *base);
|
||||
};
|
||||
|
||||
/// A listener that forwards all notifications to another listener. This
|
||||
/// struct can be used as a base to create listener chains, so that multiple
|
||||
/// listeners can be notified of IR changes.
|
||||
struct ForwardingListener : public RewriterBase::Listener {
|
||||
ForwardingListener(Listener *listener) : listener(listener) {}
|
||||
|
||||
void notifyOperationInserted(Operation *op) override {
|
||||
listener->notifyOperationInserted(op);
|
||||
}
|
||||
void notifyBlockCreated(Block *block) override {
|
||||
listener->notifyBlockCreated(block);
|
||||
}
|
||||
void notifyOperationModified(Operation *op) override {
|
||||
listener->notifyOperationModified(op);
|
||||
}
|
||||
void notifyOperationReplaced(Operation *op,
|
||||
ValueRange replacement) override {
|
||||
listener->notifyOperationReplaced(op, replacement);
|
||||
}
|
||||
void notifyOperationRemoved(Operation *op) override {
|
||||
listener->notifyOperationRemoved(op);
|
||||
}
|
||||
LogicalResult notifyMatchFailure(
|
||||
Location loc,
|
||||
function_ref<void(Diagnostic &)> reasonCallback) override {
|
||||
return listener->notifyMatchFailure(loc, reasonCallback);
|
||||
}
|
||||
|
||||
private:
|
||||
Listener *listener;
|
||||
};
|
||||
|
||||
/// Move the blocks that belong to "region" before the given position in
|
||||
/// another region "parent". The two regions must be different. The caller
|
||||
/// is responsible for creating or updating the operation transferring flow
|
||||
|
||||
@@ -11,6 +11,8 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
#include "mlir/Config/mlir-config.h"
|
||||
#include "mlir/IR/Action.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
@@ -30,10 +32,108 @@ using namespace mlir;
|
||||
#define DEBUG_TYPE "greedy-rewriter"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GreedyPatternRewriteDriver
|
||||
// Debugging Infrastructure
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
/// A helper struct that stores finger prints of ops in order to detect broken
|
||||
/// RewritePatterns. A rewrite pattern is broken if it modifies IR without
|
||||
/// using the rewriter API or if it returns an inconsistent return value.
|
||||
struct DebugFingerPrints : public RewriterBase::ForwardingListener {
|
||||
DebugFingerPrints(RewriterBase::Listener *driver)
|
||||
: RewriterBase::ForwardingListener(driver) {}
|
||||
|
||||
/// Compute finger prints of the given op and its nested ops.
|
||||
void computeFingerPrints(Operation *topLevel) {
|
||||
this->topLevel = topLevel;
|
||||
this->topLevelFingerPrint.emplace(topLevel);
|
||||
topLevel->walk([&](Operation *op) { fingerprints.try_emplace(op, op); });
|
||||
}
|
||||
|
||||
/// Clear all finger prints.
|
||||
void clear() {
|
||||
topLevel = nullptr;
|
||||
topLevelFingerPrint.reset();
|
||||
fingerprints.clear();
|
||||
}
|
||||
|
||||
void notifyRewriteSuccess() {
|
||||
// Pattern application success => IR must have changed.
|
||||
OperationFingerPrint afterFingerPrint(topLevel);
|
||||
if (*topLevelFingerPrint == afterFingerPrint) {
|
||||
// Note: Run "mlir-opt -debug" to see which pattern is broken.
|
||||
llvm::report_fatal_error(
|
||||
"pattern returned success but IR did not change");
|
||||
}
|
||||
for (const auto &it : fingerprints) {
|
||||
// Skip top-level op, its finger print is never invalidated.
|
||||
if (it.first == topLevel)
|
||||
continue;
|
||||
// Note: Finger print computation may crash when an op was erased
|
||||
// without notifying the rewriter. (Run with ASAN to see where the op was
|
||||
// erased; the op was probably erased directly, bypassing the rewriter
|
||||
// API.) Finger print computation does may not crash if a new op was
|
||||
// created at the same memory location. (But then the finger print should
|
||||
// have changed.)
|
||||
if (it.second != OperationFingerPrint(it.first)) {
|
||||
// Note: Run "mlir-opt -debug" to see which pattern is broken.
|
||||
llvm::report_fatal_error("operation finger print changed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void notifyRewriteFailure() {
|
||||
// Pattern application failure => IR must not have changed.
|
||||
OperationFingerPrint afterFingerPrint(topLevel);
|
||||
if (*topLevelFingerPrint != afterFingerPrint) {
|
||||
// Note: Run "mlir-opt -debug" to see which pattern is broken.
|
||||
llvm::report_fatal_error("pattern returned failure but IR did change");
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
/// Invalidate the finger print of the given op, i.e., remove it from the map.
|
||||
void invalidateFingerPrint(Operation *op) {
|
||||
// Invalidate all finger prints until the top level.
|
||||
while (op && op != topLevel) {
|
||||
fingerprints.erase(op);
|
||||
op = op->getParentOp();
|
||||
}
|
||||
}
|
||||
|
||||
void notifyOperationInserted(Operation *op) override {
|
||||
RewriterBase::ForwardingListener::notifyOperationInserted(op);
|
||||
invalidateFingerPrint(op->getParentOp());
|
||||
}
|
||||
|
||||
void notifyOperationModified(Operation *op) override {
|
||||
RewriterBase::ForwardingListener::notifyOperationModified(op);
|
||||
invalidateFingerPrint(op);
|
||||
}
|
||||
|
||||
void notifyOperationRemoved(Operation *op) override {
|
||||
RewriterBase::ForwardingListener::notifyOperationRemoved(op);
|
||||
op->walk([this](Operation *op) { invalidateFingerPrint(op); });
|
||||
}
|
||||
|
||||
/// Operation finger prints to detect invalid pattern API usage. IR is checked
|
||||
/// against these finger prints after pattern application to detect cases
|
||||
/// where IR was modified directly, bypassing the rewriter API.
|
||||
DenseMap<Operation *, OperationFingerPrint> fingerprints;
|
||||
|
||||
/// Top-level operation of the current greedy rewrite.
|
||||
Operation *topLevel = nullptr;
|
||||
|
||||
/// Finger print of the top-level operation.
|
||||
std::optional<OperationFingerPrint> topLevelFingerPrint;
|
||||
};
|
||||
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GreedyPatternRewriteDriver
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
|
||||
/// applies the locally optimal patterns.
|
||||
///
|
||||
@@ -122,21 +222,36 @@ private:
|
||||
|
||||
/// The low-level pattern applicator.
|
||||
PatternApplicator matcher;
|
||||
|
||||
#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
DebugFingerPrints debugFingerPrints;
|
||||
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
};
|
||||
} // namespace
|
||||
|
||||
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
|
||||
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
|
||||
const GreedyRewriteConfig &config)
|
||||
: PatternRewriter(ctx), folder(ctx, this), config(config),
|
||||
matcher(patterns) {
|
||||
: PatternRewriter(ctx), folder(ctx, this), config(config), matcher(patterns)
|
||||
#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
// clang-format off
|
||||
, debugFingerPrints(this)
|
||||
// clang-format on
|
||||
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
{
|
||||
worklist.reserve(64);
|
||||
|
||||
// Apply a simple cost model based solely on pattern benefit.
|
||||
matcher.applyDefaultCostModel();
|
||||
|
||||
// Set up listener.
|
||||
#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
// Send IR notifications to the debug handler. This handler will then forward
|
||||
// all notifications to this GreedyPatternRewriteDriver.
|
||||
setListener(&debugFingerPrints);
|
||||
#else
|
||||
setListener(this);
|
||||
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
}
|
||||
|
||||
bool GreedyPatternRewriteDriver::processWorklist() {
|
||||
@@ -231,15 +346,28 @@ bool GreedyPatternRewriteDriver::processWorklist() {
|
||||
function_ref<LogicalResult(const Pattern &)> onSuccess = {};
|
||||
#endif
|
||||
|
||||
#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
debugFingerPrints.computeFingerPrints(
|
||||
/*topLevel=*/config.scope ? config.scope->getParentOp() : op);
|
||||
auto clearFingerprints =
|
||||
llvm::make_scope_exit([&]() { debugFingerPrints.clear(); });
|
||||
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
|
||||
LogicalResult matchResult =
|
||||
matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
|
||||
|
||||
if (succeeded(matchResult)) {
|
||||
LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
|
||||
#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
debugFingerPrints.notifyRewriteSuccess();
|
||||
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
changed = true;
|
||||
++numRewrites;
|
||||
} else {
|
||||
LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
|
||||
#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
debugFingerPrints.notifyRewriteFailure();
|
||||
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,6 +375,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
|
||||
}
|
||||
|
||||
void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
|
||||
assert(op && "expected valid op");
|
||||
// Gather potential ancestors while looking for a "scope" parent region.
|
||||
SmallVector<Operation *, 8> ancestors;
|
||||
Region *region = nullptr;
|
||||
|
||||
Reference in New Issue
Block a user