77#include < clang/Basic/SourceManager.h>
88#include < llvm/Support/ThreadPool.h>
99
10- #include < atomic>
1110#include < format>
12- #include < mutex>
1311#include < regex>
12+ #include < unordered_map>
1413#include < utility>
1514#include < vector>
1615
@@ -25,8 +24,8 @@ clang::ASTContext *ctx_ = nullptr;
2524Model model_ = Model::kUnsafe ;
2625bool translation_rules_loaded_ = false ;
2726
28- std::unordered_map <std::string, TranslationRule::ExprTgt >
29- exprs_; // src -> ExprTgt
27+ std::unordered_multimap <std::string, TranslationRule::ExprRule >
28+ exprs_; // src -> ExprRule
3029std::unordered_multimap<std::string, TranslationRule::TypeRule>
3130 types_; // src -> TypeRule
3231
@@ -64,8 +63,8 @@ void AddTypeRule(std::string src, TranslationRule::TypeRule &&rule) {
6463// Example:
6564// template_str = "std::vector<T1>::vector()"
6665// instantiated = "std::vector<int>::vector()"
67- // result = { {"T1", " int"} }
68- std::optional<std::unordered_map<std::string, std::string>>
66+ // result = { " int" }
67+ std::optional<std::vector< std::string>>
6968matchTemplate (const std::string &template_str,
7069 const std::string &instantiated) {
7170 auto matchLiteralAt = [&](const std::string &input_str, size_t pos,
@@ -193,7 +192,7 @@ matchTemplate(const std::string &template_str,
193192 return std::string::npos;
194193 };
195194
196- std::unordered_map<std::string, std::string> captured;
195+ std::vector< std::string> captured;
197196
198197 size_t ti = 0 ;
199198 size_t si = 0 ;
@@ -207,7 +206,7 @@ matchTemplate(const std::string &template_str,
207206 tj++;
208207 }
209208
210- std::string name = template_str. substr (ti, tj - ti) ;
209+ size_t type_idx = std::stoi (&template_str[ti + 1 ]) - 1 ;
211210 ti = tj;
212211
213212 std::string nextLit;
@@ -221,10 +220,11 @@ matchTemplate(const std::string &template_str,
221220 }
222221 nextLit = template_str.substr (ti, scan - ti);
223222
224- auto [it, inserted] = captured.try_emplace (std::move (name));
225- if (!inserted) {
223+ captured.resize (std::max (captured.size (), type_idx + 1 ));
224+ auto &repl = captured[type_idx];
225+ if (!repl.empty ()) {
226226 size_t end_pos = 0 ;
227- if (!matchLiteralAt (instantiated, si, it-> second , end_pos)) {
227+ if (!matchLiteralAt (instantiated, si, repl , end_pos)) {
228228 return std::nullopt ;
229229 }
230230 si = end_pos;
@@ -245,7 +245,7 @@ matchTemplate(const std::string &template_str,
245245 b--;
246246 }
247247
248- it-> second = instantiated.substr (a, b - a);
248+ repl = instantiated.substr (a, b - a);
249249 si = k;
250250 } else {
251251 size_t a = si;
@@ -258,7 +258,7 @@ matchTemplate(const std::string &template_str,
258258 b--;
259259 }
260260
261- it-> second = instantiated.substr (a, b - a);
261+ repl = instantiated.substr (a, b - a);
262262 si = instantiated.size ();
263263 }
264264 }
@@ -278,7 +278,7 @@ matchTemplate(const std::string &template_str,
278278 std::isdigit (template_str[tj + 1 ])) {
279279 break ;
280280 }
281- tj++ ;
281+ ++tj ;
282282 }
283283
284284 std::string lit = template_str.substr (ti, tj - ti);
@@ -307,29 +307,36 @@ matchTemplate(const std::string &template_str,
307307// corresponding instantiated type from `types`.
308308//
309309// Example:
310- // types = { {"T1", " i32"} }
310+ // types = { {"i32"} }
311311// tgt_template = "Vec<T1>"
312312// result = "Vec<i32>"
313- std::string
314- instantiateTgt ( const std::unordered_map<std:: string, std::string> &types,
315- const std::string &tgt_template) {
313+ std::string instantiateTgt ( const std::vector<std::string> &types,
314+ const std::string &tgt_template) {
315+ assert (types. size () <= 9 );
316316 std::string instantiated_template = tgt_template;
317- for (const auto &[key, value] : types) {
318- std::string::size_type pos = 0 ;
319- while ((pos = instantiated_template.find (key, pos)) != std::string::npos) {
320- instantiated_template.replace (pos, key.length (), value);
321- pos += value.length ();
317+ std::string::size_type pos = 0 ;
318+ while ((pos = instantiated_template.find (' T' , pos)) != std::string::npos) {
319+ if (pos + 1 >= instantiated_template.size ()) {
320+ break ;
322321 }
322+ if (!std::isdigit (instantiated_template[pos + 1 ])) {
323+ ++pos;
324+ continue ;
325+ }
326+ auto &repl = types.at (instantiated_template[pos + 1 ] - ' 1' );
327+ instantiated_template.replace (pos, 2 , repl);
328+ pos += repl.length ();
323329 }
324330 return instantiated_template;
325331}
326332
327- std::pair<TranslationRule::TypeRule *,
328- std::unordered_map<std::string, std::string>>
329- search_type (const std::string &cpp_type) {
330- auto [it, end] = types_.equal_range (GetMapKey (cpp_type));
331- TranslationRule::TypeRule *rule = nullptr ;
332- std::unordered_map<std::string, std::string> subs;
333+ template <typename T>
334+ std::pair<T *, std::vector<std::string>>
335+ search (std::unordered_multimap<std::string, T> &map,
336+ const std::string &cpp_type) {
337+ auto [it, end] = map.equal_range (GetMapKey (cpp_type));
338+ T *rule = nullptr ;
339+ std::vector<std::string> subs;
333340
334341 for (; it != end; ++it) {
335342 auto &this_rule = it->second ;
@@ -346,77 +353,21 @@ search_type(const std::string &cpp_type) {
346353 return {rule, std::move (subs)};
347354}
348355
349- template <class Map , class MatchPred >
350- Map::const_iterator parallel_search (const Map &container,
351- MatchPred &&match_func) {
352- if (container.empty ()) {
353- return container.cend ();
354- }
355-
356- auto tie_breaker = [](const std::string &a, const std::string &b) -> bool {
357- if (a.size () != b.size ()) {
358- // Match more specific rules first (usually the longer ones).
359- return a.size () > b.size ();
360- }
361- return a < b; // Lexicographically
362- };
363-
364- const unsigned hw = std::max (1u , std::thread::hardware_concurrency ());
365- const unsigned nthreads =
366- std::min<unsigned >(hw, std::max<size_t >(1 , container.bucket_count ()));
367-
368- std::atomic<size_t > next_bucket{0 };
369- std::mutex hit_mtx;
370- std::optional<typename Map::key_type> hit_key;
371-
372- auto worker = [&](unsigned ) {
373- while (true ) {
374- size_t b = next_bucket.fetch_add (1 , std::memory_order_relaxed);
375- if (b >= container.bucket_count ()) {
376- break ;
377- }
378-
379- for (auto it = container.cbegin (b); it != container.cend (b); ++it) {
380- if (!match_func (it->first )) {
381- continue ;
382- }
383-
384- std::scoped_lock lk (hit_mtx);
385- if (!hit_key || tie_breaker (it->first , *hit_key)) {
386- hit_key = it->first ;
387- }
388- }
389- }
390- };
391-
392- {
393- llvm::DefaultThreadPool pool (
394- llvm::heavyweight_hardware_concurrency (nthreads));
395- for (unsigned t = 0 ; t < nthreads; ++t)
396- pool.async (worker, t);
397- pool.wait ();
398- }
399-
400- return hit_key ? container.find (*hit_key) : container.cend ();
401- }
402-
403- decltype (exprs_)::const_iterator search (const clang::Expr *expr) {
356+ TranslationRule::ExprRule *search (const clang::Expr *expr) {
404357 auto qualified_name = ToString (expr);
405- auto result = parallel_search (exprs_, [&](const std::string &tpl) {
406- return matchTemplate (tpl, qualified_name);
407- });
358+ auto [rule, subs] = search (exprs_, qualified_name);
408359 llvm::errs () << " search expr " << qualified_name << " , result:\n " ;
409- if (result != exprs_. end () ) {
410- result-> second . dump ();
360+ if (rule ) {
361+ rule-> dump ();
411362 } else {
412363 llvm::errs () << " None\n " ;
413364 }
414- return result ;
365+ return rule ;
415366}
416367
417368TranslationRule::TypeRule *search (clang::QualType qual_type) {
418369 auto type = ToString (qual_type);
419- auto [rule, subs] = search_type ( type);
370+ auto [rule, subs] = search (types_, type);
420371 llvm::errs () << " search type " << type
421372 << " , result: " << (rule ? rule->type_info .type : " None" )
422373 << ' \n ' ;
@@ -427,22 +378,16 @@ void addRulesFromDirectory(const std::filesystem::path &dir, Model model) {
427378 for (const auto &entry : std::filesystem::recursive_directory_iterator (dir)) {
428379 auto &path = entry.path ();
429380 if (entry.is_regular_file () && path.extension () == " .cpp" ) {
430- auto rules = TranslationRule::Load (path, model);
431- if (rules .empty ()) {
381+ auto [expr_rules, type_rules] = TranslationRule::Load (path, model);
382+ if (expr_rules. empty () && type_rules .empty ()) {
432383 llvm::errs () << " No rules found in " << path << ' \n ' ;
433384 continue ;
434385 }
435- for (auto &[_, rule] : rules) {
436- if (auto *expr = std::get_if<TranslationRule::ExprTgt>(&rule.tgt )) {
437- if (!exprs_.try_emplace (std::move (rule.src ), std::move (*expr))
438- .second ) {
439- llvm::errs () << " Key: " << rule.src << " already exists in exprs\n " ;
440- assert (0 );
441- }
442- } else if (auto *type =
443- std::get_if<TranslationRule::TypeRule>(&rule.tgt )) {
444- types_.emplace (GetMapKey (rule.src ), std::move (*type));
445- }
386+ for (auto &[_, rule] : expr_rules) {
387+ exprs_.emplace (GetMapKey (rule.src ), std::move (rule));
388+ }
389+ for (auto &[_, rule] : type_rules) {
390+ types_.emplace (GetMapKey (rule.src ), std::move (rule));
446391 }
447392 }
448393 }
@@ -560,22 +505,23 @@ clang::QualType normalizeQualType(clang::QualType qual_type) {
560505}
561506
562507std::string mapTypeStringRecursive (const std::string &cpp_type) {
563- auto [rule, subs] = search_type ( cpp_type);
508+ auto [rule, subs] = search (types_, cpp_type);
564509 if (!rule) {
565510 llvm::errs () << " cpp_type: " << cpp_type << ' \n ' ;
566511 assert (0 && " Type is not present in types_" );
567512 }
568- for (auto &kv : subs) {
569- kv. second = mapTypeStringRecursive (kv. second );
513+ for (auto &ty : subs) {
514+ ty = mapTypeStringRecursive (ty );
570515 }
571516 return instantiateTgt (subs, rule->type_info .type );
572517}
573518
574519std::string normalizeTranslationRule (std::string rule) {
575- const std::array<std::pair<std::regex, std::string>, 2 > normalization_rules{{
576- // Detach pointer from double reference. Useful for matching translation
577- // rules.
578- {std::regex (R"( \*\&\&)" ), " * &&" },
520+ // Detach pointer from double reference. Useful for matching translation
521+ // rules.
522+ rule = ReplaceAll (rule, " (*&&)" , " * &&" );
523+
524+ const std::array<std::pair<std::regex, std::string>, 1 > normalization_rules{{
579525 // Ignore constant template parameters, i.e. replace them with _.
580526 {std::regex (R"( \b\d+\b)" ), " _" },
581527 }};
@@ -610,13 +556,10 @@ bool Contains(clang::QualType qual_type) {
610556 return search (qual_type) != nullptr ;
611557}
612558
613- bool Contains (const clang::Expr *expr) { return search (expr) != exprs_. end () ; }
559+ bool Contains (const clang::Expr *expr) { return search (expr) != nullptr ; }
614560
615- const TranslationRule::ExprTgt *GetExprTgt (const clang::Expr *expr) {
616- if (auto it = search (expr); it != exprs_.end ()) {
617- return &it->second ;
618- }
619- return nullptr ;
561+ const TranslationRule::ExprRule *GetExprRule (const clang::Expr *expr) {
562+ return search (expr);
620563}
621564
622565std::string MapFunctionName (const clang::FunctionDecl *decl) {
@@ -630,22 +573,21 @@ std::string MapFunctionName(const clang::FunctionDecl *decl) {
630573
631574std::string InstantiateTemplate (const clang::Expr *expr, unsigned n) {
632575 auto text = ' T' + std::to_string (n);
633- auto it = search (expr);
634- if (it == exprs_. end () ) {
576+ auto [rule, subs] = search (exprs_, ToString ( expr) );
577+ if (!rule ) {
635578 return text;
636579 }
637- auto types_map = matchTemplate (it->first , ToString (expr)).value ();
638- for (auto &kv : types_map) {
639- kv.second = mapTypeStringRecursive (kv.second );
580+ for (auto &ty : subs) {
581+ ty = mapTypeStringRecursive (ty);
640582 }
641- return instantiateTgt (types_map , text);
583+ return instantiateTgt (subs , text);
642584}
643585
644586std::string Map (clang::QualType qual_type) {
645- auto [rule, subs] = search_type ( ToString (qual_type));
587+ auto [rule, subs] = search (types_, ToString (qual_type));
646588 if (rule) {
647- for (auto &kv : subs) {
648- kv. second = mapTypeStringRecursive (kv. second );
589+ for (auto &ty : subs) {
590+ ty = mapTypeStringRecursive (ty );
649591 }
650592 return instantiateTgt (subs, rule->type_info .type );
651593 }
@@ -663,30 +605,24 @@ bool MapsToRefcountPointer(clang::QualType qual_type) {
663605}
664606
665607bool ReturnsPointer (const clang::Expr *expr) {
666- if (auto it = search (expr); it != exprs_.end ()) {
667- return it->second .return_type .is_pointer ();
668- }
669- return false ;
608+ auto rule = search (expr);
609+ return rule && rule->return_type .is_pointer ();
670610}
671611
672612const TranslationRule::TypeInfo &GetParamInfo (const clang::Expr *expr,
673613 unsigned index) {
674- auto name = " a" + std::to_string (index);
675- auto it = search (expr);
676- assert (it != exprs_.end () && " expression must have a translation rule" );
677- auto name_it = it->second .params .find (name);
678- assert (name_it != it->second .params .end () &&
679- " placeholder arg must have a corresponding param type in IR" );
680- return name_it->second ;
614+ auto rule = search (expr);
615+ assert (rule && " expression must have a translation rule" );
616+ return rule->params .at (index);
681617}
682618
683619std::string GetParamType (const clang::Expr *expr, unsigned index) {
684620 auto &info = GetParamInfo (expr, index);
685- auto types_map = matchTemplate ( search (expr)-> first , ToString (expr)). value ( );
686- for (auto &kv : types_map ) {
687- kv. second = mapTypeStringRecursive (kv. second );
621+ auto [rule, subs] = search (exprs_ , ToString (expr));
622+ for (auto &ty : subs ) {
623+ ty = mapTypeStringRecursive (ty );
688624 }
689- return instantiateTgt (types_map , info.type );
625+ return instantiateTgt (subs , info.type );
690626}
691627
692628bool ParamIsPointer (const clang::Expr *expr, unsigned index) {
0 commit comments