#include #include #include class DependentShaper { clang::ASTContext& context; clang::Sema& sema; public: DependentShaper(clang::ASTContext& context, clang::Sema& sema) : context(context), sema(sema) {} /// replace template arguments in a dependent type /// e.g. /// ```cpp /// template /// struct A { /// using reference = T1&; /// }; /// /// template /// struct B { /// using type = A::reference; /// }; /// ``` /// we want to simplify `A::reference` to `U1&`, then we need to call /// `replace(, )` to get `U1&`. clang::QualType replace(clang::QualType input, llvm::ArrayRef originalArguments) { assert(type->isDependentType() && "type is not dependent"); // if the input is still a dependent name type, we need to simplify it recursively while(auto type = llvm::dyn_cast(input)) { input = simplify(type); } clang::ElaboratedType* elaboratedType = nullptr; clang::MultiLevelTemplateArgumentList list; list.addOuterTemplateArguments(originalArguments); llvm::outs() << "-------------------------------------------------\n"; clang::Sema::CodeSynthesisContext InstContext; // InstContext.Kind = clang::Sema::CodeSynthesisContext::TemplateInstantiation; // InstContext.Entity = nullptr; // InstContext.TemplateArgs = TemplateArgs; // sema.CodeSynthesisContexts.back().dump(); auto result = sema.SubstType(input, list, {}, {}); if(auto type = llvm::dyn_cast(input)) { auto pointee = type->getPointeeType(); while(auto elaboratedType = llvm::dyn_cast(pointee)) { pointee = elaboratedType->desugar(); } // get the position of the dependent type in the template parameter list // e.g `T1` in , the index is 0 if(auto paramType = llvm::dyn_cast(pointee)) { auto index = paramType->getIndex(); auto argument = originalArguments[index]; // create a new type that fit the replacement return context.getLValueReferenceType(argument.getAsType()); } else if(auto defType = llvm::dyn_cast(pointee)) { return replace(defType->getDecl()->getUnderlyingType(), originalArguments); } else { pointee->dump(); std::terminate(); } } // TODO: handle other kinds return input; } clang::NamedDecl* lookup(const clang::ClassTemplateDecl* classTemplateDecl, const clang::IdentifierInfo* identifier) { clang::CXXRecordDecl* recordDecl = classTemplateDecl->getTemplatedDecl(); auto result = recordDecl->lookup(identifier); return result.front(); } /// for a complex dependent type: `X<...>::name::name2::...::nameN`, we can resolve it recursively. /// so we only need to handle the `X<...>::name`, whose prefix is a template specialization type. clang::QualType simplify(const clang::TemplateSpecializationType* templateType, const clang::IdentifierInfo* identifier) { // X is a class template or a type alias template auto templateDecl = templateType->getTemplateName().getAsTemplateDecl(); if(auto classTemplateDecl = llvm::dyn_cast(templateDecl)) { // lookup the identifier in the record decl auto namedDecl = lookup(classTemplateDecl, identifier); if(auto decl = llvm::dyn_cast(namedDecl)) { return replace(decl->getUnderlyingType(), templateType->template_arguments()); } else if(auto decl = llvm::dyn_cast(namedDecl)) { return replace(decl->getUnderlyingType(), templateType->template_arguments()); } else { namedDecl->dump(); } } else if(auto aliasTemplateDecl = llvm::dyn_cast(templateDecl)) { // TODO: } else { templateDecl->dump(); } } const clang::QualType simplify(const clang::NestedNameSpecifier* specifier, const clang::IdentifierInfo* identifier) { auto kind = specifier->getKind(); switch(specifier->getKind()) { case clang::NestedNameSpecifier::Identifier: { const auto prefix = simplify(specifier->getPrefix(), specifier->getAsIdentifier()); if(auto type = llvm::dyn_cast(prefix)) { return simplify(type, identifier); } else { prefix->dump(); } break; } case clang::NestedNameSpecifier::TypeSpec: { auto node = specifier->getAsType(); if(auto type = llvm::dyn_cast(node)) { // represent a direct dependent name, e.g. typename T::^ name // and can not be further simplified // node->dump(); type->dump(); } else if(auto type = node->getAs()) { // represent a dependent name that is a template specialization // e.g. typename vector::^ name, and can be further simplified return simplify(type, identifier); } else { node->dump(); } break; } case clang::NestedNameSpecifier::TypeSpecWithTemplate: { llvm::outs() << "unsupported kind: " << kind << "\n"; break; } default: { llvm::outs() << "unsupported kind: " << kind << "\n"; } } } const clang::QualType simplify(const clang::DependentNameType* type) { // llvm::outs() << "-----------------------------------------------------------" << "\n"; // type->dump(); return simplify(type->getQualifier(), type->getIdentifier()); } }; namespace clang { class DependentNameResolver { public: Sema& S; ASTContext& Ctx; clang::NamedDecl* CurrentDecl; public: DependentNameResolver(ASTContext& Ctx, Sema& S) : Ctx(Ctx), S(S) {} std::vector resolve(llvm::ArrayRef arguments) { std::vector result; for(auto arg: arguments) { if(arg.getKind() == TemplateArgument::ArgKind::Type) { if(auto type = llvm::dyn_cast(arg.getAsType())) { const TemplateTypeParmDecl* param = type->getDecl(); if(param->hasDefaultArgument()) { result.push_back(param->getDefaultArgument().getArgument()); continue; } } } result.push_back(arg); } return result; } QualType resolve(QualType T) { if(!T->isDependentType()) { return T; } while(true) { if(auto DNT = T->getAs()) { T = resolve(DNT); } else if(auto DTST = T->getAs()) { T = resolve(DTST); } else if(auto LRT = T->getAs()) { return Ctx.getLValueReferenceType(resolve(LRT->getPointeeType())); } else { return T; } } } /// resolve a dependent name type, e.g. `typename std::vector::reference` QualType resolve(const DependentNameType* DNT) { // e.g. when DNT is `typename std::vector::reference` // - qualifier: std::vector // - identifier: reference return resolve(DNT->getQualifier(), DNT->getIdentifier()); } /// resolve a dependent template specialization type. QualType resolve(const DependentTemplateSpecializationType* DTST) { // e.g. when DTST is `typename std::allocator_traits::template rebind_alloc`. // - qualifier: std::allocator_traits // - identifier: rebind_alloc // - template_arguments: return resolve(DTST->getQualifier(), DTST->getIdentifier(), DTST->template_arguments()); } QualType resolve(const NestedNameSpecifier* TST, const IdentifierInfo* II, ArrayRef arguments = {}) { switch(TST->getKind()) { case NestedNameSpecifier::SpecifierKind::Identifier: { llvm::outs() << "\n------------------ Identifier -----------------------\n"; // when the kind of TST is Identifier // e.g. std::vector>::value_type:: // resolve it recursively return resolve(resolve(TST->getPrefix(), TST->getAsIdentifier()), II, arguments); } case NestedNameSpecifier::SpecifierKind::TypeSpec: { llvm::outs() << "\n------------------ TypeSpec -----------------------\n"; TST->dump(); llvm::outs() << " " << II->getName() << "\n"; // when the kind of TST is TypeSpec, e.g. std::vector:: return resolve(QualType(TST->getAsType(), 0), II, arguments); } case NestedNameSpecifier::SpecifierKind::TypeSpecWithTemplate: { llvm::outs() << "------------------ TypeSpecWithTemplate -----------------------\n"; // when the kind of TST is TypeSpecWithTemplate, e.g. std::vector::template name:: TST->dump(); return resolve(QualType(TST->getAsType(), 0), II, arguments); } default: { llvm::outs() << "\n------------------ Unknown -----------------------\n"; TST->dump(); std::terminate(); } } } QualType substitute(ClassTemplateDecl* CTD, const IdentifierInfo* II, ArrayRef arguments = {}) { Sema::CodeSynthesisContext context; auto args = resolve(arguments); context.Entity = CTD; context.TemplateArgs = args.data(); context.Kind = Sema::CodeSynthesisContext::TemplateInstantiation; S.pushCodeSynthesisContext(context); MultiLevelTemplateArgumentList list; auto recordDecl = CTD->getTemplatedDecl(); auto member = recordDecl->lookup(II).front(); QualType type; if(auto TAD = llvm::dyn_cast(member)) { type = TAD->getUnderlyingType(); } else if(auto TD = llvm::dyn_cast(member)) { type = TD->getUnderlyingType(); } else if(auto TATD = llvm::dyn_cast(member)) { auto args2 = resolve(arguments); context.Entity = TATD; context.TemplateArgs = args2.data(); context.Kind = Sema::CodeSynthesisContext::TypeAliasTemplateInstantiation; S.pushCodeSynthesisContext(context); MultiLevelTemplateArgumentList list; list.addOuterTemplateArguments(TATD, args2, false); auto TAD = TATD->getTemplatedDecl(); type = TAD->getUnderlyingType(); } else if(auto CTD = llvm::dyn_cast(member)) { return substitute(CTD, II, arguments); } else { member->dump(); std::terminate(); } list.addOuterTemplateArguments(CTD, args, true); return S.SubstType(type, list, {}, {}); } /// typename A::template B::template C::type:: QualType resolve(const DependentTemplateSpecializationType* DTST, const IdentifierInfo* II) { // auto prefix; MultiLevelTemplateArgumentList list; while(true) { // list.addOuterTemplateArguments() } } QualType resolve(QualType T, const IdentifierInfo* II, ArrayRef arguments = {}) { if(!T->isDependentType() && arguments.size() == 0) { // TODO: } llvm::outs() << "\n"; T.dump(); Sema::CodeSynthesisContext context; if(auto TTPT = T->getAs()) { llvm::outs() << "\n-------------------------------------------------\n"; T->dump(); llvm::outs() << " \n" << II->getName(); // e.g. when T is `T` // - index: 0 } else if(auto TST = T->getAs()) { auto TemplateName = TST->getTemplateName(); auto TemplateDecl = TemplateName.getAsTemplateDecl(); auto TemplatedDecl = TemplateDecl->getTemplatedDecl(); if(auto CTD = llvm::dyn_cast(TemplateDecl)) { return substitute(CTD, II, TST->template_arguments()); } else { TemplateDecl->dump(); std::terminate(); } auto TemplateArgs = TST->template_arguments(); } else if(auto DTST = T->getAs()) { auto TST = DTST->getQualifier()->getAsType()->getAs(); auto TemplateName = TST->getTemplateName(); auto TemplateDecl = TemplateName.getAsTemplateDecl(); auto TemplatedDecl = TemplateDecl->getTemplatedDecl(); if(auto CTD = llvm::dyn_cast(TemplateDecl)) { auto args = resolve(TST->template_arguments()); context.Entity = CTD; context.TemplateArgs = args.data(); context.Kind = Sema::CodeSynthesisContext::TemplateInstantiation; S.pushCodeSynthesisContext(context); MultiLevelTemplateArgumentList list; list.addOuterTemplateArguments(CTD, args, true); auto recordDecl = CTD->getTemplatedDecl(); auto member = recordDecl->lookup(DTST->getIdentifier()).front(); if(auto CTD2 = llvm::dyn_cast(member)) { auto args2 = resolve(DTST->template_arguments()); context.Entity = CTD2; context.TemplateArgs = args2.data(); context.Kind = Sema::CodeSynthesisContext::TemplateInstantiation; S.pushCodeSynthesisContext(context); MultiLevelTemplateArgumentList list; list.addOuterTemplateArguments(CTD2, args2, true); list.addOuterTemplateArguments(CTD, args, true); auto CRD = CTD2->getTemplatedDecl(); return S.SubstType( llvm::dyn_cast(CRD->lookup(II).front())->getUnderlyingType(), list, {}, {}); } else if(auto TATD = llvm::dyn_cast(member)) { auto args2 = resolve(arguments); context.Entity = TATD; context.TemplateArgs = args2.data(); context.Kind = Sema::CodeSynthesisContext::TypeAliasTemplateInstantiation; S.pushCodeSynthesisContext(context); MultiLevelTemplateArgumentList list; list.addOuterTemplateArguments(TATD, args2, false); list.addOuterTemplateArguments(CTD, args, true); auto TAD = TATD->getTemplatedDecl(); return S.SubstType(TAD->getUnderlyingType(), list, {}, {}); } else { member->dump(); std::terminate(); } } else { T->dump(); std::terminate(); } // S.SubstType() } } }; class DependentNameResolverV2 { public: Sema& S; ASTContext& Ctx; std::vector>*> arguments; public: DependentNameResolverV2(ASTContext& Ctx, Sema& S) : Ctx(Ctx), S(S) {} std::vector resolve(llvm::ArrayRef arguments) { std::vector result; for(auto arg: arguments) { if(arg.getKind() == TemplateArgument::ArgKind::Type) { if(auto type = llvm::dyn_cast(arg.getAsType())) { const TemplateTypeParmDecl* param = type->getDecl(); if(param && param->hasDefaultArgument()) { result.push_back(param->getDefaultArgument().getArgument()); continue; } } } result.push_back(arg); } return result; } QualType dealias(QualType type) { if(auto DNT = type->getAs()) { return QualType(DNT, 0); } else if(auto DTST = type->getAs()) { auto NNS = NestedNameSpecifier::Create( Ctx, nullptr, false, dealias(QualType(DTST->getQualifier()->getAsType(), 0)).getTypePtr()); return Ctx.getDependentTemplateSpecializationType(DTST->getKeyword(), NNS, DTST->getIdentifier(), resolve(DTST->template_arguments())); } else { return type; } } QualType resolve(QualType type) { while(true) { // llvm::outs() << "--------------------------------------------------------------------\n"; // type.dump(); MultiLevelTemplateArgumentList list; if(auto DNT = type->getAs()) { type = resolve(resolve(DNT->getQualifier(), DNT->getIdentifier())); for(auto begin = arguments.rbegin(), end = arguments.rend(); begin != end; ++begin) { list.addOuterTemplateArguments((*begin)->first, (*begin)->second, true); } type = S.SubstType(dealias(type), list, {}, {}); arguments.clear(); } else if(auto DTST = type->getAs()) { auto ND = resolve(DTST->getQualifier(), DTST->getIdentifier()); if(auto TATD = llvm::dyn_cast(ND)) { auto args = resolve(DTST->template_arguments()); Sema::CodeSynthesisContext context; context.Entity = TATD; context.Kind = Sema::CodeSynthesisContext::TypeAliasTemplateInstantiation; context.TemplateArgs = args.data(); S.pushCodeSynthesisContext(context); list.addOuterTemplateArguments(TATD, args, true); for(auto begin = arguments.rbegin(), end = arguments.rend(); begin != end; ++begin) { list.addOuterTemplateArguments((*begin)->first, (*begin)->second, true); } // llvm::outs() << "before: // ----------------------------------------------------------------\n"; // TATD->getTemplatedDecl()->getUnderlyingType().dump(); type = dealias(TATD->getTemplatedDecl()->getUnderlyingType()); // llvm::outs() << "arguments: // -------------------------------------------------------------\n"; list.dump(); type = S.SubstType(type, list, {}, {}); // type.dump(); arguments.clear(); } else { ND->dump(); std::terminate(); } // return resolve(DTST); } else if(auto LRT = type->getAs()) { type = Ctx.getLValueReferenceType(resolve(LRT->getPointeeType())); } else { return type; } } } QualType resolve(NamedDecl* ND) { if(auto TD = llvm::dyn_cast(ND)) { return TD->getUnderlyingType(); } else if(auto TAD = llvm::dyn_cast(ND)) { return TAD->getUnderlyingType(); } else { ND->dump(); std::terminate(); } } NamedDecl* resolve(const NestedNameSpecifier* NNS, const IdentifierInfo* II) { switch(NNS->getKind()) { // prefix is an identifier, e.g. <...>::name:: case NestedNameSpecifier::SpecifierKind::Identifier: { return lookup(resolve(resolve(NNS->getPrefix(), NNS->getAsIdentifier())), II); } // prefix is a type, e.g. <...>::typename name:: case NestedNameSpecifier::SpecifierKind::TypeSpec: case NestedNameSpecifier::SpecifierKind::TypeSpecWithTemplate: { return lookup(QualType(NNS->getAsType(), 0), II); } default: { NNS->dump(); std::terminate(); } } } NamedDecl* lookup(QualType Type, const IdentifierInfo* Name) { NamedDecl* TemplateDecl; ArrayRef arguments; llvm::outs() << "--------------------------------------------------------------------\n"; Type.dump(); if(auto TTPT = Type->getAs()) { Type->dump(); std::terminate(); } else if(auto TST = Type->getAs()) { auto TemplateName = TST->getTemplateName(); TemplateDecl = TemplateName.getAsTemplateDecl(); arguments = TST->template_arguments(); } else if(auto DTST = Type->getAs()) { TemplateDecl = resolve(DTST->getQualifier(), DTST->getIdentifier()); arguments = DTST->template_arguments(); } else if(auto RT = Type->getAs()) { return RT->getDecl()->lookup(Name).front(); } else { Type->dump(); std::terminate(); } this->arguments.push_back( new std::pair>{TemplateDecl, resolve(arguments)}); NamedDecl* result; Sema::CodeSynthesisContext context; context.Entity = TemplateDecl; context.Kind = Sema::CodeSynthesisContext::TemplateInstantiation; context.TemplateArgs = this->arguments.back()->second.data(); S.pushCodeSynthesisContext(context); if(auto CTD = llvm::dyn_cast(TemplateDecl)) { llvm::outs() << "--------------------------------------------------------------------\n"; llvm::SmallVector paritals; CTD->getPartialSpecializations(paritals); for(auto partial: paritals) { partial->getInjectedSpecializationType().dump(); } llvm::outs() << "--------------------------------------------------------------------\n"; // CTD->findPartialSpecialization() auto partial = CTD->findPartialSpecialization(Type); if(partial) { result = partial->lookup(Name).front(); } if(!result) { result = CTD->getTemplatedDecl()->lookup(Name).front(); } } else if(auto TATD = llvm::dyn_cast(TemplateDecl)) { result = lookup(TATD->getTemplatedDecl()->getUnderlyingType(), Name); } if(result == nullptr) { Type.dump(); std::terminate(); } return result; } }; } // namespace clang class ASTVistor : public clang::RecursiveASTVisitor { private: clang::Preprocessor& preprocessor; clang::SourceManager& sourceManager; clang::syntax::TokenBuffer& buffer; clang::ASTContext& context; clang::Sema& sema; public: ASTVistor(clang::Preprocessor& preprocessor, clang::syntax::TokenBuffer& buffer, clang::ASTContext& context, clang::Sema& sema) : preprocessor(preprocessor), sourceManager(preprocessor.getSourceManager()), buffer(buffer), context(context), sema(sema) {} bool VisitTypeAliasDecl(clang::TypeAliasDecl* decl) { auto& sm = context.getSourceManager(); if(sm.isInMainFile(decl->getLocation()) && decl->getName() == "result") { auto type = decl->getUnderlyingType(); type.dump(); llvm::outs() << "--------------------------------- Result ------------------------------\n"; clang::DependentNameResolverV2 shaper{context, sema}; auto result = shaper.resolve(type); result.dump(); } return true; } }; int main(int argc, const char** argv) { assert(argc == 2 && "Usage: Preprocessor "); llvm::outs() << "running ASTVisitor...\n"; auto instance = std::make_unique(); clang::DiagnosticIDs* ids = new clang::DiagnosticIDs(); clang::DiagnosticOptions* diag_opts = new clang::DiagnosticOptions(); clang::DiagnosticConsumer* consumer = new clang::IgnoringDiagConsumer(); clang::DiagnosticsEngine* engine = new clang::DiagnosticsEngine(ids, diag_opts, consumer); instance->setDiagnostics(engine); auto invocation = std::make_shared(); std::vector args = { "/usr/local/bin/clang++", "-Xclang", "-no-round-trip-args", "-std=c++20", argv[1], }; invocation = clang::createInvocation(args, {}); // clang::CompilerInvocation::CreateFromArgs(*invocation, args, instance->getDiagnostics()); instance->setInvocation(std::move(invocation)); if(!instance->createTarget()) { llvm::errs() << "Failed to create target\n"; std::terminate(); } clang::SyntaxOnlyAction action; if(!action.BeginSourceFile(*instance, instance->getFrontendOpts().Inputs[0])) { llvm::errs() << "Failed to begin source file\n"; std::terminate(); } clang::syntax::TokenCollector collector{instance->getPreprocessor()}; if(auto error = action.Execute()) { llvm::errs() << "Failed to execute action: " << error << "\n"; std::terminate(); } clang::syntax::TokenBuffer buffer = std::move(collector).consume(); auto tu = instance->getASTContext().getTranslationUnitDecl(); ASTVistor visitor{instance->getPreprocessor(), buffer, instance->getASTContext(), instance->getSema()}; visitor.TraverseDecl(tu); action.EndSourceFile(); };