Introduce the initial support for operation interfaces in C API and Python bindings. Interfaces are a key component of MLIR's extensibility and should be available in bindings to make use of full potential of MLIR. This initial implementation exposes InferTypeOpInterface all the way to the Python bindings since it can be later used to simplify the operation construction methods by inferring their return types instead of requiring the user to do so. The general infrastructure for binding interfaces is defined and InferTypeOpInterface can be used as an example for binding other interfaces. Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D111656
83 lines
3.4 KiB
C++
83 lines
3.4 KiB
C++
//===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir-c/Interfaces.h"
|
|
|
|
#include "mlir/CAPI/IR.h"
|
|
#include "mlir/CAPI/Wrap.h"
|
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
|
#include "llvm/ADT/ScopeExit.h"
|
|
|
|
using namespace mlir;
|
|
|
|
bool mlirOperationImplementsInterface(MlirOperation operation,
|
|
MlirTypeID interfaceTypeID) {
|
|
const AbstractOperation *abstractOp =
|
|
unwrap(operation)->getAbstractOperation();
|
|
return abstractOp && abstractOp->hasInterface(unwrap(interfaceTypeID));
|
|
}
|
|
|
|
bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
|
|
MlirContext context,
|
|
MlirTypeID interfaceTypeID) {
|
|
const AbstractOperation *abstractOp = AbstractOperation::lookup(
|
|
StringRef(operationName.data, operationName.length), unwrap(context));
|
|
return abstractOp && abstractOp->hasInterface(unwrap(interfaceTypeID));
|
|
}
|
|
|
|
MlirTypeID mlirInferTypeOpInterfaceTypeID() {
|
|
return wrap(InferTypeOpInterface::getInterfaceID());
|
|
}
|
|
|
|
MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
|
|
MlirStringRef opName, MlirContext context, MlirLocation location,
|
|
intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
|
|
intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback,
|
|
void *userData) {
|
|
StringRef name(opName.data, opName.length);
|
|
const AbstractOperation *abstractOp =
|
|
AbstractOperation::lookup(name, unwrap(context));
|
|
if (!abstractOp)
|
|
return mlirLogicalResultFailure();
|
|
|
|
llvm::Optional<Location> maybeLocation = llvm::None;
|
|
if (!mlirLocationIsNull(location))
|
|
maybeLocation = unwrap(location);
|
|
SmallVector<Value> unwrappedOperands;
|
|
(void)unwrapList(nOperands, operands, unwrappedOperands);
|
|
DictionaryAttr attributeDict;
|
|
if (!mlirAttributeIsNull(attributes))
|
|
attributeDict = unwrap(attributes).cast<DictionaryAttr>();
|
|
|
|
// Create a vector of unique pointers to regions and make sure they are not
|
|
// deleted when exiting the scope. This is a hack caused by C++ API expecting
|
|
// an list of unique pointers to regions (without ownership transfer
|
|
// semantics) and C API making ownership transfer explicit.
|
|
SmallVector<std::unique_ptr<Region>> unwrappedRegions;
|
|
unwrappedRegions.reserve(nRegions);
|
|
for (intptr_t i = 0; i < nRegions; ++i)
|
|
unwrappedRegions.emplace_back(unwrap(*(regions + i)));
|
|
auto cleaner = llvm::make_scope_exit([&]() {
|
|
for (auto ®ion : unwrappedRegions)
|
|
region.release();
|
|
});
|
|
|
|
SmallVector<Type> inferredTypes;
|
|
if (failed(abstractOp->getInterface<InferTypeOpInterface>()->inferReturnTypes(
|
|
unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
|
|
unwrappedRegions, inferredTypes)))
|
|
return mlirLogicalResultFailure();
|
|
|
|
SmallVector<MlirType> wrappedInferredTypes;
|
|
wrappedInferredTypes.reserve(inferredTypes.size());
|
|
for (Type t : inferredTypes)
|
|
wrappedInferredTypes.push_back(wrap(t));
|
|
callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
|
|
return mlirLogicalResultSuccess();
|
|
}
|