Skip to content

Commit 2165500

Browse files
authored
Translate switches containing fallthrough (#38)
All switch tests pass now. This PR translates switches containing fallthrough using the new `switch!` macro defined in `libcc2rs-macros`: ``` switch!(match <condition> { <pat> [if <guard>] => { /* body; may contain break or continue */ }, ... _ => <body>, }); Desugars to a goto_block! with a synthetic dispatch arm prepended. goto_block! { '__dispatch => { match <condition> { <pat_1> => { __s = 1; continue '__sm; } ... _ => break '__sm, } }, '__c1 => { /* body_1 with `break` rewritten to `break '__sm` */ }, ... '__cN => { /* body_N with same rewrite */ }, }; __sm is the inner label used to describe the state machine insinde goto_block. See goto_block! for more info. ``` It's necessary to translate the switch into a goto because later we will add support for jumping between switch arms using goto. Below is an example of how the `switch!` macro expands: ```cpp int fallthrough_one(int x) { int r = 0; switch (x) { case 1: r += 10; case 2: r += 20; break; default: r = -1; break; } return r; } ``` Will be translated as: ```rs switch!(match x { v if v == 1 => { r += 10; } v if v == 2 => { r += 20; break; } _ => { r = -1_i32; break; } }); ``` Which will expand into (see comments attached to each arm): ```rs { let mut __s: u32 = 0; #[allow(unreachable_code, unused_labels)] '__sm: loop { match __s { // First arm is the dispatch arm. It decides which is the first state: [1u32, ...] 0u32 => { #[allow(unreachable_code)] { { #[allow(unreachable_patterns)] match x { v if v == 1 => { __s = 1u32; continue '__sm; } v if v == 2 => { __s = 2u32; continue '__sm; } _ => { __s = 3u32; continue '__sm; } _ => break '__sm, } }; __s = 1u32; continue '__sm; } } 1u32 => { // First real arm, it contains the body of the original arm + fallthrough to the following state, i.e. 2u32 #[allow(unreachable_code)] { { r += 10; }; __s = 2u32; continue '__sm; } } 2u32 => { // Second real arm, inside the body there is a break statement that stops the fallthrough #[allow(unreachable_code)] { { r += 20; break '__sm; }; __s = 3u32; continue '__sm; } } 3u32 => { // Default arm from the original match, the last arm breaks the state machine instea of continuing #[allow(unreachable_code)] { { r = -1_i32; break '__sm; }; break '__sm; } } // This is here only for match exhaustiveness, it's never used _ => break '__sm, } } }; ```
1 parent 5bdf6a1 commit 2165500

36 files changed

Lines changed: 1047 additions & 191 deletions

.github/workflows/format.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,12 @@ jobs:
4545
cargo fmt --manifest-path rules/Cargo.toml -- --check
4646
cargo fmt --manifest-path rule-preprocessor/Cargo.toml -- --check
4747
cargo fmt --manifest-path libcc2rs/Cargo.toml -- --check
48+
cargo fmt --manifest-path libcc2rs-macros/Cargo.toml -- --check
4849
find tests -name '*.rs' -print0 | xargs -0 rustfmt --check
4950
5051
- name: Check Rust lints
5152
run: |
5253
cargo clippy --manifest-path rules/Cargo.toml --all-targets --all-features -- -Dwarnings
5354
cargo +nightly clippy --manifest-path rule-preprocessor/Cargo.toml --all-targets --all-features -- -Dwarnings
5455
cargo clippy --manifest-path libcc2rs/Cargo.toml --all-targets --all-features -- -Dwarnings
56+
cargo clippy --manifest-path libcc2rs-macros/Cargo.toml --all-targets --all-features -- -Dwarnings

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ add_custom_target("format"
100100
COMMAND rustup run ${RUST_STABLE_VERSION} cargo fmt --manifest-path ${PROJECT_SOURCE_DIR}/rules/Cargo.toml
101101
COMMAND rustup run ${RUST_STABLE_VERSION} cargo fmt --manifest-path ${PROJECT_SOURCE_DIR}/rule-preprocessor/Cargo.toml
102102
COMMAND rustup run ${RUST_STABLE_VERSION} cargo fmt --manifest-path ${PROJECT_SOURCE_DIR}/libcc2rs/Cargo.toml
103+
COMMAND rustup run ${RUST_STABLE_VERSION} cargo fmt --manifest-path ${PROJECT_SOURCE_DIR}/libcc2rs-macros/Cargo.toml
103104
DEPENDS "install-rust-toolchain"
104105
COMMENT "Running clang-format and cargo fmt on all source files"
105106
VERBATIM

cpp2rust/converter/converter.cpp

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2706,14 +2706,31 @@ void Converter::EmitSwitchArm(clang::CompoundStmt *body, clang::SwitchCase *sc,
27062706
}
27072707

27082708
bool Converter::VisitSwitchStmt(clang::SwitchStmt *stmt) {
2709-
PushBreakTarget push(break_target_, BreakTarget::Switch);
2709+
bool has_fallthrough = SwitchHasFallthrough(stmt);
2710+
PushBreakTarget push(break_target_, has_fallthrough
2711+
? BreakTarget::FallthroughSwitch
2712+
: BreakTarget::Switch);
27102713
auto *body = clang::dyn_cast<clang::CompoundStmt>(stmt->getBody());
27112714
assert(body);
27122715

2713-
StrCat("'switch: {");
2714-
StrCat(std::format("let __match_cond = {};", ToString(stmt->getCond())));
2715-
StrCat("match __match_cond");
2716-
StrCat("{");
2716+
if (has_fallthrough) {
2717+
// Use the switch-with-fallthrough macro
2718+
StrCat("switch!");
2719+
} else {
2720+
StrCat("'switch:");
2721+
}
2722+
2723+
PushParen switch_macro_paren(*this, has_fallthrough);
2724+
PushBrace switch_label_brace(*this, !has_fallthrough);
2725+
2726+
if (has_fallthrough) {
2727+
StrCat("match", ToString(stmt->getCond()));
2728+
} else {
2729+
StrCat(std::format("let __match_cond = {};", ToString(stmt->getCond())));
2730+
StrCat("match __match_cond");
2731+
}
2732+
2733+
PushBrace match_brace(*this);
27172734

27182735
clang::SwitchCase *default_case = nullptr;
27192736
for (auto *sc : GetTopLevelSwitchCases(stmt)) {
@@ -2730,8 +2747,6 @@ bool Converter::VisitSwitchStmt(clang::SwitchStmt *stmt) {
27302747
StrCat(R"( _ => {})");
27312748
}
27322749

2733-
StrCat("}");
2734-
StrCat("}");
27352750
return false;
27362751
}
27372752

cpp2rust/converter/converter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
501501
std::stack<clang::Expr *> curr_for_inc_;
502502
std::stack<clang::QualType> curr_init_type_;
503503

504-
enum class BreakTarget { Loop, Switch };
504+
enum class BreakTarget { Loop, FallthroughSwitch, Switch };
505505
std::stack<BreakTarget> break_target_;
506506

507507
bool isSwitchBreak() const {

cpp2rust/converter/converter_lib.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,36 @@ std::vector<clang::Stmt *> GetSwitchCaseBody(clang::CompoundStmt *body,
744744
return out;
745745
}
746746

747+
static bool SwitchCaseHasFallthrough(clang::Stmt *stmt) {
748+
if (!stmt) {
749+
return false;
750+
}
751+
if (auto *compound = clang::dyn_cast<clang::CompoundStmt>(stmt)) {
752+
if (compound->body_empty()) {
753+
return true;
754+
}
755+
return SwitchCaseHasFallthrough(compound->body_back());
756+
}
757+
if (clang::isa<clang::BreakStmt>(stmt) ||
758+
clang::isa<clang::ContinueStmt>(stmt) ||
759+
clang::isa<clang::ReturnStmt>(stmt)) {
760+
return false;
761+
}
762+
return true;
763+
}
764+
765+
bool SwitchHasFallthrough(clang::SwitchStmt *stmt) {
766+
if (auto *body = clang::dyn_cast<clang::CompoundStmt>(stmt->getBody())) {
767+
for (auto top_level_case : GetTopLevelSwitchCases(stmt)) {
768+
auto arm = GetSwitchCaseBody(body, top_level_case);
769+
if (arm.empty() || SwitchCaseHasFallthrough(arm.back())) {
770+
return true;
771+
}
772+
}
773+
}
774+
return false;
775+
}
776+
747777
static std::string_view Trim(std::string_view s) {
748778
auto is_space = [](unsigned char c) { return std::isspace(c); };
749779
auto b = std::find_if_not(s.begin(), s.end(), is_space);

cpp2rust/converter/converter_lib.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ bool SwitchCaseContainsDefault(clang::SwitchCase *c);
164164
std::vector<clang::Stmt *> GetSwitchCaseBody(clang::CompoundStmt *body,
165165
clang::SwitchCase *head);
166166

167+
bool SwitchHasFallthrough(clang::SwitchStmt *stmt);
168+
167169
void Unwrap(std::string &s, std::string_view prefix, std::string_view suffix);
168170

169171
std::string ReplaceAll(std::string str, std::string_view from,

libcc2rs-macros/Cargo.toml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
[package]
2+
name = "libcc2rs-macros"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[lib]
7+
proc-macro = true
8+
9+
[dependencies]
10+
proc-macro2 = "1"
11+
quote = "1"
12+
syn = { version = "2", features = ["full", "visit-mut", "extra-traits"] }
13+
14+
[dev-dependencies]
15+
trybuild = "1"

libcc2rs-macros/src/goto.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) 2022-present INESC-ID.
2+
// Distributed under the MIT license that can be found in the LICENSE file.
3+
4+
use proc_macro::TokenStream;
5+
use syn::parse::{Parse, ParseStream};
6+
use syn::{parse_macro_input, Expr, Lifetime, Token};
7+
8+
use crate::state_machine::{Arm, GotoStateMachine, StateMachine};
9+
10+
pub fn expand(input: TokenStream) -> TokenStream {
11+
let GotoBlockInput { arms } = parse_macro_input!(input as GotoBlockInput);
12+
GotoStateMachine {
13+
arms: arms
14+
.into_iter()
15+
.map(|a| Arm {
16+
label: a.label.ident.to_string(),
17+
body: a.body,
18+
})
19+
.collect(),
20+
}
21+
.emit()
22+
.into()
23+
}
24+
25+
struct GotoBlockInput {
26+
arms: Vec<GotoArm>,
27+
}
28+
29+
struct GotoArm {
30+
label: Lifetime,
31+
body: Expr,
32+
}
33+
34+
impl Parse for GotoBlockInput {
35+
fn parse(input: ParseStream) -> syn::Result<Self> {
36+
let mut arms = Vec::new();
37+
while !input.is_empty() {
38+
let label: Lifetime = input.parse()?;
39+
input.parse::<Token![=>]>()?;
40+
let body: Expr = input.parse()?;
41+
arms.push(GotoArm { label, body });
42+
if input.peek(Token![,]) {
43+
input.parse::<Token![,]>()?;
44+
}
45+
}
46+
Ok(Self { arms })
47+
}
48+
}

libcc2rs-macros/src/lib.rs

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// Copyright (c) 2022-present INESC-ID.
2+
// Distributed under the MIT license that can be found in the LICENSE file.
3+
4+
use proc_macro::TokenStream;
5+
6+
mod goto;
7+
mod state_machine;
8+
mod switch;
9+
10+
// switch!(match <condition> {
11+
// <pat> [if <guard>] => { /* body; may contain break or continue */ },
12+
// ...
13+
// _ => <body>,
14+
// });
15+
//
16+
// Desugars to a goto_block! with a synthetic dispatch arm prepended.
17+
//
18+
// goto_block! {
19+
// '__dispatch => {
20+
// match <condition> {
21+
// <pat_1> => { __s = 1; continue '__sm; }
22+
// ...
23+
// _ => break '__sm,
24+
// }
25+
// },
26+
// '__c1 => { /* body_1 with `break` rewritten to `break '__sm` */ },
27+
// ...
28+
// '__cN => { /* body_N with same rewrite */ },
29+
// };
30+
//
31+
// __sm is the inner label used to describe the state machine insinde goto_block. See goto_block!
32+
// for more info.
33+
34+
#[proc_macro]
35+
pub fn switch(input: TokenStream) -> TokenStream {
36+
switch::expand(input)
37+
}
38+
39+
// goto_block! {
40+
// '<label> => { /* body; may contain `break` or `continue` */ },
41+
// ...
42+
// };
43+
//
44+
// Expands to
45+
//
46+
// {
47+
// let mut __user_break: bool = false; // only if any arm has `break`
48+
// let mut __user_continue: bool = false; // only if any arm has `continue`
49+
// let mut __s: u32 = 0;
50+
// '__sm: loop {
51+
// match __s {
52+
// 0u32 => {
53+
// /* body_0 with these rewrites (outside nested user loops): */
54+
// /* break; -> { __user_break = true; break '__sm; } */
55+
// /* continue; -> { __user_continue = true; break '__sm; } */
56+
// __s = 1; continue '__sm;
57+
// }
58+
// ...
59+
// (N-1)u32 => { /* body_N-1 with same rewrites */ break '__sm; }
60+
// _ => break '__sm, // written only for match exhaustiveness
61+
// }
62+
// }
63+
// if __user_break { break; } // only if any arm has `break;`
64+
// if __user_continue { continue; } // only if any arm has `continue;`
65+
// }
66+
//
67+
// __user_break and __user_continue propagate the `break`s and `continue`s outside the goto state
68+
// machine loop.
69+
70+
#[proc_macro]
71+
pub fn goto_block(input: TokenStream) -> TokenStream {
72+
goto::expand(input)
73+
}
74+
75+
#[proc_macro]
76+
pub fn goto(_input: TokenStream) -> TokenStream {
77+
quote::quote! {
78+
compile_error!("goto!() can only be used inside goto_block!")
79+
}
80+
.into()
81+
}

0 commit comments

Comments
 (0)