…_reduce_matmul.
This patch exposes broadcast and transpose semantics on
'batch_reduce_matmul'. This is the last one in continuation of other two
variant of matmul ops.
The broadcast and transpose semantic are as follows:
Broadcast and Transpose semantics can be appiled by specifying the
explicit attribute 'indexing_maps' as shown below. This is a list
attribute, so must include maps for all arguments if specified.
Example Transpose:
```
linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>)
outs(%arg2: memref<3x7xf32>)
```
Example Broadcast:
```
linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>)
outs(%arg2: memref<3x7xf32>)
```
Example Broadcast and Transpose:
```
linalg.batch_reduce_matmul indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose
affine_map<(d0, d1, d2, d3) -> (d1, d2)>
]
ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>)
outs(%arg2: memref<3x7xf32>)
```
RFCs and related PR:
https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586https://github.com/llvm/llvm-project/pull/115319https://github.com/llvm/llvm-project/pull/122275
let constructor is legacy (do not use in tree!) since the tableGen
backend emits most of the glue logic to build a pass.
Note: The following constructor has been retired:
```cpp
std::unique_ptr<Pass> createExpandReallocPass(bool emitDeallocs = true);
```
To update your codebase, replace it with the new options-based API:
```cpp
memref::ExpandReallocPassOptions expandAllocPassOptions{
/*emitDeallocs=*/false};
pm.addPass(memref::createExpandReallocPass(expandAllocPassOptions));
```
Ops that are already snake case (like [`ROCDL_wmma_*`
ops](66b0b0466b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (L411)))
produce python "value-builders" that collide with the class names:
```python
class wmma_bf16_16x16x16_bf16(_ods_ir.OpView):
OPERATION_NAME = "rocdl.wmma.bf16.16x16x16.bf16"
...
def wmma_bf16_16x16x16_bf16(res, args, *, loc=None, ip=None) -> _ods_ir.Value:
return wmma_bf16_16x16x16_bf16(res=res, args=args, loc=loc, ip=ip).result
```
and thus cannot be emitted (because of recursive self-calls).
This PR fixes that by affixing `_` to the value builder names.
I would've preferred to just rename the ops but that would be a breaking
change 🤷.
This PR is mainly about exposing the python bindings for
`linalg::isaConvolutionOpInterface` and `linalg::inferConvolutionDims`.
---------
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
This PR is mainly about exposing the python bindings for`
linalg::isaContractionOpInterface` and` linalg::inferContractionDims`.
---------
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
This commit extends the MLIR vector type to support pointer-like types
such as `!llvm.ptr` and `!ptr.ptr`, as indicated by the newly added
`VectorTypeElementInterface`. This makes the LLVM dialect closer to LLVM
IR. LLVM IR already supports pointers as vector element type.
Only integers, floats, pointers and index are valid vector element types
for now. Additional vector element types may be added in the future
after further discussions. The interface is still evolving and may
eventually turn into one of the alternatives that were discussed on the
RFC.
This commit also disallows `!llvm.ptr` as an element type of
`!llvm.vec`. This type exists due to limitations of the MLIR vector
type.
RFC:
https://discourse.llvm.org/t/rfc-allow-pointers-as-element-type-of-vector/85360
This is an implementation for [RFC: Supporting Sub-Channel Quantization
in
MLIR](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694).
In order to make the review process easier, the PR has been divided into
the following commit labels:
1. **Add implementation for sub-channel type:** Includes the class
design for `UniformQuantizedSubChannelType`, printer/parser and bytecode
read/write support. The existing types (per-tensor and per-axis) are
unaltered.
2. **Add implementation for sub-channel type:** Lowering of
`quant.qcast` and `quant.dcast` operations to Linalg operations.
3. **Adding C/Python Apis:** We first define he C-APIs and build the
Python-APIs on top of those.
4. **Add pass to normalize generic ....:** This pass normalizes
sub-channel quantized types to per-tensor per-axis types, if possible.
A design note:
- **Explicitly storing the `quantized_dimensions`, even when they can be
derived for ranked tensor.**
While it's possible to infer quantized dimensions from the static shape
of the scales (or zero-points) tensor for ranked
data tensors
([ref](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694/3)
for background), there are cases where this can lead to ambiguity and
issues with round-tripping.
```
Consider the example: tensor<2x4x!quant.uniform<i8:f32:{0:2, 0:2}, {{s00:z00, s01:z01}}>>
```
The shape of the scales tensor is [1, 2], which might suggest that only
axis 1 is quantized. While this inference is technically correct, as the
block size for axis 0 is a degenerate case (equal to the dimension
size), it can cause problems with round-tripping. Therefore, even for
ranked tensors, we are explicitly storing the quantized dimensions.
Suggestions welcome!
PS: I understand that the upcoming holidays may impact your schedule, so
please take your time with the review. There's no rush.
* `PyRegionList` is now sliceable. The dialect bindings generator seems
to assume it is sliceable already (!), yet accessing e.g. `cases` on
`scf.IndexedSwitchOp` raises a `TypeError` at runtime.
* `PyBlockList` and `PyOperationList` support negative indexing. It is
common for containers to do that in Python, and most container in the
MLIR Python bindings already allow the index to be negative.
Updated the Python diagnostics handler to emit notes (in addition to
errors) into the output stream so that users have more context as to
where in the IR the error is occurring.
In some projects like JAX ir.Context are used with disabled multi-threading to avoid
caching multiple threading pools:
623865fe95/jax/_src/interpreters/mlir.py (L606-L611)
However, when context has enabled multithreading it also uses locks on
the StorageUniquers and this can be helpful to avoid data races in the
multi-threaded execution (for example with free-threaded cpython,
https://github.com/jax-ml/jax/issues/26272).
With this PR user can enable the multi-threading: 1) enables additional
locking and 2) set a shared threading pool such that cached contexts can
have one global pool.
This PR extends the python bindings for CallSiteLoc, FileLineColRange,
FusedLoc, NameLoc with field accessors. It also adds the missing
`value.location` accessor.
I also did some "spring cleaning" here (`cast` -> `dyn_cast`) after
running into some of my own illegal casts.
The current `write_bytecode` implementation necessarily requires the
serialized module to be duplicated in memory when the python `bytes`
object is created and sent over the binding. For modules with large
resources, we may want to avoid this in-memory copy by serializing
directly to a file instead of sending bytes across the boundary.
This PR https://github.com/llvm/llvm-project/pull/123902 broke python
bindings for `tensor.pack`/`unpack`. This PR fixes that. It also
1. adds convenience wrappers for pack/unpack
2. cleans up matmul-like ops in the linalg bindings
3. fixes linalg docs missing pack/unpack
As linalg.batch_matmul has been moved into tablegen from OpDSL, its
derived python wrapper no longer exist.This patch adds the required
python wrapper.
Also refactors the BatchmatmulOp printer to make it consistent with its
parser.
For extremely large models, it may be inefficient to load the model into
memory in Python prior to passing it to the MLIR C APIs for
deserialization. This change adds an API to parse a ModuleOp directly
from a file path.
Re-lands
[4e14b8a](4e14b8afb4).
Implement the feature about perf by stage(llvm-ir -> isa, isa->binary).
The results will be stored into the properties, then users can use them
after using GpuModuleToBinary Pass.
Now that linalg.matmul is in tablegen, "hand write" the Python wrapper
that OpDSL used to derive. Similarly, add a Python wrapper for the new
linalg.contract op.
Required following misc. fixes:
1) make linalg.matmul's parsing and printing consistent w.r.t. whether
indexing_maps occurs before or after operands, i.e. per the tests cases
it comes _before_.
2) tablegen for linalg.contract did not state it accepted an optional
cast attr.
3) In ODS's C++-generating code, expand partial support for `$_builder`
access in `Attr::defaultValue` to full support. This enables access to
the current `MlirContext` when constructing the default value (as is
required when the default value consists of affine maps).
For extremely large models, it may be inefficient to load the model into
memory in Python prior to passing it to the MLIR C APIs for
deserialization. This change adds an API to parse a ModuleOp directly
from a file path.
If the large element limit is specified, large elements are hidden from
the asm but large resources are not. This change extends the large
elements limit to apply to printed resources as well.
Model the `IndexType` as `uint64_t` when converting to a python integer.
With the python bindings,
```python
DenseIntElementsAttr(op.attributes["attr"])
```
used to `assert` when `attr` had `index` type like `dense<[1, 2, 3, 4]>
: vector<4xindex>`.
---------
Co-authored-by: Christopher McGirr <christopher.mcgirr@amd.com>
Co-authored-by: Tiago Trevisan Jost <tiago.trevisanjost@amd.com>
Use `mlir_target_link_libraries()` to link dependencies of libraries
that are not included in libMLIR, to ensure that they link to the dylib
when they are used in Flang. Otherwise, they implicitly pull in all
their static dependencies, effectively causing Flang binaries to
simultaneously link to the dylib and to static libraries, which is never
a good idea.
I have only covered the libraries that are used by Flang. If you wish, I
can extend this approach to all non-libMLIR libraries in MLIR, making
MLIR itself also link to the dylib consistently.
[v3 with more `-DBUILD_SHARED_LIBS=ON` fixes]
Use `mlir_target_link_libraries()` to link dependencies of libraries
that are not included in libMLIR, to ensure that they link to the dylib
when they are used in Flang. Otherwise, they implicitly pull in all
their static dependencies, effectively causing Flang binaries to
simultaneously link to the dylib and to static libraries, which is never
a good idea.
I have only covered the libraries that are used by Flang. If you wish, I
can extend this approach to all non-libMLIR libraries in MLIR, making
MLIR itself also link to the dylib consistently.
[v2 with fixed `-DBUILD_SHARED_LIBS=ON` build]
Use `mlir_target_link_libraries()` to link dependencies of libraries
that are not included in libMLIR, to ensure that they link to the dylib
when they are used in Flang. Otherwise, they implicitly pull in all
their static dependencies, effectively causing Flang binaries to
simultaneously link to the dylib and to static libraries, which is never
a good idea.
I have only covered the libraries that are used by Flang. If you wish, I
can extend this approach to all non-libMLIR libraries in MLIR, making
MLIR itself also link to the dylib consistently.
In order to optionally run some checks that depend on the `ml_dtypes`
python module we have to remove the `CHECK` lines for those tests or
they will be required and missed in the test output.
I've changed to use asserts as recommended in [1].
[1]:
https://github.com/llvm/llvm-project/pull/123061#issuecomment-2596116023
We noticed that `mlir/python/requirements.txt` lists `ml_dtypes` as a requirement but when looking at the code in `mlir/python`, the only `import` is guarded:
```python
try:
import ml_dtypes
except ModuleNotFoundError:
# The third-party ml_dtypes provides some optional low precision data-types for NumPy.
ml_dtypes = None
```
This makes `ml_dtypes` an optional dependency.
Some python tests however partially depend on `ml_dtypes` and should not run if that module is unavailable. That is what this change does.
This is a replacement for #123051 which was excluding tests too broadly.
Gives option post as global list as well as arg to control which
dialects are loaded during context creation. This enables setting either
a good base set or skipping in individual cases.
This is a companion to #118583, although it can be landed independently
because since #117922 dialects do not have to use the same Python
binding framework as the Python core code.
This PR ports all of the in-tree dialect and pass extensions to
nanobind, with the exception of those that remain for testing pybind11
support.
This PR also:
* removes CollectDiagnosticsToStringScope from NanobindAdaptors.h. This
was overlooked in a previous PR and it is duplicated in Diagnostics.h.
---------
Co-authored-by: Jacques Pienaar <jpienaar@google.com>
Do not run `cf-to-llvm` as part of `func-to-llvm`. This commit fixes
https://github.com/llvm/llvm-project/issues/70982.
This commit changes the way how `func.func` ops are lowered to LLVM.
Previously, the signature of the entire region (i.e., entry block and
all other blocks in the `func.func` op) was converted as part of the
`func.func` lowering pattern.
Now, only the entry block is converted. The remaining block signatures
are converted together with `cf.br` and `cf.cond_br` as part of
`cf-to-llvm`. All unstructured control flow is not converted as part of
a single pass (`cf-to-llvm`). `func-to-llvm` no longer deals with
unstructured control flow.
Also add more test cases for control flow dialect ops.
Note: This PR is in preparation of #120431, which adds an additional
GPU-specific lowering for `cf.assert`. This was a problem because
`cf.assert` used to be converted as part of `func-to-llvm`.
Note for LLVM integration: If you see failures, add
`-convert-cf-to-llvm` to your pass pipeline.
Do not run `arith-to-llvm` as part of `func-to-llvm`. This commit partly
fixes#70982.
Also simplify the pass pipeline for two math dialect integration tests.
Note for LLVM integration: If you see failures, add `arith-to-llvm` to your pass pipeline.
Relands #118583, with a fix for Python 3.8 compatibility. It was not
possible to set the buffer protocol accessers via slots in Python 3.8.
Why? https://nanobind.readthedocs.io/en/latest/why.html says it better
than I can, but my primary motivation for this change is to improve MLIR
IR construction time from JAX.
For a complicated Google-internal LLM model in JAX, this change improves
the MLIR
lowering time by around 5s (out of around 30s), which is a significant
speedup for simply switching binding frameworks.
To a large extent, this is a mechanical change, for instance changing
`pybind11::` to `nanobind::`.
Notes:
* this PR needs Nanobind 2.4.0, because it needs a bug fix
(https://github.com/wjakob/nanobind/pull/806) that landed in that
release.
* this PR does not port the in-tree dialect extension modules. They can
be ported in a future PR.
* I removed the py::sibling() annotations from def_static and def_class
in `PybindAdapters.h`. These ask pybind11 to try to form an overload
with an existing method, but it's not possible to form mixed
pybind11/nanobind overloads this ways and the parent class is now
defined in nanobind. Better solutions may be possible here.
* nanobind does not contain an exact equivalent of pybind11's buffer
protocol support. It was not hard to add a nanobind implementation of a
similar API.
* nanobind is pickier about casting to std::vector<bool>, expecting that
the input is a sequence of bool types, not truthy values. In a couple of
places I added code to support truthy values during casting.
* nanobind distinguishes bytes (`nb::bytes`) from strings (e.g.,
`std::string`). This required nb::bytes overloads in a few places.