Skip to main content

rml/
lean_export.rs

1//! Lean 4 exporter for the typed RML fragment (issue #60).
2//!
3//! The exporter translates declaration syntax into a Lean-checkable artifact
4//! and rejects probabilistic forms instead of assigning them a Lean meaning.
5
6use crate::{
7    compute_form_spans, is_num, key_of, parse_lino, parse_one, tokenize_one, Diagnostic, Node, Span,
8};
9use std::collections::{HashMap, HashSet};
10
11const HEADER: &[&str] = &[
12    "-- Generated by RML Lean exporter.",
13    "-- Supported subset: typed declarations, Pi/lambda/apply terms, and inductive declarations.",
14    "",
15];
16
17/// Structured result returned by [`export_lean`].
18#[derive(Debug, Clone)]
19pub struct LeanExportResult {
20    pub source: String,
21    pub diagnostics: Vec<Diagnostic>,
22}
23
24#[derive(Debug, Clone)]
25struct Binding {
26    param_name: String,
27    param_type: Node,
28}
29
30#[derive(Debug, Default)]
31struct ExportCtx {
32    types: HashMap<String, Node>,
33    lines: Vec<String>,
34    blocks: Vec<String>,
35}
36
37/// Convert an RML identifier to a Lean-safe identifier.
38pub fn lean_ident(raw: &str) -> String {
39    if raw == "_" {
40        return "_".to_string();
41    }
42    let mut out: String = raw
43        .chars()
44        .map(|c| {
45            if c.is_ascii_alphanumeric() || c == '_' {
46                c
47            } else {
48                '_'
49            }
50        })
51        .collect();
52    if out.is_empty() {
53        out = "rml".to_string();
54    }
55    if !out
56        .chars()
57        .next()
58        .map(|c| c.is_ascii_alphabetic() || c == '_')
59        .unwrap_or(false)
60    {
61        out = format!("rml_{}", out);
62    }
63    if reserved().contains(out.as_str()) {
64        out = format!("rml_{}", out);
65    }
66    out
67}
68
69fn reserved() -> HashSet<&'static str> {
70    [
71        "Type",
72        "Prop",
73        "Sort",
74        "axiom",
75        "def",
76        "fun",
77        "inductive",
78        "where",
79        "match",
80        "with",
81        "let",
82        "in",
83        "if",
84        "then",
85        "else",
86        "forall",
87        "by",
88        "theorem",
89        "example",
90        "namespace",
91        "open",
92        "import",
93        "true",
94        "false",
95    ]
96    .into_iter()
97    .collect()
98}
99
100fn probabilistic_heads() -> HashSet<&'static str> {
101    [
102        "range", "valence", "=", "!=", "and", "or", "not", "both", "neither",
103    ]
104    .into_iter()
105    .collect()
106}
107
108fn diagnostic(message: impl Into<String>, span: &Span) -> Diagnostic {
109    Diagnostic::new("E050", message, span.clone())
110}
111
112fn parse_forms(text: &str) -> Result<Vec<Node>, String> {
113    let mut out = Vec::new();
114    for link in parse_lino(text) {
115        let trimmed = link.trim();
116        if trimmed.starts_with("(# ") {
117            continue;
118        }
119        let toks = tokenize_one(&link);
120        out.push(parse_one(&toks)?);
121    }
122    Ok(out)
123}
124
125fn unwrap_form(mut form: Node) -> Node {
126    loop {
127        match form {
128            Node::List(children) if children.len() == 1 && matches!(children[0], Node::List(_)) => {
129                form = children[0].clone();
130            }
131            _ => return form,
132        }
133    }
134}
135
136fn leaf(node: &Node) -> Option<&str> {
137    match node {
138        Node::Leaf(s) => Some(s),
139        _ => None,
140    }
141}
142
143fn list(node: &Node) -> Option<&[Node]> {
144    match node {
145        Node::List(children) => Some(children),
146        _ => None,
147    }
148}
149
150fn parse_binding_node(binding: &Node) -> Option<Binding> {
151    let children = list(binding)?;
152    if children.len() != 2 {
153        return None;
154    }
155    if let Some(s) = leaf(&children[0]) {
156        if let Some(name) = s.strip_suffix(':') {
157            return Some(Binding {
158                param_name: name.to_string(),
159                param_type: children[1].clone(),
160            });
161        }
162    }
163    if let (Some(type_name), Some(var_name)) = (leaf(&children[0]), leaf(&children[1])) {
164        if type_name
165            .chars()
166            .next()
167            .map(|c| c.is_uppercase())
168            .unwrap_or(false)
169            && !var_name.ends_with(':')
170        {
171            return Some(Binding {
172                param_name: var_name.to_string(),
173                param_type: Node::Leaf(type_name.to_string()),
174            });
175        }
176    }
177    if matches!(children[0], Node::List(_)) {
178        if let Some(var_name) = leaf(&children[1]) {
179            if !var_name.ends_with(':') {
180                return Some(Binding {
181                    param_name: var_name.to_string(),
182                    param_type: children[0].clone(),
183                });
184            }
185        }
186    }
187    None
188}
189
190fn subst_node(node: &Node, name: &str, replacement: &Node) -> Node {
191    match node {
192        Node::Leaf(s) if s == name => replacement.clone(),
193        Node::Leaf(_) => node.clone(),
194        Node::List(children) => Node::List(
195            children
196                .iter()
197                .map(|child| subst_node(child, name, replacement))
198                .collect(),
199        ),
200    }
201}
202
203fn is_simple_atom(expr: &str) -> bool {
204    if is_num(expr) {
205        return true;
206    }
207    let mut chars = expr.chars();
208    let Some(first) = chars.next() else {
209        return false;
210    };
211    if !(first.is_ascii_alphabetic() || first == '_') {
212        return false;
213    }
214    chars.all(|c| c.is_ascii_alphanumeric() || c == '_')
215}
216
217fn wrap_lean(expr: &str) -> String {
218    if is_simple_atom(expr) || (expr.starts_with('(') && expr.ends_with(')')) {
219        expr.to_string()
220    } else {
221        format!("({})", expr)
222    }
223}
224
225fn with_scope(scope: &HashMap<String, String>, from: &str, to: &str) -> HashMap<String, String> {
226    let mut next = scope.clone();
227    next.insert(from.to_string(), to.to_string());
228    next
229}
230
231fn resolve_name(name: &str, scope: &HashMap<String, String>) -> String {
232    if let Some(mapped) = scope.get(name) {
233        return mapped.clone();
234    }
235    if name == "Type" || name == "Prop" {
236        return name.to_string();
237    }
238    lean_ident(name)
239}
240
241fn type_to_lean(
242    node: &Node,
243    ctx: &ExportCtx,
244    scope: &HashMap<String, String>,
245    span: &Span,
246) -> Result<String, String> {
247    match node {
248        Node::Leaf(s) => {
249            if is_num(s) {
250                return Err(format!(
251                    "Lean export cannot use numeric literal `{}` as a type",
252                    s
253                ));
254            }
255            Ok(resolve_name(s, scope))
256        }
257        Node::List(children) => {
258            if children.is_empty() {
259                return Err(format!(
260                    "Lean export cannot translate malformed type `{}`",
261                    key_of(node)
262                ));
263            }
264            if children.len() == 2 && leaf(&children[0]) == Some("Type") {
265                if let Some(level) = leaf(&children[1]) {
266                    if level.chars().all(|c| c.is_ascii_digit()) {
267                        return Ok(format!("Type {}", level));
268                    }
269                }
270                return Err(format!(
271                    "Lean export requires numeric Type levels, got `{}`",
272                    key_of(node)
273                ));
274            }
275            if children.len() == 1 && leaf(&children[0]) == Some("Prop") {
276                return Ok("Prop".to_string());
277            }
278            if children.len() == 3
279                && (leaf(&children[0]) == Some("Pi") || leaf(&children[0]) == Some("forall"))
280            {
281                let binding = parse_binding_node(&children[1]).ok_or_else(|| {
282                    format!(
283                        "Lean export cannot translate malformed binder in `{}`",
284                        key_of(node)
285                    )
286                })?;
287                let domain = type_to_lean(&binding.param_type, ctx, scope, span)?;
288                let param = lean_ident(&binding.param_name);
289                let next_scope = if binding.param_name == "_" {
290                    scope.clone()
291                } else {
292                    with_scope(scope, &binding.param_name, &param)
293                };
294                let codomain = type_to_lean(&children[2], ctx, &next_scope, span)?;
295                if binding.param_name == "_" {
296                    Ok(format!("{} -> {}", wrap_lean(&domain), codomain))
297                } else {
298                    Ok(format!("({} : {}) -> {}", param, domain, codomain))
299                }
300            } else {
301                term_to_lean(node, ctx, scope, span)
302            }
303        }
304    }
305}
306
307fn term_to_lean(
308    node: &Node,
309    ctx: &ExportCtx,
310    scope: &HashMap<String, String>,
311    span: &Span,
312) -> Result<String, String> {
313    match node {
314        Node::Leaf(s) => {
315            if is_num(s) {
316                Ok(s.clone())
317            } else {
318                Ok(resolve_name(s, scope))
319            }
320        }
321        Node::List(children) => {
322            if children.is_empty() {
323                return Err(format!(
324                    "Lean export cannot translate malformed term `{}`",
325                    key_of(node)
326                ));
327            }
328            if children.len() == 4
329                && leaf(&children[1]) == Some("has")
330                && leaf(&children[2]) == Some("probability")
331            {
332                return Err(
333                    "Lean export does not support probabilistic `has probability` forms"
334                        .to_string(),
335                );
336            }
337            if children.len() == 2 && leaf(&children[0]) == Some("Type") {
338                return type_to_lean(node, ctx, scope, span);
339            }
340            if children.len() == 1 && leaf(&children[0]) == Some("Prop") {
341                return Ok("Prop".to_string());
342            }
343            if children.len() == 3
344                && (leaf(&children[0]) == Some("Pi") || leaf(&children[0]) == Some("forall"))
345            {
346                return type_to_lean(node, ctx, scope, span);
347            }
348            if children.len() == 3 && leaf(&children[0]) == Some("lambda") {
349                let binding = parse_binding_node(&children[1]).ok_or_else(|| {
350                    format!(
351                        "Lean export cannot translate malformed lambda binder in `{}`",
352                        key_of(node)
353                    )
354                })?;
355                let param = lean_ident(&binding.param_name);
356                let domain = type_to_lean(&binding.param_type, ctx, scope, span)?;
357                let next_scope = with_scope(scope, &binding.param_name, &param);
358                return Ok(format!(
359                    "fun ({} : {}) => {}",
360                    param,
361                    domain,
362                    term_to_lean(&children[2], ctx, &next_scope, span)?
363                ));
364            }
365            if children.len() == 3 && leaf(&children[0]) == Some("apply") {
366                let f = term_to_lean(&children[1], ctx, scope, span)?;
367                let arg = term_to_lean(&children[2], ctx, scope, span)?;
368                return Ok(format!("{} {}", wrap_lean(&f), wrap_lean(&arg)));
369            }
370            if children.len() == 3 && leaf(&children[1]) == Some("=") {
371                let left = term_to_lean(&children[0], ctx, scope, span)?;
372                let right = term_to_lean(&children[2], ctx, scope, span)?;
373                return Ok(format!("{} = {}", wrap_lean(&left), wrap_lean(&right)));
374            }
375            if children.len() == 3 && leaf(&children[1]) == Some("!=") {
376                let left = term_to_lean(&children[0], ctx, scope, span)?;
377                let right = term_to_lean(&children[2], ctx, scope, span)?;
378                return Ok(format!("{} != {}", wrap_lean(&left), wrap_lean(&right)));
379            }
380            let Some(head) = leaf(&children[0]) else {
381                return Err(format!(
382                    "Lean export requires prefix application heads to be symbols, got `{}`",
383                    key_of(node)
384                ));
385            };
386            if probabilistic_heads().contains(head) {
387                return Err(format!(
388                    "Lean export does not support probabilistic operator `{}`",
389                    head
390                ));
391            }
392            let f = resolve_name(head, scope);
393            if children.len() == 1 {
394                return Ok(f);
395            }
396            let mut parts = vec![f];
397            for arg in &children[1..] {
398                parts.push(wrap_lean(&term_to_lean(arg, ctx, scope, span)?));
399            }
400            Ok(parts.join(" "))
401        }
402    }
403}
404
405fn infer_type(node: &Node, ctx: &ExportCtx, bindings: &HashMap<String, Node>) -> Option<Node> {
406    match node {
407        Node::Leaf(s) => bindings
408            .get(s)
409            .cloned()
410            .or_else(|| ctx.types.get(s).cloned()),
411        Node::List(children) if children.len() == 3 && leaf(&children[0]) == Some("lambda") => {
412            let binding = parse_binding_node(&children[1])?;
413            let mut next = bindings.clone();
414            next.insert(binding.param_name, binding.param_type.clone());
415            let body_type = infer_type(&children[2], ctx, &next)?;
416            Some(Node::List(vec![
417                Node::Leaf("Pi".to_string()),
418                children[1].clone(),
419                body_type,
420            ]))
421        }
422        Node::List(children) if children.len() == 3 && leaf(&children[0]) == Some("apply") => {
423            let fn_type = infer_type(&children[1], ctx, bindings)?;
424            let fn_children = list(&fn_type)?;
425            if fn_children.len() != 3 || leaf(&fn_children[0]) != Some("Pi") {
426                return None;
427            }
428            let binding = parse_binding_node(&fn_children[1])?;
429            Some(subst_node(
430                &fn_children[2],
431                &binding.param_name,
432                &children[2],
433            ))
434        }
435        Node::List(children) if !children.is_empty() => {
436            let head = leaf(&children[0])?;
437            let mut current = ctx.types.get(head)?.clone();
438            for arg in &children[1..] {
439                let current_children = list(&current)?;
440                if current_children.len() != 3 || leaf(&current_children[0]) != Some("Pi") {
441                    return None;
442                }
443                let binding = parse_binding_node(&current_children[1])?;
444                current = subst_node(&current_children[2], &binding.param_name, arg);
445            }
446            Some(current)
447        }
448        _ => None,
449    }
450}
451
452fn declare_type(ctx: &mut ExportCtx, name: &str, typ: Node) {
453    ctx.types.insert(name.to_string(), typ);
454}
455
456fn export_definition(form: &[Node], ctx: &mut ExportCtx, span: &Span) -> Result<(), String> {
457    let Some(raw_head) = leaf(&form[0]) else {
458        return Err(format!(
459            "Lean export supports typed declarations, got `{}`",
460            key_of(&Node::List(form.to_vec()))
461        ));
462    };
463    let head = raw_head.trim_end_matches(':');
464    let rhs = &form[1..];
465    if head == "range" || head == "valence" {
466        return Err(format!(
467            "Lean export does not support probabilistic configuration `{}:`",
468            head
469        ));
470    }
471    if probabilistic_heads().contains(head) {
472        return Err(format!(
473            "Lean export does not support probabilistic operator definition `{}:`",
474            head
475        ));
476    }
477
478    if rhs.len() == 3
479        && leaf(&rhs[0]) == Some(head)
480        && leaf(&rhs[1]) == Some("is")
481        && leaf(&rhs[2]) == Some(head)
482    {
483        return Ok(());
484    }
485
486    if rhs.len() == 2 && leaf(&rhs[1]) == Some(head) {
487        let type_node = rhs[0].clone();
488        declare_type(ctx, head, type_node.clone());
489        ctx.lines.push(format!(
490            "axiom {} : {}",
491            lean_ident(head),
492            type_to_lean(&type_node, ctx, &HashMap::new(), span)?
493        ));
494        return Ok(());
495    }
496
497    if rhs.len() == 1 && matches!(rhs[0], Node::List(_)) {
498        let type_node = rhs[0].clone();
499        declare_type(ctx, head, type_node.clone());
500        ctx.lines.push(format!(
501            "axiom {} : {}",
502            lean_ident(head),
503            type_to_lean(&type_node, ctx, &HashMap::new(), span)?
504        ));
505        return Ok(());
506    }
507
508    if rhs.len() == 3 && leaf(&rhs[0]) == Some("lambda") && matches!(rhs[1], Node::List(_)) {
509        let lambda_node = Node::List(vec![
510            Node::Leaf("lambda".to_string()),
511            rhs[1].clone(),
512            rhs[2].clone(),
513        ]);
514        let binding = parse_binding_node(&rhs[1]).ok_or_else(|| {
515            format!(
516                "Lean export cannot translate malformed lambda binder in `{}`",
517                key_of(&Node::List(form.to_vec()))
518            )
519        })?;
520        let type_node = infer_type(&lambda_node, ctx, &HashMap::new())
521            .ok_or_else(|| format!("Lean export could not infer a Lean type for `{}`", head))?;
522        declare_type(ctx, head, type_node.clone());
523        let param = lean_ident(&binding.param_name);
524        let scope = with_scope(&HashMap::new(), &binding.param_name, &param);
525        ctx.lines.push(format!(
526            "def {} : {} := fun {} => {}",
527            lean_ident(head),
528            type_to_lean(&type_node, ctx, &HashMap::new(), span)?,
529            param,
530            term_to_lean(&rhs[2], ctx, &scope, span)?
531        ));
532        return Ok(());
533    }
534
535    Err(format!(
536        "Lean export supports typed declarations and lambda definitions, got `{}`",
537        key_of(&Node::List(form.to_vec()))
538    ))
539}
540
541fn export_inductive(form: &[Node], ctx: &mut ExportCtx, span: &Span) -> Result<(), String> {
542    if form.len() < 3 {
543        return Err("Lean export requires `(inductive Name (constructor ...) ...)`".to_string());
544    }
545    let Some(type_name) = leaf(&form[1]) else {
546        return Err("Lean export requires `(inductive Name (constructor ...) ...)`".to_string());
547    };
548    declare_type(
549        ctx,
550        type_name,
551        Node::List(vec![
552            Node::Leaf("Type".to_string()),
553            Node::Leaf("0".to_string()),
554        ]),
555    );
556    let mut lines = vec![format!(
557        "inductive {} : Type 0 where",
558        lean_ident(type_name)
559    )];
560    for clause in &form[2..] {
561        let Some(parts) = list(clause) else {
562            return Err(format!(
563                "Lean export cannot translate malformed constructor clause `{}`",
564                key_of(clause)
565            ));
566        };
567        if parts.len() != 2 || leaf(&parts[0]) != Some("constructor") {
568            return Err(format!(
569                "Lean export cannot translate malformed constructor clause `{}`",
570                key_of(clause)
571            ));
572        }
573        if let Some(ctor_name) = leaf(&parts[1]) {
574            declare_type(ctx, ctor_name, Node::Leaf(type_name.to_string()));
575            lines.push(format!(
576                "  | {} : {}",
577                lean_ident(ctor_name),
578                lean_ident(type_name)
579            ));
580            continue;
581        }
582        if let Some(body) = list(&parts[1]) {
583            if body.len() == 2 {
584                if let Some(ctor_name) = leaf(&body[0]) {
585                    let ctor_type = body[1].clone();
586                    declare_type(ctx, ctor_name, ctor_type.clone());
587                    lines.push(format!(
588                        "  | {} : {}",
589                        lean_ident(ctor_name),
590                        type_to_lean(&ctor_type, ctx, &HashMap::new(), span)?
591                    ));
592                    continue;
593                }
594            }
595        }
596        return Err(format!(
597            "Lean export cannot translate constructor clause `{}`",
598            key_of(clause)
599        ));
600    }
601    ctx.blocks.push(lines.join("\n"));
602    Ok(())
603}
604
605fn export_form(form: Node, ctx: &mut ExportCtx, span: &Span) -> Result<(), String> {
606    let form = unwrap_form(form);
607    let Some(children) = list(&form) else {
608        return Err(format!(
609            "Lean export cannot translate top-level form `{}`",
610            key_of(&form)
611        ));
612    };
613    if children.is_empty() {
614        return Err(format!(
615            "Lean export cannot translate top-level form `{}`",
616            key_of(&form)
617        ));
618    }
619    if children.len() == 4
620        && leaf(&children[1]) == Some("has")
621        && leaf(&children[2]) == Some("probability")
622    {
623        return Err(
624            "Lean export does not support probabilistic `has probability` forms".to_string(),
625        );
626    }
627    if leaf(&children[0]) == Some("?")
628        || leaf(&children[0]) == Some("Type")
629        || leaf(&children[0]) == Some("Prop")
630    {
631        return Ok(());
632    }
633    if let Some(head) = leaf(&children[0]) {
634        if head.ends_with(':') {
635            return export_definition(children, ctx, span);
636        }
637        if head == "inductive" {
638            return export_inductive(children, ctx, span);
639        }
640        if head == "namespace" || head == "import" || head == "template" {
641            return Err(format!("Lean export does not yet support `{}` forms", head));
642        }
643    }
644    Err(format!(
645        "Lean export supports typed declarations and inductives, got `{}`",
646        key_of(&form)
647    ))
648}
649
650/// Export the supported typed RML fragment to Lean 4 source.
651pub fn export_lean(text: &str, file: Option<&str>) -> LeanExportResult {
652    let spans = compute_form_spans(text, file);
653    let forms = match parse_forms(text) {
654        Ok(forms) => forms,
655        Err(e) => {
656            return LeanExportResult {
657                source: String::new(),
658                diagnostics: vec![Diagnostic::new(
659                    "E006",
660                    format!("LiNo parse failure: {}", e),
661                    Span::new(file.map(|s| s.to_string()), 1, 1, 0),
662                )],
663            };
664        }
665    };
666    let mut ctx = ExportCtx::default();
667    let mut diagnostics = Vec::new();
668    for (idx, form) in forms.into_iter().enumerate() {
669        let span = spans
670            .get(idx)
671            .cloned()
672            .unwrap_or_else(|| Span::new(file.map(|s| s.to_string()), 1, 1, 0));
673        if let Err(message) = export_form(form, &mut ctx, &span) {
674            diagnostics.push(diagnostic(message, &span));
675        }
676    }
677    if !diagnostics.is_empty() {
678        return LeanExportResult {
679            source: String::new(),
680            diagnostics,
681        };
682    }
683    let mut chunks: Vec<String> = HEADER.iter().map(|s| s.to_string()).collect();
684    if !ctx.lines.is_empty() {
685        chunks.push(ctx.lines.join("\n"));
686    }
687    if !ctx.lines.is_empty() && !ctx.blocks.is_empty() {
688        chunks.push(String::new());
689    }
690    if !ctx.blocks.is_empty() {
691        chunks.push(ctx.blocks.join("\n\n"));
692    }
693    let mut source = chunks.join("\n");
694    while source.ends_with('\n') {
695        source.pop();
696    }
697    source.push('\n');
698    LeanExportResult {
699        source,
700        diagnostics: Vec::new(),
701    }
702}