Skip to content

Commit d136ad3

Browse files
committed
Optimize rule loading procedure and scan expr maps sequentually using hasing
1 parent eddc30e commit d136ad3

3 files changed

Lines changed: 122 additions & 127 deletions

File tree

cpp2rust/converter/mapper.cpp

Lines changed: 82 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ bool translation_rules_loaded_ = false;
2727

2828
std::unordered_map<std::string, TranslationRule::ExprTgt>
2929
exprs_; // src -> ExprTgt
30-
std::unordered_map<std::string, TranslationRule::TypeTgt>
31-
types_; // src -> TypeTgt
30+
std::unordered_multimap<std::string, TranslationRule::TypeRule>
31+
types_; // src -> TypeRule
3232

3333
clang::PrintingPolicy getPrintPolicy() {
3434
assert(ctx_);
@@ -41,6 +41,16 @@ clang::PrintingPolicy getPrintPolicy() {
4141
return policy;
4242
}
4343

44+
std::string GetMapKey(const std::string &str) {
45+
return str.substr(0, str.find_first_of("<["));
46+
}
47+
48+
void AddTypeRule(std::string src, TranslationRule::TypeRule &&rule) {
49+
auto key = GetMapKey(src);
50+
rule.src = std::move(src);
51+
types_.emplace(std::move(key), std::move(rule));
52+
}
53+
4454
// Attempts to unify an instantiated C++ type or function signature with a
4555
// corresponding template pattern. If the two match structurally, it returns
4656
// a mapping from template parameter names (e.g., "T1") to their concrete
@@ -309,6 +319,28 @@ instantiateTgt(const std::unordered_map<std::string, std::string> &types,
309319
return instantiated_template;
310320
}
311321

322+
std::pair<TranslationRule::TypeRule *,
323+
std::unordered_map<std::string, std::string>>
324+
search_type(const std::string &cpp_type) {
325+
auto [it, end] = types_.equal_range(GetMapKey(cpp_type));
326+
TranslationRule::TypeRule *rule = nullptr;
327+
std::unordered_map<std::string, std::string> subs;
328+
329+
for (; it != end; ++it) {
330+
auto &this_rule = it->second;
331+
auto this_subs = matchTemplate(this_rule.src, cpp_type);
332+
if (!this_subs) {
333+
continue;
334+
}
335+
// tie breaker: prefer more specific rules (usually the longer ones)
336+
if (!rule || this_rule.src.size() > rule->src.size()) {
337+
rule = &this_rule;
338+
subs = *std::move(this_subs);
339+
}
340+
}
341+
return {rule, std::move(subs)};
342+
}
343+
312344
template <class Map, class MatchPred>
313345
Map::const_iterator parallel_search(const Map &container,
314346
MatchPred &&match_func) {
@@ -377,15 +409,13 @@ decltype(exprs_)::const_iterator search(const clang::Expr *expr) {
377409
return result;
378410
}
379411

380-
decltype(types_)::const_iterator search(clang::QualType qual_type) {
412+
TranslationRule::TypeRule *search(clang::QualType qual_type) {
381413
auto type = ToString(qual_type);
382-
auto result = parallel_search(
383-
types_, [&](const std::string &tpl) { return matchTemplate(tpl, type); });
384-
llvm::errs() << "search type " << type << ", result: "
385-
<< ((result == types_.end()) ? "None"
386-
: result->second.type_info.type)
414+
auto [rule, subs] = search_type(type);
415+
llvm::errs() << "search type " << type
416+
<< ", result: " << (rule ? rule->type_info.type : "None")
387417
<< '\n';
388-
return result;
418+
return rule;
389419
}
390420

391421
void addRulesFromDirectory(const std::filesystem::path &dir, Model model) {
@@ -397,20 +427,16 @@ void addRulesFromDirectory(const std::filesystem::path &dir, Model model) {
397427
llvm::errs() << "No rules found in " << path << '\n';
398428
continue;
399429
}
400-
for (auto &rule : rules) {
430+
for (auto &[_, rule] : rules) {
401431
if (auto *expr = std::get_if<TranslationRule::ExprTgt>(&rule.tgt)) {
402432
if (!exprs_.try_emplace(std::move(rule.src), std::move(*expr))
403433
.second) {
404434
llvm::errs() << "Key: " << rule.src << " already exists in exprs\n";
405435
assert(0);
406436
}
407437
} else if (auto *type =
408-
std::get_if<TranslationRule::TypeTgt>(&rule.tgt)) {
409-
if (!types_.try_emplace(std::move(rule.src), std::move(*type))
410-
.second) {
411-
llvm::errs() << "Key: " << rule.src << " already exists in types\n";
412-
assert(0);
413-
}
438+
std::get_if<TranslationRule::TypeRule>(&rule.tgt)) {
439+
types_.emplace(GetMapKey(type->src), std::move(*type));
414440
}
415441
}
416442
}
@@ -422,20 +448,21 @@ void addBuiltinTypes(Model model) {
422448

423449
auto add_builtin_rule = [&](clang::QualType qt, const std::string &rust) {
424450
auto cxx = ToString(qt);
425-
types_[cxx] = TranslationRule::TypeTgt::Plain(rust);
426-
types_["const " + cxx] = TranslationRule::TypeTgt::Plain(rust);
451+
AddTypeRule(cxx, TranslationRule::TypeRule::Plain(rust));
452+
AddTypeRule("const " + cxx, TranslationRule::TypeRule::Plain(rust));
427453

428454
switch (model) {
429455
case Model::kUnsafe:
430-
types_[cxx + " *"] = TranslationRule::TypeTgt::UnsafePtr("*mut " + rust);
431-
types_["const " + cxx + " *"] =
432-
TranslationRule::TypeTgt::UnsafePtr("*const " + rust);
456+
AddTypeRule(cxx + " *",
457+
TranslationRule::TypeRule::UnsafePtr("*mut " + rust));
458+
AddTypeRule("const " + cxx + " *",
459+
TranslationRule::TypeRule::UnsafePtr("*const " + rust));
433460
break;
434461
case Model::kRefCount:
435-
types_[cxx + " *"] =
436-
TranslationRule::TypeTgt::RefcountPtr("Ptr::<" + rust + ">");
437-
types_["const " + cxx + " *"] =
438-
TranslationRule::TypeTgt::RefcountPtr("Ptr::<" + rust + ">");
462+
AddTypeRule(cxx + " *", TranslationRule::TypeRule::RefcountPtr(
463+
"Ptr::<" + rust + ">"));
464+
AddTypeRule("const " + cxx + " *", TranslationRule::TypeRule::RefcountPtr(
465+
"Ptr::<" + rust + ">"));
439466
break;
440467
}
441468
};
@@ -453,16 +480,16 @@ void addBuiltinTypes(Model model) {
453480

454481
switch (model) {
455482
case Model::kUnsafe:
456-
types_[ToString(ctx_->VoidTy) + " *"] =
457-
TranslationRule::TypeTgt::UnsafePtr("*mut ::libc::c_void");
458-
types_["const " + ToString(ctx_->VoidTy) + " *"] =
459-
TranslationRule::TypeTgt::UnsafePtr("*const ::libc::c_void");
483+
AddTypeRule(ToString(ctx_->VoidTy) + " *",
484+
TranslationRule::TypeRule::UnsafePtr("*mut ::libc::c_void"));
485+
AddTypeRule("const " + ToString(ctx_->VoidTy) + " *",
486+
TranslationRule::TypeRule::UnsafePtr("*const ::libc::c_void"));
460487
break;
461488
case Model::kRefCount:
462-
types_[ToString(ctx_->VoidTy) + " *"] =
463-
TranslationRule::TypeTgt::RefcountPtr("AnyPtr");
464-
types_["const " + ToString(ctx_->VoidTy) + " *"] =
465-
TranslationRule::TypeTgt::RefcountPtr("AnyPtr");
489+
AddTypeRule(ToString(ctx_->VoidTy) + " *",
490+
TranslationRule::TypeRule::RefcountPtr("AnyPtr"));
491+
AddTypeRule("const " + ToString(ctx_->VoidTy) + " *",
492+
TranslationRule::TypeRule::RefcountPtr("AnyPtr"));
466493
break;
467494
}
468495

@@ -528,18 +555,15 @@ clang::QualType normalizeQualType(clang::QualType qual_type) {
528555
}
529556

530557
std::string mapTypeStringRecursive(const std::string &cpp_type) {
531-
auto rule = parallel_search(types_, [&](const std::string &tpl) {
532-
return matchTemplate(tpl, cpp_type);
533-
});
534-
if (rule == types_.end()) {
558+
auto [rule, subs] = search_type(cpp_type);
559+
if (!rule) {
535560
llvm::errs() << "cpp_type: " << cpp_type << '\n';
536561
assert(0 && "Type is not present in types_");
537562
}
538-
auto subs = matchTemplate(rule->first, cpp_type).value();
539563
for (auto &kv : subs) {
540564
kv.second = mapTypeStringRecursive(kv.second);
541565
}
542-
return instantiateTgt(subs, rule->second.type_info.type);
566+
return instantiateTgt(subs, rule->type_info.type);
543567
}
544568

545569
std::string normalizeTranslationRule(std::string rule) {
@@ -578,7 +602,7 @@ PushASTContext::PushASTContext(clang::ASTContext &ctx) : prev_(ctx_) {
578602
PushASTContext::~PushASTContext() { ctx_ = prev_; }
579603

580604
bool Contains(clang::QualType qual_type) {
581-
return search(qual_type) != types_.end();
605+
return search(qual_type) != nullptr;
582606
}
583607

584608
bool Contains(const clang::Expr *expr) { return search(expr) != exprs_.end(); }
@@ -613,26 +637,26 @@ std::string InstantiateTemplate(const clang::Expr *expr,
613637
}
614638

615639
std::string Map(clang::QualType qual_type) {
616-
if (auto it = search(qual_type); it != types_.end()) {
617-
auto types_map = matchTemplate(it->first, ToString(qual_type)).value();
618-
for (auto &kv : types_map) {
640+
auto [rule, subs] = search_type(ToString(qual_type));
641+
if (rule) {
642+
for (auto &kv : subs) {
619643
kv.second = mapTypeStringRecursive(kv.second);
620644
}
621-
return instantiateTgt(types_map, it->second.type_info.type);
645+
return instantiateTgt(subs, rule->type_info.type);
622646
}
623647
return {};
624648
}
625649

626650
bool MapsToPointer(clang::QualType qual_type) {
627-
if (auto it = search(qual_type); it != types_.end()) {
628-
return it->second.type_info.is_pointer();
651+
if (auto rule = search(qual_type)) {
652+
return rule->type_info.is_pointer();
629653
}
630654
return false;
631655
}
632656

633657
bool MapsToRefcountPointer(clang::QualType qual_type) {
634-
if (auto it = search(qual_type); it != types_.end()) {
635-
return it->second.type_info.is_refcount_pointer;
658+
if (auto rule = search(qual_type)) {
659+
return rule->type_info.is_refcount_pointer;
636660
}
637661
return false;
638662
}
@@ -672,10 +696,7 @@ void AddRuleForUserDefinedType(clang::NamedDecl *decl) {
672696
auto cpp_name = ToString(decl);
673697
auto rs_name = ReplaceAll(cpp_name, "::", "_");
674698

675-
if (!types_.try_emplace(cpp_name, TranslationRule::TypeTgt::Plain(rs_name))
676-
.second) {
677-
return;
678-
}
699+
AddTypeRule(cpp_name, TranslationRule::TypeRule::Plain(rs_name));
679700

680701
if (auto record_decl = llvm::dyn_cast<clang::RecordDecl>(decl)) {
681702
// Forward declaration
@@ -687,23 +708,22 @@ void AddRuleForUserDefinedType(clang::NamedDecl *decl) {
687708
if (cxx_decl->isAbstract()) {
688709
switch (model_) {
689710
case Model::kUnsafe:
690-
types_[cpp_name + " *"] =
691-
TranslationRule::TypeTgt::UnsafePtr("*mut dyn " + rs_name);
711+
AddTypeRule(cpp_name + " *", TranslationRule::TypeRule::UnsafePtr(
712+
"*mut dyn " + rs_name));
692713
break;
693714
case Model::kRefCount:
694-
types_[cpp_name + " *"] = TranslationRule::TypeTgt::RefcountPtr(
695-
"PtrDyn<dyn " + rs_name + '>');
715+
AddTypeRule(cpp_name + " *", TranslationRule::TypeRule::RefcountPtr(
716+
"PtrDyn<dyn " + rs_name + '>'));
696717
break;
697718
}
698719
} else {
699720
switch (model_) {
700721
case Model::kUnsafe:
701-
types_[cpp_name + " *"] =
702-
TranslationRule::TypeTgt::UnsafePtr("*mut " + rs_name);
722+
AddTypeRule(cpp_name + " *",
723+
TranslationRule::TypeRule::UnsafePtr("*mut " + rs_name));
703724
break;
704725
case Model::kRefCount:
705-
types_[cpp_name + " *"] =
706-
TranslationRule::TypeTgt::RefcountPtr("Ptr<" + rs_name + '>');
726+
AddTypeRule(cpp_name + " *", TranslationRule::TypeRule::RefcountPtr("Ptr<" + rs_name + '>'));
707727
break;
708728
}
709729
}
@@ -765,7 +785,7 @@ std::string ToString(const clang::NamedDecl *decl) {
765785
return normalizeTranslationRule(std::move(out));
766786
}
767787

768-
os << ToString(func_decl->getReturnType()) << " ";
788+
os << ToString(func_decl->getReturnType()) << ' ';
769789
if (const auto *method_decl =
770790
llvm::dyn_cast<clang::CXXMethodDecl>(func_decl)) {
771791
if (method_decl->getParent()->isLambda() &&
@@ -901,7 +921,7 @@ void LoadTranslationRules(Model model, clang::ASTContext &ctx,
901921
expr.dump();
902922
}
903923
for (auto &[src, type_tgt] : types_) {
904-
llvm::errs() << "Type: " << src << '\n';
924+
llvm::errs() << "Type key: " << src << '\n';
905925
type_tgt.dump();
906926
}
907927
#endif

0 commit comments

Comments
 (0)