Files
clang-p2996/mlir/lib/CAPI/Dialect/Transform.cpp
martin-luecke 681eacc1b6 [MLIR][transform][python] add sugared python abstractions for transform dialect (#75073)
This adds Python abstractions for the different handle types of the
transform dialect

The abstractions allow for straightforward chaining of transforms by
calling their member functions.
As an initial PR for this infrastructure, only a single transform is
included: `transform.structured.match`.
With a future `tile` transform abstraction an example of the usage is: 
```Python
def script(module: OpHandle):
    module.match_ops(MatchInterfaceEnum.TilingInterface).tile(tile_sizes=[32,32])
```
to generate the following IR:
```mlir
%0 = transform.structured.match interface{TilingInterface} in %arg0
%tiled_op, %loops = transform.structured.tile_using_for %0 [32, 32]
```

These abstractions are intended to enhance the usability and flexibility
of the transform dialect by providing an accessible interface that
allows for easy assembly of complex transformation chains.
2023-12-15 13:04:43 +01:00

109 lines
3.6 KiB
C++

//===- Transform.cpp - C Interface for Transform dialect ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/Support.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
using namespace mlir;
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Transform, transform,
transform::TransformDialect)
//===---------------------------------------------------------------------===//
// AnyOpType
//===---------------------------------------------------------------------===//
bool mlirTypeIsATransformAnyOpType(MlirType type) {
return isa<transform::AnyOpType>(unwrap(type));
}
MlirTypeID mlirTransformAnyOpTypeGetTypeID(void) {
return wrap(transform::AnyOpType::getTypeID());
}
MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
return wrap(transform::AnyOpType::get(unwrap(ctx)));
}
//===---------------------------------------------------------------------===//
// AnyParamType
//===---------------------------------------------------------------------===//
bool mlirTypeIsATransformAnyParamType(MlirType type) {
return isa<transform::AnyParamType>(unwrap(type));
}
MlirTypeID mlirTransformAnyParamTypeGetTypeID(void) {
return wrap(transform::AnyParamType::getTypeID());
}
MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) {
return wrap(transform::AnyParamType::get(unwrap(ctx)));
}
//===---------------------------------------------------------------------===//
// AnyValueType
//===---------------------------------------------------------------------===//
bool mlirTypeIsATransformAnyValueType(MlirType type) {
return isa<transform::AnyValueType>(unwrap(type));
}
MlirTypeID mlirTransformAnyValueTypeGetTypeID(void) {
return wrap(transform::AnyValueType::getTypeID());
}
MlirType mlirTransformAnyValueTypeGet(MlirContext ctx) {
return wrap(transform::AnyValueType::get(unwrap(ctx)));
}
//===---------------------------------------------------------------------===//
// OperationType
//===---------------------------------------------------------------------===//
bool mlirTypeIsATransformOperationType(MlirType type) {
return isa<transform::OperationType>(unwrap(type));
}
MlirTypeID mlirTransformOperationTypeGetTypeID(void) {
return wrap(transform::OperationType::getTypeID());
}
MlirType mlirTransformOperationTypeGet(MlirContext ctx,
MlirStringRef operationName) {
return wrap(
transform::OperationType::get(unwrap(ctx), unwrap(operationName)));
}
MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
return wrap(cast<transform::OperationType>(unwrap(type)).getOperationName());
}
//===---------------------------------------------------------------------===//
// ParamType
//===---------------------------------------------------------------------===//
bool mlirTypeIsATransformParamType(MlirType type) {
return isa<transform::ParamType>(unwrap(type));
}
MlirTypeID mlirTransformParamTypeGetTypeID(void) {
return wrap(transform::ParamType::getTypeID());
}
MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) {
return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type)));
}
MlirType mlirTransformParamTypeGetType(MlirType type) {
return wrap(cast<transform::ParamType>(unwrap(type)).getType());
}