arithmetic_eval/compiler/
mod.rs

1//! Transformation of AST output by the parser into non-recursive format.
2
3use arithmetic_parser::{
4    grammars::Grammar, BinaryOp, Block, Destructure, FnDefinition, InputSpan, Lvalue,
5    ObjectDestructure, Spanned, SpannedLvalue, UnaryOp,
6};
7
8pub(crate) use self::captures::Captures;
9use self::captures::{CapturesExtractor, CompilerExtTarget};
10use crate::{
11    alloc::{Arc, HashMap, String, ToOwned},
12    exec::{Atom, Command, CompiledExpr, Executable, ExecutableModule, FieldName, ModuleId},
13    Error, ErrorKind,
14};
15
16mod captures;
17mod expr;
18
19#[derive(Debug)]
20pub(crate) struct Compiler {
21    /// Mapping between registers and named variables.
22    vars_to_registers: HashMap<String, usize>,
23    scope_depth: usize,
24    register_count: usize,
25    module_id: Arc<dyn ModuleId>,
26}
27
28impl Compiler {
29    fn new(module_id: Arc<dyn ModuleId>) -> Self {
30        Self {
31            vars_to_registers: HashMap::new(),
32            scope_depth: 0,
33            register_count: 0,
34            module_id,
35        }
36    }
37
38    fn from_env(module_id: Arc<dyn ModuleId>, env: &Captures) -> Self {
39        Self {
40            vars_to_registers: env.variables_map().clone(),
41            register_count: env.len(),
42            scope_depth: 0,
43            module_id,
44        }
45    }
46
47    /// Backups this instance. This effectively clones all fields.
48    fn backup(&mut self) -> Self {
49        Self {
50            vars_to_registers: self.vars_to_registers.clone(),
51            scope_depth: self.scope_depth,
52            register_count: self.register_count,
53            module_id: self.module_id.clone(),
54        }
55    }
56
57    fn create_error<T>(&self, span: &Spanned<'_, T>, err: ErrorKind) -> Error {
58        Error::new(self.module_id.clone(), span, err)
59    }
60
61    fn check_unary_op(&self, op: &Spanned<'_, UnaryOp>) -> Result<UnaryOp, Error> {
62        match op.extra {
63            UnaryOp::Neg | UnaryOp::Not => Ok(op.extra),
64            _ => Err(self.create_error(op, ErrorKind::unsupported(op.extra))),
65        }
66    }
67
68    fn check_binary_op(&self, op: &Spanned<'_, BinaryOp>) -> Result<BinaryOp, Error> {
69        match op.extra {
70            BinaryOp::Add
71            | BinaryOp::Sub
72            | BinaryOp::Mul
73            | BinaryOp::Div
74            | BinaryOp::Power
75            | BinaryOp::And
76            | BinaryOp::Or
77            | BinaryOp::Eq
78            | BinaryOp::NotEq
79            | BinaryOp::Gt
80            | BinaryOp::Ge
81            | BinaryOp::Lt
82            | BinaryOp::Le => Ok(op.extra),
83
84            _ => Err(self.create_error(op, ErrorKind::unsupported(op.extra))),
85        }
86    }
87
88    fn get_var(&self, name: &str) -> usize {
89        *self
90            .vars_to_registers
91            .get(name)
92            .expect("Captures must created during module compilation")
93    }
94
95    fn push_assignment<T, U>(
96        &mut self,
97        executable: &mut Executable<T>,
98        rhs: CompiledExpr<T>,
99        rhs_span: &Spanned<'_, U>,
100    ) -> usize {
101        let register = self.register_count;
102        let command = Command::Push(rhs);
103        executable.push_command(rhs_span.copy_with_extra(command));
104        self.register_count += 1;
105        register
106    }
107
108    pub fn compile_module<Id: ModuleId, T: Grammar>(
109        module_id: Id,
110        block: &Block<'_, T>,
111    ) -> Result<ExecutableModule<T::Lit>, Error> {
112        let module_id = Arc::new(module_id) as Arc<dyn ModuleId>;
113        let captures = Self::extract_captures(module_id.clone(), block)?;
114        let mut compiler = Self::from_env(module_id.clone(), &captures);
115
116        let mut executable = Executable::new(module_id);
117        let empty_span = InputSpan::new("");
118        let last_atom = compiler
119            .compile_block_inner(&mut executable, block)?
120            .map_or(Atom::Void, |spanned| spanned.extra);
121        // Push the last variable to a register to be popped during execution.
122        compiler.push_assignment(
123            &mut executable,
124            CompiledExpr::Atom(last_atom),
125            &empty_span.into(),
126        );
127
128        executable.finalize_block(compiler.register_count);
129        Ok(ExecutableModule::from_parts(executable, captures))
130    }
131
132    fn extract_captures<T: Grammar>(
133        module_id: Arc<dyn ModuleId>,
134        block: &Block<'_, T>,
135    ) -> Result<Captures, Error> {
136        let mut extractor = CapturesExtractor::new(module_id);
137        extractor.eval_block(block)?;
138        Ok(extractor.into_captures())
139    }
140
141    fn assign<T, Ty>(
142        &mut self,
143        executable: &mut Executable<T>,
144        lhs: &SpannedLvalue<'_, Ty>,
145        rhs_register: usize,
146    ) -> Result<(), Error> {
147        match &lhs.extra {
148            Lvalue::Variable { .. } => {
149                self.insert_var(executable, lhs.with_no_extra(), rhs_register);
150            }
151
152            Lvalue::Tuple(destructure) => {
153                let span = lhs.with_no_extra();
154                self.destructure(executable, destructure, span, rhs_register)?;
155            }
156
157            Lvalue::Object(destructure) => {
158                let span = lhs.with_no_extra();
159                self.destructure_object(executable, destructure, span, rhs_register)?;
160            }
161
162            _ => {
163                let err = ErrorKind::unsupported(lhs.extra.ty());
164                return Err(self.create_error(lhs, err));
165            }
166        }
167
168        Ok(())
169    }
170
171    fn insert_var<T>(
172        &mut self,
173        executable: &mut Executable<T>,
174        var_span: Spanned<'_>,
175        register: usize,
176    ) {
177        let var_name = *var_span.fragment();
178        if var_name != "_" {
179            self.vars_to_registers.insert(var_name.to_owned(), register);
180
181            // It does not make sense to annotate vars in the inner scopes, since
182            // they cannot be accessed externally.
183            if self.scope_depth == 0 {
184                let command = Command::Annotate {
185                    register,
186                    name: var_name.to_owned(),
187                };
188                executable.push_command(var_span.copy_with_extra(command));
189            }
190        }
191    }
192
193    fn destructure<'a, T, Ty>(
194        &mut self,
195        executable: &mut Executable<T>,
196        destructure: &Destructure<'a, Ty>,
197        span: Spanned<'a>,
198        rhs_register: usize,
199    ) -> Result<(), Error> {
200        let command = Command::Destructure {
201            source: rhs_register,
202            start_len: destructure.start.len(),
203            end_len: destructure.end.len(),
204            lvalue_len: destructure.len(),
205            unchecked: false,
206        };
207        executable.push_command(span.copy_with_extra(command));
208        let start_register = self.register_count;
209        self.register_count += destructure.start.len() + destructure.end.len() + 1;
210
211        for (i, lvalue) in (start_register..).zip(&destructure.start) {
212            self.assign(executable, lvalue, i)?;
213        }
214
215        let start_register = start_register + destructure.start.len();
216        if let Some(middle) = &destructure.middle {
217            if let Some(lvalue) = middle.extra.to_lvalue() {
218                self.assign(executable, &lvalue, start_register)?;
219            }
220        }
221
222        let start_register = start_register + 1;
223        for (i, lvalue) in (start_register..).zip(&destructure.end) {
224            self.assign(executable, lvalue, i)?;
225        }
226
227        Ok(())
228    }
229
230    fn destructure_object<'a, T, Ty>(
231        &mut self,
232        executable: &mut Executable<T>,
233        destructure: &ObjectDestructure<'a, Ty>,
234        span: Spanned<'a>,
235        rhs_register: usize,
236    ) -> Result<(), Error> {
237        for field in &destructure.fields {
238            let field_name = FieldName::Name((*field.field_name.fragment()).to_owned());
239            let field_access = CompiledExpr::FieldAccess {
240                receiver: span.copy_with_extra(Atom::Register(rhs_register)).into(),
241                field: field_name,
242            };
243            let register = self.push_assignment(executable, field_access, &field.field_name);
244            if let Some(binding) = &field.binding {
245                self.assign(executable, binding, register)?;
246            } else {
247                self.insert_var(executable, field.field_name, register);
248            }
249        }
250        Ok(())
251    }
252}
253
254/// Compiler extensions defined for some AST nodes, most notably, `Block`.
255///
256/// # Examples
257///
258/// ```
259/// use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
260/// use arithmetic_eval::exec::CompilerExt;
261/// # use std::{collections::HashSet, iter::FromIterator};
262///
263/// # fn main() -> anyhow::Result<()> {
264/// let block = "x = sin(0.5) / PI; y = x * E; (x, y)";
265/// let block = Untyped::<F32Grammar>::parse_statements(block)?;
266/// let undefined_vars = block.undefined_variables()?;
267/// assert_eq!(
268///     undefined_vars.keys().copied().collect::<HashSet<_>>(),
269///     HashSet::from_iter(vec!["sin", "PI", "E"])
270/// );
271/// assert_eq!(undefined_vars["PI"].location_offset(), 15);
272/// # Ok(())
273/// # }
274/// ```
275pub trait CompilerExt<'a> {
276    /// Returns variables not defined within the AST node, together with the span of their first
277    /// occurrence.
278    ///
279    /// # Errors
280    ///
281    /// - Returns an error if the AST is intrinsically malformed. This may be the case if it
282    ///   contains destructuring with the same variable on left-hand side,
283    ///   such as `(x, x) = ...`.
284    ///
285    /// The fact that an error is *not* returned does not guarantee that the AST node will evaluate
286    /// successfully if all variables are assigned.
287    fn undefined_variables(&self) -> Result<HashMap<&'a str, Spanned<'a>>, Error>;
288}
289
290impl<'a, T: Grammar> CompilerExt<'a> for Block<'a, T> {
291    fn undefined_variables(&self) -> Result<HashMap<&'a str, Spanned<'a>>, Error> {
292        CompilerExtTarget::Block(self).get_undefined_variables()
293    }
294}
295
296impl<'a, T: Grammar> CompilerExt<'a> for FnDefinition<'a, T> {
297    fn undefined_variables(&self) -> Result<HashMap<&'a str, Spanned<'a>>, Error> {
298        CompilerExtTarget::FnDefinition(self).get_undefined_variables()
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use arithmetic_parser::{
305        grammars::{F32Grammar, Parse, ParseLiteral, Typed, Untyped},
306        Expr, Location, NomResult,
307    };
308    use nom::Parser as _;
309
310    use super::*;
311    use crate::{exec::WildcardId, Environment, Value};
312
313    #[test]
314    fn compilation_basics() {
315        let block = "x = 3; 1 + { y = 2; y * x } == 7";
316        let block = Untyped::<F32Grammar>::parse_statements(block).unwrap();
317        let module = Compiler::compile_module(WildcardId, &block).unwrap();
318        let value = module.with_env(&Environment::new()).unwrap().run().unwrap();
319        assert_eq!(value, Value::Bool(true));
320    }
321
322    #[test]
323    fn compiled_function() {
324        let block = "add = |x, y| x + y; add(2, 3) == 5";
325        let block = Untyped::<F32Grammar>::parse_statements(block).unwrap();
326        let module = Compiler::compile_module(WildcardId, &block).unwrap();
327        let value = module.with_env(&Environment::new()).unwrap().run().unwrap();
328        assert_eq!(value, Value::Bool(true));
329    }
330
331    #[test]
332    fn compiled_function_with_capture() {
333        let block = "A = 2; add = |x, y| x + y / A; add(2, 3) == 3.5";
334        let block = Untyped::<F32Grammar>::parse_statements(block).unwrap();
335        let module = Compiler::compile_module(WildcardId, &block).unwrap();
336        let value = module.with_env(&Environment::new()).unwrap().run().unwrap();
337        assert_eq!(value, Value::Bool(true));
338    }
339
340    #[test]
341    fn variable_extraction() {
342        let def = "|a, b| ({ x = a * b + y; x - 2 }, a / b)";
343        let def = Untyped::<F32Grammar>::parse_statements(def)
344            .unwrap()
345            .return_value
346            .unwrap();
347        let Expr::FnDefinition(def) = def.extra else {
348            panic!("Unexpected function parsing result: {def:?}");
349        };
350
351        let captures = def.undefined_variables().unwrap();
352        assert_eq!(captures["y"].location_offset(), 22);
353        assert!(!captures.contains_key("x"));
354    }
355
356    #[test]
357    fn variable_extraction_with_scoping() {
358        let def = "|a, b| ({ x = a * b + y; x - 2 }, a / x)";
359        let def = Untyped::<F32Grammar>::parse_statements(def)
360            .unwrap()
361            .return_value
362            .unwrap();
363        let Expr::FnDefinition(def) = def.extra else {
364            panic!("Unexpected function parsing result: {def:?}");
365        };
366
367        let captures = def.undefined_variables().unwrap();
368        assert_eq!(captures["y"].location_offset(), 22);
369        assert_eq!(captures["x"].location_offset(), 38);
370    }
371
372    #[test]
373    fn extracting_captures() {
374        let program = "y = 5 * x; y - 3 + x";
375        let module = Untyped::<F32Grammar>::parse_statements(program).unwrap();
376        let captures = Compiler::extract_captures(Arc::new(WildcardId), &module).unwrap();
377
378        let captures: Vec<_> = captures.iter().collect();
379        assert_eq!(captures.len(), 1);
380        assert_eq!(captures[0], ("x", &Location::from_str(program, 8..9)));
381    }
382
383    #[test]
384    fn extracting_captures_with_inner_fns() {
385        let program = "
386            y = 5 * x;          // x is a capture
387            fun = |z| {         // z is not a capture
388                z * x + y * PI  // y is not a capture for the entire module, PI is
389            };
390        ";
391        let module = Untyped::<F32Grammar>::parse_statements(program).unwrap();
392
393        let captures = Compiler::extract_captures(Arc::new(WildcardId), &module).unwrap();
394        assert_eq!(captures.len(), 2);
395
396        assert!(captures.contains("PI"));
397        let x_location = captures.location("x").unwrap();
398        assert_eq!(x_location.location_line(), 2); // should be the first mention
399    }
400
401    #[test]
402    fn type_casts_are_ignored() -> anyhow::Result<()> {
403        struct TypedGrammar;
404
405        impl ParseLiteral for TypedGrammar {
406            type Lit = f32;
407
408            fn parse_literal(input: InputSpan<'_>) -> NomResult<'_, Self::Lit> {
409                F32Grammar::parse_literal(input)
410            }
411        }
412
413        impl Grammar for TypedGrammar {
414            type Type<'a> = ();
415
416            fn parse_type(input: InputSpan<'_>) -> NomResult<'_, Self::Type<'_>> {
417                use nom::{bytes::complete::tag, combinator::map};
418                map(tag("Num"), drop).parse(input)
419            }
420        }
421
422        let block = "x = 3 as Num; 1 + { y = 2; y * x as Num } == 7";
423        let block = Typed::<TypedGrammar>::parse_statements(block)?;
424        let module = Compiler::compile_module(WildcardId, &block)?;
425        let value = module.with_env(&Environment::new())?.run()?;
426        assert_eq!(value, Value::Bool(true));
427        Ok(())
428    }
429}