arithmetic_eval/fns/wrapper/
mod.rs

1//! Wrapper for eloquent `NativeFn` definitions.
2
3use core::{fmt, marker::PhantomData};
4
5pub use self::traits::{
6    ErrorOutput, FromValueError, FromValueErrorKind, FromValueErrorLocation, IntoEvalResult,
7    TryFromValue,
8};
9use crate::{
10    CallContext, ErrorKind, EvalResult, NativeFn, SpannedValue, alloc::Vec, error::AuxErrorInfo,
11};
12
13mod traits;
14
15/// Wraps a function enriching it with the information about its arguments.
16/// This is a slightly shorter way to create wrappers compared to calling [`FnWrapper::new()`].
17///
18/// See [`FnWrapper`] for more details on function requirements.
19pub const fn wrap<const CTX: bool, T, F>(function: F) -> FnWrapper<T, F, CTX> {
20    FnWrapper::new(function)
21}
22
23/// Wrapper of a function containing information about its arguments.
24///
25/// Using `FnWrapper` allows to define [native functions](NativeFn) with minimum boilerplate
26/// and with increased type safety. `FnWrapper`s can be constructed explicitly or indirectly
27/// via [`Environment::insert_wrapped_fn()`], [`Value::wrapped_fn()`], or [`wrap()`].
28///
29/// Arguments of a wrapped function must implement [`TryFromValue`] trait for the applicable
30/// grammar, and the output type must implement [`IntoEvalResult`]. If you need [`CallContext`] (e.g.,
31/// to call functions provided as an argument), it should be specified as a first argument.
32///
33/// [`Environment::insert_wrapped_fn()`]: crate::Environment::insert_wrapped_fn()
34/// [`Value::wrapped_fn()`]: crate::Value::wrapped_fn()
35///
36/// # Examples
37///
38/// ## Basic function
39///
40/// ```
41/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
42/// use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
43///
44/// # fn main() -> anyhow::Result<()> {
45/// let max = fns::wrap(|x: f32, y: f32| if x > y { x } else { y });
46///
47/// let program = "max(1, 3) == 3 && max(-1, -3) == -1";
48/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
49/// let module = ExecutableModule::new("test_max", &program)?;
50///
51/// let mut env = Environment::new();
52/// env.insert_native_fn("max", max);
53/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
54/// # Ok(())
55/// # }
56/// ```
57///
58/// ## Fallible function with complex args
59///
60/// ```
61/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
62/// # use arithmetic_eval::{fns::FnWrapper, Environment, ExecutableModule, Value};
63/// fn zip_arrays(xs: Vec<f32>, ys: Vec<f32>) -> Result<Vec<(f32, f32)>, String> {
64///     if xs.len() == ys.len() {
65///         Ok(xs.into_iter().zip(ys).map(|(x, y)| (x, y)).collect())
66///     } else {
67///         Err("Arrays must have the same size".to_owned())
68///     }
69/// }
70///
71/// # fn main() -> anyhow::Result<()> {
72/// let program = "zip((1, 2, 3), (4, 5, 6)) == ((1, 4), (2, 5), (3, 6))";
73/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
74/// let module = ExecutableModule::new("test_zip", &program)?;
75///
76/// let mut env = Environment::new();
77/// env.insert_wrapped_fn("zip", zip_arrays);
78/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
79/// # Ok(())
80/// # }
81/// ```
82///
83/// ## Using `CallContext` to call functions
84///
85/// ```
86/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
87/// # use arithmetic_eval::{CallContext, Function, Environment, Value, ExecutableModule, Error};
88/// fn map_array(
89///     context: &mut CallContext<'_, f32>,
90///     array: Vec<Value<f32>>,
91///     map_fn: Function<f32>,
92/// ) -> Result<Vec<Value<f32>>, Error> {
93///     array
94///         .into_iter()
95///         .map(|value| {
96///             let arg = context.apply_call_location(value);
97///             map_fn.evaluate(vec![arg], context)
98///         })
99///         .collect()
100/// }
101///
102/// # fn main() -> anyhow::Result<()> {
103/// let program = "map((1, 2, 3), |x| x + 3) == (4, 5, 6)";
104/// let module = Untyped::<F32Grammar>::parse_statements(program)?;
105/// let module = ExecutableModule::new("test", &module)?;
106///
107/// let mut env = Environment::new();
108/// env.insert_wrapped_fn("map", map_array);
109/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
110/// # Ok(())
111/// # }
112/// ```
113pub struct FnWrapper<T, F, const CTX: bool = false> {
114    function: F,
115    _arg_types: PhantomData<T>,
116}
117
118impl<T, F, const CTX: bool> fmt::Debug for FnWrapper<T, F, CTX>
119where
120    F: fmt::Debug,
121{
122    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
123        formatter
124            .debug_struct("FnWrapper")
125            .field("function", &self.function)
126            .field("context", &CTX)
127            .finish()
128    }
129}
130
131impl<T, F: Clone, const CTX: bool> Clone for FnWrapper<T, F, CTX> {
132    fn clone(&self) -> Self {
133        Self {
134            function: self.function.clone(),
135            _arg_types: PhantomData,
136        }
137    }
138}
139
140impl<T, F: Copy, const CTX: bool> Copy for FnWrapper<T, F, CTX> {}
141
142// Ideally, we would want to constrain `T` and `F`, but this would make it impossible to declare
143// the constructor as `const fn`; see https://github.com/rust-lang/rust/issues/57563.
144impl<T, F, const CTX: bool> FnWrapper<T, F, CTX> {
145    /// Creates a new wrapper.
146    ///
147    /// Note that the created wrapper is not guaranteed to be usable as [`NativeFn`]. For this
148    /// to be the case, `function` needs to be a function or an [`Fn`] closure,
149    /// and the `T` type argument needs to be a tuple with the function return type
150    /// and the argument types (in this order).
151    pub const fn new(function: F) -> Self {
152        Self {
153            function,
154            _arg_types: PhantomData,
155        }
156    }
157}
158
159macro_rules! arity_fn {
160    ($arity:tt, $with_ctx:tt $(, $ctx_name:ident : $ctx_t:ty)? => $($arg_name:ident : $t:ident),*) => {
161        impl<Num, F, Ret, $($t,)*> NativeFn<Num> for FnWrapper<(Ret, $($t,)*), F, $with_ctx>
162        where
163            F: Fn($($ctx_t,)? $($t,)*) -> Ret,
164            $($t: TryFromValue<Num>,)*
165            Ret: IntoEvalResult<Num>,
166        {
167            #[allow(clippy::shadow_unrelated)] // makes it easier to write macro
168            #[allow(unused_variables, unused_mut)] // `args_iter` is unused for 0-ary functions
169            fn evaluate(
170                &self,
171                args: Vec<SpannedValue<Num>>,
172                context: &mut CallContext<'_, Num>,
173            ) -> EvalResult<Num> {
174                context.check_args_count(&args, $arity)?;
175                let mut args_iter = args.into_iter().enumerate();
176
177                $(
178                    let (index, $arg_name) = args_iter.next().unwrap();
179                    let span = $arg_name.with_no_extra();
180                    let $arg_name = $t::try_from_value($arg_name.extra).map_err(|mut err| {
181                        err.set_arg_index(index);
182                        context
183                            .call_site_error(ErrorKind::Wrapper(err))
184                            .with_location(&span, AuxErrorInfo::InvalidArg)
185                    })?;
186                )*
187
188                $(let $ctx_name = &mut *context;)?
189                let output = (self.function)($($ctx_name,)? $($arg_name,)*);
190                output.into_eval_result().map_err(|err| err.into_spanned(context))
191            }
192        }
193    };
194}
195
196arity_fn!(0, false =>);
197arity_fn!(0, true, ctx: &mut CallContext<'_, Num> =>);
198arity_fn!(1, false => x0: T);
199arity_fn!(1, true, ctx: &mut CallContext<'_, Num> => x0: T);
200arity_fn!(2, false => x0: T, x1: U);
201arity_fn!(2, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U);
202arity_fn!(3, false => x0: T, x1: U, x2: V);
203arity_fn!(3, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V);
204arity_fn!(4, false => x0: T, x1: U, x2: V, x3: W);
205arity_fn!(4, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V, x3: W);
206arity_fn!(5, false => x0: T, x1: U, x2: V, x3: W, x4: X);
207arity_fn!(5, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V, x3: W, x4: X);
208arity_fn!(6, false => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y);
209arity_fn!(6, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y);
210arity_fn!(7, false => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z);
211arity_fn!(7, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z);
212arity_fn!(8, false => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A);
213arity_fn!(8, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A);
214arity_fn!(9, false => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A, x8: B);
215arity_fn!(9, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A, x8: B);
216arity_fn!(10, false => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A, x8: B, x9: C);
217arity_fn!(10, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A, x8: B, x9: C);
218
219/// Unary function wrapper.
220pub type Unary<T> = FnWrapper<(T, T), fn(T) -> T>;
221
222/// Binary function wrapper.
223pub type Binary<T> = FnWrapper<(T, T, T), fn(T, T) -> T>;
224
225/// Ternary function wrapper.
226pub type Ternary<T> = FnWrapper<(T, T, T, T), fn(T, T, T) -> T>;
227
228/// Quaternary function wrapper.
229pub type Quaternary<T> = FnWrapper<(T, T, T, T, T), fn(T, T, T, T) -> T>;
230
231#[cfg(test)]
232mod tests {
233    use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
234    use assert_matches::assert_matches;
235
236    use super::*;
237    use crate::{
238        Function, Object, Tuple, Value,
239        alloc::{ToOwned, format},
240        env::{Environment, Prelude},
241        exec::{ExecutableModule, WildcardId},
242    };
243
244    #[test]
245    fn functions_with_primitive_args() -> anyhow::Result<()> {
246        let unary_fn = Unary::new(|x: f32| x + 3.0);
247        let binary_fn = Binary::new(f32::min);
248        let ternary_fn = Ternary::new(|x: f32, y, z| if x > 0.0 { y } else { z });
249
250        let program = "
251            unary_fn(2) == 5 && binary_fn(1, -3) == -3 &&
252                ternary_fn(1, 2, 3) == 2 && ternary_fn(-1, 2, 3) == 3
253        ";
254        let block = Untyped::<F32Grammar>::parse_statements(program)?;
255        let module = ExecutableModule::new(WildcardId, &block)?;
256
257        let mut env = Environment::new();
258        env.insert_native_fn("unary_fn", unary_fn)
259            .insert_native_fn("binary_fn", binary_fn)
260            .insert_native_fn("ternary_fn", ternary_fn);
261
262        assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
263        Ok(())
264    }
265
266    fn array_min_max(values: Vec<f32>) -> (f32, f32) {
267        let mut min = f32::INFINITY;
268        let mut max = f32::NEG_INFINITY;
269
270        for value in values {
271            if value < min {
272                min = value;
273            }
274            if value > max {
275                max = value;
276            }
277        }
278        (min, max)
279    }
280
281    fn overly_convoluted_fn(xs: Vec<(f32, f32)>, ys: (Vec<f32>, f32)) -> f32 {
282        xs.into_iter().map(|(a, b)| a + b).sum::<f32>() + ys.0.into_iter().sum::<f32>() + ys.1
283    }
284
285    #[test]
286    fn functions_with_composite_args() -> anyhow::Result<()> {
287        let program = "
288            array_min_max((1, 5, -3, 2, 1)) == (-3, 5) &&
289                total_sum(((1, 2), (3, 4)), ((5, 6, 7), 8)) == 36
290        ";
291        let block = Untyped::<F32Grammar>::parse_statements(program)?;
292        let module = ExecutableModule::new(WildcardId, &block)?;
293
294        let mut env = Environment::new();
295        env.insert_wrapped_fn("array_min_max", array_min_max)
296            .insert_wrapped_fn("total_sum", overly_convoluted_fn);
297
298        assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
299        Ok(())
300    }
301
302    fn sum_arrays(xs: Vec<f32>, ys: Vec<f32>) -> Result<Vec<f32>, String> {
303        if xs.len() == ys.len() {
304            Ok(xs.into_iter().zip(ys).map(|(x, y)| x + y).collect())
305        } else {
306            Err("Summed arrays must have the same size".to_owned())
307        }
308    }
309
310    #[test]
311    fn fallible_function() -> anyhow::Result<()> {
312        let program = "sum_arrays((1, 2, 3), (4, 5, 6)) == (5, 7, 9)";
313        let block = Untyped::<F32Grammar>::parse_statements(program)?;
314        let module = ExecutableModule::new(WildcardId, &block)?;
315
316        let mut env = Environment::new();
317        env.insert_wrapped_fn("sum_arrays", sum_arrays);
318        assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
319        Ok(())
320    }
321
322    #[test]
323    fn fallible_function_with_bogus_program() -> anyhow::Result<()> {
324        let program = "sum_arrays((1, 2, 3), (4, 5))";
325        let block = Untyped::<F32Grammar>::parse_statements(program)?;
326        let module = ExecutableModule::new(WildcardId, &block)?;
327
328        let mut env = Environment::new();
329        env.insert_wrapped_fn("sum_arrays", sum_arrays);
330
331        let err = module.with_env(&env)?.run().unwrap_err();
332        assert!(
333            err.source()
334                .kind()
335                .to_short_string()
336                .contains("Summed arrays must have the same size")
337        );
338        Ok(())
339    }
340
341    #[test]
342    fn function_with_bool_return_value() -> anyhow::Result<()> {
343        let contains = wrap(|(a, b): (f32, f32), x: f32| (a..=b).contains(&x));
344
345        let program = "contains((-1, 2), 0) && !contains((1, 3), 0)";
346        let block = Untyped::<F32Grammar>::parse_statements(program)?;
347        let module = ExecutableModule::new(WildcardId, &block)?;
348
349        let mut env = Environment::new();
350        env.insert_native_fn("contains", contains);
351        assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
352        Ok(())
353    }
354
355    #[test]
356    fn function_with_void_return_value() -> anyhow::Result<()> {
357        let program = "assert_eq(3, 1 + 2)";
358        let block = Untyped::<F32Grammar>::parse_statements(program)?;
359        let module = ExecutableModule::new(WildcardId, &block)?;
360
361        let mut env = Environment::new();
362        env.insert_wrapped_fn("assert_eq", |expected: f32, actual: f32| {
363            if (expected - actual).abs() < f32::EPSILON {
364                Ok(())
365            } else {
366                Err(format!(
367                    "Assertion failed: expected {expected}, got {actual}"
368                ))
369            }
370        });
371
372        assert!(module.with_env(&env)?.run()?.is_void());
373
374        let bogus_program = "assert_eq(3, 1 - 2)";
375        let bogus_block = Untyped::<F32Grammar>::parse_statements(bogus_program)?;
376        let err = ExecutableModule::new(WildcardId, &bogus_block)?
377            .with_env(&env)?
378            .run()
379            .unwrap_err();
380
381        assert_matches!(
382            err.source().kind(),
383            ErrorKind::NativeCall(msg) if msg.contains("Assertion failed")
384        );
385        Ok(())
386    }
387
388    #[test]
389    fn function_with_bool_argument() -> anyhow::Result<()> {
390        let program = "flip_sign(-1, true) == 1 && flip_sign(-1, false) == -1";
391        let block = Untyped::<F32Grammar>::parse_statements(program)?;
392        let module = ExecutableModule::new(WildcardId, &block)?;
393
394        let mut env = Environment::new();
395        env.extend(Prelude::iter());
396        env.insert_wrapped_fn(
397            "flip_sign",
398            |val: f32, flag: bool| if flag { -val } else { val },
399        );
400
401        assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
402        Ok(())
403    }
404
405    #[test]
406    #[allow(clippy::cast_precision_loss)] // fine for this test
407    fn function_with_object_and_tuple() -> anyhow::Result<()> {
408        fn test_function(tuple: Tuple<f32>) -> Object<f32> {
409            let mut obj = Object::default();
410            obj.insert("len", Value::Prim(tuple.len() as f32));
411            obj.insert("tuple", tuple.into());
412            obj
413        }
414
415        let program = "
416            { len, tuple } = test((1, 1, 1));
417            len == 3 && tuple == (1, 1, 1)
418        ";
419        let block = Untyped::<F32Grammar>::parse_statements(program)?;
420        let module = ExecutableModule::new(WildcardId, &block)?;
421
422        let test_function = Value::native_fn(wrap(test_function));
423        let mut env = Environment::new();
424        env.insert("test", test_function).extend(Prelude::iter());
425
426        assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
427        Ok(())
428    }
429
430    #[test]
431    fn error_reporting_with_destructuring() -> anyhow::Result<()> {
432        let program = "destructure(((true, 1), (2, 3)))";
433        let block = Untyped::<F32Grammar>::parse_statements(program)?;
434        let module = ExecutableModule::new(WildcardId, &block)?;
435
436        let mut env = Environment::new();
437        env.extend(Prelude::iter());
438        env.insert_wrapped_fn("destructure", |values: Vec<(bool, f32)>| {
439            values
440                .into_iter()
441                .map(|(flag, x)| if flag { x } else { 0.0 })
442                .sum::<f32>()
443        });
444
445        let err = module.with_env(&env)?.run().unwrap_err();
446        let err_message = err.source().kind().to_short_string();
447        assert!(err_message.contains("Cannot convert primitive value to bool"));
448        assert!(err_message.contains("location: arg0[1].0"));
449        Ok(())
450    }
451
452    #[test]
453    fn function_with_context() -> anyhow::Result<()> {
454        #[allow(clippy::needless_pass_by_value)] // required for wrapping to work
455        fn call(
456            ctx: &mut CallContext<'_, f32>,
457            func: Function<f32>,
458            value: f32,
459        ) -> EvalResult<f32> {
460            let args = vec![ctx.apply_call_location(Value::Prim(value))];
461            func.evaluate(args, ctx)
462        }
463
464        let program = "(|x| { x + 1 }).call(1)";
465        let block = Untyped::<F32Grammar>::parse_statements(program)?;
466        let module = ExecutableModule::new(WildcardId, &block)?;
467
468        let mut env = Environment::new();
469        env.insert_wrapped_fn("call", call);
470        assert_eq!(module.with_env(&env)?.run()?, Value::Prim(2.0));
471        Ok(())
472    }
473}