Skip to main content

rml/
cst_lean.rs

1//! Lean 4 ↔ `.lino` CST converter (issue #138).
2//!
3//! Token-level lossless converter for Lean 4 source. Produces a
4//! `lino-cst.lean.*` flat CST whose round-trip is byte-faithful:
5//! `print_lean(&parse_lean(src)) == src`. Mirrors `js/src/cst-lean.mjs`
6//! line for line.
7
8use crate::cst::{dialects::LEAN, print_cst, CstNode};
9
10/// Parse Lean 4 source into a `lino-cst.lean.*` CST.
11pub fn parse_lean(src: &str) -> CstNode {
12    let children = tokenise(src);
13    CstNode::list(format!("{}.module", LEAN), children)
14}
15
16/// Print a Lean CST back to source.
17pub fn print_lean(node: &CstNode) -> String {
18    print_cst(node)
19}
20
21fn tokenise(src: &str) -> Vec<CstNode> {
22    let chars: Vec<char> = src.chars().collect();
23    let mut out: Vec<CstNode> = Vec::new();
24    let mut i = 0usize;
25
26    while i < chars.len() {
27        let c = chars[i];
28
29        if c == ' ' || c == '\t' || c == '\r' || c == '\n' {
30            let mut j = i;
31            while j < chars.len()
32                && (chars[j] == ' ' || chars[j] == '\t' || chars[j] == '\r' || chars[j] == '\n')
33            {
34                j += 1;
35            }
36            out.push(CstNode::trivia(
37                chars[i..j].iter().collect::<String>(),
38                Some(&format!("{}.whitespace", LEAN)),
39            ));
40            i = j;
41            continue;
42        }
43
44        if c == '-' && chars.get(i + 1) == Some(&'-') {
45            let mut j = i + 2;
46            while j < chars.len() && chars[j] != '\n' {
47                j += 1;
48            }
49            out.push(CstNode::trivia(
50                chars[i..j].iter().collect::<String>(),
51                Some(&format!("{}.comment.line", LEAN)),
52            ));
53            i = j;
54            continue;
55        }
56
57        if c == '/' && chars.get(i + 1) == Some(&'-') {
58            let j = scan_block_comment(&chars, i);
59            let tag = if chars.get(i + 2) == Some(&'-') {
60                format!("{}.doc.block", LEAN)
61            } else {
62                format!("{}.comment.block", LEAN)
63            };
64            out.push(CstNode::trivia(
65                chars[i..j].iter().collect::<String>(),
66                Some(&tag),
67            ));
68            i = j;
69            continue;
70        }
71
72        if c == '"' {
73            let j = scan_string(&chars, i + 1, '"');
74            out.push(CstNode::token(
75                chars[i..j].iter().collect::<String>(),
76                Some(&format!("{}.string_literal", LEAN)),
77            ));
78            i = j;
79            continue;
80        }
81
82        if c == 'r' && chars.get(i + 1) == Some(&'"') {
83            let j = scan_string(&chars, i + 2, '"');
84            out.push(CstNode::token(
85                chars[i..j].iter().collect::<String>(),
86                Some(&format!("{}.raw_string_literal", LEAN)),
87            ));
88            i = j;
89            continue;
90        }
91
92        if c == '\'' {
93            let mut j = i + 1;
94            if chars.get(j) == Some(&'\\') {
95                j += 2;
96            } else {
97                j += 1;
98            }
99            if chars.get(j) == Some(&'\'') {
100                j += 1;
101                out.push(CstNode::token(
102                    chars[i..j].iter().collect::<String>(),
103                    Some(&format!("{}.char_literal", LEAN)),
104                ));
105                i = j;
106                continue;
107            }
108            // Not a valid char literal; fall through to punctuation.
109        }
110
111        if c.is_ascii_digit() {
112            let j = scan_number(&chars, i);
113            out.push(CstNode::token(
114                chars[i..j].iter().collect::<String>(),
115                Some(&format!("{}.numeric_literal", LEAN)),
116            ));
117            i = j;
118            continue;
119        }
120
121        if is_ident_start(c) {
122            let mut j = i + 1;
123            while j < chars.len() && is_ident_continue(chars[j]) {
124                j += 1;
125            }
126            // Dotted hierarchical name: `Nat.succ`, `List.foldr`.
127            while chars.get(j) == Some(&'.')
128                && chars
129                    .get(j + 1)
130                    .map(|c| is_ident_start(*c))
131                    .unwrap_or(false)
132            {
133                j += 1;
134                while j < chars.len() && is_ident_continue(chars[j]) {
135                    j += 1;
136                }
137            }
138            out.push(CstNode::token(
139                chars[i..j].iter().collect::<String>(),
140                Some(&format!("{}.ident", LEAN)),
141            ));
142            i = j;
143            continue;
144        }
145
146        // Multi-byte / other punctuation: emit one codepoint.
147        out.push(CstNode::token(
148            c.to_string(),
149            Some(&format!("{}.punct", LEAN)),
150        ));
151        i += 1;
152    }
153
154    out
155}
156
157fn scan_block_comment(chars: &[char], i: usize) -> usize {
158    let mut j = i + 2;
159    let mut depth = 1;
160    while j < chars.len() && depth > 0 {
161        if chars[j] == '/' && chars.get(j + 1) == Some(&'-') {
162            depth += 1;
163            j += 2;
164        } else if chars[j] == '-' && chars.get(j + 1) == Some(&'/') {
165            depth -= 1;
166            j += 2;
167        } else {
168            j += 1;
169        }
170    }
171    j
172}
173
174fn scan_string(chars: &[char], mut j: usize, quote: char) -> usize {
175    while j < chars.len() {
176        let c = chars[j];
177        if c == '\\' {
178            j += 2;
179            continue;
180        }
181        if c == quote {
182            return j + 1;
183        }
184        j += 1;
185    }
186    j
187}
188
189fn scan_number(chars: &[char], i: usize) -> usize {
190    let mut j = i;
191    if chars.get(j) == Some(&'0') && matches!(chars.get(j + 1), Some('x') | Some('X')) {
192        j += 2;
193        while j < chars.len() && chars[j].is_ascii_hexdigit() {
194            j += 1;
195        }
196        return j;
197    }
198    if chars.get(j) == Some(&'0') && matches!(chars.get(j + 1), Some('o') | Some('O')) {
199        j += 2;
200        while j < chars.len() && matches!(chars[j], '0'..='7') {
201            j += 1;
202        }
203        return j;
204    }
205    if chars.get(j) == Some(&'0') && matches!(chars.get(j + 1), Some('b') | Some('B')) {
206        j += 2;
207        while j < chars.len() && matches!(chars[j], '0' | '1') {
208            j += 1;
209        }
210        return j;
211    }
212    while j < chars.len() && chars[j].is_ascii_digit() {
213        j += 1;
214    }
215    if chars.get(j) == Some(&'.')
216        && chars.get(j + 1).map(|c| c.is_ascii_digit()).unwrap_or(false)
217    {
218        j += 1;
219        while j < chars.len() && chars[j].is_ascii_digit() {
220            j += 1;
221        }
222        if matches!(chars.get(j), Some('e') | Some('E')) {
223            j += 1;
224            if matches!(chars.get(j), Some('+') | Some('-')) {
225                j += 1;
226            }
227            while j < chars.len() && chars[j].is_ascii_digit() {
228                j += 1;
229            }
230        }
231    }
232    j
233}
234
235fn is_ident_start(c: char) -> bool {
236    if c == '_' || c.is_ascii_alphabetic() {
237        return true;
238    }
239    if (c as u32) > 0x7F {
240        return !is_lean_punct_char(c);
241    }
242    false
243}
244
245fn is_ident_continue(c: char) -> bool {
246    if c == '_' || c == '\'' || c == '!' || c == '?' || c.is_ascii_alphanumeric() {
247        return true;
248    }
249    if (c as u32) > 0x7F {
250        return !is_lean_punct_char(c);
251    }
252    false
253}
254
255fn is_lean_punct_char(c: char) -> bool {
256    matches!(
257        c,
258        '→' | '←' | '↦' | '⟨' | '⟩' | '⟦' | '⟧' | '«' | '»' | '‹' | '›'
259    )
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    fn rt(src: &str) {
267        let node = parse_lean(src);
268        let back = print_lean(&node);
269        assert_eq!(back, src, "round-trip mismatch");
270    }
271
272    #[test]
273    fn empty_string() {
274        rt("");
275    }
276
277    #[test]
278    fn simple_def() {
279        rt("def f : Nat := 1\n");
280    }
281
282    #[test]
283    fn line_comment() {
284        rt("-- comment\ndef f : Nat := 1\n");
285    }
286
287    #[test]
288    fn block_and_doc_comments() {
289        rt("/- block -/\n/-- doc -/\n/-! module -/\n");
290    }
291
292    #[test]
293    fn nested_block_comment() {
294        rt("/- outer /- inner -/ still outer -/\ndef x := 1\n");
295    }
296
297    #[test]
298    fn unicode_ident_and_arrow() {
299        rt("def id {α : Type} (x : α) : α := x\n");
300    }
301
302    #[test]
303    fn dotted_ident() {
304        rt("#check Nat.succ\n");
305    }
306
307    #[test]
308    fn char_literal() {
309        rt("def c : Char := 'a'\n");
310    }
311}