[mlir][python] remove mixins (#68853)
This PR replaces the mixin `OpView` extension mechanism with the standard inheritance mechanism. Why? Firstly, mixins are not very pythonic (inheritance is usually used for this), a little convoluted, and too "tight" (can only be used in the immediately adjacent `_ext.py`). Secondly, it (mixins) are now blocking are correct implementation of "value builders" (see [here](https://github.com/llvm/llvm-project/pull/68764)) where the problem becomes how to choose the correct base class that the value builder should call. This PR looks big/complicated but appearances are deceiving; 4 things were needed to make this work: 1. Drop `skipDefaultBuilders` in `OpPythonBindingGen::emitDefaultOpBuilders` 2. Former mixin extension classes are converted to inherit from the generated `OpView` instead of being "mixins" a. extension classes that simply were calling into an already generated `super().__init__` continue to do so b. (almost all) extension classes that were calling `self.build_generic` because of a lack of default builder being generated can now also just call `super().__init__` 3. To handle the [lone single use-case](https://sourcegraph.com/search?q=context%3Aglobal+select_opview_mixin&patternType=standard&sm=1&groupBy=repo) of `select_opview_mixin`, namely [linalg](https://github.com/llvm/llvm-project/blob/main/mlir/python/mlir/dialects/_linalg_ops_ext.py#L38), only a small change was necessary in `opdsl/lang/emitter.py` (thanks to the emission/generation of default builders/`__init__`s) 4. since the `extend_opview_class` decorator is removed, we need a way to register extension classes as the desired `OpView` that `op.opview` conjures into existence; so we do the standard thing and just enable replacing the existing registered `OpView` i.e., `register_operation(_Dialect, replace=True)`. Note, the upgrade path for the common case is to change an extension to inherit from the generated builder and decorate it with `register_operation(_Dialect, replace=True)`. In the slightly more complicated case where `super().__init(self.build_generic(...))` is called in the extension's `__init__`, this needs to be updated to call `__init__` in `OpView`, i.e., the grandparent (see updated docs). Note, also `<DIALECT>_ext.py` files/modules will no longer be automatically loaded. Note, the PR has 3 base commits that look funny but this was done for the purpose of tracking the line history of moving the `<DIALECT>_ops_ext.py` class into `<DIALECT>.py` and updating (commit labeled "fix").
This commit is contained in:
@@ -1017,90 +1017,79 @@ very generic signature.
|
||||
|
||||
#### Extending Generated Op Classes
|
||||
|
||||
Note that this is a rather complex mechanism and this section errs on the side
|
||||
of explicitness. Users are encouraged to find an example and duplicate it if
|
||||
they don't feel the need to understand the subtlety. The `builtin` dialect
|
||||
provides some relatively simple examples.
|
||||
|
||||
As mentioned above, the build system generates Python sources like
|
||||
`_{DIALECT_NAMESPACE}_ops_gen.py` for each dialect with Python bindings. It is
|
||||
often desirable to to use these generated classes as a starting point for
|
||||
further customization, so an extension mechanism is provided to make this easy
|
||||
(you are always free to do ad-hoc patching in your `{DIALECT_NAMESPACE}.py` file
|
||||
but we prefer a more standard mechanism that is applied uniformly).
|
||||
|
||||
To provide extensions, add a `_{DIALECT_NAMESPACE}_ops_ext.py` file to the
|
||||
`dialects` module (i.e. adjacent to your `{DIALECT_NAMESPACE}.py` top-level and
|
||||
the `*_ops_gen.py` file). Using the `builtin` dialect and `FuncOp` as an
|
||||
example, the generated code will include an import like this:
|
||||
often desirable to use these generated classes as a starting point for
|
||||
further customization, so an extension mechanism is provided to make this easy.
|
||||
This mechanism uses conventional inheritance combined with `OpView` registration.
|
||||
For example, the default builder for `arith.constant`
|
||||
|
||||
```python
|
||||
try:
|
||||
from . import _builtin_ops_ext as _ods_ext_module
|
||||
except ImportError:
|
||||
_ods_ext_module = None
|
||||
class ConstantOp(_ods_ir.OpView):
|
||||
OPERATION_NAME = "arith.constant"
|
||||
|
||||
_ODS_REGIONS = (0, True)
|
||||
|
||||
def __init__(self, value, *, loc=None, ip=None):
|
||||
...
|
||||
```
|
||||
|
||||
Then for each generated concrete `OpView` subclass, it will apply a decorator
|
||||
like:
|
||||
expects `value` to be a `TypedAttr` (e.g., `IntegerAttr` or `FloatAttr`).
|
||||
Thus, a natural extension is a builder that accepts a MLIR type and a Python value and instantiates the appropriate `TypedAttr`:
|
||||
|
||||
```python
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
@_ods_extend_opview_class(_ods_ext_module)
|
||||
class FuncOp(_ods_ir.OpView):
|
||||
from typing import Union
|
||||
|
||||
from mlir.ir import Type, IntegerAttr, FloatAttr
|
||||
from mlir.dialects._arith_ops_gen import _Dialect, ConstantOp
|
||||
from mlir.dialects._ods_common import _cext
|
||||
|
||||
@_cext.register_operation(_Dialect, replace=True)
|
||||
class ConstantOpExt(ConstantOp):
|
||||
def __init__(
|
||||
self, result: Type, value: Union[int, float], *, loc=None, ip=None
|
||||
):
|
||||
if isinstance(value, int):
|
||||
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
|
||||
elif isinstance(value, float):
|
||||
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
|
||||
else:
|
||||
raise NotImplementedError(f"Building `arith.constant` not supported for {result=} {value=}")
|
||||
```
|
||||
|
||||
See the `_ods_common.py` `extend_opview_class` function for details of the
|
||||
mechanism. At a high level:
|
||||
|
||||
* If the extension module exists, locate an extension class for the op (in
|
||||
this example, `FuncOp`):
|
||||
* First by looking for an attribute with the exact name in the extension
|
||||
module.
|
||||
* Falling back to calling a `select_opview_mixin(parent_opview_cls)`
|
||||
function defined in the extension module.
|
||||
* If a mixin class is found, a new subclass is dynamically created that
|
||||
multiply inherits from `({_builtin_ops_ext.FuncOp},
|
||||
_builtin_ops_gen.FuncOp)`.
|
||||
|
||||
The mixin class should not inherit from anything (i.e. directly extends `object`
|
||||
only). The facility is typically used to define custom `__init__` methods,
|
||||
properties, instance methods and static methods. Due to the inheritance
|
||||
ordering, the mixin class can act as though it extends the generated `OpView`
|
||||
subclass in most contexts (i.e. `issubclass(_builtin_ops_ext.FuncOp, OpView)`
|
||||
will return `False` but usage generally allows you treat it as duck typed as an
|
||||
`OpView`).
|
||||
|
||||
There are a couple of recommendations, given how the class hierarchy is defined:
|
||||
|
||||
* For static methods that need to instantiate the actual "leaf" op (which is
|
||||
dynamically generated and would result in circular dependencies to try to
|
||||
reference by name), prefer to use `@classmethod` and the concrete subclass
|
||||
will be provided as your first `cls` argument. See
|
||||
`_builtin_ops_ext.FuncOp.from_py_func` as an example.
|
||||
* If seeking to replace the generated `__init__` method entirely, you may
|
||||
actually want to invoke the super-super-class `mlir.ir.OpView` constructor
|
||||
directly, as it takes an `mlir.ir.Operation`, which is likely what you are
|
||||
constructing (i.e. the generated `__init__` method likely adds more API
|
||||
constraints than you want to expose in a custom builder).
|
||||
|
||||
A pattern that comes up frequently is wanting to provide a sugared `__init__`
|
||||
method which has optional or type-polymorphism/implicit conversions but to
|
||||
otherwise want to invoke the default op building logic. For such cases, it is
|
||||
recommended to use an idiom such as:
|
||||
which enables building an instance of `arith.constant` like so:
|
||||
|
||||
```python
|
||||
def __init__(self, sugar, spice, *, loc=None, ip=None):
|
||||
... massage into result_type, operands, attributes ...
|
||||
OpView.__init__(self, self.build_generic(
|
||||
results=[result_type],
|
||||
operands=operands,
|
||||
attributes=attributes,
|
||||
loc=loc,
|
||||
ip=ip))
|
||||
from mlir.ir import F32Type
|
||||
|
||||
a = ConstantOpExt(F32Type.get(), 42.42)
|
||||
b = ConstantOpExt(IntegerType.get_signless(32), 42)
|
||||
```
|
||||
|
||||
Refer to the documentation for `build_generic` for more information.
|
||||
Note, three key aspects of the extension mechanism in this example:
|
||||
|
||||
1. `ConstantOpExt` directly inherits from the generated `ConstantOp`;
|
||||
2. in this, simplest, case all that's required is a call to the super class' initializer, i.e., `super().__init__(...)`;
|
||||
3. in order to register `ConstantOpExt` as the preferred `OpView` that is returned by `mlir.ir.Operation.opview` (see [Operations, Regions and Blocks](#operations-regions-and-blocks))
|
||||
we decorate the class with `@_cext.register_operation(_Dialect, replace=True)`, **where the `replace=True` must be used**.
|
||||
|
||||
In some more complex cases it might be necessary to explicitly build the `OpView` through `OpView.build_generic` (see [Default Builder](#default-builder)), just as is performed by the generated builders.
|
||||
I.e., we must call `OpView.build_generic` **and pass the result to `OpView.__init__`**, where the small issue becomes that the latter is already overridden by the generated builder.
|
||||
Thus, we must call a method of a super class' super class (the "grandparent"); for example:
|
||||
|
||||
```python
|
||||
from mlir.dialects._scf_ops_gen import _Dialect, ForOp
|
||||
from mlir.dialects._ods_common import _cext
|
||||
|
||||
@_cext.register_operation(_Dialect, replace=True)
|
||||
class ForOpExt(ForOp):
|
||||
def __init__(self, lower_bound, upper_bound, step, iter_args, *, loc=None, ip=None):
|
||||
...
|
||||
super(ForOp, self).__init__(self.build_generic(...))
|
||||
```
|
||||
|
||||
where `OpView.__init__` is called via `super(ForOp, self).__init__`.
|
||||
Note, there are alternatives ways to implement this (e.g., explicitly writing `OpView.__init__`); see any discussion on Python inheritance.
|
||||
|
||||
## Providing Python bindings for a dialect
|
||||
|
||||
|
||||
@@ -77,10 +77,10 @@ public:
|
||||
pybind11::object pyClass);
|
||||
|
||||
/// Adds a concrete implementation operation class.
|
||||
/// Raises an exception if the mapping already exists.
|
||||
/// Raises an exception if the mapping already exists and replace == false.
|
||||
/// This is intended to be called by implementation code.
|
||||
void registerOperationImpl(const std::string &operationName,
|
||||
pybind11::object pyClass);
|
||||
pybind11::object pyClass, bool replace = false);
|
||||
|
||||
/// Returns the custom Attribute builder for Attribute kind.
|
||||
std::optional<pybind11::function>
|
||||
|
||||
@@ -96,9 +96,9 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
|
||||
}
|
||||
|
||||
void PyGlobals::registerOperationImpl(const std::string &operationName,
|
||||
py::object pyClass) {
|
||||
py::object pyClass, bool replace) {
|
||||
py::object &found = operationClassMap[operationName];
|
||||
if (found) {
|
||||
if (found && !replace) {
|
||||
throw std::runtime_error((llvm::Twine("Operation '") + operationName +
|
||||
"' is already registered.")
|
||||
.str());
|
||||
|
||||
@@ -41,7 +41,7 @@ PYBIND11_MODULE(_mlir, m) {
|
||||
"dialect_namespace"_a, "dialect_class"_a,
|
||||
"Testing hook for directly registering a dialect")
|
||||
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
|
||||
"operation_name"_a, "operation_class"_a,
|
||||
"operation_name"_a, "operation_class"_a, "replace"_a = false,
|
||||
"Testing hook for directly registering an operation");
|
||||
|
||||
// Aside from making the globals accessible to python, having python manage
|
||||
@@ -63,12 +63,13 @@ PYBIND11_MODULE(_mlir, m) {
|
||||
"Class decorator for registering a custom Dialect wrapper");
|
||||
m.def(
|
||||
"register_operation",
|
||||
[](const py::object &dialectClass) -> py::cpp_function {
|
||||
[](const py::object &dialectClass, bool replace) -> py::cpp_function {
|
||||
return py::cpp_function(
|
||||
[dialectClass](py::object opClass) -> py::object {
|
||||
[dialectClass, replace](py::object opClass) -> py::object {
|
||||
std::string operationName =
|
||||
opClass.attr("OPERATION_NAME").cast<std::string>();
|
||||
PyGlobals::get().registerOperationImpl(operationName, opClass);
|
||||
PyGlobals::get().registerOperationImpl(operationName, opClass,
|
||||
replace);
|
||||
|
||||
// Dict-stuff the new opClass by name onto the dialect class.
|
||||
py::object opClassName = opClass.attr("__name__");
|
||||
@@ -76,7 +77,7 @@ PYBIND11_MODULE(_mlir, m) {
|
||||
return opClass;
|
||||
});
|
||||
},
|
||||
"dialect_class"_a,
|
||||
"dialect_class"_a, "replace"_a = false,
|
||||
"Produce a class decorator for registering an Operation class as part of "
|
||||
"a dialect");
|
||||
m.def(
|
||||
|
||||
@@ -52,7 +52,6 @@ declare_mlir_dialect_python_bindings(
|
||||
TD_FILE dialects/AffineOps.td
|
||||
SOURCES
|
||||
dialects/affine.py
|
||||
dialects/_affine_ops_ext.py
|
||||
DIALECT_NAME affine
|
||||
GEN_ENUM_BINDINGS)
|
||||
|
||||
@@ -78,7 +77,6 @@ declare_mlir_dialect_python_bindings(
|
||||
TD_FILE dialects/BufferizationOps.td
|
||||
SOURCES
|
||||
dialects/bufferization.py
|
||||
dialects/_bufferization_ops_ext.py
|
||||
DIALECT_NAME bufferization
|
||||
GEN_ENUM_BINDINGS_TD_FILE
|
||||
"../../include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
|
||||
@@ -90,7 +88,6 @@ declare_mlir_dialect_python_bindings(
|
||||
TD_FILE dialects/BuiltinOps.td
|
||||
SOURCES
|
||||
dialects/builtin.py
|
||||
dialects/_builtin_ops_ext.py
|
||||
DIALECT_NAME builtin)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
@@ -115,7 +112,6 @@ declare_mlir_dialect_python_bindings(
|
||||
TD_FILE dialects/FuncOps.td
|
||||
SOURCES
|
||||
dialects/func.py
|
||||
dialects/_func_ops_ext.py
|
||||
DIALECT_NAME func)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
@@ -131,7 +127,6 @@ declare_mlir_dialect_python_bindings(
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/LinalgOps.td
|
||||
SOURCES
|
||||
dialects/_linalg_ops_ext.py
|
||||
SOURCES_GLOB
|
||||
dialects/linalg/*.py
|
||||
DIALECT_NAME linalg
|
||||
@@ -152,7 +147,6 @@ ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/TransformPDLExtensionOps.td
|
||||
SOURCES
|
||||
dialects/_transform_pdl_extension_ops_ext.py
|
||||
dialects/transform/pdl.py
|
||||
DIALECT_NAME transform
|
||||
EXTENSION_NAME transform_pdl_extension)
|
||||
@@ -162,7 +156,6 @@ declare_mlir_dialect_python_bindings(
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/TransformOps.td
|
||||
SOURCES
|
||||
dialects/_transform_ops_ext.py
|
||||
dialects/transform/__init__.py
|
||||
_mlir_libs/_mlir/dialects/transform/__init__.pyi
|
||||
DIALECT_NAME transform
|
||||
@@ -175,7 +168,6 @@ declare_mlir_dialect_extension_python_bindings(
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/BufferizationTransformOps.td
|
||||
SOURCES
|
||||
dialects/_bufferization_transform_ops_ext.py
|
||||
dialects/transform/bufferization.py
|
||||
DIALECT_NAME transform
|
||||
EXTENSION_NAME bufferization_transform)
|
||||
@@ -185,7 +177,6 @@ declare_mlir_dialect_extension_python_bindings(
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/GPUTransformOps.td
|
||||
SOURCES
|
||||
dialects/_gpu_transform_ops_ext.py
|
||||
dialects/transform/gpu.py
|
||||
DIALECT_NAME transform
|
||||
EXTENSION_NAME gpu_transform)
|
||||
@@ -195,7 +186,6 @@ declare_mlir_dialect_extension_python_bindings(
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/SCFLoopTransformOps.td
|
||||
SOURCES
|
||||
dialects/_loop_transform_ops_ext.py
|
||||
dialects/transform/loop.py
|
||||
DIALECT_NAME transform
|
||||
EXTENSION_NAME loop_transform)
|
||||
@@ -205,7 +195,6 @@ declare_mlir_dialect_extension_python_bindings(
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/MemRefTransformOps.td
|
||||
SOURCES
|
||||
dialects/_memref_transform_ops_ext.py
|
||||
dialects/transform/memref.py
|
||||
DIALECT_NAME transform
|
||||
EXTENSION_NAME memref_transform)
|
||||
@@ -224,7 +213,6 @@ declare_mlir_dialect_extension_python_bindings(
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/LinalgStructuredTransformOps.td
|
||||
SOURCES
|
||||
dialects/_structured_transform_ops_ext.py
|
||||
dialects/transform/structured.py
|
||||
DIALECT_NAME transform
|
||||
EXTENSION_NAME structured_transform
|
||||
@@ -246,7 +234,6 @@ declare_mlir_dialect_extension_python_bindings(
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/TensorTransformOps.td
|
||||
SOURCES
|
||||
dialects/_tensor_transform_ops_ext.py
|
||||
dialects/transform/tensor.py
|
||||
DIALECT_NAME transform
|
||||
EXTENSION_NAME tensor_transform)
|
||||
@@ -276,7 +263,6 @@ declare_mlir_dialect_python_bindings(
|
||||
TD_FILE dialects/ArithOps.td
|
||||
SOURCES
|
||||
dialects/arith.py
|
||||
dialects/_arith_ops_ext.py
|
||||
DIALECT_NAME arith
|
||||
GEN_ENUM_BINDINGS)
|
||||
|
||||
@@ -286,7 +272,6 @@ declare_mlir_dialect_python_bindings(
|
||||
TD_FILE dialects/MemRefOps.td
|
||||
SOURCES
|
||||
dialects/memref.py
|
||||
dialects/_memref_ops_ext.py
|
||||
DIALECT_NAME memref)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
@@ -295,7 +280,6 @@ declare_mlir_dialect_python_bindings(
|
||||
TD_FILE dialects/MLProgramOps.td
|
||||
SOURCES
|
||||
dialects/ml_program.py
|
||||
dialects/_ml_program_ops_ext.py
|
||||
DIALECT_NAME ml_program)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
@@ -339,7 +323,6 @@ declare_mlir_dialect_python_bindings(
|
||||
TD_FILE dialects/PDLOps.td
|
||||
SOURCES
|
||||
dialects/pdl.py
|
||||
dialects/_pdl_ops_ext.py
|
||||
_mlir_libs/_mlir/dialects/pdl.pyi
|
||||
DIALECT_NAME pdl)
|
||||
|
||||
@@ -357,7 +340,6 @@ declare_mlir_dialect_python_bindings(
|
||||
TD_FILE dialects/SCFOps.td
|
||||
SOURCES
|
||||
dialects/scf.py
|
||||
dialects/_scf_ops_ext.py
|
||||
DIALECT_NAME scf)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
@@ -383,7 +365,6 @@ declare_mlir_dialect_python_bindings(
|
||||
TD_FILE dialects/TensorOps.td
|
||||
SOURCES
|
||||
dialects/tensor.py
|
||||
dialects/_tensor_ops_ext.py
|
||||
DIALECT_NAME tensor)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
|
||||
from ._ods_common import get_op_results_or_values as _get_op_results_or_values
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
|
||||
class AffineStoreOp:
|
||||
"""Specialization for the Affine store operation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
value: Union[Operation, OpView, Value],
|
||||
memref: Union[Operation, OpView, Value],
|
||||
map: AffineMap=None,
|
||||
*,
|
||||
map_operands=None,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
"""Creates an affine store operation.
|
||||
|
||||
- `value`: the value to store into the memref.
|
||||
- `memref`: the buffer to store into.
|
||||
- `map`: the affine map that maps the map_operands to the index of the
|
||||
memref.
|
||||
- `map_operands`: the list of arguments to substitute the dimensions,
|
||||
then symbols in the affine map, in increasing order.
|
||||
"""
|
||||
map = map if map is not None else []
|
||||
map_operands = map_operands if map_operands is not None else []
|
||||
operands = [
|
||||
_get_op_result_or_value(value),
|
||||
_get_op_result_or_value(memref),
|
||||
*[_get_op_result_or_value(op) for op in map_operands]
|
||||
]
|
||||
results = []
|
||||
attributes = {"map": AffineMapAttr.get(map)}
|
||||
regions = None
|
||||
_ods_successors = None
|
||||
super().__init__(self.build_generic(
|
||||
attributes=attributes,
|
||||
results=results,
|
||||
operands=operands,
|
||||
successors=_ods_successors,
|
||||
regions=regions,
|
||||
loc=loc,
|
||||
ip=ip
|
||||
))
|
||||
@@ -1,69 +0,0 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import get_default_loc_context as _get_default_loc_context
|
||||
|
||||
from typing import Any, List, Union
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
|
||||
def _isa(obj: Any, cls: type):
|
||||
try:
|
||||
cls(obj)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _is_any_of(obj: Any, classes: List[type]):
|
||||
return any(_isa(obj, cls) for cls in classes)
|
||||
|
||||
|
||||
def _is_integer_like_type(type: Type):
|
||||
return _is_any_of(type, [IntegerType, IndexType])
|
||||
|
||||
|
||||
def _is_float_type(type: Type):
|
||||
return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
|
||||
|
||||
|
||||
class ConstantOp:
|
||||
"""Specialization for the constant op class."""
|
||||
|
||||
def __init__(
|
||||
self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
|
||||
):
|
||||
if isinstance(value, int):
|
||||
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
|
||||
elif isinstance(value, float):
|
||||
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
|
||||
else:
|
||||
super().__init__(value, loc=loc, ip=ip)
|
||||
|
||||
@classmethod
|
||||
def create_index(cls, value: int, *, loc=None, ip=None):
|
||||
"""Create an index-typed constant."""
|
||||
return cls(
|
||||
IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return self.results[0].type
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return Attribute(self.operation.attributes["value"])
|
||||
|
||||
@property
|
||||
def literal_value(self) -> Union[int, float]:
|
||||
if _is_integer_like_type(self.type):
|
||||
return IntegerAttr(self.value).value
|
||||
elif _is_float_type(self.type):
|
||||
return FloatAttr(self.value).value
|
||||
else:
|
||||
raise ValueError("only integer and float constants have literal values")
|
||||
@@ -1,41 +0,0 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from typing import Sequence, Union
|
||||
from ..ir import *
|
||||
from ._ods_common import get_default_loc_context
|
||||
|
||||
from typing import Any, List, Union
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
|
||||
class AllocTensorOp:
|
||||
"""Extends the bufferization.alloc_tensor op."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tensor_type: Type,
|
||||
dynamic_sizes: Sequence[Value],
|
||||
copy: Value,
|
||||
size_hint: Value,
|
||||
escape: BoolAttr,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
"""Constructs an `alloc_tensor` with static and/or dynamic sizes."""
|
||||
context = get_default_loc_context(loc)
|
||||
attributes = {}
|
||||
if escape:
|
||||
attributes["escape"] = escape
|
||||
op = self.build_generic(
|
||||
results=[tensor_type],
|
||||
operands=[dynamic_sizes, copy, size_hint],
|
||||
attributes=attributes,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
OpView.__init__(self, op)
|
||||
@@ -1,128 +0,0 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ..dialects import transform
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional, overload, Union
|
||||
|
||||
|
||||
class EmptyTensorToAllocTensorOp:
|
||||
"""Specialization for EmptyTensorToAllocTensorOp class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
transformed_type: Type,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformed_type_or_target: Type,
|
||||
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
if isinstance(transformed_type_or_target, Type):
|
||||
transformed_type = transformed_type_or_target
|
||||
target = target_or_none
|
||||
else:
|
||||
transformed_type = transform.OperationType.get("bufferization.alloc_tensor")
|
||||
target = transformed_type_or_target
|
||||
|
||||
super().__init__(
|
||||
transformed_type,
|
||||
target,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class OneShotBufferizeOp:
|
||||
"""Specialization for OneShotBufferizeOp class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
transformed_type: Type,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
allow_return_allocs_from_loops: Optional[bool] = None,
|
||||
allow_unknown_ops: Optional[bool] = None,
|
||||
bufferize_function_boundaries: Optional[bool] = None,
|
||||
function_boundary_type_conversion: Optional[Enum] = None,
|
||||
memcpy_op: Optional[str] = None,
|
||||
print_conflicts: Optional[bool] = None,
|
||||
test_analysis_only: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
allow_return_allocs_from_loops: Optional[bool] = None,
|
||||
allow_unknown_ops: Optional[bool] = None,
|
||||
bufferize_function_boundaries: Optional[bool] = None,
|
||||
function_boundary_type_conversion: Optional[Enum] = None,
|
||||
memcpy_op: Optional[str] = None,
|
||||
print_conflicts: Optional[bool] = None,
|
||||
test_analysis_only: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformed_type_or_target: Type,
|
||||
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
|
||||
*,
|
||||
allow_return_allocs_from_loops: Optional[bool] = None,
|
||||
allow_unknown_ops: Optional[bool] = None,
|
||||
bufferize_function_boundaries: Optional[bool] = None,
|
||||
function_boundary_type_conversion: Optional[Enum] = None,
|
||||
memcpy_op: Optional[str] = None,
|
||||
print_conflicts: Optional[bool] = None,
|
||||
test_analysis_only: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
if isinstance(transformed_type_or_target, Type):
|
||||
transformed_type = transformed_type_or_target
|
||||
target = target_or_none
|
||||
else:
|
||||
transformed_type = transform.AnyOpType.get()
|
||||
target = transformed_type_or_target
|
||||
|
||||
super().__init__(
|
||||
transformed_type,
|
||||
target,
|
||||
allow_return_allocs_from_loops=allow_return_allocs_from_loops,
|
||||
allow_unknown_ops=allow_unknown_ops,
|
||||
bufferize_function_boundaries=bufferize_function_boundaries,
|
||||
function_boundary_type_conversion=function_boundary_type_conversion,
|
||||
memcpy_op=memcpy_op,
|
||||
print_conflicts=print_conflicts,
|
||||
test_analysis_only=test_analysis_only,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
@@ -1,20 +0,0 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
|
||||
class ModuleOp:
|
||||
"""Specialization for the module op class."""
|
||||
|
||||
def __init__(self, *, loc=None, ip=None):
|
||||
super().__init__(self.build_generic(results=[], operands=[], loc=loc, ip=ip))
|
||||
body = self.regions[0].blocks.append()
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self.regions[0].blocks[0]
|
||||
@@ -1,319 +0,0 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import get_default_loc_context as _get_default_loc_context
|
||||
|
||||
import inspect
|
||||
|
||||
from typing import Any, List, Optional, Sequence, Union
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
|
||||
RESULT_ATTRIBUTE_NAME = "res_attrs"
|
||||
|
||||
|
||||
class ConstantOp:
|
||||
"""Specialization for the constant op class."""
|
||||
|
||||
def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
|
||||
super().__init__(result, value, loc=loc, ip=ip)
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return self.results[0].type
|
||||
|
||||
|
||||
class FuncOp:
|
||||
"""Specialization for the func op class."""
|
||||
|
||||
def __init__(
|
||||
self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
|
||||
):
|
||||
"""
|
||||
Create a FuncOp with the provided `name`, `type`, and `visibility`.
|
||||
- `name` is a string representing the function name.
|
||||
- `type` is either a FunctionType or a pair of list describing inputs and
|
||||
results.
|
||||
- `visibility` is a string matching `public`, `private`, or `nested`. None
|
||||
implies private visibility.
|
||||
- `body_builder` is an optional callback, when provided a new entry block
|
||||
is created and the callback is invoked with the new op as argument within
|
||||
an InsertionPoint context already set for the block. The callback is
|
||||
expected to insert a terminator in the block.
|
||||
"""
|
||||
sym_name = StringAttr.get(str(name))
|
||||
|
||||
# If the type is passed as a tuple, build a FunctionType on the fly.
|
||||
if isinstance(type, tuple):
|
||||
type = FunctionType.get(inputs=type[0], results=type[1])
|
||||
|
||||
type = TypeAttr.get(type)
|
||||
sym_visibility = (
|
||||
StringAttr.get(str(visibility)) if visibility is not None else None
|
||||
)
|
||||
super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
|
||||
if body_builder:
|
||||
entry_block = self.add_entry_block()
|
||||
with InsertionPoint(entry_block):
|
||||
body_builder(self)
|
||||
|
||||
@property
|
||||
def is_external(self):
|
||||
return len(self.regions[0].blocks) == 0
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self.regions[0]
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return FunctionType(TypeAttr(self.attributes["function_type"]).value)
|
||||
|
||||
@property
|
||||
def visibility(self):
|
||||
return self.attributes["sym_visibility"]
|
||||
|
||||
@property
|
||||
def name(self) -> StringAttr:
|
||||
return StringAttr(self.attributes["sym_name"])
|
||||
|
||||
@property
|
||||
def entry_block(self):
|
||||
if self.is_external:
|
||||
raise IndexError("External function does not have a body")
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
|
||||
"""
|
||||
Add an entry block to the function body using the function signature to
|
||||
infer block arguments.
|
||||
Returns the newly created block
|
||||
"""
|
||||
if not self.is_external:
|
||||
raise IndexError("The function already has an entry block!")
|
||||
self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
|
||||
return self.body.blocks[0]
|
||||
|
||||
@property
|
||||
def arg_attrs(self):
|
||||
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
|
||||
|
||||
@arg_attrs.setter
|
||||
def arg_attrs(self, attribute: Union[ArrayAttr, list]):
|
||||
if isinstance(attribute, ArrayAttr):
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
|
||||
else:
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
|
||||
attribute, context=self.context
|
||||
)
|
||||
|
||||
@property
|
||||
def arguments(self):
|
||||
return self.entry_block.arguments
|
||||
|
||||
@property
|
||||
def result_attrs(self):
|
||||
return self.attributes[RESULT_ATTRIBUTE_NAME]
|
||||
|
||||
@result_attrs.setter
|
||||
def result_attrs(self, attribute: ArrayAttr):
|
||||
self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
|
||||
|
||||
@classmethod
|
||||
def from_py_func(
|
||||
FuncOp,
|
||||
*inputs: Type,
|
||||
results: Optional[Sequence[Type]] = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Decorator to define an MLIR FuncOp specified as a python function.
|
||||
|
||||
Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
|
||||
active for the current thread (i.e. established in a `with` block).
|
||||
|
||||
When applied as a decorator to a Python function, an entry block will
|
||||
be constructed for the FuncOp with types as specified in `*inputs`. The
|
||||
block arguments will be passed positionally to the Python function. In
|
||||
addition, if the Python function accepts keyword arguments generally or
|
||||
has a corresponding keyword argument, the following will be passed:
|
||||
* `func_op`: The `func` op being defined.
|
||||
|
||||
By default, the function name will be the Python function `__name__`. This
|
||||
can be overriden by passing the `name` argument to the decorator.
|
||||
|
||||
If `results` is not specified, then the decorator will implicitly
|
||||
insert a `ReturnOp` with the `Value`'s returned from the decorated
|
||||
function. It will also set the `FuncOp` type with the actual return
|
||||
value types. If `results` is specified, then the decorated function
|
||||
must return `None` and no implicit `ReturnOp` is added (nor are the result
|
||||
types updated). The implicit behavior is intended for simple, single-block
|
||||
cases, and users should specify result types explicitly for any complicated
|
||||
cases.
|
||||
|
||||
The decorated function can further be called from Python and will insert
|
||||
a `CallOp` at the then-current insertion point, returning either None (
|
||||
if no return values), a unary Value (for one result), or a list of Values).
|
||||
This mechanism cannot be used to emit recursive calls (by construction).
|
||||
"""
|
||||
|
||||
def decorator(f):
|
||||
from . import func
|
||||
|
||||
# Introspect the callable for optional features.
|
||||
sig = inspect.signature(f)
|
||||
has_arg_func_op = False
|
||||
for param in sig.parameters.values():
|
||||
if param.kind == param.VAR_KEYWORD:
|
||||
has_arg_func_op = True
|
||||
if param.name == "func_op" and (
|
||||
param.kind == param.POSITIONAL_OR_KEYWORD
|
||||
or param.kind == param.KEYWORD_ONLY
|
||||
):
|
||||
has_arg_func_op = True
|
||||
|
||||
# Emit the FuncOp.
|
||||
implicit_return = results is None
|
||||
symbol_name = name or f.__name__
|
||||
function_type = FunctionType.get(
|
||||
inputs=inputs, results=[] if implicit_return else results
|
||||
)
|
||||
func_op = FuncOp(name=symbol_name, type=function_type)
|
||||
with InsertionPoint(func_op.add_entry_block()):
|
||||
func_args = func_op.entry_block.arguments
|
||||
func_kwargs = {}
|
||||
if has_arg_func_op:
|
||||
func_kwargs["func_op"] = func_op
|
||||
return_values = f(*func_args, **func_kwargs)
|
||||
if not implicit_return:
|
||||
return_types = list(results)
|
||||
assert return_values is None, (
|
||||
"Capturing a python function with explicit `results=` "
|
||||
"requires that the wrapped function returns None."
|
||||
)
|
||||
else:
|
||||
# Coerce return values, add ReturnOp and rewrite func type.
|
||||
if return_values is None:
|
||||
return_values = []
|
||||
elif isinstance(return_values, tuple):
|
||||
return_values = list(return_values)
|
||||
elif isinstance(return_values, Value):
|
||||
# Returning a single value is fine, coerce it into a list.
|
||||
return_values = [return_values]
|
||||
elif isinstance(return_values, OpView):
|
||||
# Returning a single operation is fine, coerce its results a list.
|
||||
return_values = return_values.operation.results
|
||||
elif isinstance(return_values, Operation):
|
||||
# Returning a single operation is fine, coerce its results a list.
|
||||
return_values = return_values.results
|
||||
else:
|
||||
return_values = list(return_values)
|
||||
func.ReturnOp(return_values)
|
||||
# Recompute the function type.
|
||||
return_types = [v.type for v in return_values]
|
||||
function_type = FunctionType.get(
|
||||
inputs=inputs, results=return_types
|
||||
)
|
||||
func_op.attributes["function_type"] = TypeAttr.get(function_type)
|
||||
|
||||
def emit_call_op(*call_args):
|
||||
call_op = func.CallOp(
|
||||
return_types, FlatSymbolRefAttr.get(symbol_name), call_args
|
||||
)
|
||||
if return_types is None:
|
||||
return None
|
||||
elif len(return_types) == 1:
|
||||
return call_op.result
|
||||
else:
|
||||
return call_op.results
|
||||
|
||||
wrapped = emit_call_op
|
||||
wrapped.__name__ = f.__name__
|
||||
wrapped.func_op = func_op
|
||||
return wrapped
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class CallOp:
|
||||
"""Specialization for the call op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
calleeOrResults: Union[FuncOp, List[Type]],
|
||||
argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
|
||||
arguments: Optional[List] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Creates an call operation.
|
||||
|
||||
The constructor accepts three different forms:
|
||||
|
||||
1. A function op to be called followed by a list of arguments.
|
||||
2. A list of result types, followed by the name of the function to be
|
||||
called as string, following by a list of arguments.
|
||||
3. A list of result types, followed by the name of the function to be
|
||||
called as symbol reference attribute, followed by a list of arguments.
|
||||
|
||||
For example
|
||||
|
||||
f = func.FuncOp("foo", ...)
|
||||
func.CallOp(f, [args])
|
||||
func.CallOp([result_types], "foo", [args])
|
||||
|
||||
In all cases, the location and insertion point may be specified as keyword
|
||||
arguments if not provided by the surrounding context managers.
|
||||
"""
|
||||
|
||||
# TODO: consider supporting constructor "overloads", e.g., through a custom
|
||||
# or pybind-provided metaclass.
|
||||
if isinstance(calleeOrResults, FuncOp):
|
||||
if not isinstance(argumentsOrCallee, list):
|
||||
raise ValueError(
|
||||
"when constructing a call to a function, expected "
|
||||
+ "the second argument to be a list of call arguments, "
|
||||
+ f"got {type(argumentsOrCallee)}"
|
||||
)
|
||||
if arguments is not None:
|
||||
raise ValueError(
|
||||
"unexpected third argument when constructing a call"
|
||||
+ "to a function"
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
calleeOrResults.type.results,
|
||||
FlatSymbolRefAttr.get(
|
||||
calleeOrResults.name.value, context=_get_default_loc_context(loc)
|
||||
),
|
||||
argumentsOrCallee,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(argumentsOrCallee, list):
|
||||
raise ValueError(
|
||||
"when constructing a call to a function by name, "
|
||||
+ "expected the second argument to be a string or a "
|
||||
+ f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}"
|
||||
)
|
||||
|
||||
if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
|
||||
super().__init__(
|
||||
calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip
|
||||
)
|
||||
elif isinstance(argumentsOrCallee, str):
|
||||
super().__init__(
|
||||
calleeOrResults,
|
||||
FlatSymbolRefAttr.get(
|
||||
argumentsOrCallee, context=_get_default_loc_context(loc)
|
||||
),
|
||||
arguments,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
@@ -1,124 +0,0 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ..dialects import transform
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, Sequence, Union, overload
|
||||
|
||||
|
||||
class MapForallToBlocks:
|
||||
"""Specialization for MapForallToBlocks class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
|
||||
generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
|
||||
generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result_type_or_target: Union[Operation, OpView, Type, Value],
|
||||
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
|
||||
*,
|
||||
grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
|
||||
generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
if isinstance(result_type_or_target, Type):
|
||||
result_type = result_type_or_target
|
||||
target = target_or_none
|
||||
else:
|
||||
result_type = transform.AnyOpType.get()
|
||||
target = result_type_or_target
|
||||
|
||||
super().__init__(
|
||||
result_type,
|
||||
target,
|
||||
grid_dims=grid_dims,
|
||||
generate_gpu_launch=generate_gpu_launch,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class MapNestedForallToThreads:
|
||||
"""Specialization for MapNestedForallToThreads class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
block_dims: Optional[Sequence[int]] = None,
|
||||
warp_size: Optional[Sequence[int]] = None,
|
||||
sync_after_distribute: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
block_dims: Optional[Sequence[int]] = None,
|
||||
warp_size: Optional[Sequence[int]] = None,
|
||||
sync_after_distribute: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result_type_or_target: Union[Operation, OpView, Value, Type],
|
||||
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
|
||||
*,
|
||||
block_dims: Optional[Union[Sequence[int], Attribute]] = None,
|
||||
warp_size: Optional[Union[Sequence[int], Attribute]] = None,
|
||||
sync_after_distribute: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
if isinstance(result_type_or_target, Type):
|
||||
result_type = result_type_or_target
|
||||
target = target_or_none
|
||||
else:
|
||||
result_type = result_type_or_target.type
|
||||
target = result_type_or_target
|
||||
super().__init__(
|
||||
result_type,
|
||||
target,
|
||||
block_dims=block_dims,
|
||||
warp_size=warp_size,
|
||||
sync_after_distribute=sync_after_distribute,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
@@ -1,47 +0,0 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from typing import Optional, Sequence, Union
|
||||
from ..ir import *
|
||||
from ._ods_common import get_default_loc_context
|
||||
from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
|
||||
|
||||
|
||||
def isa(cls: Type, ty: Type):
|
||||
try:
|
||||
cls(ty)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
class StructuredOpMixin:
|
||||
"""All structured ops use the same mixin class."""
|
||||
|
||||
def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
|
||||
super().__init__(
|
||||
self.build_generic(
|
||||
results=list(results),
|
||||
operands=[list(inputs), list(outputs)],
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def select_opview_mixin(parent_opview_cls):
|
||||
# TODO: This shouldn't be a heuristic: we should have a way to annotate
|
||||
# the OpView to note that it is a structured op.
|
||||
if (
|
||||
"__init__" not in parent_opview_cls.__dict__
|
||||
and hasattr(parent_opview_cls, "inputs")
|
||||
and hasattr(parent_opview_cls, "outputs")
|
||||
and hasattr(parent_opview_cls, "result_tensors")
|
||||
):
|
||||
return StructuredOpMixin
|
||||
@@ -1,134 +0,0 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class GetParentForOp:
|
||||
"""Extension for GetParentForOp."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
num_loops: Optional[int] = None,
|
||||
ip=None,
|
||||
loc=None,
|
||||
):
|
||||
if num_loops is None:
|
||||
num_loops = 1
|
||||
super().__init__(
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
num_loops=num_loops,
|
||||
ip=ip,
|
||||
loc=loc,
|
||||
)
|
||||
|
||||
|
||||
class LoopOutlineOp:
|
||||
"""Extension for LoopOutlineOp."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
function_type: Type,
|
||||
call_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
func_name: Union[str, StringAttr],
|
||||
ip=None,
|
||||
loc=None,
|
||||
):
|
||||
super().__init__(
|
||||
function_type,
|
||||
call_type,
|
||||
_get_op_result_or_value(target),
|
||||
func_name=(
|
||||
func_name
|
||||
if isinstance(func_name, StringAttr)
|
||||
else StringAttr.get(func_name)
|
||||
),
|
||||
ip=ip,
|
||||
loc=loc,
|
||||
)
|
||||
|
||||
|
||||
class LoopPeelOp:
|
||||
"""Extension for LoopPeelOp."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
main_loop_type: Type,
|
||||
remainder_loop_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
fail_if_already_divisible: Union[bool, BoolAttr] = False,
|
||||
ip=None,
|
||||
loc=None,
|
||||
):
|
||||
super().__init__(
|
||||
main_loop_type,
|
||||
remainder_loop_type,
|
||||
_get_op_result_or_value(target),
|
||||
fail_if_already_divisible=(
|
||||
fail_if_already_divisible
|
||||
if isinstance(fail_if_already_divisible, BoolAttr)
|
||||
else BoolAttr.get(fail_if_already_divisible)
|
||||
),
|
||||
ip=ip,
|
||||
loc=loc,
|
||||
)
|
||||
|
||||
|
||||
class LoopPipelineOp:
|
||||
"""Extension for LoopPipelineOp."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
iteration_interval: Optional[Union[int, IntegerAttr]] = None,
|
||||
read_latency: Optional[Union[int, IntegerAttr]] = None,
|
||||
ip=None,
|
||||
loc=None,
|
||||
):
|
||||
if iteration_interval is None:
|
||||
iteration_interval = 1
|
||||
if read_latency is None:
|
||||
read_latency = 10
|
||||
super().__init__(
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
iteration_interval=iteration_interval,
|
||||
read_latency=read_latency,
|
||||
ip=ip,
|
||||
loc=loc,
|
||||
)
|
||||
|
||||
|
||||
class LoopUnrollOp:
|
||||
"""Extension for LoopUnrollOp."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
factor: Union[int, IntegerAttr],
|
||||
ip=None,
|
||||
loc=None,
|
||||
):
|
||||
super().__init__(
|
||||
_get_op_result_or_value(target),
|
||||
factor=factor,
|
||||
ip=ip,
|
||||
loc=loc,
|
||||
)
|
||||
@@ -1,36 +0,0 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
|
||||
from ._ods_common import get_op_results_or_values as _get_op_results_or_values
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
|
||||
class LoadOp:
|
||||
"""Specialization for the MemRef load operation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
memref: Union[Operation, OpView, Value],
|
||||
indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
"""Creates a memref load operation.
|
||||
|
||||
Args:
|
||||
memref: the buffer to load from.
|
||||
indices: the list of subscripts, may be empty for zero-dimensional
|
||||
buffers.
|
||||
loc: user-visible location of the operation.
|
||||
ip: insertion point.
|
||||
"""
|
||||
indices_resolved = [] if indices is None else _get_op_results_or_values(indices)
|
||||
super().__init__(memref, indices_resolved, loc=loc, ip=ip)
|
||||
@@ -1,114 +0,0 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ..dialects import transform
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, overload, Union
|
||||
|
||||
|
||||
class MemRefAllocaToGlobalOp:
|
||||
"""Specialization for MemRefAllocaToGlobalOp class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
get_global_type: Type,
|
||||
global_type: Type,
|
||||
alloca: Union[Operation, OpView, Value],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(self, alloca: Union[Operation, OpView, Value], *, loc=None, ip=None):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
get_global_type_or_alloca: Union[Operation, OpView, Type, Value],
|
||||
global_type_or_none: Optional[Type] = None,
|
||||
alloca_or_none: Optional[Union[Operation, OpView, Value]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
if isinstance(get_global_type_or_alloca, Type):
|
||||
get_global_type = get_global_type_or_alloca
|
||||
global_type = global_type_or_none
|
||||
alloca = alloca_or_none
|
||||
else:
|
||||
get_global_type = transform.AnyOpType.get()
|
||||
global_type = transform.AnyOpType.get()
|
||||
alloca = get_global_type_or_alloca
|
||||
|
||||
super().__init__(
|
||||
get_global_type,
|
||||
global_type,
|
||||
alloca,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class MemRefMultiBufferOp:
|
||||
"""Specialization for MemRefMultiBufferOp class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
transformed_type: Type,
|
||||
target: Union[Operation, OpView, Value],
|
||||
factor: Union[int, IntegerAttr],
|
||||
*,
|
||||
skip_analysis: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, OpView, Value],
|
||||
factor: Union[int, IntegerAttr],
|
||||
*,
|
||||
skip_analysis: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformed_type_or_target: Type,
|
||||
target_or_factor: Union[int, IntegerAttr, Operation, OpView, Value] = None,
|
||||
factor_or_none: Optional[Union[int, IntegerAttr]] = None,
|
||||
*,
|
||||
skip_analysis: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
if isinstance(transformed_type_or_target, Type):
|
||||
transformed_type = transformed_type_or_target
|
||||
target = target_or_factor
|
||||
factor = factor_or_none
|
||||
else:
|
||||
transformed_type = transform.AnyOpType.get()
|
||||
target = transformed_type_or_target
|
||||
factor = target_or_factor
|
||||
|
||||
super().__init__(
|
||||
transformed_type,
|
||||
target,
|
||||
factor,
|
||||
skip_analysis=skip_analysis,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
@@ -1,113 +0,0 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from typing import Union
|
||||
from ..ir import *
|
||||
from ._ods_common import get_default_loc_context as _get_default_loc_context
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from ._ml_program_ops_gen import *
|
||||
|
||||
|
||||
ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
|
||||
RESULT_ATTRIBUTE_NAME = "res_attrs"
|
||||
|
||||
|
||||
class FuncOp:
|
||||
"""Specialization for the func op class."""
|
||||
|
||||
def __init__(
|
||||
self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
|
||||
):
|
||||
"""
|
||||
Create a FuncOp with the provided `name`, `type`, and `visibility`.
|
||||
- `name` is a string representing the function name.
|
||||
- `type` is either a FunctionType or a pair of list describing inputs and
|
||||
results.
|
||||
- `visibility` is a string matching `public`, `private`, or `nested`. None
|
||||
implies private visibility.
|
||||
- `body_builder` is an optional callback, when provided a new entry block
|
||||
is created and the callback is invoked with the new op as argument within
|
||||
an InsertionPoint context already set for the block. The callback is
|
||||
expected to insert a terminator in the block.
|
||||
"""
|
||||
sym_name = StringAttr.get(str(name))
|
||||
|
||||
# If the type is passed as a tuple, build a FunctionType on the fly.
|
||||
if isinstance(type, tuple):
|
||||
type = FunctionType.get(inputs=type[0], results=type[1])
|
||||
|
||||
type = TypeAttr.get(type)
|
||||
sym_visibility = (
|
||||
StringAttr.get(str(visibility)) if visibility is not None else None
|
||||
)
|
||||
super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
|
||||
if body_builder:
|
||||
entry_block = self.add_entry_block()
|
||||
with InsertionPoint(entry_block):
|
||||
body_builder(self)
|
||||
|
||||
@property
|
||||
def is_external(self):
|
||||
return len(self.regions[0].blocks) == 0
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self.regions[0]
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return FunctionType(TypeAttr(self.attributes["function_type"]).value)
|
||||
|
||||
@property
|
||||
def visibility(self):
|
||||
return self.attributes["sym_visibility"]
|
||||
|
||||
@property
|
||||
def name(self) -> StringAttr:
|
||||
return StringAttr(self.attributes["sym_name"])
|
||||
|
||||
@property
|
||||
def entry_block(self):
|
||||
if self.is_external:
|
||||
raise IndexError("External function does not have a body")
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
def add_entry_block(self):
|
||||
"""
|
||||
Add an entry block to the function body using the function signature to
|
||||
infer block arguments.
|
||||
Returns the newly created block
|
||||
"""
|
||||
if not self.is_external:
|
||||
raise IndexError("The function already has an entry block!")
|
||||
self.body.blocks.append(*self.type.inputs)
|
||||
return self.body.blocks[0]
|
||||
|
||||
@property
|
||||
def arg_attrs(self):
|
||||
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
|
||||
|
||||
@arg_attrs.setter
|
||||
def arg_attrs(self, attribute: Union[ArrayAttr, list]):
|
||||
if isinstance(attribute, ArrayAttr):
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
|
||||
else:
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
|
||||
attribute, context=self.context
|
||||
)
|
||||
|
||||
@property
|
||||
def arguments(self):
|
||||
return self.entry_block.arguments
|
||||
|
||||
@property
|
||||
def result_attrs(self):
|
||||
return self.attributes[RESULT_ATTRIBUTE_NAME]
|
||||
|
||||
@result_attrs.setter
|
||||
def result_attrs(self, attribute: ArrayAttr):
|
||||
self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
|
||||
@@ -9,7 +9,6 @@ from typing import Sequence as _Sequence, Union as _Union
|
||||
|
||||
__all__ = [
|
||||
"equally_sized_accessor",
|
||||
"extend_opview_class",
|
||||
"get_default_loc_context",
|
||||
"get_op_result_or_value",
|
||||
"get_op_results_or_values",
|
||||
@@ -18,64 +17,6 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def extend_opview_class(ext_module):
|
||||
"""Decorator to extend an OpView class from an extension module.
|
||||
|
||||
Extension modules can expose various entry-points:
|
||||
Stand-alone class with the same name as a parent OpView class (i.e.
|
||||
"ReturnOp"). A name-based match is attempted first before falling back
|
||||
to a below mechanism.
|
||||
|
||||
def select_opview_mixin(parent_opview_cls):
|
||||
If defined, allows an appropriate mixin class to be selected dynamically
|
||||
based on the parent OpView class. Should return NotImplemented if a
|
||||
decision is not made.
|
||||
|
||||
Args:
|
||||
ext_module: A module from which to locate extensions. Can be None if not
|
||||
available.
|
||||
|
||||
Returns:
|
||||
A decorator that takes an OpView subclass and further extends it as
|
||||
needed.
|
||||
"""
|
||||
|
||||
def class_decorator(parent_opview_cls: type):
|
||||
if ext_module is None:
|
||||
return parent_opview_cls
|
||||
mixin_cls = NotImplemented
|
||||
# First try to resolve by name.
|
||||
try:
|
||||
mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
|
||||
except AttributeError:
|
||||
# Fall back to a select_opview_mixin hook.
|
||||
try:
|
||||
select_mixin = getattr(ext_module, "select_opview_mixin")
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
mixin_cls = select_mixin(parent_opview_cls)
|
||||
|
||||
if mixin_cls is NotImplemented or mixin_cls is None:
|
||||
return parent_opview_cls
|
||||
|
||||
# Have a mixin_cls. Create an appropriate subclass.
|
||||
try:
|
||||
|
||||
class LocalOpView(mixin_cls, parent_opview_cls):
|
||||
pass
|
||||
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
f"Could not mixin {mixin_cls} into {parent_opview_cls}"
|
||||
) from e
|
||||
LocalOpView.__name__ = parent_opview_cls.__name__
|
||||
LocalOpView.__qualname__ = parent_opview_cls.__qualname__
|
||||
return LocalOpView
|
||||
|
||||
return class_decorator
|
||||
|
||||
|
||||
def segmented_accessor(elements, raw_segments, idx):
|
||||
"""
|
||||
Returns a slice of elements corresponding to the idx-th segment.
|
||||
|
||||
@@ -1,271 +0,0 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ..dialects import pdl
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Union, Optional, Sequence, Mapping
|
||||
from ._ods_common import (
|
||||
get_op_result_or_value as _get_value,
|
||||
get_op_results_or_values as _get_values,
|
||||
)
|
||||
|
||||
|
||||
class ApplyNativeConstraintOp:
|
||||
"""Specialization for PDL apply native constraint op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Union[str, StringAttr],
|
||||
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if args is None:
|
||||
args = []
|
||||
args = _get_values(args)
|
||||
super().__init__(name, args, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class ApplyNativeRewriteOp:
|
||||
"""Specialization for PDL apply native rewrite op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
results: Sequence[Type],
|
||||
name: Union[str, StringAttr],
|
||||
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if args is None:
|
||||
args = []
|
||||
args = _get_values(args)
|
||||
super().__init__(results, name, args, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class AttributeOp:
|
||||
"""Specialization for PDL attribute op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
valueType: Optional[Union[OpView, Operation, Value]] = None,
|
||||
value: Optional[Attribute] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
valueType = valueType if valueType is None else _get_value(valueType)
|
||||
result = pdl.AttributeType.get()
|
||||
super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class EraseOp:
|
||||
"""Specialization for PDL erase op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
operation: Optional[Union[OpView, Operation, Value]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
operation = _get_value(operation)
|
||||
super().__init__(operation, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class OperandOp:
|
||||
"""Specialization for PDL operand op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: Optional[Union[OpView, Operation, Value]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
type = type if type is None else _get_value(type)
|
||||
result = pdl.ValueType.get()
|
||||
super().__init__(result, valueType=type, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class OperandsOp:
|
||||
"""Specialization for PDL operands op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
types: Optional[Union[OpView, Operation, Value]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
types = types if types is None else _get_value(types)
|
||||
result = pdl.RangeType.get(pdl.ValueType.get())
|
||||
super().__init__(result, valueType=types, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class OperationOp:
|
||||
"""Specialization for PDL operand op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[Union[str, StringAttr]] = None,
|
||||
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
attributes: Optional[Mapping[str, Union[OpView, Operation, Value]]] = None,
|
||||
types: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if types is None:
|
||||
types = []
|
||||
if attributes is None:
|
||||
attributes = {}
|
||||
if args is None:
|
||||
args = []
|
||||
args = _get_values(args)
|
||||
attrNames = []
|
||||
attrValues = []
|
||||
for attrName, attrValue in attributes.items():
|
||||
attrNames.append(StringAttr.get(attrName))
|
||||
attrValues.append(_get_value(attrValue))
|
||||
attrNames = ArrayAttr.get(attrNames)
|
||||
types = _get_values(types)
|
||||
result = pdl.OperationType.get()
|
||||
super().__init__(
|
||||
result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
class PatternOp:
|
||||
"""Specialization for PDL pattern op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
benefit: Union[IntegerAttr, int],
|
||||
name: Optional[Union[StringAttr, str]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Creates an PDL `pattern` operation."""
|
||||
super().__init__(benefit, sym_name=name, loc=loc, ip=ip)
|
||||
self.regions[0].blocks.append()
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
"""Return the body (block) of the pattern."""
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
|
||||
class ReplaceOp:
|
||||
"""Specialization for PDL replace op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
op: Union[OpView, Operation, Value],
|
||||
*,
|
||||
with_op: Optional[Union[OpView, Operation, Value]] = None,
|
||||
with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if with_values is None:
|
||||
with_values = []
|
||||
op = _get_value(op)
|
||||
with_op = with_op if with_op is None else _get_value(with_op)
|
||||
with_values = _get_values(with_values)
|
||||
super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class ResultOp:
|
||||
"""Specialization for PDL result op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent: Union[OpView, Operation, Value],
|
||||
index: Union[IntegerAttr, int],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
parent = _get_value(parent)
|
||||
result = pdl.ValueType.get()
|
||||
super().__init__(result, parent, index, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class ResultsOp:
|
||||
"""Specialization for PDL results op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result: Type,
|
||||
parent: Union[OpView, Operation, Value],
|
||||
index: Optional[Union[IntegerAttr, int]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
parent = _get_value(parent)
|
||||
super().__init__(result, parent, index=index, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class RewriteOp:
|
||||
"""Specialization for PDL rewrite op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: Optional[Union[OpView, Operation, Value]] = None,
|
||||
name: Optional[Union[StringAttr, str]] = None,
|
||||
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if args is None:
|
||||
args = []
|
||||
root = root if root is None else _get_value(root)
|
||||
args = _get_values(args)
|
||||
super().__init__(args, root=root, name=name, loc=loc, ip=ip)
|
||||
|
||||
def add_body(self):
|
||||
"""Add body (block) to the rewrite."""
|
||||
self.regions[0].blocks.append()
|
||||
return self.body
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
"""Return the body (block) of the rewrite."""
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
|
||||
class TypeOp:
|
||||
"""Specialization for PDL type op class."""
|
||||
|
||||
def __init__(
|
||||
self, constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None
|
||||
):
|
||||
result = pdl.TypeType.get()
|
||||
super().__init__(result, constantType=constantType, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class TypesOp:
|
||||
"""Specialization for PDL types op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if constantTypes is None:
|
||||
constantTypes = []
|
||||
result = pdl.RangeType.get(pdl.TypeType.get())
|
||||
super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
|
||||
@@ -1,107 +0,0 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from ._ods_common import (
|
||||
get_op_result_or_value as _get_op_result_or_value,
|
||||
get_op_results_or_values as _get_op_results_or_values,
|
||||
)
|
||||
|
||||
|
||||
class ForOp:
|
||||
"""Specialization for the SCF for op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lower_bound,
|
||||
upper_bound,
|
||||
step,
|
||||
iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Creates an SCF `for` operation.
|
||||
|
||||
- `lower_bound` is the value to use as lower bound of the loop.
|
||||
- `upper_bound` is the value to use as upper bound of the loop.
|
||||
- `step` is the value to use as loop step.
|
||||
- `iter_args` is a list of additional loop-carried arguments or an operation
|
||||
producing them as results.
|
||||
"""
|
||||
if iter_args is None:
|
||||
iter_args = []
|
||||
iter_args = _get_op_results_or_values(iter_args)
|
||||
|
||||
results = [arg.type for arg in iter_args]
|
||||
super().__init__(
|
||||
self.build_generic(
|
||||
regions=1,
|
||||
results=results,
|
||||
operands=[
|
||||
_get_op_result_or_value(o) for o in [lower_bound, upper_bound, step]
|
||||
]
|
||||
+ list(iter_args),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
self.regions[0].blocks.append(self.operands[0].type, *results)
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
"""Returns the body (block) of the loop."""
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
@property
|
||||
def induction_variable(self):
|
||||
"""Returns the induction variable of the loop."""
|
||||
return self.body.arguments[0]
|
||||
|
||||
@property
|
||||
def inner_iter_args(self):
|
||||
"""Returns the loop-carried arguments usable within the loop.
|
||||
|
||||
To obtain the loop-carried operands, use `iter_args`.
|
||||
"""
|
||||
return self.body.arguments[1:]
|
||||
|
||||
|
||||
class IfOp:
|
||||
"""Specialization for the SCF if op class."""
|
||||
|
||||
def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
|
||||
"""Creates an SCF `if` operation.
|
||||
|
||||
- `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
|
||||
- `hasElse` determines whether the if operation has the else branch.
|
||||
"""
|
||||
operands = []
|
||||
operands.append(cond)
|
||||
results = []
|
||||
results.extend(results_)
|
||||
super().__init__(
|
||||
self.build_generic(
|
||||
regions=2, results=results, operands=operands, loc=loc, ip=ip
|
||||
)
|
||||
)
|
||||
self.regions[0].blocks.append(*[])
|
||||
if hasElse:
|
||||
self.regions[1].blocks.append(*[])
|
||||
|
||||
@property
|
||||
def then_block(self):
|
||||
"""Returns the then block of the if operation."""
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
@property
|
||||
def else_block(self):
|
||||
"""Returns the else block of the if operation."""
|
||||
return self.regions[1].blocks[0]
|
||||
@@ -1,759 +0,0 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ..dialects import transform
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import List, Optional, Sequence, Tuple, Union, overload
|
||||
|
||||
StaticIntLike = Union[int, IntegerAttr]
|
||||
ValueLike = Union[Operation, OpView, Value]
|
||||
MixedInt = Union[StaticIntLike, ValueLike]
|
||||
|
||||
IntOrAttrList = Sequence[Union[IntegerAttr, int]]
|
||||
OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
|
||||
|
||||
BoolOrAttrList = Sequence[Union[BoolAttr, bool]]
|
||||
OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]]
|
||||
|
||||
MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
|
||||
|
||||
DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]]
|
||||
|
||||
|
||||
def _dispatch_dynamic_index_list(
|
||||
indices: Union[DynamicIndexList, ArrayAttr],
|
||||
) -> Tuple[List[ValueLike], Union[List[int], ArrayAttr], List[bool]]:
|
||||
"""Dispatches a list of indices to the appropriate form.
|
||||
|
||||
This is similar to the custom `DynamicIndexList` directive upstream:
|
||||
provided indices may be in the form of dynamic SSA values or static values,
|
||||
and they may be scalable (i.e., as a singleton list) or not. This function
|
||||
dispatches each index into its respective form. It also extracts the SSA
|
||||
values and static indices from various similar structures, respectively.
|
||||
"""
|
||||
dynamic_indices = []
|
||||
static_indices = [ShapedType.get_dynamic_size()] * len(indices)
|
||||
scalable_indices = [False] * len(indices)
|
||||
|
||||
# ArrayAttr: Extract index values.
|
||||
if isinstance(indices, ArrayAttr):
|
||||
indices = [idx for idx in indices]
|
||||
|
||||
def process_nonscalable_index(i, index):
|
||||
"""Processes any form of non-scalable index.
|
||||
|
||||
Returns False if the given index was scalable and thus remains
|
||||
unprocessed; True otherwise.
|
||||
"""
|
||||
if isinstance(index, int):
|
||||
static_indices[i] = index
|
||||
elif isinstance(index, IntegerAttr):
|
||||
static_indices[i] = index.value # pytype: disable=attribute-error
|
||||
elif isinstance(index, (Operation, Value, OpView)):
|
||||
dynamic_indices.append(index)
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
# Process each index at a time.
|
||||
for i, index in enumerate(indices):
|
||||
if not process_nonscalable_index(i, index):
|
||||
# If it wasn't processed, it must be a scalable index, which is
|
||||
# provided as a Sequence of one value, so extract and process that.
|
||||
scalable_indices[i] = True
|
||||
assert len(index) == 1
|
||||
ret = process_nonscalable_index(i, index[0])
|
||||
assert ret
|
||||
|
||||
return dynamic_indices, static_indices, scalable_indices
|
||||
|
||||
|
||||
# Dispatches `MixedValues` that all represents integers in various forms into
|
||||
# the following three categories:
|
||||
# - `dynamic_values`: a list of `Value`s, potentially from op results;
|
||||
# - `packed_values`: a value handle, potentially from an op result, associated
|
||||
# to one or more payload operations of integer type;
|
||||
# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
|
||||
# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
|
||||
# The input is in the form for `packed_values`, only that result is set and the
|
||||
# other two are empty. Otherwise, the input can be a mix of the other two forms,
|
||||
# and for each dynamic value, a special value is added to the `static_values`.
|
||||
def _dispatch_mixed_values(
|
||||
values: MixedValues,
|
||||
) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]:
|
||||
dynamic_values = []
|
||||
packed_values = None
|
||||
static_values = None
|
||||
if isinstance(values, ArrayAttr):
|
||||
static_values = values
|
||||
elif isinstance(values, (Operation, Value, OpView)):
|
||||
packed_values = values
|
||||
else:
|
||||
static_values = []
|
||||
for size in values or []:
|
||||
if isinstance(size, int):
|
||||
static_values.append(size)
|
||||
else:
|
||||
static_values.append(ShapedType.get_dynamic_size())
|
||||
dynamic_values.append(size)
|
||||
static_values = DenseI64ArrayAttr.get(static_values)
|
||||
|
||||
return (dynamic_values, packed_values, static_values)
|
||||
|
||||
|
||||
def _get_value_or_attribute_value(
|
||||
value_or_attr: Union[any, Attribute, ArrayAttr]
|
||||
) -> any:
|
||||
if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
|
||||
return value_or_attr.value
|
||||
if isinstance(value_or_attr, ArrayAttr):
|
||||
return _get_value_list(value_or_attr)
|
||||
return value_or_attr
|
||||
|
||||
|
||||
def _get_value_list(
|
||||
sequence_or_array_attr: Union[Sequence[any], ArrayAttr]
|
||||
) -> Sequence[any]:
|
||||
return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
|
||||
|
||||
|
||||
def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr:
|
||||
if values is None:
|
||||
return None
|
||||
|
||||
# Turn into a Python list of Python ints.
|
||||
values = _get_value_list(values)
|
||||
|
||||
# Make an ArrayAttr of IntegerAttrs out of it.
|
||||
return ArrayAttr.get(
|
||||
[IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
|
||||
)
|
||||
|
||||
|
||||
def _get_int_array_array_attr(
|
||||
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
|
||||
) -> ArrayAttr:
|
||||
"""Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
|
||||
|
||||
The input has to be a collection of collection of integers, where any
|
||||
Python Sequence and ArrayAttr are admissible collections and Python ints and
|
||||
any IntegerAttr are admissible integers. Both levels of collections are
|
||||
turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
|
||||
If the input is None, an empty ArrayAttr is returned.
|
||||
"""
|
||||
if values is None:
|
||||
return None
|
||||
|
||||
# Make sure the outer level is a list.
|
||||
values = _get_value_list(values)
|
||||
|
||||
# The inner level is now either invalid or a mixed sequence of ArrayAttrs and
|
||||
# Sequences. Make sure the nested values are all lists.
|
||||
values = [_get_value_list(nested) for nested in values]
|
||||
|
||||
# Turn each nested list into an ArrayAttr.
|
||||
values = [_get_int_array_attr(nested) for nested in values]
|
||||
|
||||
# Turn the outer list into an ArrayAttr.
|
||||
return ArrayAttr.get(values)
|
||||
|
||||
|
||||
class BufferizeToAllocationOp:
|
||||
"""Specialization for BufferizeToAllocationOp class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
memory_space: Optional[Union[int, str, Attribute]] = None,
|
||||
memcpy_op: Optional[str] = None,
|
||||
alloc_op: Optional[str] = None,
|
||||
bufferize_destination_only: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
# No other types are allowed, so hard-code those here.
|
||||
allocated_buffer_type = transform.AnyValueType.get()
|
||||
new_ops_type = transform.AnyOpType.get()
|
||||
|
||||
if isinstance(memory_space, int):
|
||||
memory_space = str(memory_space)
|
||||
if isinstance(memory_space, str):
|
||||
memory_space = Attribute.parse(memory_space)
|
||||
|
||||
super().__init__(
|
||||
allocated_buffer_type,
|
||||
new_ops_type,
|
||||
target,
|
||||
memory_space=memory_space,
|
||||
memcpy_op=memcpy_op,
|
||||
alloc_op=alloc_op,
|
||||
bufferize_destination_only=bufferize_destination_only,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class DecomposeOp:
|
||||
"""Specialization for DecomposeOp class."""
|
||||
|
||||
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
|
||||
transformed_type = transform.AnyOpType.get()
|
||||
super().__init__(transformed_type, target, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class FuseIntoContainingOp:
|
||||
"""Specialization for FuseIntoContainingOp class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
fused_op_type: Type,
|
||||
new_containing_op_type: Type,
|
||||
producer_op: Union[Operation, OpView, Value],
|
||||
containing_op: Union[Operation, OpView, Value],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
producer_op: Union[Operation, OpView, Value],
|
||||
containing_op: Union[Operation, OpView, Value],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fused_op_type_or_producer_op: Union[Operation, OpView, Type, Value],
|
||||
new_containing_op_type_or_containing_op: Union[Operation, OpView, Type, Value],
|
||||
producer_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
|
||||
containing_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(fused_op_type_or_producer_op, Type):
|
||||
if not isinstance(new_containing_op_type_or_containing_op, Type):
|
||||
raise TypeError(
|
||||
"If 'fused_op_type_or_producer_op' is a type, then "
|
||||
"'new_containing_op_type_or_containing_op' is expected "
|
||||
"to be one as well."
|
||||
)
|
||||
fused_op_type = fused_op_type_or_producer_op
|
||||
new_containing_op_type = new_containing_op_type_or_containing_op
|
||||
producer_op = producer_op_or_none
|
||||
containing_op = containing_op_or_none
|
||||
else:
|
||||
fused_op_type = transform.AnyOpType.get()
|
||||
new_containing_op_type = transform.AnyOpType.get()
|
||||
producer_op = fused_op_type_or_producer_op
|
||||
containing_op = new_containing_op_type_or_containing_op
|
||||
|
||||
super().__init__(
|
||||
fused_op_type,
|
||||
new_containing_op_type,
|
||||
producer_op,
|
||||
containing_op,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class GeneralizeOp:
|
||||
"""Specialization for GeneralizeOp class."""
|
||||
|
||||
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
|
||||
transformed_type = transform.AnyOpType.get()
|
||||
super().__init__(transformed_type, target, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class InterchangeOp:
|
||||
"""Specialization for InterchangeOp class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
iterator_interchange: OptionalIntList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
transformed_type = transform.AnyOpType.get()
|
||||
super().__init__(
|
||||
transformed_type,
|
||||
target,
|
||||
iterator_interchange=iterator_interchange,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class MapCopyToThreadsOp:
|
||||
"""Specialization for MapCopyToThreadsOp class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
forall_op_type: Type,
|
||||
tiled_op_type: Type,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
total_num_threads: Union[int, IntegerAttr],
|
||||
desired_bit_alignment: Union[int, IntegerAttr],
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
total_num_threads: Union[int, IntegerAttr],
|
||||
desired_bit_alignment: Union[int, IntegerAttr],
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
forall_op_type_or_target: Union[Operation, OpView, Type, Value],
|
||||
tiled_op_type_or_none: Optional[Type] = None,
|
||||
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
|
||||
*,
|
||||
total_num_threads: Union[int, IntegerAttr],
|
||||
desired_bit_alignment: Union[int, IntegerAttr],
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(forall_op_type_or_target, Type):
|
||||
forall_op_type = forall_op_type_or_target
|
||||
tiled_op_type = tiled_op_type_or_none
|
||||
target = target_or_none
|
||||
else:
|
||||
forall_op_type = transform.AnyOpType.get()
|
||||
tiled_op_type = transform.AnyOpType.get()
|
||||
target = forall_op_type_or_target
|
||||
|
||||
super().__init__(
|
||||
forall_op_type,
|
||||
tiled_op_type,
|
||||
target,
|
||||
total_num_threads=total_num_threads,
|
||||
desired_bit_alignment=desired_bit_alignment,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class VectorizeOp:
|
||||
"""Specialization for VectorizeOp class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, OpView, Value],
|
||||
vector_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
|
||||
*,
|
||||
vectorize_nd_extract: Optional[bool] = None,
|
||||
scalable_sizes: OptionalBoolList = None,
|
||||
static_vector_sizes: OptionalIntList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if (
|
||||
scalable_sizes is None
|
||||
and static_vector_sizes is None
|
||||
and vector_sizes is None
|
||||
):
|
||||
dynamic_vector_sizes = []
|
||||
elif scalable_sizes is None and static_vector_sizes is None:
|
||||
(
|
||||
dynamic_vector_sizes,
|
||||
static_vector_sizes,
|
||||
scalable_sizes,
|
||||
) = _dispatch_dynamic_index_list(vector_sizes)
|
||||
elif scalable_sizes is None or static_vector_sizes is None:
|
||||
raise TypeError(
|
||||
"'scalable_sizes' and 'static_vector_sizes' must either both "
|
||||
"be given explicitly or both be given as part of 'vector_sizes'."
|
||||
)
|
||||
else:
|
||||
dynamic_vector_sizes = vector_sizes
|
||||
|
||||
super().__init__(
|
||||
target,
|
||||
vector_sizes=dynamic_vector_sizes,
|
||||
static_vector_sizes=static_vector_sizes,
|
||||
scalable_sizes=scalable_sizes,
|
||||
vectorize_nd_extract=vectorize_nd_extract,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class MatchOp:
|
||||
"""Specialization for MatchOp class."""
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def match_op_names(
|
||||
cls,
|
||||
target: Union[Operation, Value],
|
||||
names: Union[str, Sequence[str]],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def match_op_names(
|
||||
cls,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
names: Union[str, Sequence[str]],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def match_op_names(
|
||||
cls,
|
||||
result_type_or_target: Union[Type, Operation, Value],
|
||||
target_or_names: Union[Operation, Value, Sequence[str], str],
|
||||
names_or_none: Optional[Union[Sequence[str], str]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(result_type_or_target, Type):
|
||||
result_type = result_type_or_target
|
||||
target = target_or_names
|
||||
names = names_or_none
|
||||
else:
|
||||
result_type = transform.AnyOpType.get()
|
||||
target = result_type_or_target
|
||||
names = target_or_names
|
||||
|
||||
if isinstance(names, str):
|
||||
names = [names]
|
||||
|
||||
return cls(
|
||||
result_type,
|
||||
target,
|
||||
ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class MultiTileSizesOp:
|
||||
"""Specialization for MultiTileSizesOp class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
dimension: Union[int, IntegerAttr],
|
||||
target_size: Union[int, IntegerAttr],
|
||||
divisor: Optional[Optional[Union[int, IntegerAttr]]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(
|
||||
result_type,
|
||||
result_type,
|
||||
result_type,
|
||||
target,
|
||||
dimension=dimension,
|
||||
target_size=target_size,
|
||||
divisor=divisor,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class PadOp:
|
||||
"""Specialization for PadOp class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
|
||||
padding_dimensions: OptionalIntList = None,
|
||||
pad_to_multiple_of: OptionalIntList = None,
|
||||
pack_paddings: OptionalIntList = None,
|
||||
transpose_paddings: Optional[
|
||||
Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
|
||||
] = None,
|
||||
copy_back_op: Optional[Union[str, StringAttr]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
transpose_paddings = _get_int_array_array_attr(transpose_paddings)
|
||||
|
||||
any_op_type = transform.AnyOpType.get()
|
||||
super().__init__(
|
||||
any_op_type,
|
||||
any_op_type,
|
||||
any_op_type,
|
||||
target,
|
||||
padding_values=padding_values,
|
||||
padding_dimensions=padding_dimensions,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
pack_paddings=pack_paddings,
|
||||
transpose_paddings=transpose_paddings,
|
||||
copy_back_op=copy_back_op,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class ScalarizeOp:
|
||||
"""Specialization for ScalarizeOp class."""
|
||||
|
||||
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
|
||||
result_type = transform.AnyOpType.get()
|
||||
super().__init__(result_type, target, loc=loc, ip=ip)
|
||||
|
||||
|
||||
class SplitOp:
|
||||
"""Specialization for SplitOp class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value],
|
||||
dimension: Union[int, Attribute],
|
||||
split_point: Union[int, Operation, Value, Attribute],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(split_point, int):
|
||||
static_split_point = split_point
|
||||
dynamic_split_point = None
|
||||
else:
|
||||
static_split_point = ShapedType.get_dynamic_size()
|
||||
dynamic_split_point = split_point
|
||||
|
||||
super().__init__(
|
||||
target.type,
|
||||
target.type,
|
||||
target,
|
||||
dimension=dimension,
|
||||
static_split_point=static_split_point,
|
||||
dynamic_split_point=dynamic_split_point,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class TileUsingForOp:
|
||||
"""Specialization for TileUsingForOp class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
loop_types: Union[Type, List[Type]],
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
|
||||
interchange: OptionalIntList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value, OpView],
|
||||
*,
|
||||
sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
|
||||
interchange: OptionalIntList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loop_types_or_target: Union[Type, List[Type], Operation, Value],
|
||||
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
|
||||
*,
|
||||
sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
|
||||
interchange: OptionalIntList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
(
|
||||
dynamic_sizes,
|
||||
static_sizes,
|
||||
scalable_sizes,
|
||||
) = _dispatch_dynamic_index_list(sizes)
|
||||
|
||||
num_loops = sum(v if v == 0 else 1 for v in static_sizes)
|
||||
|
||||
if isinstance(loop_types_or_target, (Operation, Value, OpView)):
|
||||
loop_types = [transform.AnyOpType.get()] * num_loops
|
||||
target = loop_types_or_target
|
||||
assert (
|
||||
target_or_none is None
|
||||
), "Cannot construct TileUsingForOp with two targets."
|
||||
else:
|
||||
loop_types = (
|
||||
([loop_types_or_target] * num_loops)
|
||||
if isinstance(loop_types_or_target, Type)
|
||||
else loop_types_or_target
|
||||
)
|
||||
target = target_or_none
|
||||
|
||||
super().__init__(
|
||||
target.type,
|
||||
loop_types,
|
||||
target,
|
||||
dynamic_sizes=dynamic_sizes,
|
||||
static_sizes=static_sizes,
|
||||
interchange=interchange,
|
||||
scalable_sizes=scalable_sizes,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class TileUsingForallOp:
|
||||
"""Specialization for TileUsingForallOp class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
loops_type: Type,
|
||||
tiled_op_type: Type,
|
||||
target: Union[Operation, Value, OpView],
|
||||
*,
|
||||
num_threads: Optional[MixedValues] = None,
|
||||
tile_sizes: MixedValues = None,
|
||||
mapping=None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value, OpView],
|
||||
*,
|
||||
num_threads: Optional[MixedValues] = None,
|
||||
tile_sizes: MixedValues = None,
|
||||
mapping=None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loops_type_or_target: Union[
|
||||
Type, Union[Operation, Value, OpView] # loops_type
|
||||
], # target
|
||||
tiled_op_type_or_none: Optional[Type] = None,
|
||||
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
|
||||
*,
|
||||
num_threads: MixedValues = None,
|
||||
tile_sizes: MixedValues = None,
|
||||
mapping=None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
# `Type` arguments in the front are optional: add default values to front.
|
||||
if isinstance(loops_type_or_target, Type):
|
||||
# First overload: type arguments provided.
|
||||
if not isinstance(tiled_op_type_or_none, Type):
|
||||
raise TypeError(
|
||||
"If 'loops_type_or_target' is a type, then "
|
||||
"'tiled_op_type_or_none' is expected to be one as well."
|
||||
)
|
||||
loops_type = loops_type_or_target
|
||||
tiled_op_type = tiled_op_type_or_none
|
||||
target = target_or_none
|
||||
else:
|
||||
# Last overload: type arguments missing.
|
||||
loops_type = transform.AnyOpType.get()
|
||||
tiled_op_type = transform.AnyOpType.get()
|
||||
target = loops_type_or_target
|
||||
|
||||
# Unpack mixed num_threads.
|
||||
(
|
||||
dynamic_num_threads,
|
||||
packed_num_threads,
|
||||
num_threads_attr,
|
||||
) = _dispatch_mixed_values(num_threads)
|
||||
|
||||
# Unpack mixed tile_sizes.
|
||||
(
|
||||
dynamic_tile_sizes,
|
||||
packed_tile_sizes,
|
||||
tile_sizes_attr,
|
||||
) = _dispatch_mixed_values(tile_sizes)
|
||||
|
||||
super().__init__(
|
||||
loops_type,
|
||||
tiled_op_type,
|
||||
target=target,
|
||||
tile_sizes=dynamic_tile_sizes,
|
||||
packed_tile_sizes=packed_tile_sizes,
|
||||
static_tile_sizes=tile_sizes_attr,
|
||||
num_threads=dynamic_num_threads,
|
||||
packed_num_threads=packed_num_threads,
|
||||
static_num_threads=num_threads_attr,
|
||||
mapping=mapping,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class VectorizeChildrenAndApplyPatternsOp:
|
||||
"""Specialization for VectorizeChildrenAndApplyPatternsOp class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
disable_multi_reduction_to_contract_patterns: bool = False,
|
||||
disable_transfer_permutation_map_lowering_patterns: bool = False,
|
||||
vectorize_nd_extract: bool = False,
|
||||
vectorize_padding: bool = False,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
transformed_type = transform.AnyOpType.get()
|
||||
super().__init__(
|
||||
transformed_type,
|
||||
target,
|
||||
disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns,
|
||||
disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns,
|
||||
vectorize_nd_extract=vectorize_nd_extract,
|
||||
vectorize_padding=vectorize_padding,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
@@ -1,44 +0,0 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Any, Optional, Sequence, Union
|
||||
from ._ods_common import (
|
||||
get_op_result_or_value as _get_op_result_or_value,
|
||||
get_op_results_or_values as _get_op_results_or_values,
|
||||
)
|
||||
|
||||
|
||||
class EmptyOp:
|
||||
"""Extends the tensor.empty op."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sizes: Sequence[Union[int, Value]],
|
||||
element_type: Type,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
"""Constructs an `empty` with mixed static/dynamic sizes."""
|
||||
# TODO: Refactor the EmptyOp to take an element type attribute and
|
||||
# then use normal result type inference, unifying the Python and C++ side
|
||||
# with a standard mechanism (versus stashing that in builders).
|
||||
dynamic_sizes = []
|
||||
static_sizes = []
|
||||
for s in sizes:
|
||||
if isinstance(s, int):
|
||||
static_sizes.append(s)
|
||||
else:
|
||||
static_sizes.append(ShapedType.get_dynamic_size())
|
||||
dynamic_sizes.append(s)
|
||||
result_type = RankedTensorType.get(static_sizes, element_type)
|
||||
op = self.build_generic(
|
||||
results=[result_type], operands=dynamic_sizes, attributes={}, loc=loc, ip=ip
|
||||
)
|
||||
OpView.__init__(self, op)
|
||||
@@ -1,64 +0,0 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ..dialects import transform
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, overload, Union
|
||||
|
||||
|
||||
class MakeLoopIndependentOp:
|
||||
"""Specialization for MakeLoopIndependentOp class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
transformed_type: Type,
|
||||
target: Union[Operation, OpView, Value],
|
||||
num_loops: Union[int, IntegerAttr],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, OpView, Value],
|
||||
num_loops: Union[int, IntegerAttr],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformed_type_or_target: Type,
|
||||
target_or_num_loops: Union[int, IntegerAttr, Operation, OpView, Value] = None,
|
||||
num_loops_or_none: Optional[Union[int, IntegerAttr]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None
|
||||
):
|
||||
if isinstance(transformed_type_or_target, Type):
|
||||
transformed_type = transformed_type_or_target
|
||||
target = target_or_num_loops
|
||||
num_loops = num_loops_or_none
|
||||
else:
|
||||
transformed_type = transform.AnyOpType.get()
|
||||
target = transformed_type_or_target
|
||||
num_loops = target_or_num_loops
|
||||
|
||||
super().__init__(
|
||||
transformed_type,
|
||||
target,
|
||||
num_loops,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
@@ -1,176 +0,0 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import (
|
||||
get_op_result_or_value as _get_op_result_or_value,
|
||||
get_op_results_or_values as _get_op_results_or_values,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
|
||||
class CastOp:
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
|
||||
|
||||
|
||||
class ApplyPatternsOp:
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value, OpView],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
operands = []
|
||||
operands.append(_get_op_result_or_value(target))
|
||||
super().__init__(
|
||||
self.build_generic(
|
||||
attributes={},
|
||||
results=[],
|
||||
operands=operands,
|
||||
successors=None,
|
||||
regions=None,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
self.regions[0].blocks.append()
|
||||
|
||||
@property
|
||||
def patterns(self) -> Block:
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
|
||||
class testGetParentOp:
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
isolated_from_above: bool = False,
|
||||
op_name: Optional[str] = None,
|
||||
deduplicate: bool = False,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
isolated_from_above=isolated_from_above,
|
||||
op_name=op_name,
|
||||
deduplicate=deduplicate,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class MergeHandlesOp:
|
||||
def __init__(
|
||||
self,
|
||||
handles: Sequence[Union[Operation, Value]],
|
||||
*,
|
||||
deduplicate: bool = False,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(
|
||||
[_get_op_result_or_value(h) for h in handles],
|
||||
deduplicate=deduplicate,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class ReplicateOp:
|
||||
def __init__(
|
||||
self,
|
||||
pattern: Union[Operation, Value],
|
||||
handles: Sequence[Union[Operation, Value]],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(
|
||||
[_get_op_result_or_value(h).type for h in handles],
|
||||
_get_op_result_or_value(pattern),
|
||||
[_get_op_result_or_value(h) for h in handles],
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class SequenceOp:
|
||||
def __init__(
|
||||
self,
|
||||
failure_propagation_mode,
|
||||
results: Sequence[Type],
|
||||
target: Union[Operation, Value, Type],
|
||||
extra_bindings: Optional[
|
||||
Union[Sequence[Value], Sequence[Type], Operation, OpView]
|
||||
] = None,
|
||||
):
|
||||
root = (
|
||||
_get_op_result_or_value(target)
|
||||
if isinstance(target, (Operation, Value))
|
||||
else None
|
||||
)
|
||||
root_type = root.type if not isinstance(target, Type) else target
|
||||
|
||||
if extra_bindings is None:
|
||||
extra_bindings = []
|
||||
if isinstance(extra_bindings, (Operation, OpView)):
|
||||
extra_bindings = _get_op_results_or_values(extra_bindings)
|
||||
|
||||
extra_binding_types = []
|
||||
if len(extra_bindings) != 0:
|
||||
if isinstance(extra_bindings[0], Type):
|
||||
extra_binding_types = extra_bindings
|
||||
extra_bindings = []
|
||||
else:
|
||||
extra_binding_types = [v.type for v in extra_bindings]
|
||||
|
||||
super().__init__(
|
||||
results_=results,
|
||||
failure_propagation_mode=failure_propagation_mode,
|
||||
root=root,
|
||||
extra_bindings=extra_bindings,
|
||||
)
|
||||
self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
|
||||
|
||||
@property
|
||||
def body(self) -> Block:
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
@property
|
||||
def bodyTarget(self) -> Value:
|
||||
return self.body.arguments[0]
|
||||
|
||||
@property
|
||||
def bodyExtraArgs(self) -> BlockArgumentList:
|
||||
return self.body.arguments[1:]
|
||||
|
||||
|
||||
class YieldOp:
|
||||
def __init__(
|
||||
self,
|
||||
operands: Optional[Union[Operation, Sequence[Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if operands is None:
|
||||
operands = []
|
||||
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
|
||||
@@ -1,55 +0,0 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import (
|
||||
get_op_result_or_value as _get_op_result_or_value,
|
||||
get_op_results_or_values as _get_op_results_or_values,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Union
|
||||
|
||||
class PDLMatchOp:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
pattern_name: Union[Attribute, str],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
pattern_name,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class WithPDLPatternsOp:
|
||||
|
||||
def __init__(self,
|
||||
target: Union[Operation, Value, Type],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
root = _get_op_result_or_value(target) if not isinstance(target,
|
||||
Type) else None
|
||||
root_type = target if isinstance(target, Type) else root.type
|
||||
super().__init__(root=root, loc=loc, ip=ip)
|
||||
self.regions[0].blocks.append(root_type)
|
||||
|
||||
@property
|
||||
def body(self) -> Block:
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
@property
|
||||
def bodyTarget(self) -> Value:
|
||||
return self.body.arguments[0]
|
||||
@@ -1,5 +1,50 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._affine_ops_gen import *
|
||||
from ._affine_ops_gen import _Dialect
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import (
|
||||
get_op_result_or_value as _get_op_result_or_value,
|
||||
get_op_results_or_values as _get_op_results_or_values,
|
||||
_cext as _ods_cext,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class AffineStoreOp(AffineStoreOp):
|
||||
"""Specialization for the Affine store operation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
value: Union[Operation, OpView, Value],
|
||||
memref: Union[Operation, OpView, Value],
|
||||
map: AffineMap = None,
|
||||
*,
|
||||
map_operands=None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Creates an affine store operation.
|
||||
|
||||
- `value`: the value to store into the memref.
|
||||
- `memref`: the buffer to store into.
|
||||
- `map`: the affine map that maps the map_operands to the index of the
|
||||
memref.
|
||||
- `map_operands`: the list of arguments to substitute the dimensions,
|
||||
then symbols in the affine map, in increasing order.
|
||||
"""
|
||||
map = map if map is not None else []
|
||||
map_operands = map_operands if map_operands is not None else []
|
||||
indicies = [_get_op_result_or_value(op) for op in map_operands]
|
||||
_ods_successors = None
|
||||
super().__init__(
|
||||
value, memref, indicies, AffineMapAttr.get(map), loc=loc, ip=ip
|
||||
)
|
||||
|
||||
@@ -3,4 +3,75 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._arith_ops_gen import *
|
||||
from ._arith_ops_gen import _Dialect
|
||||
from ._arith_enum_gen import *
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import (
|
||||
get_default_loc_context as _get_default_loc_context,
|
||||
_cext as _ods_cext,
|
||||
)
|
||||
|
||||
from typing import Any, List, Union
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
|
||||
def _isa(obj: Any, cls: type):
|
||||
try:
|
||||
cls(obj)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _is_any_of(obj: Any, classes: List[type]):
|
||||
return any(_isa(obj, cls) for cls in classes)
|
||||
|
||||
|
||||
def _is_integer_like_type(type: Type):
|
||||
return _is_any_of(type, [IntegerType, IndexType])
|
||||
|
||||
|
||||
def _is_float_type(type: Type):
|
||||
return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ConstantOp(ConstantOp):
|
||||
"""Specialization for the constant op class."""
|
||||
|
||||
def __init__(
|
||||
self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
|
||||
):
|
||||
if isinstance(value, int):
|
||||
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
|
||||
elif isinstance(value, float):
|
||||
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
|
||||
else:
|
||||
super().__init__(value, loc=loc, ip=ip)
|
||||
|
||||
@classmethod
|
||||
def create_index(cls, value: int, *, loc=None, ip=None):
|
||||
"""Create an index-typed constant."""
|
||||
return cls(
|
||||
IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return self.results[0].type
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return Attribute(self.operation.attributes["value"])
|
||||
|
||||
@property
|
||||
def literal_value(self) -> Union[int, float]:
|
||||
if _is_integer_like_type(self.type):
|
||||
return IntegerAttr(self.value).value
|
||||
elif _is_float_type(self.type):
|
||||
return FloatAttr(self.value).value
|
||||
else:
|
||||
raise ValueError("only integer and float constants have literal values")
|
||||
|
||||
@@ -3,4 +3,40 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._bufferization_ops_gen import *
|
||||
from ._bufferization_ops_gen import _Dialect
|
||||
from ._bufferization_enum_gen import *
|
||||
|
||||
try:
|
||||
from typing import Sequence, Union
|
||||
from ..ir import *
|
||||
from ._ods_common import get_default_loc_context, _cext as _ods_cext
|
||||
|
||||
from typing import Any, List, Union
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class AllocTensorOp(AllocTensorOp):
|
||||
"""Extends the bufferization.alloc_tensor op."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tensor_type: Type,
|
||||
dynamic_sizes: Sequence[Value],
|
||||
copy: Value,
|
||||
size_hint: Value,
|
||||
escape: BoolAttr,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Constructs an `alloc_tensor` with static and/or dynamic sizes."""
|
||||
super().__init__(
|
||||
tensor_type,
|
||||
dynamic_sizes,
|
||||
copy=copy,
|
||||
size_hint=size_hint,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
@@ -3,3 +3,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._builtin_ops_gen import *
|
||||
from ._builtin_ops_gen import _Dialect
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import _cext as _ods_cext
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ModuleOp(ModuleOp):
|
||||
"""Specialization for the module op class."""
|
||||
|
||||
def __init__(self, *, loc=None, ip=None):
|
||||
super().__init__(loc=loc, ip=ip)
|
||||
body = self.regions[0].blocks.append()
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
@@ -3,3 +3,326 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._func_ops_gen import *
|
||||
from ._func_ops_gen import _Dialect
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import (
|
||||
get_default_loc_context as _get_default_loc_context,
|
||||
_cext as _ods_cext,
|
||||
)
|
||||
|
||||
import inspect
|
||||
|
||||
from typing import Any, List, Optional, Sequence, Union
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
|
||||
RESULT_ATTRIBUTE_NAME = "res_attrs"
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ConstantOp(ConstantOp):
|
||||
"""Specialization for the constant op class."""
|
||||
|
||||
def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
|
||||
super().__init__(result, value, loc=loc, ip=ip)
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return self.results[0].type
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class FuncOp(FuncOp):
|
||||
"""Specialization for the func op class."""
|
||||
|
||||
def __init__(
|
||||
self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
|
||||
):
|
||||
"""
|
||||
Create a FuncOp with the provided `name`, `type`, and `visibility`.
|
||||
- `name` is a string representing the function name.
|
||||
- `type` is either a FunctionType or a pair of list describing inputs and
|
||||
results.
|
||||
- `visibility` is a string matching `public`, `private`, or `nested`. None
|
||||
implies private visibility.
|
||||
- `body_builder` is an optional callback, when provided a new entry block
|
||||
is created and the callback is invoked with the new op as argument within
|
||||
an InsertionPoint context already set for the block. The callback is
|
||||
expected to insert a terminator in the block.
|
||||
"""
|
||||
sym_name = StringAttr.get(str(name))
|
||||
|
||||
# If the type is passed as a tuple, build a FunctionType on the fly.
|
||||
if isinstance(type, tuple):
|
||||
type = FunctionType.get(inputs=type[0], results=type[1])
|
||||
|
||||
type = TypeAttr.get(type)
|
||||
sym_visibility = (
|
||||
StringAttr.get(str(visibility)) if visibility is not None else None
|
||||
)
|
||||
super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
|
||||
if body_builder:
|
||||
entry_block = self.add_entry_block()
|
||||
with InsertionPoint(entry_block):
|
||||
body_builder(self)
|
||||
|
||||
@property
|
||||
def is_external(self):
|
||||
return len(self.regions[0].blocks) == 0
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self.regions[0]
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return FunctionType(TypeAttr(self.attributes["function_type"]).value)
|
||||
|
||||
@property
|
||||
def visibility(self):
|
||||
return self.attributes["sym_visibility"]
|
||||
|
||||
@property
|
||||
def name(self) -> StringAttr:
|
||||
return StringAttr(self.attributes["sym_name"])
|
||||
|
||||
@property
|
||||
def entry_block(self):
|
||||
if self.is_external:
|
||||
raise IndexError("External function does not have a body")
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
|
||||
"""
|
||||
Add an entry block to the function body using the function signature to
|
||||
infer block arguments.
|
||||
Returns the newly created block
|
||||
"""
|
||||
if not self.is_external:
|
||||
raise IndexError("The function already has an entry block!")
|
||||
self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
|
||||
return self.body.blocks[0]
|
||||
|
||||
@property
|
||||
def arg_attrs(self):
|
||||
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
|
||||
|
||||
@arg_attrs.setter
|
||||
def arg_attrs(self, attribute: Union[ArrayAttr, list]):
|
||||
if isinstance(attribute, ArrayAttr):
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
|
||||
else:
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
|
||||
attribute, context=self.context
|
||||
)
|
||||
|
||||
@property
|
||||
def arguments(self):
|
||||
return self.entry_block.arguments
|
||||
|
||||
@property
|
||||
def result_attrs(self):
|
||||
return self.attributes[RESULT_ATTRIBUTE_NAME]
|
||||
|
||||
@result_attrs.setter
|
||||
def result_attrs(self, attribute: ArrayAttr):
|
||||
self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
|
||||
|
||||
@classmethod
|
||||
def from_py_func(
|
||||
FuncOp,
|
||||
*inputs: Type,
|
||||
results: Optional[Sequence[Type]] = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Decorator to define an MLIR FuncOp specified as a python function.
|
||||
|
||||
Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
|
||||
active for the current thread (i.e. established in a `with` block).
|
||||
|
||||
When applied as a decorator to a Python function, an entry block will
|
||||
be constructed for the FuncOp with types as specified in `*inputs`. The
|
||||
block arguments will be passed positionally to the Python function. In
|
||||
addition, if the Python function accepts keyword arguments generally or
|
||||
has a corresponding keyword argument, the following will be passed:
|
||||
* `func_op`: The `func` op being defined.
|
||||
|
||||
By default, the function name will be the Python function `__name__`. This
|
||||
can be overriden by passing the `name` argument to the decorator.
|
||||
|
||||
If `results` is not specified, then the decorator will implicitly
|
||||
insert a `ReturnOp` with the `Value`'s returned from the decorated
|
||||
function. It will also set the `FuncOp` type with the actual return
|
||||
value types. If `results` is specified, then the decorated function
|
||||
must return `None` and no implicit `ReturnOp` is added (nor are the result
|
||||
types updated). The implicit behavior is intended for simple, single-block
|
||||
cases, and users should specify result types explicitly for any complicated
|
||||
cases.
|
||||
|
||||
The decorated function can further be called from Python and will insert
|
||||
a `CallOp` at the then-current insertion point, returning either None (
|
||||
if no return values), a unary Value (for one result), or a list of Values).
|
||||
This mechanism cannot be used to emit recursive calls (by construction).
|
||||
"""
|
||||
|
||||
def decorator(f):
|
||||
from . import func
|
||||
|
||||
# Introspect the callable for optional features.
|
||||
sig = inspect.signature(f)
|
||||
has_arg_func_op = False
|
||||
for param in sig.parameters.values():
|
||||
if param.kind == param.VAR_KEYWORD:
|
||||
has_arg_func_op = True
|
||||
if param.name == "func_op" and (
|
||||
param.kind == param.POSITIONAL_OR_KEYWORD
|
||||
or param.kind == param.KEYWORD_ONLY
|
||||
):
|
||||
has_arg_func_op = True
|
||||
|
||||
# Emit the FuncOp.
|
||||
implicit_return = results is None
|
||||
symbol_name = name or f.__name__
|
||||
function_type = FunctionType.get(
|
||||
inputs=inputs, results=[] if implicit_return else results
|
||||
)
|
||||
func_op = FuncOp(name=symbol_name, type=function_type)
|
||||
with InsertionPoint(func_op.add_entry_block()):
|
||||
func_args = func_op.entry_block.arguments
|
||||
func_kwargs = {}
|
||||
if has_arg_func_op:
|
||||
func_kwargs["func_op"] = func_op
|
||||
return_values = f(*func_args, **func_kwargs)
|
||||
if not implicit_return:
|
||||
return_types = list(results)
|
||||
assert return_values is None, (
|
||||
"Capturing a python function with explicit `results=` "
|
||||
"requires that the wrapped function returns None."
|
||||
)
|
||||
else:
|
||||
# Coerce return values, add ReturnOp and rewrite func type.
|
||||
if return_values is None:
|
||||
return_values = []
|
||||
elif isinstance(return_values, tuple):
|
||||
return_values = list(return_values)
|
||||
elif isinstance(return_values, Value):
|
||||
# Returning a single value is fine, coerce it into a list.
|
||||
return_values = [return_values]
|
||||
elif isinstance(return_values, OpView):
|
||||
# Returning a single operation is fine, coerce its results a list.
|
||||
return_values = return_values.operation.results
|
||||
elif isinstance(return_values, Operation):
|
||||
# Returning a single operation is fine, coerce its results a list.
|
||||
return_values = return_values.results
|
||||
else:
|
||||
return_values = list(return_values)
|
||||
func.ReturnOp(return_values)
|
||||
# Recompute the function type.
|
||||
return_types = [v.type for v in return_values]
|
||||
function_type = FunctionType.get(
|
||||
inputs=inputs, results=return_types
|
||||
)
|
||||
func_op.attributes["function_type"] = TypeAttr.get(function_type)
|
||||
|
||||
def emit_call_op(*call_args):
|
||||
call_op = func.CallOp(
|
||||
return_types, FlatSymbolRefAttr.get(symbol_name), call_args
|
||||
)
|
||||
if return_types is None:
|
||||
return None
|
||||
elif len(return_types) == 1:
|
||||
return call_op.result
|
||||
else:
|
||||
return call_op.results
|
||||
|
||||
wrapped = emit_call_op
|
||||
wrapped.__name__ = f.__name__
|
||||
wrapped.func_op = func_op
|
||||
return wrapped
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class CallOp(CallOp):
|
||||
"""Specialization for the call op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
calleeOrResults: Union[FuncOp, List[Type]],
|
||||
argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
|
||||
arguments: Optional[List] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Creates an call operation.
|
||||
|
||||
The constructor accepts three different forms:
|
||||
|
||||
1. A function op to be called followed by a list of arguments.
|
||||
2. A list of result types, followed by the name of the function to be
|
||||
called as string, following by a list of arguments.
|
||||
3. A list of result types, followed by the name of the function to be
|
||||
called as symbol reference attribute, followed by a list of arguments.
|
||||
|
||||
For example
|
||||
|
||||
f = func.FuncOp("foo", ...)
|
||||
func.CallOp(f, [args])
|
||||
func.CallOp([result_types], "foo", [args])
|
||||
|
||||
In all cases, the location and insertion point may be specified as keyword
|
||||
arguments if not provided by the surrounding context managers.
|
||||
"""
|
||||
|
||||
# TODO: consider supporting constructor "overloads", e.g., through a custom
|
||||
# or pybind-provided metaclass.
|
||||
if isinstance(calleeOrResults, FuncOp):
|
||||
if not isinstance(argumentsOrCallee, list):
|
||||
raise ValueError(
|
||||
"when constructing a call to a function, expected "
|
||||
+ "the second argument to be a list of call arguments, "
|
||||
+ f"got {type(argumentsOrCallee)}"
|
||||
)
|
||||
if arguments is not None:
|
||||
raise ValueError(
|
||||
"unexpected third argument when constructing a call"
|
||||
+ "to a function"
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
calleeOrResults.type.results,
|
||||
FlatSymbolRefAttr.get(
|
||||
calleeOrResults.name.value, context=_get_default_loc_context(loc)
|
||||
),
|
||||
argumentsOrCallee,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(argumentsOrCallee, list):
|
||||
raise ValueError(
|
||||
"when constructing a call to a function by name, "
|
||||
+ "expected the second argument to be a string or a "
|
||||
+ f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}"
|
||||
)
|
||||
|
||||
if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
|
||||
super().__init__(
|
||||
calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip
|
||||
)
|
||||
elif isinstance(argumentsOrCallee, str):
|
||||
super().__init__(
|
||||
calleeOrResults,
|
||||
FlatSymbolRefAttr.get(
|
||||
argumentsOrCallee, context=_get_default_loc_context(loc)
|
||||
),
|
||||
arguments,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
@@ -310,7 +310,7 @@ def emit_named_structured_op(
|
||||
)
|
||||
|
||||
# Set the index attributes used to compute the indexing maps.
|
||||
named_op = getattr(linalg, op_class_name)(ins, outs, result_types)
|
||||
named_op = getattr(linalg, op_class_name)(result_types, ins, outs)
|
||||
for name, value in index_attrs.items():
|
||||
named_op.operation.attributes[name] = value
|
||||
|
||||
|
||||
@@ -296,35 +296,39 @@ def quantized_matmul(
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def matmul_transpose_a(A=TensorDef(T1, S.K, S.N),
|
||||
B=TensorDef(T2, S.K, S.M),
|
||||
C=TensorDef(U, S.M, S.N, output=True),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
|
||||
"""Performs a matrix multiplication of two 2D inputs with lhs operand
|
||||
transposed.
|
||||
def matmul_transpose_a(
|
||||
A=TensorDef(T1, S.K, S.N),
|
||||
B=TensorDef(T2, S.K, S.M),
|
||||
C=TensorDef(U, S.M, S.N, output=True),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast_signed),
|
||||
):
|
||||
"""Performs a matrix multiplication of two 2D inputs with lhs operand
|
||||
transposed.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
"""
|
||||
domain(D.m, D.n, D.k)
|
||||
implements(ContractionOpInterface)
|
||||
C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n])
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
"""
|
||||
domain(D.m, D.n, D.k)
|
||||
implements(ContractionOpInterface)
|
||||
C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def matmul_transpose_b(A=TensorDef(T1, S.M, S.K),
|
||||
B=TensorDef(T2, S.N, S.K),
|
||||
C=TensorDef(U, S.M, S.N, output=True),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
|
||||
"""Performs a matrix multiplication of two 2D inputs with rhs operand
|
||||
transposed.
|
||||
def matmul_transpose_b(
|
||||
A=TensorDef(T1, S.M, S.K),
|
||||
B=TensorDef(T2, S.N, S.K),
|
||||
C=TensorDef(U, S.M, S.N, output=True),
|
||||
cast=TypeFnAttrDef(default=TypeFn.cast_signed),
|
||||
):
|
||||
"""Performs a matrix multiplication of two 2D inputs with rhs operand
|
||||
transposed.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
"""
|
||||
domain(D.m, D.n, D.k)
|
||||
implements(ContractionOpInterface)
|
||||
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k])
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
"""
|
||||
domain(D.m, D.n, D.k)
|
||||
implements(ContractionOpInterface)
|
||||
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k])
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
@@ -390,36 +394,41 @@ def batch_matmul(
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def batch_matmul_transpose_a(A=TensorDef(T1, Batch, S.K, S.M),
|
||||
B=TensorDef(T2, Batch, S.K, S.N),
|
||||
C=TensorDef(U, Batch, S.M, S.N, output=True)):
|
||||
"""Performs a batched matrix multiplication of two 3D inputs where lhs operand
|
||||
has its non-batch dimensions transposed.
|
||||
def batch_matmul_transpose_a(
|
||||
A=TensorDef(T1, Batch, S.K, S.M),
|
||||
B=TensorDef(T2, Batch, S.K, S.N),
|
||||
C=TensorDef(U, Batch, S.M, S.N, output=True),
|
||||
):
|
||||
"""Performs a batched matrix multiplication of two 3D inputs where lhs operand
|
||||
has its non-batch dimensions transposed.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
"""
|
||||
domain(D.b, D.m, D.n, D.k)
|
||||
implements(ContractionOpInterface)
|
||||
C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) \
|
||||
* TypeFn.cast_signed(U, B[D.b, D.k, D.n])
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
"""
|
||||
domain(D.b, D.m, D.n, D.k)
|
||||
implements(ContractionOpInterface)
|
||||
C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) * TypeFn.cast_signed(
|
||||
U, B[D.b, D.k, D.n]
|
||||
)
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def batch_matmul_transpose_b(A=TensorDef(T1, Batch, S.M, S.K),
|
||||
B=TensorDef(T2, Batch, S.N, S.K),
|
||||
C=TensorDef(U, Batch, S.M, S.N, output=True)):
|
||||
"""Performs a batched matrix multiplication of two 3D inputs where rhs operand
|
||||
has its non-batch dimensions transposed.
|
||||
def batch_matmul_transpose_b(
|
||||
A=TensorDef(T1, Batch, S.M, S.K),
|
||||
B=TensorDef(T2, Batch, S.N, S.K),
|
||||
C=TensorDef(U, Batch, S.M, S.N, output=True),
|
||||
):
|
||||
"""Performs a batched matrix multiplication of two 3D inputs where rhs operand
|
||||
has its non-batch dimensions transposed.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
"""
|
||||
domain(D.b, D.m, D.n, D.k)
|
||||
implements(ContractionOpInterface)
|
||||
C[D.b, D.m,
|
||||
D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
|
||||
U, B[D.b, D.n, D.k])
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
"""
|
||||
domain(D.b, D.m, D.n, D.k)
|
||||
implements(ContractionOpInterface)
|
||||
C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
|
||||
U, B[D.b, D.n, D.k]
|
||||
)
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
|
||||
@@ -3,3 +3,41 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._memref_ops_gen import *
|
||||
from ._memref_ops_gen import _Dialect
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import (
|
||||
get_op_result_or_value as _get_op_result_or_value,
|
||||
get_op_results_or_values as _get_op_results_or_values,
|
||||
_cext as _ods_cext,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class LoadOp(LoadOp):
|
||||
"""Specialization for the MemRef load operation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
memref: Union[Operation, OpView, Value],
|
||||
indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Creates a memref load operation.
|
||||
|
||||
Args:
|
||||
memref: the buffer to load from.
|
||||
indices: the list of subscripts, may be empty for zero-dimensional
|
||||
buffers.
|
||||
loc: user-visible location of the operation.
|
||||
ip: insertion point.
|
||||
"""
|
||||
indices_resolved = [] if indices is None else _get_op_results_or_values(indices)
|
||||
super().__init__(memref, indices_resolved, loc=loc, ip=ip)
|
||||
|
||||
@@ -2,4 +2,118 @@
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from typing import Union
|
||||
|
||||
from ._ml_program_ops_gen import *
|
||||
from ._ml_program_ops_gen import _Dialect
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import (
|
||||
get_default_loc_context as _get_default_loc_context,
|
||||
_cext as _ods_cext,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
|
||||
ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
|
||||
RESULT_ATTRIBUTE_NAME = "res_attrs"
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class FuncOp(FuncOp):
|
||||
"""Specialization for the func op class."""
|
||||
|
||||
def __init__(
|
||||
self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
|
||||
):
|
||||
"""
|
||||
Create a FuncOp with the provided `name`, `type`, and `visibility`.
|
||||
- `name` is a string representing the function name.
|
||||
- `type` is either a FunctionType or a pair of list describing inputs and
|
||||
results.
|
||||
- `visibility` is a string matching `public`, `private`, or `nested`. None
|
||||
implies private visibility.
|
||||
- `body_builder` is an optional callback, when provided a new entry block
|
||||
is created and the callback is invoked with the new op as argument within
|
||||
an InsertionPoint context already set for the block. The callback is
|
||||
expected to insert a terminator in the block.
|
||||
"""
|
||||
sym_name = StringAttr.get(str(name))
|
||||
|
||||
# If the type is passed as a tuple, build a FunctionType on the fly.
|
||||
if isinstance(type, tuple):
|
||||
type = FunctionType.get(inputs=type[0], results=type[1])
|
||||
|
||||
type = TypeAttr.get(type)
|
||||
sym_visibility = (
|
||||
StringAttr.get(str(visibility)) if visibility is not None else None
|
||||
)
|
||||
super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
|
||||
if body_builder:
|
||||
entry_block = self.add_entry_block()
|
||||
with InsertionPoint(entry_block):
|
||||
body_builder(self)
|
||||
|
||||
@property
|
||||
def is_external(self):
|
||||
return len(self.regions[0].blocks) == 0
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self.regions[0]
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return FunctionType(TypeAttr(self.attributes["function_type"]).value)
|
||||
|
||||
@property
|
||||
def visibility(self):
|
||||
return self.attributes["sym_visibility"]
|
||||
|
||||
@property
|
||||
def name(self) -> StringAttr:
|
||||
return StringAttr(self.attributes["sym_name"])
|
||||
|
||||
@property
|
||||
def entry_block(self):
|
||||
if self.is_external:
|
||||
raise IndexError("External function does not have a body")
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
def add_entry_block(self):
|
||||
"""
|
||||
Add an entry block to the function body using the function signature to
|
||||
infer block arguments.
|
||||
Returns the newly created block
|
||||
"""
|
||||
if not self.is_external:
|
||||
raise IndexError("The function already has an entry block!")
|
||||
self.body.blocks.append(*self.type.inputs)
|
||||
return self.body.blocks[0]
|
||||
|
||||
@property
|
||||
def arg_attrs(self):
|
||||
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
|
||||
|
||||
@arg_attrs.setter
|
||||
def arg_attrs(self, attribute: Union[ArrayAttr, list]):
|
||||
if isinstance(attribute, ArrayAttr):
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
|
||||
else:
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
|
||||
attribute, context=self.context
|
||||
)
|
||||
|
||||
@property
|
||||
def arguments(self):
|
||||
return self.entry_block.arguments
|
||||
|
||||
@property
|
||||
def result_attrs(self):
|
||||
return self.attributes[RESULT_ATTRIBUTE_NAME]
|
||||
|
||||
@result_attrs.setter
|
||||
def result_attrs(self, attribute: ArrayAttr):
|
||||
self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
|
||||
|
||||
@@ -3,4 +3,289 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._pdl_ops_gen import *
|
||||
from ._pdl_ops_gen import _Dialect
|
||||
from .._mlir_libs._mlirDialectsPDL import *
|
||||
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ..dialects import pdl
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Union, Optional, Sequence, Mapping
|
||||
from ._ods_common import (
|
||||
get_op_result_or_value as _get_value,
|
||||
get_op_results_or_values as _get_values,
|
||||
_cext as _ods_cext,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ApplyNativeConstraintOp(ApplyNativeConstraintOp):
|
||||
"""Specialization for PDL apply native constraint op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Union[str, StringAttr],
|
||||
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if args is None:
|
||||
args = []
|
||||
args = _get_values(args)
|
||||
super().__init__(name, args, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ApplyNativeRewriteOp(ApplyNativeRewriteOp):
|
||||
"""Specialization for PDL apply native rewrite op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
results: Sequence[Type],
|
||||
name: Union[str, StringAttr],
|
||||
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if args is None:
|
||||
args = []
|
||||
args = _get_values(args)
|
||||
super().__init__(results, name, args, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class AttributeOp(AttributeOp):
|
||||
"""Specialization for PDL attribute op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
valueType: Optional[Union[OpView, Operation, Value]] = None,
|
||||
value: Optional[Attribute] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
valueType = valueType if valueType is None else _get_value(valueType)
|
||||
result = pdl.AttributeType.get()
|
||||
super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class EraseOp(EraseOp):
|
||||
"""Specialization for PDL erase op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
operation: Optional[Union[OpView, Operation, Value]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
operation = _get_value(operation)
|
||||
super().__init__(operation, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class OperandOp(OperandOp):
|
||||
"""Specialization for PDL operand op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: Optional[Union[OpView, Operation, Value]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
type = type if type is None else _get_value(type)
|
||||
result = pdl.ValueType.get()
|
||||
super().__init__(result, valueType=type, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class OperandsOp(OperandsOp):
|
||||
"""Specialization for PDL operands op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
types: Optional[Union[OpView, Operation, Value]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
types = types if types is None else _get_value(types)
|
||||
result = pdl.RangeType.get(pdl.ValueType.get())
|
||||
super().__init__(result, valueType=types, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class OperationOp(OperationOp):
|
||||
"""Specialization for PDL operand op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[Union[str, StringAttr]] = None,
|
||||
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
attributes: Optional[Mapping[str, Union[OpView, Operation, Value]]] = None,
|
||||
types: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if types is None:
|
||||
types = []
|
||||
if attributes is None:
|
||||
attributes = {}
|
||||
if args is None:
|
||||
args = []
|
||||
args = _get_values(args)
|
||||
attrNames = []
|
||||
attrValues = []
|
||||
for attrName, attrValue in attributes.items():
|
||||
attrNames.append(StringAttr.get(attrName))
|
||||
attrValues.append(_get_value(attrValue))
|
||||
attrNames = ArrayAttr.get(attrNames)
|
||||
types = _get_values(types)
|
||||
result = pdl.OperationType.get()
|
||||
super().__init__(
|
||||
result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class PatternOp(PatternOp):
|
||||
"""Specialization for PDL pattern op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
benefit: Union[IntegerAttr, int],
|
||||
name: Optional[Union[StringAttr, str]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Creates an PDL `pattern` operation."""
|
||||
super().__init__(benefit, sym_name=name, loc=loc, ip=ip)
|
||||
self.regions[0].blocks.append()
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
"""Return the body (block) of the pattern."""
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ReplaceOp(ReplaceOp):
|
||||
"""Specialization for PDL replace op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
op: Union[OpView, Operation, Value],
|
||||
*,
|
||||
with_op: Optional[Union[OpView, Operation, Value]] = None,
|
||||
with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if with_values is None:
|
||||
with_values = []
|
||||
op = _get_value(op)
|
||||
with_op = with_op if with_op is None else _get_value(with_op)
|
||||
with_values = _get_values(with_values)
|
||||
super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ResultOp(ResultOp):
|
||||
"""Specialization for PDL result op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent: Union[OpView, Operation, Value],
|
||||
index: Union[IntegerAttr, int],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
parent = _get_value(parent)
|
||||
result = pdl.ValueType.get()
|
||||
super().__init__(result, parent, index, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ResultsOp(ResultsOp):
|
||||
"""Specialization for PDL results op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result: Type,
|
||||
parent: Union[OpView, Operation, Value],
|
||||
index: Optional[Union[IntegerAttr, int]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
parent = _get_value(parent)
|
||||
super().__init__(result, parent, index=index, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class RewriteOp(RewriteOp):
|
||||
"""Specialization for PDL rewrite op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: Optional[Union[OpView, Operation, Value]] = None,
|
||||
name: Optional[Union[StringAttr, str]] = None,
|
||||
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if args is None:
|
||||
args = []
|
||||
root = root if root is None else _get_value(root)
|
||||
args = _get_values(args)
|
||||
super().__init__(args, root=root, name=name, loc=loc, ip=ip)
|
||||
|
||||
def add_body(self):
|
||||
"""Add body (block) to the rewrite."""
|
||||
self.regions[0].blocks.append()
|
||||
return self.body
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
"""Return the body (block) of the rewrite."""
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class TypeOp(TypeOp):
|
||||
"""Specialization for PDL type op class."""
|
||||
|
||||
def __init__(
|
||||
self, constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None
|
||||
):
|
||||
result = pdl.TypeType.get()
|
||||
super().__init__(result, constantType=constantType, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class TypesOp(TypesOp):
|
||||
"""Specialization for PDL types op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if constantTypes is None:
|
||||
constantTypes = []
|
||||
result = pdl.RangeType.get(pdl.TypeType.get())
|
||||
super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
|
||||
|
||||
@@ -3,7 +3,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._python_test_ops_gen import *
|
||||
from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestIntegerRankedTensorType
|
||||
from .._mlir_libs._mlirPythonTest import (
|
||||
TestAttr,
|
||||
TestType,
|
||||
TestTensorValue,
|
||||
TestIntegerRankedTensorType,
|
||||
)
|
||||
|
||||
|
||||
def register_python_test_dialect(context, load=True):
|
||||
|
||||
@@ -2,11 +2,122 @@
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from ._scf_ops_gen import *
|
||||
from ._scf_ops_gen import _Dialect
|
||||
from .arith import constant
|
||||
from ..ir import *
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import (
|
||||
get_op_result_or_value as _get_op_result_or_value,
|
||||
get_op_results_or_values as _get_op_results_or_values,
|
||||
_cext as _ods_cext,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
|
||||
_ForOp = ForOp
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ForOp(_ForOp):
|
||||
"""Specialization for the SCF for op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lower_bound,
|
||||
upper_bound,
|
||||
step,
|
||||
iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Creates an SCF `for` operation.
|
||||
|
||||
- `lower_bound` is the value to use as lower bound of the loop.
|
||||
- `upper_bound` is the value to use as upper bound of the loop.
|
||||
- `step` is the value to use as loop step.
|
||||
- `iter_args` is a list of additional loop-carried arguments or an operation
|
||||
producing them as results.
|
||||
"""
|
||||
if iter_args is None:
|
||||
iter_args = []
|
||||
iter_args = _get_op_results_or_values(iter_args)
|
||||
|
||||
results = [arg.type for arg in iter_args]
|
||||
super(_ForOp, self).__init__(
|
||||
self.build_generic(
|
||||
regions=1,
|
||||
results=results,
|
||||
operands=[
|
||||
_get_op_result_or_value(o) for o in [lower_bound, upper_bound, step]
|
||||
]
|
||||
+ list(iter_args),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
)
|
||||
self.regions[0].blocks.append(self.operands[0].type, *results)
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
"""Returns the body (block) of the loop."""
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
@property
|
||||
def induction_variable(self):
|
||||
"""Returns the induction variable of the loop."""
|
||||
return self.body.arguments[0]
|
||||
|
||||
@property
|
||||
def inner_iter_args(self):
|
||||
"""Returns the loop-carried arguments usable within the loop.
|
||||
|
||||
To obtain the loop-carried operands, use `iter_args`.
|
||||
"""
|
||||
return self.body.arguments[1:]
|
||||
|
||||
|
||||
_IfOp = IfOp
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class IfOp(_IfOp):
|
||||
"""Specialization for the SCF if op class."""
|
||||
|
||||
def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
|
||||
"""Creates an SCF `if` operation.
|
||||
|
||||
- `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
|
||||
- `hasElse` determines whether the if operation has the else branch.
|
||||
"""
|
||||
operands = []
|
||||
operands.append(cond)
|
||||
results = []
|
||||
results.extend(results_)
|
||||
super(_IfOp, self).__init__(
|
||||
self.build_generic(
|
||||
regions=2, results=results, operands=operands, loc=loc, ip=ip
|
||||
)
|
||||
)
|
||||
self.regions[0].blocks.append(*[])
|
||||
if hasElse:
|
||||
self.regions[1].blocks.append(*[])
|
||||
|
||||
@property
|
||||
def then_block(self):
|
||||
"""Returns the then block of the if operation."""
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
@property
|
||||
def else_block(self):
|
||||
"""Returns the else block of the if operation."""
|
||||
return self.regions[1].blocks[0]
|
||||
|
||||
|
||||
def for_(
|
||||
|
||||
@@ -3,3 +3,40 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._tensor_ops_gen import *
|
||||
from ._tensor_ops_gen import _Dialect
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Sequence, Union
|
||||
from ._ods_common import _cext as _ods_cext
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class EmptyOp(EmptyOp):
|
||||
"""Extends the tensor.empty op."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sizes: Sequence[Union[int, Value]],
|
||||
element_type: Type,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Constructs an `empty` with mixed static/dynamic sizes."""
|
||||
# TODO: Refactor the EmptyOp to take an element type attribute and
|
||||
# then use normal result type inference, unifying the Python and C++ side
|
||||
# with a standard mechanism (versus stashing that in builders).
|
||||
dynamic_sizes = []
|
||||
static_sizes = []
|
||||
for s in sizes:
|
||||
if isinstance(s, int):
|
||||
static_sizes.append(s)
|
||||
else:
|
||||
static_sizes.append(ShapedType.get_dynamic_size())
|
||||
dynamic_sizes.append(s)
|
||||
result_type = RankedTensorType.get(static_sizes, element_type)
|
||||
super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip)
|
||||
|
||||
@@ -4,4 +4,174 @@
|
||||
|
||||
from .._transform_enum_gen import *
|
||||
from .._transform_ops_gen import *
|
||||
from .._transform_ops_gen import _Dialect
|
||||
from ..._mlir_libs._mlirDialectsTransform import *
|
||||
|
||||
try:
|
||||
from ...ir import *
|
||||
from .._ods_common import (
|
||||
get_op_result_or_value as _get_op_result_or_value,
|
||||
get_op_results_or_values as _get_op_results_or_values,
|
||||
_cext as _ods_cext,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class CastOp(CastOp):
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ApplyPatternsOp(ApplyPatternsOp):
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value, OpView],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(target, loc=loc, ip=ip)
|
||||
self.regions[0].blocks.append()
|
||||
|
||||
@property
|
||||
def patterns(self) -> Block:
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class GetParentOp(GetParentOp):
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
isolated_from_above: bool = False,
|
||||
op_name: Optional[str] = None,
|
||||
deduplicate: bool = False,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
isolated_from_above=isolated_from_above,
|
||||
op_name=op_name,
|
||||
deduplicate=deduplicate,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class MergeHandlesOp(MergeHandlesOp):
|
||||
def __init__(
|
||||
self,
|
||||
handles: Sequence[Union[Operation, Value]],
|
||||
*,
|
||||
deduplicate: bool = False,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(
|
||||
[_get_op_result_or_value(h) for h in handles],
|
||||
deduplicate=deduplicate,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ReplicateOp(ReplicateOp):
|
||||
def __init__(
|
||||
self,
|
||||
pattern: Union[Operation, Value],
|
||||
handles: Sequence[Union[Operation, Value]],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(
|
||||
[_get_op_result_or_value(h).type for h in handles],
|
||||
_get_op_result_or_value(pattern),
|
||||
[_get_op_result_or_value(h) for h in handles],
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class SequenceOp(SequenceOp):
|
||||
def __init__(
|
||||
self,
|
||||
failure_propagation_mode,
|
||||
results: Sequence[Type],
|
||||
target: Union[Operation, Value, Type],
|
||||
extra_bindings: Optional[
|
||||
Union[Sequence[Value], Sequence[Type], Operation, OpView]
|
||||
] = None,
|
||||
):
|
||||
root = (
|
||||
_get_op_result_or_value(target)
|
||||
if isinstance(target, (Operation, Value))
|
||||
else None
|
||||
)
|
||||
root_type = root.type if not isinstance(target, Type) else target
|
||||
|
||||
if extra_bindings is None:
|
||||
extra_bindings = []
|
||||
if isinstance(extra_bindings, (Operation, OpView)):
|
||||
extra_bindings = _get_op_results_or_values(extra_bindings)
|
||||
|
||||
extra_binding_types = []
|
||||
if len(extra_bindings) != 0:
|
||||
if isinstance(extra_bindings[0], Type):
|
||||
extra_binding_types = extra_bindings
|
||||
extra_bindings = []
|
||||
else:
|
||||
extra_binding_types = [v.type for v in extra_bindings]
|
||||
|
||||
super().__init__(
|
||||
results_=results,
|
||||
failure_propagation_mode=failure_propagation_mode,
|
||||
root=root,
|
||||
extra_bindings=extra_bindings,
|
||||
)
|
||||
self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
|
||||
|
||||
@property
|
||||
def body(self) -> Block:
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
@property
|
||||
def bodyTarget(self) -> Value:
|
||||
return self.body.arguments[0]
|
||||
|
||||
@property
|
||||
def bodyExtraArgs(self) -> BlockArgumentList:
|
||||
return self.body.arguments[1:]
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class YieldOp(YieldOp):
|
||||
def __init__(
|
||||
self,
|
||||
operands: Optional[Union[Operation, Sequence[Value]]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if operands is None:
|
||||
operands = []
|
||||
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
|
||||
|
||||
@@ -3,3 +3,132 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from .._bufferization_transform_ops_gen import *
|
||||
from .._bufferization_transform_ops_gen import _Dialect
|
||||
|
||||
try:
|
||||
from ...ir import *
|
||||
from ...dialects import transform
|
||||
from .._ods_common import _cext as _ods_cext
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional, overload, Union
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class EmptyTensorToAllocTensorOp(EmptyTensorToAllocTensorOp):
|
||||
"""Specialization for EmptyTensorToAllocTensorOp class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
transformed_type: Type,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformed_type_or_target: Type,
|
||||
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(transformed_type_or_target, Type):
|
||||
transformed_type = transformed_type_or_target
|
||||
target = target_or_none
|
||||
else:
|
||||
transformed_type = transform.OperationType.get("bufferization.alloc_tensor")
|
||||
target = transformed_type_or_target
|
||||
|
||||
super().__init__(
|
||||
transformed_type,
|
||||
target,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class OneShotBufferizeOp(OneShotBufferizeOp):
|
||||
"""Specialization for OneShotBufferizeOp class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
transformed_type: Type,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
allow_return_allocs_from_loops: Optional[bool] = None,
|
||||
allow_unknown_ops: Optional[bool] = None,
|
||||
bufferize_function_boundaries: Optional[bool] = None,
|
||||
function_boundary_type_conversion: Optional[Enum] = None,
|
||||
memcpy_op: Optional[str] = None,
|
||||
print_conflicts: Optional[bool] = None,
|
||||
test_analysis_only: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
allow_return_allocs_from_loops: Optional[bool] = None,
|
||||
allow_unknown_ops: Optional[bool] = None,
|
||||
bufferize_function_boundaries: Optional[bool] = None,
|
||||
function_boundary_type_conversion: Optional[Enum] = None,
|
||||
memcpy_op: Optional[str] = None,
|
||||
print_conflicts: Optional[bool] = None,
|
||||
test_analysis_only: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformed_type_or_target: Type,
|
||||
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
|
||||
*,
|
||||
allow_return_allocs_from_loops: Optional[bool] = None,
|
||||
allow_unknown_ops: Optional[bool] = None,
|
||||
bufferize_function_boundaries: Optional[bool] = None,
|
||||
function_boundary_type_conversion: Optional[Enum] = None,
|
||||
memcpy_op: Optional[str] = None,
|
||||
print_conflicts: Optional[bool] = None,
|
||||
test_analysis_only: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(transformed_type_or_target, Type):
|
||||
transformed_type = transformed_type_or_target
|
||||
target = target_or_none
|
||||
else:
|
||||
transformed_type = transform.AnyOpType.get()
|
||||
target = transformed_type_or_target
|
||||
|
||||
super().__init__(
|
||||
transformed_type,
|
||||
target,
|
||||
allow_return_allocs_from_loops=allow_return_allocs_from_loops,
|
||||
allow_unknown_ops=allow_unknown_ops,
|
||||
bufferize_function_boundaries=bufferize_function_boundaries,
|
||||
function_boundary_type_conversion=function_boundary_type_conversion,
|
||||
memcpy_op=memcpy_op,
|
||||
print_conflicts=print_conflicts,
|
||||
test_analysis_only=test_analysis_only,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
@@ -3,3 +3,128 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from .._gpu_transform_ops_gen import *
|
||||
from .._gpu_transform_ops_gen import _Dialect
|
||||
|
||||
try:
|
||||
from ...ir import *
|
||||
from ...dialects import transform
|
||||
from .._ods_common import _cext as _ods_cext
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, Sequence, Union, overload
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class MapForallToBlocks(MapForallToBlocks):
|
||||
"""Specialization for MapForallToBlocks class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
|
||||
generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
|
||||
generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result_type_or_target: Union[Operation, OpView, Type, Value],
|
||||
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
|
||||
*,
|
||||
grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
|
||||
generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(result_type_or_target, Type):
|
||||
result_type = result_type_or_target
|
||||
target = target_or_none
|
||||
else:
|
||||
result_type = transform.AnyOpType.get()
|
||||
target = result_type_or_target
|
||||
|
||||
super().__init__(
|
||||
result_type,
|
||||
target,
|
||||
grid_dims=grid_dims,
|
||||
generate_gpu_launch=generate_gpu_launch,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class MapNestedForallToThreads(MapNestedForallToThreads):
|
||||
"""Specialization for MapNestedForallToThreads class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
block_dims: Optional[Sequence[int]] = None,
|
||||
warp_size: Optional[Sequence[int]] = None,
|
||||
sync_after_distribute: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
block_dims: Optional[Sequence[int]] = None,
|
||||
warp_size: Optional[Sequence[int]] = None,
|
||||
sync_after_distribute: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result_type_or_target: Union[Operation, OpView, Value, Type],
|
||||
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
|
||||
*,
|
||||
block_dims: Optional[Union[Sequence[int], Attribute]] = None,
|
||||
warp_size: Optional[Union[Sequence[int], Attribute]] = None,
|
||||
sync_after_distribute: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(result_type_or_target, Type):
|
||||
result_type = result_type_or_target
|
||||
target = target_or_none
|
||||
else:
|
||||
result_type = result_type_or_target.type
|
||||
target = result_type_or_target
|
||||
super().__init__(
|
||||
result_type,
|
||||
target,
|
||||
block_dims=block_dims,
|
||||
warp_size=warp_size,
|
||||
sync_after_distribute=sync_after_distribute,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
@@ -3,3 +3,143 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from .._loop_transform_ops_gen import *
|
||||
from .._loop_transform_ops_gen import _Dialect
|
||||
|
||||
try:
|
||||
from ...ir import *
|
||||
from .._ods_common import (
|
||||
get_op_result_or_value as _get_op_result_or_value,
|
||||
_cext as _ods_cext,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class GetParentForOp(GetParentForOp):
|
||||
"""Extension for GetParentForOp."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
num_loops: Optional[int] = None,
|
||||
ip=None,
|
||||
loc=None,
|
||||
):
|
||||
if num_loops is None:
|
||||
num_loops = 1
|
||||
super().__init__(
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
num_loops=num_loops,
|
||||
ip=ip,
|
||||
loc=loc,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class LoopOutlineOp(LoopOutlineOp):
|
||||
"""Extension for LoopOutlineOp."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
function_type: Type,
|
||||
call_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
func_name: Union[str, StringAttr],
|
||||
ip=None,
|
||||
loc=None,
|
||||
):
|
||||
super().__init__(
|
||||
function_type,
|
||||
call_type,
|
||||
_get_op_result_or_value(target),
|
||||
func_name=(
|
||||
func_name
|
||||
if isinstance(func_name, StringAttr)
|
||||
else StringAttr.get(func_name)
|
||||
),
|
||||
ip=ip,
|
||||
loc=loc,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class LoopPeelOp(LoopPeelOp):
|
||||
"""Extension for LoopPeelOp."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
main_loop_type: Type,
|
||||
remainder_loop_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
fail_if_already_divisible: Union[bool, BoolAttr] = False,
|
||||
ip=None,
|
||||
loc=None,
|
||||
):
|
||||
super().__init__(
|
||||
main_loop_type,
|
||||
remainder_loop_type,
|
||||
_get_op_result_or_value(target),
|
||||
fail_if_already_divisible=(
|
||||
fail_if_already_divisible
|
||||
if isinstance(fail_if_already_divisible, BoolAttr)
|
||||
else BoolAttr.get(fail_if_already_divisible)
|
||||
),
|
||||
ip=ip,
|
||||
loc=loc,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class LoopPipelineOp(LoopPipelineOp):
|
||||
"""Extension for LoopPipelineOp."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
iteration_interval: Optional[Union[int, IntegerAttr]] = None,
|
||||
read_latency: Optional[Union[int, IntegerAttr]] = None,
|
||||
ip=None,
|
||||
loc=None,
|
||||
):
|
||||
if iteration_interval is None:
|
||||
iteration_interval = 1
|
||||
if read_latency is None:
|
||||
read_latency = 10
|
||||
super().__init__(
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
iteration_interval=iteration_interval,
|
||||
read_latency=read_latency,
|
||||
ip=ip,
|
||||
loc=loc,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class LoopUnrollOp(LoopUnrollOp):
|
||||
"""Extension for LoopUnrollOp."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
factor: Union[int, IntegerAttr],
|
||||
ip=None,
|
||||
loc=None,
|
||||
):
|
||||
super().__init__(
|
||||
_get_op_result_or_value(target),
|
||||
factor=factor,
|
||||
ip=ip,
|
||||
loc=loc,
|
||||
)
|
||||
|
||||
@@ -3,3 +3,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from .._memref_transform_ops_gen import *
|
||||
from .._memref_transform_ops_gen import _Dialect
|
||||
|
||||
try:
|
||||
from ...ir import *
|
||||
from ...dialects import transform
|
||||
from .._ods_common import _cext as _ods_cext
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, overload, Union
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class MemRefAllocaToGlobalOp(MemRefAllocaToGlobalOp):
|
||||
"""Specialization for MemRefAllocaToGlobalOp class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
get_global_type: Type,
|
||||
global_type: Type,
|
||||
alloca: Union[Operation, OpView, Value],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(self, alloca: Union[Operation, OpView, Value], *, loc=None, ip=None):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
get_global_type_or_alloca: Union[Operation, OpView, Type, Value],
|
||||
global_type_or_none: Optional[Type] = None,
|
||||
alloca_or_none: Optional[Union[Operation, OpView, Value]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(get_global_type_or_alloca, Type):
|
||||
get_global_type = get_global_type_or_alloca
|
||||
global_type = global_type_or_none
|
||||
alloca = alloca_or_none
|
||||
else:
|
||||
get_global_type = transform.AnyOpType.get()
|
||||
global_type = transform.AnyOpType.get()
|
||||
alloca = get_global_type_or_alloca
|
||||
|
||||
super().__init__(
|
||||
get_global_type,
|
||||
global_type,
|
||||
alloca,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class MemRefMultiBufferOp(MemRefMultiBufferOp):
|
||||
"""Specialization for MemRefMultiBufferOp class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
transformed_type: Type,
|
||||
target: Union[Operation, OpView, Value],
|
||||
factor: Union[int, IntegerAttr],
|
||||
*,
|
||||
skip_analysis: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, OpView, Value],
|
||||
factor: Union[int, IntegerAttr],
|
||||
*,
|
||||
skip_analysis: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformed_type_or_target: Type,
|
||||
target_or_factor: Union[int, IntegerAttr, Operation, OpView, Value] = None,
|
||||
factor_or_none: Optional[Union[int, IntegerAttr]] = None,
|
||||
*,
|
||||
skip_analysis: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(transformed_type_or_target, Type):
|
||||
transformed_type = transformed_type_or_target
|
||||
target = target_or_factor
|
||||
factor = factor_or_none
|
||||
else:
|
||||
transformed_type = transform.AnyOpType.get()
|
||||
target = transformed_type_or_target
|
||||
factor = target_or_factor
|
||||
|
||||
super().__init__(
|
||||
transformed_type,
|
||||
target,
|
||||
factor,
|
||||
skip_analysis=skip_analysis,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
@@ -3,3 +3,53 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from .._transform_pdl_extension_ops_gen import *
|
||||
from .._transform_pdl_extension_ops_gen import _Dialect
|
||||
|
||||
try:
|
||||
from ...ir import *
|
||||
from .._ods_common import (
|
||||
get_op_result_or_value as _get_op_result_or_value,
|
||||
get_op_results_or_values as _get_op_results_or_values,
|
||||
_cext as _ods_cext,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Union
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class PDLMatchOp(PDLMatchOp):
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
pattern_name: Union[Attribute, str],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
pattern_name,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class WithPDLPatternsOp(WithPDLPatternsOp):
|
||||
def __init__(self, target: Union[Operation, Value, Type], *, loc=None, ip=None):
|
||||
root = _get_op_result_or_value(target) if not isinstance(target, Type) else None
|
||||
root_type = target if isinstance(target, Type) else root.type
|
||||
super().__init__(root=root, loc=loc, ip=ip)
|
||||
self.regions[0].blocks.append(root_type)
|
||||
|
||||
@property
|
||||
def body(self) -> Block:
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
@property
|
||||
def bodyTarget(self) -> Value:
|
||||
return self.body.arguments[0]
|
||||
|
||||
@@ -3,4 +3,777 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from .._structured_transform_ops_gen import *
|
||||
from .._structured_transform_ops_gen import _Dialect
|
||||
from .._structured_transform_enum_gen import *
|
||||
|
||||
try:
|
||||
from ...ir import *
|
||||
from ...dialects import transform
|
||||
from .._ods_common import _cext as _ods_cext
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import List, Optional, Sequence, Tuple, Union, overload
|
||||
|
||||
StaticIntLike = Union[int, IntegerAttr]
|
||||
ValueLike = Union[Operation, OpView, Value]
|
||||
MixedInt = Union[StaticIntLike, ValueLike]
|
||||
|
||||
IntOrAttrList = Sequence[Union[IntegerAttr, int]]
|
||||
OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
|
||||
|
||||
BoolOrAttrList = Sequence[Union[BoolAttr, bool]]
|
||||
OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]]
|
||||
|
||||
MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
|
||||
|
||||
DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]]
|
||||
|
||||
|
||||
def _dispatch_dynamic_index_list(
|
||||
indices: Union[DynamicIndexList, ArrayAttr],
|
||||
) -> Tuple[List[ValueLike], Union[List[int], ArrayAttr], List[bool]]:
|
||||
"""Dispatches a list of indices to the appropriate form.
|
||||
|
||||
This is similar to the custom `DynamicIndexList` directive upstream:
|
||||
provided indices may be in the form of dynamic SSA values or static values,
|
||||
and they may be scalable (i.e., as a singleton list) or not. This function
|
||||
dispatches each index into its respective form. It also extracts the SSA
|
||||
values and static indices from various similar structures, respectively.
|
||||
"""
|
||||
dynamic_indices = []
|
||||
static_indices = [ShapedType.get_dynamic_size()] * len(indices)
|
||||
scalable_indices = [False] * len(indices)
|
||||
|
||||
# ArrayAttr: Extract index values.
|
||||
if isinstance(indices, ArrayAttr):
|
||||
indices = [idx for idx in indices]
|
||||
|
||||
def process_nonscalable_index(i, index):
|
||||
"""Processes any form of non-scalable index.
|
||||
|
||||
Returns False if the given index was scalable and thus remains
|
||||
unprocessed; True otherwise.
|
||||
"""
|
||||
if isinstance(index, int):
|
||||
static_indices[i] = index
|
||||
elif isinstance(index, IntegerAttr):
|
||||
static_indices[i] = index.value # pytype: disable=attribute-error
|
||||
elif isinstance(index, (Operation, Value, OpView)):
|
||||
dynamic_indices.append(index)
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
# Process each index at a time.
|
||||
for i, index in enumerate(indices):
|
||||
if not process_nonscalable_index(i, index):
|
||||
# If it wasn't processed, it must be a scalable index, which is
|
||||
# provided as a Sequence of one value, so extract and process that.
|
||||
scalable_indices[i] = True
|
||||
assert len(index) == 1
|
||||
ret = process_nonscalable_index(i, index[0])
|
||||
assert ret
|
||||
|
||||
return dynamic_indices, static_indices, scalable_indices
|
||||
|
||||
|
||||
# Dispatches `MixedValues` that all represents integers in various forms into
|
||||
# the following three categories:
|
||||
# - `dynamic_values`: a list of `Value`s, potentially from op results;
|
||||
# - `packed_values`: a value handle, potentially from an op result, associated
|
||||
# to one or more payload operations of integer type;
|
||||
# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
|
||||
# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
|
||||
# The input is in the form for `packed_values`, only that result is set and the
|
||||
# other two are empty. Otherwise, the input can be a mix of the other two forms,
|
||||
# and for each dynamic value, a special value is added to the `static_values`.
|
||||
def _dispatch_mixed_values(
|
||||
values: MixedValues,
|
||||
) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]:
|
||||
dynamic_values = []
|
||||
packed_values = None
|
||||
static_values = None
|
||||
if isinstance(values, ArrayAttr):
|
||||
static_values = values
|
||||
elif isinstance(values, (Operation, Value, OpView)):
|
||||
packed_values = values
|
||||
else:
|
||||
static_values = []
|
||||
for size in values or []:
|
||||
if isinstance(size, int):
|
||||
static_values.append(size)
|
||||
else:
|
||||
static_values.append(ShapedType.get_dynamic_size())
|
||||
dynamic_values.append(size)
|
||||
static_values = DenseI64ArrayAttr.get(static_values)
|
||||
|
||||
return (dynamic_values, packed_values, static_values)
|
||||
|
||||
|
||||
def _get_value_or_attribute_value(
|
||||
value_or_attr: Union[any, Attribute, ArrayAttr]
|
||||
) -> any:
|
||||
if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
|
||||
return value_or_attr.value
|
||||
if isinstance(value_or_attr, ArrayAttr):
|
||||
return _get_value_list(value_or_attr)
|
||||
return value_or_attr
|
||||
|
||||
|
||||
def _get_value_list(
|
||||
sequence_or_array_attr: Union[Sequence[any], ArrayAttr]
|
||||
) -> Sequence[any]:
|
||||
return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
|
||||
|
||||
|
||||
def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr:
|
||||
if values is None:
|
||||
return None
|
||||
|
||||
# Turn into a Python list of Python ints.
|
||||
values = _get_value_list(values)
|
||||
|
||||
# Make an ArrayAttr of IntegerAttrs out of it.
|
||||
return ArrayAttr.get(
|
||||
[IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
|
||||
)
|
||||
|
||||
|
||||
def _get_int_array_array_attr(
|
||||
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
|
||||
) -> ArrayAttr:
|
||||
"""Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
|
||||
|
||||
The input has to be a collection of collection of integers, where any
|
||||
Python Sequence and ArrayAttr are admissible collections and Python ints and
|
||||
any IntegerAttr are admissible integers. Both levels of collections are
|
||||
turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
|
||||
If the input is None, an empty ArrayAttr is returned.
|
||||
"""
|
||||
if values is None:
|
||||
return None
|
||||
|
||||
# Make sure the outer level is a list.
|
||||
values = _get_value_list(values)
|
||||
|
||||
# The inner level is now either invalid or a mixed sequence of ArrayAttrs and
|
||||
# Sequences. Make sure the nested values are all lists.
|
||||
values = [_get_value_list(nested) for nested in values]
|
||||
|
||||
# Turn each nested list into an ArrayAttr.
|
||||
values = [_get_int_array_attr(nested) for nested in values]
|
||||
|
||||
# Turn the outer list into an ArrayAttr.
|
||||
return ArrayAttr.get(values)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class BufferizeToAllocationOp(BufferizeToAllocationOp):
|
||||
"""Specialization for BufferizeToAllocationOp class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
memory_space: Optional[Union[int, str, Attribute]] = None,
|
||||
memcpy_op: Optional[str] = None,
|
||||
alloc_op: Optional[str] = None,
|
||||
bufferize_destination_only: Optional[bool] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
# No other types are allowed, so hard-code those here.
|
||||
allocated_buffer_type = transform.AnyValueType.get()
|
||||
new_ops_type = transform.AnyOpType.get()
|
||||
|
||||
if isinstance(memory_space, int):
|
||||
memory_space = str(memory_space)
|
||||
if isinstance(memory_space, str):
|
||||
memory_space = Attribute.parse(memory_space)
|
||||
|
||||
super().__init__(
|
||||
allocated_buffer_type,
|
||||
new_ops_type,
|
||||
target,
|
||||
memory_space=memory_space,
|
||||
memcpy_op=memcpy_op,
|
||||
alloc_op=alloc_op,
|
||||
bufferize_destination_only=bufferize_destination_only,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class DecomposeOp(DecomposeOp):
|
||||
"""Specialization for DecomposeOp class."""
|
||||
|
||||
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
|
||||
transformed_type = transform.AnyOpType.get()
|
||||
super().__init__(transformed_type, target, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class FuseIntoContainingOp(FuseIntoContainingOp):
|
||||
"""Specialization for FuseIntoContainingOp class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
fused_op_type: Type,
|
||||
new_containing_op_type: Type,
|
||||
producer_op: Union[Operation, OpView, Value],
|
||||
containing_op: Union[Operation, OpView, Value],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
producer_op: Union[Operation, OpView, Value],
|
||||
containing_op: Union[Operation, OpView, Value],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fused_op_type_or_producer_op: Union[Operation, OpView, Type, Value],
|
||||
new_containing_op_type_or_containing_op: Union[Operation, OpView, Type, Value],
|
||||
producer_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
|
||||
containing_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(fused_op_type_or_producer_op, Type):
|
||||
if not isinstance(new_containing_op_type_or_containing_op, Type):
|
||||
raise TypeError(
|
||||
"If 'fused_op_type_or_producer_op' is a type, then "
|
||||
"'new_containing_op_type_or_containing_op' is expected "
|
||||
"to be one as well."
|
||||
)
|
||||
fused_op_type = fused_op_type_or_producer_op
|
||||
new_containing_op_type = new_containing_op_type_or_containing_op
|
||||
producer_op = producer_op_or_none
|
||||
containing_op = containing_op_or_none
|
||||
else:
|
||||
fused_op_type = transform.AnyOpType.get()
|
||||
new_containing_op_type = transform.AnyOpType.get()
|
||||
producer_op = fused_op_type_or_producer_op
|
||||
containing_op = new_containing_op_type_or_containing_op
|
||||
|
||||
super().__init__(
|
||||
fused_op_type,
|
||||
new_containing_op_type,
|
||||
producer_op,
|
||||
containing_op,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class GeneralizeOp(GeneralizeOp):
|
||||
"""Specialization for GeneralizeOp class."""
|
||||
|
||||
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
|
||||
transformed_type = transform.AnyOpType.get()
|
||||
super().__init__(transformed_type, target, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class InterchangeOp(InterchangeOp):
|
||||
"""Specialization for InterchangeOp class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
iterator_interchange: OptionalIntList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
transformed_type = transform.AnyOpType.get()
|
||||
super().__init__(
|
||||
transformed_type,
|
||||
target,
|
||||
iterator_interchange=iterator_interchange,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class MapCopyToThreadsOp(MapCopyToThreadsOp):
|
||||
"""Specialization for MapCopyToThreadsOp class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
forall_op_type: Type,
|
||||
tiled_op_type: Type,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
total_num_threads: Union[int, IntegerAttr],
|
||||
desired_bit_alignment: Union[int, IntegerAttr],
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
total_num_threads: Union[int, IntegerAttr],
|
||||
desired_bit_alignment: Union[int, IntegerAttr],
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
forall_op_type_or_target: Union[Operation, OpView, Type, Value],
|
||||
tiled_op_type_or_none: Optional[Type] = None,
|
||||
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
|
||||
*,
|
||||
total_num_threads: Union[int, IntegerAttr],
|
||||
desired_bit_alignment: Union[int, IntegerAttr],
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(forall_op_type_or_target, Type):
|
||||
forall_op_type = forall_op_type_or_target
|
||||
tiled_op_type = tiled_op_type_or_none
|
||||
target = target_or_none
|
||||
else:
|
||||
forall_op_type = transform.AnyOpType.get()
|
||||
tiled_op_type = transform.AnyOpType.get()
|
||||
target = forall_op_type_or_target
|
||||
|
||||
super().__init__(
|
||||
forall_op_type,
|
||||
tiled_op_type,
|
||||
target,
|
||||
total_num_threads=total_num_threads,
|
||||
desired_bit_alignment=desired_bit_alignment,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class VectorizeOp(VectorizeOp):
|
||||
"""Specialization for VectorizeOp class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, OpView, Value],
|
||||
vector_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
|
||||
*,
|
||||
vectorize_nd_extract: Optional[bool] = None,
|
||||
scalable_sizes: OptionalBoolList = None,
|
||||
static_vector_sizes: OptionalIntList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if (
|
||||
scalable_sizes is None
|
||||
and static_vector_sizes is None
|
||||
and vector_sizes is None
|
||||
):
|
||||
dynamic_vector_sizes = []
|
||||
elif scalable_sizes is None and static_vector_sizes is None:
|
||||
(
|
||||
dynamic_vector_sizes,
|
||||
static_vector_sizes,
|
||||
scalable_sizes,
|
||||
) = _dispatch_dynamic_index_list(vector_sizes)
|
||||
elif scalable_sizes is None or static_vector_sizes is None:
|
||||
raise TypeError(
|
||||
"'scalable_sizes' and 'static_vector_sizes' must either both "
|
||||
"be given explicitly or both be given as part of 'vector_sizes'."
|
||||
)
|
||||
else:
|
||||
dynamic_vector_sizes = vector_sizes
|
||||
|
||||
super().__init__(
|
||||
target,
|
||||
vector_sizes=dynamic_vector_sizes,
|
||||
static_vector_sizes=static_vector_sizes,
|
||||
scalable_sizes=scalable_sizes,
|
||||
vectorize_nd_extract=vectorize_nd_extract,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class MatchOp(MatchOp):
|
||||
"""Specialization for MatchOp class."""
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def match_op_names(
|
||||
cls,
|
||||
target: Union[Operation, Value],
|
||||
names: Union[str, Sequence[str]],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def match_op_names(
|
||||
cls,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
names: Union[str, Sequence[str]],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def match_op_names(
|
||||
cls,
|
||||
result_type_or_target: Union[Type, Operation, Value],
|
||||
target_or_names: Union[Operation, Value, Sequence[str], str],
|
||||
names_or_none: Optional[Union[Sequence[str], str]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(result_type_or_target, Type):
|
||||
result_type = result_type_or_target
|
||||
target = target_or_names
|
||||
names = names_or_none
|
||||
else:
|
||||
result_type = transform.AnyOpType.get()
|
||||
target = result_type_or_target
|
||||
names = target_or_names
|
||||
|
||||
if isinstance(names, str):
|
||||
names = [names]
|
||||
|
||||
return cls(
|
||||
result_type,
|
||||
target,
|
||||
ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class MultiTileSizesOp(MultiTileSizesOp):
|
||||
"""Specialization for MultiTileSizesOp class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
dimension: Union[int, IntegerAttr],
|
||||
target_size: Union[int, IntegerAttr],
|
||||
divisor: Optional[Optional[Union[int, IntegerAttr]]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(
|
||||
result_type,
|
||||
result_type,
|
||||
result_type,
|
||||
target,
|
||||
dimension=dimension,
|
||||
target_size=target_size,
|
||||
divisor=divisor,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class PadOp(PadOp):
|
||||
"""Specialization for PadOp class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
|
||||
padding_dimensions: OptionalIntList = None,
|
||||
pad_to_multiple_of: OptionalIntList = None,
|
||||
pack_paddings: OptionalIntList = None,
|
||||
transpose_paddings: Optional[
|
||||
Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
|
||||
] = None,
|
||||
copy_back_op: Optional[Union[str, StringAttr]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
transpose_paddings = _get_int_array_array_attr(transpose_paddings)
|
||||
|
||||
any_op_type = transform.AnyOpType.get()
|
||||
super().__init__(
|
||||
any_op_type,
|
||||
any_op_type,
|
||||
any_op_type,
|
||||
target,
|
||||
padding_values=padding_values,
|
||||
padding_dimensions=padding_dimensions,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
pack_paddings=pack_paddings,
|
||||
transpose_paddings=transpose_paddings,
|
||||
copy_back_op=copy_back_op,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ScalarizeOp(ScalarizeOp):
|
||||
"""Specialization for ScalarizeOp class."""
|
||||
|
||||
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
|
||||
result_type = transform.AnyOpType.get()
|
||||
super().__init__(result_type, target, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class SplitOp(SplitOp):
|
||||
"""Specialization for SplitOp class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value],
|
||||
dimension: Union[int, Attribute],
|
||||
split_point: Union[int, Operation, Value, Attribute],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(split_point, int):
|
||||
static_split_point = split_point
|
||||
dynamic_split_point = None
|
||||
else:
|
||||
static_split_point = ShapedType.get_dynamic_size()
|
||||
dynamic_split_point = split_point
|
||||
|
||||
super().__init__(
|
||||
target.type,
|
||||
target.type,
|
||||
target,
|
||||
dimension=dimension,
|
||||
static_split_point=static_split_point,
|
||||
dynamic_split_point=dynamic_split_point,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class TileUsingForOp(TileUsingForOp):
|
||||
"""Specialization for TileUsingForOp class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
loop_types: Union[Type, List[Type]],
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
|
||||
interchange: OptionalIntList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value, OpView],
|
||||
*,
|
||||
sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
|
||||
interchange: OptionalIntList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loop_types_or_target: Union[Type, List[Type], Operation, Value],
|
||||
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
|
||||
*,
|
||||
sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
|
||||
interchange: OptionalIntList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
(
|
||||
dynamic_sizes,
|
||||
static_sizes,
|
||||
scalable_sizes,
|
||||
) = _dispatch_dynamic_index_list(sizes)
|
||||
|
||||
num_loops = sum(v if v == 0 else 1 for v in static_sizes)
|
||||
|
||||
if isinstance(loop_types_or_target, (Operation, Value, OpView)):
|
||||
loop_types = [transform.AnyOpType.get()] * num_loops
|
||||
target = loop_types_or_target
|
||||
assert (
|
||||
target_or_none is None
|
||||
), "Cannot construct TileUsingForOp with two targets."
|
||||
else:
|
||||
loop_types = (
|
||||
([loop_types_or_target] * num_loops)
|
||||
if isinstance(loop_types_or_target, Type)
|
||||
else loop_types_or_target
|
||||
)
|
||||
target = target_or_none
|
||||
|
||||
super().__init__(
|
||||
target.type,
|
||||
loop_types,
|
||||
target,
|
||||
dynamic_sizes=dynamic_sizes,
|
||||
static_sizes=static_sizes,
|
||||
interchange=interchange,
|
||||
scalable_sizes=scalable_sizes,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class TileUsingForallOp(TileUsingForallOp):
|
||||
"""Specialization for TileUsingForallOp class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
loops_type: Type,
|
||||
tiled_op_type: Type,
|
||||
target: Union[Operation, Value, OpView],
|
||||
*,
|
||||
num_threads: Optional[MixedValues] = None,
|
||||
tile_sizes: MixedValues = None,
|
||||
mapping=None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value, OpView],
|
||||
*,
|
||||
num_threads: Optional[MixedValues] = None,
|
||||
tile_sizes: MixedValues = None,
|
||||
mapping=None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loops_type_or_target: Union[
|
||||
Type, Union[Operation, Value, OpView] # loops_type
|
||||
], # target
|
||||
tiled_op_type_or_none: Optional[Type] = None,
|
||||
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
|
||||
*,
|
||||
num_threads: MixedValues = None,
|
||||
tile_sizes: MixedValues = None,
|
||||
mapping=None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
# `Type` arguments in the front are optional: add default values to front.
|
||||
if isinstance(loops_type_or_target, Type):
|
||||
# First overload: type arguments provided.
|
||||
if not isinstance(tiled_op_type_or_none, Type):
|
||||
raise TypeError(
|
||||
"If 'loops_type_or_target' is a type, then "
|
||||
"'tiled_op_type_or_none' is expected to be one as well."
|
||||
)
|
||||
loops_type = loops_type_or_target
|
||||
tiled_op_type = tiled_op_type_or_none
|
||||
target = target_or_none
|
||||
else:
|
||||
# Last overload: type arguments missing.
|
||||
loops_type = transform.AnyOpType.get()
|
||||
tiled_op_type = transform.AnyOpType.get()
|
||||
target = loops_type_or_target
|
||||
|
||||
# Unpack mixed num_threads.
|
||||
(
|
||||
dynamic_num_threads,
|
||||
packed_num_threads,
|
||||
num_threads_attr,
|
||||
) = _dispatch_mixed_values(num_threads)
|
||||
|
||||
# Unpack mixed tile_sizes.
|
||||
(
|
||||
dynamic_tile_sizes,
|
||||
packed_tile_sizes,
|
||||
tile_sizes_attr,
|
||||
) = _dispatch_mixed_values(tile_sizes)
|
||||
|
||||
super().__init__(
|
||||
loops_type,
|
||||
tiled_op_type,
|
||||
target=target,
|
||||
tile_sizes=dynamic_tile_sizes,
|
||||
packed_tile_sizes=packed_tile_sizes,
|
||||
static_tile_sizes=tile_sizes_attr,
|
||||
num_threads=dynamic_num_threads,
|
||||
packed_num_threads=packed_num_threads,
|
||||
static_num_threads=num_threads_attr,
|
||||
mapping=mapping,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class VectorizeChildrenAndApplyPatternsOp(VectorizeChildrenAndApplyPatternsOp):
|
||||
"""Specialization for VectorizeChildrenAndApplyPatternsOp class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
disable_multi_reduction_to_contract_patterns: bool = False,
|
||||
disable_transfer_permutation_map_lowering_patterns: bool = False,
|
||||
vectorize_nd_extract: bool = False,
|
||||
vectorize_padding: bool = False,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
transformed_type = transform.AnyOpType.get()
|
||||
super().__init__(
|
||||
transformed_type,
|
||||
target,
|
||||
disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns,
|
||||
disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns,
|
||||
vectorize_nd_extract=vectorize_nd_extract,
|
||||
vectorize_padding=vectorize_padding,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
@@ -3,3 +3,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from .._tensor_transform_ops_gen import *
|
||||
from .._tensor_transform_ops_gen import _Dialect
|
||||
|
||||
try:
|
||||
from ...ir import *
|
||||
from ...dialects import transform
|
||||
from .._ods_common import _cext as _ods_cext
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Optional, overload, Union
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class MakeLoopIndependentOp(MakeLoopIndependentOp):
|
||||
"""Specialization for MakeLoopIndependentOp class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
transformed_type: Type,
|
||||
target: Union[Operation, OpView, Value],
|
||||
num_loops: Union[int, IntegerAttr],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, OpView, Value],
|
||||
num_loops: Union[int, IntegerAttr],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformed_type_or_target: Type,
|
||||
target_or_num_loops: Union[int, IntegerAttr, Operation, OpView, Value] = None,
|
||||
num_loops_or_none: Optional[Union[int, IntegerAttr]] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(transformed_type_or_target, Type):
|
||||
transformed_type = transformed_type_or_target
|
||||
target = target_or_num_loops
|
||||
num_loops = num_loops_or_none
|
||||
else:
|
||||
transformed_type = transform.AnyOpType.get()
|
||||
target = transformed_type_or_target
|
||||
num_loops = target_or_num_loops
|
||||
|
||||
super().__init__(
|
||||
transformed_type,
|
||||
target,
|
||||
num_loops,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
@@ -114,6 +114,7 @@ def get_unranked_memref_descriptor(nparray):
|
||||
d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
|
||||
return d
|
||||
|
||||
|
||||
def move_aligned_ptr_by_offset(aligned_ptr, offset):
|
||||
"""Moves the supplied ctypes pointer ahead by `offset` elements."""
|
||||
aligned_addr = ctypes.addressof(aligned_ptr.contents)
|
||||
@@ -122,6 +123,7 @@ def move_aligned_ptr_by_offset(aligned_ptr, offset):
|
||||
content_ptr = ctypes.cast(aligned_addr + shift, type(aligned_ptr))
|
||||
return content_ptr
|
||||
|
||||
|
||||
def unranked_memref_to_numpy(unranked_memref, np_dtype):
|
||||
"""Converts unranked memrefs to numpy arrays."""
|
||||
ctp = as_ctype(np_dtype)
|
||||
@@ -139,10 +141,10 @@ def unranked_memref_to_numpy(unranked_memref, np_dtype):
|
||||
|
||||
def ranked_memref_to_numpy(ranked_memref):
|
||||
"""Converts ranked memrefs to numpy arrays."""
|
||||
content_ptr = move_aligned_ptr_by_offset(ranked_memref[0].aligned, ranked_memref[0].offset)
|
||||
np_arr = np.ctypeslib.as_array(
|
||||
content_ptr, shape=ranked_memref[0].shape
|
||||
content_ptr = move_aligned_ptr_by_offset(
|
||||
ranked_memref[0].aligned, ranked_memref[0].offset
|
||||
)
|
||||
np_arr = np.ctypeslib.as_array(content_ptr, shape=ranked_memref[0].shape)
|
||||
strided_arr = np.lib.stride_tricks.as_strided(
|
||||
np_arr,
|
||||
np.ctypeslib.as_array(ranked_memref[0].shape),
|
||||
|
||||
@@ -33,3 +33,16 @@ def testFastMathFlags():
|
||||
)
|
||||
# CHECK: %0 = arith.addf %cst, %cst fastmath<nnan,ninf> : f32
|
||||
print(r)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testArithValueBuilder
|
||||
@run
|
||||
def testArithValueBuilder():
|
||||
with Context() as ctx, Location.unknown():
|
||||
module = Module.create()
|
||||
f32_t = F32Type.get()
|
||||
|
||||
with InsertionPoint(module.body):
|
||||
a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
|
||||
# CHECK: %cst = arith.constant 4.242000e+01 : f32
|
||||
print(a)
|
||||
|
||||
@@ -30,14 +30,9 @@ constexpr const char *fileHeader = R"Py(
|
||||
# Autogenerated by mlir-tblgen; don't manually edit.
|
||||
|
||||
from ._ods_common import _cext as _ods_cext
|
||||
from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
|
||||
from ._ods_common import segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
|
||||
_ods_ir = _ods_cext.ir
|
||||
|
||||
try:
|
||||
from . import _{0}_ops_ext as _ods_ext_module
|
||||
except ImportError:
|
||||
_ods_ext_module = None
|
||||
|
||||
import builtins
|
||||
from typing import Sequence as _Sequence, Union as _Union
|
||||
|
||||
@@ -62,7 +57,6 @@ from ._{0}_ops_gen import _Dialect
|
||||
/// {1} is the operation name.
|
||||
constexpr const char *opClassTemplate = R"Py(
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
@_ods_extend_opview_class(_ods_ext_module)
|
||||
class {0}(_ods_ir.OpView):
|
||||
OPERATION_NAME = "{1}"
|
||||
)Py";
|
||||
@@ -301,17 +295,17 @@ static bool isODSReserved(StringRef str) {
|
||||
/// (does not change the `name` if it already is suitable) and returns the
|
||||
/// modified version.
|
||||
static std::string sanitizeName(StringRef name) {
|
||||
std::string processed_str = name.str();
|
||||
std::string processedStr = name.str();
|
||||
std::replace_if(
|
||||
processed_str.begin(), processed_str.end(),
|
||||
processedStr.begin(), processedStr.end(),
|
||||
[](char c) { return !llvm::isAlnum(c); }, '_');
|
||||
|
||||
if (llvm::isDigit(*processed_str.begin()))
|
||||
return "_" + processed_str;
|
||||
if (llvm::isDigit(*processedStr.begin()))
|
||||
return "_" + processedStr;
|
||||
|
||||
if (isPythonReserved(processed_str) || isODSReserved(processed_str))
|
||||
return processed_str + "_";
|
||||
return processed_str;
|
||||
if (isPythonReserved(processedStr) || isODSReserved(processedStr))
|
||||
return processedStr + "_";
|
||||
return processedStr;
|
||||
}
|
||||
|
||||
static std::string attrSizedTraitForKind(const char *kind) {
|
||||
@@ -853,10 +847,6 @@ populateBuilderRegions(const Operator &op,
|
||||
/// rebuild anew).
|
||||
static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
|
||||
raw_ostream &os) {
|
||||
// If we are asked to skip default builders, comply.
|
||||
if (op.skipDefaultBuilders())
|
||||
return {};
|
||||
|
||||
llvm::SmallVector<std::string> builderArgs;
|
||||
llvm::SmallVector<std::string> builderLines;
|
||||
llvm::SmallVector<std::string> operandArgNames;
|
||||
@@ -989,9 +979,6 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
|
||||
static void emitValueBuilder(const Operator &op,
|
||||
llvm::SmallVector<std::string> functionArgs,
|
||||
raw_ostream &os) {
|
||||
// If we are asked to skip default builders, comply.
|
||||
if (op.skipDefaultBuilders())
|
||||
return;
|
||||
// Params with (possibly) default args.
|
||||
auto valueBuilderParams =
|
||||
llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
|
||||
@@ -1010,9 +997,9 @@ static void emitValueBuilder(const Operator &op,
|
||||
auto lhs = *llvm::split(arg, "=").begin();
|
||||
return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
|
||||
});
|
||||
std::string name_without_dialect =
|
||||
std::string nameWithoutDialect =
|
||||
op.getOperationName().substr(op.getOperationName().find('.') + 1);
|
||||
os << llvm::formatv(valueBuilderTemplate, sanitizeName(name_without_dialect),
|
||||
os << llvm::formatv(valueBuilderTemplate, sanitizeName(nameWithoutDialect),
|
||||
op.getCppClassName(),
|
||||
llvm::join(valueBuilderParams, ", "),
|
||||
llvm::join(opBuilderArgs, ", "),
|
||||
@@ -1051,11 +1038,8 @@ static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
if (clDialectName.empty())
|
||||
llvm::PrintFatalError("dialect name not provided");
|
||||
|
||||
bool isExtension = !clDialectExtensionName.empty();
|
||||
os << llvm::formatv(fileHeader, isExtension
|
||||
? clDialectExtensionName.getValue()
|
||||
: clDialectName.getValue());
|
||||
if (isExtension)
|
||||
os << fileHeader;
|
||||
if (!clDialectExtensionName.empty())
|
||||
os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue());
|
||||
else
|
||||
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
|
||||
|
||||
Reference in New Issue
Block a user