Files
clang-p2996/mlir/lib/AsmParser/AsmParserState.cpp
Tres Popp 5550c82189 [mlir] Move casting calls from methods to function calls
The MLIR classes Type/Attribute/Operation/Op/Value support
cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast
functionality in addition to defining methods with the same name.
This change begins the migration of uses of the method to the
corresponding function call as has been decided as more consistent.

Note that there still exist classes that only define methods directly,
such as AffineExpr, and this does not include work currently to support
a functional cast/isa call.

Caveats include:
- This clang-tidy script probably has more problems.
- This only touches C++ code, so nothing that is being generated.

Context:
- https://mlir.llvm.org/deprecation/ at "Use the free function variants
  for dyn_cast/cast/isa/…"
- Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443

Implementation:
This first patch was created with the following steps. The intention is
to only do automated changes at first, so I waste less time if it's
reverted, and so the first mass change is more clear as an example to
other teams that will need to follow similar steps.

Steps are described per line, as comments are removed by git:
0. Retrieve the change from the following to build clang-tidy with an
   additional check:
   https://github.com/llvm/llvm-project/compare/main...tpopp:llvm-project:tidy-cast-check
1. Build clang-tidy
2. Run clang-tidy over your entire codebase while disabling all checks
   and enabling the one relevant one. Run on all header files also.
3. Delete .inc files that were also modified, so the next build rebuilds
   them to a pure state.
4. Some changes have been deleted for the following reasons:
   - Some files had a variable also named cast
   - Some files had not included a header file that defines the cast
     functions
   - Some files are definitions of the classes that have the casting
     methods, so the code still refers to the method instead of the
     function without adding a prefix or removing the method declaration
     at the same time.

```
ninja -C $BUILD_DIR clang-tidy

run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\
               -header-filter=mlir/ mlir/* -fix

rm -rf $BUILD_DIR/tools/mlir/**/*.inc

git restore mlir/lib/IR mlir/lib/Dialect/DLTI/DLTI.cpp\
            mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp\
            mlir/lib/**/IR/\
            mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp\
            mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp\
            mlir/test/lib/Dialect/Test/TestTypes.cpp\
            mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp\
            mlir/test/lib/Dialect/Test/TestAttributes.cpp\
            mlir/unittests/TableGen/EnumsGenTest.cpp\
            mlir/test/python/lib/PythonTestCAPI.cpp\
            mlir/include/mlir/IR/
```

Differential Revision: https://reviews.llvm.org/D150123
2023-05-12 11:21:25 +02:00

345 lines
13 KiB
C++

//===- AsmParserState.cpp -------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/AsmParser/AsmParserState.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/ADT/StringExtras.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// AsmParserState::Impl
//===----------------------------------------------------------------------===//
struct AsmParserState::Impl {
/// A map from a SymbolRefAttr to a range of uses.
using SymbolUseMap =
DenseMap<Attribute, SmallVector<SmallVector<SMRange>, 0>>;
struct PartialOpDef {
explicit PartialOpDef(const OperationName &opName) {
if (opName.hasTrait<OpTrait::SymbolTable>())
symbolTable = std::make_unique<SymbolUseMap>();
}
/// Return if this operation is a symbol table.
bool isSymbolTable() const { return symbolTable.get(); }
/// If this operation is a symbol table, the following contains symbol uses
/// within this operation.
std::unique_ptr<SymbolUseMap> symbolTable;
};
/// Resolve any symbol table uses in the IR.
void resolveSymbolUses();
/// A mapping from operations in the input source file to their parser state.
SmallVector<std::unique_ptr<OperationDefinition>> operations;
DenseMap<Operation *, unsigned> operationToIdx;
/// A mapping from blocks in the input source file to their parser state.
SmallVector<std::unique_ptr<BlockDefinition>> blocks;
DenseMap<Block *, unsigned> blocksToIdx;
/// A set of value definitions that are placeholders for forward references.
/// This map should be empty if the parser finishes successfully.
DenseMap<Value, SmallVector<SMLoc>> placeholderValueUses;
/// The symbol table operations within the IR.
SmallVector<std::pair<Operation *, std::unique_ptr<SymbolUseMap>>>
symbolTableOperations;
/// A stack of partial operation definitions that have been started but not
/// yet finalized.
SmallVector<PartialOpDef> partialOperations;
/// A stack of symbol use scopes. This is used when collecting symbol table
/// uses during parsing.
SmallVector<SymbolUseMap *> symbolUseScopes;
/// A symbol table containing all of the symbol table operations in the IR.
SymbolTableCollection symbolTable;
};
void AsmParserState::Impl::resolveSymbolUses() {
SmallVector<Operation *> symbolOps;
for (auto &opAndUseMapIt : symbolTableOperations) {
for (auto &it : *opAndUseMapIt.second) {
symbolOps.clear();
if (failed(symbolTable.lookupSymbolIn(
opAndUseMapIt.first, cast<SymbolRefAttr>(it.first), symbolOps)))
continue;
for (ArrayRef<SMRange> useRange : it.second) {
for (const auto &symIt : llvm::zip(symbolOps, useRange)) {
auto opIt = operationToIdx.find(std::get<0>(symIt));
if (opIt != operationToIdx.end())
operations[opIt->second]->symbolUses.push_back(std::get<1>(symIt));
}
}
}
}
}
//===----------------------------------------------------------------------===//
// AsmParserState
//===----------------------------------------------------------------------===//
AsmParserState::AsmParserState() : impl(std::make_unique<Impl>()) {}
AsmParserState::~AsmParserState() = default;
AsmParserState &AsmParserState::operator=(AsmParserState &&other) {
impl = std::move(other.impl);
return *this;
}
//===----------------------------------------------------------------------===//
// Access State
auto AsmParserState::getBlockDefs() const -> iterator_range<BlockDefIterator> {
return llvm::make_pointee_range(llvm::ArrayRef(impl->blocks));
}
auto AsmParserState::getBlockDef(Block *block) const
-> const BlockDefinition * {
auto it = impl->blocksToIdx.find(block);
return it == impl->blocksToIdx.end() ? nullptr : &*impl->blocks[it->second];
}
auto AsmParserState::getOpDefs() const -> iterator_range<OperationDefIterator> {
return llvm::make_pointee_range(llvm::ArrayRef(impl->operations));
}
auto AsmParserState::getOpDef(Operation *op) const
-> const OperationDefinition * {
auto it = impl->operationToIdx.find(op);
return it == impl->operationToIdx.end() ? nullptr
: &*impl->operations[it->second];
}
/// Lex a string token whose contents start at the given `curPtr`. Returns the
/// position at the end of the string, after a terminal or invalid character
/// (e.g. `"` or `\0`).
static const char *lexLocStringTok(const char *curPtr) {
while (char c = *curPtr++) {
// Check for various terminal characters.
if (StringRef("\"\n\v\f").contains(c))
return curPtr;
// Check for escape sequences.
if (c == '\\') {
// Check a few known escapes and \xx hex digits.
if (*curPtr == '"' || *curPtr == '\\' || *curPtr == 'n' || *curPtr == 't')
++curPtr;
else if (llvm::isHexDigit(*curPtr) && llvm::isHexDigit(curPtr[1]))
curPtr += 2;
else
return curPtr;
}
}
// If we hit this point, we've reached the end of the buffer. Update the end
// pointer to not point past the buffer.
return curPtr - 1;
}
SMRange AsmParserState::convertIdLocToRange(SMLoc loc) {
if (!loc.isValid())
return SMRange();
const char *curPtr = loc.getPointer();
// Check if this is a string token.
if (*curPtr == '"') {
curPtr = lexLocStringTok(curPtr + 1);
// Otherwise, default to handling an identifier.
} else {
// Return if the given character is a valid identifier character.
auto isIdentifierChar = [](char c) {
return isalnum(c) || c == '$' || c == '.' || c == '_' || c == '-';
};
while (*curPtr && isIdentifierChar(*(++curPtr)))
continue;
}
return SMRange(loc, SMLoc::getFromPointer(curPtr));
}
//===----------------------------------------------------------------------===//
// Populate State
void AsmParserState::initialize(Operation *topLevelOp) {
startOperationDefinition(topLevelOp->getName());
// If the top-level operation is a symbol table, push a new symbol scope.
Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
if (partialOpDef.isSymbolTable())
impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get());
}
void AsmParserState::finalize(Operation *topLevelOp) {
assert(!impl->partialOperations.empty() &&
"expected valid partial operation definition");
Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
// If this operation is a symbol table, resolve any symbol uses.
if (partialOpDef.isSymbolTable()) {
impl->symbolTableOperations.emplace_back(
topLevelOp, std::move(partialOpDef.symbolTable));
}
impl->resolveSymbolUses();
}
void AsmParserState::startOperationDefinition(const OperationName &opName) {
impl->partialOperations.emplace_back(opName);
}
void AsmParserState::finalizeOperationDefinition(
Operation *op, SMRange nameLoc, SMLoc endLoc,
ArrayRef<std::pair<unsigned, SMLoc>> resultGroups) {
assert(!impl->partialOperations.empty() &&
"expected valid partial operation definition");
Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
// Build the full operation definition.
std::unique_ptr<OperationDefinition> def =
std::make_unique<OperationDefinition>(op, nameLoc, endLoc);
for (auto &resultGroup : resultGroups)
def->resultGroups.emplace_back(resultGroup.first,
convertIdLocToRange(resultGroup.second));
impl->operationToIdx.try_emplace(op, impl->operations.size());
impl->operations.emplace_back(std::move(def));
// If this operation is a symbol table, resolve any symbol uses.
if (partialOpDef.isSymbolTable()) {
impl->symbolTableOperations.emplace_back(
op, std::move(partialOpDef.symbolTable));
}
}
void AsmParserState::startRegionDefinition() {
assert(!impl->partialOperations.empty() &&
"expected valid partial operation definition");
// If the parent operation of this region is a symbol table, we also push a
// new symbol scope.
Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
if (partialOpDef.isSymbolTable())
impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get());
}
void AsmParserState::finalizeRegionDefinition() {
assert(!impl->partialOperations.empty() &&
"expected valid partial operation definition");
// If the parent operation of this region is a symbol table, pop the symbol
// scope for this region.
Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
if (partialOpDef.isSymbolTable())
impl->symbolUseScopes.pop_back();
}
void AsmParserState::addDefinition(Block *block, SMLoc location) {
auto it = impl->blocksToIdx.find(block);
if (it == impl->blocksToIdx.end()) {
impl->blocksToIdx.try_emplace(block, impl->blocks.size());
impl->blocks.emplace_back(std::make_unique<BlockDefinition>(
block, convertIdLocToRange(location)));
return;
}
// If an entry already exists, this was a forward declaration that now has a
// proper definition.
impl->blocks[it->second]->definition.loc = convertIdLocToRange(location);
}
void AsmParserState::addDefinition(BlockArgument blockArg, SMLoc location) {
auto it = impl->blocksToIdx.find(blockArg.getOwner());
assert(it != impl->blocksToIdx.end() &&
"expected owner block to have an entry");
BlockDefinition &def = *impl->blocks[it->second];
unsigned argIdx = blockArg.getArgNumber();
if (def.arguments.size() <= argIdx)
def.arguments.resize(argIdx + 1);
def.arguments[argIdx] = SMDefinition(convertIdLocToRange(location));
}
void AsmParserState::addUses(Value value, ArrayRef<SMLoc> locations) {
// Handle the case where the value is an operation result.
if (OpResult result = dyn_cast<OpResult>(value)) {
// Check to see if a definition for the parent operation has been recorded.
// If one hasn't, we treat the provided value as a placeholder value that
// will be refined further later.
Operation *parentOp = result.getOwner();
auto existingIt = impl->operationToIdx.find(parentOp);
if (existingIt == impl->operationToIdx.end()) {
impl->placeholderValueUses[value].append(locations.begin(),
locations.end());
return;
}
// If a definition does exist, locate the value's result group and add the
// use. The result groups are ordered by increasing start index, so we just
// need to find the last group that has a smaller/equal start index.
unsigned resultNo = result.getResultNumber();
OperationDefinition &def = *impl->operations[existingIt->second];
for (auto &resultGroup : llvm::reverse(def.resultGroups)) {
if (resultNo >= resultGroup.startIndex) {
for (SMLoc loc : locations)
resultGroup.definition.uses.push_back(convertIdLocToRange(loc));
return;
}
}
llvm_unreachable("expected valid result group for value use");
}
// Otherwise, this is a block argument.
BlockArgument arg = cast<BlockArgument>(value);
auto existingIt = impl->blocksToIdx.find(arg.getOwner());
assert(existingIt != impl->blocksToIdx.end() &&
"expected valid block definition for block argument");
BlockDefinition &blockDef = *impl->blocks[existingIt->second];
SMDefinition &argDef = blockDef.arguments[arg.getArgNumber()];
for (SMLoc loc : locations)
argDef.uses.emplace_back(convertIdLocToRange(loc));
}
void AsmParserState::addUses(Block *block, ArrayRef<SMLoc> locations) {
auto it = impl->blocksToIdx.find(block);
if (it == impl->blocksToIdx.end()) {
it = impl->blocksToIdx.try_emplace(block, impl->blocks.size()).first;
impl->blocks.emplace_back(std::make_unique<BlockDefinition>(block));
}
BlockDefinition &def = *impl->blocks[it->second];
for (SMLoc loc : locations)
def.definition.uses.push_back(convertIdLocToRange(loc));
}
void AsmParserState::addUses(SymbolRefAttr refAttr,
ArrayRef<SMRange> locations) {
// Ignore this symbol if no scopes are active.
if (impl->symbolUseScopes.empty())
return;
assert((refAttr.getNestedReferences().size() + 1) == locations.size() &&
"expected the same number of references as provided locations");
(*impl->symbolUseScopes.back())[refAttr].emplace_back(locations.begin(),
locations.end());
}
void AsmParserState::refineDefinition(Value oldValue, Value newValue) {
auto it = impl->placeholderValueUses.find(oldValue);
assert(it != impl->placeholderValueUses.end() &&
"expected `oldValue` to be a placeholder");
addUses(newValue, it->second);
impl->placeholderValueUses.erase(oldValue);
}