[mlir] Add PDL C & Python usage (#94714)

Following a rather direct approach to expose PDL usage from C and then
Python. This doesn't yes plumb through adding support for custom
matchers through this interface, so constrained to basics initially.

This also exposes greedy rewrite driver. Only way currently to define
patterns is via PDL (just to keep small). The creation of the PDL
pattern module could be improved to avoid folks potentially accessing
the module used to construct it post construction. No ergonomic work
done yet.

---------

Signed-off-by: Jacques Pienaar <jpienaar@google.com>
This commit is contained in:
Jacques Pienaar
2024-06-11 07:45:12 -07:00
committed by GitHub
parent 38ccee0034
commit 18cf1cd92b
15 changed files with 424 additions and 2 deletions

View File

@@ -39,6 +39,7 @@
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Pass.h"
#include "mlir-c/Rewrite.h"
// The 'mlir' Python package is relocatable and supports co-existing in multiple
// projects. Each project must define its outer package prefix with this define
@@ -284,6 +285,26 @@ static inline MlirModule mlirPythonCapsuleToModule(PyObject *capsule) {
return module;
}
/** Creates a capsule object encapsulating the raw C-API
* MlirFrozenRewritePatternSet.
* The returned capsule does not extend or affect ownership of any Python
* objects that reference the module in any way. */
static inline PyObject *
mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm) {
return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(pm),
MLIR_PYTHON_CAPSULE_PASS_MANAGER, NULL);
}
/** Extracts an MlirFrozenRewritePatternSet from a capsule as produced from
* mlirPythonFrozenRewritePatternSetToCapsule. If the capsule is not of the
* right type, then a null module is returned. */
static inline MlirFrozenRewritePatternSet
mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule) {
void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_PASS_MANAGER);
MlirFrozenRewritePatternSet pm = {ptr};
return pm;
}
/** Creates a capsule object encapsulating the raw C-API MlirPassManager.
* The returned capsule does not extend or affect ownership of any Python
* objects that reference the module in any way. */

View File

@@ -0,0 +1,60 @@
//===-- mlir-c/Rewrite.h - Helpers for C API to Rewrites ----------*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header declares the registration and creation method for
// rewrite patterns.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_C_REWRITE_H
#define MLIR_C_REWRITE_H
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Config/mlir-config.h"
//===----------------------------------------------------------------------===//
/// Opaque type declarations (see mlir-c/IR.h for more details).
//===----------------------------------------------------------------------===//
#define DEFINE_C_API_STRUCT(name, storage) \
struct name { \
storage *ptr; \
}; \
typedef struct name name
DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet
mlirFreezeRewritePattern(MlirRewritePatternSet op);
MLIR_CAPI_EXPORTED void
mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
MlirModule op, MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig);
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
DEFINE_C_API_STRUCT(MlirPDLPatternModule, void);
MLIR_CAPI_EXPORTED MlirPDLPatternModule
mlirPDLPatternModuleFromModule(MlirModule op);
MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op);
MLIR_CAPI_EXPORTED MlirRewritePatternSet
mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op);
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
#undef DEFINE_C_API_STRUCT
#endif // MLIR_C_REWRITE_H

View File

@@ -198,6 +198,27 @@ struct type_caster<MlirModule> {
};
};
/// Casts object <-> MlirFrozenRewritePatternSet.
template <>
struct type_caster<MlirFrozenRewritePatternSet> {
PYBIND11_TYPE_CASTER(MlirFrozenRewritePatternSet,
_("MlirFrozenRewritePatternSet"));
bool load(handle src, bool) {
py::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
return value.ptr != nullptr;
}
static handle cast(MlirFrozenRewritePatternSet v, return_value_policy,
handle) {
py::object capsule = py::reinterpret_steal<py::object>(
mlirPythonFrozenRewritePatternSetToCapsule(v));
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("rewrite"))
.attr("FrozenRewritePatternSet")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.release();
};
};
/// Casts object <-> MlirOperation.
template <>
struct type_caster<MlirOperation> {

View File

@@ -22,6 +22,7 @@
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Transforms.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "llvm/ADT/DenseMap.h"

View File

@@ -11,6 +11,7 @@
#include "Globals.h"
#include "IRModule.h"
#include "Pass.h"
#include "Rewrite.h"
namespace py = pybind11;
using namespace mlir;
@@ -116,6 +117,9 @@ PYBIND11_MODULE(_mlir, m) {
populateIRInterfaces(irModule);
populateIRTypes(irModule);
auto rewriteModule = m.def_submodule("rewrite", "MLIR Rewrite Bindings");
populateRewriteSubmodule(rewriteModule);
// Define and populate PassManager submodule.
auto passModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");

View File

@@ -0,0 +1,110 @@
//===- Rewrite.cpp - Rewrite ----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "Rewrite.h"
#include "IRModule.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Rewrite.h"
#include "mlir/Config/mlir-config.h"
namespace py = pybind11;
using namespace mlir;
using namespace py::literals;
using namespace mlir::python;
namespace {
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
/// Owning Wrapper around a PDLPatternModule.
class PyPDLPatternModule {
public:
PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
: module(other.module) {
other.module.ptr = nullptr;
}
~PyPDLPatternModule() {
if (module.ptr != nullptr)
mlirPDLPatternModuleDestroy(module);
}
MlirPDLPatternModule get() { return module; }
private:
MlirPDLPatternModule module;
};
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
/// Owning Wrapper around a FrozenRewritePatternSet.
class PyFrozenRewritePatternSet {
public:
PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
: set(other.set) {
other.set.ptr = nullptr;
}
~PyFrozenRewritePatternSet() {
if (set.ptr != nullptr)
mlirFrozenRewritePatternSetDestroy(set);
}
MlirFrozenRewritePatternSet get() { return set; }
pybind11::object getCapsule() {
return py::reinterpret_steal<py::object>(
mlirPythonFrozenRewritePatternSetToCapsule(get()));
}
static pybind11::object createFromCapsule(pybind11::object capsule) {
MlirFrozenRewritePatternSet rawPm =
mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
if (rawPm.ptr == nullptr)
throw py::error_already_set();
return py::cast(PyFrozenRewritePatternSet(rawPm),
py::return_value_policy::move);
}
private:
MlirFrozenRewritePatternSet set;
};
} // namespace
/// Create the `mlir.rewrite` here.
void mlir::python::populateRewriteSubmodule(py::module &m) {
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
py::class_<PyPDLPatternModule>(m, "PDLModule", py::module_local())
.def(py::init<>([](MlirModule module) {
return mlirPDLPatternModuleFromModule(module);
}),
"module"_a, "Create a PDL module from the given module.")
.def("freeze", [](PyPDLPatternModule &self) {
return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
mlirRewritePatternSetFromPDLPatternModule(self.get())));
});
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCg
py::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet",
py::module_local())
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyFrozenRewritePatternSet::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
&PyFrozenRewritePatternSet::createFromCapsule);
m.def(
"apply_patterns_and_fold_greedily",
[](MlirModule module, MlirFrozenRewritePatternSet set) {
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
if (mlirLogicalResultIsFailure(status))
// FIXME: Not sure this is the right error to throw here.
throw py::value_error("pattern application failed to converge");
},
"module"_a, "set"_a,
"Applys the given patterns to the given module greedily while folding "
"results.");
}

View File

@@ -0,0 +1,22 @@
//===- Rewrite.h - Rewrite Submodules of pybind module --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
#define MLIR_BINDINGS_PYTHON_REWRITE_H
#include "PybindUtils.h"
namespace mlir {
namespace python {
void populateRewriteSubmodule(pybind11::module &m);
} // namespace python
} // namespace mlir
#endif // MLIR_BINDINGS_PYTHON_REWRITE_H

View File

@@ -1,6 +1,9 @@
add_mlir_upstream_c_api_library(MLIRCAPITransforms
Passes.cpp
Rewrite.cpp
LINK_LIBS PUBLIC
MLIRIR
MLIRTransforms
MLIRTransformUtils
)

View File

@@ -0,0 +1,83 @@
//===- Rewrite.cpp - C API for Rewrite Patterns ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir-c/Rewrite.h"
#include "mlir-c/Transforms.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
assert(module.ptr && "unexpected null module");
return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
}
inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
return {module};
}
inline mlir::FrozenRewritePatternSet *
unwrap(MlirFrozenRewritePatternSet module) {
assert(module.ptr && "unexpected null module");
return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr);
}
inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) {
return {module};
}
MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
op.ptr = nullptr;
return wrap(m);
}
void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
delete unwrap(op);
op.ptr = nullptr;
}
MlirLogicalResult
mlirApplyPatternsAndFoldGreedily(MlirModule op,
MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig) {
return wrap(
mlir::applyPatternsAndFoldGreedily(unwrap(op), *unwrap(patterns)));
}
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
assert(module.ptr && "unexpected null module");
return static_cast<mlir::PDLPatternModule *>(module.ptr);
}
inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) {
return {module};
}
MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) {
return wrap(new mlir::PDLPatternModule(
mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op))));
}
void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op) {
delete unwrap(op);
op.ptr = nullptr;
}
MlirRewritePatternSet
mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
auto *m = new mlir::RewritePatternSet(std::move(*unwrap(op)));
op.ptr = nullptr;
return wrap(m);
}
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH

View File

@@ -21,6 +21,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
_mlir_libs/__init__.py
ir.py
passmanager.py
rewrite.py
dialects/_ods_common.py
# The main _mlir module has submodules: include stubs from each.
@@ -448,6 +449,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
IRModule.cpp
IRTypes.cpp
Pass.cpp
Rewrite.cpp
# Headers must be included explicitly so they are installed.
Globals.h

View File

@@ -6,7 +6,7 @@ from ._pdl_ops_gen import *
from ._pdl_ops_gen import _Dialect
from .._mlir_libs._mlirDialectsPDL import *
from .._mlir_libs._mlirDialectsPDL import OperationType
from ..extras.meta import region_op
try:
from ..ir import *
@@ -127,6 +127,9 @@ class PatternOp(PatternOp):
return self.regions[0].blocks[0]
pattern = region_op(PatternOp.__base__)
@_ods_cext.register_operation(_Dialect, replace=True)
class ReplaceOp(ReplaceOp):
"""Specialization for PDL replace op class."""
@@ -195,6 +198,9 @@ class RewriteOp(RewriteOp):
return self.regions[0].blocks[0]
rewrite = region_op(RewriteOp)
@_ods_cext.register_operation(_Dialect, replace=True)
class TypeOp(TypeOp):
"""Specialization for PDL type op class."""

View File

@@ -0,0 +1,5 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._mlir_libs._mlir.rewrite import *

View File

@@ -0,0 +1,67 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s
from mlir.dialects import arith, func, pdl
from mlir.dialects.builtin import module
from mlir.ir import *
from mlir.rewrite import *
def construct_and_print_in_module(f):
print("\nTEST:", f.__name__)
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
module = f(module)
if module is not None:
print(module)
return f
# CHECK-LABEL: TEST: test_add_to_mul
# CHECK: arith.muli
@construct_and_print_in_module
def test_add_to_mul(module_):
index_type = IndexType.get()
# Create a test case.
@module(sym_name="ir")
def ir():
@func.func(index_type, index_type)
def add_func(a, b):
return arith.addi(a, b)
# Create a rewrite from add to mul. This will match
# - operation name is arith.addi
# - operands are index types.
# - there are two operands.
with Location.unknown():
m = Module.create()
with InsertionPoint(m.body):
# Change all arith.addi with index types to arith.muli.
@pdl.pattern(benefit=1, sym_name="addi_to_mul")
def pat():
# Match arith.addi with index types.
index_type = pdl.TypeOp(IndexType.get())
operand0 = pdl.OperandOp(index_type)
operand1 = pdl.OperandOp(index_type)
op0 = pdl.OperationOp(
name="arith.addi", args=[operand0, operand1], types=[index_type]
)
# Replace the matched op with arith.muli.
@pdl.rewrite()
def rew():
newOp = pdl.OperationOp(
name="arith.muli", args=[operand0, operand1], types=[index_type]
)
pdl.ReplaceOp(op0, with_op=newOp)
# Create a PDL module from module and freeze it. At this point the ownership
# of the module is transferred to the PDL module. This ownership transfer is
# not yet captured Python side/has sharp edges. So best to construct the
# module and PDL module in same scope.
# FIXME: This should be made more robust.
frozen = PDLModule(m).freeze()
# Could apply frozen pattern set multiple times.
apply_patterns_and_fold_greedily(module_, frozen)
return module_

View File

@@ -420,6 +420,7 @@ mlir_c_api_cc_library(
"include/mlir-c/Interfaces.h",
"include/mlir-c/Pass.h",
"include/mlir-c/RegisterEverything.h",
"include/mlir-c/Rewrite.h",
"include/mlir-c/Support.h",
"include/mlir/CAPI/AffineExpr.h",
"include/mlir/CAPI/AffineMap.h",
@@ -866,7 +867,10 @@ mlir_c_api_cc_library(
mlir_c_api_cc_library(
name = "CAPITransforms",
srcs = ["lib/CAPI/Transforms/Passes.cpp"],
srcs = [
"lib/CAPI/Transforms/Passes.cpp",
"lib/CAPI/Transforms/Rewrite.cpp",
],
hdrs = ["include/mlir-c/Transforms.h"],
capi_deps = [
":CAPIIR",
@@ -876,7 +880,10 @@ mlir_c_api_cc_library(
],
includes = ["include"],
deps = [
":IR",
":Pass",
":Rewrite",
":TransformUtils",
":Transforms",
],
)
@@ -939,6 +946,7 @@ cc_library(
textual_hdrs = glob(MLIR_BINDINGS_PYTHON_HEADERS),
deps = [
":CAPIIRHeaders",
":CAPITransformsHeaders",
"@local_config_python//:python_headers",
"@pybind11",
],
@@ -957,6 +965,7 @@ cc_library(
textual_hdrs = glob(MLIR_BINDINGS_PYTHON_HEADERS),
deps = [
":CAPIIR",
":CAPITransforms",
"@local_config_python//:python_headers",
"@pybind11",
],
@@ -981,6 +990,7 @@ MLIR_PYTHON_BINDINGS_SOURCES = [
"lib/Bindings/Python/IRModule.cpp",
"lib/Bindings/Python/IRTypes.cpp",
"lib/Bindings/Python/Pass.cpp",
"lib/Bindings/Python/Rewrite.cpp",
]
cc_library(

View File

@@ -82,6 +82,13 @@ filegroup(
],
)
filegroup(
name = "RewritePyFiles",
srcs = [
"mlir/rewrite.py",
],
)
filegroup(
name = "RuntimePyFiles",
srcs = glob([