[mlir][Transforms] GreedyPatternRewriteDriver debugging: Detect faulty patterns

Compute operation finger prints to detect incorrect API usage in RewritePatterns. Does not work for dialect conversion patterns.

Detect patterns that:
* Returned `failure` but changed the IR.
* Returned `success` but did not change the IR.
* Inserted/removed/modified ops, bypassing the rewriter. Not all cases are detected.

These new checks are quite expensive, so they are only enabled with `-DMLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON`. Failures manifest as fatal errors (`llvm::report_fatal_error`) or crashes (accessing deallocated memory). To get better debugging information, run `mlir-opt -debug` (to see which pattern is broken) with ASAN (to see where memory was deallocated).

Differential Revision: https://reviews.llvm.org/D144552
This commit is contained in:
Matthias Springer
2023-05-24 16:14:47 +02:00
parent 0ea5eb143c
commit e6d90a0d5e
4 changed files with 190 additions and 3 deletions

View File

@@ -141,6 +141,10 @@ set(MLIR_INSTALL_AGGREGATE_OBJECTS 1 CACHE BOOL
set(MLIR_BUILD_MLIR_C_DYLIB 0 CACHE BOOL "Builds libMLIR-C shared library.")
configure_file(
${MLIR_MAIN_INCLUDE_DIR}/mlir/Config/mlir-config.h.cmake
${MLIR_INCLUDE_DIR}/mlir/Config/mlir-config.h)
#-------------------------------------------------------------------------------
# Python Bindings Configuration
# Requires:

View File

@@ -0,0 +1,22 @@
//===- mlir-config.h - MLIR configuration ------------------------*- C -*-===*//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
/* This file enumerates variables from the MLIR configuration so that they
can be in exported headers and won't override package specific directives.
This is a C header that can be included in the mlir-c headers. */
#ifndef MLIR_CONFIG_H
#define MLIR_CONFIG_H
/* Enable expensive checks to detect invalid pattern API usage. Failed checks
manifest as fatal errors or invalid memory accesses (e.g., accessing
deallocated memory) that cause a crash. Running with ASAN is recommended for
easier debugging. */
#cmakedefine01 MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
#endif

View File

@@ -429,6 +429,38 @@ public:
static bool classof(const OpBuilder::Listener *base);
};
/// A listener that forwards all notifications to another listener. This
/// struct can be used as a base to create listener chains, so that multiple
/// listeners can be notified of IR changes.
struct ForwardingListener : public RewriterBase::Listener {
ForwardingListener(Listener *listener) : listener(listener) {}
void notifyOperationInserted(Operation *op) override {
listener->notifyOperationInserted(op);
}
void notifyBlockCreated(Block *block) override {
listener->notifyBlockCreated(block);
}
void notifyOperationModified(Operation *op) override {
listener->notifyOperationModified(op);
}
void notifyOperationReplaced(Operation *op,
ValueRange replacement) override {
listener->notifyOperationReplaced(op, replacement);
}
void notifyOperationRemoved(Operation *op) override {
listener->notifyOperationRemoved(op);
}
LogicalResult notifyMatchFailure(
Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override {
return listener->notifyMatchFailure(loc, reasonCallback);
}
private:
Listener *listener;
};
/// Move the blocks that belong to "region" before the given position in
/// another region "parent". The two regions must be different. The caller
/// is responsible for creating or updating the operation transferring flow

View File

@@ -11,6 +11,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Config/mlir-config.h"
#include "mlir/IR/Action.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -30,10 +32,108 @@ using namespace mlir;
#define DEBUG_TYPE "greedy-rewriter"
//===----------------------------------------------------------------------===//
// GreedyPatternRewriteDriver
// Debugging Infrastructure
//===----------------------------------------------------------------------===//
namespace {
#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
/// A helper struct that stores finger prints of ops in order to detect broken
/// RewritePatterns. A rewrite pattern is broken if it modifies IR without
/// using the rewriter API or if it returns an inconsistent return value.
struct DebugFingerPrints : public RewriterBase::ForwardingListener {
DebugFingerPrints(RewriterBase::Listener *driver)
: RewriterBase::ForwardingListener(driver) {}
/// Compute finger prints of the given op and its nested ops.
void computeFingerPrints(Operation *topLevel) {
this->topLevel = topLevel;
this->topLevelFingerPrint.emplace(topLevel);
topLevel->walk([&](Operation *op) { fingerprints.try_emplace(op, op); });
}
/// Clear all finger prints.
void clear() {
topLevel = nullptr;
topLevelFingerPrint.reset();
fingerprints.clear();
}
void notifyRewriteSuccess() {
// Pattern application success => IR must have changed.
OperationFingerPrint afterFingerPrint(topLevel);
if (*topLevelFingerPrint == afterFingerPrint) {
// Note: Run "mlir-opt -debug" to see which pattern is broken.
llvm::report_fatal_error(
"pattern returned success but IR did not change");
}
for (const auto &it : fingerprints) {
// Skip top-level op, its finger print is never invalidated.
if (it.first == topLevel)
continue;
// Note: Finger print computation may crash when an op was erased
// without notifying the rewriter. (Run with ASAN to see where the op was
// erased; the op was probably erased directly, bypassing the rewriter
// API.) Finger print computation does may not crash if a new op was
// created at the same memory location. (But then the finger print should
// have changed.)
if (it.second != OperationFingerPrint(it.first)) {
// Note: Run "mlir-opt -debug" to see which pattern is broken.
llvm::report_fatal_error("operation finger print changed");
}
}
}
void notifyRewriteFailure() {
// Pattern application failure => IR must not have changed.
OperationFingerPrint afterFingerPrint(topLevel);
if (*topLevelFingerPrint != afterFingerPrint) {
// Note: Run "mlir-opt -debug" to see which pattern is broken.
llvm::report_fatal_error("pattern returned failure but IR did change");
}
}
protected:
/// Invalidate the finger print of the given op, i.e., remove it from the map.
void invalidateFingerPrint(Operation *op) {
// Invalidate all finger prints until the top level.
while (op && op != topLevel) {
fingerprints.erase(op);
op = op->getParentOp();
}
}
void notifyOperationInserted(Operation *op) override {
RewriterBase::ForwardingListener::notifyOperationInserted(op);
invalidateFingerPrint(op->getParentOp());
}
void notifyOperationModified(Operation *op) override {
RewriterBase::ForwardingListener::notifyOperationModified(op);
invalidateFingerPrint(op);
}
void notifyOperationRemoved(Operation *op) override {
RewriterBase::ForwardingListener::notifyOperationRemoved(op);
op->walk([this](Operation *op) { invalidateFingerPrint(op); });
}
/// Operation finger prints to detect invalid pattern API usage. IR is checked
/// against these finger prints after pattern application to detect cases
/// where IR was modified directly, bypassing the rewriter API.
DenseMap<Operation *, OperationFingerPrint> fingerprints;
/// Top-level operation of the current greedy rewrite.
Operation *topLevel = nullptr;
/// Finger print of the top-level operation.
std::optional<OperationFingerPrint> topLevelFingerPrint;
};
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
//===----------------------------------------------------------------------===//
// GreedyPatternRewriteDriver
//===----------------------------------------------------------------------===//
/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
/// applies the locally optimal patterns.
///
@@ -122,21 +222,36 @@ private:
/// The low-level pattern applicator.
PatternApplicator matcher;
#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
DebugFingerPrints debugFingerPrints;
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
};
} // namespace
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config)
: PatternRewriter(ctx), folder(ctx, this), config(config),
matcher(patterns) {
: PatternRewriter(ctx), folder(ctx, this), config(config), matcher(patterns)
#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// clang-format off
, debugFingerPrints(this)
// clang-format on
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
{
worklist.reserve(64);
// Apply a simple cost model based solely on pattern benefit.
matcher.applyDefaultCostModel();
// Set up listener.
#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Send IR notifications to the debug handler. This handler will then forward
// all notifications to this GreedyPatternRewriteDriver.
setListener(&debugFingerPrints);
#else
setListener(this);
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
}
bool GreedyPatternRewriteDriver::processWorklist() {
@@ -231,15 +346,28 @@ bool GreedyPatternRewriteDriver::processWorklist() {
function_ref<LogicalResult(const Pattern &)> onSuccess = {};
#endif
#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
debugFingerPrints.computeFingerPrints(
/*topLevel=*/config.scope ? config.scope->getParentOp() : op);
auto clearFingerprints =
llvm::make_scope_exit([&]() { debugFingerPrints.clear(); });
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
LogicalResult matchResult =
matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
if (succeeded(matchResult)) {
LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
debugFingerPrints.notifyRewriteSuccess();
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
changed = true;
++numRewrites;
} else {
LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
debugFingerPrints.notifyRewriteFailure();
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
}
}
@@ -247,6 +375,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
}
void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
assert(op && "expected valid op");
// Gather potential ancestors while looking for a "scope" parent region.
SmallVector<Operation *, 8> ancestors;
Region *region = nullptr;