//===- SubElementInterfaces.cpp - Attr and Type SubElement Interfaces -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/IR/SubElementInterfaces.h" #include "mlir/IR/Operation.h" #include "llvm/ADT/DenseSet.h" using namespace mlir; //===----------------------------------------------------------------------===// // SubElementInterface //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// // WalkSubElements template static void walkSubElementsImpl(InterfaceT interface, function_ref walkAttrsFn, function_ref walkTypesFn, DenseSet &visitedAttrs, DenseSet &visitedTypes) { interface.walkImmediateSubElements( [&](Attribute attr) { // Guard against potentially null inputs. This removes the need for the // derived attribute/type to do it. if (!attr) return; // Avoid infinite recursion when visiting sub attributes later, if this // is a mutable attribute. if (LLVM_UNLIKELY(attr.hasTrait())) { if (!visitedAttrs.insert(attr).second) return; } // Walk any sub elements first. if (auto interface = attr.dyn_cast()) walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs, visitedTypes); // Walk this attribute. walkAttrsFn(attr); }, [&](Type type) { // Guard against potentially null inputs. This removes the need for the // derived attribute/type to do it. if (!type) return; // Avoid infinite recursion when visiting sub types later, if this // is a mutable type. if (LLVM_UNLIKELY(type.hasTrait())) { if (!visitedTypes.insert(type).second) return; } // Walk any sub elements first. if (auto interface = type.dyn_cast()) walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs, visitedTypes); // Walk this type. walkTypesFn(type); }); } void SubElementAttrInterface::walkSubElements( function_ref walkAttrsFn, function_ref walkTypesFn) { assert(walkAttrsFn && walkTypesFn && "expected valid walk functions"); DenseSet visitedAttrs; DenseSet visitedTypes; walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs, visitedTypes); } void SubElementTypeInterface::walkSubElements( function_ref walkAttrsFn, function_ref walkTypesFn) { assert(walkAttrsFn && walkTypesFn && "expected valid walk functions"); DenseSet visitedAttrs; DenseSet visitedTypes; walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs, visitedTypes); } //===----------------------------------------------------------------------===// /// AttrTypeReplacer //===----------------------------------------------------------------------===// void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) { // Functor that replaces the given element if the new value is different, // otherwise returns nullptr. auto replaceIfDifferent = [&](auto element) { auto replacement = replace(element); return (replacement && replacement != element) ? replacement : nullptr; }; // Update the attribute dictionary. if (replaceAttrs) { if (auto newAttrs = replaceIfDifferent(op->getAttrDictionary())) op->setAttrs(cast(newAttrs)); } // If we aren't updating locations or types, we're done. if (!replaceTypes && !replaceLocs) return; // Update the location. if (replaceLocs) { if (Attribute newLoc = replaceIfDifferent(op->getLoc())) op->setLoc(cast(newLoc)); } // Update the result types. if (replaceTypes) { for (OpResult result : op->getResults()) if (Type newType = replaceIfDifferent(result.getType())) result.setType(newType); } // Update any nested block arguments. for (Region ®ion : op->getRegions()) { for (Block &block : region) { for (BlockArgument &arg : block.getArguments()) { if (replaceLocs) { if (Attribute newLoc = replaceIfDifferent(arg.getLoc())) arg.setLoc(cast(newLoc)); } if (replaceTypes) { if (Type newType = replaceIfDifferent(arg.getType())) arg.setType(newType); } } } } } void AttrTypeReplacer::recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) { op->walk([&](Operation *nestedOp) { replaceElementsIn(nestedOp, replaceAttrs, replaceLocs, replaceTypes); }); } template static void updateSubElementImpl(T element, AttrTypeReplacer &replacer, DenseMap &elementMap, SmallVectorImpl &newElements, FailureOr &changed) { // Bail early if we failed at any point. if (failed(changed)) return; // Guard against potentially null inputs. We always map null to null. if (!element) { newElements.push_back(nullptr); return; } // Replace the element. if (T result = replacer.replace(element)) { newElements.push_back(result); if (result != element) changed = true; } else { changed = failure(); } } template T AttrTypeReplacer::replaceSubElements(InterfaceT interface, DenseMap &interfaceMap) { // Walk the current sub-elements, replacing them as necessary. SmallVector newAttrs; SmallVector newTypes; FailureOr changed = false; interface.walkImmediateSubElements( [&](Attribute element) { updateSubElementImpl(element, *this, attrMap, newAttrs, changed); }, [&](Type element) { updateSubElementImpl(element, *this, typeMap, newTypes, changed); }); if (failed(changed)) return nullptr; // If any sub-elements changed, use the new elements during the replacement. T result = interface; if (*changed) result = interface.replaceImmediateSubElements(newAttrs, newTypes); return result; } /// Shared implementation of replacing a given attribute or type element. template T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns, DenseMap &map) { auto [it, inserted] = map.try_emplace(element, element); if (!inserted) return it->second; T result = element; WalkResult walkResult = WalkResult::advance(); for (auto &replaceFn : llvm::reverse(replaceFns)) { if (Optional> newRes = replaceFn(element)) { std::tie(result, walkResult) = *newRes; break; } } // If an error occurred, return nullptr to indicate failure. if (walkResult.wasInterrupted() || !result) return map[element] = nullptr; // Handle replacing sub-elements if this element is also a container. if (!walkResult.wasSkipped()) { if (auto interface = dyn_cast(result)) { // Replace the sub elements of this element, bailing if we fail. if (!(result = replaceSubElements(interface, map))) return map[element] = nullptr; } } return map[element] = result; } Attribute AttrTypeReplacer::replace(Attribute attr) { return replaceImpl(attr, attrReplacementFns, attrMap); } Type AttrTypeReplacer::replace(Type type) { return replaceImpl(type, typeReplacementFns, typeMap); } //===----------------------------------------------------------------------===// // SubElementInterface Tablegen definitions //===----------------------------------------------------------------------===// #include "mlir/IR/SubElementAttrInterfaces.cpp.inc" #include "mlir/IR/SubElementTypeInterfaces.cpp.inc"