[mlir][func]: Introduce ReplaceFuncSignature tranform operation (#143381)

This transform takes a module and a function name, and replaces the
signature of the function by reordering the arguments and results
according to the interchange arrays. The function is expected to be
defined in the module, and the interchange arrays must match the number
of arguments and results of the function.
This commit is contained in:
Aviad Cohen
2025-06-24 06:35:06 +03:00
committed by GitHub
parent 37eb465710
commit 3ba7a872bf
9 changed files with 545 additions and 2 deletions

View File

@@ -1,4 +1,4 @@
//===- FuncTransformOps.h - CF transformation ops --------*- C++ -*-===//
//===- FuncTransformOps.h - Function transformation ops --------*- C++ -*--===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.

View File

@@ -98,4 +98,40 @@ def CastAndCallOp : Op<Transform_Dialect,
let hasVerifier = 1;
}
def ReplaceFuncSignatureOp
: Op<Transform_Dialect, "func.replace_func_signature",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let description = [{
This transform takes a module and a function name, and replaces the
signature of the function by reordering the arguments and results
according to the interchange arrays. The function is expected to be
defined in the module, and the interchange arrays must match the number
of arguments and results of the function.
The `adjust_func_calls` attribute indicates whether the function calls
should be adjusted to match the new signature. If set to `true`, the
function calls will be adjusted to match the new signature, otherwise
they will not be adjusted.
This transform will emit a silenceable failure if:
- The function with the given name does not exist in the module.
- The interchange arrays do not match the number of arguments/results.
- The interchange arrays contain out of bound indices.
}];
let arguments = (ins TransformHandleTypeInterface:$module,
SymbolRefAttr:$function_name, DenseI32ArrayAttr:$args_interchange,
DenseI32ArrayAttr:$results_interchange, UnitAttr:$adjust_func_calls);
let results = (outs TransformHandleTypeInterface:$transformed_module,
TransformHandleTypeInterface:$transformed_function);
let assemblyFormat = [{
$function_name
`args_interchange` `=` $args_interchange
`results_interchange` `=` $results_interchange
`at` $module attr-dict `:` functional-type(operands, results)
}];
}
#endif // FUNC_TRANSFORM_OPS

View File

@@ -0,0 +1,49 @@
//===- Utils.h - General Func transformation utilities ----*- C++ -*-------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file defines prototypes for various transformation utilities for
// the Func dialect. These are not passes by themselves but are used
// either by passes, optimization sequences, or in turn by other transformation
// utilities.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_FUNC_UTILS_H
#define MLIR_DIALECT_FUNC_UTILS_H
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/ArrayRef.h"
namespace mlir {
namespace func {
class FuncOp;
class CallOp;
/// Creates a new function operation with the same name as the original
/// function operation, but with the arguments reordered according to
/// the `newArgsOrder` and `newResultsOrder`.
/// The `funcOp` operation must have exactly one block.
/// Returns the new function operation or failure if `funcOp` doesn't
/// have exactly one block.
FailureOr<FuncOp>
replaceFuncWithNewOrder(RewriterBase &rewriter, FuncOp funcOp,
llvm::ArrayRef<unsigned> newArgsOrder,
llvm::ArrayRef<unsigned> newResultsOrder);
/// Creates a new call operation with the values as the original
/// call operation, but with the arguments reordered according to
/// the `newArgsOrder` and `newResultsOrder`.
CallOp replaceCallOpWithNewOrder(RewriterBase &rewriter, CallOp callOp,
llvm::ArrayRef<unsigned> newArgsOrder,
llvm::ArrayRef<unsigned> newResultsOrder);
} // namespace func
} // namespace mlir
#endif // MLIR_DIALECT_FUNC_UTILS_H

View File

@@ -2,3 +2,4 @@ add_subdirectory(Extensions)
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(TransformOps)
add_subdirectory(Utils)

View File

@@ -1,4 +1,4 @@
//===- FuncTransformOps.cpp - Implementation of CF transform ops ---===//
//===- FuncTransformOps.cpp - Implementation of CF transform ops ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -11,10 +11,12 @@
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Utils/Utils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
@@ -226,6 +228,109 @@ void transform::CastAndCallOp::getEffects(
transform::modifiesPayload(effects);
}
//===----------------------------------------------------------------------===//
// ReplaceFuncSignatureOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::ReplaceFuncSignatureOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto payloadOps = state.getPayloadOps(getModule());
if (!llvm::hasSingleElement(payloadOps))
return emitDefiniteFailure() << "requires a single module to operate on";
auto targetModuleOp = dyn_cast<ModuleOp>(*payloadOps.begin());
if (!targetModuleOp)
return emitSilenceableFailure(getLoc())
<< "target is expected to be module operation";
func::FuncOp funcOp =
targetModuleOp.lookupSymbol<func::FuncOp>(getFunctionName());
if (!funcOp)
return emitSilenceableFailure(getLoc())
<< "function with name '" << getFunctionName() << "' not found";
unsigned numArgs = funcOp.getNumArguments();
unsigned numResults = funcOp.getNumResults();
// Check that the number of arguments and results matches the
// interchange sizes.
if (numArgs != getArgsInterchange().size())
return emitSilenceableFailure(getLoc())
<< "function with name '" << getFunctionName() << "' has " << numArgs
<< " arguments, but " << getArgsInterchange().size()
<< " args interchange were given";
if (numResults != getResultsInterchange().size())
return emitSilenceableFailure(getLoc())
<< "function with name '" << getFunctionName() << "' has "
<< numResults << " results, but " << getResultsInterchange().size()
<< " results interchange were given";
// Check that the args and results interchanges are unique.
SetVector<unsigned> argsInterchange, resultsInterchange;
argsInterchange.insert_range(getArgsInterchange());
resultsInterchange.insert_range(getResultsInterchange());
if (argsInterchange.size() != getArgsInterchange().size())
return emitSilenceableFailure(getLoc())
<< "args interchange must be unique";
if (resultsInterchange.size() != getResultsInterchange().size())
return emitSilenceableFailure(getLoc())
<< "results interchange must be unique";
// Check that the args and results interchange indices are in bounds.
for (unsigned index : argsInterchange) {
if (index >= numArgs) {
return emitSilenceableFailure(getLoc())
<< "args interchange index " << index
<< " is out of bounds for function with name '"
<< getFunctionName() << "' with " << numArgs << " arguments";
}
}
for (unsigned index : resultsInterchange) {
if (index >= numResults) {
return emitSilenceableFailure(getLoc())
<< "results interchange index " << index
<< " is out of bounds for function with name '"
<< getFunctionName() << "' with " << numResults << " results";
}
}
FailureOr<func::FuncOp> newFuncOpOrFailure = func::replaceFuncWithNewOrder(
rewriter, funcOp, argsInterchange.getArrayRef(),
resultsInterchange.getArrayRef());
if (failed(newFuncOpOrFailure))
return emitSilenceableFailure(getLoc())
<< "failed to replace function signature '" << getFunctionName()
<< "' with new order";
if (getAdjustFuncCalls()) {
SmallVector<func::CallOp> callOps;
targetModuleOp.walk([&](func::CallOp callOp) {
if (callOp.getCallee() == getFunctionName().getRootReference().getValue())
callOps.push_back(callOp);
});
for (func::CallOp callOp : callOps)
func::replaceCallOpWithNewOrder(rewriter, callOp,
argsInterchange.getArrayRef(),
resultsInterchange.getArrayRef());
}
results.set(cast<OpResult>(getTransformedModule()), {targetModuleOp});
results.set(cast<OpResult>(getTransformedFunction()), {*newFuncOpOrFailure});
return DiagnosedSilenceableFailure::success();
}
void transform::ReplaceFuncSignatureOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::consumesHandle(getModuleMutable(), effects);
transform::producesHandle(getOperation()->getOpResults(), effects);
transform::modifiesPayload(effects);
}
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//

View File

@@ -0,0 +1,12 @@
add_mlir_dialect_library(MLIRFuncUtils
Utils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Utils
LINK_LIBS PUBLIC
MLIRFuncDialect
MLIRDialect
MLIRDialectUtils
MLIRIR
)

View File

@@ -0,0 +1,121 @@
//===- Utils.cpp - Utilities to support the Func dialect ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements utilities for the Func dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SmallVector.h"
using namespace mlir;
FailureOr<func::FuncOp>
func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
ArrayRef<unsigned> newArgsOrder,
ArrayRef<unsigned> newResultsOrder) {
// Generate an empty new function operation with the same name as the
// original.
assert(funcOp.getNumArguments() == newArgsOrder.size() &&
"newArgsOrder must match the number of arguments in the function");
assert(funcOp.getNumResults() == newResultsOrder.size() &&
"newResultsOrder must match the number of results in the function");
if (!funcOp.getBody().hasOneBlock())
return rewriter.notifyMatchFailure(
funcOp, "expected function to have exactly one block");
ArrayRef<Type> origInputTypes = funcOp.getFunctionType().getInputs();
ArrayRef<Type> origOutputTypes = funcOp.getFunctionType().getResults();
SmallVector<Type> newInputTypes, newOutputTypes;
SmallVector<Location> locs;
for (unsigned int idx : newArgsOrder) {
newInputTypes.push_back(origInputTypes[idx]);
locs.push_back(funcOp.getArgument(newArgsOrder[idx]).getLoc());
}
for (unsigned int idx : newResultsOrder)
newOutputTypes.push_back(origOutputTypes[idx]);
rewriter.setInsertionPoint(funcOp);
auto newFuncOp = rewriter.create<func::FuncOp>(
funcOp.getLoc(), funcOp.getName(),
rewriter.getFunctionType(newInputTypes, newOutputTypes));
Region &newRegion = newFuncOp.getBody();
rewriter.createBlock(&newRegion, newRegion.begin(), newInputTypes, locs);
newFuncOp.setVisibility(funcOp.getVisibility());
newFuncOp->setDiscardableAttrs(funcOp->getDiscardableAttrDictionary());
// Map the arguments of the original function to the new function in
// the new order and adjust the attributes accordingly.
IRMapping operandMapper;
SmallVector<DictionaryAttr> argAttrs, resultAttrs;
funcOp.getAllArgAttrs(argAttrs);
for (unsigned int i = 0; i < newArgsOrder.size(); ++i) {
operandMapper.map(funcOp.getArgument(newArgsOrder[i]),
newFuncOp.getArgument(i));
newFuncOp.setArgAttrs(i, argAttrs[newArgsOrder[i]]);
}
funcOp.getAllResultAttrs(resultAttrs);
for (unsigned int i = 0; i < newResultsOrder.size(); ++i)
newFuncOp.setResultAttrs(i, resultAttrs[newResultsOrder[i]]);
// Clone the operations from the original function to the new function.
rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
for (Operation &op : funcOp.getOps())
rewriter.clone(op, operandMapper);
// Handle the return operation.
auto returnOp = cast<func::ReturnOp>(
newFuncOp.getFunctionBody().begin()->getTerminator());
SmallVector<Value> newReturnValues;
for (unsigned int idx : newResultsOrder)
newReturnValues.push_back(returnOp.getOperand(idx));
rewriter.setInsertionPoint(returnOp);
auto newReturnOp =
rewriter.create<func::ReturnOp>(newFuncOp.getLoc(), newReturnValues);
newReturnOp->setDiscardableAttrs(returnOp->getDiscardableAttrDictionary());
rewriter.eraseOp(returnOp);
rewriter.eraseOp(funcOp);
return newFuncOp;
}
func::CallOp
func::replaceCallOpWithNewOrder(RewriterBase &rewriter, func::CallOp callOp,
ArrayRef<unsigned> newArgsOrder,
ArrayRef<unsigned> newResultsOrder) {
assert(
callOp.getNumOperands() == newArgsOrder.size() &&
"newArgsOrder must match the number of operands in the call operation");
assert(
callOp.getNumResults() == newResultsOrder.size() &&
"newResultsOrder must match the number of results in the call operation");
SmallVector<Value> newArgsOrderValues;
for (unsigned int argIdx : newArgsOrder)
newArgsOrderValues.push_back(callOp.getOperand(argIdx));
SmallVector<Type> newResultTypes;
for (unsigned int resIdx : newResultsOrder)
newResultTypes.push_back(callOp.getResult(resIdx).getType());
// Replace the kernel call operation with a new one that has the
// reordered arguments.
rewriter.setInsertionPoint(callOp);
auto newCallOp = rewriter.create<func::CallOp>(
callOp.getLoc(), callOp.getCallee(), newResultTypes, newArgsOrderValues);
newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
for (auto &&[newIndex, origIndex] : llvm::enumerate(newResultsOrder))
rewriter.replaceAllUsesWith(callOp.getResult(origIndex),
newCallOp.getResult(newIndex));
rewriter.eraseOp(callOp);
return newCallOp;
}

View File

@@ -0,0 +1,87 @@
// RUN: mlir-opt %s --transform-interpreter -allow-unregistered-dialect --split-input-file -verify-diagnostics
module {
func.func private @func_with_reverse_order_no_result_no_calls(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
%c0 = arith.constant 0 : index
%view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
%view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
%view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
return
}
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%module = transform.get_parent_op %func : (!transform.any_op) -> !transform.any_op
// expected-error @+1 {{function with name '@func_not_in_module' not found}}
transform.func.replace_func_signature @func_not_in_module args_interchange = [0, 2, 1] results_interchange = [] at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
// -----
module {
func.func private @func_with_reverse_order_no_result_no_calls(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
%c0 = arith.constant 0 : index
%view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
%view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
%view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
return
}
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%module = transform.get_parent_op %func : (!transform.any_op) -> !transform.any_op
// expected-error @+1 {{function with name '@func_with_reverse_order_no_result_no_calls' has 3 arguments, but 2 args interchange were given}}
transform.func.replace_func_signature @func_with_reverse_order_no_result_no_calls args_interchange = [0, 2] results_interchange = [] at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
// -----
module {
func.func private @func_with_reverse_order_no_result_no_calls(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
%c0 = arith.constant 0 : index
%view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
%view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
%view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
return
}
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%module = transform.get_parent_op %func : (!transform.any_op) -> !transform.any_op
// expected-error @+1 {{function with name '@func_with_reverse_order_no_result_no_calls' has 0 results, but 1 results interchange were given}}
transform.func.replace_func_signature @func_with_reverse_order_no_result_no_calls args_interchange = [0, 2, 1] results_interchange = [0] at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
// -----
module {
func.func private @func_with_reverse_order_no_result_no_calls(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
%c0 = arith.constant 0 : index
%view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
%view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
%view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
return
}
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%module = transform.get_parent_op %func : (!transform.any_op) -> !transform.any_op
// expected-error @+1 {{args interchange must be unique}}
transform.func.replace_func_signature @func_with_reverse_order_no_result_no_calls args_interchange = [0, 2, 2] results_interchange = [] at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}

View File

@@ -118,3 +118,135 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
// -----
module {
// CHECK: func.func private @func_with_reverse_order_no_result_no_calls(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<3xi8, 1>, %[[ARG2:.*]]: memref<2xi8, 1>) {
func.func private @func_with_reverse_order_no_result_no_calls(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
// CHECK: %[[VAL_4:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0]]][] : memref<1xi8, 1> to memref<1xi8, 1>
%view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
// CHECK: %[[VAL_5:.*]] = memref.view %[[ARG2]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1>
%view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
// CHECK: %[[VAL_6:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<3xi8, 1> to memref<3xi8, 1>
%view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
return
}
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%module = transform.get_parent_op %func : (!transform.any_op) -> !transform.any_op
transform.func.replace_func_signature @func_with_reverse_order_no_result_no_calls args_interchange = [0, 2, 1] results_interchange = [] at %module : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
// -----
module {
// CHECK: func.func private @func_with_reverse_order_no_result(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<3xi8, 1>, %[[ARG2:.*]]: memref<2xi8, 1>) {
func.func private @func_with_reverse_order_no_result(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
// CHECK: %[[VAL_4:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0]]][] : memref<1xi8, 1> to memref<1xi8, 1>
%view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
// CHECK: %[[VAL_5:.*]] = memref.view %[[ARG2]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1>
%view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
// CHECK: %[[VAL_6:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<3xi8, 1> to memref<3xi8, 1>
%view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
return
}
// CHECK: func.func @func_with_reverse_order_no_result_caller(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>, %[[ARG2:.*]]: memref<3xi8, 1>) {
func.func @func_with_reverse_order_no_result_caller(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) {
// CHECK: call @func_with_reverse_order_no_result(%[[ARG0]], %[[ARG2]], %[[ARG1]]) : (memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) -> ()
call @func_with_reverse_order_no_result(%arg0, %arg1, %arg2) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<3xi8, 1>) -> ()
return
}
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%module = transform.get_parent_op %f#0 : (!transform.any_op) -> !transform.any_op
transform.func.replace_func_signature @func_with_reverse_order_no_result args_interchange = [0, 2, 1] results_interchange = [] at %module {adjust_func_calls} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
// -----
module {
// CHECK: func.func private @func_with_reverse_order(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<3xi8, 1>, %[[ARG2:.*]]: memref<2xi8, 1>) -> (memref<2xi8, 1>, memref<1xi8, 1>) {
func.func private @func_with_reverse_order(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
// CHECK: %[[RET_0:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0]]][] : memref<1xi8, 1> to memref<1xi8, 1>
%view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
// CHECK: %[[RET_1:.*]] = memref.view %[[ARG2]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1>
%view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
// CHECK: %[[VAL_6:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<3xi8, 1> to memref<3xi8, 1>
%view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
// CHECK: return %[[RET_1]], %[[RET_0]] : memref<2xi8, 1>, memref<1xi8, 1>
return %view, %view0 : memref<1xi8, 1>, memref<2xi8, 1>
}
// CHECK: func.func @func_with_reverse_order_caller(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>, %[[ARG2:.*]]: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>) {
func.func @func_with_reverse_order_caller(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>) {
// CHECK: %[[RET:.*]]:2 = call @func_with_reverse_order(%[[ARG0]], %[[ARG2]], %[[ARG1]]) : (memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) -> (memref<2xi8, 1>, memref<1xi8, 1>)
%0, %1 = call @func_with_reverse_order(%arg0, %arg1, %arg2) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>)
// CHECK: return %[[RET]]#1, %[[RET]]#0 : memref<1xi8, 1>, memref<2xi8, 1>
return %0, %1 : memref<1xi8, 1>, memref<2xi8, 1>
}
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%module = transform.get_parent_op %f#0 : (!transform.any_op) -> !transform.any_op
transform.func.replace_func_signature @func_with_reverse_order args_interchange = [0, 2, 1] results_interchange = [1, 0] at %module {adjust_func_calls} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
// -----
module {
// CHECK: func.func private @func_with_reverse_order_with_attr(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<3xi8, 1>, %[[ARG2:.*]]: memref<2xi8, 1> {transform.readonly}) -> (memref<2xi8, 1>, memref<1xi8, 1>) {
func.func private @func_with_reverse_order_with_attr(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>{transform.readonly}, %arg2: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
// CHECK: %[[RET_0:.*]] = memref.view %[[ARG0]]{{\[}}%[[C0]]][] : memref<1xi8, 1> to memref<1xi8, 1>
%view = memref.view %arg0[%c0][] : memref<1xi8, 1> to memref<1xi8, 1>
// CHECK: %[[RET_1:.*]] = memref.view %[[ARG2]]{{\[}}%[[C0]]][] : memref<2xi8, 1> to memref<2xi8, 1>
%view0 = memref.view %arg1[%c0][] : memref<2xi8, 1> to memref<2xi8, 1>
// CHECK: %[[VAL_6:.*]] = memref.view %[[ARG1]]{{\[}}%[[C0]]][] : memref<3xi8, 1> to memref<3xi8, 1>
%view1 = memref.view %arg2[%c0][] : memref<3xi8, 1> to memref<3xi8, 1>
// CHECK: return %[[RET_1]], %[[RET_0]] : memref<2xi8, 1>, memref<1xi8, 1>
return %view, %view0 : memref<1xi8, 1>, memref<2xi8, 1>
}
// CHECK: func.func @func_with_reverse_order_with_attr_caller(%[[ARG0:.*]]: memref<1xi8, 1>, %[[ARG1:.*]]: memref<2xi8, 1>, %[[ARG2:.*]]: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>) {
func.func @func_with_reverse_order_with_attr_caller(%arg0: memref<1xi8, 1>, %arg1: memref<2xi8, 1>, %arg2: memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>) {
// CHECK: %[[RET:.*]]:2 = call @func_with_reverse_order_with_attr(%[[ARG0]], %[[ARG2]], %[[ARG1]]) : (memref<1xi8, 1>, memref<3xi8, 1>, memref<2xi8, 1>) -> (memref<2xi8, 1>, memref<1xi8, 1>)
%0, %1 = call @func_with_reverse_order_with_attr(%arg0, %arg1, %arg2) : (memref<1xi8, 1>, memref<2xi8, 1>, memref<3xi8, 1>) -> (memref<1xi8, 1>, memref<2xi8, 1>)
// CHECK: return %[[RET]]#1, %[[RET]]#0 : memref<1xi8, 1>, memref<2xi8, 1>
return %0, %1 : memref<1xi8, 1>, memref<2xi8, 1>
}
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%funcs = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%f:2 = transform.split_handle %funcs : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%module = transform.get_parent_op %f#0 : (!transform.any_op) -> !transform.any_op
transform.func.replace_func_signature @func_with_reverse_order_with_attr args_interchange = [0, 2, 1] results_interchange = [1, 0] at %module {adjust_func_calls} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}