1use crate::cst::CstNode;
13use crate::cst_js::{parse_js, print_js};
14use crate::cst_lean::{parse_lean, print_lean};
15use crate::cst_rocq::{parse_rocq, print_rocq};
16use crate::cst_rust::{parse_rust, print_rust};
17
18pub const SUPPORTED_LANGUAGES: &[&str] = &["rust", "js", "javascript", "lean", "rocq"];
20
21pub fn parse_to_cst(src: &str, lang: &str) -> Result<CstNode, String> {
23 match lang {
24 "rust" => Ok(parse_rust(src)),
25 "js" | "javascript" => Ok(parse_js(src)),
26 "lean" => Ok(parse_lean(src)),
27 "rocq" => Ok(parse_rocq(src)),
28 other => Err(format!("unsupported language for parse_to_cst: {}", other)),
29 }
30}
31
32pub fn print_from_cst(node: &CstNode, lang: &str) -> Result<String, String> {
34 match lang {
35 "rust" => Ok(print_rust(node)),
36 "js" | "javascript" => Ok(print_js(node)),
37 "lean" => Ok(print_lean(node)),
38 "rocq" => Ok(print_rocq(node)),
39 other => Err(format!("unsupported language for print_from_cst: {}", other)),
40 }
41}
42
43pub struct RoundTripResult {
45 pub ok: bool,
46 pub source: String,
47 pub round_tripped: String,
48}
49
50pub fn round_trip(src: &str, lang: &str) -> Result<RoundTripResult, String> {
52 let cst = parse_to_cst(src, lang)?;
53 let round_tripped = print_from_cst(&cst, lang)?;
54 Ok(RoundTripResult {
55 ok: round_tripped == src,
56 source: src.to_string(),
57 round_tripped,
58 })
59}
60
61#[cfg(test)]
62mod tests {
63 use super::*;
64
65 #[test]
66 fn dispatches_by_language_name() {
67 for lang in &["rust", "js", "lean", "rocq"] {
68 let sample = if *lang == "rocq" {
69 "Definition x := 1.\n"
70 } else {
71 "x\n"
72 };
73 let out = print_from_cst(&parse_to_cst(sample, lang).unwrap(), lang).unwrap();
74 assert_eq!(out, sample);
75 }
76 }
77
78 #[test]
79 fn javascript_alias_works_like_js() {
80 let src = "const x = 1;\n";
81 let out =
82 print_from_cst(&parse_to_cst(src, "javascript").unwrap(), "javascript").unwrap();
83 assert_eq!(out, src);
84 }
85
86 #[test]
87 fn rejects_unsupported_language() {
88 assert!(parse_to_cst("x", "python").is_err());
89 let n = CstNode::list("x", vec![]);
90 assert!(print_from_cst(&n, "python").is_err());
91 }
92
93 #[test]
94 fn round_trip_helper_reports_ok() {
95 let r = round_trip("fn main() {}\n", "rust").unwrap();
96 assert!(r.ok);
97 assert_eq!(r.round_tripped, "fn main() {}\n");
98 }
99}