Skip to content

Commit 9802d0a

Browse files
authored
Translate switch with default case in the middle (#30)
This PR handles code as: ```cpp // tests/unit/switch_default_middle.cpp int default_middle(int x) { int r = 0; switch (x) { case 1: r = 1; break; default: r = 99; break; case 2: r = 2; break; } return r; } ``` In rust, the default arm has to be the last arm, otherwise, arms following the default one will be ignored, making this assertion fail `assert(default_middle(2) == 2)`. Hence, the rust translation becomes: ```rs match x { v if v == 1 => { r = 1; } v if v == 2 => { r = 2; } _ => { r = 99; } } ``` This is achieved by iterating over all top-level case statements in VisitSwitchStmt and deferring the translation of the default arm until every othter arm is translated. A top-level case statement is defined as: ```cpp switch (x) { case TOP_LEVEL_STMT_1: case NOT_TOP_LEVEL_STMT_1: default: // not top level as well ... break; case TOP_LEVEL_STMT_2: case NOT_TOP_LEVEL_STMT_2: ... break } ``` Furthermore, chains of case statements that contain a default statement are reduced to default, effectively making the above code translate as: ```rs match x { v if v == TOP_LEVEL_STMT_2 || v == NOT_TOP_LEVEL_STMT_2 => {} _ => {} // TOP_LEVEL_STMT_1, NOT_TOP_LEVEL_STMT_1, default was reduced to default } ```
1 parent 370e554 commit 9802d0a

19 files changed

Lines changed: 458 additions & 61 deletions

cpp2rust/converter/converter.cpp

Lines changed: 43 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2621,63 +2621,67 @@ bool Converter::VisitImplicitValueInitExpr(clang::ImplicitValueInitExpr *expr) {
26212621
return false;
26222622
}
26232623

2624-
static std::unordered_set<clang::SwitchCase *> visited_cases;
2625-
2626-
bool Converter::VisitSwitchCase(clang::SwitchCase *stmt) {
2627-
if (visited_cases.contains(stmt)) {
2628-
return false;
2624+
bool Converter::ConvertSwitchCaseCondition(clang::SwitchCase *stmt) {
2625+
clang::Stmt *cur = stmt;
2626+
clang::SwitchCase *last = nullptr;
2627+
bool first = true;
2628+
2629+
while (auto *sc = clang::dyn_cast<clang::SwitchCase>(cur)) {
2630+
if (auto *case_stmt = clang::dyn_cast<clang::CaseStmt>(sc)) {
2631+
if (!first) {
2632+
StrCat("|| v == ");
2633+
}
2634+
Convert(case_stmt->getLHS());
2635+
}
2636+
last = sc;
2637+
first = false;
2638+
cur = sc->getSubStmt();
26292639
}
2630-
visited_cases.insert(stmt);
26312640

2632-
if (auto case_stmt = clang::dyn_cast<clang::CaseStmt>(stmt)) {
2633-
Convert(case_stmt->getLHS());
2641+
if (clang::isa<clang::CaseStmt>(last)) {
2642+
StrCat(" => {");
2643+
} else /* DefaultStmt */ {
2644+
StrCat("_ => {");
26342645
}
2646+
return false;
2647+
}
26352648

2636-
if (clang::isa<clang::CaseStmt>(stmt->getSubStmt())) {
2637-
StrCat("|| v == ");
2649+
void Converter::EmitSwitchArm(clang::CompoundStmt *body, clang::SwitchCase *sc,
2650+
bool is_default) {
2651+
if (is_default) {
2652+
StrCat("_ => {");
26382653
} else {
2639-
if (clang::isa<clang::CaseStmt>(stmt)) {
2640-
StrCat(" => {");
2641-
} else {
2642-
StrCat("_ => {");
2643-
}
2654+
StrCat("v if v == ");
2655+
ConvertSwitchCaseCondition(sc);
26442656
}
2645-
2646-
Convert(stmt->getSubStmt());
2647-
return false;
2657+
for (auto *t : GetSwitchCaseBody(body, sc)) {
2658+
Convert(t);
2659+
}
2660+
StrCat("},");
26482661
}
26492662

26502663
bool Converter::VisitSwitchStmt(clang::SwitchStmt *stmt) {
26512664
PushBreakTarget push(break_target_, BreakTarget::Switch);
2665+
auto *body = clang::dyn_cast<clang::CompoundStmt>(stmt->getBody());
2666+
assert(body);
2667+
26522668
StrCat("'switch: {");
26532669
StrCat(std::format("let __match_cond = {};", ToString(stmt->getCond())));
26542670
StrCat("match __match_cond");
26552671
StrCat("{");
26562672

2657-
bool has_default_case = false;
2658-
auto body = llvm::cast<clang::CompoundStmt>(stmt->getBody());
2659-
assert(body);
2660-
2661-
for (auto it = body->body_begin(), end = body->body_end(); it != end;) {
2662-
if (auto switch_case = clang::dyn_cast<clang::SwitchCase>(*it)) {
2663-
if (clang::isa<clang::CaseStmt>(switch_case)) {
2664-
StrCat("v if v == ");
2665-
} else {
2666-
has_default_case = true;
2667-
}
2668-
VisitSwitchCase(switch_case);
2669-
++it;
2670-
}
2671-
2672-
while (it != end && !clang::isa<clang::SwitchCase>(*it)) {
2673-
Convert(*it);
2674-
++it;
2673+
clang::SwitchCase *default_case = nullptr;
2674+
for (auto *sc : GetTopLevelSwitchCases(stmt)) {
2675+
if (SwitchCaseContainsDefault(sc)) {
2676+
default_case = sc;
2677+
continue;
26752678
}
2676-
2677-
StrCat("},");
2679+
EmitSwitchArm(body, sc, /*is_default=*/false);
26782680
}
26792681

2680-
if (!has_default_case) {
2682+
if (default_case) {
2683+
EmitSwitchArm(body, default_case, /*is_default=*/true);
2684+
} else {
26812685
StrCat(R"( _ => {})");
26822686
}
26832687

cpp2rust/converter/converter.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,10 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
290290

291291
virtual bool VisitSwitchStmt(clang::SwitchStmt *stmt);
292292

293-
virtual bool VisitSwitchCase(clang::SwitchCase *stmt);
293+
void EmitSwitchArm(clang::CompoundStmt *body, clang::SwitchCase *sc,
294+
bool is_default);
295+
296+
bool ConvertSwitchCaseCondition(clang::SwitchCase *stmt);
294297

295298
virtual bool VisitVAArgExpr(clang::VAArgExpr *expr);
296299

cpp2rust/converter/converter_lib.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,58 @@ clang::Expr *CreateConversionToBool(clang::Expr *expr, clang::ASTContext &ctx) {
662662
/*BasePath=*/nullptr, clang::VK_PRValue, clang::FPOptionsOverride());
663663
}
664664

665+
std::vector<clang::SwitchCase *>
666+
GetTopLevelSwitchCases(clang::SwitchStmt *stmt) {
667+
std::vector<clang::SwitchCase *> cases;
668+
if (auto *body = llvm::dyn_cast<clang::CompoundStmt>(stmt->getBody())) {
669+
for (auto *s : body->body()) {
670+
if (auto *sc = clang::dyn_cast<clang::SwitchCase>(s)) {
671+
cases.push_back(sc);
672+
}
673+
}
674+
}
675+
return cases;
676+
}
677+
678+
bool SwitchCaseContainsDefault(clang::SwitchCase *c) {
679+
for (clang::Stmt *cur = c;;) {
680+
if (clang::isa<clang::DefaultStmt>(cur)) {
681+
return true;
682+
}
683+
auto *sc = clang::dyn_cast<clang::SwitchCase>(cur);
684+
if (!sc) {
685+
return false;
686+
}
687+
cur = sc->getSubStmt();
688+
}
689+
return false;
690+
}
691+
692+
static clang::Stmt *GetLastStmtOfSwitchCase(clang::SwitchCase *c) {
693+
clang::Stmt *cur = c->getSubStmt();
694+
while (auto *sc = clang::dyn_cast<clang::SwitchCase>(cur)) {
695+
cur = sc->getSubStmt();
696+
}
697+
return cur;
698+
}
699+
700+
std::vector<clang::Stmt *> GetSwitchCaseBody(clang::CompoundStmt *body,
701+
clang::SwitchCase *head) {
702+
std::vector<clang::Stmt *> out;
703+
out.push_back(GetLastStmtOfSwitchCase(head));
704+
auto it = body->body_begin(), end = body->body_end();
705+
while (it != end && *it != head) {
706+
++it;
707+
}
708+
assert(it != end);
709+
++it;
710+
while (it != end && !clang::isa<clang::SwitchCase>(*it)) {
711+
out.push_back(*it);
712+
++it;
713+
}
714+
return out;
715+
}
716+
665717
static std::string_view Trim(std::string_view s) {
666718
auto is_space = [](unsigned char c) { return std::isspace(c); };
667719
auto b = std::find_if_not(s.begin(), s.end(), is_space);

cpp2rust/converter/converter_lib.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,14 @@ bool ContainsVAArgExpr(const clang::Stmt *stmt);
155155

156156
clang::Expr *CreateConversionToBool(clang::Expr *expr, clang::ASTContext &ctx);
157157

158+
std::vector<clang::SwitchCase *>
159+
GetTopLevelSwitchCases(clang::SwitchStmt *stmt);
160+
161+
bool SwitchCaseContainsDefault(clang::SwitchCase *c);
162+
163+
std::vector<clang::Stmt *> GetSwitchCaseBody(clang::CompoundStmt *body,
164+
clang::SwitchCase *head);
165+
158166
void Unwrap(std::string &s, std::string_view prefix, std::string_view suffix);
159167

160168
} // namespace cpp2rust
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
extern crate libcc2rs;
2+
use libcc2rs::*;
3+
use std::cell::RefCell;
4+
use std::collections::BTreeMap;
5+
use std::io::prelude::*;
6+
use std::io::{Read, Seek, Write};
7+
use std::os::fd::AsFd;
8+
use std::rc::{Rc, Weak};
9+
pub fn case_then_default_0(x: i32) -> i32 {
10+
let x: Value<i32> = Rc::new(RefCell::new(x));
11+
let r: Value<i32> = Rc::new(RefCell::new(0));
12+
'switch: {
13+
let __match_cond = (*x.borrow());
14+
match __match_cond {
15+
v if v == 2 => {
16+
(*r.borrow_mut()) = 20;
17+
break 'switch;
18+
}
19+
_ => {
20+
(*r.borrow_mut()) = 10;
21+
break 'switch;
22+
}
23+
}
24+
};
25+
return (*r.borrow());
26+
}
27+
pub fn main() {
28+
std::process::exit(main_0());
29+
}
30+
fn main_0() -> i32 {
31+
assert!(
32+
(({
33+
let _x: i32 = 1;
34+
case_then_default_0(_x)
35+
}) == 10)
36+
);
37+
assert!(
38+
(({
39+
let _x: i32 = 2;
40+
case_then_default_0(_x)
41+
}) == 20)
42+
);
43+
assert!(
44+
(({
45+
let _x: i32 = 99;
46+
case_then_default_0(_x)
47+
}) == 10)
48+
);
49+
return 0;
50+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
extern crate libcc2rs;
2+
use libcc2rs::*;
3+
use std::cell::RefCell;
4+
use std::collections::BTreeMap;
5+
use std::io::prelude::*;
6+
use std::io::{Read, Seek, Write};
7+
use std::os::fd::AsFd;
8+
use std::rc::{Rc, Weak};
9+
pub fn cases_and_default_stacked_0(x: i32) -> i32 {
10+
let x: Value<i32> = Rc::new(RefCell::new(x));
11+
let r: Value<i32> = Rc::new(RefCell::new(0));
12+
'switch: {
13+
let __match_cond = (*x.borrow());
14+
match __match_cond {
15+
v if v == 3 => {
16+
(*r.borrow_mut()) = 3;
17+
break 'switch;
18+
}
19+
_ => {
20+
(*r.borrow_mut()) = 42;
21+
break 'switch;
22+
}
23+
}
24+
};
25+
return (*r.borrow());
26+
}
27+
pub fn main() {
28+
std::process::exit(main_0());
29+
}
30+
fn main_0() -> i32 {
31+
assert!(
32+
(({
33+
let _x: i32 = 1;
34+
cases_and_default_stacked_0(_x)
35+
}) == 42)
36+
);
37+
assert!(
38+
(({
39+
let _x: i32 = 2;
40+
cases_and_default_stacked_0(_x)
41+
}) == 42)
42+
);
43+
assert!(
44+
(({
45+
let _x: i32 = 3;
46+
cases_and_default_stacked_0(_x)
47+
}) == 3)
48+
);
49+
assert!(
50+
(({
51+
let _x: i32 = 99;
52+
cases_and_default_stacked_0(_x)
53+
}) == 42)
54+
);
55+
return 0;
56+
}

tests/unit/out/refcount/switch_default_first.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@ pub fn default_first_0(x: i32) -> i32 {
1212
'switch: {
1313
let __match_cond = (*x.borrow());
1414
match __match_cond {
15-
_ => {
16-
(*r.borrow_mut()) = 7;
17-
break 'switch;
18-
}
1915
v if v == 1 => {
2016
(*r.borrow_mut()) = 1;
2117
break 'switch;
@@ -24,6 +20,10 @@ pub fn default_first_0(x: i32) -> i32 {
2420
(*r.borrow_mut()) = 2;
2521
break 'switch;
2622
}
23+
_ => {
24+
(*r.borrow_mut()) = 7;
25+
break 'switch;
26+
}
2727
}
2828
};
2929
return (*r.borrow());

tests/unit/out/refcount/switch_default_middle.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@ pub fn default_middle_0(x: i32) -> i32 {
1616
(*r.borrow_mut()) = 1;
1717
break 'switch;
1818
}
19-
_ => {
20-
(*r.borrow_mut()) = 99;
21-
break 'switch;
22-
}
2319
v if v == 2 => {
2420
(*r.borrow_mut()) = 2;
2521
break 'switch;
2622
}
23+
_ => {
24+
(*r.borrow_mut()) = 99;
25+
break 'switch;
26+
}
2727
}
2828
};
2929
return (*r.borrow());

0 commit comments

Comments
 (0)