[MLIR][OpenMP] Minor improvements to BlockArgOpenMPOpInterface, NFC (#130789)
This patch introduces a use for the new `getBlockArgsPairs` to avoid having to manually list each applicable clause. Also, the `numClauseBlockArgs()` function is introduced, which simplifies the implementation of the interface's verifier and enables better memory handling within `getBlockArgsPairs`.
This commit is contained in:
@@ -372,6 +372,8 @@ accessed:
|
||||
should be located.
|
||||
- `get<ClauseName>BlockArgs()`: Returns the list of entry block arguments
|
||||
defined by the given clause.
|
||||
- `numClauseBlockArgs()`: Returns the total number of entry block arguments
|
||||
defined by all clauses.
|
||||
- `getBlockArgsPairs()`: Returns a list of pairs where the first element is
|
||||
the outside value, or operand, and the second element is the corresponding
|
||||
entry block argument.
|
||||
|
||||
@@ -136,12 +136,20 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
|
||||
!foreach(clause, clauses, clause.startMethod),
|
||||
!foreach(clause, clauses, clause.blockArgsMethod),
|
||||
[
|
||||
InterfaceMethod<
|
||||
"Get the total number of clause-defined entry block arguments",
|
||||
"unsigned", "numClauseBlockArgs", (ins),
|
||||
"return " # !interleave(
|
||||
!foreach(clause, clauses, "$_op." # clause.numArgsMethod.name # "()"),
|
||||
" + ") # ";"
|
||||
>,
|
||||
InterfaceMethod<
|
||||
"Populate a vector of pairs representing the matching between operands "
|
||||
"and entry block arguments.", "void", "getBlockArgsPairs",
|
||||
(ins "::llvm::SmallVectorImpl<std::pair<::mlir::Value, ::mlir::BlockArgument>> &" : $pairs),
|
||||
[{
|
||||
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
|
||||
pairs.reserve(pairs.size() + iface.numClauseBlockArgs());
|
||||
}] # !interleave(!foreach(clause, clauses, [{
|
||||
}] # "if (iface." # clause.numArgsMethod.name # "() > 0) {" # [{
|
||||
}] # " for (auto [var, arg] : ::llvm::zip_equal(" #
|
||||
@@ -155,11 +163,7 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
|
||||
|
||||
let verify = [{
|
||||
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
|
||||
}] # "unsigned expectedArgs = "
|
||||
# !interleave(
|
||||
!foreach(clause, clauses, "iface." # clause.numArgsMethod.name # "()"),
|
||||
" + "
|
||||
) # ";" # [{
|
||||
unsigned expectedArgs = iface.numClauseBlockArgs();
|
||||
if ($_op->getRegion(0).getNumArguments() < expectedArgs)
|
||||
return $_op->emitOpError() << "expected at least " << expectedArgs
|
||||
<< " entry block argument(s)";
|
||||
|
||||
@@ -550,18 +550,16 @@ convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
|
||||
// corresponding operand. This is semantically equivalent to this wrapper not
|
||||
// being present.
|
||||
auto forwardArgs =
|
||||
[&moduleTranslation](llvm::ArrayRef<BlockArgument> blockArgs,
|
||||
OperandRange operands) {
|
||||
for (auto [arg, var] : llvm::zip_equal(blockArgs, operands))
|
||||
[&moduleTranslation](omp::BlockArgOpenMPOpInterface blockArgIface) {
|
||||
llvm::SmallVector<std::pair<Value, BlockArgument>> blockArgsPairs;
|
||||
blockArgIface.getBlockArgsPairs(blockArgsPairs);
|
||||
for (auto [var, arg] : blockArgsPairs)
|
||||
moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
|
||||
};
|
||||
|
||||
return llvm::TypeSwitch<Operation *, LogicalResult>(opInst)
|
||||
.Case([&](omp::SimdOp op) {
|
||||
auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*op);
|
||||
forwardArgs(blockArgIface.getPrivateBlockArgs(), op.getPrivateVars());
|
||||
forwardArgs(blockArgIface.getReductionBlockArgs(),
|
||||
op.getReductionVars());
|
||||
forwardArgs(cast<omp::BlockArgOpenMPOpInterface>(*op));
|
||||
op.emitWarning() << "simd information on composite construct discarded";
|
||||
return success();
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user