1use arithmetic_parser::{
4 grammars::Grammar, BinaryOp, Block, Destructure, FnDefinition, InputSpan, Lvalue,
5 ObjectDestructure, Spanned, SpannedLvalue, UnaryOp,
6};
7
8pub(crate) use self::captures::Captures;
9use self::captures::{CapturesExtractor, CompilerExtTarget};
10use crate::{
11 alloc::{Arc, HashMap, String, ToOwned},
12 exec::{Atom, Command, CompiledExpr, Executable, ExecutableModule, FieldName, ModuleId},
13 Error, ErrorKind,
14};
15
16mod captures;
17mod expr;
18
19#[derive(Debug)]
20pub(crate) struct Compiler {
21 vars_to_registers: HashMap<String, usize>,
23 scope_depth: usize,
24 register_count: usize,
25 module_id: Arc<dyn ModuleId>,
26}
27
28impl Compiler {
29 fn new(module_id: Arc<dyn ModuleId>) -> Self {
30 Self {
31 vars_to_registers: HashMap::new(),
32 scope_depth: 0,
33 register_count: 0,
34 module_id,
35 }
36 }
37
38 fn from_env(module_id: Arc<dyn ModuleId>, env: &Captures) -> Self {
39 Self {
40 vars_to_registers: env.variables_map().clone(),
41 register_count: env.len(),
42 scope_depth: 0,
43 module_id,
44 }
45 }
46
47 fn backup(&mut self) -> Self {
49 Self {
50 vars_to_registers: self.vars_to_registers.clone(),
51 scope_depth: self.scope_depth,
52 register_count: self.register_count,
53 module_id: self.module_id.clone(),
54 }
55 }
56
57 fn create_error<T>(&self, span: &Spanned<'_, T>, err: ErrorKind) -> Error {
58 Error::new(self.module_id.clone(), span, err)
59 }
60
61 fn check_unary_op(&self, op: &Spanned<'_, UnaryOp>) -> Result<UnaryOp, Error> {
62 match op.extra {
63 UnaryOp::Neg | UnaryOp::Not => Ok(op.extra),
64 _ => Err(self.create_error(op, ErrorKind::unsupported(op.extra))),
65 }
66 }
67
68 fn check_binary_op(&self, op: &Spanned<'_, BinaryOp>) -> Result<BinaryOp, Error> {
69 match op.extra {
70 BinaryOp::Add
71 | BinaryOp::Sub
72 | BinaryOp::Mul
73 | BinaryOp::Div
74 | BinaryOp::Power
75 | BinaryOp::And
76 | BinaryOp::Or
77 | BinaryOp::Eq
78 | BinaryOp::NotEq
79 | BinaryOp::Gt
80 | BinaryOp::Ge
81 | BinaryOp::Lt
82 | BinaryOp::Le => Ok(op.extra),
83
84 _ => Err(self.create_error(op, ErrorKind::unsupported(op.extra))),
85 }
86 }
87
88 fn get_var(&self, name: &str) -> usize {
89 *self
90 .vars_to_registers
91 .get(name)
92 .expect("Captures must created during module compilation")
93 }
94
95 fn push_assignment<T, U>(
96 &mut self,
97 executable: &mut Executable<T>,
98 rhs: CompiledExpr<T>,
99 rhs_span: &Spanned<'_, U>,
100 ) -> usize {
101 let register = self.register_count;
102 let command = Command::Push(rhs);
103 executable.push_command(rhs_span.copy_with_extra(command));
104 self.register_count += 1;
105 register
106 }
107
108 pub fn compile_module<Id: ModuleId, T: Grammar>(
109 module_id: Id,
110 block: &Block<'_, T>,
111 ) -> Result<ExecutableModule<T::Lit>, Error> {
112 let module_id = Arc::new(module_id) as Arc<dyn ModuleId>;
113 let captures = Self::extract_captures(module_id.clone(), block)?;
114 let mut compiler = Self::from_env(module_id.clone(), &captures);
115
116 let mut executable = Executable::new(module_id);
117 let empty_span = InputSpan::new("");
118 let last_atom = compiler
119 .compile_block_inner(&mut executable, block)?
120 .map_or(Atom::Void, |spanned| spanned.extra);
121 compiler.push_assignment(
123 &mut executable,
124 CompiledExpr::Atom(last_atom),
125 &empty_span.into(),
126 );
127
128 executable.finalize_block(compiler.register_count);
129 Ok(ExecutableModule::from_parts(executable, captures))
130 }
131
132 fn extract_captures<T: Grammar>(
133 module_id: Arc<dyn ModuleId>,
134 block: &Block<'_, T>,
135 ) -> Result<Captures, Error> {
136 let mut extractor = CapturesExtractor::new(module_id);
137 extractor.eval_block(block)?;
138 Ok(extractor.into_captures())
139 }
140
141 fn assign<T, Ty>(
142 &mut self,
143 executable: &mut Executable<T>,
144 lhs: &SpannedLvalue<'_, Ty>,
145 rhs_register: usize,
146 ) -> Result<(), Error> {
147 match &lhs.extra {
148 Lvalue::Variable { .. } => {
149 self.insert_var(executable, lhs.with_no_extra(), rhs_register);
150 }
151
152 Lvalue::Tuple(destructure) => {
153 let span = lhs.with_no_extra();
154 self.destructure(executable, destructure, span, rhs_register)?;
155 }
156
157 Lvalue::Object(destructure) => {
158 let span = lhs.with_no_extra();
159 self.destructure_object(executable, destructure, span, rhs_register)?;
160 }
161
162 _ => {
163 let err = ErrorKind::unsupported(lhs.extra.ty());
164 return Err(self.create_error(lhs, err));
165 }
166 }
167
168 Ok(())
169 }
170
171 fn insert_var<T>(
172 &mut self,
173 executable: &mut Executable<T>,
174 var_span: Spanned<'_>,
175 register: usize,
176 ) {
177 let var_name = *var_span.fragment();
178 if var_name != "_" {
179 self.vars_to_registers.insert(var_name.to_owned(), register);
180
181 if self.scope_depth == 0 {
184 let command = Command::Annotate {
185 register,
186 name: var_name.to_owned(),
187 };
188 executable.push_command(var_span.copy_with_extra(command));
189 }
190 }
191 }
192
193 fn destructure<'a, T, Ty>(
194 &mut self,
195 executable: &mut Executable<T>,
196 destructure: &Destructure<'a, Ty>,
197 span: Spanned<'a>,
198 rhs_register: usize,
199 ) -> Result<(), Error> {
200 let command = Command::Destructure {
201 source: rhs_register,
202 start_len: destructure.start.len(),
203 end_len: destructure.end.len(),
204 lvalue_len: destructure.len(),
205 unchecked: false,
206 };
207 executable.push_command(span.copy_with_extra(command));
208 let start_register = self.register_count;
209 self.register_count += destructure.start.len() + destructure.end.len() + 1;
210
211 for (i, lvalue) in (start_register..).zip(&destructure.start) {
212 self.assign(executable, lvalue, i)?;
213 }
214
215 let start_register = start_register + destructure.start.len();
216 if let Some(middle) = &destructure.middle {
217 if let Some(lvalue) = middle.extra.to_lvalue() {
218 self.assign(executable, &lvalue, start_register)?;
219 }
220 }
221
222 let start_register = start_register + 1;
223 for (i, lvalue) in (start_register..).zip(&destructure.end) {
224 self.assign(executable, lvalue, i)?;
225 }
226
227 Ok(())
228 }
229
230 fn destructure_object<'a, T, Ty>(
231 &mut self,
232 executable: &mut Executable<T>,
233 destructure: &ObjectDestructure<'a, Ty>,
234 span: Spanned<'a>,
235 rhs_register: usize,
236 ) -> Result<(), Error> {
237 for field in &destructure.fields {
238 let field_name = FieldName::Name((*field.field_name.fragment()).to_owned());
239 let field_access = CompiledExpr::FieldAccess {
240 receiver: span.copy_with_extra(Atom::Register(rhs_register)).into(),
241 field: field_name,
242 };
243 let register = self.push_assignment(executable, field_access, &field.field_name);
244 if let Some(binding) = &field.binding {
245 self.assign(executable, binding, register)?;
246 } else {
247 self.insert_var(executable, field.field_name, register);
248 }
249 }
250 Ok(())
251 }
252}
253
254pub trait CompilerExt<'a> {
276 fn undefined_variables(&self) -> Result<HashMap<&'a str, Spanned<'a>>, Error>;
288}
289
290impl<'a, T: Grammar> CompilerExt<'a> for Block<'a, T> {
291 fn undefined_variables(&self) -> Result<HashMap<&'a str, Spanned<'a>>, Error> {
292 CompilerExtTarget::Block(self).get_undefined_variables()
293 }
294}
295
296impl<'a, T: Grammar> CompilerExt<'a> for FnDefinition<'a, T> {
297 fn undefined_variables(&self) -> Result<HashMap<&'a str, Spanned<'a>>, Error> {
298 CompilerExtTarget::FnDefinition(self).get_undefined_variables()
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use arithmetic_parser::{
305 grammars::{F32Grammar, Parse, ParseLiteral, Typed, Untyped},
306 Expr, Location, NomResult,
307 };
308 use nom::Parser as _;
309
310 use super::*;
311 use crate::{exec::WildcardId, Environment, Value};
312
313 #[test]
314 fn compilation_basics() {
315 let block = "x = 3; 1 + { y = 2; y * x } == 7";
316 let block = Untyped::<F32Grammar>::parse_statements(block).unwrap();
317 let module = Compiler::compile_module(WildcardId, &block).unwrap();
318 let value = module.with_env(&Environment::new()).unwrap().run().unwrap();
319 assert_eq!(value, Value::Bool(true));
320 }
321
322 #[test]
323 fn compiled_function() {
324 let block = "add = |x, y| x + y; add(2, 3) == 5";
325 let block = Untyped::<F32Grammar>::parse_statements(block).unwrap();
326 let module = Compiler::compile_module(WildcardId, &block).unwrap();
327 let value = module.with_env(&Environment::new()).unwrap().run().unwrap();
328 assert_eq!(value, Value::Bool(true));
329 }
330
331 #[test]
332 fn compiled_function_with_capture() {
333 let block = "A = 2; add = |x, y| x + y / A; add(2, 3) == 3.5";
334 let block = Untyped::<F32Grammar>::parse_statements(block).unwrap();
335 let module = Compiler::compile_module(WildcardId, &block).unwrap();
336 let value = module.with_env(&Environment::new()).unwrap().run().unwrap();
337 assert_eq!(value, Value::Bool(true));
338 }
339
340 #[test]
341 fn variable_extraction() {
342 let def = "|a, b| ({ x = a * b + y; x - 2 }, a / b)";
343 let def = Untyped::<F32Grammar>::parse_statements(def)
344 .unwrap()
345 .return_value
346 .unwrap();
347 let Expr::FnDefinition(def) = def.extra else {
348 panic!("Unexpected function parsing result: {def:?}");
349 };
350
351 let captures = def.undefined_variables().unwrap();
352 assert_eq!(captures["y"].location_offset(), 22);
353 assert!(!captures.contains_key("x"));
354 }
355
356 #[test]
357 fn variable_extraction_with_scoping() {
358 let def = "|a, b| ({ x = a * b + y; x - 2 }, a / x)";
359 let def = Untyped::<F32Grammar>::parse_statements(def)
360 .unwrap()
361 .return_value
362 .unwrap();
363 let Expr::FnDefinition(def) = def.extra else {
364 panic!("Unexpected function parsing result: {def:?}");
365 };
366
367 let captures = def.undefined_variables().unwrap();
368 assert_eq!(captures["y"].location_offset(), 22);
369 assert_eq!(captures["x"].location_offset(), 38);
370 }
371
372 #[test]
373 fn extracting_captures() {
374 let program = "y = 5 * x; y - 3 + x";
375 let module = Untyped::<F32Grammar>::parse_statements(program).unwrap();
376 let captures = Compiler::extract_captures(Arc::new(WildcardId), &module).unwrap();
377
378 let captures: Vec<_> = captures.iter().collect();
379 assert_eq!(captures.len(), 1);
380 assert_eq!(captures[0], ("x", &Location::from_str(program, 8..9)));
381 }
382
383 #[test]
384 fn extracting_captures_with_inner_fns() {
385 let program = "
386 y = 5 * x; // x is a capture
387 fun = |z| { // z is not a capture
388 z * x + y * PI // y is not a capture for the entire module, PI is
389 };
390 ";
391 let module = Untyped::<F32Grammar>::parse_statements(program).unwrap();
392
393 let captures = Compiler::extract_captures(Arc::new(WildcardId), &module).unwrap();
394 assert_eq!(captures.len(), 2);
395
396 assert!(captures.contains("PI"));
397 let x_location = captures.location("x").unwrap();
398 assert_eq!(x_location.location_line(), 2); }
400
401 #[test]
402 fn type_casts_are_ignored() -> anyhow::Result<()> {
403 struct TypedGrammar;
404
405 impl ParseLiteral for TypedGrammar {
406 type Lit = f32;
407
408 fn parse_literal(input: InputSpan<'_>) -> NomResult<'_, Self::Lit> {
409 F32Grammar::parse_literal(input)
410 }
411 }
412
413 impl Grammar for TypedGrammar {
414 type Type<'a> = ();
415
416 fn parse_type(input: InputSpan<'_>) -> NomResult<'_, Self::Type<'_>> {
417 use nom::{bytes::complete::tag, combinator::map};
418 map(tag("Num"), drop).parse(input)
419 }
420 }
421
422 let block = "x = 3 as Num; 1 + { y = 2; y * x as Num } == 7";
423 let block = Typed::<TypedGrammar>::parse_statements(block)?;
424 let module = Compiler::compile_module(WildcardId, &block)?;
425 let value = module.with_env(&Environment::new())?.run()?;
426 assert_eq!(value, Value::Bool(true));
427 Ok(())
428 }
429}