arithmetic_eval/fns/
mod.rs

1//! Standard functions for the interpreter, and the tools to define new native functions.
2//!
3//! # Defining native functions
4//!
5//! There are several ways to define new native functions:
6//!
7//! - Implement [`NativeFn`] manually. This is the most versatile approach, but it can be overly
8//!   verbose.
9//! - Use [`FnWrapper`] or the [`wrap`] function. This allows specifying arguments / output
10//!   with custom types (such as `bool` or a [`Number`]).
11//!
12//! [`Number`]: crate::Number
13
14use core::{cmp::Ordering, fmt};
15
16use once_cell::unsync::OnceCell;
17
18#[cfg(feature = "std")]
19pub use self::std::Dbg;
20pub use self::{
21    array::{All, Any, Array, Filter, Fold, Len, Map, Merge, Push},
22    assertions::{Assert, AssertClose, AssertEq, AssertFails},
23    flow::{If, While},
24    wrapper::{
25        wrap, Binary, ErrorOutput, FnWrapper, FromValueError, FromValueErrorKind,
26        FromValueErrorLocation, IntoEvalResult, Quaternary, Ternary, TryFromValue, Unary,
27    },
28};
29use crate::{
30    alloc::{vec, Vec},
31    error::AuxErrorInfo,
32    CallContext, Error, ErrorKind, EvalResult, Function, NativeFn, OpaqueRef, SpannedValue, Value,
33};
34
35mod array;
36mod assertions;
37mod flow;
38#[cfg(feature = "std")]
39mod std;
40mod wrapper;
41
42fn extract_primitive<T, A>(
43    ctx: &CallContext<'_, A>,
44    value: SpannedValue<T>,
45    error_msg: &str,
46) -> Result<T, Error> {
47    match value.extra {
48        Value::Prim(value) => Ok(value),
49        _ => Err(ctx
50            .call_site_error(ErrorKind::native(error_msg))
51            .with_location(&value, AuxErrorInfo::InvalidArg)),
52    }
53}
54
55fn extract_array<T, A>(
56    ctx: &CallContext<'_, A>,
57    value: SpannedValue<T>,
58    error_msg: &str,
59) -> Result<Vec<Value<T>>, Error> {
60    if let Value::Tuple(array) = value.extra {
61        Ok(array.into())
62    } else {
63        let err = ErrorKind::native(error_msg);
64        Err(ctx
65            .call_site_error(err)
66            .with_location(&value, AuxErrorInfo::InvalidArg))
67    }
68}
69
70fn extract_fn<T, A>(
71    ctx: &CallContext<'_, A>,
72    value: SpannedValue<T>,
73    error_msg: &str,
74) -> Result<Function<T>, Error> {
75    if let Value::Function(function) = value.extra {
76        Ok(function)
77    } else {
78        let err = ErrorKind::native(error_msg);
79        Err(ctx
80            .call_site_error(err)
81            .with_location(&value, AuxErrorInfo::InvalidArg))
82    }
83}
84
85/// Comparator functions on two primitive arguments. All functions use [`Arithmetic`] to determine
86/// ordering between the args.
87///
88/// # Type
89///
90/// ```text
91/// fn(Num, Num) -> Ordering // for `Compare::Raw`
92/// fn(Num, Num) -> Num // for `Compare::Min` and `Compare::Max`
93/// ```
94///
95/// [`Arithmetic`]: crate::arith::Arithmetic
96///
97/// # Examples
98///
99/// Using `min` function:
100///
101/// ```
102/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
103/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
104/// # fn main() -> anyhow::Result<()> {
105/// let program = "
106///     // Finds a minimum number in an array.
107///     extended_min = |...xs| fold(xs, INFINITY, min);
108///     extended_min(2, -3, 7, 1, 3) == -3
109/// ";
110/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
111/// let module = ExecutableModule::new("test_min", &program)?;
112///
113/// let mut env = Environment::new();
114/// env.insert("INFINITY", Value::Prim(f32::INFINITY))
115///     .insert_native_fn("fold", fns::Fold)
116///     .insert_native_fn("min", fns::Compare::Min);
117/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
118/// # Ok(())
119/// # }
120/// ```
121///
122/// Using `cmp` function with [`Comparisons`](crate::env::Comparisons).
123///
124/// ```
125/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
126/// # use arithmetic_eval::{fns, env::Comparisons, Environment, ExecutableModule, Value};
127/// # fn main() -> anyhow::Result<()> {
128/// let program = "
129///     map((1, -7, 0, 2), |x| cmp(x, 0)) == (GREATER, LESS, EQUAL, GREATER)
130/// ";
131/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
132/// let module = ExecutableModule::new("test_cmp", &program)?;
133///
134/// let mut env = Environment::new();
135/// env.extend(Comparisons::iter());
136/// env.insert_native_fn("map", fns::Map);
137/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
138/// # Ok(())
139/// # }
140/// ```
141#[derive(Debug, Clone, Copy)]
142#[non_exhaustive]
143pub enum Compare {
144    /// Returns an [`Ordering`] wrapped into an [`OpaqueRef`](crate::OpaqueRef),
145    /// or [`Value::void()`] if the provided values are not comparable.
146    Raw,
147    /// Returns the minimum of the two values. If the values are equal / not comparable, returns the first one.
148    Min,
149    /// Returns the maximum of the two values. If the values are equal / not comparable, returns the first one.
150    Max,
151}
152
153impl Compare {
154    fn extract_primitives<T>(
155        mut args: Vec<SpannedValue<T>>,
156        ctx: &mut CallContext<'_, T>,
157    ) -> Result<(T, T), Error> {
158        ctx.check_args_count(&args, 2)?;
159        let y = args.pop().unwrap();
160        let x = args.pop().unwrap();
161        let x = extract_primitive(ctx, x, COMPARE_ERROR_MSG)?;
162        let y = extract_primitive(ctx, y, COMPARE_ERROR_MSG)?;
163        Ok((x, y))
164    }
165}
166
167const COMPARE_ERROR_MSG: &str = "Compare requires 2 primitive arguments";
168
169impl<T> NativeFn<T> for Compare {
170    fn evaluate(&self, args: Vec<SpannedValue<T>>, ctx: &mut CallContext<'_, T>) -> EvalResult<T> {
171        let (x, y) = Self::extract_primitives(args, ctx)?;
172        let maybe_ordering = ctx.arithmetic().partial_cmp(&x, &y);
173
174        if let Self::Raw = self {
175            Ok(maybe_ordering.map_or_else(Value::void, Value::opaque_ref))
176        } else {
177            let ordering =
178                maybe_ordering.ok_or_else(|| ctx.call_site_error(ErrorKind::CannotCompare))?;
179            let value = match (ordering, self) {
180                (Ordering::Equal, _)
181                | (Ordering::Less, Self::Min)
182                | (Ordering::Greater, Self::Max) => x,
183                _ => y,
184            };
185            Ok(Value::Prim(value))
186        }
187    }
188}
189
190/// Allows to define a value recursively, by referencing a value being created.
191///
192/// It works like this:
193///
194/// - Provide a function as the only argument. The (only) argument of this function is the value
195///   being created.
196/// - Do not use the uninitialized value synchronously; only use it in inner function definitions.
197/// - Return the created value from a function.
198///
199/// # Examples
200///
201/// ```
202/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
203/// # use arithmetic_eval::{fns, env::Comparisons, Environment, ExecutableModule, Value};
204/// # fn main() -> anyhow::Result<()> {
205/// let program = "
206///     recursive_fib = defer(|fib| {
207///         |n| if(n >= 0 && n <= 1, || n, || fib(n - 1) + fib(n - 2))()
208///     });
209///     recursive_fib(10)
210/// ";
211/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
212/// let module = ExecutableModule::new("test_defer", &program)?;
213///
214/// let mut env = Environment::new();
215/// env.extend(Comparisons::iter());
216/// env.insert_native_fn("if", fns::If).insert_native_fn("defer", fns::Defer);
217/// assert_eq!(module.with_env(&env)?.run()?, Value::Prim(55.0));
218/// # Ok(())
219/// # }
220/// ```
221#[derive(Debug, Clone, Copy, Default)]
222pub struct Defer;
223
224impl<T: Clone + 'static> NativeFn<T> for Defer {
225    fn evaluate(
226        &self,
227        mut args: Vec<SpannedValue<T>>,
228        ctx: &mut CallContext<'_, T>,
229    ) -> EvalResult<T> {
230        const ARG_ERROR: &str = "Argument must be a function";
231
232        ctx.check_args_count(&args, 1)?;
233        let function = extract_fn(ctx, args.pop().unwrap(), ARG_ERROR)?;
234        let cell = OpaqueRef::with_identity_eq(ValueCell::<T>::default());
235        let spanned_cell = ctx.apply_call_location(Value::Ref(cell.clone()));
236        let return_value = function.evaluate(vec![spanned_cell], ctx)?;
237
238        let cell = cell.downcast_ref::<ValueCell<T>>().unwrap();
239        // ^ `unwrap()` is safe by construction
240        cell.set(return_value.clone());
241        Ok(return_value)
242    }
243}
244
245#[derive(Debug)]
246pub(crate) struct ValueCell<T> {
247    inner: OnceCell<Value<T>>,
248}
249
250impl<T> Default for ValueCell<T> {
251    fn default() -> Self {
252        Self {
253            inner: OnceCell::new(),
254        }
255    }
256}
257
258impl<T: 'static + fmt::Debug> From<ValueCell<T>> for Value<T> {
259    fn from(cell: ValueCell<T>) -> Self {
260        Self::Ref(OpaqueRef::with_identity_eq(cell))
261    }
262}
263
264impl<T> ValueCell<T> {
265    /// Gets the internally stored value, or `None` if the cell was not initialized yet.
266    pub fn get(&self) -> Option<&Value<T>> {
267        self.inner.get()
268    }
269
270    fn set(&self, value: Value<T>) {
271        self.inner
272            .set(value)
273            .map_err(drop)
274            .expect("Repeated `ValueCell` assignment");
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
281    use assert_matches::assert_matches;
282
283    use super::*;
284    use crate::{
285        env::Environment,
286        exec::{ExecutableModule, WildcardId},
287    };
288
289    #[test]
290    fn if_basics() -> anyhow::Result<()> {
291        let block = "
292            x = 1.0;
293            if(x < 2, x + 5, 3 - x)
294        ";
295        let block = Untyped::<F32Grammar>::parse_statements(block)?;
296        let module = ExecutableModule::new(WildcardId, &block)?;
297        let mut env = Environment::new();
298        env.insert_native_fn("if", If);
299        assert_eq!(module.with_env(&env)?.run()?, Value::Prim(6.0));
300        Ok(())
301    }
302
303    #[test]
304    fn if_with_closures() -> anyhow::Result<()> {
305        let block = "
306            x = 4.5;
307            if(x < 2, || x + 5, || 3 - x)()
308        ";
309        let block = Untyped::<F32Grammar>::parse_statements(block)?;
310        let module = ExecutableModule::new(WildcardId, &block)?;
311        let mut env = Environment::new();
312        env.insert_native_fn("if", If);
313        assert_eq!(module.with_env(&env)?.run()?, Value::Prim(-1.5));
314        Ok(())
315    }
316
317    #[test]
318    fn cmp_sugar() -> anyhow::Result<()> {
319        let program = "x = 1.0; x > 0 && x <= 3";
320        let block = Untyped::<F32Grammar>::parse_statements(program)?;
321        let module = ExecutableModule::new(WildcardId, &block)?;
322        assert_eq!(
323            module.with_env(&Environment::new())?.run()?,
324            Value::Bool(true)
325        );
326
327        let bogus_program = "x = 1.0; x > (1, 2)";
328        let bogus_block = Untyped::<F32Grammar>::parse_statements(bogus_program)?;
329        let bogus_module = ExecutableModule::new(WildcardId, &bogus_block)?;
330
331        let err = bogus_module
332            .with_env(&Environment::new())?
333            .run()
334            .unwrap_err();
335        let err = err.source();
336        assert_matches!(err.kind(), ErrorKind::CannotCompare);
337        assert_eq!(err.location().in_module().span(bogus_program), "(1, 2)");
338        Ok(())
339    }
340
341    #[test]
342    fn while_basic() -> anyhow::Result<()> {
343        let program = "
344            // Finds the greatest power of 2 lesser or equal to the value.
345            discrete_log2 = |x| {
346                while(0, |i| 2^i <= x, |i| i + 1) - 1
347            };
348
349            (discrete_log2(1), discrete_log2(2),
350                discrete_log2(4), discrete_log2(6.5), discrete_log2(1000))
351        ";
352        let block = Untyped::<F32Grammar>::parse_statements(program)?;
353
354        let module = ExecutableModule::new(WildcardId, &block)?;
355        let mut env = Environment::new();
356        env.insert_native_fn("while", While)
357            .insert_native_fn("if", If);
358
359        assert_eq!(
360            module.with_env(&env)?.run()?,
361            Value::from(vec![
362                Value::Prim(0.0),
363                Value::Prim(1.0),
364                Value::Prim(2.0),
365                Value::Prim(2.0),
366                Value::Prim(9.0),
367            ])
368        );
369        Ok(())
370    }
371
372    #[test]
373    fn max_value_with_fold() -> anyhow::Result<()> {
374        let program = "
375            max_value = |...xs| {
376                fold(xs, -Inf, |acc, x| if(x > acc, x, acc))
377            };
378            max_value(1, -2, 7, 2, 5) == 7 && max_value(3, -5, 9) == 9
379        ";
380        let block = Untyped::<F32Grammar>::parse_statements(program)?;
381
382        let module = ExecutableModule::new(WildcardId, &block)?;
383        let mut env = Environment::new();
384        env.insert("Inf", Value::Prim(f32::INFINITY))
385            .insert_native_fn("fold", Fold)
386            .insert_native_fn("if", If);
387
388        assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
389        Ok(())
390    }
391
392    #[test]
393    fn reverse_list_with_fold() -> anyhow::Result<()> {
394        const SAMPLES: &[(&[f32], &[f32])] = &[
395            (&[1.0, 2.0, 3.0], &[3.0, 2.0, 1.0]),
396            (&[], &[]),
397            (&[1.0], &[1.0]),
398        ];
399
400        let program = "
401            reverse = |xs| {
402                fold(xs, (), |acc, x| merge((x,), acc))
403            };
404            xs = (-4, 3, 0, 1);
405            reverse(xs) == (1, 0, 3, -4)
406        ";
407        let block = Untyped::<F32Grammar>::parse_statements(program)?;
408        let module = ExecutableModule::new(WildcardId, &block)?;
409
410        let mut env = Environment::new();
411        env.insert_native_fn("merge", Merge)
412            .insert_native_fn("fold", Fold);
413
414        assert_eq!(module.with_mutable_env(&mut env)?.run()?, Value::Bool(true));
415
416        let test_block = Untyped::<F32Grammar>::parse_statements("reverse(xs)")?;
417        let test_module = ExecutableModule::new("test", &test_block)?;
418
419        for &(input, expected) in SAMPLES {
420            let input = input.iter().copied().map(Value::Prim).collect();
421            let expected = expected.iter().copied().map(Value::Prim).collect();
422            env.insert("xs", Value::Tuple(input));
423            assert_eq!(test_module.with_env(&env)?.run()?, Value::Tuple(expected));
424        }
425        Ok(())
426    }
427
428    #[test]
429    fn error_with_min_function_args() -> anyhow::Result<()> {
430        let program = "5 - min(1, (2, 3))";
431        let block = Untyped::<F32Grammar>::parse_statements(program)?;
432        let module = ExecutableModule::new(WildcardId, &block)?;
433        let mut env = Environment::new();
434        env.insert_native_fn("min", Compare::Min);
435
436        let err = module.with_env(&env)?.run().unwrap_err();
437        let err = err.source();
438        assert_eq!(err.location().in_module().span(program), "min(1, (2, 3))");
439        assert_matches!(
440            err.kind(),
441            ErrorKind::NativeCall(ref msg) if msg.contains("requires 2 primitive arguments")
442        );
443        Ok(())
444    }
445
446    #[test]
447    fn error_with_min_function_incomparable_args() -> anyhow::Result<()> {
448        let program = "5 - min(1, NAN)";
449        let block = Untyped::<F32Grammar>::parse_statements(program)?;
450        let module = ExecutableModule::new(WildcardId, &block)?;
451        let mut env = Environment::new();
452        env.insert("NAN", Value::Prim(f32::NAN))
453            .insert_native_fn("min", Compare::Min);
454
455        let err = module.with_env(&env)?.run().unwrap_err();
456        let err = err.source();
457        assert_eq!(err.location().in_module().span(program), "min(1, NAN)");
458        assert_matches!(err.kind(), ErrorKind::CannotCompare);
459        Ok(())
460    }
461}