Skip to content

Commit 2276bb7

Browse files
committed
Handle default as middle case
C++ allows: switch (x) { default: ... case 1: ... } In rust, default needs to be always on the last position, otherwise, all values of x, even 1, will hit the default arm. Hence, the above C++ exmaple becomes: match x { 1 => ... _ => ... } As such, the algorithm for translating the cases becomes: for (case: GetTopLevelSwitchCases()) { if (ChainContainsDefault(case)) { defer the conversion for the end } Convert(case) Convert(GetSwitchArmBody(case)) } Convert(deferred default) ChainContainsDefault traverses the stacked case statements starting from a top level case. For example: case 2: case 3: default: ... is deferred for the end of the match arms and is translated as: _ => {} i.e., drop the case 2, case 3 from the output, only convert as if it was only default.
1 parent 4c601da commit 2276bb7

2 files changed

Lines changed: 88 additions & 36 deletions

File tree

cpp2rust/converter/converter.cpp

Lines changed: 88 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2579,65 +2579,118 @@ bool Converter::VisitImplicitValueInitExpr(clang::ImplicitValueInitExpr *expr) {
25792579
return false;
25802580
}
25812581

2582-
bool Converter::VisitSwitchCase(clang::SwitchCase *stmt) {
2583-
if (visited_switch_cases_.contains(stmt)) {
2584-
return false;
2582+
static std::vector<clang::SwitchCase *>
2583+
GetTopLevelSwitchCases(clang::SwitchStmt *stmt) {
2584+
std::vector<clang::SwitchCase *> cases;
2585+
if (auto *body = llvm::dyn_cast<clang::CompoundStmt>(stmt->getBody())) {
2586+
for (auto *s : body->body()) {
2587+
if (auto *sc = clang::dyn_cast<clang::SwitchCase>(s)) {
2588+
cases.push_back(sc);
2589+
}
2590+
}
2591+
}
2592+
return cases;
2593+
}
2594+
2595+
static bool ChainContainsDefault(clang::SwitchCase *c) {
2596+
for (clang::Stmt *cur = c;;) {
2597+
if (clang::isa<clang::DefaultStmt>(cur)) {
2598+
return true;
2599+
}
2600+
auto *sc = clang::dyn_cast<clang::SwitchCase>(cur);
2601+
if (!sc) {
2602+
return false;
2603+
}
2604+
cur = sc->getSubStmt();
25852605
}
2586-
visited_switch_cases_.insert(stmt);
2606+
return false;
2607+
}
25872608

2588-
if (auto case_stmt = clang::dyn_cast<clang::CaseStmt>(stmt)) {
2589-
Convert(case_stmt->getLHS());
2609+
static clang::Stmt *ChainLeafBody(clang::SwitchCase *c) {
2610+
clang::Stmt *cur = c->getSubStmt();
2611+
while (auto *sc = clang::dyn_cast<clang::SwitchCase>(cur)) {
2612+
cur = sc->getSubStmt();
25902613
}
2614+
return cur;
2615+
}
25912616

2592-
if (clang::isa<clang::CaseStmt>(stmt->getSubStmt())) {
2593-
StrCat("|| v == ");
2594-
} else {
2595-
if (clang::isa<clang::CaseStmt>(stmt)) {
2596-
StrCat(" => {");
2597-
} else {
2598-
StrCat("_ => {");
2617+
static std::vector<clang::Stmt *> GetSwitchArmBody(clang::CompoundStmt *body,
2618+
clang::SwitchCase *head) {
2619+
std::vector<clang::Stmt *> out;
2620+
out.push_back(ChainLeafBody(head));
2621+
auto it = body->body_begin(), end = body->body_end();
2622+
while (it != end && *it != head) {
2623+
++it;
2624+
}
2625+
assert(it != end);
2626+
++it;
2627+
while (it != end && !clang::isa<clang::SwitchCase>(*it)) {
2628+
out.push_back(*it);
2629+
++it;
2630+
}
2631+
return out;
2632+
}
2633+
2634+
bool Converter::VisitSwitchCase(clang::SwitchCase *stmt) {
2635+
clang::Stmt *cur = stmt;
2636+
clang::SwitchCase *last = nullptr;
2637+
bool first = true;
2638+
2639+
while (auto *sc = clang::dyn_cast<clang::SwitchCase>(cur)) {
2640+
if (auto *case_stmt = clang::dyn_cast<clang::CaseStmt>(sc)) {
2641+
if (!first) {
2642+
StrCat("|| v == ");
2643+
}
2644+
Convert(case_stmt->getLHS());
25992645
}
2646+
last = sc;
2647+
first = false;
2648+
cur = sc->getSubStmt();
26002649
}
26012650

2602-
Convert(stmt->getSubStmt());
2651+
if (clang::isa<clang::CaseStmt>(last)) {
2652+
StrCat(" => {");
2653+
} else /* DefaultStmt */ {
2654+
StrCat("_ => {");
2655+
}
26032656
return false;
26042657
}
26052658

26062659
bool Converter::VisitSwitchStmt(clang::SwitchStmt *stmt) {
2660+
auto *body = clang::dyn_cast<clang::CompoundStmt>(stmt->getBody());
2661+
assert(body);
2662+
26072663
StrCat("'switch: {");
26082664
StrCat(std::format("let __match_cond = {};", ToString(stmt->getCond())));
26092665
StrCat("match __match_cond");
26102666
StrCat("{");
26112667

2612-
bool has_default_case = false;
2613-
auto body = llvm::cast<clang::CompoundStmt>(stmt->getBody());
2614-
assert(body);
2615-
2616-
visited_switch_cases_ = {};
2617-
26182668
break_with_explicit_label_ = true;
2619-
for (auto it = body->body_begin(), end = body->body_end(); it != end;) {
2620-
if (auto switch_case = clang::dyn_cast<clang::SwitchCase>(*it)) {
2621-
if (clang::isa<clang::CaseStmt>(switch_case)) {
2622-
StrCat("v if v == ");
2623-
} else {
2624-
has_default_case = true;
2625-
}
2626-
VisitSwitchCase(switch_case);
2627-
++it;
2628-
}
26292669

2630-
while (it != end && !clang::isa<clang::SwitchCase>(*it)) {
2631-
Convert(*it);
2632-
++it;
2670+
clang::SwitchCase *default_case = nullptr;
2671+
for (auto *sc : GetTopLevelSwitchCases(stmt)) {
2672+
if (ChainContainsDefault(sc)) {
2673+
default_case = sc;
2674+
continue;
2675+
}
2676+
StrCat("v if v == ");
2677+
VisitSwitchCase(sc);
2678+
for (auto *t : GetSwitchArmBody(body, sc)) {
2679+
Convert(t);
26332680
}
2634-
26352681
StrCat("},");
26362682
}
26372683

2638-
if (!has_default_case) {
2684+
if (default_case) {
2685+
StrCat("_ => {");
2686+
for (auto *t : GetSwitchArmBody(body, default_case)) {
2687+
Convert(t);
2688+
}
2689+
StrCat("},");
2690+
} else {
26392691
StrCat(R"( _ => {})");
26402692
}
2693+
26412694
break_with_explicit_label_ = false;
26422695

26432696
StrCat("}");

cpp2rust/converter/converter.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,6 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
462462
bool break_with_explicit_label_ = false;
463463
std::stack<clang::Expr *> curr_for_inc_;
464464
std::stack<clang::QualType> curr_init_type_;
465-
std::unordered_set<clang::SwitchCase *> visited_switch_cases_;
466465

467466
std::unordered_set<const clang::VarDecl *> map_iter_decls_;
468467

0 commit comments

Comments
 (0)