Skip to content

Commit 235175d

Browse files
Copilotnunoplopes
andauthored
Changes before error encountered
Agent-Logs-Url: https://github.com/Cpp2Rust/cpp2rust/sessions/ff09d3f1-67b5-4edc-a293-23566d7a6b9f Co-authored-by: nunoplopes <2998477+nunoplopes@users.noreply.github.com>
1 parent 05f1ce7 commit 235175d

3 files changed

Lines changed: 96 additions & 43 deletions

File tree

cpp2rust/converter/converter.cpp

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ bool Converter::Convert(clang::QualType qual_type) {
6565
return false;
6666
}
6767

68-
if (Mapper::Contains(qual_type) &&
69-
Mapper::Map(qual_type) != ignore_rule_type_) {
70-
StrCat(Mapper::Map(qual_type));
68+
if (auto mapped = Mapper::Map(qual_type);
69+
!mapped.empty() && mapped != ignore_rule_type_) {
70+
StrCat(mapped);
7171
return false;
7272
}
7373

@@ -228,7 +228,8 @@ Converter::ConvertFunctionPointerType(const clang::FunctionProtoType *proto,
228228
std::string result =
229229
(kind == FnProtoType::LambdaCallOperator ? "impl Fn(" : "fn(");
230230
for (auto p_ty : proto->param_types()) {
231-
result += ToString(p_ty) + ",";
231+
result += ToString(p_ty);
232+
result += ',';
232233
}
233234
result += ")";
234235
if (!proto->getReturnType()->isVoidType()) {
@@ -1391,13 +1392,13 @@ bool Converter::VisitCallExpr(clang::CallExpr *expr) {
13911392
return false;
13921393
}
13931394

1394-
if (Mapper::Contains(expr->getCallee())) {
1395+
if (auto *tgt_ir = Mapper::GetExprTgt(expr->getCallee())) {
13951396
auto **args = expr->getArgs();
13961397
auto num_args = expr->getNumArgs();
13971398
auto ctx = CollectPrvalueToLRefArgs(expr);
13981399
auto str = [&] {
13991400
PushExprKind push(*this, ExprKind::RValue);
1400-
return GetMappedAsString(expr, args, num_args, &ctx);
1401+
return GetMappedAsString(tgt_ir, expr, args, num_args, &ctx);
14011402
}();
14021403

14031404
if ((IsReferenceType(expr) ||
@@ -1592,11 +1593,11 @@ Converter::ConvertCallExpr(clang::CallExpr *expr) {
15921593
} else if (IsBuiltinConstantP(callee)) {
15931594
StrCat(expr->getArg(0)->isCXX11ConstantExpr(ctx_) ? token::kOne
15941595
: token::kZero);
1595-
} else if (Mapper::Contains(callee)) {
1596+
} else if (auto *tgt_ir = Mapper::GetExprTgt(callee)) {
15961597
auto **args = expr->getArgs();
15971598
auto num_args = expr->getNumArgs();
15981599
auto ctx = CollectPrvalueToLRefArgs(expr);
1599-
auto mapped = GetMappedAsString(expr, args, num_args, &ctx);
1600+
auto mapped = GetMappedAsString(tgt_ir, expr, args, num_args, &ctx);
16001601
StrCat(mapped);
16011602
return ctx;
16021603
} else if (auto *opcall = clang::dyn_cast<clang::CXXOperatorCallExpr>(expr)) {
@@ -2945,10 +2946,10 @@ Converter::GetOverloadedFunctionName(const clang::FunctionDecl *decl) {
29452946
name += "_const";
29462947
}
29472948

2948-
std::vector<char> tokens = {'<', '>', ' ', ':'};
2949-
for (auto token : tokens) {
2950-
name.erase(std::remove(name.begin(), name.end(), token), name.end());
2951-
}
2949+
name.erase(
2950+
std::remove_if(name.begin(), name.end(),
2951+
[](char c) { return c == '<' || c == '>' || c == ' ' || c == ':'; }),
2952+
name.end());
29522953
std::replace(name.begin(), name.end(), '*', 'p');
29532954

29542955
return name;
@@ -3189,16 +3190,15 @@ template <typename Predicate>
31893190
void Converter::ConvertCXXMethodDecls(const clang::CXXRecordDecl *decl,
31903191
const std::string_view signature,
31913192
Predicate predicate) {
3192-
std::vector<clang::CXXMethodDecl *> methods;
3193-
std::copy_if(decl->method_begin(), decl->method_end(),
3194-
std::back_inserter(methods), predicate);
3195-
if (methods.empty()) {
3193+
if (!std::any_of(decl->method_begin(), decl->method_end(), predicate)) {
31963194
return;
31973195
}
31983196
StrCat(signature);
31993197
PushBrace brace(*this);
3200-
for (auto *method : methods) {
3201-
VisitCXXMethodDecl(method);
3198+
for (auto *method : decl->methods()) {
3199+
if (predicate(method)) {
3200+
VisitCXXMethodDecl(method);
3201+
}
32023202
}
32033203
}
32043204

@@ -3526,19 +3526,25 @@ std::string Converter::ConvertMappedMethodCall(
35263526
ConvertIRFragment(mc.body, expr, args, num_args, ctx);
35273527
}
35283528

3529-
std::string Converter::GetMappedAsString(clang::Expr *expr, clang::Expr **args,
3529+
std::string Converter::GetMappedAsString(const TranslationRule::ExprTgt *tgt_ir,
3530+
clang::Expr *expr, clang::Expr **args,
35303531
unsigned num_args,
35313532
TempMaterializationCtx *ctx) {
3532-
auto *tgt_ir = Mapper::GetExprTgt(GetCalleeOrExpr(expr));
35333533
assert(tgt_ir && "GetExprTgt failed to find a translation rule");
3534-
35353534
auto result = ConvertIRFragment(tgt_ir->body, expr, args, num_args, ctx);
35363535
if (tgt_ir->multi_statement) {
35373536
return '{' + result + '}';
35383537
}
35393538
return result;
35403539
}
35413540

3541+
std::string Converter::GetMappedAsString(clang::Expr *expr, clang::Expr **args,
3542+
unsigned num_args,
3543+
TempMaterializationCtx *ctx) {
3544+
auto *tgt_ir = Mapper::GetExprTgt(GetCalleeOrExpr(expr));
3545+
return GetMappedAsString(tgt_ir, expr, args, num_args, ctx);
3546+
}
3547+
35423548
std::string Converter::ConvertIRFragment(
35433549
const std::vector<TranslationRule::BodyFragment> &fragments,
35443550
clang::Expr *expr, clang::Expr **args, unsigned num_args,
@@ -3554,7 +3560,7 @@ std::string Converter::ConvertIRFragment(
35543560
} else if (auto *g = std::get_if<GenericFragment>(&frag)) {
35553561
result += Mapper::InstantiateTemplate(GetCalleeOrExpr(expr), g->name);
35563562
} else if (auto *ph = std::get_if<PlaceholderFragment>(&frag)) {
3557-
auto arg_idx = std::stoi(ph->arg.substr(1)); // "a0" -> 0
3563+
auto arg_idx = static_cast<int>(ph->arg[1] - '0'); // "a0" -> 0
35583564
assert(arg_idx < static_cast<int>(all_args.size()));
35593565
auto *arg = all_args[arg_idx];
35603566
bool is_receiver = HasReceiver(expr) && arg_idx == 0;

cpp2rust/converter/converter.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,12 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
472472

473473
virtual bool ConvertCXXOperatorCallExpr(clang::CXXOperatorCallExpr *expr);
474474

475+
// Overload that skips the GetExprTgt lookup when the caller already has it.
476+
std::string GetMappedAsString(const TranslationRule::ExprTgt *tgt_ir,
477+
clang::Expr *expr, clang::Expr **args = nullptr,
478+
unsigned num_args = 0,
479+
TempMaterializationCtx *ctx = nullptr);
480+
475481
std::string GetMappedAsString(clang::Expr *expr, clang::Expr **args = nullptr,
476482
unsigned num_args = 0,
477483
TempMaterializationCtx *ctx = nullptr);

cpp2rust/converter/mapper.cpp

Lines changed: 62 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -309,12 +309,14 @@ instantiateTgt(const std::unordered_map<std::string, std::string> &types,
309309
return instantiated_template;
310310
}
311311

312+
// Returns a pair of {iterator, match_result}. The match_result is the value
313+
// returned by match_func for the winning key. When the iterator equals
314+
// container.cend(), match_result is value-initialized (e.g. std::nullopt for
315+
// std::optional).
312316
template <class Map, class MatchPred>
313-
Map::const_iterator parallel_search(const Map &container,
314-
MatchPred &&match_func) {
315-
if (container.empty()) {
316-
return container.cend();
317-
}
317+
auto parallel_search(const Map &container, MatchPred &&match_func) {
318+
using MatchResult =
319+
decltype(match_func(std::declval<const typename Map::key_type &>()));
318320

319321
auto tie_breaker = [](const std::string &a, const std::string &b) -> bool {
320322
if (a.size() != b.size()) {
@@ -324,13 +326,35 @@ Map::const_iterator parallel_search(const Map &container,
324326
return a < b; // Lexicographically
325327
};
326328

329+
std::optional<typename Map::key_type> hit_key;
330+
MatchResult hit_match;
331+
332+
if (container.empty()) {
333+
return std::make_pair(container.cend(), std::move(hit_match));
334+
}
335+
336+
// Serial fast-path: avoids thread-pool creation overhead for small maps.
337+
constexpr size_t kSerialThreshold = 16;
338+
if (container.size() <= kSerialThreshold) {
339+
for (auto it = container.cbegin(); it != container.cend(); ++it) {
340+
auto m = match_func(it->first);
341+
if (!m)
342+
continue;
343+
if (!hit_key || tie_breaker(it->first, *hit_key)) {
344+
hit_key = it->first;
345+
hit_match = std::move(m);
346+
}
347+
}
348+
return std::make_pair(hit_key ? container.find(*hit_key) : container.cend(),
349+
std::move(hit_match));
350+
}
351+
327352
const unsigned hw = std::max(1u, std::thread::hardware_concurrency());
328353
const unsigned nthreads =
329354
std::min<unsigned>(hw, std::max<size_t>(1, container.bucket_count()));
330355

331356
std::atomic<size_t> next_bucket{0};
332357
std::mutex hit_mtx;
333-
std::optional<typename Map::key_type> hit_key;
334358

335359
auto worker = [&](unsigned) {
336360
while (true) {
@@ -340,13 +364,14 @@ Map::const_iterator parallel_search(const Map &container,
340364
}
341365

342366
for (auto it = container.cbegin(b); it != container.cend(b); ++it) {
343-
if (!match_func(it->first)) {
367+
auto m = match_func(it->first);
368+
if (!m)
344369
continue;
345-
}
346370

347371
std::scoped_lock lk(hit_mtx);
348372
if (!hit_key || tie_breaker(it->first, *hit_key)) {
349373
hit_key = it->first;
374+
hit_match = std::move(m);
350375
}
351376
}
352377
}
@@ -360,12 +385,13 @@ Map::const_iterator parallel_search(const Map &container,
360385
pool.wait();
361386
}
362387

363-
return hit_key ? container.find(*hit_key) : container.cend();
388+
return std::make_pair(hit_key ? container.find(*hit_key) : container.cend(),
389+
std::move(hit_match));
364390
}
365391

366392
decltype(exprs_)::const_iterator search(const clang::Expr *expr) {
367393
auto qualified_name = ToString(expr);
368-
auto result = parallel_search(exprs_, [&](const std::string &tpl) {
394+
auto [result, match] = parallel_search(exprs_, [&](const std::string &tpl) {
369395
return matchTemplate(tpl, qualified_name);
370396
});
371397
llvm::errs() << "search expr " << qualified_name << ", result:\n";
@@ -379,7 +405,7 @@ decltype(exprs_)::const_iterator search(const clang::Expr *expr) {
379405

380406
decltype(types_)::const_iterator search(clang::QualType qual_type) {
381407
auto type = ToString(qual_type);
382-
auto result = parallel_search(
408+
auto [result, match] = parallel_search(
383409
types_, [&](const std::string &tpl) { return matchTemplate(tpl, type); });
384410
llvm::errs() << "search type " << type << ", result: "
385411
<< ((result == types_.end()) ? "None"
@@ -528,22 +554,23 @@ clang::QualType normalizeQualType(clang::QualType qual_type) {
528554
}
529555

530556
std::string mapTypeStringRecursive(const std::string &cpp_type) {
531-
auto rule = parallel_search(types_, [&](const std::string &tpl) {
557+
auto [rule, match] = parallel_search(types_, [&](const std::string &tpl) {
532558
return matchTemplate(tpl, cpp_type);
533559
});
534560
if (rule == types_.end()) {
535561
llvm::errs() << "cpp_type: " << cpp_type << '\n';
536562
assert(0 && "Type is not present in types_");
537563
}
538-
auto subs = matchTemplate(rule->first, cpp_type).value();
564+
auto subs = std::move(match).value();
539565
for (auto &kv : subs) {
540566
kv.second = mapTypeStringRecursive(kv.second);
541567
}
542568
return instantiateTgt(subs, rule->second.type_info.type);
543569
}
544570

545571
std::string normalizeTranslationRule(std::string rule) {
546-
const std::array<std::pair<std::regex, std::string>, 2> normalization_rules{{
572+
static const std::array<std::pair<std::regex, std::string>, 2>
573+
normalization_rules{{
547574
// Detach pointer from double reference. Useful for matching translation
548575
// rules.
549576
{std::regex(R"(\*\&\&)"), "* &&"},
@@ -601,20 +628,26 @@ std::string MapFunctionName(const clang::FunctionDecl *decl) {
601628

602629
std::string InstantiateTemplate(const clang::Expr *expr,
603630
const std::string &text) {
604-
auto it = search(expr);
631+
auto qualified_name = ToString(expr);
632+
auto [it, match] = parallel_search(exprs_, [&](const std::string &tpl) {
633+
return matchTemplate(tpl, qualified_name);
634+
});
605635
if (it == exprs_.end()) {
606636
return text;
607637
}
608-
auto types_map = matchTemplate(it->first, ToString(expr)).value();
638+
auto types_map = std::move(match).value();
609639
for (auto &kv : types_map) {
610640
kv.second = mapTypeStringRecursive(kv.second);
611641
}
612642
return instantiateTgt(types_map, text);
613643
}
614644

615645
std::string Map(clang::QualType qual_type) {
616-
if (auto it = search(qual_type); it != types_.end()) {
617-
auto types_map = matchTemplate(it->first, ToString(qual_type)).value();
646+
auto type_str = ToString(qual_type);
647+
auto [it, match] = parallel_search(
648+
types_, [&](const std::string &tpl) { return matchTemplate(tpl, type_str); });
649+
if (it != types_.end()) {
650+
auto types_map = std::move(match).value();
618651
for (auto &kv : types_map) {
619652
kv.second = mapTypeStringRecursive(kv.second);
620653
}
@@ -656,12 +689,20 @@ const TranslationRule::TypeInfo &GetParamInfo(const clang::Expr *expr,
656689
}
657690

658691
std::string GetParamType(const clang::Expr *expr, unsigned index) {
659-
auto &info = GetParamInfo(expr, index);
660-
auto types_map = matchTemplate(search(expr)->first, ToString(expr)).value();
692+
auto qualified_name = ToString(expr);
693+
auto [it, match] = parallel_search(exprs_, [&](const std::string &tpl) {
694+
return matchTemplate(tpl, qualified_name);
695+
});
696+
assert(it != exprs_.end() && "expression must have a translation rule");
697+
auto name = "a" + std::to_string(index);
698+
auto name_it = it->second.params.find(name);
699+
assert(name_it != it->second.params.end() &&
700+
"placeholder arg must have a corresponding param type in IR");
701+
auto types_map = std::move(match).value();
661702
for (auto &kv : types_map) {
662703
kv.second = mapTypeStringRecursive(kv.second);
663704
}
664-
return instantiateTgt(types_map, info.type);
705+
return instantiateTgt(types_map, name_it->second.type);
665706
}
666707

667708
bool ParamIsPointer(const clang::Expr *expr, unsigned index) {

0 commit comments

Comments
 (0)