Skip to content

Commit 1ca5ff9

Browse files
committed
edits
1 parent 84e15b2 commit 1ca5ff9

40 files changed

Lines changed: 874 additions & 973 deletions

cpp2rust/converter/converter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3558,7 +3558,7 @@ std::string Converter::ConvertMappedMethodCall(
35583558
std::string Converter::GetMappedAsString(clang::Expr *expr, clang::Expr **args,
35593559
unsigned num_args,
35603560
TempMaterializationCtx *ctx) {
3561-
auto *tgt_ir = Mapper::GetExprTgt(GetCalleeOrExpr(expr));
3561+
auto *tgt_ir = Mapper::GetExprRule(GetCalleeOrExpr(expr));
35623562
if (!tgt_ir)
35633563
return {};
35643564

@@ -3585,7 +3585,7 @@ std::string Converter::ConvertIRFragment(
35853585
result += Mapper::InstantiateTemplate(GetCalleeOrExpr(expr), g->n);
35863586
} else if (auto *ph = std::get_if<PlaceholderFragment>(&frag)) {
35873587
auto arg_idx = ph->n;
3588-
assert(arg_idx < static_cast<int>(all_args.size()));
3588+
assert(arg_idx < all_args.size());
35893589
auto *arg = all_args[arg_idx];
35903590
bool is_receiver = HasReceiver(expr) && arg_idx == 0;
35913591

cpp2rust/converter/mapper.cpp

Lines changed: 77 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
#include <clang/Basic/SourceManager.h>
88
#include <llvm/Support/ThreadPool.h>
99

10-
#include <atomic>
1110
#include <format>
12-
#include <mutex>
1311
#include <regex>
12+
#include <unordered_map>
1413
#include <utility>
1514
#include <vector>
1615

@@ -25,8 +24,8 @@ clang::ASTContext *ctx_ = nullptr;
2524
Model model_ = Model::kUnsafe;
2625
bool translation_rules_loaded_ = false;
2726

28-
std::unordered_map<std::string, TranslationRule::ExprTgt>
29-
exprs_; // src -> ExprTgt
27+
std::unordered_multimap<std::string, TranslationRule::ExprRule>
28+
exprs_; // src -> ExprRule
3029
std::unordered_multimap<std::string, TranslationRule::TypeRule>
3130
types_; // src -> TypeRule
3231

@@ -64,8 +63,8 @@ void AddTypeRule(std::string src, TranslationRule::TypeRule &&rule) {
6463
// Example:
6564
// template_str = "std::vector<T1>::vector()"
6665
// instantiated = "std::vector<int>::vector()"
67-
// result = { {"T1", "int"} }
68-
std::optional<std::unordered_map<std::string, std::string>>
66+
// result = { "int" }
67+
std::optional<std::vector<std::string>>
6968
matchTemplate(const std::string &template_str,
7069
const std::string &instantiated) {
7170
auto matchLiteralAt = [&](const std::string &input_str, size_t pos,
@@ -193,7 +192,7 @@ matchTemplate(const std::string &template_str,
193192
return std::string::npos;
194193
};
195194

196-
std::unordered_map<std::string, std::string> captured;
195+
std::vector<std::string> captured;
197196

198197
size_t ti = 0;
199198
size_t si = 0;
@@ -207,7 +206,7 @@ matchTemplate(const std::string &template_str,
207206
tj++;
208207
}
209208

210-
std::string name = template_str.substr(ti, tj - ti);
209+
size_t type_idx = std::stoi(&template_str[ti + 1]) - 1;
211210
ti = tj;
212211

213212
std::string nextLit;
@@ -221,10 +220,11 @@ matchTemplate(const std::string &template_str,
221220
}
222221
nextLit = template_str.substr(ti, scan - ti);
223222

224-
auto [it, inserted] = captured.try_emplace(std::move(name));
225-
if (!inserted) {
223+
captured.resize(std::max(captured.size(), type_idx + 1));
224+
auto &repl = captured[type_idx];
225+
if (!repl.empty()) {
226226
size_t end_pos = 0;
227-
if (!matchLiteralAt(instantiated, si, it->second, end_pos)) {
227+
if (!matchLiteralAt(instantiated, si, repl, end_pos)) {
228228
return std::nullopt;
229229
}
230230
si = end_pos;
@@ -245,7 +245,7 @@ matchTemplate(const std::string &template_str,
245245
b--;
246246
}
247247

248-
it->second = instantiated.substr(a, b - a);
248+
repl = instantiated.substr(a, b - a);
249249
si = k;
250250
} else {
251251
size_t a = si;
@@ -258,7 +258,7 @@ matchTemplate(const std::string &template_str,
258258
b--;
259259
}
260260

261-
it->second = instantiated.substr(a, b - a);
261+
repl = instantiated.substr(a, b - a);
262262
si = instantiated.size();
263263
}
264264
}
@@ -278,7 +278,7 @@ matchTemplate(const std::string &template_str,
278278
std::isdigit(template_str[tj + 1])) {
279279
break;
280280
}
281-
tj++;
281+
++tj;
282282
}
283283

284284
std::string lit = template_str.substr(ti, tj - ti);
@@ -307,29 +307,36 @@ matchTemplate(const std::string &template_str,
307307
// corresponding instantiated type from `types`.
308308
//
309309
// Example:
310-
// types = { {"T1", "i32"} }
310+
// types = { {"i32"} }
311311
// tgt_template = "Vec<T1>"
312312
// result = "Vec<i32>"
313-
std::string
314-
instantiateTgt(const std::unordered_map<std::string, std::string> &types,
315-
const std::string &tgt_template) {
313+
std::string instantiateTgt(const std::vector<std::string> &types,
314+
const std::string &tgt_template) {
315+
assert(types.size() <= 9);
316316
std::string instantiated_template = tgt_template;
317-
for (const auto &[key, value] : types) {
318-
std::string::size_type pos = 0;
319-
while ((pos = instantiated_template.find(key, pos)) != std::string::npos) {
320-
instantiated_template.replace(pos, key.length(), value);
321-
pos += value.length();
317+
std::string::size_type pos = 0;
318+
while ((pos = instantiated_template.find('T', pos)) != std::string::npos) {
319+
if (pos + 1 >= instantiated_template.size()) {
320+
break;
322321
}
322+
if (!std::isdigit(instantiated_template[pos + 1])) {
323+
++pos;
324+
continue;
325+
}
326+
auto &repl = types.at(instantiated_template[pos + 1] - '1');
327+
instantiated_template.replace(pos, 2, repl);
328+
pos += repl.length();
323329
}
324330
return instantiated_template;
325331
}
326332

327-
std::pair<TranslationRule::TypeRule *,
328-
std::unordered_map<std::string, std::string>>
329-
search_type(const std::string &cpp_type) {
330-
auto [it, end] = types_.equal_range(GetMapKey(cpp_type));
331-
TranslationRule::TypeRule *rule = nullptr;
332-
std::unordered_map<std::string, std::string> subs;
333+
template <typename T>
334+
std::pair<T *, std::vector<std::string>>
335+
search(std::unordered_multimap<std::string, T> &map,
336+
const std::string &cpp_type) {
337+
auto [it, end] = map.equal_range(GetMapKey(cpp_type));
338+
T *rule = nullptr;
339+
std::vector<std::string> subs;
333340

334341
for (; it != end; ++it) {
335342
auto &this_rule = it->second;
@@ -346,77 +353,21 @@ search_type(const std::string &cpp_type) {
346353
return {rule, std::move(subs)};
347354
}
348355

349-
template <class Map, class MatchPred>
350-
Map::const_iterator parallel_search(const Map &container,
351-
MatchPred &&match_func) {
352-
if (container.empty()) {
353-
return container.cend();
354-
}
355-
356-
auto tie_breaker = [](const std::string &a, const std::string &b) -> bool {
357-
if (a.size() != b.size()) {
358-
// Match more specific rules first (usually the longer ones).
359-
return a.size() > b.size();
360-
}
361-
return a < b; // Lexicographically
362-
};
363-
364-
const unsigned hw = std::max(1u, std::thread::hardware_concurrency());
365-
const unsigned nthreads =
366-
std::min<unsigned>(hw, std::max<size_t>(1, container.bucket_count()));
367-
368-
std::atomic<size_t> next_bucket{0};
369-
std::mutex hit_mtx;
370-
std::optional<typename Map::key_type> hit_key;
371-
372-
auto worker = [&](unsigned) {
373-
while (true) {
374-
size_t b = next_bucket.fetch_add(1, std::memory_order_relaxed);
375-
if (b >= container.bucket_count()) {
376-
break;
377-
}
378-
379-
for (auto it = container.cbegin(b); it != container.cend(b); ++it) {
380-
if (!match_func(it->first)) {
381-
continue;
382-
}
383-
384-
std::scoped_lock lk(hit_mtx);
385-
if (!hit_key || tie_breaker(it->first, *hit_key)) {
386-
hit_key = it->first;
387-
}
388-
}
389-
}
390-
};
391-
392-
{
393-
llvm::DefaultThreadPool pool(
394-
llvm::heavyweight_hardware_concurrency(nthreads));
395-
for (unsigned t = 0; t < nthreads; ++t)
396-
pool.async(worker, t);
397-
pool.wait();
398-
}
399-
400-
return hit_key ? container.find(*hit_key) : container.cend();
401-
}
402-
403-
decltype(exprs_)::const_iterator search(const clang::Expr *expr) {
356+
TranslationRule::ExprRule *search(const clang::Expr *expr) {
404357
auto qualified_name = ToString(expr);
405-
auto result = parallel_search(exprs_, [&](const std::string &tpl) {
406-
return matchTemplate(tpl, qualified_name);
407-
});
358+
auto [rule, subs] = search(exprs_, qualified_name);
408359
llvm::errs() << "search expr " << qualified_name << ", result:\n";
409-
if (result != exprs_.end()) {
410-
result->second.dump();
360+
if (rule) {
361+
rule->dump();
411362
} else {
412363
llvm::errs() << "None\n";
413364
}
414-
return result;
365+
return rule;
415366
}
416367

417368
TranslationRule::TypeRule *search(clang::QualType qual_type) {
418369
auto type = ToString(qual_type);
419-
auto [rule, subs] = search_type(type);
370+
auto [rule, subs] = search(types_, type);
420371
llvm::errs() << "search type " << type
421372
<< ", result: " << (rule ? rule->type_info.type : "None")
422373
<< '\n';
@@ -427,22 +378,16 @@ void addRulesFromDirectory(const std::filesystem::path &dir, Model model) {
427378
for (const auto &entry : std::filesystem::recursive_directory_iterator(dir)) {
428379
auto &path = entry.path();
429380
if (entry.is_regular_file() && path.extension() == ".cpp") {
430-
auto rules = TranslationRule::Load(path, model);
431-
if (rules.empty()) {
381+
auto [expr_rules, type_rules] = TranslationRule::Load(path, model);
382+
if (expr_rules.empty() && type_rules.empty()) {
432383
llvm::errs() << "No rules found in " << path << '\n';
433384
continue;
434385
}
435-
for (auto &[_, rule] : rules) {
436-
if (auto *expr = std::get_if<TranslationRule::ExprTgt>(&rule.tgt)) {
437-
if (!exprs_.try_emplace(std::move(rule.src), std::move(*expr))
438-
.second) {
439-
llvm::errs() << "Key: " << rule.src << " already exists in exprs\n";
440-
assert(0);
441-
}
442-
} else if (auto *type =
443-
std::get_if<TranslationRule::TypeRule>(&rule.tgt)) {
444-
types_.emplace(GetMapKey(rule.src), std::move(*type));
445-
}
386+
for (auto &[_, rule] : expr_rules) {
387+
exprs_.emplace(GetMapKey(rule.src), std::move(rule));
388+
}
389+
for (auto &[_, rule] : type_rules) {
390+
types_.emplace(GetMapKey(rule.src), std::move(rule));
446391
}
447392
}
448393
}
@@ -560,22 +505,23 @@ clang::QualType normalizeQualType(clang::QualType qual_type) {
560505
}
561506

562507
std::string mapTypeStringRecursive(const std::string &cpp_type) {
563-
auto [rule, subs] = search_type(cpp_type);
508+
auto [rule, subs] = search(types_, cpp_type);
564509
if (!rule) {
565510
llvm::errs() << "cpp_type: " << cpp_type << '\n';
566511
assert(0 && "Type is not present in types_");
567512
}
568-
for (auto &kv : subs) {
569-
kv.second = mapTypeStringRecursive(kv.second);
513+
for (auto &ty : subs) {
514+
ty = mapTypeStringRecursive(ty);
570515
}
571516
return instantiateTgt(subs, rule->type_info.type);
572517
}
573518

574519
std::string normalizeTranslationRule(std::string rule) {
575-
const std::array<std::pair<std::regex, std::string>, 2> normalization_rules{{
576-
// Detach pointer from double reference. Useful for matching translation
577-
// rules.
578-
{std::regex(R"(\*\&\&)"), "* &&"},
520+
// Detach pointer from double reference. Useful for matching translation
521+
// rules.
522+
rule = ReplaceAll(rule, "(*&&)", "* &&");
523+
524+
const std::array<std::pair<std::regex, std::string>, 1> normalization_rules{{
579525
// Ignore constant template parameters, i.e. replace them with _.
580526
{std::regex(R"(\b\d+\b)"), "_"},
581527
}};
@@ -610,13 +556,10 @@ bool Contains(clang::QualType qual_type) {
610556
return search(qual_type) != nullptr;
611557
}
612558

613-
bool Contains(const clang::Expr *expr) { return search(expr) != exprs_.end(); }
559+
bool Contains(const clang::Expr *expr) { return search(expr) != nullptr; }
614560

615-
const TranslationRule::ExprTgt *GetExprTgt(const clang::Expr *expr) {
616-
if (auto it = search(expr); it != exprs_.end()) {
617-
return &it->second;
618-
}
619-
return nullptr;
561+
const TranslationRule::ExprRule *GetExprRule(const clang::Expr *expr) {
562+
return search(expr);
620563
}
621564

622565
std::string MapFunctionName(const clang::FunctionDecl *decl) {
@@ -630,22 +573,21 @@ std::string MapFunctionName(const clang::FunctionDecl *decl) {
630573

631574
std::string InstantiateTemplate(const clang::Expr *expr, unsigned n) {
632575
auto text = 'T' + std::to_string(n);
633-
auto it = search(expr);
634-
if (it == exprs_.end()) {
576+
auto [rule, subs] = search(exprs_, ToString(expr));
577+
if (!rule) {
635578
return text;
636579
}
637-
auto types_map = matchTemplate(it->first, ToString(expr)).value();
638-
for (auto &kv : types_map) {
639-
kv.second = mapTypeStringRecursive(kv.second);
580+
for (auto &ty : subs) {
581+
ty = mapTypeStringRecursive(ty);
640582
}
641-
return instantiateTgt(types_map, text);
583+
return instantiateTgt(subs, text);
642584
}
643585

644586
std::string Map(clang::QualType qual_type) {
645-
auto [rule, subs] = search_type(ToString(qual_type));
587+
auto [rule, subs] = search(types_, ToString(qual_type));
646588
if (rule) {
647-
for (auto &kv : subs) {
648-
kv.second = mapTypeStringRecursive(kv.second);
589+
for (auto &ty : subs) {
590+
ty = mapTypeStringRecursive(ty);
649591
}
650592
return instantiateTgt(subs, rule->type_info.type);
651593
}
@@ -663,30 +605,24 @@ bool MapsToRefcountPointer(clang::QualType qual_type) {
663605
}
664606

665607
bool ReturnsPointer(const clang::Expr *expr) {
666-
if (auto it = search(expr); it != exprs_.end()) {
667-
return it->second.return_type.is_pointer();
668-
}
669-
return false;
608+
auto rule = search(expr);
609+
return rule && rule->return_type.is_pointer();
670610
}
671611

672612
const TranslationRule::TypeInfo &GetParamInfo(const clang::Expr *expr,
673613
unsigned index) {
674-
auto name = "a" + std::to_string(index);
675-
auto it = search(expr);
676-
assert(it != exprs_.end() && "expression must have a translation rule");
677-
auto name_it = it->second.params.find(name);
678-
assert(name_it != it->second.params.end() &&
679-
"placeholder arg must have a corresponding param type in IR");
680-
return name_it->second;
614+
auto rule = search(expr);
615+
assert(rule && "expression must have a translation rule");
616+
return rule->params.at(index);
681617
}
682618

683619
std::string GetParamType(const clang::Expr *expr, unsigned index) {
684620
auto &info = GetParamInfo(expr, index);
685-
auto types_map = matchTemplate(search(expr)->first, ToString(expr)).value();
686-
for (auto &kv : types_map) {
687-
kv.second = mapTypeStringRecursive(kv.second);
621+
auto [rule, subs] = search(exprs_, ToString(expr));
622+
for (auto &ty : subs) {
623+
ty = mapTypeStringRecursive(ty);
688624
}
689-
return instantiateTgt(types_map, info.type);
625+
return instantiateTgt(subs, info.type);
690626
}
691627

692628
bool ParamIsPointer(const clang::Expr *expr, unsigned index) {

0 commit comments

Comments
 (0)