|
1 | 1 | // Copyright (c) 2022-present INESC-ID. |
2 | 2 | // Distributed under the MIT license that can be found in the LICENSE file. |
3 | 3 |
|
4 | | -use proc_macro2::TokenStream as TokenStream2; |
| 4 | +use proc_macro2::{Ident, Span, TokenStream as TokenStream2}; |
5 | 5 | use quote::{format_ident, quote}; |
6 | 6 | use syn::visit_mut::{self, VisitMut}; |
7 | | -use syn::{Expr, Lifetime, Pat}; |
| 7 | +use syn::{Expr, ExprBreak, ExprContinue, Lifetime, Pat}; |
8 | 8 |
|
9 | 9 | pub struct Arm { |
10 | 10 | pub label: String, |
11 | | - pub entry: ArmEntry, |
12 | 11 | pub body: Expr, |
13 | 12 | } |
14 | 13 |
|
15 | | -pub enum ArmEntry { |
16 | | - Dispatch { pat: Pat, guard: Option<Expr> }, |
17 | | - LabelOnly, |
| 14 | +pub struct DispatchCase { |
| 15 | + pub pat: Pat, |
| 16 | + pub guard: Option<Expr>, |
| 17 | + pub target: String, |
18 | 18 | } |
19 | 19 |
|
20 | | -pub fn emit_state_machine(condition: Option<Expr>, arms: Vec<Arm>) -> TokenStream2 { |
21 | | - let lbl = Lifetime::new("'__sm", proc_macro2::Span::call_site()); |
22 | | - let s = format_ident!("__s"); |
| 20 | +pub trait StateMachine { |
| 21 | + fn emit(self) -> TokenStream2; |
| 22 | +} |
23 | 23 |
|
24 | | - let base = if condition.is_some() { 1u32 } else { 0u32 }; |
25 | | - let rewrite_switch_break = condition.is_some(); |
| 24 | +fn sm_label() -> Lifetime { |
| 25 | + Lifetime::new("'__sm", Span::call_site()) |
| 26 | +} |
26 | 27 |
|
27 | | - let dispatch_arm = condition.map(|scrut| { |
28 | | - let case_arms = arms |
29 | | - .iter() |
30 | | - .enumerate() |
31 | | - .filter_map(|(i, arm)| match &arm.entry { |
32 | | - ArmEntry::Dispatch { pat, guard } => { |
33 | | - let idx = base + i as u32; |
34 | | - let guard = guard.as_ref().map(|g| quote! { if #g }); |
35 | | - Some(quote! { #pat #guard => { #s = #idx; continue #lbl; } }) |
36 | | - } |
37 | | - ArmEntry::LabelOnly => None, |
38 | | - }); |
39 | | - quote! { |
40 | | - 0u32 => { |
41 | | - #[allow(unreachable_patterns)] |
42 | | - match #scrut { |
43 | | - #(#case_arms,)* |
44 | | - _ => break #lbl, |
45 | | - } |
46 | | - } |
47 | | - } |
48 | | - }); |
| 28 | +// Collection of labeled arms that fall-through by default |
| 29 | +pub struct GotoStateMachine { |
| 30 | + pub arms: Vec<Arm>, |
| 31 | +} |
49 | 32 |
|
50 | | - let n = arms.len(); |
51 | | - let body_arms = arms.iter().enumerate().map(|(i, arm)| { |
52 | | - let idx = base + i as u32; |
53 | | - let body = rewrite_body(&arm.body, &lbl, rewrite_switch_break); |
54 | | - let tail = if i + 1 < n { |
55 | | - let next = idx + 1; |
56 | | - quote! { #s = #next; continue #lbl; } |
| 33 | +impl GotoStateMachine { |
| 34 | + // Rewrites unlabeled break / continue to { flag = true; break '__sm; } |
| 35 | + fn propagate_rewrite( |
| 36 | + body: &mut Expr, |
| 37 | + label: &Lifetime, |
| 38 | + break_flag: &Ident, |
| 39 | + cont_flag: &Ident, |
| 40 | + ) -> (bool, bool) { |
| 41 | + let mut br = PropagateRewriter::for_break(label.clone(), break_flag.clone()); |
| 42 | + br.visit_expr_mut(body); |
| 43 | + let mut cr = PropagateRewriter::for_continue(label.clone(), cont_flag.clone()); |
| 44 | + cr.visit_expr_mut(body); |
| 45 | + (br.found, cr.found) |
| 46 | + } |
| 47 | + |
| 48 | + // idx => { body; tail } |
| 49 | + fn emit_body_arm( |
| 50 | + idx: u32, |
| 51 | + body: &Expr, |
| 52 | + is_last: bool, |
| 53 | + label: &Lifetime, |
| 54 | + state: &Ident, |
| 55 | + ) -> TokenStream2 { |
| 56 | + let tail = if is_last { |
| 57 | + quote! { break #label; } |
57 | 58 | } else { |
58 | | - quote! { break #lbl; } |
| 59 | + let next = idx + 1; |
| 60 | + quote! { #state = #next; continue #label; } |
59 | 61 | }; |
60 | 62 | quote! { |
61 | 63 | #idx => { |
62 | 64 | #[allow(unreachable_code)] |
63 | 65 | { #body; #tail } |
64 | 66 | } |
65 | 67 | } |
66 | | - }); |
67 | | - |
68 | | - quote! {{ |
69 | | - let mut #s: u32 = 0; |
70 | | - #[allow(unreachable_code, unused_labels)] |
71 | | - #lbl: loop { |
72 | | - match #s { |
73 | | - #dispatch_arm |
74 | | - #(#body_arms)* |
75 | | - _ => break #lbl, |
76 | | - } |
| 68 | + } |
| 69 | + |
| 70 | + fn bailout(any: bool, flag: &Ident, stmt: TokenStream2) -> (TokenStream2, TokenStream2) { |
| 71 | + if !any { |
| 72 | + return (TokenStream2::new(), TokenStream2::new()); |
77 | 73 | } |
78 | | - }} |
| 74 | + ( |
| 75 | + quote! { let mut #flag: bool = false; }, |
| 76 | + quote! { |
| 77 | + #[allow(unreachable_code)] |
| 78 | + if #flag { #stmt } |
| 79 | + }, |
| 80 | + ) |
| 81 | + } |
79 | 82 | } |
80 | 83 |
|
81 | | -fn rewrite_body(body: &Expr, label: &Lifetime, rewrite_switch_break: bool) -> TokenStream2 { |
82 | | - let mut body = body.clone(); |
83 | | - LoopControlForbidden.visit_expr_mut(&mut body); |
84 | | - if rewrite_switch_break { |
85 | | - SwitchBreakRewriter { |
86 | | - label: label.clone(), |
87 | | - } |
88 | | - .visit_expr_mut(&mut body); |
| 84 | +impl StateMachine for GotoStateMachine { |
| 85 | + fn emit(self) -> TokenStream2 { |
| 86 | + let lbl = sm_label(); |
| 87 | + let s = format_ident!("__s"); |
| 88 | + let break_flag = format_ident!("__user_break"); |
| 89 | + let cont_flag = format_ident!("__user_continue"); |
| 90 | + |
| 91 | + let n = self.arms.len(); |
| 92 | + let mut any_break = false; |
| 93 | + let mut any_continue = false; |
| 94 | + let body_arms: Vec<_> = self |
| 95 | + .arms |
| 96 | + .iter() |
| 97 | + .enumerate() |
| 98 | + .map(|(i, arm)| { |
| 99 | + let mut body = arm.body.clone(); |
| 100 | + let (had_br, had_cn) = |
| 101 | + Self::propagate_rewrite(&mut body, &lbl, &break_flag, &cont_flag); |
| 102 | + any_break |= had_br; |
| 103 | + any_continue |= had_cn; |
| 104 | + Self::emit_body_arm(i as u32, &body, i + 1 == n, &lbl, &s) |
| 105 | + }) |
| 106 | + .collect(); |
| 107 | + |
| 108 | + let (brk_decl, brk_bailout) = Self::bailout(any_break, &break_flag, quote! { break; }); |
| 109 | + let (cnt_decl, cnt_bailout) = Self::bailout(any_continue, &cont_flag, quote! { continue; }); |
| 110 | + |
| 111 | + quote! {{ |
| 112 | + #brk_decl |
| 113 | + #cnt_decl |
| 114 | + let mut #s: u32 = 0; |
| 115 | + #[allow(unreachable_code, unused_labels)] |
| 116 | + #lbl: loop { |
| 117 | + match #s { |
| 118 | + #(#body_arms)* |
| 119 | + _ => break #lbl, |
| 120 | + } |
| 121 | + } |
| 122 | + #brk_bailout |
| 123 | + #cnt_bailout |
| 124 | + }} |
89 | 125 | } |
90 | | - quote! { #body } |
91 | 126 | } |
92 | 127 |
|
93 | | -// Rewrites top-level switch_break!() into break '__sm. switch_break!() inside a loop is a compile |
94 | | -// error. |
95 | | -struct SwitchBreakRewriter { |
96 | | - label: Lifetime, |
| 128 | +// GotoStateMachine(dispatch arm + cases) |
| 129 | +pub struct SwitchStateMachine { |
| 130 | + pub goto: GotoStateMachine, |
| 131 | + pub condition: Expr, |
| 132 | + pub cases: Vec<DispatchCase>, |
97 | 133 | } |
98 | 134 |
|
99 | | -impl SwitchBreakRewriter { |
100 | | - fn replacement(&self) -> Expr { |
101 | | - let lbl = &self.label; |
102 | | - syn::parse_quote! { break #lbl } |
| 135 | +impl SwitchStateMachine { |
| 136 | + // Rewrite break into break '__sm |
| 137 | + fn convert_break_to_switch_exit(arms: &Vec<Arm>, label: &Lifetime) -> Vec<Arm> { |
| 138 | + arms.into_iter() |
| 139 | + .map(|a| { |
| 140 | + let mut body = a.body.clone(); |
| 141 | + ExitSwitchRewriter { |
| 142 | + label: label.clone(), |
| 143 | + } |
| 144 | + .visit_expr_mut(&mut body); |
| 145 | + Arm { |
| 146 | + label: a.label.clone(), |
| 147 | + body, |
| 148 | + } |
| 149 | + }) |
| 150 | + .collect() |
103 | 151 | } |
104 | | -} |
105 | 152 |
|
106 | | -impl VisitMut for SwitchBreakRewriter { |
107 | | - fn visit_stmt_mut(&mut self, stmt: &mut syn::Stmt) { |
108 | | - if let syn::Stmt::Macro(sm) = stmt { |
109 | | - if sm.mac.path.is_ident("switch_break") { |
110 | | - *stmt = syn::Stmt::Expr(self.replacement(), sm.semi_token); |
111 | | - return; |
| 153 | + fn build_dispatch_arm(&self, user_arms: &[Arm], label: &Lifetime, state: &Ident) -> Arm { |
| 154 | + let cond = &self.condition; |
| 155 | + let case_arms = self.cases.iter().map(|c| { |
| 156 | + let target_pos = user_arms |
| 157 | + .iter() |
| 158 | + .position(|a| a.label == c.target) |
| 159 | + .expect("dispatch target must reference an arm label"); |
| 160 | + let idx = (target_pos as u32) + 1; |
| 161 | + let pat = &c.pat; |
| 162 | + let guard = c.guard.as_ref().map(|g| quote! { if #g }); |
| 163 | + quote! { #pat #guard => { #state = #idx; continue #label; } } |
| 164 | + }); |
| 165 | + let body: Expr = syn::parse_quote! { |
| 166 | + { |
| 167 | + #[allow(unreachable_patterns)] |
| 168 | + match #cond { |
| 169 | + #(#case_arms,)* |
| 170 | + _ => break #label, |
| 171 | + } |
112 | 172 | } |
| 173 | + }; |
| 174 | + Arm { |
| 175 | + label: "__dispatch".into(), |
| 176 | + body, |
| 177 | + } |
| 178 | + } |
| 179 | +} |
| 180 | + |
| 181 | +impl StateMachine for SwitchStateMachine { |
| 182 | + fn emit(self) -> TokenStream2 { |
| 183 | + let lbl = sm_label(); |
| 184 | + let s = format_ident!("__s"); |
| 185 | + |
| 186 | + let user_arms = Self::convert_break_to_switch_exit(&self.goto.arms, &lbl); |
| 187 | + let dispatch = self.build_dispatch_arm(&user_arms, &lbl, &s); |
| 188 | + |
| 189 | + let mut arms = Vec::new(); |
| 190 | + arms.push(dispatch); |
| 191 | + arms.extend(user_arms); |
| 192 | + |
| 193 | + GotoStateMachine { arms }.emit() |
| 194 | + } |
| 195 | +} |
| 196 | + |
| 197 | +// Rewrite break into break '__sm |
| 198 | +struct ExitSwitchRewriter { |
| 199 | + label: Lifetime, |
| 200 | +} |
| 201 | + |
| 202 | +impl VisitMut for ExitSwitchRewriter { |
| 203 | + fn visit_expr_break_mut(&mut self, node: &mut ExprBreak) { |
| 204 | + if node.label.is_none() { |
| 205 | + node.label = Some(self.label.clone()); |
113 | 206 | } |
114 | | - visit_mut::visit_stmt_mut(self, stmt); |
115 | 207 | } |
116 | 208 | fn visit_expr_loop_mut(&mut self, _: &mut syn::ExprLoop) {} |
117 | 209 | fn visit_expr_while_mut(&mut self, _: &mut syn::ExprWhile) {} |
118 | 210 | fn visit_expr_for_loop_mut(&mut self, _: &mut syn::ExprForLoop) {} |
119 | 211 | } |
120 | 212 |
|
121 | | -// Forbid user-written break/continue. We want to hide the fact that switch! and goto_block! |
122 | | -// are loops behind the scenes. |
123 | | -struct LoopControlForbidden; |
| 213 | +enum ControlKind { |
| 214 | + Break, |
| 215 | + Continue, |
| 216 | +} |
| 217 | + |
| 218 | +// Rewrites unlabeled break / continue to { flag = true; break '__sm; } |
| 219 | +struct PropagateRewriter { |
| 220 | + label: Lifetime, |
| 221 | + flag: Ident, |
| 222 | + kind: ControlKind, |
| 223 | + found: bool, |
| 224 | +} |
124 | 225 |
|
125 | | -impl VisitMut for LoopControlForbidden { |
| 226 | +impl PropagateRewriter { |
| 227 | + fn for_break(label: Lifetime, flag: Ident) -> Self { |
| 228 | + Self { |
| 229 | + label, |
| 230 | + flag, |
| 231 | + kind: ControlKind::Break, |
| 232 | + found: false, |
| 233 | + } |
| 234 | + } |
| 235 | + fn for_continue(label: Lifetime, flag: Ident) -> Self { |
| 236 | + Self { |
| 237 | + label, |
| 238 | + flag, |
| 239 | + kind: ControlKind::Continue, |
| 240 | + found: false, |
| 241 | + } |
| 242 | + } |
| 243 | +} |
| 244 | + |
| 245 | +impl VisitMut for PropagateRewriter { |
126 | 246 | fn visit_expr_mut(&mut self, expr: &mut Expr) { |
127 | | - match expr { |
128 | | - Expr::Break(_) => { |
129 | | - *expr = syn::parse_quote! { |
130 | | - compile_error!( |
131 | | - "break is not allowed at the top level of a switch!/goto_block! arm. Inside a switch! use switch_break!()", |
132 | | - ) |
133 | | - }; |
134 | | - return; |
| 247 | + let hit = match self.kind { |
| 248 | + ControlKind::Break => { |
| 249 | + matches!(expr, Expr::Break(ExprBreak { label: None, .. })) |
135 | 250 | } |
136 | | - Expr::Continue(_) => { |
137 | | - *expr = syn::parse_quote! { |
138 | | - compile_error!( |
139 | | - "continue is not allowed at the top level of a switch!/goto_block! arm" |
140 | | - ) |
141 | | - }; |
142 | | - return; |
| 251 | + ControlKind::Continue => { |
| 252 | + matches!(expr, Expr::Continue(ExprContinue { label: None, .. })) |
143 | 253 | } |
144 | | - _ => {} |
| 254 | + }; |
| 255 | + if hit { |
| 256 | + self.found = true; |
| 257 | + let flag = &self.flag; |
| 258 | + let lbl = &self.label; |
| 259 | + *expr = syn::parse_quote! { |
| 260 | + { #flag = true; break #lbl; } |
| 261 | + }; |
| 262 | + return; |
145 | 263 | } |
146 | 264 | visit_mut::visit_expr_mut(self, expr); |
147 | 265 | } |
|
0 commit comments