1use crate::cst::{dialects::RUST, print_cst, CstNode};
9
10pub fn parse_rust(src: &str) -> CstNode {
12 let children = tokenise(src);
13 CstNode::list(format!("{}.source_file", RUST), children)
14}
15
16pub fn print_rust(node: &CstNode) -> String {
18 print_cst(node)
19}
20
21fn tokenise(src: &str) -> Vec<CstNode> {
22 let chars: Vec<char> = src.chars().collect();
23 let mut out: Vec<CstNode> = Vec::new();
24 let mut i = 0usize;
25
26 if chars.len() >= 2 && chars[0] == '#' && chars[1] == '!' && chars.get(2) != Some(&'[') {
27 let mut j = i;
28 while j < chars.len() && chars[j] != '\n' {
29 j += 1;
30 }
31 out.push(CstNode::trivia(
32 chars[i..j].iter().collect::<String>(),
33 Some(&format!("{}.shebang", RUST)),
34 ));
35 i = j;
36 }
37
38 while i < chars.len() {
39 let c = chars[i];
40
41 if c == ' ' || c == '\t' || c == '\r' || c == '\n' {
42 let mut j = i;
43 while j < chars.len()
44 && (chars[j] == ' ' || chars[j] == '\t' || chars[j] == '\r' || chars[j] == '\n')
45 {
46 j += 1;
47 }
48 out.push(CstNode::trivia(
49 chars[i..j].iter().collect::<String>(),
50 Some(&format!("{}.whitespace", RUST)),
51 ));
52 i = j;
53 continue;
54 }
55
56 if c == '/' && chars.get(i + 1) == Some(&'/') {
57 let mut j = i + 2;
58 while j < chars.len() && chars[j] != '\n' {
59 j += 1;
60 }
61 out.push(CstNode::trivia(
62 chars[i..j].iter().collect::<String>(),
63 Some(&format!("{}.comment.line", RUST)),
64 ));
65 i = j;
66 continue;
67 }
68
69 if c == '/' && chars.get(i + 1) == Some(&'*') {
70 let j = scan_block_comment(&chars, i);
71 out.push(CstNode::trivia(
72 chars[i..j].iter().collect::<String>(),
73 Some(&format!("{}.comment.block", RUST)),
74 ));
75 i = j;
76 continue;
77 }
78
79 if c == '"' {
80 let j = scan_string(&chars, i + 1, '"');
81 out.push(CstNode::token(
82 chars[i..j].iter().collect::<String>(),
83 Some(&format!("{}.string_literal", RUST)),
84 ));
85 i = j;
86 continue;
87 }
88
89 if (c == 'b' || c == 'r')
91 && (chars.get(i + 1) == Some(&'"')
92 || (c == 'r' && chars.get(i + 1) == Some(&'#'))
93 || (c == 'b'
94 && chars.get(i + 1) == Some(&'r')
95 && (chars.get(i + 2) == Some(&'"') || chars.get(i + 2) == Some(&'#'))))
96 {
97 if let Some(j) = scan_raw_or_prefixed_string(&chars, i) {
98 out.push(CstNode::token(
99 chars[i..j].iter().collect::<String>(),
100 Some(&format!("{}.string_literal", RUST)),
101 ));
102 i = j;
103 continue;
104 }
105 }
106
107 if c == '\'' {
108 let lifetime_end = scan_lifetime(&chars, i);
109 if lifetime_end > i + 1 {
110 out.push(CstNode::token(
111 chars[i..lifetime_end].iter().collect::<String>(),
112 Some(&format!("{}.lifetime", RUST)),
113 ));
114 i = lifetime_end;
115 continue;
116 }
117 let j = scan_string(&chars, i + 1, '\'');
118 out.push(CstNode::token(
119 chars[i..j].iter().collect::<String>(),
120 Some(&format!("{}.char_literal", RUST)),
121 ));
122 i = j;
123 continue;
124 }
125
126 if c == 'b' && chars.get(i + 1) == Some(&'\'') {
127 let j = scan_string(&chars, i + 2, '\'');
128 out.push(CstNode::token(
129 chars[i..j].iter().collect::<String>(),
130 Some(&format!("{}.byte_literal", RUST)),
131 ));
132 i = j;
133 continue;
134 }
135
136 if c.is_ascii_digit() {
137 let j = scan_number(&chars, i);
138 out.push(CstNode::token(
139 chars[i..j].iter().collect::<String>(),
140 Some(&format!("{}.numeric_literal", RUST)),
141 ));
142 i = j;
143 continue;
144 }
145
146 if c == 'r'
147 && chars.get(i + 1) == Some(&'#')
148 && chars
149 .get(i + 2)
150 .map(|c| is_ident_start(*c))
151 .unwrap_or(false)
152 {
153 let mut j = i + 2;
154 while j < chars.len() && is_ident_continue(chars[j]) {
155 j += 1;
156 }
157 out.push(CstNode::token(
158 chars[i..j].iter().collect::<String>(),
159 Some(&format!("{}.raw_ident", RUST)),
160 ));
161 i = j;
162 continue;
163 }
164
165 if is_ident_start(c) {
166 let mut j = i + 1;
167 while j < chars.len() && is_ident_continue(chars[j]) {
168 j += 1;
169 }
170 out.push(CstNode::token(
171 chars[i..j].iter().collect::<String>(),
172 Some(&format!("{}.ident", RUST)),
173 ));
174 i = j;
175 continue;
176 }
177
178 out.push(CstNode::token(
179 c.to_string(),
180 Some(&format!("{}.punct", RUST)),
181 ));
182 i += 1;
183 }
184
185 out
186}
187
188fn scan_block_comment(chars: &[char], i: usize) -> usize {
189 let mut j = i + 2;
190 let mut depth = 1;
191 while j < chars.len() && depth > 0 {
192 if chars[j] == '/' && chars.get(j + 1) == Some(&'*') {
193 depth += 1;
194 j += 2;
195 } else if chars[j] == '*' && chars.get(j + 1) == Some(&'/') {
196 depth -= 1;
197 j += 2;
198 } else {
199 j += 1;
200 }
201 }
202 j
203}
204
205fn scan_string(chars: &[char], mut j: usize, quote: char) -> usize {
206 while j < chars.len() {
207 let c = chars[j];
208 if c == '\\' {
209 j += 2;
210 continue;
211 }
212 if c == quote {
213 return j + 1;
214 }
215 j += 1;
216 }
217 j
218}
219
220fn scan_raw_or_prefixed_string(chars: &[char], i: usize) -> Option<usize> {
221 let mut j = i;
222 if chars.get(j) == Some(&'b') {
223 j += 1;
224 }
225 if chars.get(j) == Some(&'r') {
226 j += 1;
227 let mut hashes = 0;
228 while chars.get(j) == Some(&'#') {
229 hashes += 1;
230 j += 1;
231 }
232 if chars.get(j) != Some(&'"') {
233 return None;
234 }
235 j += 1;
236 let terminator: String =
237 std::iter::once('"').chain(std::iter::repeat('#').take(hashes)).collect();
238 let rest: String = chars[j..].iter().collect();
240 match rest.find(&terminator) {
241 Some(rel) => Some(j + rel + terminator.chars().count()),
242 None => Some(chars.len()),
243 }
244 } else if chars.get(j) == Some(&'"') {
245 Some(scan_string(chars, j + 1, '"'))
246 } else {
247 None
248 }
249}
250
251fn scan_lifetime(chars: &[char], i: usize) -> usize {
252 let mut j = i + 1;
253 if j < chars.len() && is_ident_start(chars[j]) {
254 j += 1;
255 while j < chars.len() && is_ident_continue(chars[j]) {
256 j += 1;
257 }
258 if chars.get(j) == Some(&'\'') {
259 return i;
260 }
261 return j;
262 }
263 i
264}
265
266fn scan_number(chars: &[char], i: usize) -> usize {
267 let mut j = i;
268 if chars.get(j) == Some(&'0') && matches!(chars.get(j + 1), Some('x') | Some('X')) {
269 j += 2;
270 while j < chars.len() && (chars[j].is_ascii_hexdigit() || chars[j] == '_') {
271 j += 1;
272 }
273 } else if chars.get(j) == Some(&'0') && matches!(chars.get(j + 1), Some('o') | Some('O')) {
274 j += 2;
275 while j < chars.len() && matches!(chars[j], '0'..='7' | '_') {
276 j += 1;
277 }
278 } else if chars.get(j) == Some(&'0') && matches!(chars.get(j + 1), Some('b') | Some('B')) {
279 j += 2;
280 while j < chars.len() && matches!(chars[j], '0' | '1' | '_') {
281 j += 1;
282 }
283 } else {
284 while j < chars.len() && (chars[j].is_ascii_digit() || chars[j] == '_') {
285 j += 1;
286 }
287 if chars.get(j) == Some(&'.')
288 && chars.get(j + 1).map(|c| c.is_ascii_digit()).unwrap_or(false)
289 {
290 j += 1;
291 while j < chars.len() && (chars[j].is_ascii_digit() || chars[j] == '_') {
292 j += 1;
293 }
294 }
295 if matches!(chars.get(j), Some('e') | Some('E')) {
296 j += 1;
297 if matches!(chars.get(j), Some('+') | Some('-')) {
298 j += 1;
299 }
300 while j < chars.len() && (chars[j].is_ascii_digit() || chars[j] == '_') {
301 j += 1;
302 }
303 }
304 }
305 if j < chars.len() && is_ident_start(chars[j]) {
306 while j < chars.len() && is_ident_continue(chars[j]) {
307 j += 1;
308 }
309 }
310 j
311}
312
313fn is_ident_start(c: char) -> bool {
314 c == '_' || c.is_ascii_alphabetic() || (c as u32) > 0x7F
315}
316
317fn is_ident_continue(c: char) -> bool {
318 c == '_' || c.is_ascii_alphanumeric() || (c as u32) > 0x7F
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 fn rt(src: &str) {
326 let node = parse_rust(src);
327 let back = print_rust(&node);
328 assert_eq!(back, src, "round-trip mismatch");
329 }
330
331 #[test]
332 fn empty_string() {
333 rt("");
334 }
335
336 #[test]
337 fn simple_fn() {
338 rt("fn main() {}\n");
339 }
340
341 #[test]
342 fn line_and_block_comments() {
343 rt("// hi\nfn f() {\n /* mid */ 1\n}\n");
344 }
345
346 #[test]
347 fn raw_and_byte_strings() {
348 rt("let s = r#\"raw\"#;\nlet b = b\"abc\";\n");
349 }
350
351 #[test]
352 fn lifetime_vs_char() {
353 rt("let c = 'a';\nlet lt: &'static str = \"x\";\n");
354 }
355
356 #[test]
357 fn numeric_with_suffix() {
358 rt("let n = 0xFF_FFu32;\nlet f = 3.14e10_f64;\n");
359 }
360
361 #[test]
362 fn raw_ident() {
363 rt("let r#match = 1;\n");
364 }
365}