@@ -27,8 +27,8 @@ bool translation_rules_loaded_ = false;
2727
2828std::unordered_map<std::string, TranslationRule::ExprTgt>
2929 exprs_; // src -> ExprTgt
30- std::unordered_map <std::string, TranslationRule::TypeTgt >
31- types_; // src -> TypeTgt
30+ std::unordered_multimap <std::string, TranslationRule::TypeRule >
31+ types_; // src -> TypeRule
3232
3333clang::PrintingPolicy getPrintPolicy () {
3434 assert (ctx_);
@@ -41,6 +41,16 @@ clang::PrintingPolicy getPrintPolicy() {
4141 return policy;
4242}
4343
44+ std::string GetMapKey (const std::string &str) {
45+ return str.substr (0 , str.find_first_of (" <[" ));
46+ }
47+
48+ void AddTypeRule (std::string src, TranslationRule::TypeRule &&rule) {
49+ auto key = GetMapKey (src);
50+ rule.src = std::move (src);
51+ types_.emplace (std::move (key), std::move (rule));
52+ }
53+
4454// Attempts to unify an instantiated C++ type or function signature with a
4555// corresponding template pattern. If the two match structurally, it returns
4656// a mapping from template parameter names (e.g., "T1") to their concrete
@@ -309,6 +319,28 @@ instantiateTgt(const std::unordered_map<std::string, std::string> &types,
309319 return instantiated_template;
310320}
311321
322+ std::pair<TranslationRule::TypeRule *,
323+ std::unordered_map<std::string, std::string>>
324+ search_type (const std::string &cpp_type) {
325+ auto [it, end] = types_.equal_range (GetMapKey (cpp_type));
326+ TranslationRule::TypeRule *rule = nullptr ;
327+ std::unordered_map<std::string, std::string> subs;
328+
329+ for (; it != end; ++it) {
330+ auto &this_rule = it->second ;
331+ auto this_subs = matchTemplate (this_rule.src , cpp_type);
332+ if (!this_subs) {
333+ continue ;
334+ }
335+ // tie breaker: prefer more specific rules (usually the longer ones)
336+ if (!rule || this_rule.src .size () > rule->src .size ()) {
337+ rule = &this_rule;
338+ subs = *std::move (this_subs);
339+ }
340+ }
341+ return {rule, std::move (subs)};
342+ }
343+
312344template <class Map , class MatchPred >
313345Map::const_iterator parallel_search (const Map &container,
314346 MatchPred &&match_func) {
@@ -377,15 +409,13 @@ decltype(exprs_)::const_iterator search(const clang::Expr *expr) {
377409 return result;
378410}
379411
380- decltype (types_)::const_iterator search (clang::QualType qual_type) {
412+ TranslationRule::TypeRule * search (clang::QualType qual_type) {
381413 auto type = ToString (qual_type);
382- auto result = parallel_search (
383- types_, [&](const std::string &tpl) { return matchTemplate (tpl, type); });
384- llvm::errs () << " search type " << type << " , result: "
385- << ((result == types_.end ()) ? " None"
386- : result->second .type_info .type )
414+ auto [rule, subs] = search_type (type);
415+ llvm::errs () << " search type " << type
416+ << " , result: " << (rule ? rule->type_info .type : " None" )
387417 << ' \n ' ;
388- return result ;
418+ return rule ;
389419}
390420
391421void addRulesFromDirectory (const std::filesystem::path &dir, Model model) {
@@ -397,20 +427,16 @@ void addRulesFromDirectory(const std::filesystem::path &dir, Model model) {
397427 llvm::errs () << " No rules found in " << path << ' \n ' ;
398428 continue ;
399429 }
400- for (auto &rule : rules) {
430+ for (auto &[_, rule] : rules) {
401431 if (auto *expr = std::get_if<TranslationRule::ExprTgt>(&rule.tgt )) {
402432 if (!exprs_.try_emplace (std::move (rule.src ), std::move (*expr))
403433 .second ) {
404434 llvm::errs () << " Key: " << rule.src << " already exists in exprs\n " ;
405435 assert (0 );
406436 }
407437 } else if (auto *type =
408- std::get_if<TranslationRule::TypeTgt>(&rule.tgt )) {
409- if (!types_.try_emplace (std::move (rule.src ), std::move (*type))
410- .second ) {
411- llvm::errs () << " Key: " << rule.src << " already exists in types\n " ;
412- assert (0 );
413- }
438+ std::get_if<TranslationRule::TypeRule>(&rule.tgt )) {
439+ types_.emplace (GetMapKey (type->src ), std::move (*type));
414440 }
415441 }
416442 }
@@ -422,20 +448,21 @@ void addBuiltinTypes(Model model) {
422448
423449 auto add_builtin_rule = [&](clang::QualType qt, const std::string &rust) {
424450 auto cxx = ToString (qt);
425- types_[ cxx] = TranslationRule::TypeTgt ::Plain (rust);
426- types_[ " const " + cxx] = TranslationRule::TypeTgt ::Plain (rust);
451+ AddTypeRule ( cxx, TranslationRule::TypeRule ::Plain (rust) );
452+ AddTypeRule ( " const " + cxx, TranslationRule::TypeRule ::Plain (rust) );
427453
428454 switch (model) {
429455 case Model::kUnsafe :
430- types_[cxx + " *" ] = TranslationRule::TypeTgt::UnsafePtr (" *mut " + rust);
431- types_[" const " + cxx + " *" ] =
432- TranslationRule::TypeTgt::UnsafePtr (" *const " + rust);
456+ AddTypeRule (cxx + " *" ,
457+ TranslationRule::TypeRule::UnsafePtr (" *mut " + rust));
458+ AddTypeRule (" const " + cxx + " *" ,
459+ TranslationRule::TypeRule::UnsafePtr (" *const " + rust));
433460 break ;
434461 case Model::kRefCount :
435- types_[ cxx + " *" ] =
436- TranslationRule::TypeTgt::RefcountPtr ( " Ptr::<" + rust + " >" );
437- types_[ " const " + cxx + " *" ] =
438- TranslationRule::TypeTgt::RefcountPtr ( " Ptr::<" + rust + " >" );
462+ AddTypeRule ( cxx + " *" , TranslationRule::TypeRule::RefcountPtr (
463+ " Ptr::<" + rust + " >" ) );
464+ AddTypeRule ( " const " + cxx + " *" , TranslationRule::TypeRule::RefcountPtr (
465+ " Ptr::<" + rust + " >" ) );
439466 break ;
440467 }
441468 };
@@ -453,16 +480,16 @@ void addBuiltinTypes(Model model) {
453480
454481 switch (model) {
455482 case Model::kUnsafe :
456- types_[ ToString (ctx_->VoidTy ) + " *" ] =
457- TranslationRule::TypeTgt ::UnsafePtr (" *mut ::libc::c_void" );
458- types_[ " const " + ToString (ctx_->VoidTy ) + " *" ] =
459- TranslationRule::TypeTgt ::UnsafePtr (" *const ::libc::c_void" );
483+ AddTypeRule ( ToString (ctx_->VoidTy ) + " *" ,
484+ TranslationRule::TypeRule ::UnsafePtr (" *mut ::libc::c_void" ) );
485+ AddTypeRule ( " const " + ToString (ctx_->VoidTy ) + " *" ,
486+ TranslationRule::TypeRule ::UnsafePtr (" *const ::libc::c_void" ) );
460487 break ;
461488 case Model::kRefCount :
462- types_[ ToString (ctx_->VoidTy ) + " *" ] =
463- TranslationRule::TypeTgt ::RefcountPtr (" AnyPtr" );
464- types_[ " const " + ToString (ctx_->VoidTy ) + " *" ] =
465- TranslationRule::TypeTgt ::RefcountPtr (" AnyPtr" );
489+ AddTypeRule ( ToString (ctx_->VoidTy ) + " *" ,
490+ TranslationRule::TypeRule ::RefcountPtr (" AnyPtr" ) );
491+ AddTypeRule ( " const " + ToString (ctx_->VoidTy ) + " *" ,
492+ TranslationRule::TypeRule ::RefcountPtr (" AnyPtr" ) );
466493 break ;
467494 }
468495
@@ -528,18 +555,15 @@ clang::QualType normalizeQualType(clang::QualType qual_type) {
528555}
529556
530557std::string mapTypeStringRecursive (const std::string &cpp_type) {
531- auto rule = parallel_search (types_, [&](const std::string &tpl) {
532- return matchTemplate (tpl, cpp_type);
533- });
534- if (rule == types_.end ()) {
558+ auto [rule, subs] = search_type (cpp_type);
559+ if (!rule) {
535560 llvm::errs () << " cpp_type: " << cpp_type << ' \n ' ;
536561 assert (0 && " Type is not present in types_" );
537562 }
538- auto subs = matchTemplate (rule->first , cpp_type).value ();
539563 for (auto &kv : subs) {
540564 kv.second = mapTypeStringRecursive (kv.second );
541565 }
542- return instantiateTgt (subs, rule->second . type_info .type );
566+ return instantiateTgt (subs, rule->type_info .type );
543567}
544568
545569std::string normalizeTranslationRule (std::string rule) {
@@ -578,7 +602,7 @@ PushASTContext::PushASTContext(clang::ASTContext &ctx) : prev_(ctx_) {
578602PushASTContext::~PushASTContext () { ctx_ = prev_; }
579603
580604bool Contains (clang::QualType qual_type) {
581- return search (qual_type) != types_. end () ;
605+ return search (qual_type) != nullptr ;
582606}
583607
584608bool Contains (const clang::Expr *expr) { return search (expr) != exprs_.end (); }
@@ -613,26 +637,26 @@ std::string InstantiateTemplate(const clang::Expr *expr,
613637}
614638
615639std::string Map (clang::QualType qual_type) {
616- if ( auto it = search ( qual_type); it != types_. end ()) {
617- auto types_map = matchTemplate (it-> first , ToString (qual_type)). value ();
618- for (auto &kv : types_map ) {
640+ auto [rule, subs] = search_type ( ToString ( qual_type));
641+ if (rule) {
642+ for (auto &kv : subs ) {
619643 kv.second = mapTypeStringRecursive (kv.second );
620644 }
621- return instantiateTgt (types_map, it-> second . type_info .type );
645+ return instantiateTgt (subs, rule-> type_info .type );
622646 }
623647 return {};
624648}
625649
626650bool MapsToPointer (clang::QualType qual_type) {
627- if (auto it = search (qual_type); it != types_. end ( )) {
628- return it-> second . type_info .is_pointer ();
651+ if (auto rule = search (qual_type)) {
652+ return rule-> type_info .is_pointer ();
629653 }
630654 return false ;
631655}
632656
633657bool MapsToRefcountPointer (clang::QualType qual_type) {
634- if (auto it = search (qual_type); it != types_. end ( )) {
635- return it-> second . type_info .is_refcount_pointer ;
658+ if (auto rule = search (qual_type)) {
659+ return rule-> type_info .is_refcount_pointer ;
636660 }
637661 return false ;
638662}
@@ -672,10 +696,7 @@ void AddRuleForUserDefinedType(clang::NamedDecl *decl) {
672696 auto cpp_name = ToString (decl);
673697 auto rs_name = ReplaceAll (cpp_name, " ::" , " _" );
674698
675- if (!types_.try_emplace (cpp_name, TranslationRule::TypeTgt::Plain (rs_name))
676- .second ) {
677- return ;
678- }
699+ AddTypeRule (cpp_name, TranslationRule::TypeRule::Plain (rs_name));
679700
680701 if (auto record_decl = llvm::dyn_cast<clang::RecordDecl>(decl)) {
681702 // Forward declaration
@@ -687,23 +708,22 @@ void AddRuleForUserDefinedType(clang::NamedDecl *decl) {
687708 if (cxx_decl->isAbstract ()) {
688709 switch (model_) {
689710 case Model::kUnsafe :
690- types_[ cpp_name + " *" ] =
691- TranslationRule::TypeTgt::UnsafePtr ( " *mut dyn " + rs_name);
711+ AddTypeRule ( cpp_name + " *" , TranslationRule::TypeRule::UnsafePtr (
712+ " *mut dyn " + rs_name) );
692713 break ;
693714 case Model::kRefCount :
694- types_[ cpp_name + " *" ] = TranslationRule::TypeTgt ::RefcountPtr (
695- " PtrDyn<dyn " + rs_name + ' >' );
715+ AddTypeRule ( cpp_name + " *" , TranslationRule::TypeRule ::RefcountPtr (
716+ " PtrDyn<dyn " + rs_name + ' >' ) );
696717 break ;
697718 }
698719 } else {
699720 switch (model_) {
700721 case Model::kUnsafe :
701- types_[ cpp_name + " *" ] =
702- TranslationRule::TypeTgt ::UnsafePtr (" *mut " + rs_name);
722+ AddTypeRule ( cpp_name + " *" ,
723+ TranslationRule::TypeRule ::UnsafePtr (" *mut " + rs_name) );
703724 break ;
704725 case Model::kRefCount :
705- types_[cpp_name + " *" ] =
706- TranslationRule::TypeTgt::RefcountPtr (" Ptr<" + rs_name + ' >' );
726+ AddTypeRule (cpp_name + " *" , TranslationRule::TypeRule::RefcountPtr (" Ptr<" + rs_name + ' >' ));
707727 break ;
708728 }
709729 }
@@ -765,7 +785,7 @@ std::string ToString(const clang::NamedDecl *decl) {
765785 return normalizeTranslationRule (std::move (out));
766786 }
767787
768- os << ToString (func_decl->getReturnType ()) << " " ;
788+ os << ToString (func_decl->getReturnType ()) << ' ' ;
769789 if (const auto *method_decl =
770790 llvm::dyn_cast<clang::CXXMethodDecl>(func_decl)) {
771791 if (method_decl->getParent ()->isLambda () &&
@@ -901,7 +921,7 @@ void LoadTranslationRules(Model model, clang::ASTContext &ctx,
901921 expr.dump();
902922 }
903923 for (auto &[src, type_tgt] : types_) {
904- llvm::errs() << "Type: " << src << '\n';
924+ llvm::errs() << "Type key : " << src << '\n';
905925 type_tgt.dump();
906926 }
907927#endif
0 commit comments