diff --git a/mlir/docs/Dialects/OpenMPDialect/_index.md b/mlir/docs/Dialects/OpenMPDialect/_index.md index adde17675043..1df80fac2a68 100644 --- a/mlir/docs/Dialects/OpenMPDialect/_index.md +++ b/mlir/docs/Dialects/OpenMPDialect/_index.md @@ -372,6 +372,8 @@ accessed: should be located. - `getBlockArgs()`: Returns the list of entry block arguments defined by the given clause. + - `numClauseBlockArgs()`: Returns the total number of entry block arguments + defined by all clauses. - `getBlockArgsPairs()`: Returns a list of pairs where the first element is the outside value, or operand, and the second element is the corresponding entry block argument. diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td index 0766b4e8d147..3fa54d35ed09 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td @@ -136,12 +136,20 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> { !foreach(clause, clauses, clause.startMethod), !foreach(clause, clauses, clause.blockArgsMethod), [ + InterfaceMethod< + "Get the total number of clause-defined entry block arguments", + "unsigned", "numClauseBlockArgs", (ins), + "return " # !interleave( + !foreach(clause, clauses, "$_op." # clause.numArgsMethod.name # "()"), + " + ") # ";" + >, InterfaceMethod< "Populate a vector of pairs representing the matching between operands " "and entry block arguments.", "void", "getBlockArgsPairs", (ins "::llvm::SmallVectorImpl> &" : $pairs), [{ auto iface = ::llvm::cast(*$_op); + pairs.reserve(pairs.size() + iface.numClauseBlockArgs()); }] # !interleave(!foreach(clause, clauses, [{ }] # "if (iface." # clause.numArgsMethod.name # "() > 0) {" # [{ }] # " for (auto [var, arg] : ::llvm::zip_equal(" # @@ -155,11 +163,7 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> { let verify = [{ auto iface = ::llvm::cast($_op); - }] # "unsigned expectedArgs = " - # !interleave( - !foreach(clause, clauses, "iface." # clause.numArgsMethod.name # "()"), - " + " - ) # ";" # [{ + unsigned expectedArgs = iface.numClauseBlockArgs(); if ($_op->getRegion(0).getNumArguments() < expectedArgs) return $_op->emitOpError() << "expected at least " << expectedArgs << " entry block argument(s)"; diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 3373f19a006b..b9893716980f 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -550,18 +550,16 @@ convertIgnoredWrapper(omp::LoopWrapperInterface opInst, // corresponding operand. This is semantically equivalent to this wrapper not // being present. auto forwardArgs = - [&moduleTranslation](llvm::ArrayRef blockArgs, - OperandRange operands) { - for (auto [arg, var] : llvm::zip_equal(blockArgs, operands)) + [&moduleTranslation](omp::BlockArgOpenMPOpInterface blockArgIface) { + llvm::SmallVector> blockArgsPairs; + blockArgIface.getBlockArgsPairs(blockArgsPairs); + for (auto [var, arg] : blockArgsPairs) moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var)); }; return llvm::TypeSwitch(opInst) .Case([&](omp::SimdOp op) { - auto blockArgIface = cast(*op); - forwardArgs(blockArgIface.getPrivateBlockArgs(), op.getPrivateVars()); - forwardArgs(blockArgIface.getReductionBlockArgs(), - op.getReductionVars()); + forwardArgs(cast(*op)); op.emitWarning() << "simd information on composite construct discarded"; return success(); })