[mlir][python] remove mixins (#68853)

This PR replaces the mixin `OpView` extension mechanism with the
standard inheritance mechanism.

Why? Firstly, mixins are not very pythonic (inheritance is usually used
for this), a little convoluted, and too "tight" (can only be used in the
immediately adjacent `_ext.py`). Secondly, it (mixins) are now blocking
are correct implementation of "value builders" (see
[here](https://github.com/llvm/llvm-project/pull/68764)) where the
problem becomes how to choose the correct base class that the value
builder should call.

This PR looks big/complicated but appearances are deceiving; 4 things
were needed to make this work:

1. Drop `skipDefaultBuilders` in
`OpPythonBindingGen::emitDefaultOpBuilders`
2. Former mixin extension classes are converted to inherit from the
generated `OpView` instead of being "mixins"
a. extension classes that simply were calling into an already generated
`super().__init__` continue to do so
b. (almost all) extension classes that were calling `self.build_generic`
because of a lack of default builder being generated can now also just
call `super().__init__`
3. To handle the [lone single
use-case](https://sourcegraph.com/search?q=context%3Aglobal+select_opview_mixin&patternType=standard&sm=1&groupBy=repo)
of `select_opview_mixin`, namely
[linalg](https://github.com/llvm/llvm-project/blob/main/mlir/python/mlir/dialects/_linalg_ops_ext.py#L38),
only a small change was necessary in `opdsl/lang/emitter.py` (thanks to
the emission/generation of default builders/`__init__`s)
4. since the `extend_opview_class` decorator is removed, we need a way
to register extension classes as the desired `OpView` that `op.opview`
conjures into existence; so we do the standard thing and just enable
replacing the existing registered `OpView` i.e.,
`register_operation(_Dialect, replace=True)`.

Note, the upgrade path for the common case is to change an extension to
inherit from the generated builder and decorate it with
`register_operation(_Dialect, replace=True)`. In the slightly more
complicated case where `super().__init(self.build_generic(...))` is
called in the extension's `__init__`, this needs to be updated to call
`__init__` in `OpView`, i.e., the grandparent (see updated docs). 
Note, also `<DIALECT>_ext.py` files/modules will no longer be automatically loaded.

Note, the PR has 3 base commits that look funny but this was done for
the purpose of tracking the line history of moving the
`<DIALECT>_ops_ext.py` class into `<DIALECT>.py` and updating (commit
labeled "fix").
This commit is contained in:
Maksim Levental
2023-10-19 16:20:14 -05:00
committed by GitHub
parent a30095a1e4
commit a2288a8944
49 changed files with 2815 additions and 2921 deletions

View File

@@ -1017,90 +1017,79 @@ very generic signature.
#### Extending Generated Op Classes
Note that this is a rather complex mechanism and this section errs on the side
of explicitness. Users are encouraged to find an example and duplicate it if
they don't feel the need to understand the subtlety. The `builtin` dialect
provides some relatively simple examples.
As mentioned above, the build system generates Python sources like
`_{DIALECT_NAMESPACE}_ops_gen.py` for each dialect with Python bindings. It is
often desirable to to use these generated classes as a starting point for
further customization, so an extension mechanism is provided to make this easy
(you are always free to do ad-hoc patching in your `{DIALECT_NAMESPACE}.py` file
but we prefer a more standard mechanism that is applied uniformly).
To provide extensions, add a `_{DIALECT_NAMESPACE}_ops_ext.py` file to the
`dialects` module (i.e. adjacent to your `{DIALECT_NAMESPACE}.py` top-level and
the `*_ops_gen.py` file). Using the `builtin` dialect and `FuncOp` as an
example, the generated code will include an import like this:
often desirable to use these generated classes as a starting point for
further customization, so an extension mechanism is provided to make this easy.
This mechanism uses conventional inheritance combined with `OpView` registration.
For example, the default builder for `arith.constant`
```python
try:
from . import _builtin_ops_ext as _ods_ext_module
except ImportError:
_ods_ext_module = None
class ConstantOp(_ods_ir.OpView):
OPERATION_NAME = "arith.constant"
_ODS_REGIONS = (0, True)
def __init__(self, value, *, loc=None, ip=None):
...
```
Then for each generated concrete `OpView` subclass, it will apply a decorator
like:
expects `value` to be a `TypedAttr` (e.g., `IntegerAttr` or `FloatAttr`).
Thus, a natural extension is a builder that accepts a MLIR type and a Python value and instantiates the appropriate `TypedAttr`:
```python
@_ods_cext.register_operation(_Dialect)
@_ods_extend_opview_class(_ods_ext_module)
class FuncOp(_ods_ir.OpView):
from typing import Union
from mlir.ir import Type, IntegerAttr, FloatAttr
from mlir.dialects._arith_ops_gen import _Dialect, ConstantOp
from mlir.dialects._ods_common import _cext
@_cext.register_operation(_Dialect, replace=True)
class ConstantOpExt(ConstantOp):
def __init__(
self, result: Type, value: Union[int, float], *, loc=None, ip=None
):
if isinstance(value, int):
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
elif isinstance(value, float):
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
else:
raise NotImplementedError(f"Building `arith.constant` not supported for {result=} {value=}")
```
See the `_ods_common.py` `extend_opview_class` function for details of the
mechanism. At a high level:
* If the extension module exists, locate an extension class for the op (in
this example, `FuncOp`):
* First by looking for an attribute with the exact name in the extension
module.
* Falling back to calling a `select_opview_mixin(parent_opview_cls)`
function defined in the extension module.
* If a mixin class is found, a new subclass is dynamically created that
multiply inherits from `({_builtin_ops_ext.FuncOp},
_builtin_ops_gen.FuncOp)`.
The mixin class should not inherit from anything (i.e. directly extends `object`
only). The facility is typically used to define custom `__init__` methods,
properties, instance methods and static methods. Due to the inheritance
ordering, the mixin class can act as though it extends the generated `OpView`
subclass in most contexts (i.e. `issubclass(_builtin_ops_ext.FuncOp, OpView)`
will return `False` but usage generally allows you treat it as duck typed as an
`OpView`).
There are a couple of recommendations, given how the class hierarchy is defined:
* For static methods that need to instantiate the actual "leaf" op (which is
dynamically generated and would result in circular dependencies to try to
reference by name), prefer to use `@classmethod` and the concrete subclass
will be provided as your first `cls` argument. See
`_builtin_ops_ext.FuncOp.from_py_func` as an example.
* If seeking to replace the generated `__init__` method entirely, you may
actually want to invoke the super-super-class `mlir.ir.OpView` constructor
directly, as it takes an `mlir.ir.Operation`, which is likely what you are
constructing (i.e. the generated `__init__` method likely adds more API
constraints than you want to expose in a custom builder).
A pattern that comes up frequently is wanting to provide a sugared `__init__`
method which has optional or type-polymorphism/implicit conversions but to
otherwise want to invoke the default op building logic. For such cases, it is
recommended to use an idiom such as:
which enables building an instance of `arith.constant` like so:
```python
def __init__(self, sugar, spice, *, loc=None, ip=None):
... massage into result_type, operands, attributes ...
OpView.__init__(self, self.build_generic(
results=[result_type],
operands=operands,
attributes=attributes,
loc=loc,
ip=ip))
from mlir.ir import F32Type
a = ConstantOpExt(F32Type.get(), 42.42)
b = ConstantOpExt(IntegerType.get_signless(32), 42)
```
Refer to the documentation for `build_generic` for more information.
Note, three key aspects of the extension mechanism in this example:
1. `ConstantOpExt` directly inherits from the generated `ConstantOp`;
2. in this, simplest, case all that's required is a call to the super class' initializer, i.e., `super().__init__(...)`;
3. in order to register `ConstantOpExt` as the preferred `OpView` that is returned by `mlir.ir.Operation.opview` (see [Operations, Regions and Blocks](#operations-regions-and-blocks))
we decorate the class with `@_cext.register_operation(_Dialect, replace=True)`, **where the `replace=True` must be used**.
In some more complex cases it might be necessary to explicitly build the `OpView` through `OpView.build_generic` (see [Default Builder](#default-builder)), just as is performed by the generated builders.
I.e., we must call `OpView.build_generic` **and pass the result to `OpView.__init__`**, where the small issue becomes that the latter is already overridden by the generated builder.
Thus, we must call a method of a super class' super class (the "grandparent"); for example:
```python
from mlir.dialects._scf_ops_gen import _Dialect, ForOp
from mlir.dialects._ods_common import _cext
@_cext.register_operation(_Dialect, replace=True)
class ForOpExt(ForOp):
def __init__(self, lower_bound, upper_bound, step, iter_args, *, loc=None, ip=None):
...
super(ForOp, self).__init__(self.build_generic(...))
```
where `OpView.__init__` is called via `super(ForOp, self).__init__`.
Note, there are alternatives ways to implement this (e.g., explicitly writing `OpView.__init__`); see any discussion on Python inheritance.
## Providing Python bindings for a dialect

View File

@@ -77,10 +77,10 @@ public:
pybind11::object pyClass);
/// Adds a concrete implementation operation class.
/// Raises an exception if the mapping already exists.
/// Raises an exception if the mapping already exists and replace == false.
/// This is intended to be called by implementation code.
void registerOperationImpl(const std::string &operationName,
pybind11::object pyClass);
pybind11::object pyClass, bool replace = false);
/// Returns the custom Attribute builder for Attribute kind.
std::optional<pybind11::function>

View File

@@ -96,9 +96,9 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
}
void PyGlobals::registerOperationImpl(const std::string &operationName,
py::object pyClass) {
py::object pyClass, bool replace) {
py::object &found = operationClassMap[operationName];
if (found) {
if (found && !replace) {
throw std::runtime_error((llvm::Twine("Operation '") + operationName +
"' is already registered.")
.str());

View File

@@ -41,7 +41,7 @@ PYBIND11_MODULE(_mlir, m) {
"dialect_namespace"_a, "dialect_class"_a,
"Testing hook for directly registering a dialect")
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
"operation_name"_a, "operation_class"_a,
"operation_name"_a, "operation_class"_a, "replace"_a = false,
"Testing hook for directly registering an operation");
// Aside from making the globals accessible to python, having python manage
@@ -63,12 +63,13 @@ PYBIND11_MODULE(_mlir, m) {
"Class decorator for registering a custom Dialect wrapper");
m.def(
"register_operation",
[](const py::object &dialectClass) -> py::cpp_function {
[](const py::object &dialectClass, bool replace) -> py::cpp_function {
return py::cpp_function(
[dialectClass](py::object opClass) -> py::object {
[dialectClass, replace](py::object opClass) -> py::object {
std::string operationName =
opClass.attr("OPERATION_NAME").cast<std::string>();
PyGlobals::get().registerOperationImpl(operationName, opClass);
PyGlobals::get().registerOperationImpl(operationName, opClass,
replace);
// Dict-stuff the new opClass by name onto the dialect class.
py::object opClassName = opClass.attr("__name__");
@@ -76,7 +77,7 @@ PYBIND11_MODULE(_mlir, m) {
return opClass;
});
},
"dialect_class"_a,
"dialect_class"_a, "replace"_a = false,
"Produce a class decorator for registering an Operation class as part of "
"a dialect");
m.def(

View File

@@ -52,7 +52,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/AffineOps.td
SOURCES
dialects/affine.py
dialects/_affine_ops_ext.py
DIALECT_NAME affine
GEN_ENUM_BINDINGS)
@@ -78,7 +77,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/BufferizationOps.td
SOURCES
dialects/bufferization.py
dialects/_bufferization_ops_ext.py
DIALECT_NAME bufferization
GEN_ENUM_BINDINGS_TD_FILE
"../../include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
@@ -90,7 +88,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/BuiltinOps.td
SOURCES
dialects/builtin.py
dialects/_builtin_ops_ext.py
DIALECT_NAME builtin)
declare_mlir_dialect_python_bindings(
@@ -115,7 +112,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/FuncOps.td
SOURCES
dialects/func.py
dialects/_func_ops_ext.py
DIALECT_NAME func)
declare_mlir_dialect_python_bindings(
@@ -131,7 +127,6 @@ declare_mlir_dialect_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/LinalgOps.td
SOURCES
dialects/_linalg_ops_ext.py
SOURCES_GLOB
dialects/linalg/*.py
DIALECT_NAME linalg
@@ -152,7 +147,6 @@ ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TransformPDLExtensionOps.td
SOURCES
dialects/_transform_pdl_extension_ops_ext.py
dialects/transform/pdl.py
DIALECT_NAME transform
EXTENSION_NAME transform_pdl_extension)
@@ -162,7 +156,6 @@ declare_mlir_dialect_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TransformOps.td
SOURCES
dialects/_transform_ops_ext.py
dialects/transform/__init__.py
_mlir_libs/_mlir/dialects/transform/__init__.pyi
DIALECT_NAME transform
@@ -175,7 +168,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/BufferizationTransformOps.td
SOURCES
dialects/_bufferization_transform_ops_ext.py
dialects/transform/bufferization.py
DIALECT_NAME transform
EXTENSION_NAME bufferization_transform)
@@ -185,7 +177,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/GPUTransformOps.td
SOURCES
dialects/_gpu_transform_ops_ext.py
dialects/transform/gpu.py
DIALECT_NAME transform
EXTENSION_NAME gpu_transform)
@@ -195,7 +186,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/SCFLoopTransformOps.td
SOURCES
dialects/_loop_transform_ops_ext.py
dialects/transform/loop.py
DIALECT_NAME transform
EXTENSION_NAME loop_transform)
@@ -205,7 +195,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/MemRefTransformOps.td
SOURCES
dialects/_memref_transform_ops_ext.py
dialects/transform/memref.py
DIALECT_NAME transform
EXTENSION_NAME memref_transform)
@@ -224,7 +213,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/LinalgStructuredTransformOps.td
SOURCES
dialects/_structured_transform_ops_ext.py
dialects/transform/structured.py
DIALECT_NAME transform
EXTENSION_NAME structured_transform
@@ -246,7 +234,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TensorTransformOps.td
SOURCES
dialects/_tensor_transform_ops_ext.py
dialects/transform/tensor.py
DIALECT_NAME transform
EXTENSION_NAME tensor_transform)
@@ -276,7 +263,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/ArithOps.td
SOURCES
dialects/arith.py
dialects/_arith_ops_ext.py
DIALECT_NAME arith
GEN_ENUM_BINDINGS)
@@ -286,7 +272,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/MemRefOps.td
SOURCES
dialects/memref.py
dialects/_memref_ops_ext.py
DIALECT_NAME memref)
declare_mlir_dialect_python_bindings(
@@ -295,7 +280,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/MLProgramOps.td
SOURCES
dialects/ml_program.py
dialects/_ml_program_ops_ext.py
DIALECT_NAME ml_program)
declare_mlir_dialect_python_bindings(
@@ -339,7 +323,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/PDLOps.td
SOURCES
dialects/pdl.py
dialects/_pdl_ops_ext.py
_mlir_libs/_mlir/dialects/pdl.pyi
DIALECT_NAME pdl)
@@ -357,7 +340,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/SCFOps.td
SOURCES
dialects/scf.py
dialects/_scf_ops_ext.py
DIALECT_NAME scf)
declare_mlir_dialect_python_bindings(
@@ -383,7 +365,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/TensorOps.td
SOURCES
dialects/tensor.py
dialects/_tensor_ops_ext.py
DIALECT_NAME tensor)
declare_mlir_dialect_python_bindings(

View File

@@ -1,56 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
from ._ods_common import get_op_results_or_values as _get_op_results_or_values
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Sequence, Union
class AffineStoreOp:
"""Specialization for the Affine store operation."""
def __init__(
self,
value: Union[Operation, OpView, Value],
memref: Union[Operation, OpView, Value],
map: AffineMap=None,
*,
map_operands=None,
loc=None,
ip=None
):
"""Creates an affine store operation.
- `value`: the value to store into the memref.
- `memref`: the buffer to store into.
- `map`: the affine map that maps the map_operands to the index of the
memref.
- `map_operands`: the list of arguments to substitute the dimensions,
then symbols in the affine map, in increasing order.
"""
map = map if map is not None else []
map_operands = map_operands if map_operands is not None else []
operands = [
_get_op_result_or_value(value),
_get_op_result_or_value(memref),
*[_get_op_result_or_value(op) for op in map_operands]
]
results = []
attributes = {"map": AffineMapAttr.get(map)}
regions = None
_ods_successors = None
super().__init__(self.build_generic(
attributes=attributes,
results=results,
operands=operands,
successors=_ods_successors,
regions=regions,
loc=loc,
ip=ip
))

View File

@@ -1,69 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ._ods_common import get_default_loc_context as _get_default_loc_context
from typing import Any, List, Union
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
def _isa(obj: Any, cls: type):
try:
cls(obj)
except ValueError:
return False
return True
def _is_any_of(obj: Any, classes: List[type]):
return any(_isa(obj, cls) for cls in classes)
def _is_integer_like_type(type: Type):
return _is_any_of(type, [IntegerType, IndexType])
def _is_float_type(type: Type):
return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
class ConstantOp:
"""Specialization for the constant op class."""
def __init__(
self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
):
if isinstance(value, int):
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
elif isinstance(value, float):
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
else:
super().__init__(value, loc=loc, ip=ip)
@classmethod
def create_index(cls, value: int, *, loc=None, ip=None):
"""Create an index-typed constant."""
return cls(
IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip
)
@property
def type(self):
return self.results[0].type
@property
def value(self):
return Attribute(self.operation.attributes["value"])
@property
def literal_value(self) -> Union[int, float]:
if _is_integer_like_type(self.type):
return IntegerAttr(self.value).value
elif _is_float_type(self.type):
return FloatAttr(self.value).value
else:
raise ValueError("only integer and float constants have literal values")

View File

@@ -1,41 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from typing import Sequence, Union
from ..ir import *
from ._ods_common import get_default_loc_context
from typing import Any, List, Union
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
class AllocTensorOp:
"""Extends the bufferization.alloc_tensor op."""
def __init__(
self,
tensor_type: Type,
dynamic_sizes: Sequence[Value],
copy: Value,
size_hint: Value,
escape: BoolAttr,
*,
loc=None,
ip=None
):
"""Constructs an `alloc_tensor` with static and/or dynamic sizes."""
context = get_default_loc_context(loc)
attributes = {}
if escape:
attributes["escape"] = escape
op = self.build_generic(
results=[tensor_type],
operands=[dynamic_sizes, copy, size_hint],
attributes=attributes,
loc=loc,
ip=ip,
)
OpView.__init__(self, op)

View File

@@ -1,128 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ..dialects import transform
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from enum import Enum
from typing import Optional, overload, Union
class EmptyTensorToAllocTensorOp:
"""Specialization for EmptyTensorToAllocTensorOp class."""
@overload
def __init__(
self,
transformed_type: Type,
target: Union[Operation, OpView, Value],
*,
loc=None,
ip=None
):
...
@overload
def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None):
...
def __init__(
self,
transformed_type_or_target: Type,
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
*,
loc=None,
ip=None
):
if isinstance(transformed_type_or_target, Type):
transformed_type = transformed_type_or_target
target = target_or_none
else:
transformed_type = transform.OperationType.get("bufferization.alloc_tensor")
target = transformed_type_or_target
super().__init__(
transformed_type,
target,
loc=loc,
ip=ip,
)
class OneShotBufferizeOp:
"""Specialization for OneShotBufferizeOp class."""
@overload
def __init__(
self,
transformed_type: Type,
target: Union[Operation, OpView, Value],
*,
allow_return_allocs_from_loops: Optional[bool] = None,
allow_unknown_ops: Optional[bool] = None,
bufferize_function_boundaries: Optional[bool] = None,
function_boundary_type_conversion: Optional[Enum] = None,
memcpy_op: Optional[str] = None,
print_conflicts: Optional[bool] = None,
test_analysis_only: Optional[bool] = None,
loc=None,
ip=None
):
...
@overload
def __init__(
self,
target: Union[Operation, OpView, Value],
*,
allow_return_allocs_from_loops: Optional[bool] = None,
allow_unknown_ops: Optional[bool] = None,
bufferize_function_boundaries: Optional[bool] = None,
function_boundary_type_conversion: Optional[Enum] = None,
memcpy_op: Optional[str] = None,
print_conflicts: Optional[bool] = None,
test_analysis_only: Optional[bool] = None,
loc=None,
ip=None
):
...
def __init__(
self,
transformed_type_or_target: Type,
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
*,
allow_return_allocs_from_loops: Optional[bool] = None,
allow_unknown_ops: Optional[bool] = None,
bufferize_function_boundaries: Optional[bool] = None,
function_boundary_type_conversion: Optional[Enum] = None,
memcpy_op: Optional[str] = None,
print_conflicts: Optional[bool] = None,
test_analysis_only: Optional[bool] = None,
loc=None,
ip=None
):
if isinstance(transformed_type_or_target, Type):
transformed_type = transformed_type_or_target
target = target_or_none
else:
transformed_type = transform.AnyOpType.get()
target = transformed_type_or_target
super().__init__(
transformed_type,
target,
allow_return_allocs_from_loops=allow_return_allocs_from_loops,
allow_unknown_ops=allow_unknown_ops,
bufferize_function_boundaries=bufferize_function_boundaries,
function_boundary_type_conversion=function_boundary_type_conversion,
memcpy_op=memcpy_op,
print_conflicts=print_conflicts,
test_analysis_only=test_analysis_only,
loc=loc,
ip=ip,
)

View File

@@ -1,20 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
class ModuleOp:
"""Specialization for the module op class."""
def __init__(self, *, loc=None, ip=None):
super().__init__(self.build_generic(results=[], operands=[], loc=loc, ip=ip))
body = self.regions[0].blocks.append()
@property
def body(self):
return self.regions[0].blocks[0]

View File

@@ -1,319 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ._ods_common import get_default_loc_context as _get_default_loc_context
import inspect
from typing import Any, List, Optional, Sequence, Union
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
RESULT_ATTRIBUTE_NAME = "res_attrs"
class ConstantOp:
"""Specialization for the constant op class."""
def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
super().__init__(result, value, loc=loc, ip=ip)
@property
def type(self):
return self.results[0].type
class FuncOp:
"""Specialization for the func op class."""
def __init__(
self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
):
"""
Create a FuncOp with the provided `name`, `type`, and `visibility`.
- `name` is a string representing the function name.
- `type` is either a FunctionType or a pair of list describing inputs and
results.
- `visibility` is a string matching `public`, `private`, or `nested`. None
implies private visibility.
- `body_builder` is an optional callback, when provided a new entry block
is created and the callback is invoked with the new op as argument within
an InsertionPoint context already set for the block. The callback is
expected to insert a terminator in the block.
"""
sym_name = StringAttr.get(str(name))
# If the type is passed as a tuple, build a FunctionType on the fly.
if isinstance(type, tuple):
type = FunctionType.get(inputs=type[0], results=type[1])
type = TypeAttr.get(type)
sym_visibility = (
StringAttr.get(str(visibility)) if visibility is not None else None
)
super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
if body_builder:
entry_block = self.add_entry_block()
with InsertionPoint(entry_block):
body_builder(self)
@property
def is_external(self):
return len(self.regions[0].blocks) == 0
@property
def body(self):
return self.regions[0]
@property
def type(self):
return FunctionType(TypeAttr(self.attributes["function_type"]).value)
@property
def visibility(self):
return self.attributes["sym_visibility"]
@property
def name(self) -> StringAttr:
return StringAttr(self.attributes["sym_name"])
@property
def entry_block(self):
if self.is_external:
raise IndexError("External function does not have a body")
return self.regions[0].blocks[0]
def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
"""
Add an entry block to the function body using the function signature to
infer block arguments.
Returns the newly created block
"""
if not self.is_external:
raise IndexError("The function already has an entry block!")
self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
return self.body.blocks[0]
@property
def arg_attrs(self):
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
@arg_attrs.setter
def arg_attrs(self, attribute: Union[ArrayAttr, list]):
if isinstance(attribute, ArrayAttr):
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
else:
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
attribute, context=self.context
)
@property
def arguments(self):
return self.entry_block.arguments
@property
def result_attrs(self):
return self.attributes[RESULT_ATTRIBUTE_NAME]
@result_attrs.setter
def result_attrs(self, attribute: ArrayAttr):
self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
@classmethod
def from_py_func(
FuncOp,
*inputs: Type,
results: Optional[Sequence[Type]] = None,
name: Optional[str] = None,
):
"""Decorator to define an MLIR FuncOp specified as a python function.
Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
active for the current thread (i.e. established in a `with` block).
When applied as a decorator to a Python function, an entry block will
be constructed for the FuncOp with types as specified in `*inputs`. The
block arguments will be passed positionally to the Python function. In
addition, if the Python function accepts keyword arguments generally or
has a corresponding keyword argument, the following will be passed:
* `func_op`: The `func` op being defined.
By default, the function name will be the Python function `__name__`. This
can be overriden by passing the `name` argument to the decorator.
If `results` is not specified, then the decorator will implicitly
insert a `ReturnOp` with the `Value`'s returned from the decorated
function. It will also set the `FuncOp` type with the actual return
value types. If `results` is specified, then the decorated function
must return `None` and no implicit `ReturnOp` is added (nor are the result
types updated). The implicit behavior is intended for simple, single-block
cases, and users should specify result types explicitly for any complicated
cases.
The decorated function can further be called from Python and will insert
a `CallOp` at the then-current insertion point, returning either None (
if no return values), a unary Value (for one result), or a list of Values).
This mechanism cannot be used to emit recursive calls (by construction).
"""
def decorator(f):
from . import func
# Introspect the callable for optional features.
sig = inspect.signature(f)
has_arg_func_op = False
for param in sig.parameters.values():
if param.kind == param.VAR_KEYWORD:
has_arg_func_op = True
if param.name == "func_op" and (
param.kind == param.POSITIONAL_OR_KEYWORD
or param.kind == param.KEYWORD_ONLY
):
has_arg_func_op = True
# Emit the FuncOp.
implicit_return = results is None
symbol_name = name or f.__name__
function_type = FunctionType.get(
inputs=inputs, results=[] if implicit_return else results
)
func_op = FuncOp(name=symbol_name, type=function_type)
with InsertionPoint(func_op.add_entry_block()):
func_args = func_op.entry_block.arguments
func_kwargs = {}
if has_arg_func_op:
func_kwargs["func_op"] = func_op
return_values = f(*func_args, **func_kwargs)
if not implicit_return:
return_types = list(results)
assert return_values is None, (
"Capturing a python function with explicit `results=` "
"requires that the wrapped function returns None."
)
else:
# Coerce return values, add ReturnOp and rewrite func type.
if return_values is None:
return_values = []
elif isinstance(return_values, tuple):
return_values = list(return_values)
elif isinstance(return_values, Value):
# Returning a single value is fine, coerce it into a list.
return_values = [return_values]
elif isinstance(return_values, OpView):
# Returning a single operation is fine, coerce its results a list.
return_values = return_values.operation.results
elif isinstance(return_values, Operation):
# Returning a single operation is fine, coerce its results a list.
return_values = return_values.results
else:
return_values = list(return_values)
func.ReturnOp(return_values)
# Recompute the function type.
return_types = [v.type for v in return_values]
function_type = FunctionType.get(
inputs=inputs, results=return_types
)
func_op.attributes["function_type"] = TypeAttr.get(function_type)
def emit_call_op(*call_args):
call_op = func.CallOp(
return_types, FlatSymbolRefAttr.get(symbol_name), call_args
)
if return_types is None:
return None
elif len(return_types) == 1:
return call_op.result
else:
return call_op.results
wrapped = emit_call_op
wrapped.__name__ = f.__name__
wrapped.func_op = func_op
return wrapped
return decorator
class CallOp:
"""Specialization for the call op class."""
def __init__(
self,
calleeOrResults: Union[FuncOp, List[Type]],
argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
arguments: Optional[List] = None,
*,
loc=None,
ip=None,
):
"""Creates an call operation.
The constructor accepts three different forms:
1. A function op to be called followed by a list of arguments.
2. A list of result types, followed by the name of the function to be
called as string, following by a list of arguments.
3. A list of result types, followed by the name of the function to be
called as symbol reference attribute, followed by a list of arguments.
For example
f = func.FuncOp("foo", ...)
func.CallOp(f, [args])
func.CallOp([result_types], "foo", [args])
In all cases, the location and insertion point may be specified as keyword
arguments if not provided by the surrounding context managers.
"""
# TODO: consider supporting constructor "overloads", e.g., through a custom
# or pybind-provided metaclass.
if isinstance(calleeOrResults, FuncOp):
if not isinstance(argumentsOrCallee, list):
raise ValueError(
"when constructing a call to a function, expected "
+ "the second argument to be a list of call arguments, "
+ f"got {type(argumentsOrCallee)}"
)
if arguments is not None:
raise ValueError(
"unexpected third argument when constructing a call"
+ "to a function"
)
super().__init__(
calleeOrResults.type.results,
FlatSymbolRefAttr.get(
calleeOrResults.name.value, context=_get_default_loc_context(loc)
),
argumentsOrCallee,
loc=loc,
ip=ip,
)
return
if isinstance(argumentsOrCallee, list):
raise ValueError(
"when constructing a call to a function by name, "
+ "expected the second argument to be a string or a "
+ f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}"
)
if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
super().__init__(
calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip
)
elif isinstance(argumentsOrCallee, str):
super().__init__(
calleeOrResults,
FlatSymbolRefAttr.get(
argumentsOrCallee, context=_get_default_loc_context(loc)
),
arguments,
loc=loc,
ip=ip,
)

View File

@@ -1,124 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ..dialects import transform
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Sequence, Union, overload
class MapForallToBlocks:
"""Specialization for MapForallToBlocks class."""
@overload
def __init__(
self,
result_type: Type,
target: Union[Operation, OpView, Value],
*,
grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
loc=None,
ip=None
):
...
@overload
def __init__(
self,
target: Union[Operation, OpView, Value],
*,
grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
loc=None,
ip=None
):
...
def __init__(
self,
result_type_or_target: Union[Operation, OpView, Type, Value],
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
*,
grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
loc=None,
ip=None
):
if isinstance(result_type_or_target, Type):
result_type = result_type_or_target
target = target_or_none
else:
result_type = transform.AnyOpType.get()
target = result_type_or_target
super().__init__(
result_type,
target,
grid_dims=grid_dims,
generate_gpu_launch=generate_gpu_launch,
loc=loc,
ip=ip,
)
class MapNestedForallToThreads:
"""Specialization for MapNestedForallToThreads class."""
@overload
def __init__(
self,
result_type: Type,
target: Union[Operation, OpView, Value],
*,
block_dims: Optional[Sequence[int]] = None,
warp_size: Optional[Sequence[int]] = None,
sync_after_distribute: Optional[bool] = None,
loc=None,
ip=None
):
...
@overload
def __init__(
self,
target: Union[Operation, OpView, Value],
*,
block_dims: Optional[Sequence[int]] = None,
warp_size: Optional[Sequence[int]] = None,
sync_after_distribute: Optional[bool] = None,
loc=None,
ip=None
):
...
def __init__(
self,
result_type_or_target: Union[Operation, OpView, Value, Type],
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
*,
block_dims: Optional[Union[Sequence[int], Attribute]] = None,
warp_size: Optional[Union[Sequence[int], Attribute]] = None,
sync_after_distribute: Optional[bool] = None,
loc=None,
ip=None
):
if isinstance(result_type_or_target, Type):
result_type = result_type_or_target
target = target_or_none
else:
result_type = result_type_or_target.type
target = result_type_or_target
super().__init__(
result_type,
target,
block_dims=block_dims,
warp_size=warp_size,
sync_after_distribute=sync_after_distribute,
loc=loc,
ip=ip,
)

View File

@@ -1,47 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from typing import Optional, Sequence, Union
from ..ir import *
from ._ods_common import get_default_loc_context
from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
def isa(cls: Type, ty: Type):
try:
cls(ty)
return True
except ValueError:
return False
class StructuredOpMixin:
"""All structured ops use the same mixin class."""
def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
super().__init__(
self.build_generic(
results=list(results),
operands=[list(inputs), list(outputs)],
loc=loc,
ip=ip,
)
)
def select_opview_mixin(parent_opview_cls):
# TODO: This shouldn't be a heuristic: we should have a way to annotate
# the OpView to note that it is a structured op.
if (
"__init__" not in parent_opview_cls.__dict__
and hasattr(parent_opview_cls, "inputs")
and hasattr(parent_opview_cls, "outputs")
and hasattr(parent_opview_cls, "result_tensors")
):
return StructuredOpMixin

View File

@@ -1,134 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Union
class GetParentForOp:
"""Extension for GetParentForOp."""
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
num_loops: Optional[int] = None,
ip=None,
loc=None,
):
if num_loops is None:
num_loops = 1
super().__init__(
result_type,
_get_op_result_or_value(target),
num_loops=num_loops,
ip=ip,
loc=loc,
)
class LoopOutlineOp:
"""Extension for LoopOutlineOp."""
def __init__(
self,
function_type: Type,
call_type: Type,
target: Union[Operation, Value],
*,
func_name: Union[str, StringAttr],
ip=None,
loc=None,
):
super().__init__(
function_type,
call_type,
_get_op_result_or_value(target),
func_name=(
func_name
if isinstance(func_name, StringAttr)
else StringAttr.get(func_name)
),
ip=ip,
loc=loc,
)
class LoopPeelOp:
"""Extension for LoopPeelOp."""
def __init__(
self,
main_loop_type: Type,
remainder_loop_type: Type,
target: Union[Operation, Value],
*,
fail_if_already_divisible: Union[bool, BoolAttr] = False,
ip=None,
loc=None,
):
super().__init__(
main_loop_type,
remainder_loop_type,
_get_op_result_or_value(target),
fail_if_already_divisible=(
fail_if_already_divisible
if isinstance(fail_if_already_divisible, BoolAttr)
else BoolAttr.get(fail_if_already_divisible)
),
ip=ip,
loc=loc,
)
class LoopPipelineOp:
"""Extension for LoopPipelineOp."""
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
iteration_interval: Optional[Union[int, IntegerAttr]] = None,
read_latency: Optional[Union[int, IntegerAttr]] = None,
ip=None,
loc=None,
):
if iteration_interval is None:
iteration_interval = 1
if read_latency is None:
read_latency = 10
super().__init__(
result_type,
_get_op_result_or_value(target),
iteration_interval=iteration_interval,
read_latency=read_latency,
ip=ip,
loc=loc,
)
class LoopUnrollOp:
"""Extension for LoopUnrollOp."""
def __init__(
self,
target: Union[Operation, Value],
*,
factor: Union[int, IntegerAttr],
ip=None,
loc=None,
):
super().__init__(
_get_op_result_or_value(target),
factor=factor,
ip=ip,
loc=loc,
)

View File

@@ -1,36 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
from ._ods_common import get_op_results_or_values as _get_op_results_or_values
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Sequence, Union
class LoadOp:
"""Specialization for the MemRef load operation."""
def __init__(
self,
memref: Union[Operation, OpView, Value],
indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
*,
loc=None,
ip=None
):
"""Creates a memref load operation.
Args:
memref: the buffer to load from.
indices: the list of subscripts, may be empty for zero-dimensional
buffers.
loc: user-visible location of the operation.
ip: insertion point.
"""
indices_resolved = [] if indices is None else _get_op_results_or_values(indices)
super().__init__(memref, indices_resolved, loc=loc, ip=ip)

View File

@@ -1,114 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ..dialects import transform
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, overload, Union
class MemRefAllocaToGlobalOp:
"""Specialization for MemRefAllocaToGlobalOp class."""
@overload
def __init__(
self,
get_global_type: Type,
global_type: Type,
alloca: Union[Operation, OpView, Value],
*,
loc=None,
ip=None
):
...
@overload
def __init__(self, alloca: Union[Operation, OpView, Value], *, loc=None, ip=None):
...
def __init__(
self,
get_global_type_or_alloca: Union[Operation, OpView, Type, Value],
global_type_or_none: Optional[Type] = None,
alloca_or_none: Optional[Union[Operation, OpView, Value]] = None,
*,
loc=None,
ip=None
):
if isinstance(get_global_type_or_alloca, Type):
get_global_type = get_global_type_or_alloca
global_type = global_type_or_none
alloca = alloca_or_none
else:
get_global_type = transform.AnyOpType.get()
global_type = transform.AnyOpType.get()
alloca = get_global_type_or_alloca
super().__init__(
get_global_type,
global_type,
alloca,
loc=loc,
ip=ip,
)
class MemRefMultiBufferOp:
"""Specialization for MemRefMultiBufferOp class."""
@overload
def __init__(
self,
transformed_type: Type,
target: Union[Operation, OpView, Value],
factor: Union[int, IntegerAttr],
*,
skip_analysis: Optional[bool] = None,
loc=None,
ip=None
):
...
@overload
def __init__(
self,
target: Union[Operation, OpView, Value],
factor: Union[int, IntegerAttr],
*,
skip_analysis: Optional[bool] = None,
loc=None,
ip=None
):
...
def __init__(
self,
transformed_type_or_target: Type,
target_or_factor: Union[int, IntegerAttr, Operation, OpView, Value] = None,
factor_or_none: Optional[Union[int, IntegerAttr]] = None,
*,
skip_analysis: Optional[bool] = None,
loc=None,
ip=None
):
if isinstance(transformed_type_or_target, Type):
transformed_type = transformed_type_or_target
target = target_or_factor
factor = factor_or_none
else:
transformed_type = transform.AnyOpType.get()
target = transformed_type_or_target
factor = target_or_factor
super().__init__(
transformed_type,
target,
factor,
skip_analysis=skip_analysis,
loc=loc,
ip=ip,
)

View File

@@ -1,113 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from typing import Union
from ..ir import *
from ._ods_common import get_default_loc_context as _get_default_loc_context
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from ._ml_program_ops_gen import *
ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
RESULT_ATTRIBUTE_NAME = "res_attrs"
class FuncOp:
"""Specialization for the func op class."""
def __init__(
self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
):
"""
Create a FuncOp with the provided `name`, `type`, and `visibility`.
- `name` is a string representing the function name.
- `type` is either a FunctionType or a pair of list describing inputs and
results.
- `visibility` is a string matching `public`, `private`, or `nested`. None
implies private visibility.
- `body_builder` is an optional callback, when provided a new entry block
is created and the callback is invoked with the new op as argument within
an InsertionPoint context already set for the block. The callback is
expected to insert a terminator in the block.
"""
sym_name = StringAttr.get(str(name))
# If the type is passed as a tuple, build a FunctionType on the fly.
if isinstance(type, tuple):
type = FunctionType.get(inputs=type[0], results=type[1])
type = TypeAttr.get(type)
sym_visibility = (
StringAttr.get(str(visibility)) if visibility is not None else None
)
super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
if body_builder:
entry_block = self.add_entry_block()
with InsertionPoint(entry_block):
body_builder(self)
@property
def is_external(self):
return len(self.regions[0].blocks) == 0
@property
def body(self):
return self.regions[0]
@property
def type(self):
return FunctionType(TypeAttr(self.attributes["function_type"]).value)
@property
def visibility(self):
return self.attributes["sym_visibility"]
@property
def name(self) -> StringAttr:
return StringAttr(self.attributes["sym_name"])
@property
def entry_block(self):
if self.is_external:
raise IndexError("External function does not have a body")
return self.regions[0].blocks[0]
def add_entry_block(self):
"""
Add an entry block to the function body using the function signature to
infer block arguments.
Returns the newly created block
"""
if not self.is_external:
raise IndexError("The function already has an entry block!")
self.body.blocks.append(*self.type.inputs)
return self.body.blocks[0]
@property
def arg_attrs(self):
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
@arg_attrs.setter
def arg_attrs(self, attribute: Union[ArrayAttr, list]):
if isinstance(attribute, ArrayAttr):
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
else:
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
attribute, context=self.context
)
@property
def arguments(self):
return self.entry_block.arguments
@property
def result_attrs(self):
return self.attributes[RESULT_ATTRIBUTE_NAME]
@result_attrs.setter
def result_attrs(self, attribute: ArrayAttr):
self.attributes[RESULT_ATTRIBUTE_NAME] = attribute

View File

@@ -9,7 +9,6 @@ from typing import Sequence as _Sequence, Union as _Union
__all__ = [
"equally_sized_accessor",
"extend_opview_class",
"get_default_loc_context",
"get_op_result_or_value",
"get_op_results_or_values",
@@ -18,64 +17,6 @@ __all__ = [
]
def extend_opview_class(ext_module):
"""Decorator to extend an OpView class from an extension module.
Extension modules can expose various entry-points:
Stand-alone class with the same name as a parent OpView class (i.e.
"ReturnOp"). A name-based match is attempted first before falling back
to a below mechanism.
def select_opview_mixin(parent_opview_cls):
If defined, allows an appropriate mixin class to be selected dynamically
based on the parent OpView class. Should return NotImplemented if a
decision is not made.
Args:
ext_module: A module from which to locate extensions. Can be None if not
available.
Returns:
A decorator that takes an OpView subclass and further extends it as
needed.
"""
def class_decorator(parent_opview_cls: type):
if ext_module is None:
return parent_opview_cls
mixin_cls = NotImplemented
# First try to resolve by name.
try:
mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
except AttributeError:
# Fall back to a select_opview_mixin hook.
try:
select_mixin = getattr(ext_module, "select_opview_mixin")
except AttributeError:
pass
else:
mixin_cls = select_mixin(parent_opview_cls)
if mixin_cls is NotImplemented or mixin_cls is None:
return parent_opview_cls
# Have a mixin_cls. Create an appropriate subclass.
try:
class LocalOpView(mixin_cls, parent_opview_cls):
pass
except TypeError as e:
raise TypeError(
f"Could not mixin {mixin_cls} into {parent_opview_cls}"
) from e
LocalOpView.__name__ = parent_opview_cls.__name__
LocalOpView.__qualname__ = parent_opview_cls.__qualname__
return LocalOpView
return class_decorator
def segmented_accessor(elements, raw_segments, idx):
"""
Returns a slice of elements corresponding to the idx-th segment.

View File

@@ -1,271 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ..dialects import pdl
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Union, Optional, Sequence, Mapping
from ._ods_common import (
get_op_result_or_value as _get_value,
get_op_results_or_values as _get_values,
)
class ApplyNativeConstraintOp:
"""Specialization for PDL apply native constraint op class."""
def __init__(
self,
name: Union[str, StringAttr],
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
*,
loc=None,
ip=None,
):
if args is None:
args = []
args = _get_values(args)
super().__init__(name, args, loc=loc, ip=ip)
class ApplyNativeRewriteOp:
"""Specialization for PDL apply native rewrite op class."""
def __init__(
self,
results: Sequence[Type],
name: Union[str, StringAttr],
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
*,
loc=None,
ip=None,
):
if args is None:
args = []
args = _get_values(args)
super().__init__(results, name, args, loc=loc, ip=ip)
class AttributeOp:
"""Specialization for PDL attribute op class."""
def __init__(
self,
valueType: Optional[Union[OpView, Operation, Value]] = None,
value: Optional[Attribute] = None,
*,
loc=None,
ip=None,
):
valueType = valueType if valueType is None else _get_value(valueType)
result = pdl.AttributeType.get()
super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
class EraseOp:
"""Specialization for PDL erase op class."""
def __init__(
self,
operation: Optional[Union[OpView, Operation, Value]] = None,
*,
loc=None,
ip=None,
):
operation = _get_value(operation)
super().__init__(operation, loc=loc, ip=ip)
class OperandOp:
"""Specialization for PDL operand op class."""
def __init__(
self,
type: Optional[Union[OpView, Operation, Value]] = None,
*,
loc=None,
ip=None,
):
type = type if type is None else _get_value(type)
result = pdl.ValueType.get()
super().__init__(result, valueType=type, loc=loc, ip=ip)
class OperandsOp:
"""Specialization for PDL operands op class."""
def __init__(
self,
types: Optional[Union[OpView, Operation, Value]] = None,
*,
loc=None,
ip=None,
):
types = types if types is None else _get_value(types)
result = pdl.RangeType.get(pdl.ValueType.get())
super().__init__(result, valueType=types, loc=loc, ip=ip)
class OperationOp:
"""Specialization for PDL operand op class."""
def __init__(
self,
name: Optional[Union[str, StringAttr]] = None,
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
attributes: Optional[Mapping[str, Union[OpView, Operation, Value]]] = None,
types: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
*,
loc=None,
ip=None,
):
if types is None:
types = []
if attributes is None:
attributes = {}
if args is None:
args = []
args = _get_values(args)
attrNames = []
attrValues = []
for attrName, attrValue in attributes.items():
attrNames.append(StringAttr.get(attrName))
attrValues.append(_get_value(attrValue))
attrNames = ArrayAttr.get(attrNames)
types = _get_values(types)
result = pdl.OperationType.get()
super().__init__(
result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip
)
class PatternOp:
"""Specialization for PDL pattern op class."""
def __init__(
self,
benefit: Union[IntegerAttr, int],
name: Optional[Union[StringAttr, str]] = None,
*,
loc=None,
ip=None,
):
"""Creates an PDL `pattern` operation."""
super().__init__(benefit, sym_name=name, loc=loc, ip=ip)
self.regions[0].blocks.append()
@property
def body(self):
"""Return the body (block) of the pattern."""
return self.regions[0].blocks[0]
class ReplaceOp:
"""Specialization for PDL replace op class."""
def __init__(
self,
op: Union[OpView, Operation, Value],
*,
with_op: Optional[Union[OpView, Operation, Value]] = None,
with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
loc=None,
ip=None,
):
if with_values is None:
with_values = []
op = _get_value(op)
with_op = with_op if with_op is None else _get_value(with_op)
with_values = _get_values(with_values)
super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
class ResultOp:
"""Specialization for PDL result op class."""
def __init__(
self,
parent: Union[OpView, Operation, Value],
index: Union[IntegerAttr, int],
*,
loc=None,
ip=None,
):
parent = _get_value(parent)
result = pdl.ValueType.get()
super().__init__(result, parent, index, loc=loc, ip=ip)
class ResultsOp:
"""Specialization for PDL results op class."""
def __init__(
self,
result: Type,
parent: Union[OpView, Operation, Value],
index: Optional[Union[IntegerAttr, int]] = None,
*,
loc=None,
ip=None,
):
parent = _get_value(parent)
super().__init__(result, parent, index=index, loc=loc, ip=ip)
class RewriteOp:
"""Specialization for PDL rewrite op class."""
def __init__(
self,
root: Optional[Union[OpView, Operation, Value]] = None,
name: Optional[Union[StringAttr, str]] = None,
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
*,
loc=None,
ip=None,
):
if args is None:
args = []
root = root if root is None else _get_value(root)
args = _get_values(args)
super().__init__(args, root=root, name=name, loc=loc, ip=ip)
def add_body(self):
"""Add body (block) to the rewrite."""
self.regions[0].blocks.append()
return self.body
@property
def body(self):
"""Return the body (block) of the rewrite."""
return self.regions[0].blocks[0]
class TypeOp:
"""Specialization for PDL type op class."""
def __init__(
self, constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None
):
result = pdl.TypeType.get()
super().__init__(result, constantType=constantType, loc=loc, ip=ip)
class TypesOp:
"""Specialization for PDL types op class."""
def __init__(
self,
constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None,
*,
loc=None,
ip=None,
):
if constantTypes is None:
constantTypes = []
result = pdl.RangeType.get(pdl.TypeType.get())
super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)

View File

@@ -1,107 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Sequence, Union
from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
)
class ForOp:
"""Specialization for the SCF for op class."""
def __init__(
self,
lower_bound,
upper_bound,
step,
iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
*,
loc=None,
ip=None,
):
"""Creates an SCF `for` operation.
- `lower_bound` is the value to use as lower bound of the loop.
- `upper_bound` is the value to use as upper bound of the loop.
- `step` is the value to use as loop step.
- `iter_args` is a list of additional loop-carried arguments or an operation
producing them as results.
"""
if iter_args is None:
iter_args = []
iter_args = _get_op_results_or_values(iter_args)
results = [arg.type for arg in iter_args]
super().__init__(
self.build_generic(
regions=1,
results=results,
operands=[
_get_op_result_or_value(o) for o in [lower_bound, upper_bound, step]
]
+ list(iter_args),
loc=loc,
ip=ip,
)
)
self.regions[0].blocks.append(self.operands[0].type, *results)
@property
def body(self):
"""Returns the body (block) of the loop."""
return self.regions[0].blocks[0]
@property
def induction_variable(self):
"""Returns the induction variable of the loop."""
return self.body.arguments[0]
@property
def inner_iter_args(self):
"""Returns the loop-carried arguments usable within the loop.
To obtain the loop-carried operands, use `iter_args`.
"""
return self.body.arguments[1:]
class IfOp:
"""Specialization for the SCF if op class."""
def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
"""Creates an SCF `if` operation.
- `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
- `hasElse` determines whether the if operation has the else branch.
"""
operands = []
operands.append(cond)
results = []
results.extend(results_)
super().__init__(
self.build_generic(
regions=2, results=results, operands=operands, loc=loc, ip=ip
)
)
self.regions[0].blocks.append(*[])
if hasElse:
self.regions[1].blocks.append(*[])
@property
def then_block(self):
"""Returns the then block of the if operation."""
return self.regions[0].blocks[0]
@property
def else_block(self):
"""Returns the else block of the if operation."""
return self.regions[1].blocks[0]

View File

@@ -1,759 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ..dialects import transform
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import List, Optional, Sequence, Tuple, Union, overload
StaticIntLike = Union[int, IntegerAttr]
ValueLike = Union[Operation, OpView, Value]
MixedInt = Union[StaticIntLike, ValueLike]
IntOrAttrList = Sequence[Union[IntegerAttr, int]]
OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
BoolOrAttrList = Sequence[Union[BoolAttr, bool]]
OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]]
MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]]
def _dispatch_dynamic_index_list(
indices: Union[DynamicIndexList, ArrayAttr],
) -> Tuple[List[ValueLike], Union[List[int], ArrayAttr], List[bool]]:
"""Dispatches a list of indices to the appropriate form.
This is similar to the custom `DynamicIndexList` directive upstream:
provided indices may be in the form of dynamic SSA values or static values,
and they may be scalable (i.e., as a singleton list) or not. This function
dispatches each index into its respective form. It also extracts the SSA
values and static indices from various similar structures, respectively.
"""
dynamic_indices = []
static_indices = [ShapedType.get_dynamic_size()] * len(indices)
scalable_indices = [False] * len(indices)
# ArrayAttr: Extract index values.
if isinstance(indices, ArrayAttr):
indices = [idx for idx in indices]
def process_nonscalable_index(i, index):
"""Processes any form of non-scalable index.
Returns False if the given index was scalable and thus remains
unprocessed; True otherwise.
"""
if isinstance(index, int):
static_indices[i] = index
elif isinstance(index, IntegerAttr):
static_indices[i] = index.value # pytype: disable=attribute-error
elif isinstance(index, (Operation, Value, OpView)):
dynamic_indices.append(index)
else:
return False
return True
# Process each index at a time.
for i, index in enumerate(indices):
if not process_nonscalable_index(i, index):
# If it wasn't processed, it must be a scalable index, which is
# provided as a Sequence of one value, so extract and process that.
scalable_indices[i] = True
assert len(index) == 1
ret = process_nonscalable_index(i, index[0])
assert ret
return dynamic_indices, static_indices, scalable_indices
# Dispatches `MixedValues` that all represents integers in various forms into
# the following three categories:
# - `dynamic_values`: a list of `Value`s, potentially from op results;
# - `packed_values`: a value handle, potentially from an op result, associated
# to one or more payload operations of integer type;
# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
# The input is in the form for `packed_values`, only that result is set and the
# other two are empty. Otherwise, the input can be a mix of the other two forms,
# and for each dynamic value, a special value is added to the `static_values`.
def _dispatch_mixed_values(
values: MixedValues,
) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]:
dynamic_values = []
packed_values = None
static_values = None
if isinstance(values, ArrayAttr):
static_values = values
elif isinstance(values, (Operation, Value, OpView)):
packed_values = values
else:
static_values = []
for size in values or []:
if isinstance(size, int):
static_values.append(size)
else:
static_values.append(ShapedType.get_dynamic_size())
dynamic_values.append(size)
static_values = DenseI64ArrayAttr.get(static_values)
return (dynamic_values, packed_values, static_values)
def _get_value_or_attribute_value(
value_or_attr: Union[any, Attribute, ArrayAttr]
) -> any:
if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
return value_or_attr.value
if isinstance(value_or_attr, ArrayAttr):
return _get_value_list(value_or_attr)
return value_or_attr
def _get_value_list(
sequence_or_array_attr: Union[Sequence[any], ArrayAttr]
) -> Sequence[any]:
return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr:
if values is None:
return None
# Turn into a Python list of Python ints.
values = _get_value_list(values)
# Make an ArrayAttr of IntegerAttrs out of it.
return ArrayAttr.get(
[IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
)
def _get_int_array_array_attr(
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
) -> ArrayAttr:
"""Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
The input has to be a collection of collection of integers, where any
Python Sequence and ArrayAttr are admissible collections and Python ints and
any IntegerAttr are admissible integers. Both levels of collections are
turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
If the input is None, an empty ArrayAttr is returned.
"""
if values is None:
return None
# Make sure the outer level is a list.
values = _get_value_list(values)
# The inner level is now either invalid or a mixed sequence of ArrayAttrs and
# Sequences. Make sure the nested values are all lists.
values = [_get_value_list(nested) for nested in values]
# Turn each nested list into an ArrayAttr.
values = [_get_int_array_attr(nested) for nested in values]
# Turn the outer list into an ArrayAttr.
return ArrayAttr.get(values)
class BufferizeToAllocationOp:
"""Specialization for BufferizeToAllocationOp class."""
def __init__(
self,
target: Union[Operation, OpView, Value],
*,
memory_space: Optional[Union[int, str, Attribute]] = None,
memcpy_op: Optional[str] = None,
alloc_op: Optional[str] = None,
bufferize_destination_only: Optional[bool] = None,
loc=None,
ip=None,
):
# No other types are allowed, so hard-code those here.
allocated_buffer_type = transform.AnyValueType.get()
new_ops_type = transform.AnyOpType.get()
if isinstance(memory_space, int):
memory_space = str(memory_space)
if isinstance(memory_space, str):
memory_space = Attribute.parse(memory_space)
super().__init__(
allocated_buffer_type,
new_ops_type,
target,
memory_space=memory_space,
memcpy_op=memcpy_op,
alloc_op=alloc_op,
bufferize_destination_only=bufferize_destination_only,
loc=loc,
ip=ip,
)
class DecomposeOp:
"""Specialization for DecomposeOp class."""
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
transformed_type = transform.AnyOpType.get()
super().__init__(transformed_type, target, loc=loc, ip=ip)
class FuseIntoContainingOp:
"""Specialization for FuseIntoContainingOp class."""
@overload
def __init__(
self,
fused_op_type: Type,
new_containing_op_type: Type,
producer_op: Union[Operation, OpView, Value],
containing_op: Union[Operation, OpView, Value],
*,
loc=None,
ip=None,
):
...
@overload
def __init__(
self,
producer_op: Union[Operation, OpView, Value],
containing_op: Union[Operation, OpView, Value],
*,
loc=None,
ip=None,
):
...
def __init__(
self,
fused_op_type_or_producer_op: Union[Operation, OpView, Type, Value],
new_containing_op_type_or_containing_op: Union[Operation, OpView, Type, Value],
producer_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
containing_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
*,
loc=None,
ip=None,
):
if isinstance(fused_op_type_or_producer_op, Type):
if not isinstance(new_containing_op_type_or_containing_op, Type):
raise TypeError(
"If 'fused_op_type_or_producer_op' is a type, then "
"'new_containing_op_type_or_containing_op' is expected "
"to be one as well."
)
fused_op_type = fused_op_type_or_producer_op
new_containing_op_type = new_containing_op_type_or_containing_op
producer_op = producer_op_or_none
containing_op = containing_op_or_none
else:
fused_op_type = transform.AnyOpType.get()
new_containing_op_type = transform.AnyOpType.get()
producer_op = fused_op_type_or_producer_op
containing_op = new_containing_op_type_or_containing_op
super().__init__(
fused_op_type,
new_containing_op_type,
producer_op,
containing_op,
loc=loc,
ip=ip,
)
class GeneralizeOp:
"""Specialization for GeneralizeOp class."""
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
transformed_type = transform.AnyOpType.get()
super().__init__(transformed_type, target, loc=loc, ip=ip)
class InterchangeOp:
"""Specialization for InterchangeOp class."""
def __init__(
self,
target: Union[Operation, Value],
*,
iterator_interchange: OptionalIntList = None,
loc=None,
ip=None,
):
transformed_type = transform.AnyOpType.get()
super().__init__(
transformed_type,
target,
iterator_interchange=iterator_interchange,
loc=loc,
ip=ip,
)
class MapCopyToThreadsOp:
"""Specialization for MapCopyToThreadsOp class."""
@overload
def __init__(
self,
forall_op_type: Type,
tiled_op_type: Type,
target: Union[Operation, OpView, Value],
*,
total_num_threads: Union[int, IntegerAttr],
desired_bit_alignment: Union[int, IntegerAttr],
loc=None,
ip=None,
):
...
@overload
def __init__(
self,
target: Union[Operation, OpView, Value],
*,
total_num_threads: Union[int, IntegerAttr],
desired_bit_alignment: Union[int, IntegerAttr],
loc=None,
ip=None,
):
...
def __init__(
self,
forall_op_type_or_target: Union[Operation, OpView, Type, Value],
tiled_op_type_or_none: Optional[Type] = None,
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
*,
total_num_threads: Union[int, IntegerAttr],
desired_bit_alignment: Union[int, IntegerAttr],
loc=None,
ip=None,
):
if isinstance(forall_op_type_or_target, Type):
forall_op_type = forall_op_type_or_target
tiled_op_type = tiled_op_type_or_none
target = target_or_none
else:
forall_op_type = transform.AnyOpType.get()
tiled_op_type = transform.AnyOpType.get()
target = forall_op_type_or_target
super().__init__(
forall_op_type,
tiled_op_type,
target,
total_num_threads=total_num_threads,
desired_bit_alignment=desired_bit_alignment,
loc=loc,
ip=ip,
)
class VectorizeOp:
"""Specialization for VectorizeOp class."""
def __init__(
self,
target: Union[Operation, OpView, Value],
vector_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
*,
vectorize_nd_extract: Optional[bool] = None,
scalable_sizes: OptionalBoolList = None,
static_vector_sizes: OptionalIntList = None,
loc=None,
ip=None,
):
if (
scalable_sizes is None
and static_vector_sizes is None
and vector_sizes is None
):
dynamic_vector_sizes = []
elif scalable_sizes is None and static_vector_sizes is None:
(
dynamic_vector_sizes,
static_vector_sizes,
scalable_sizes,
) = _dispatch_dynamic_index_list(vector_sizes)
elif scalable_sizes is None or static_vector_sizes is None:
raise TypeError(
"'scalable_sizes' and 'static_vector_sizes' must either both "
"be given explicitly or both be given as part of 'vector_sizes'."
)
else:
dynamic_vector_sizes = vector_sizes
super().__init__(
target,
vector_sizes=dynamic_vector_sizes,
static_vector_sizes=static_vector_sizes,
scalable_sizes=scalable_sizes,
vectorize_nd_extract=vectorize_nd_extract,
loc=loc,
ip=ip,
)
class MatchOp:
"""Specialization for MatchOp class."""
@overload
@classmethod
def match_op_names(
cls,
target: Union[Operation, Value],
names: Union[str, Sequence[str]],
*,
loc=None,
ip=None,
):
...
@overload
@classmethod
def match_op_names(
cls,
result_type: Type,
target: Union[Operation, Value],
names: Union[str, Sequence[str]],
*,
loc=None,
ip=None,
):
...
@classmethod
def match_op_names(
cls,
result_type_or_target: Union[Type, Operation, Value],
target_or_names: Union[Operation, Value, Sequence[str], str],
names_or_none: Optional[Union[Sequence[str], str]] = None,
*,
loc=None,
ip=None,
):
if isinstance(result_type_or_target, Type):
result_type = result_type_or_target
target = target_or_names
names = names_or_none
else:
result_type = transform.AnyOpType.get()
target = result_type_or_target
names = target_or_names
if isinstance(names, str):
names = [names]
return cls(
result_type,
target,
ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
loc=loc,
ip=ip,
)
class MultiTileSizesOp:
"""Specialization for MultiTileSizesOp class."""
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
dimension: Union[int, IntegerAttr],
target_size: Union[int, IntegerAttr],
divisor: Optional[Optional[Union[int, IntegerAttr]]] = None,
loc=None,
ip=None,
):
super().__init__(
result_type,
result_type,
result_type,
target,
dimension=dimension,
target_size=target_size,
divisor=divisor,
loc=loc,
ip=ip,
)
class PadOp:
"""Specialization for PadOp class."""
def __init__(
self,
target: Union[Operation, OpView, Value],
*,
padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
padding_dimensions: OptionalIntList = None,
pad_to_multiple_of: OptionalIntList = None,
pack_paddings: OptionalIntList = None,
transpose_paddings: Optional[
Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
] = None,
copy_back_op: Optional[Union[str, StringAttr]] = None,
loc=None,
ip=None,
):
transpose_paddings = _get_int_array_array_attr(transpose_paddings)
any_op_type = transform.AnyOpType.get()
super().__init__(
any_op_type,
any_op_type,
any_op_type,
target,
padding_values=padding_values,
padding_dimensions=padding_dimensions,
pad_to_multiple_of=pad_to_multiple_of,
pack_paddings=pack_paddings,
transpose_paddings=transpose_paddings,
copy_back_op=copy_back_op,
loc=loc,
ip=ip,
)
class ScalarizeOp:
"""Specialization for ScalarizeOp class."""
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
result_type = transform.AnyOpType.get()
super().__init__(result_type, target, loc=loc, ip=ip)
class SplitOp:
"""Specialization for SplitOp class."""
def __init__(
self,
target: Union[Operation, Value],
dimension: Union[int, Attribute],
split_point: Union[int, Operation, Value, Attribute],
*,
loc=None,
ip=None,
):
if isinstance(split_point, int):
static_split_point = split_point
dynamic_split_point = None
else:
static_split_point = ShapedType.get_dynamic_size()
dynamic_split_point = split_point
super().__init__(
target.type,
target.type,
target,
dimension=dimension,
static_split_point=static_split_point,
dynamic_split_point=dynamic_split_point,
loc=loc,
ip=ip,
)
class TileUsingForOp:
"""Specialization for TileUsingForOp class."""
@overload
def __init__(
self,
loop_types: Union[Type, List[Type]],
target: Union[Operation, Value],
*,
sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
interchange: OptionalIntList = None,
loc=None,
ip=None,
):
...
@overload
def __init__(
self,
target: Union[Operation, Value, OpView],
*,
sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
interchange: OptionalIntList = None,
loc=None,
ip=None,
):
...
def __init__(
self,
loop_types_or_target: Union[Type, List[Type], Operation, Value],
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
*,
sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
interchange: OptionalIntList = None,
loc=None,
ip=None,
):
(
dynamic_sizes,
static_sizes,
scalable_sizes,
) = _dispatch_dynamic_index_list(sizes)
num_loops = sum(v if v == 0 else 1 for v in static_sizes)
if isinstance(loop_types_or_target, (Operation, Value, OpView)):
loop_types = [transform.AnyOpType.get()] * num_loops
target = loop_types_or_target
assert (
target_or_none is None
), "Cannot construct TileUsingForOp with two targets."
else:
loop_types = (
([loop_types_or_target] * num_loops)
if isinstance(loop_types_or_target, Type)
else loop_types_or_target
)
target = target_or_none
super().__init__(
target.type,
loop_types,
target,
dynamic_sizes=dynamic_sizes,
static_sizes=static_sizes,
interchange=interchange,
scalable_sizes=scalable_sizes,
loc=loc,
ip=ip,
)
class TileUsingForallOp:
"""Specialization for TileUsingForallOp class."""
@overload
def __init__(
self,
loops_type: Type,
tiled_op_type: Type,
target: Union[Operation, Value, OpView],
*,
num_threads: Optional[MixedValues] = None,
tile_sizes: MixedValues = None,
mapping=None,
loc=None,
ip=None,
):
...
@overload
def __init__(
self,
target: Union[Operation, Value, OpView],
*,
num_threads: Optional[MixedValues] = None,
tile_sizes: MixedValues = None,
mapping=None,
loc=None,
ip=None,
):
...
def __init__(
self,
loops_type_or_target: Union[
Type, Union[Operation, Value, OpView] # loops_type
], # target
tiled_op_type_or_none: Optional[Type] = None,
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
*,
num_threads: MixedValues = None,
tile_sizes: MixedValues = None,
mapping=None,
loc=None,
ip=None,
):
# `Type` arguments in the front are optional: add default values to front.
if isinstance(loops_type_or_target, Type):
# First overload: type arguments provided.
if not isinstance(tiled_op_type_or_none, Type):
raise TypeError(
"If 'loops_type_or_target' is a type, then "
"'tiled_op_type_or_none' is expected to be one as well."
)
loops_type = loops_type_or_target
tiled_op_type = tiled_op_type_or_none
target = target_or_none
else:
# Last overload: type arguments missing.
loops_type = transform.AnyOpType.get()
tiled_op_type = transform.AnyOpType.get()
target = loops_type_or_target
# Unpack mixed num_threads.
(
dynamic_num_threads,
packed_num_threads,
num_threads_attr,
) = _dispatch_mixed_values(num_threads)
# Unpack mixed tile_sizes.
(
dynamic_tile_sizes,
packed_tile_sizes,
tile_sizes_attr,
) = _dispatch_mixed_values(tile_sizes)
super().__init__(
loops_type,
tiled_op_type,
target=target,
tile_sizes=dynamic_tile_sizes,
packed_tile_sizes=packed_tile_sizes,
static_tile_sizes=tile_sizes_attr,
num_threads=dynamic_num_threads,
packed_num_threads=packed_num_threads,
static_num_threads=num_threads_attr,
mapping=mapping,
loc=loc,
ip=ip,
)
class VectorizeChildrenAndApplyPatternsOp:
"""Specialization for VectorizeChildrenAndApplyPatternsOp class."""
def __init__(
self,
target: Union[Operation, Value],
*,
disable_multi_reduction_to_contract_patterns: bool = False,
disable_transfer_permutation_map_lowering_patterns: bool = False,
vectorize_nd_extract: bool = False,
vectorize_padding: bool = False,
loc=None,
ip=None,
):
transformed_type = transform.AnyOpType.get()
super().__init__(
transformed_type,
target,
disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns,
disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns,
vectorize_nd_extract=vectorize_nd_extract,
vectorize_padding=vectorize_padding,
loc=loc,
ip=ip,
)

View File

@@ -1,44 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Any, Optional, Sequence, Union
from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
)
class EmptyOp:
"""Extends the tensor.empty op."""
def __init__(
self,
sizes: Sequence[Union[int, Value]],
element_type: Type,
*,
loc=None,
ip=None
):
"""Constructs an `empty` with mixed static/dynamic sizes."""
# TODO: Refactor the EmptyOp to take an element type attribute and
# then use normal result type inference, unifying the Python and C++ side
# with a standard mechanism (versus stashing that in builders).
dynamic_sizes = []
static_sizes = []
for s in sizes:
if isinstance(s, int):
static_sizes.append(s)
else:
static_sizes.append(ShapedType.get_dynamic_size())
dynamic_sizes.append(s)
result_type = RankedTensorType.get(static_sizes, element_type)
op = self.build_generic(
results=[result_type], operands=dynamic_sizes, attributes={}, loc=loc, ip=ip
)
OpView.__init__(self, op)

View File

@@ -1,64 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ..dialects import transform
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, overload, Union
class MakeLoopIndependentOp:
"""Specialization for MakeLoopIndependentOp class."""
@overload
def __init__(
self,
transformed_type: Type,
target: Union[Operation, OpView, Value],
num_loops: Union[int, IntegerAttr],
*,
loc=None,
ip=None
):
...
@overload
def __init__(
self,
target: Union[Operation, OpView, Value],
num_loops: Union[int, IntegerAttr],
*,
loc=None,
ip=None
):
...
def __init__(
self,
transformed_type_or_target: Type,
target_or_num_loops: Union[int, IntegerAttr, Operation, OpView, Value] = None,
num_loops_or_none: Optional[Union[int, IntegerAttr]] = None,
*,
loc=None,
ip=None
):
if isinstance(transformed_type_or_target, Type):
transformed_type = transformed_type_or_target
target = target_or_num_loops
num_loops = num_loops_or_none
else:
transformed_type = transform.AnyOpType.get()
target = transformed_type_or_target
num_loops = target_or_num_loops
super().__init__(
transformed_type,
target,
num_loops,
loc=loc,
ip=ip,
)

View File

@@ -1,176 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Sequence, Union
class CastOp:
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
loc=None,
ip=None,
):
super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
class ApplyPatternsOp:
def __init__(
self,
target: Union[Operation, Value, OpView],
*,
loc=None,
ip=None,
):
operands = []
operands.append(_get_op_result_or_value(target))
super().__init__(
self.build_generic(
attributes={},
results=[],
operands=operands,
successors=None,
regions=None,
loc=loc,
ip=ip,
)
)
self.regions[0].blocks.append()
@property
def patterns(self) -> Block:
return self.regions[0].blocks[0]
class testGetParentOp:
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
isolated_from_above: bool = False,
op_name: Optional[str] = None,
deduplicate: bool = False,
loc=None,
ip=None,
):
super().__init__(
result_type,
_get_op_result_or_value(target),
isolated_from_above=isolated_from_above,
op_name=op_name,
deduplicate=deduplicate,
loc=loc,
ip=ip,
)
class MergeHandlesOp:
def __init__(
self,
handles: Sequence[Union[Operation, Value]],
*,
deduplicate: bool = False,
loc=None,
ip=None,
):
super().__init__(
[_get_op_result_or_value(h) for h in handles],
deduplicate=deduplicate,
loc=loc,
ip=ip,
)
class ReplicateOp:
def __init__(
self,
pattern: Union[Operation, Value],
handles: Sequence[Union[Operation, Value]],
*,
loc=None,
ip=None,
):
super().__init__(
[_get_op_result_or_value(h).type for h in handles],
_get_op_result_or_value(pattern),
[_get_op_result_or_value(h) for h in handles],
loc=loc,
ip=ip,
)
class SequenceOp:
def __init__(
self,
failure_propagation_mode,
results: Sequence[Type],
target: Union[Operation, Value, Type],
extra_bindings: Optional[
Union[Sequence[Value], Sequence[Type], Operation, OpView]
] = None,
):
root = (
_get_op_result_or_value(target)
if isinstance(target, (Operation, Value))
else None
)
root_type = root.type if not isinstance(target, Type) else target
if extra_bindings is None:
extra_bindings = []
if isinstance(extra_bindings, (Operation, OpView)):
extra_bindings = _get_op_results_or_values(extra_bindings)
extra_binding_types = []
if len(extra_bindings) != 0:
if isinstance(extra_bindings[0], Type):
extra_binding_types = extra_bindings
extra_bindings = []
else:
extra_binding_types = [v.type for v in extra_bindings]
super().__init__(
results_=results,
failure_propagation_mode=failure_propagation_mode,
root=root,
extra_bindings=extra_bindings,
)
self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
@property
def body(self) -> Block:
return self.regions[0].blocks[0]
@property
def bodyTarget(self) -> Value:
return self.body.arguments[0]
@property
def bodyExtraArgs(self) -> BlockArgumentList:
return self.body.arguments[1:]
class YieldOp:
def __init__(
self,
operands: Optional[Union[Operation, Sequence[Value]]] = None,
*,
loc=None,
ip=None,
):
if operands is None:
operands = []
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)

View File

@@ -1,55 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Union
class PDLMatchOp:
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
pattern_name: Union[Attribute, str],
*,
loc=None,
ip=None,
):
super().__init__(
result_type,
_get_op_result_or_value(target),
pattern_name,
loc=loc,
ip=ip,
)
class WithPDLPatternsOp:
def __init__(self,
target: Union[Operation, Value, Type],
*,
loc=None,
ip=None):
root = _get_op_result_or_value(target) if not isinstance(target,
Type) else None
root_type = target if isinstance(target, Type) else root.type
super().__init__(root=root, loc=loc, ip=ip)
self.regions[0].blocks.append(root_type)
@property
def body(self) -> Block:
return self.regions[0].blocks[0]
@property
def bodyTarget(self) -> Value:
return self.body.arguments[0]

View File

@@ -1,5 +1,50 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._affine_ops_gen import *
from ._affine_ops_gen import _Dialect
try:
from ..ir import *
from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
_cext as _ods_cext,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Sequence, Union
@_ods_cext.register_operation(_Dialect, replace=True)
class AffineStoreOp(AffineStoreOp):
"""Specialization for the Affine store operation."""
def __init__(
self,
value: Union[Operation, OpView, Value],
memref: Union[Operation, OpView, Value],
map: AffineMap = None,
*,
map_operands=None,
loc=None,
ip=None,
):
"""Creates an affine store operation.
- `value`: the value to store into the memref.
- `memref`: the buffer to store into.
- `map`: the affine map that maps the map_operands to the index of the
memref.
- `map_operands`: the list of arguments to substitute the dimensions,
then symbols in the affine map, in increasing order.
"""
map = map if map is not None else []
map_operands = map_operands if map_operands is not None else []
indicies = [_get_op_result_or_value(op) for op in map_operands]
_ods_successors = None
super().__init__(
value, memref, indicies, AffineMapAttr.get(map), loc=loc, ip=ip
)

View File

@@ -3,4 +3,75 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._arith_ops_gen import *
from ._arith_ops_gen import _Dialect
from ._arith_enum_gen import *
try:
from ..ir import *
from ._ods_common import (
get_default_loc_context as _get_default_loc_context,
_cext as _ods_cext,
)
from typing import Any, List, Union
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
def _isa(obj: Any, cls: type):
try:
cls(obj)
except ValueError:
return False
return True
def _is_any_of(obj: Any, classes: List[type]):
return any(_isa(obj, cls) for cls in classes)
def _is_integer_like_type(type: Type):
return _is_any_of(type, [IntegerType, IndexType])
def _is_float_type(type: Type):
return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
@_ods_cext.register_operation(_Dialect, replace=True)
class ConstantOp(ConstantOp):
"""Specialization for the constant op class."""
def __init__(
self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
):
if isinstance(value, int):
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
elif isinstance(value, float):
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
else:
super().__init__(value, loc=loc, ip=ip)
@classmethod
def create_index(cls, value: int, *, loc=None, ip=None):
"""Create an index-typed constant."""
return cls(
IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip
)
@property
def type(self):
return self.results[0].type
@property
def value(self):
return Attribute(self.operation.attributes["value"])
@property
def literal_value(self) -> Union[int, float]:
if _is_integer_like_type(self.type):
return IntegerAttr(self.value).value
elif _is_float_type(self.type):
return FloatAttr(self.value).value
else:
raise ValueError("only integer and float constants have literal values")

View File

@@ -3,4 +3,40 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._bufferization_ops_gen import *
from ._bufferization_ops_gen import _Dialect
from ._bufferization_enum_gen import *
try:
from typing import Sequence, Union
from ..ir import *
from ._ods_common import get_default_loc_context, _cext as _ods_cext
from typing import Any, List, Union
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
@_ods_cext.register_operation(_Dialect, replace=True)
class AllocTensorOp(AllocTensorOp):
"""Extends the bufferization.alloc_tensor op."""
def __init__(
self,
tensor_type: Type,
dynamic_sizes: Sequence[Value],
copy: Value,
size_hint: Value,
escape: BoolAttr,
*,
loc=None,
ip=None,
):
"""Constructs an `alloc_tensor` with static and/or dynamic sizes."""
super().__init__(
tensor_type,
dynamic_sizes,
copy=copy,
size_hint=size_hint,
loc=loc,
ip=ip,
)

View File

@@ -3,3 +3,23 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._builtin_ops_gen import *
from ._builtin_ops_gen import _Dialect
try:
from ..ir import *
from ._ods_common import _cext as _ods_cext
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
@_ods_cext.register_operation(_Dialect, replace=True)
class ModuleOp(ModuleOp):
"""Specialization for the module op class."""
def __init__(self, *, loc=None, ip=None):
super().__init__(loc=loc, ip=ip)
body = self.regions[0].blocks.append()
@property
def body(self):
return self.regions[0].blocks[0]

View File

@@ -3,3 +3,326 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._func_ops_gen import *
from ._func_ops_gen import _Dialect
try:
from ..ir import *
from ._ods_common import (
get_default_loc_context as _get_default_loc_context,
_cext as _ods_cext,
)
import inspect
from typing import Any, List, Optional, Sequence, Union
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
RESULT_ATTRIBUTE_NAME = "res_attrs"
@_ods_cext.register_operation(_Dialect, replace=True)
class ConstantOp(ConstantOp):
"""Specialization for the constant op class."""
def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
super().__init__(result, value, loc=loc, ip=ip)
@property
def type(self):
return self.results[0].type
@_ods_cext.register_operation(_Dialect, replace=True)
class FuncOp(FuncOp):
"""Specialization for the func op class."""
def __init__(
self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
):
"""
Create a FuncOp with the provided `name`, `type`, and `visibility`.
- `name` is a string representing the function name.
- `type` is either a FunctionType or a pair of list describing inputs and
results.
- `visibility` is a string matching `public`, `private`, or `nested`. None
implies private visibility.
- `body_builder` is an optional callback, when provided a new entry block
is created and the callback is invoked with the new op as argument within
an InsertionPoint context already set for the block. The callback is
expected to insert a terminator in the block.
"""
sym_name = StringAttr.get(str(name))
# If the type is passed as a tuple, build a FunctionType on the fly.
if isinstance(type, tuple):
type = FunctionType.get(inputs=type[0], results=type[1])
type = TypeAttr.get(type)
sym_visibility = (
StringAttr.get(str(visibility)) if visibility is not None else None
)
super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
if body_builder:
entry_block = self.add_entry_block()
with InsertionPoint(entry_block):
body_builder(self)
@property
def is_external(self):
return len(self.regions[0].blocks) == 0
@property
def body(self):
return self.regions[0]
@property
def type(self):
return FunctionType(TypeAttr(self.attributes["function_type"]).value)
@property
def visibility(self):
return self.attributes["sym_visibility"]
@property
def name(self) -> StringAttr:
return StringAttr(self.attributes["sym_name"])
@property
def entry_block(self):
if self.is_external:
raise IndexError("External function does not have a body")
return self.regions[0].blocks[0]
def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
"""
Add an entry block to the function body using the function signature to
infer block arguments.
Returns the newly created block
"""
if not self.is_external:
raise IndexError("The function already has an entry block!")
self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
return self.body.blocks[0]
@property
def arg_attrs(self):
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
@arg_attrs.setter
def arg_attrs(self, attribute: Union[ArrayAttr, list]):
if isinstance(attribute, ArrayAttr):
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
else:
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
attribute, context=self.context
)
@property
def arguments(self):
return self.entry_block.arguments
@property
def result_attrs(self):
return self.attributes[RESULT_ATTRIBUTE_NAME]
@result_attrs.setter
def result_attrs(self, attribute: ArrayAttr):
self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
@classmethod
def from_py_func(
FuncOp,
*inputs: Type,
results: Optional[Sequence[Type]] = None,
name: Optional[str] = None,
):
"""Decorator to define an MLIR FuncOp specified as a python function.
Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
active for the current thread (i.e. established in a `with` block).
When applied as a decorator to a Python function, an entry block will
be constructed for the FuncOp with types as specified in `*inputs`. The
block arguments will be passed positionally to the Python function. In
addition, if the Python function accepts keyword arguments generally or
has a corresponding keyword argument, the following will be passed:
* `func_op`: The `func` op being defined.
By default, the function name will be the Python function `__name__`. This
can be overriden by passing the `name` argument to the decorator.
If `results` is not specified, then the decorator will implicitly
insert a `ReturnOp` with the `Value`'s returned from the decorated
function. It will also set the `FuncOp` type with the actual return
value types. If `results` is specified, then the decorated function
must return `None` and no implicit `ReturnOp` is added (nor are the result
types updated). The implicit behavior is intended for simple, single-block
cases, and users should specify result types explicitly for any complicated
cases.
The decorated function can further be called from Python and will insert
a `CallOp` at the then-current insertion point, returning either None (
if no return values), a unary Value (for one result), or a list of Values).
This mechanism cannot be used to emit recursive calls (by construction).
"""
def decorator(f):
from . import func
# Introspect the callable for optional features.
sig = inspect.signature(f)
has_arg_func_op = False
for param in sig.parameters.values():
if param.kind == param.VAR_KEYWORD:
has_arg_func_op = True
if param.name == "func_op" and (
param.kind == param.POSITIONAL_OR_KEYWORD
or param.kind == param.KEYWORD_ONLY
):
has_arg_func_op = True
# Emit the FuncOp.
implicit_return = results is None
symbol_name = name or f.__name__
function_type = FunctionType.get(
inputs=inputs, results=[] if implicit_return else results
)
func_op = FuncOp(name=symbol_name, type=function_type)
with InsertionPoint(func_op.add_entry_block()):
func_args = func_op.entry_block.arguments
func_kwargs = {}
if has_arg_func_op:
func_kwargs["func_op"] = func_op
return_values = f(*func_args, **func_kwargs)
if not implicit_return:
return_types = list(results)
assert return_values is None, (
"Capturing a python function with explicit `results=` "
"requires that the wrapped function returns None."
)
else:
# Coerce return values, add ReturnOp and rewrite func type.
if return_values is None:
return_values = []
elif isinstance(return_values, tuple):
return_values = list(return_values)
elif isinstance(return_values, Value):
# Returning a single value is fine, coerce it into a list.
return_values = [return_values]
elif isinstance(return_values, OpView):
# Returning a single operation is fine, coerce its results a list.
return_values = return_values.operation.results
elif isinstance(return_values, Operation):
# Returning a single operation is fine, coerce its results a list.
return_values = return_values.results
else:
return_values = list(return_values)
func.ReturnOp(return_values)
# Recompute the function type.
return_types = [v.type for v in return_values]
function_type = FunctionType.get(
inputs=inputs, results=return_types
)
func_op.attributes["function_type"] = TypeAttr.get(function_type)
def emit_call_op(*call_args):
call_op = func.CallOp(
return_types, FlatSymbolRefAttr.get(symbol_name), call_args
)
if return_types is None:
return None
elif len(return_types) == 1:
return call_op.result
else:
return call_op.results
wrapped = emit_call_op
wrapped.__name__ = f.__name__
wrapped.func_op = func_op
return wrapped
return decorator
@_ods_cext.register_operation(_Dialect, replace=True)
class CallOp(CallOp):
"""Specialization for the call op class."""
def __init__(
self,
calleeOrResults: Union[FuncOp, List[Type]],
argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
arguments: Optional[List] = None,
*,
loc=None,
ip=None,
):
"""Creates an call operation.
The constructor accepts three different forms:
1. A function op to be called followed by a list of arguments.
2. A list of result types, followed by the name of the function to be
called as string, following by a list of arguments.
3. A list of result types, followed by the name of the function to be
called as symbol reference attribute, followed by a list of arguments.
For example
f = func.FuncOp("foo", ...)
func.CallOp(f, [args])
func.CallOp([result_types], "foo", [args])
In all cases, the location and insertion point may be specified as keyword
arguments if not provided by the surrounding context managers.
"""
# TODO: consider supporting constructor "overloads", e.g., through a custom
# or pybind-provided metaclass.
if isinstance(calleeOrResults, FuncOp):
if not isinstance(argumentsOrCallee, list):
raise ValueError(
"when constructing a call to a function, expected "
+ "the second argument to be a list of call arguments, "
+ f"got {type(argumentsOrCallee)}"
)
if arguments is not None:
raise ValueError(
"unexpected third argument when constructing a call"
+ "to a function"
)
super().__init__(
calleeOrResults.type.results,
FlatSymbolRefAttr.get(
calleeOrResults.name.value, context=_get_default_loc_context(loc)
),
argumentsOrCallee,
loc=loc,
ip=ip,
)
return
if isinstance(argumentsOrCallee, list):
raise ValueError(
"when constructing a call to a function by name, "
+ "expected the second argument to be a string or a "
+ f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}"
)
if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
super().__init__(
calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip
)
elif isinstance(argumentsOrCallee, str):
super().__init__(
calleeOrResults,
FlatSymbolRefAttr.get(
argumentsOrCallee, context=_get_default_loc_context(loc)
),
arguments,
loc=loc,
ip=ip,
)

View File

@@ -310,7 +310,7 @@ def emit_named_structured_op(
)
# Set the index attributes used to compute the indexing maps.
named_op = getattr(linalg, op_class_name)(ins, outs, result_types)
named_op = getattr(linalg, op_class_name)(result_types, ins, outs)
for name, value in index_attrs.items():
named_op.operation.attributes[name] = value

View File

@@ -296,35 +296,39 @@ def quantized_matmul(
@linalg_structured_op
def matmul_transpose_a(A=TensorDef(T1, S.K, S.N),
B=TensorDef(T2, S.K, S.M),
C=TensorDef(U, S.M, S.N, output=True),
cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
"""Performs a matrix multiplication of two 2D inputs with lhs operand
transposed.
def matmul_transpose_a(
A=TensorDef(T1, S.K, S.N),
B=TensorDef(T2, S.K, S.M),
C=TensorDef(U, S.M, S.N, output=True),
cast=TypeFnAttrDef(default=TypeFn.cast_signed),
):
"""Performs a matrix multiplication of two 2D inputs with lhs operand
transposed.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
domain(D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n])
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
domain(D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n])
@linalg_structured_op
def matmul_transpose_b(A=TensorDef(T1, S.M, S.K),
B=TensorDef(T2, S.N, S.K),
C=TensorDef(U, S.M, S.N, output=True),
cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
"""Performs a matrix multiplication of two 2D inputs with rhs operand
transposed.
def matmul_transpose_b(
A=TensorDef(T1, S.M, S.K),
B=TensorDef(T2, S.N, S.K),
C=TensorDef(U, S.M, S.N, output=True),
cast=TypeFnAttrDef(default=TypeFn.cast_signed),
):
"""Performs a matrix multiplication of two 2D inputs with rhs operand
transposed.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
domain(D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k])
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
domain(D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k])
@linalg_structured_op
@@ -390,36 +394,41 @@ def batch_matmul(
@linalg_structured_op
def batch_matmul_transpose_a(A=TensorDef(T1, Batch, S.K, S.M),
B=TensorDef(T2, Batch, S.K, S.N),
C=TensorDef(U, Batch, S.M, S.N, output=True)):
"""Performs a batched matrix multiplication of two 3D inputs where lhs operand
has its non-batch dimensions transposed.
def batch_matmul_transpose_a(
A=TensorDef(T1, Batch, S.K, S.M),
B=TensorDef(T2, Batch, S.K, S.N),
C=TensorDef(U, Batch, S.M, S.N, output=True),
):
"""Performs a batched matrix multiplication of two 3D inputs where lhs operand
has its non-batch dimensions transposed.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
domain(D.b, D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) \
* TypeFn.cast_signed(U, B[D.b, D.k, D.n])
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
domain(D.b, D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) * TypeFn.cast_signed(
U, B[D.b, D.k, D.n]
)
@linalg_structured_op
def batch_matmul_transpose_b(A=TensorDef(T1, Batch, S.M, S.K),
B=TensorDef(T2, Batch, S.N, S.K),
C=TensorDef(U, Batch, S.M, S.N, output=True)):
"""Performs a batched matrix multiplication of two 3D inputs where rhs operand
has its non-batch dimensions transposed.
def batch_matmul_transpose_b(
A=TensorDef(T1, Batch, S.M, S.K),
B=TensorDef(T2, Batch, S.N, S.K),
C=TensorDef(U, Batch, S.M, S.N, output=True),
):
"""Performs a batched matrix multiplication of two 3D inputs where rhs operand
has its non-batch dimensions transposed.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
domain(D.b, D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.b, D.m,
D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
U, B[D.b, D.n, D.k])
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
domain(D.b, D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
U, B[D.b, D.n, D.k]
)
@linalg_structured_op

View File

@@ -3,3 +3,41 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._memref_ops_gen import *
from ._memref_ops_gen import _Dialect
try:
from ..ir import *
from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
_cext as _ods_cext,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Sequence, Union
@_ods_cext.register_operation(_Dialect, replace=True)
class LoadOp(LoadOp):
"""Specialization for the MemRef load operation."""
def __init__(
self,
memref: Union[Operation, OpView, Value],
indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
*,
loc=None,
ip=None,
):
"""Creates a memref load operation.
Args:
memref: the buffer to load from.
indices: the list of subscripts, may be empty for zero-dimensional
buffers.
loc: user-visible location of the operation.
ip: insertion point.
"""
indices_resolved = [] if indices is None else _get_op_results_or_values(indices)
super().__init__(memref, indices_resolved, loc=loc, ip=ip)

View File

@@ -2,4 +2,118 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from typing import Union
from ._ml_program_ops_gen import *
from ._ml_program_ops_gen import _Dialect
try:
from ..ir import *
from ._ods_common import (
get_default_loc_context as _get_default_loc_context,
_cext as _ods_cext,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
RESULT_ATTRIBUTE_NAME = "res_attrs"
@_ods_cext.register_operation(_Dialect, replace=True)
class FuncOp(FuncOp):
"""Specialization for the func op class."""
def __init__(
self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
):
"""
Create a FuncOp with the provided `name`, `type`, and `visibility`.
- `name` is a string representing the function name.
- `type` is either a FunctionType or a pair of list describing inputs and
results.
- `visibility` is a string matching `public`, `private`, or `nested`. None
implies private visibility.
- `body_builder` is an optional callback, when provided a new entry block
is created and the callback is invoked with the new op as argument within
an InsertionPoint context already set for the block. The callback is
expected to insert a terminator in the block.
"""
sym_name = StringAttr.get(str(name))
# If the type is passed as a tuple, build a FunctionType on the fly.
if isinstance(type, tuple):
type = FunctionType.get(inputs=type[0], results=type[1])
type = TypeAttr.get(type)
sym_visibility = (
StringAttr.get(str(visibility)) if visibility is not None else None
)
super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
if body_builder:
entry_block = self.add_entry_block()
with InsertionPoint(entry_block):
body_builder(self)
@property
def is_external(self):
return len(self.regions[0].blocks) == 0
@property
def body(self):
return self.regions[0]
@property
def type(self):
return FunctionType(TypeAttr(self.attributes["function_type"]).value)
@property
def visibility(self):
return self.attributes["sym_visibility"]
@property
def name(self) -> StringAttr:
return StringAttr(self.attributes["sym_name"])
@property
def entry_block(self):
if self.is_external:
raise IndexError("External function does not have a body")
return self.regions[0].blocks[0]
def add_entry_block(self):
"""
Add an entry block to the function body using the function signature to
infer block arguments.
Returns the newly created block
"""
if not self.is_external:
raise IndexError("The function already has an entry block!")
self.body.blocks.append(*self.type.inputs)
return self.body.blocks[0]
@property
def arg_attrs(self):
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
@arg_attrs.setter
def arg_attrs(self, attribute: Union[ArrayAttr, list]):
if isinstance(attribute, ArrayAttr):
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
else:
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
attribute, context=self.context
)
@property
def arguments(self):
return self.entry_block.arguments
@property
def result_attrs(self):
return self.attributes[RESULT_ATTRIBUTE_NAME]
@result_attrs.setter
def result_attrs(self, attribute: ArrayAttr):
self.attributes[RESULT_ATTRIBUTE_NAME] = attribute

View File

@@ -3,4 +3,289 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._pdl_ops_gen import *
from ._pdl_ops_gen import _Dialect
from .._mlir_libs._mlirDialectsPDL import *
try:
from ..ir import *
from ..dialects import pdl
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Union, Optional, Sequence, Mapping
from ._ods_common import (
get_op_result_or_value as _get_value,
get_op_results_or_values as _get_values,
_cext as _ods_cext,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class ApplyNativeConstraintOp(ApplyNativeConstraintOp):
"""Specialization for PDL apply native constraint op class."""
def __init__(
self,
name: Union[str, StringAttr],
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
*,
loc=None,
ip=None,
):
if args is None:
args = []
args = _get_values(args)
super().__init__(name, args, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
class ApplyNativeRewriteOp(ApplyNativeRewriteOp):
"""Specialization for PDL apply native rewrite op class."""
def __init__(
self,
results: Sequence[Type],
name: Union[str, StringAttr],
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
*,
loc=None,
ip=None,
):
if args is None:
args = []
args = _get_values(args)
super().__init__(results, name, args, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
class AttributeOp(AttributeOp):
"""Specialization for PDL attribute op class."""
def __init__(
self,
valueType: Optional[Union[OpView, Operation, Value]] = None,
value: Optional[Attribute] = None,
*,
loc=None,
ip=None,
):
valueType = valueType if valueType is None else _get_value(valueType)
result = pdl.AttributeType.get()
super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
class EraseOp(EraseOp):
"""Specialization for PDL erase op class."""
def __init__(
self,
operation: Optional[Union[OpView, Operation, Value]] = None,
*,
loc=None,
ip=None,
):
operation = _get_value(operation)
super().__init__(operation, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
class OperandOp(OperandOp):
"""Specialization for PDL operand op class."""
def __init__(
self,
type: Optional[Union[OpView, Operation, Value]] = None,
*,
loc=None,
ip=None,
):
type = type if type is None else _get_value(type)
result = pdl.ValueType.get()
super().__init__(result, valueType=type, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
class OperandsOp(OperandsOp):
"""Specialization for PDL operands op class."""
def __init__(
self,
types: Optional[Union[OpView, Operation, Value]] = None,
*,
loc=None,
ip=None,
):
types = types if types is None else _get_value(types)
result = pdl.RangeType.get(pdl.ValueType.get())
super().__init__(result, valueType=types, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
class OperationOp(OperationOp):
"""Specialization for PDL operand op class."""
def __init__(
self,
name: Optional[Union[str, StringAttr]] = None,
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
attributes: Optional[Mapping[str, Union[OpView, Operation, Value]]] = None,
types: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
*,
loc=None,
ip=None,
):
if types is None:
types = []
if attributes is None:
attributes = {}
if args is None:
args = []
args = _get_values(args)
attrNames = []
attrValues = []
for attrName, attrValue in attributes.items():
attrNames.append(StringAttr.get(attrName))
attrValues.append(_get_value(attrValue))
attrNames = ArrayAttr.get(attrNames)
types = _get_values(types)
result = pdl.OperationType.get()
super().__init__(
result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip
)
@_ods_cext.register_operation(_Dialect, replace=True)
class PatternOp(PatternOp):
"""Specialization for PDL pattern op class."""
def __init__(
self,
benefit: Union[IntegerAttr, int],
name: Optional[Union[StringAttr, str]] = None,
*,
loc=None,
ip=None,
):
"""Creates an PDL `pattern` operation."""
super().__init__(benefit, sym_name=name, loc=loc, ip=ip)
self.regions[0].blocks.append()
@property
def body(self):
"""Return the body (block) of the pattern."""
return self.regions[0].blocks[0]
@_ods_cext.register_operation(_Dialect, replace=True)
class ReplaceOp(ReplaceOp):
"""Specialization for PDL replace op class."""
def __init__(
self,
op: Union[OpView, Operation, Value],
*,
with_op: Optional[Union[OpView, Operation, Value]] = None,
with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
loc=None,
ip=None,
):
if with_values is None:
with_values = []
op = _get_value(op)
with_op = with_op if with_op is None else _get_value(with_op)
with_values = _get_values(with_values)
super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
class ResultOp(ResultOp):
"""Specialization for PDL result op class."""
def __init__(
self,
parent: Union[OpView, Operation, Value],
index: Union[IntegerAttr, int],
*,
loc=None,
ip=None,
):
parent = _get_value(parent)
result = pdl.ValueType.get()
super().__init__(result, parent, index, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
class ResultsOp(ResultsOp):
"""Specialization for PDL results op class."""
def __init__(
self,
result: Type,
parent: Union[OpView, Operation, Value],
index: Optional[Union[IntegerAttr, int]] = None,
*,
loc=None,
ip=None,
):
parent = _get_value(parent)
super().__init__(result, parent, index=index, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
class RewriteOp(RewriteOp):
"""Specialization for PDL rewrite op class."""
def __init__(
self,
root: Optional[Union[OpView, Operation, Value]] = None,
name: Optional[Union[StringAttr, str]] = None,
args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
*,
loc=None,
ip=None,
):
if args is None:
args = []
root = root if root is None else _get_value(root)
args = _get_values(args)
super().__init__(args, root=root, name=name, loc=loc, ip=ip)
def add_body(self):
"""Add body (block) to the rewrite."""
self.regions[0].blocks.append()
return self.body
@property
def body(self):
"""Return the body (block) of the rewrite."""
return self.regions[0].blocks[0]
@_ods_cext.register_operation(_Dialect, replace=True)
class TypeOp(TypeOp):
"""Specialization for PDL type op class."""
def __init__(
self, constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None
):
result = pdl.TypeType.get()
super().__init__(result, constantType=constantType, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
class TypesOp(TypesOp):
"""Specialization for PDL types op class."""
def __init__(
self,
constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None,
*,
loc=None,
ip=None,
):
if constantTypes is None:
constantTypes = []
result = pdl.RangeType.get(pdl.TypeType.get())
super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)

View File

@@ -3,7 +3,12 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._python_test_ops_gen import *
from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestIntegerRankedTensorType
from .._mlir_libs._mlirPythonTest import (
TestAttr,
TestType,
TestTensorValue,
TestIntegerRankedTensorType,
)
def register_python_test_dialect(context, load=True):

View File

@@ -2,11 +2,122 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from typing import Optional, Sequence
from ._scf_ops_gen import *
from ._scf_ops_gen import _Dialect
from .arith import constant
from ..ir import *
try:
from ..ir import *
from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
_cext as _ods_cext,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Sequence, Union
_ForOp = ForOp
@_ods_cext.register_operation(_Dialect, replace=True)
class ForOp(_ForOp):
"""Specialization for the SCF for op class."""
def __init__(
self,
lower_bound,
upper_bound,
step,
iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
*,
loc=None,
ip=None,
):
"""Creates an SCF `for` operation.
- `lower_bound` is the value to use as lower bound of the loop.
- `upper_bound` is the value to use as upper bound of the loop.
- `step` is the value to use as loop step.
- `iter_args` is a list of additional loop-carried arguments or an operation
producing them as results.
"""
if iter_args is None:
iter_args = []
iter_args = _get_op_results_or_values(iter_args)
results = [arg.type for arg in iter_args]
super(_ForOp, self).__init__(
self.build_generic(
regions=1,
results=results,
operands=[
_get_op_result_or_value(o) for o in [lower_bound, upper_bound, step]
]
+ list(iter_args),
loc=loc,
ip=ip,
)
)
self.regions[0].blocks.append(self.operands[0].type, *results)
@property
def body(self):
"""Returns the body (block) of the loop."""
return self.regions[0].blocks[0]
@property
def induction_variable(self):
"""Returns the induction variable of the loop."""
return self.body.arguments[0]
@property
def inner_iter_args(self):
"""Returns the loop-carried arguments usable within the loop.
To obtain the loop-carried operands, use `iter_args`.
"""
return self.body.arguments[1:]
_IfOp = IfOp
@_ods_cext.register_operation(_Dialect, replace=True)
class IfOp(_IfOp):
"""Specialization for the SCF if op class."""
def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
"""Creates an SCF `if` operation.
- `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
- `hasElse` determines whether the if operation has the else branch.
"""
operands = []
operands.append(cond)
results = []
results.extend(results_)
super(_IfOp, self).__init__(
self.build_generic(
regions=2, results=results, operands=operands, loc=loc, ip=ip
)
)
self.regions[0].blocks.append(*[])
if hasElse:
self.regions[1].blocks.append(*[])
@property
def then_block(self):
"""Returns the then block of the if operation."""
return self.regions[0].blocks[0]
@property
def else_block(self):
"""Returns the else block of the if operation."""
return self.regions[1].blocks[0]
def for_(

View File

@@ -3,3 +3,40 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._tensor_ops_gen import *
from ._tensor_ops_gen import _Dialect
try:
from ..ir import *
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Sequence, Union
from ._ods_common import _cext as _ods_cext
@_ods_cext.register_operation(_Dialect, replace=True)
class EmptyOp(EmptyOp):
"""Extends the tensor.empty op."""
def __init__(
self,
sizes: Sequence[Union[int, Value]],
element_type: Type,
*,
loc=None,
ip=None,
):
"""Constructs an `empty` with mixed static/dynamic sizes."""
# TODO: Refactor the EmptyOp to take an element type attribute and
# then use normal result type inference, unifying the Python and C++ side
# with a standard mechanism (versus stashing that in builders).
dynamic_sizes = []
static_sizes = []
for s in sizes:
if isinstance(s, int):
static_sizes.append(s)
else:
static_sizes.append(ShapedType.get_dynamic_size())
dynamic_sizes.append(s)
result_type = RankedTensorType.get(static_sizes, element_type)
super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip)

View File

@@ -4,4 +4,174 @@
from .._transform_enum_gen import *
from .._transform_ops_gen import *
from .._transform_ops_gen import _Dialect
from ..._mlir_libs._mlirDialectsTransform import *
try:
from ...ir import *
from .._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
_cext as _ods_cext,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Sequence, Union
@_ods_cext.register_operation(_Dialect, replace=True)
class CastOp(CastOp):
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
loc=None,
ip=None,
):
super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
class ApplyPatternsOp(ApplyPatternsOp):
def __init__(
self,
target: Union[Operation, Value, OpView],
*,
loc=None,
ip=None,
):
super().__init__(target, loc=loc, ip=ip)
self.regions[0].blocks.append()
@property
def patterns(self) -> Block:
return self.regions[0].blocks[0]
@_ods_cext.register_operation(_Dialect, replace=True)
class GetParentOp(GetParentOp):
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
isolated_from_above: bool = False,
op_name: Optional[str] = None,
deduplicate: bool = False,
loc=None,
ip=None,
):
super().__init__(
result_type,
_get_op_result_or_value(target),
isolated_from_above=isolated_from_above,
op_name=op_name,
deduplicate=deduplicate,
loc=loc,
ip=ip,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class MergeHandlesOp(MergeHandlesOp):
def __init__(
self,
handles: Sequence[Union[Operation, Value]],
*,
deduplicate: bool = False,
loc=None,
ip=None,
):
super().__init__(
[_get_op_result_or_value(h) for h in handles],
deduplicate=deduplicate,
loc=loc,
ip=ip,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class ReplicateOp(ReplicateOp):
def __init__(
self,
pattern: Union[Operation, Value],
handles: Sequence[Union[Operation, Value]],
*,
loc=None,
ip=None,
):
super().__init__(
[_get_op_result_or_value(h).type for h in handles],
_get_op_result_or_value(pattern),
[_get_op_result_or_value(h) for h in handles],
loc=loc,
ip=ip,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class SequenceOp(SequenceOp):
def __init__(
self,
failure_propagation_mode,
results: Sequence[Type],
target: Union[Operation, Value, Type],
extra_bindings: Optional[
Union[Sequence[Value], Sequence[Type], Operation, OpView]
] = None,
):
root = (
_get_op_result_or_value(target)
if isinstance(target, (Operation, Value))
else None
)
root_type = root.type if not isinstance(target, Type) else target
if extra_bindings is None:
extra_bindings = []
if isinstance(extra_bindings, (Operation, OpView)):
extra_bindings = _get_op_results_or_values(extra_bindings)
extra_binding_types = []
if len(extra_bindings) != 0:
if isinstance(extra_bindings[0], Type):
extra_binding_types = extra_bindings
extra_bindings = []
else:
extra_binding_types = [v.type for v in extra_bindings]
super().__init__(
results_=results,
failure_propagation_mode=failure_propagation_mode,
root=root,
extra_bindings=extra_bindings,
)
self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
@property
def body(self) -> Block:
return self.regions[0].blocks[0]
@property
def bodyTarget(self) -> Value:
return self.body.arguments[0]
@property
def bodyExtraArgs(self) -> BlockArgumentList:
return self.body.arguments[1:]
@_ods_cext.register_operation(_Dialect, replace=True)
class YieldOp(YieldOp):
def __init__(
self,
operands: Optional[Union[Operation, Sequence[Value]]] = None,
*,
loc=None,
ip=None,
):
if operands is None:
operands = []
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)

View File

@@ -3,3 +3,132 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._bufferization_transform_ops_gen import *
from .._bufferization_transform_ops_gen import _Dialect
try:
from ...ir import *
from ...dialects import transform
from .._ods_common import _cext as _ods_cext
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from enum import Enum
from typing import Optional, overload, Union
@_ods_cext.register_operation(_Dialect, replace=True)
class EmptyTensorToAllocTensorOp(EmptyTensorToAllocTensorOp):
"""Specialization for EmptyTensorToAllocTensorOp class."""
@overload
def __init__(
self,
transformed_type: Type,
target: Union[Operation, OpView, Value],
*,
loc=None,
ip=None,
):
...
@overload
def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None):
...
def __init__(
self,
transformed_type_or_target: Type,
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
*,
loc=None,
ip=None,
):
if isinstance(transformed_type_or_target, Type):
transformed_type = transformed_type_or_target
target = target_or_none
else:
transformed_type = transform.OperationType.get("bufferization.alloc_tensor")
target = transformed_type_or_target
super().__init__(
transformed_type,
target,
loc=loc,
ip=ip,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class OneShotBufferizeOp(OneShotBufferizeOp):
"""Specialization for OneShotBufferizeOp class."""
@overload
def __init__(
self,
transformed_type: Type,
target: Union[Operation, OpView, Value],
*,
allow_return_allocs_from_loops: Optional[bool] = None,
allow_unknown_ops: Optional[bool] = None,
bufferize_function_boundaries: Optional[bool] = None,
function_boundary_type_conversion: Optional[Enum] = None,
memcpy_op: Optional[str] = None,
print_conflicts: Optional[bool] = None,
test_analysis_only: Optional[bool] = None,
loc=None,
ip=None,
):
...
@overload
def __init__(
self,
target: Union[Operation, OpView, Value],
*,
allow_return_allocs_from_loops: Optional[bool] = None,
allow_unknown_ops: Optional[bool] = None,
bufferize_function_boundaries: Optional[bool] = None,
function_boundary_type_conversion: Optional[Enum] = None,
memcpy_op: Optional[str] = None,
print_conflicts: Optional[bool] = None,
test_analysis_only: Optional[bool] = None,
loc=None,
ip=None,
):
...
def __init__(
self,
transformed_type_or_target: Type,
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
*,
allow_return_allocs_from_loops: Optional[bool] = None,
allow_unknown_ops: Optional[bool] = None,
bufferize_function_boundaries: Optional[bool] = None,
function_boundary_type_conversion: Optional[Enum] = None,
memcpy_op: Optional[str] = None,
print_conflicts: Optional[bool] = None,
test_analysis_only: Optional[bool] = None,
loc=None,
ip=None,
):
if isinstance(transformed_type_or_target, Type):
transformed_type = transformed_type_or_target
target = target_or_none
else:
transformed_type = transform.AnyOpType.get()
target = transformed_type_or_target
super().__init__(
transformed_type,
target,
allow_return_allocs_from_loops=allow_return_allocs_from_loops,
allow_unknown_ops=allow_unknown_ops,
bufferize_function_boundaries=bufferize_function_boundaries,
function_boundary_type_conversion=function_boundary_type_conversion,
memcpy_op=memcpy_op,
print_conflicts=print_conflicts,
test_analysis_only=test_analysis_only,
loc=loc,
ip=ip,
)

View File

@@ -3,3 +3,128 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._gpu_transform_ops_gen import *
from .._gpu_transform_ops_gen import _Dialect
try:
from ...ir import *
from ...dialects import transform
from .._ods_common import _cext as _ods_cext
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Sequence, Union, overload
@_ods_cext.register_operation(_Dialect, replace=True)
class MapForallToBlocks(MapForallToBlocks):
"""Specialization for MapForallToBlocks class."""
@overload
def __init__(
self,
result_type: Type,
target: Union[Operation, OpView, Value],
*,
grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
loc=None,
ip=None,
):
...
@overload
def __init__(
self,
target: Union[Operation, OpView, Value],
*,
grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
loc=None,
ip=None,
):
...
def __init__(
self,
result_type_or_target: Union[Operation, OpView, Type, Value],
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
*,
grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
loc=None,
ip=None,
):
if isinstance(result_type_or_target, Type):
result_type = result_type_or_target
target = target_or_none
else:
result_type = transform.AnyOpType.get()
target = result_type_or_target
super().__init__(
result_type,
target,
grid_dims=grid_dims,
generate_gpu_launch=generate_gpu_launch,
loc=loc,
ip=ip,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class MapNestedForallToThreads(MapNestedForallToThreads):
"""Specialization for MapNestedForallToThreads class."""
@overload
def __init__(
self,
result_type: Type,
target: Union[Operation, OpView, Value],
*,
block_dims: Optional[Sequence[int]] = None,
warp_size: Optional[Sequence[int]] = None,
sync_after_distribute: Optional[bool] = None,
loc=None,
ip=None,
):
...
@overload
def __init__(
self,
target: Union[Operation, OpView, Value],
*,
block_dims: Optional[Sequence[int]] = None,
warp_size: Optional[Sequence[int]] = None,
sync_after_distribute: Optional[bool] = None,
loc=None,
ip=None,
):
...
def __init__(
self,
result_type_or_target: Union[Operation, OpView, Value, Type],
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
*,
block_dims: Optional[Union[Sequence[int], Attribute]] = None,
warp_size: Optional[Union[Sequence[int], Attribute]] = None,
sync_after_distribute: Optional[bool] = None,
loc=None,
ip=None,
):
if isinstance(result_type_or_target, Type):
result_type = result_type_or_target
target = target_or_none
else:
result_type = result_type_or_target.type
target = result_type_or_target
super().__init__(
result_type,
target,
block_dims=block_dims,
warp_size=warp_size,
sync_after_distribute=sync_after_distribute,
loc=loc,
ip=ip,
)

View File

@@ -3,3 +3,143 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._loop_transform_ops_gen import *
from .._loop_transform_ops_gen import _Dialect
try:
from ...ir import *
from .._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
_cext as _ods_cext,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Union
@_ods_cext.register_operation(_Dialect, replace=True)
class GetParentForOp(GetParentForOp):
"""Extension for GetParentForOp."""
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
num_loops: Optional[int] = None,
ip=None,
loc=None,
):
if num_loops is None:
num_loops = 1
super().__init__(
result_type,
_get_op_result_or_value(target),
num_loops=num_loops,
ip=ip,
loc=loc,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class LoopOutlineOp(LoopOutlineOp):
"""Extension for LoopOutlineOp."""
def __init__(
self,
function_type: Type,
call_type: Type,
target: Union[Operation, Value],
*,
func_name: Union[str, StringAttr],
ip=None,
loc=None,
):
super().__init__(
function_type,
call_type,
_get_op_result_or_value(target),
func_name=(
func_name
if isinstance(func_name, StringAttr)
else StringAttr.get(func_name)
),
ip=ip,
loc=loc,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class LoopPeelOp(LoopPeelOp):
"""Extension for LoopPeelOp."""
def __init__(
self,
main_loop_type: Type,
remainder_loop_type: Type,
target: Union[Operation, Value],
*,
fail_if_already_divisible: Union[bool, BoolAttr] = False,
ip=None,
loc=None,
):
super().__init__(
main_loop_type,
remainder_loop_type,
_get_op_result_or_value(target),
fail_if_already_divisible=(
fail_if_already_divisible
if isinstance(fail_if_already_divisible, BoolAttr)
else BoolAttr.get(fail_if_already_divisible)
),
ip=ip,
loc=loc,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class LoopPipelineOp(LoopPipelineOp):
"""Extension for LoopPipelineOp."""
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
iteration_interval: Optional[Union[int, IntegerAttr]] = None,
read_latency: Optional[Union[int, IntegerAttr]] = None,
ip=None,
loc=None,
):
if iteration_interval is None:
iteration_interval = 1
if read_latency is None:
read_latency = 10
super().__init__(
result_type,
_get_op_result_or_value(target),
iteration_interval=iteration_interval,
read_latency=read_latency,
ip=ip,
loc=loc,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class LoopUnrollOp(LoopUnrollOp):
"""Extension for LoopUnrollOp."""
def __init__(
self,
target: Union[Operation, Value],
*,
factor: Union[int, IntegerAttr],
ip=None,
loc=None,
):
super().__init__(
_get_op_result_or_value(target),
factor=factor,
ip=ip,
loc=loc,
)

View File

@@ -3,3 +3,118 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._memref_transform_ops_gen import *
from .._memref_transform_ops_gen import _Dialect
try:
from ...ir import *
from ...dialects import transform
from .._ods_common import _cext as _ods_cext
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, overload, Union
@_ods_cext.register_operation(_Dialect, replace=True)
class MemRefAllocaToGlobalOp(MemRefAllocaToGlobalOp):
"""Specialization for MemRefAllocaToGlobalOp class."""
@overload
def __init__(
self,
get_global_type: Type,
global_type: Type,
alloca: Union[Operation, OpView, Value],
*,
loc=None,
ip=None,
):
...
@overload
def __init__(self, alloca: Union[Operation, OpView, Value], *, loc=None, ip=None):
...
def __init__(
self,
get_global_type_or_alloca: Union[Operation, OpView, Type, Value],
global_type_or_none: Optional[Type] = None,
alloca_or_none: Optional[Union[Operation, OpView, Value]] = None,
*,
loc=None,
ip=None,
):
if isinstance(get_global_type_or_alloca, Type):
get_global_type = get_global_type_or_alloca
global_type = global_type_or_none
alloca = alloca_or_none
else:
get_global_type = transform.AnyOpType.get()
global_type = transform.AnyOpType.get()
alloca = get_global_type_or_alloca
super().__init__(
get_global_type,
global_type,
alloca,
loc=loc,
ip=ip,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class MemRefMultiBufferOp(MemRefMultiBufferOp):
"""Specialization for MemRefMultiBufferOp class."""
@overload
def __init__(
self,
transformed_type: Type,
target: Union[Operation, OpView, Value],
factor: Union[int, IntegerAttr],
*,
skip_analysis: Optional[bool] = None,
loc=None,
ip=None,
):
...
@overload
def __init__(
self,
target: Union[Operation, OpView, Value],
factor: Union[int, IntegerAttr],
*,
skip_analysis: Optional[bool] = None,
loc=None,
ip=None,
):
...
def __init__(
self,
transformed_type_or_target: Type,
target_or_factor: Union[int, IntegerAttr, Operation, OpView, Value] = None,
factor_or_none: Optional[Union[int, IntegerAttr]] = None,
*,
skip_analysis: Optional[bool] = None,
loc=None,
ip=None,
):
if isinstance(transformed_type_or_target, Type):
transformed_type = transformed_type_or_target
target = target_or_factor
factor = factor_or_none
else:
transformed_type = transform.AnyOpType.get()
target = transformed_type_or_target
factor = target_or_factor
super().__init__(
transformed_type,
target,
factor,
skip_analysis=skip_analysis,
loc=loc,
ip=ip,
)

View File

@@ -3,3 +3,53 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._transform_pdl_extension_ops_gen import *
from .._transform_pdl_extension_ops_gen import _Dialect
try:
from ...ir import *
from .._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
_cext as _ods_cext,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Union
@_ods_cext.register_operation(_Dialect, replace=True)
class PDLMatchOp(PDLMatchOp):
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
pattern_name: Union[Attribute, str],
*,
loc=None,
ip=None,
):
super().__init__(
result_type,
_get_op_result_or_value(target),
pattern_name,
loc=loc,
ip=ip,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class WithPDLPatternsOp(WithPDLPatternsOp):
def __init__(self, target: Union[Operation, Value, Type], *, loc=None, ip=None):
root = _get_op_result_or_value(target) if not isinstance(target, Type) else None
root_type = target if isinstance(target, Type) else root.type
super().__init__(root=root, loc=loc, ip=ip)
self.regions[0].blocks.append(root_type)
@property
def body(self) -> Block:
return self.regions[0].blocks[0]
@property
def bodyTarget(self) -> Value:
return self.body.arguments[0]

View File

@@ -3,4 +3,777 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._structured_transform_ops_gen import *
from .._structured_transform_ops_gen import _Dialect
from .._structured_transform_enum_gen import *
try:
from ...ir import *
from ...dialects import transform
from .._ods_common import _cext as _ods_cext
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import List, Optional, Sequence, Tuple, Union, overload
StaticIntLike = Union[int, IntegerAttr]
ValueLike = Union[Operation, OpView, Value]
MixedInt = Union[StaticIntLike, ValueLike]
IntOrAttrList = Sequence[Union[IntegerAttr, int]]
OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
BoolOrAttrList = Sequence[Union[BoolAttr, bool]]
OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]]
MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]]
def _dispatch_dynamic_index_list(
indices: Union[DynamicIndexList, ArrayAttr],
) -> Tuple[List[ValueLike], Union[List[int], ArrayAttr], List[bool]]:
"""Dispatches a list of indices to the appropriate form.
This is similar to the custom `DynamicIndexList` directive upstream:
provided indices may be in the form of dynamic SSA values or static values,
and they may be scalable (i.e., as a singleton list) or not. This function
dispatches each index into its respective form. It also extracts the SSA
values and static indices from various similar structures, respectively.
"""
dynamic_indices = []
static_indices = [ShapedType.get_dynamic_size()] * len(indices)
scalable_indices = [False] * len(indices)
# ArrayAttr: Extract index values.
if isinstance(indices, ArrayAttr):
indices = [idx for idx in indices]
def process_nonscalable_index(i, index):
"""Processes any form of non-scalable index.
Returns False if the given index was scalable and thus remains
unprocessed; True otherwise.
"""
if isinstance(index, int):
static_indices[i] = index
elif isinstance(index, IntegerAttr):
static_indices[i] = index.value # pytype: disable=attribute-error
elif isinstance(index, (Operation, Value, OpView)):
dynamic_indices.append(index)
else:
return False
return True
# Process each index at a time.
for i, index in enumerate(indices):
if not process_nonscalable_index(i, index):
# If it wasn't processed, it must be a scalable index, which is
# provided as a Sequence of one value, so extract and process that.
scalable_indices[i] = True
assert len(index) == 1
ret = process_nonscalable_index(i, index[0])
assert ret
return dynamic_indices, static_indices, scalable_indices
# Dispatches `MixedValues` that all represents integers in various forms into
# the following three categories:
# - `dynamic_values`: a list of `Value`s, potentially from op results;
# - `packed_values`: a value handle, potentially from an op result, associated
# to one or more payload operations of integer type;
# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
# The input is in the form for `packed_values`, only that result is set and the
# other two are empty. Otherwise, the input can be a mix of the other two forms,
# and for each dynamic value, a special value is added to the `static_values`.
def _dispatch_mixed_values(
values: MixedValues,
) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]:
dynamic_values = []
packed_values = None
static_values = None
if isinstance(values, ArrayAttr):
static_values = values
elif isinstance(values, (Operation, Value, OpView)):
packed_values = values
else:
static_values = []
for size in values or []:
if isinstance(size, int):
static_values.append(size)
else:
static_values.append(ShapedType.get_dynamic_size())
dynamic_values.append(size)
static_values = DenseI64ArrayAttr.get(static_values)
return (dynamic_values, packed_values, static_values)
def _get_value_or_attribute_value(
value_or_attr: Union[any, Attribute, ArrayAttr]
) -> any:
if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
return value_or_attr.value
if isinstance(value_or_attr, ArrayAttr):
return _get_value_list(value_or_attr)
return value_or_attr
def _get_value_list(
sequence_or_array_attr: Union[Sequence[any], ArrayAttr]
) -> Sequence[any]:
return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr:
if values is None:
return None
# Turn into a Python list of Python ints.
values = _get_value_list(values)
# Make an ArrayAttr of IntegerAttrs out of it.
return ArrayAttr.get(
[IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
)
def _get_int_array_array_attr(
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
) -> ArrayAttr:
"""Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
The input has to be a collection of collection of integers, where any
Python Sequence and ArrayAttr are admissible collections and Python ints and
any IntegerAttr are admissible integers. Both levels of collections are
turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
If the input is None, an empty ArrayAttr is returned.
"""
if values is None:
return None
# Make sure the outer level is a list.
values = _get_value_list(values)
# The inner level is now either invalid or a mixed sequence of ArrayAttrs and
# Sequences. Make sure the nested values are all lists.
values = [_get_value_list(nested) for nested in values]
# Turn each nested list into an ArrayAttr.
values = [_get_int_array_attr(nested) for nested in values]
# Turn the outer list into an ArrayAttr.
return ArrayAttr.get(values)
@_ods_cext.register_operation(_Dialect, replace=True)
class BufferizeToAllocationOp(BufferizeToAllocationOp):
"""Specialization for BufferizeToAllocationOp class."""
def __init__(
self,
target: Union[Operation, OpView, Value],
*,
memory_space: Optional[Union[int, str, Attribute]] = None,
memcpy_op: Optional[str] = None,
alloc_op: Optional[str] = None,
bufferize_destination_only: Optional[bool] = None,
loc=None,
ip=None,
):
# No other types are allowed, so hard-code those here.
allocated_buffer_type = transform.AnyValueType.get()
new_ops_type = transform.AnyOpType.get()
if isinstance(memory_space, int):
memory_space = str(memory_space)
if isinstance(memory_space, str):
memory_space = Attribute.parse(memory_space)
super().__init__(
allocated_buffer_type,
new_ops_type,
target,
memory_space=memory_space,
memcpy_op=memcpy_op,
alloc_op=alloc_op,
bufferize_destination_only=bufferize_destination_only,
loc=loc,
ip=ip,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class DecomposeOp(DecomposeOp):
"""Specialization for DecomposeOp class."""
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
transformed_type = transform.AnyOpType.get()
super().__init__(transformed_type, target, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
class FuseIntoContainingOp(FuseIntoContainingOp):
"""Specialization for FuseIntoContainingOp class."""
@overload
def __init__(
self,
fused_op_type: Type,
new_containing_op_type: Type,
producer_op: Union[Operation, OpView, Value],
containing_op: Union[Operation, OpView, Value],
*,
loc=None,
ip=None,
):
...
@overload
def __init__(
self,
producer_op: Union[Operation, OpView, Value],
containing_op: Union[Operation, OpView, Value],
*,
loc=None,
ip=None,
):
...
def __init__(
self,
fused_op_type_or_producer_op: Union[Operation, OpView, Type, Value],
new_containing_op_type_or_containing_op: Union[Operation, OpView, Type, Value],
producer_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
containing_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
*,
loc=None,
ip=None,
):
if isinstance(fused_op_type_or_producer_op, Type):
if not isinstance(new_containing_op_type_or_containing_op, Type):
raise TypeError(
"If 'fused_op_type_or_producer_op' is a type, then "
"'new_containing_op_type_or_containing_op' is expected "
"to be one as well."
)
fused_op_type = fused_op_type_or_producer_op
new_containing_op_type = new_containing_op_type_or_containing_op
producer_op = producer_op_or_none
containing_op = containing_op_or_none
else:
fused_op_type = transform.AnyOpType.get()
new_containing_op_type = transform.AnyOpType.get()
producer_op = fused_op_type_or_producer_op
containing_op = new_containing_op_type_or_containing_op
super().__init__(
fused_op_type,
new_containing_op_type,
producer_op,
containing_op,
loc=loc,
ip=ip,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class GeneralizeOp(GeneralizeOp):
"""Specialization for GeneralizeOp class."""
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
transformed_type = transform.AnyOpType.get()
super().__init__(transformed_type, target, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
class InterchangeOp(InterchangeOp):
"""Specialization for InterchangeOp class."""
def __init__(
self,
target: Union[Operation, Value],
*,
iterator_interchange: OptionalIntList = None,
loc=None,
ip=None,
):
transformed_type = transform.AnyOpType.get()
super().__init__(
transformed_type,
target,
iterator_interchange=iterator_interchange,
loc=loc,
ip=ip,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class MapCopyToThreadsOp(MapCopyToThreadsOp):
"""Specialization for MapCopyToThreadsOp class."""
@overload
def __init__(
self,
forall_op_type: Type,
tiled_op_type: Type,
target: Union[Operation, OpView, Value],
*,
total_num_threads: Union[int, IntegerAttr],
desired_bit_alignment: Union[int, IntegerAttr],
loc=None,
ip=None,
):
...
@overload
def __init__(
self,
target: Union[Operation, OpView, Value],
*,
total_num_threads: Union[int, IntegerAttr],
desired_bit_alignment: Union[int, IntegerAttr],
loc=None,
ip=None,
):
...
def __init__(
self,
forall_op_type_or_target: Union[Operation, OpView, Type, Value],
tiled_op_type_or_none: Optional[Type] = None,
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
*,
total_num_threads: Union[int, IntegerAttr],
desired_bit_alignment: Union[int, IntegerAttr],
loc=None,
ip=None,
):
if isinstance(forall_op_type_or_target, Type):
forall_op_type = forall_op_type_or_target
tiled_op_type = tiled_op_type_or_none
target = target_or_none
else:
forall_op_type = transform.AnyOpType.get()
tiled_op_type = transform.AnyOpType.get()
target = forall_op_type_or_target
super().__init__(
forall_op_type,
tiled_op_type,
target,
total_num_threads=total_num_threads,
desired_bit_alignment=desired_bit_alignment,
loc=loc,
ip=ip,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class VectorizeOp(VectorizeOp):
"""Specialization for VectorizeOp class."""
def __init__(
self,
target: Union[Operation, OpView, Value],
vector_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
*,
vectorize_nd_extract: Optional[bool] = None,
scalable_sizes: OptionalBoolList = None,
static_vector_sizes: OptionalIntList = None,
loc=None,
ip=None,
):
if (
scalable_sizes is None
and static_vector_sizes is None
and vector_sizes is None
):
dynamic_vector_sizes = []
elif scalable_sizes is None and static_vector_sizes is None:
(
dynamic_vector_sizes,
static_vector_sizes,
scalable_sizes,
) = _dispatch_dynamic_index_list(vector_sizes)
elif scalable_sizes is None or static_vector_sizes is None:
raise TypeError(
"'scalable_sizes' and 'static_vector_sizes' must either both "
"be given explicitly or both be given as part of 'vector_sizes'."
)
else:
dynamic_vector_sizes = vector_sizes
super().__init__(
target,
vector_sizes=dynamic_vector_sizes,
static_vector_sizes=static_vector_sizes,
scalable_sizes=scalable_sizes,
vectorize_nd_extract=vectorize_nd_extract,
loc=loc,
ip=ip,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class MatchOp(MatchOp):
"""Specialization for MatchOp class."""
@overload
@classmethod
def match_op_names(
cls,
target: Union[Operation, Value],
names: Union[str, Sequence[str]],
*,
loc=None,
ip=None,
):
...
@overload
@classmethod
def match_op_names(
cls,
result_type: Type,
target: Union[Operation, Value],
names: Union[str, Sequence[str]],
*,
loc=None,
ip=None,
):
...
@classmethod
def match_op_names(
cls,
result_type_or_target: Union[Type, Operation, Value],
target_or_names: Union[Operation, Value, Sequence[str], str],
names_or_none: Optional[Union[Sequence[str], str]] = None,
*,
loc=None,
ip=None,
):
if isinstance(result_type_or_target, Type):
result_type = result_type_or_target
target = target_or_names
names = names_or_none
else:
result_type = transform.AnyOpType.get()
target = result_type_or_target
names = target_or_names
if isinstance(names, str):
names = [names]
return cls(
result_type,
target,
ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
loc=loc,
ip=ip,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class MultiTileSizesOp(MultiTileSizesOp):
"""Specialization for MultiTileSizesOp class."""
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
*,
dimension: Union[int, IntegerAttr],
target_size: Union[int, IntegerAttr],
divisor: Optional[Optional[Union[int, IntegerAttr]]] = None,
loc=None,
ip=None,
):
super().__init__(
result_type,
result_type,
result_type,
target,
dimension=dimension,
target_size=target_size,
divisor=divisor,
loc=loc,
ip=ip,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class PadOp(PadOp):
"""Specialization for PadOp class."""
def __init__(
self,
target: Union[Operation, OpView, Value],
*,
padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
padding_dimensions: OptionalIntList = None,
pad_to_multiple_of: OptionalIntList = None,
pack_paddings: OptionalIntList = None,
transpose_paddings: Optional[
Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
] = None,
copy_back_op: Optional[Union[str, StringAttr]] = None,
loc=None,
ip=None,
):
transpose_paddings = _get_int_array_array_attr(transpose_paddings)
any_op_type = transform.AnyOpType.get()
super().__init__(
any_op_type,
any_op_type,
any_op_type,
target,
padding_values=padding_values,
padding_dimensions=padding_dimensions,
pad_to_multiple_of=pad_to_multiple_of,
pack_paddings=pack_paddings,
transpose_paddings=transpose_paddings,
copy_back_op=copy_back_op,
loc=loc,
ip=ip,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class ScalarizeOp(ScalarizeOp):
"""Specialization for ScalarizeOp class."""
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
result_type = transform.AnyOpType.get()
super().__init__(result_type, target, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
class SplitOp(SplitOp):
"""Specialization for SplitOp class."""
def __init__(
self,
target: Union[Operation, Value],
dimension: Union[int, Attribute],
split_point: Union[int, Operation, Value, Attribute],
*,
loc=None,
ip=None,
):
if isinstance(split_point, int):
static_split_point = split_point
dynamic_split_point = None
else:
static_split_point = ShapedType.get_dynamic_size()
dynamic_split_point = split_point
super().__init__(
target.type,
target.type,
target,
dimension=dimension,
static_split_point=static_split_point,
dynamic_split_point=dynamic_split_point,
loc=loc,
ip=ip,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class TileUsingForOp(TileUsingForOp):
"""Specialization for TileUsingForOp class."""
@overload
def __init__(
self,
loop_types: Union[Type, List[Type]],
target: Union[Operation, Value],
*,
sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
interchange: OptionalIntList = None,
loc=None,
ip=None,
):
...
@overload
def __init__(
self,
target: Union[Operation, Value, OpView],
*,
sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
interchange: OptionalIntList = None,
loc=None,
ip=None,
):
...
def __init__(
self,
loop_types_or_target: Union[Type, List[Type], Operation, Value],
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
*,
sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
interchange: OptionalIntList = None,
loc=None,
ip=None,
):
(
dynamic_sizes,
static_sizes,
scalable_sizes,
) = _dispatch_dynamic_index_list(sizes)
num_loops = sum(v if v == 0 else 1 for v in static_sizes)
if isinstance(loop_types_or_target, (Operation, Value, OpView)):
loop_types = [transform.AnyOpType.get()] * num_loops
target = loop_types_or_target
assert (
target_or_none is None
), "Cannot construct TileUsingForOp with two targets."
else:
loop_types = (
([loop_types_or_target] * num_loops)
if isinstance(loop_types_or_target, Type)
else loop_types_or_target
)
target = target_or_none
super().__init__(
target.type,
loop_types,
target,
dynamic_sizes=dynamic_sizes,
static_sizes=static_sizes,
interchange=interchange,
scalable_sizes=scalable_sizes,
loc=loc,
ip=ip,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class TileUsingForallOp(TileUsingForallOp):
"""Specialization for TileUsingForallOp class."""
@overload
def __init__(
self,
loops_type: Type,
tiled_op_type: Type,
target: Union[Operation, Value, OpView],
*,
num_threads: Optional[MixedValues] = None,
tile_sizes: MixedValues = None,
mapping=None,
loc=None,
ip=None,
):
...
@overload
def __init__(
self,
target: Union[Operation, Value, OpView],
*,
num_threads: Optional[MixedValues] = None,
tile_sizes: MixedValues = None,
mapping=None,
loc=None,
ip=None,
):
...
def __init__(
self,
loops_type_or_target: Union[
Type, Union[Operation, Value, OpView] # loops_type
], # target
tiled_op_type_or_none: Optional[Type] = None,
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
*,
num_threads: MixedValues = None,
tile_sizes: MixedValues = None,
mapping=None,
loc=None,
ip=None,
):
# `Type` arguments in the front are optional: add default values to front.
if isinstance(loops_type_or_target, Type):
# First overload: type arguments provided.
if not isinstance(tiled_op_type_or_none, Type):
raise TypeError(
"If 'loops_type_or_target' is a type, then "
"'tiled_op_type_or_none' is expected to be one as well."
)
loops_type = loops_type_or_target
tiled_op_type = tiled_op_type_or_none
target = target_or_none
else:
# Last overload: type arguments missing.
loops_type = transform.AnyOpType.get()
tiled_op_type = transform.AnyOpType.get()
target = loops_type_or_target
# Unpack mixed num_threads.
(
dynamic_num_threads,
packed_num_threads,
num_threads_attr,
) = _dispatch_mixed_values(num_threads)
# Unpack mixed tile_sizes.
(
dynamic_tile_sizes,
packed_tile_sizes,
tile_sizes_attr,
) = _dispatch_mixed_values(tile_sizes)
super().__init__(
loops_type,
tiled_op_type,
target=target,
tile_sizes=dynamic_tile_sizes,
packed_tile_sizes=packed_tile_sizes,
static_tile_sizes=tile_sizes_attr,
num_threads=dynamic_num_threads,
packed_num_threads=packed_num_threads,
static_num_threads=num_threads_attr,
mapping=mapping,
loc=loc,
ip=ip,
)
@_ods_cext.register_operation(_Dialect, replace=True)
class VectorizeChildrenAndApplyPatternsOp(VectorizeChildrenAndApplyPatternsOp):
"""Specialization for VectorizeChildrenAndApplyPatternsOp class."""
def __init__(
self,
target: Union[Operation, Value],
*,
disable_multi_reduction_to_contract_patterns: bool = False,
disable_transfer_permutation_map_lowering_patterns: bool = False,
vectorize_nd_extract: bool = False,
vectorize_padding: bool = False,
loc=None,
ip=None,
):
transformed_type = transform.AnyOpType.get()
super().__init__(
transformed_type,
target,
disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns,
disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns,
vectorize_nd_extract=vectorize_nd_extract,
vectorize_padding=vectorize_padding,
loc=loc,
ip=ip,
)

View File

@@ -3,3 +3,67 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._tensor_transform_ops_gen import *
from .._tensor_transform_ops_gen import _Dialect
try:
from ...ir import *
from ...dialects import transform
from .._ods_common import _cext as _ods_cext
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, overload, Union
@_ods_cext.register_operation(_Dialect, replace=True)
class MakeLoopIndependentOp(MakeLoopIndependentOp):
"""Specialization for MakeLoopIndependentOp class."""
@overload
def __init__(
self,
transformed_type: Type,
target: Union[Operation, OpView, Value],
num_loops: Union[int, IntegerAttr],
*,
loc=None,
ip=None,
):
...
@overload
def __init__(
self,
target: Union[Operation, OpView, Value],
num_loops: Union[int, IntegerAttr],
*,
loc=None,
ip=None,
):
...
def __init__(
self,
transformed_type_or_target: Type,
target_or_num_loops: Union[int, IntegerAttr, Operation, OpView, Value] = None,
num_loops_or_none: Optional[Union[int, IntegerAttr]] = None,
*,
loc=None,
ip=None,
):
if isinstance(transformed_type_or_target, Type):
transformed_type = transformed_type_or_target
target = target_or_num_loops
num_loops = num_loops_or_none
else:
transformed_type = transform.AnyOpType.get()
target = transformed_type_or_target
num_loops = target_or_num_loops
super().__init__(
transformed_type,
target,
num_loops,
loc=loc,
ip=ip,
)

View File

@@ -114,6 +114,7 @@ def get_unranked_memref_descriptor(nparray):
d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
return d
def move_aligned_ptr_by_offset(aligned_ptr, offset):
"""Moves the supplied ctypes pointer ahead by `offset` elements."""
aligned_addr = ctypes.addressof(aligned_ptr.contents)
@@ -122,6 +123,7 @@ def move_aligned_ptr_by_offset(aligned_ptr, offset):
content_ptr = ctypes.cast(aligned_addr + shift, type(aligned_ptr))
return content_ptr
def unranked_memref_to_numpy(unranked_memref, np_dtype):
"""Converts unranked memrefs to numpy arrays."""
ctp = as_ctype(np_dtype)
@@ -139,10 +141,10 @@ def unranked_memref_to_numpy(unranked_memref, np_dtype):
def ranked_memref_to_numpy(ranked_memref):
"""Converts ranked memrefs to numpy arrays."""
content_ptr = move_aligned_ptr_by_offset(ranked_memref[0].aligned, ranked_memref[0].offset)
np_arr = np.ctypeslib.as_array(
content_ptr, shape=ranked_memref[0].shape
content_ptr = move_aligned_ptr_by_offset(
ranked_memref[0].aligned, ranked_memref[0].offset
)
np_arr = np.ctypeslib.as_array(content_ptr, shape=ranked_memref[0].shape)
strided_arr = np.lib.stride_tricks.as_strided(
np_arr,
np.ctypeslib.as_array(ranked_memref[0].shape),

View File

@@ -33,3 +33,16 @@ def testFastMathFlags():
)
# CHECK: %0 = arith.addf %cst, %cst fastmath<nnan,ninf> : f32
print(r)
# CHECK-LABEL: TEST: testArithValueBuilder
@run
def testArithValueBuilder():
with Context() as ctx, Location.unknown():
module = Module.create()
f32_t = F32Type.get()
with InsertionPoint(module.body):
a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
# CHECK: %cst = arith.constant 4.242000e+01 : f32
print(a)

View File

@@ -30,14 +30,9 @@ constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
from ._ods_common import _cext as _ods_cext
from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
from ._ods_common import segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
_ods_ir = _ods_cext.ir
try:
from . import _{0}_ops_ext as _ods_ext_module
except ImportError:
_ods_ext_module = None
import builtins
from typing import Sequence as _Sequence, Union as _Union
@@ -62,7 +57,6 @@ from ._{0}_ops_gen import _Dialect
/// {1} is the operation name.
constexpr const char *opClassTemplate = R"Py(
@_ods_cext.register_operation(_Dialect)
@_ods_extend_opview_class(_ods_ext_module)
class {0}(_ods_ir.OpView):
OPERATION_NAME = "{1}"
)Py";
@@ -301,17 +295,17 @@ static bool isODSReserved(StringRef str) {
/// (does not change the `name` if it already is suitable) and returns the
/// modified version.
static std::string sanitizeName(StringRef name) {
std::string processed_str = name.str();
std::string processedStr = name.str();
std::replace_if(
processed_str.begin(), processed_str.end(),
processedStr.begin(), processedStr.end(),
[](char c) { return !llvm::isAlnum(c); }, '_');
if (llvm::isDigit(*processed_str.begin()))
return "_" + processed_str;
if (llvm::isDigit(*processedStr.begin()))
return "_" + processedStr;
if (isPythonReserved(processed_str) || isODSReserved(processed_str))
return processed_str + "_";
return processed_str;
if (isPythonReserved(processedStr) || isODSReserved(processedStr))
return processedStr + "_";
return processedStr;
}
static std::string attrSizedTraitForKind(const char *kind) {
@@ -853,10 +847,6 @@ populateBuilderRegions(const Operator &op,
/// rebuild anew).
static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
raw_ostream &os) {
// If we are asked to skip default builders, comply.
if (op.skipDefaultBuilders())
return {};
llvm::SmallVector<std::string> builderArgs;
llvm::SmallVector<std::string> builderLines;
llvm::SmallVector<std::string> operandArgNames;
@@ -989,9 +979,6 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
static void emitValueBuilder(const Operator &op,
llvm::SmallVector<std::string> functionArgs,
raw_ostream &os) {
// If we are asked to skip default builders, comply.
if (op.skipDefaultBuilders())
return;
// Params with (possibly) default args.
auto valueBuilderParams =
llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
@@ -1010,9 +997,9 @@ static void emitValueBuilder(const Operator &op,
auto lhs = *llvm::split(arg, "=").begin();
return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
});
std::string name_without_dialect =
std::string nameWithoutDialect =
op.getOperationName().substr(op.getOperationName().find('.') + 1);
os << llvm::formatv(valueBuilderTemplate, sanitizeName(name_without_dialect),
os << llvm::formatv(valueBuilderTemplate, sanitizeName(nameWithoutDialect),
op.getCppClassName(),
llvm::join(valueBuilderParams, ", "),
llvm::join(opBuilderArgs, ", "),
@@ -1051,11 +1038,8 @@ static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
if (clDialectName.empty())
llvm::PrintFatalError("dialect name not provided");
bool isExtension = !clDialectExtensionName.empty();
os << llvm::formatv(fileHeader, isExtension
? clDialectExtensionName.getValue()
: clDialectName.getValue());
if (isExtension)
os << fileHeader;
if (!clDialectExtensionName.empty())
os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue());
else
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());