From aecf5491032d142ea3016de1828679d032d72e46 Mon Sep 17 00:00:00 2001 From: ykiko Date: Fri, 15 Nov 2024 22:46:07 +0800 Subject: [PATCH] Some fix for `PseudoInstantiator`. --- src/Compiler/Resolver.cpp | 195 +++++++++++++++++++++++-------------- unittests/AST/Resolver.cpp | 33 ++++--- 2 files changed, 146 insertions(+), 82 deletions(-) diff --git a/src/Compiler/Resolver.cpp b/src/Compiler/Resolver.cpp index 6504a413..e3b02c41 100644 --- a/src/Compiler/Resolver.cpp +++ b/src/Compiler/Resolver.cpp @@ -1,3 +1,4 @@ +#include "Support/TypeTraits.h" #include #include #include @@ -64,6 +65,10 @@ struct InstantiationStack { data.clear(); } + bool empty() const { + return data.empty(); + } + auto state() const { return data; } @@ -80,6 +85,10 @@ struct InstantiationStack { data.emplace_back(decl, arguments); } + void pop() { + data.pop_back(); + } + auto& frames() { return data; } @@ -98,6 +107,83 @@ public: Base(sema), sema(sema), context(sema.getASTContext()), resolved(resolved) {} public: + /// Deduce the template arguments for the given declaration. If deduction succeeds, push the + /// declaration and its deduced template arguments to the instantiation stack. + template + bool deduceTemplateArguments(Decl* decl, TemplateArguments arguments) { + clang::TemplateParameterList* list = nullptr; + TemplateArguments params = {}; + + if constexpr(std::is_same_v) { + const clang::ClassTemplateDecl* CTD = decl; + list = CTD->getTemplateParameters(); + params = list->getInjectedTemplateArgs(context); + } else if constexpr(std::is_same_v) { + const clang::ClassTemplatePartialSpecializationDecl* CTPSD = decl; + list = CTPSD->getTemplateParameters(); + params = CTPSD->getTemplateArgs().asArray(); + } else if constexpr(std::is_same_v) { + const clang::TypeAliasTemplateDecl* TATD = decl; + list = TATD->getTemplateParameters(); + params = list->getInjectedTemplateArgs(context); + } else { + static_assert(dependent_false, "Unknown declaration type"); + } + + assert(list && "No template parameters found"); + + TemplateDeductionInfo info = {clang::SourceLocation(), list->getDepth()}; + llvm::SmallVector deduced(list->size()); + + auto result = sema.DeduceTemplateArguments(list, params, arguments, info, deduced, false); + bool success = + result == clang::TemplateDeductionResult::Success && !info.hasSFINAEDiagnostic(); + + if(!success) { + return false; + } + + /// made up class template context. + if(stack.empty()) { + clang::Decl* D = decl; + while(true) { + auto context = llvm::dyn_cast(D->getDeclContext()); + assert(context && "No context found"); + + clang::TemplateParameterList* params = nullptr; + + if(auto TD = context->getDescribedTemplate()) { + params = TD->getTemplateParameters(); + D = TD; + } + + if(auto CTPSD = + llvm::dyn_cast(context)) { + params = CTPSD->getTemplateParameters(); + D = CTPSD; + } + + if(auto VTPSD = + llvm::dyn_cast(context)) { + params = VTPSD->getTemplateParameters(); + D = VTPSD; + } + + if(!params) { + break; + } + + stack.push_front(D, params->getInjectedTemplateArgs(this->context)); + continue; + } + } + + llvm::SmallVector output(deduced.begin(), deduced.end()); + stack.push(decl, output); + + return true; + } + /// If this class and its base class have members with the same name, `DeclContext::lookup` /// will return multiple declarations in order from the base class to the derived class, so we /// use the last declaration. @@ -133,14 +219,7 @@ public: if(auto CTD = llvm::dyn_cast(TD)) { return lookup(CTD, name, args); } else if(auto TATD = llvm::dyn_cast(TD)) { - clang::TemplateParameterList* list = TATD->getTemplateParameters(); - TemplateDeductionInfo info{clang::SourceLocation(), list->getDepth()}; - TemplateArguments params = list->getInjectedTemplateArgs(context); - llvm::SmallVector deduced(args.size()); - auto result = sema.DeduceTemplateArguments(list, params, args, info, deduced, false); - if(result == clang::TemplateDeductionResult::Success) { - llvm::SmallVector list(deduced.begin(), deduced.end()); - stack.push(TATD, list); + if(deduceTemplateArguments(TATD, args)) { return lookup(instantiate(TATD->getTemplatedDecl()->getUnderlyingType()), name); } } @@ -190,6 +269,25 @@ public: } } + /// Look up the name in the bases of the given class. Keep stack unchanged. + clang::lookup_result lookupInBases(clang::CXXRecordDecl* CRD, clang::DeclarationName name) { + if(!CRD->hasDefinition()) { + return clang::lookup_result(); + } + + for(auto base: CRD->bases()) { + if(auto type = base.getType(); type->isDependentType()) { + auto state = stack.state(); + if(auto members = lookup(instantiate(type), name); !members.empty()) { + return members; + } + stack.rewind(state); + } + } + + return clang::lookup_result(); + } + /// Look up the name in the given class template. We first search the name in the /// primary template, if failed, try dependent base classes, if still failed, try /// partial specializations. **Note that this function will be responsible for pushing @@ -197,42 +295,20 @@ public: clang::lookup_result lookup(clang::ClassTemplateDecl* CTD, clang::DeclarationName name, TemplateArguments arguments) { - clang::TemplateParameterList* list = CTD->getTemplateParameters(); - TemplateDeductionInfo info{clang::SourceLocation(), list->getDepth()}; - TemplateArguments params = list->getInjectedTemplateArgs(context); - llvm::SmallVector deduced(arguments.size()); - - if(auto result = - sema.DeduceTemplateArguments(list, params, arguments, info, deduced, false); - result == clang::TemplateDeductionResult::Success) { - llvm::SmallVector list(deduced.begin(), deduced.end()); - - auto RD = CTD->getTemplatedDecl(); + if(deduceTemplateArguments(CTD, arguments)) { + auto CRD = CTD->getTemplatedDecl(); /// First, try to find the name in the primary template. - if(auto members = CTD->getTemplatedDecl()->lookup(name); !members.empty()) { - /// FIXME: reduce copy here. - stack.push(CTD, list); + if(auto members = CRD->lookup(name); !members.empty()) { return members; } - if(RD->hasDefinition()) { - /// Try to find the member in the base class. - for(auto base: CTD->getTemplatedDecl()->bases()) { - if(auto type = base.getType(); type->isDependentType()) { - /// Because we instantiate the base class, this will clear the instantiation - /// stack. If the lookup fails, we need to rewind the stack to try the next - /// base class. - auto state = stack.state(); - stack.push(CTD, list); - - if(auto members = lookup(instantiate(type), name); !members.empty()) { - return members; - } - - stack.rewind(state); - } - } + /// If failed, try to find the name in the dependent base classes. + if(auto members = lookupInBases(CRD, name); !members.empty()) { + return members; } + + /// If failed, pop the decl and deduced template arguments. + stack.pop(); } /// Try to find the name in the partial specializations. @@ -240,21 +316,16 @@ public: CTD->getPartialSpecializations(partials); for(auto partial: partials) { - clang::TemplateParameterList* list = partial->getTemplateParameters(); - TemplateDeductionInfo info{clang::SourceLocation(), list->getDepth()}; - TemplateArguments params = partial->getTemplateArgs().asArray(); - llvm::SmallVector deduced(list->size()); - - auto result = - sema.DeduceTemplateArguments(list, params, arguments, info, deduced, false); - if(result == clang::TemplateDeductionResult::Success) { + if(deduceTemplateArguments(partial, arguments)) { if(auto members = partial->lookup(name); !members.empty()) { - llvm::SmallVector list(deduced.begin(), - deduced.end()); - stack.push(partial, list); - // FIXME: should we delete the list? return members; } + + if(auto members = lookupInBases(partial, name); !members.empty()) { + return members; + } + + stack.pop(); } } @@ -271,24 +342,8 @@ public: auto& contexts = sema.CodeSynthesisContexts; assert(contexts.empty() && "CodeSynthesisContexts should be empty"); - assert(!stack.frames().empty() && "Instantiation stack should not be empty"); - /// made up class template context. - while(true) { - auto top = stack.frames().front().first; - auto CRD = llvm::dyn_cast(top->getDeclContext()); - /// FIXME: other template context. - if(CRD && CRD->getDescribedTemplate()) { - auto TD = CRD->getDescribedTemplate(); - TemplateArguments arguments = - TD->getTemplateParameters()->getInjectedTemplateArgs(context); - stack.push_front(TD, arguments); - } else { - break; - } - } - std::ranges::for_each(stack.frames(), [&](auto& frame) { clang::Sema::CodeSynthesisContext context; context.Entity = frame.first; @@ -303,11 +358,9 @@ public: }); type = DealiasOnly(sema).TransformType(type); - llvm::outs() << "--------------------------------------\n"; - list.dump(); - type.dump(); + auto result = sema.SubstType(type, list, {}, {}); - result.dump(); + stack.clear(); contexts.clear(); diff --git a/unittests/AST/Resolver.cpp b/unittests/AST/Resolver.cpp index b38da891..b4a94289 100644 --- a/unittests/AST/Resolver.cpp +++ b/unittests/AST/Resolver.cpp @@ -6,17 +6,7 @@ namespace { using namespace clice; -namespace testing { - -std::string PrintToString(const clang::QualType& type) { - std::string str; - llvm::raw_string_ostream ss(str); - type.print(ss, clang::PrintingPolicy({})); - ss.flush(); - return str; -} - -} // namespace testing +namespace testing {} // namespace testing struct TemplateResolverTester : public clang::RecursiveASTVisitor { TemplateResolverTester(llvm::StringRef code) { @@ -345,6 +335,27 @@ struct test { )cpp"); } +TEST(TemplateResolver, InnerDependentPartialMemberClass) { + TemplateResolverTester tester(R"cpp( +template +struct type_list {}; + +template +struct test {}; + +template +struct test { + template + struct A { + using type = type_list; + }; + + using input = typename A<1, T>::type; + using expect = type_list; +}; +)cpp"); +} + TEST(TemplateResolver, PartialSpecialization) { TemplateResolverTester tester(R"cpp( template