Skip to content

Commit 2f6a163

Browse files
committed
Add GotoStateMachine and SwitchStateMachine
1 parent 88ffb77 commit 2f6a163

3 files changed

Lines changed: 243 additions & 119 deletions

File tree

libcc2rs-macros/src/goto.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,20 @@ use proc_macro::TokenStream;
55
use syn::parse::{Parse, ParseStream};
66
use syn::{parse_macro_input, Expr, Lifetime, Token};
77

8-
use crate::state_machine::{emit_state_machine, Arm, ArmEntry};
8+
use crate::state_machine::{Arm, GotoStateMachine, StateMachine};
99

1010
pub fn expand(input: TokenStream) -> TokenStream {
1111
let GotoBlockInput { arms } = parse_macro_input!(input as GotoBlockInput);
12-
emit_state_machine(
13-
None,
14-
arms.into_iter()
12+
GotoStateMachine {
13+
arms: arms
14+
.into_iter()
1515
.map(|a| Arm {
1616
label: a.label.ident.to_string(),
17-
entry: ArmEntry::LabelOnly,
1817
body: a.body,
1918
})
2019
.collect(),
21-
)
20+
}
21+
.emit()
2222
.into()
2323
}
2424

libcc2rs-macros/src/state_machine.rs

Lines changed: 216 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,147 +1,265 @@
11
// Copyright (c) 2022-present INESC-ID.
22
// Distributed under the MIT license that can be found in the LICENSE file.
33

4-
use proc_macro2::TokenStream as TokenStream2;
4+
use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
55
use quote::{format_ident, quote};
66
use syn::visit_mut::{self, VisitMut};
7-
use syn::{Expr, Lifetime, Pat};
7+
use syn::{Expr, ExprBreak, ExprContinue, Lifetime, Pat};
88

99
pub struct Arm {
1010
pub label: String,
11-
pub entry: ArmEntry,
1211
pub body: Expr,
1312
}
1413

15-
pub enum ArmEntry {
16-
Dispatch { pat: Pat, guard: Option<Expr> },
17-
LabelOnly,
14+
pub struct DispatchCase {
15+
pub pat: Pat,
16+
pub guard: Option<Expr>,
17+
pub target: String,
1818
}
1919

20-
pub fn emit_state_machine(condition: Option<Expr>, arms: Vec<Arm>) -> TokenStream2 {
21-
let lbl = Lifetime::new("'__sm", proc_macro2::Span::call_site());
22-
let s = format_ident!("__s");
20+
pub trait StateMachine {
21+
fn emit(self) -> TokenStream2;
22+
}
2323

24-
let base = if condition.is_some() { 1u32 } else { 0u32 };
25-
let rewrite_switch_break = condition.is_some();
24+
fn sm_label() -> Lifetime {
25+
Lifetime::new("'__sm", Span::call_site())
26+
}
2627

27-
let dispatch_arm = condition.map(|scrut| {
28-
let case_arms = arms
29-
.iter()
30-
.enumerate()
31-
.filter_map(|(i, arm)| match &arm.entry {
32-
ArmEntry::Dispatch { pat, guard } => {
33-
let idx = base + i as u32;
34-
let guard = guard.as_ref().map(|g| quote! { if #g });
35-
Some(quote! { #pat #guard => { #s = #idx; continue #lbl; } })
36-
}
37-
ArmEntry::LabelOnly => None,
38-
});
39-
quote! {
40-
0u32 => {
41-
#[allow(unreachable_patterns)]
42-
match #scrut {
43-
#(#case_arms,)*
44-
_ => break #lbl,
45-
}
46-
}
47-
}
48-
});
28+
// Collection of labeled arms that fall-through by default
29+
pub struct GotoStateMachine {
30+
pub arms: Vec<Arm>,
31+
}
4932

50-
let n = arms.len();
51-
let body_arms = arms.iter().enumerate().map(|(i, arm)| {
52-
let idx = base + i as u32;
53-
let body = rewrite_body(&arm.body, &lbl, rewrite_switch_break);
54-
let tail = if i + 1 < n {
55-
let next = idx + 1;
56-
quote! { #s = #next; continue #lbl; }
33+
impl GotoStateMachine {
34+
// Rewrites unlabeled break / continue to { flag = true; break '__sm; }
35+
fn propagate_rewrite(
36+
body: &mut Expr,
37+
label: &Lifetime,
38+
break_flag: &Ident,
39+
cont_flag: &Ident,
40+
) -> (bool, bool) {
41+
let mut br = PropagateRewriter::for_break(label.clone(), break_flag.clone());
42+
br.visit_expr_mut(body);
43+
let mut cr = PropagateRewriter::for_continue(label.clone(), cont_flag.clone());
44+
cr.visit_expr_mut(body);
45+
(br.found, cr.found)
46+
}
47+
48+
// idx => { body; tail }
49+
fn emit_body_arm(
50+
idx: u32,
51+
body: &Expr,
52+
is_last: bool,
53+
label: &Lifetime,
54+
state: &Ident,
55+
) -> TokenStream2 {
56+
let tail = if is_last {
57+
quote! { break #label; }
5758
} else {
58-
quote! { break #lbl; }
59+
let next = idx + 1;
60+
quote! { #state = #next; continue #label; }
5961
};
6062
quote! {
6163
#idx => {
6264
#[allow(unreachable_code)]
6365
{ #body; #tail }
6466
}
6567
}
66-
});
67-
68-
quote! {{
69-
let mut #s: u32 = 0;
70-
#[allow(unreachable_code, unused_labels)]
71-
#lbl: loop {
72-
match #s {
73-
#dispatch_arm
74-
#(#body_arms)*
75-
_ => break #lbl,
76-
}
68+
}
69+
70+
fn bailout(any: bool, flag: &Ident, stmt: TokenStream2) -> (TokenStream2, TokenStream2) {
71+
if !any {
72+
return (TokenStream2::new(), TokenStream2::new());
7773
}
78-
}}
74+
(
75+
quote! { let mut #flag: bool = false; },
76+
quote! {
77+
#[allow(unreachable_code)]
78+
if #flag { #stmt }
79+
},
80+
)
81+
}
7982
}
8083

81-
fn rewrite_body(body: &Expr, label: &Lifetime, rewrite_switch_break: bool) -> TokenStream2 {
82-
let mut body = body.clone();
83-
LoopControlForbidden.visit_expr_mut(&mut body);
84-
if rewrite_switch_break {
85-
SwitchBreakRewriter {
86-
label: label.clone(),
87-
}
88-
.visit_expr_mut(&mut body);
84+
impl StateMachine for GotoStateMachine {
85+
fn emit(self) -> TokenStream2 {
86+
let lbl = sm_label();
87+
let s = format_ident!("__s");
88+
let break_flag = format_ident!("__user_break");
89+
let cont_flag = format_ident!("__user_continue");
90+
91+
let n = self.arms.len();
92+
let mut any_break = false;
93+
let mut any_continue = false;
94+
let body_arms: Vec<_> = self
95+
.arms
96+
.iter()
97+
.enumerate()
98+
.map(|(i, arm)| {
99+
let mut body = arm.body.clone();
100+
let (had_br, had_cn) =
101+
Self::propagate_rewrite(&mut body, &lbl, &break_flag, &cont_flag);
102+
any_break |= had_br;
103+
any_continue |= had_cn;
104+
Self::emit_body_arm(i as u32, &body, i + 1 == n, &lbl, &s)
105+
})
106+
.collect();
107+
108+
let (brk_decl, brk_bailout) = Self::bailout(any_break, &break_flag, quote! { break; });
109+
let (cnt_decl, cnt_bailout) = Self::bailout(any_continue, &cont_flag, quote! { continue; });
110+
111+
quote! {{
112+
#brk_decl
113+
#cnt_decl
114+
let mut #s: u32 = 0;
115+
#[allow(unreachable_code, unused_labels)]
116+
#lbl: loop {
117+
match #s {
118+
#(#body_arms)*
119+
_ => break #lbl,
120+
}
121+
}
122+
#brk_bailout
123+
#cnt_bailout
124+
}}
89125
}
90-
quote! { #body }
91126
}
92127

93-
// Rewrites top-level switch_break!() into break '__sm. switch_break!() inside a loop is a compile
94-
// error.
95-
struct SwitchBreakRewriter {
96-
label: Lifetime,
128+
// GotoStateMachine(dispatch arm + cases)
129+
pub struct SwitchStateMachine {
130+
pub goto: GotoStateMachine,
131+
pub condition: Expr,
132+
pub cases: Vec<DispatchCase>,
97133
}
98134

99-
impl SwitchBreakRewriter {
100-
fn replacement(&self) -> Expr {
101-
let lbl = &self.label;
102-
syn::parse_quote! { break #lbl }
135+
impl SwitchStateMachine {
136+
// Rewrite break into break '__sm
137+
fn convert_break_to_switch_exit(arms: &Vec<Arm>, label: &Lifetime) -> Vec<Arm> {
138+
arms.into_iter()
139+
.map(|a| {
140+
let mut body = a.body.clone();
141+
ExitSwitchRewriter {
142+
label: label.clone(),
143+
}
144+
.visit_expr_mut(&mut body);
145+
Arm {
146+
label: a.label.clone(),
147+
body,
148+
}
149+
})
150+
.collect()
103151
}
104-
}
105152

106-
impl VisitMut for SwitchBreakRewriter {
107-
fn visit_stmt_mut(&mut self, stmt: &mut syn::Stmt) {
108-
if let syn::Stmt::Macro(sm) = stmt {
109-
if sm.mac.path.is_ident("switch_break") {
110-
*stmt = syn::Stmt::Expr(self.replacement(), sm.semi_token);
111-
return;
153+
fn build_dispatch_arm(&self, user_arms: &[Arm], label: &Lifetime, state: &Ident) -> Arm {
154+
let cond = &self.condition;
155+
let case_arms = self.cases.iter().map(|c| {
156+
let target_pos = user_arms
157+
.iter()
158+
.position(|a| a.label == c.target)
159+
.expect("dispatch target must reference an arm label");
160+
let idx = (target_pos as u32) + 1;
161+
let pat = &c.pat;
162+
let guard = c.guard.as_ref().map(|g| quote! { if #g });
163+
quote! { #pat #guard => { #state = #idx; continue #label; } }
164+
});
165+
let body: Expr = syn::parse_quote! {
166+
{
167+
#[allow(unreachable_patterns)]
168+
match #cond {
169+
#(#case_arms,)*
170+
_ => break #label,
171+
}
112172
}
173+
};
174+
Arm {
175+
label: "__dispatch".into(),
176+
body,
177+
}
178+
}
179+
}
180+
181+
impl StateMachine for SwitchStateMachine {
182+
fn emit(self) -> TokenStream2 {
183+
let lbl = sm_label();
184+
let s = format_ident!("__s");
185+
186+
let user_arms = Self::convert_break_to_switch_exit(&self.goto.arms, &lbl);
187+
let dispatch = self.build_dispatch_arm(&user_arms, &lbl, &s);
188+
189+
let mut arms = Vec::new();
190+
arms.push(dispatch);
191+
arms.extend(user_arms);
192+
193+
GotoStateMachine { arms }.emit()
194+
}
195+
}
196+
197+
// Rewrite break into break '__sm
198+
struct ExitSwitchRewriter {
199+
label: Lifetime,
200+
}
201+
202+
impl VisitMut for ExitSwitchRewriter {
203+
fn visit_expr_break_mut(&mut self, node: &mut ExprBreak) {
204+
if node.label.is_none() {
205+
node.label = Some(self.label.clone());
113206
}
114-
visit_mut::visit_stmt_mut(self, stmt);
115207
}
116208
fn visit_expr_loop_mut(&mut self, _: &mut syn::ExprLoop) {}
117209
fn visit_expr_while_mut(&mut self, _: &mut syn::ExprWhile) {}
118210
fn visit_expr_for_loop_mut(&mut self, _: &mut syn::ExprForLoop) {}
119211
}
120212

121-
// Forbid user-written break/continue. We want to hide the fact that switch! and goto_block!
122-
// are loops behind the scenes.
123-
struct LoopControlForbidden;
213+
enum ControlKind {
214+
Break,
215+
Continue,
216+
}
217+
218+
// Rewrites unlabeled break / continue to { flag = true; break '__sm; }
219+
struct PropagateRewriter {
220+
label: Lifetime,
221+
flag: Ident,
222+
kind: ControlKind,
223+
found: bool,
224+
}
124225

125-
impl VisitMut for LoopControlForbidden {
226+
impl PropagateRewriter {
227+
fn for_break(label: Lifetime, flag: Ident) -> Self {
228+
Self {
229+
label,
230+
flag,
231+
kind: ControlKind::Break,
232+
found: false,
233+
}
234+
}
235+
fn for_continue(label: Lifetime, flag: Ident) -> Self {
236+
Self {
237+
label,
238+
flag,
239+
kind: ControlKind::Continue,
240+
found: false,
241+
}
242+
}
243+
}
244+
245+
impl VisitMut for PropagateRewriter {
126246
fn visit_expr_mut(&mut self, expr: &mut Expr) {
127-
match expr {
128-
Expr::Break(_) => {
129-
*expr = syn::parse_quote! {
130-
compile_error!(
131-
"break is not allowed at the top level of a switch!/goto_block! arm. Inside a switch! use switch_break!()",
132-
)
133-
};
134-
return;
247+
let hit = match self.kind {
248+
ControlKind::Break => {
249+
matches!(expr, Expr::Break(ExprBreak { label: None, .. }))
135250
}
136-
Expr::Continue(_) => {
137-
*expr = syn::parse_quote! {
138-
compile_error!(
139-
"continue is not allowed at the top level of a switch!/goto_block! arm"
140-
)
141-
};
142-
return;
251+
ControlKind::Continue => {
252+
matches!(expr, Expr::Continue(ExprContinue { label: None, .. }))
143253
}
144-
_ => {}
254+
};
255+
if hit {
256+
self.found = true;
257+
let flag = &self.flag;
258+
let lbl = &self.label;
259+
*expr = syn::parse_quote! {
260+
{ #flag = true; break #lbl; }
261+
};
262+
return;
145263
}
146264
visit_mut::visit_expr_mut(self, expr);
147265
}

0 commit comments

Comments
 (0)