arithmetic_eval/fns/wrapper/
mod.rs

1//! Wrapper for eloquent `NativeFn` definitions.
2
3use core::{fmt, marker::PhantomData};
4
5pub use self::traits::{
6    ErrorOutput, FromValueError, FromValueErrorKind, FromValueErrorLocation, IntoEvalResult,
7    TryFromValue,
8};
9use crate::{
10    alloc::Vec, error::AuxErrorInfo, CallContext, ErrorKind, EvalResult, NativeFn, SpannedValue,
11};
12
13mod traits;
14
15/// Wraps a function enriching it with the information about its arguments.
16/// This is a slightly shorter way to create wrappers compared to calling [`FnWrapper::new()`].
17///
18/// See [`FnWrapper`] for more details on function requirements.
19pub const fn wrap<const CTX: bool, T, F>(function: F) -> FnWrapper<T, F, CTX> {
20    FnWrapper::new(function)
21}
22
23/// Wrapper of a function containing information about its arguments.
24///
25/// Using `FnWrapper` allows to define [native functions](NativeFn) with minimum boilerplate
26/// and with increased type safety. `FnWrapper`s can be constructed explicitly or indirectly
27/// via [`Environment::insert_wrapped_fn()`], [`Value::wrapped_fn()`], or [`wrap()`].
28///
29/// Arguments of a wrapped function must implement [`TryFromValue`] trait for the applicable
30/// grammar, and the output type must implement [`IntoEvalResult`]. If you need [`CallContext`] (e.g.,
31/// to call functions provided as an argument), it should be specified as a first argument.
32///
33/// [`Environment::insert_wrapped_fn()`]: crate::Environment::insert_wrapped_fn()
34/// [`Value::wrapped_fn()`]: crate::Value::wrapped_fn()
35///
36/// # Examples
37///
38/// ## Basic function
39///
40/// ```
41/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
42/// use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
43///
44/// # fn main() -> anyhow::Result<()> {
45/// let max = fns::wrap(|x: f32, y: f32| if x > y { x } else { y });
46///
47/// let program = "max(1, 3) == 3 && max(-1, -3) == -1";
48/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
49/// let module = ExecutableModule::new("test_max", &program)?;
50///
51/// let mut env = Environment::new();
52/// env.insert_native_fn("max", max);
53/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
54/// # Ok(())
55/// # }
56/// ```
57///
58/// ## Fallible function with complex args
59///
60/// ```
61/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
62/// # use arithmetic_eval::{fns::FnWrapper, Environment, ExecutableModule, Value};
63/// fn zip_arrays(xs: Vec<f32>, ys: Vec<f32>) -> Result<Vec<(f32, f32)>, String> {
64///     if xs.len() == ys.len() {
65///         Ok(xs.into_iter().zip(ys).map(|(x, y)| (x, y)).collect())
66///     } else {
67///         Err("Arrays must have the same size".to_owned())
68///     }
69/// }
70///
71/// # fn main() -> anyhow::Result<()> {
72/// let program = "zip((1, 2, 3), (4, 5, 6)) == ((1, 4), (2, 5), (3, 6))";
73/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
74/// let module = ExecutableModule::new("test_zip", &program)?;
75///
76/// let mut env = Environment::new();
77/// env.insert_wrapped_fn("zip", zip_arrays);
78/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
79/// # Ok(())
80/// # }
81/// ```
82///
83/// ## Using `CallContext` to call functions
84///
85/// ```
86/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
87/// # use arithmetic_eval::{CallContext, Function, Environment, Value, ExecutableModule, Error};
88/// fn map_array(
89///     context: &mut CallContext<'_, f32>,
90///     array: Vec<Value<f32>>,
91///     map_fn: Function<f32>,
92/// ) -> Result<Vec<Value<f32>>, Error> {
93///     array
94///         .into_iter()
95///         .map(|value| {
96///             let arg = context.apply_call_location(value);
97///             map_fn.evaluate(vec![arg], context)
98///         })
99///         .collect()
100/// }
101///
102/// # fn main() -> anyhow::Result<()> {
103/// let program = "map((1, 2, 3), |x| x + 3) == (4, 5, 6)";
104/// let module = Untyped::<F32Grammar>::parse_statements(program)?;
105/// let module = ExecutableModule::new("test", &module)?;
106///
107/// let mut env = Environment::new();
108/// env.insert_wrapped_fn("map", map_array);
109/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
110/// # Ok(())
111/// # }
112/// ```
113pub struct FnWrapper<T, F, const CTX: bool = false> {
114    function: F,
115    _arg_types: PhantomData<T>,
116}
117
118impl<T, F, const CTX: bool> fmt::Debug for FnWrapper<T, F, CTX>
119where
120    F: fmt::Debug,
121{
122    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
123        formatter
124            .debug_struct("FnWrapper")
125            .field("function", &self.function)
126            .field("context", &CTX)
127            .finish()
128    }
129}
130
131impl<T, F: Clone, const CTX: bool> Clone for FnWrapper<T, F, CTX> {
132    fn clone(&self) -> Self {
133        Self {
134            function: self.function.clone(),
135            _arg_types: PhantomData,
136        }
137    }
138}
139
140impl<T, F: Copy, const CTX: bool> Copy for FnWrapper<T, F, CTX> {}
141
142// Ideally, we would want to constrain `T` and `F`, but this would make it impossible to declare
143// the constructor as `const fn`; see https://github.com/rust-lang/rust/issues/57563.
144impl<T, F, const CTX: bool> FnWrapper<T, F, CTX> {
145    /// Creates a new wrapper.
146    ///
147    /// Note that the created wrapper is not guaranteed to be usable as [`NativeFn`]. For this
148    /// to be the case, `function` needs to be a function or an [`Fn`] closure,
149    /// and the `T` type argument needs to be a tuple with the function return type
150    /// and the argument types (in this order).
151    pub const fn new(function: F) -> Self {
152        Self {
153            function,
154            _arg_types: PhantomData,
155        }
156    }
157}
158
159macro_rules! arity_fn {
160    ($arity:tt, $with_ctx:tt $(, $ctx_name:ident : $ctx_t:ty)? => $($arg_name:ident : $t:ident),*) => {
161        impl<Num, F, Ret, $($t,)*> NativeFn<Num> for FnWrapper<(Ret, $($t,)*), F, $with_ctx>
162        where
163            F: Fn($($ctx_t,)? $($t,)*) -> Ret,
164            $($t: TryFromValue<Num>,)*
165            Ret: IntoEvalResult<Num>,
166        {
167            #[allow(clippy::shadow_unrelated)] // makes it easier to write macro
168            #[allow(unused_variables, unused_mut)] // `args_iter` is unused for 0-ary functions
169            fn evaluate(
170                &self,
171                args: Vec<SpannedValue<Num>>,
172                context: &mut CallContext<'_, Num>,
173            ) -> EvalResult<Num> {
174                context.check_args_count(&args, $arity)?;
175                let mut args_iter = args.into_iter().enumerate();
176
177                $(
178                    let (index, $arg_name) = args_iter.next().unwrap();
179                    let span = $arg_name.with_no_extra();
180                    let $arg_name = $t::try_from_value($arg_name.extra).map_err(|mut err| {
181                        err.set_arg_index(index);
182                        context
183                            .call_site_error(ErrorKind::Wrapper(err))
184                            .with_location(&span, AuxErrorInfo::InvalidArg)
185                    })?;
186                )*
187
188                $(let $ctx_name = &mut *context;)?
189                let output = (self.function)($($ctx_name,)? $($arg_name,)*);
190                output.into_eval_result().map_err(|err| err.into_spanned(context))
191            }
192        }
193    };
194}
195
196arity_fn!(0, false =>);
197arity_fn!(0, true, ctx: &mut CallContext<'_, Num> =>);
198arity_fn!(1, false => x0: T);
199arity_fn!(1, true, ctx: &mut CallContext<'_, Num> => x0: T);
200arity_fn!(2, false => x0: T, x1: U);
201arity_fn!(2, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U);
202arity_fn!(3, false => x0: T, x1: U, x2: V);
203arity_fn!(3, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V);
204arity_fn!(4, false => x0: T, x1: U, x2: V, x3: W);
205arity_fn!(4, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V, x3: W);
206arity_fn!(5, false => x0: T, x1: U, x2: V, x3: W, x4: X);
207arity_fn!(5, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V, x3: W, x4: X);
208arity_fn!(6, false => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y);
209arity_fn!(6, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y);
210arity_fn!(7, false => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z);
211arity_fn!(7, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z);
212arity_fn!(8, false => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A);
213arity_fn!(8, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A);
214arity_fn!(9, false => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A, x8: B);
215arity_fn!(9, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A, x8: B);
216arity_fn!(10, false => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A, x8: B, x9: C);
217arity_fn!(10, true, ctx: &mut CallContext<'_, Num> => x0: T, x1: U, x2: V, x3: W, x4: X, x5: Y, x6: Z, x7: A, x8: B, x9: C);
218
219/// Unary function wrapper.
220pub type Unary<T> = FnWrapper<(T, T), fn(T) -> T>;
221
222/// Binary function wrapper.
223pub type Binary<T> = FnWrapper<(T, T, T), fn(T, T) -> T>;
224
225/// Ternary function wrapper.
226pub type Ternary<T> = FnWrapper<(T, T, T, T), fn(T, T, T) -> T>;
227
228/// Quaternary function wrapper.
229pub type Quaternary<T> = FnWrapper<(T, T, T, T, T), fn(T, T, T, T) -> T>;
230
231#[cfg(test)]
232mod tests {
233    use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
234    use assert_matches::assert_matches;
235
236    use super::*;
237    use crate::{
238        alloc::{format, ToOwned},
239        env::{Environment, Prelude},
240        exec::{ExecutableModule, WildcardId},
241        Function, Object, Tuple, Value,
242    };
243
244    #[test]
245    fn functions_with_primitive_args() -> anyhow::Result<()> {
246        let unary_fn = Unary::new(|x: f32| x + 3.0);
247        let binary_fn = Binary::new(f32::min);
248        let ternary_fn = Ternary::new(|x: f32, y, z| if x > 0.0 { y } else { z });
249
250        let program = "
251            unary_fn(2) == 5 && binary_fn(1, -3) == -3 &&
252                ternary_fn(1, 2, 3) == 2 && ternary_fn(-1, 2, 3) == 3
253        ";
254        let block = Untyped::<F32Grammar>::parse_statements(program)?;
255        let module = ExecutableModule::new(WildcardId, &block)?;
256
257        let mut env = Environment::new();
258        env.insert_native_fn("unary_fn", unary_fn)
259            .insert_native_fn("binary_fn", binary_fn)
260            .insert_native_fn("ternary_fn", ternary_fn);
261
262        assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
263        Ok(())
264    }
265
266    fn array_min_max(values: Vec<f32>) -> (f32, f32) {
267        let mut min = f32::INFINITY;
268        let mut max = f32::NEG_INFINITY;
269
270        for value in values {
271            if value < min {
272                min = value;
273            }
274            if value > max {
275                max = value;
276            }
277        }
278        (min, max)
279    }
280
281    fn overly_convoluted_fn(xs: Vec<(f32, f32)>, ys: (Vec<f32>, f32)) -> f32 {
282        xs.into_iter().map(|(a, b)| a + b).sum::<f32>() + ys.0.into_iter().sum::<f32>() + ys.1
283    }
284
285    #[test]
286    fn functions_with_composite_args() -> anyhow::Result<()> {
287        let program = "
288            array_min_max((1, 5, -3, 2, 1)) == (-3, 5) &&
289                total_sum(((1, 2), (3, 4)), ((5, 6, 7), 8)) == 36
290        ";
291        let block = Untyped::<F32Grammar>::parse_statements(program)?;
292        let module = ExecutableModule::new(WildcardId, &block)?;
293
294        let mut env = Environment::new();
295        env.insert_wrapped_fn("array_min_max", array_min_max)
296            .insert_wrapped_fn("total_sum", overly_convoluted_fn);
297
298        assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
299        Ok(())
300    }
301
302    fn sum_arrays(xs: Vec<f32>, ys: Vec<f32>) -> Result<Vec<f32>, String> {
303        if xs.len() == ys.len() {
304            Ok(xs.into_iter().zip(ys).map(|(x, y)| x + y).collect())
305        } else {
306            Err("Summed arrays must have the same size".to_owned())
307        }
308    }
309
310    #[test]
311    fn fallible_function() -> anyhow::Result<()> {
312        let program = "sum_arrays((1, 2, 3), (4, 5, 6)) == (5, 7, 9)";
313        let block = Untyped::<F32Grammar>::parse_statements(program)?;
314        let module = ExecutableModule::new(WildcardId, &block)?;
315
316        let mut env = Environment::new();
317        env.insert_wrapped_fn("sum_arrays", sum_arrays);
318        assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
319        Ok(())
320    }
321
322    #[test]
323    fn fallible_function_with_bogus_program() -> anyhow::Result<()> {
324        let program = "sum_arrays((1, 2, 3), (4, 5))";
325        let block = Untyped::<F32Grammar>::parse_statements(program)?;
326        let module = ExecutableModule::new(WildcardId, &block)?;
327
328        let mut env = Environment::new();
329        env.insert_wrapped_fn("sum_arrays", sum_arrays);
330
331        let err = module.with_env(&env)?.run().unwrap_err();
332        assert!(err
333            .source()
334            .kind()
335            .to_short_string()
336            .contains("Summed arrays must have the same size"));
337        Ok(())
338    }
339
340    #[test]
341    fn function_with_bool_return_value() -> anyhow::Result<()> {
342        let contains = wrap(|(a, b): (f32, f32), x: f32| (a..=b).contains(&x));
343
344        let program = "contains((-1, 2), 0) && !contains((1, 3), 0)";
345        let block = Untyped::<F32Grammar>::parse_statements(program)?;
346        let module = ExecutableModule::new(WildcardId, &block)?;
347
348        let mut env = Environment::new();
349        env.insert_native_fn("contains", contains);
350        assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
351        Ok(())
352    }
353
354    #[test]
355    fn function_with_void_return_value() -> anyhow::Result<()> {
356        let program = "assert_eq(3, 1 + 2)";
357        let block = Untyped::<F32Grammar>::parse_statements(program)?;
358        let module = ExecutableModule::new(WildcardId, &block)?;
359
360        let mut env = Environment::new();
361        env.insert_wrapped_fn("assert_eq", |expected: f32, actual: f32| {
362            if (expected - actual).abs() < f32::EPSILON {
363                Ok(())
364            } else {
365                Err(format!(
366                    "Assertion failed: expected {expected}, got {actual}"
367                ))
368            }
369        });
370
371        assert!(module.with_env(&env)?.run()?.is_void());
372
373        let bogus_program = "assert_eq(3, 1 - 2)";
374        let bogus_block = Untyped::<F32Grammar>::parse_statements(bogus_program)?;
375        let err = ExecutableModule::new(WildcardId, &bogus_block)?
376            .with_env(&env)?
377            .run()
378            .unwrap_err();
379
380        assert_matches!(
381            err.source().kind(),
382            ErrorKind::NativeCall(ref msg) if msg.contains("Assertion failed")
383        );
384        Ok(())
385    }
386
387    #[test]
388    fn function_with_bool_argument() -> anyhow::Result<()> {
389        let program = "flip_sign(-1, true) == 1 && flip_sign(-1, false) == -1";
390        let block = Untyped::<F32Grammar>::parse_statements(program)?;
391        let module = ExecutableModule::new(WildcardId, &block)?;
392
393        let mut env = Environment::new();
394        env.extend(Prelude::iter());
395        env.insert_wrapped_fn(
396            "flip_sign",
397            |val: f32, flag: bool| if flag { -val } else { val },
398        );
399
400        assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
401        Ok(())
402    }
403
404    #[test]
405    #[allow(clippy::cast_precision_loss)] // fine for this test
406    fn function_with_object_and_tuple() -> anyhow::Result<()> {
407        fn test_function(tuple: Tuple<f32>) -> Object<f32> {
408            let mut obj = Object::default();
409            obj.insert("len", Value::Prim(tuple.len() as f32));
410            obj.insert("tuple", tuple.into());
411            obj
412        }
413
414        let program = "
415            { len, tuple } = test((1, 1, 1));
416            len == 3 && tuple == (1, 1, 1)
417        ";
418        let block = Untyped::<F32Grammar>::parse_statements(program)?;
419        let module = ExecutableModule::new(WildcardId, &block)?;
420
421        let test_function = Value::native_fn(wrap(test_function));
422        let mut env = Environment::new();
423        env.insert("test", test_function).extend(Prelude::iter());
424
425        assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
426        Ok(())
427    }
428
429    #[test]
430    fn error_reporting_with_destructuring() -> anyhow::Result<()> {
431        let program = "destructure(((true, 1), (2, 3)))";
432        let block = Untyped::<F32Grammar>::parse_statements(program)?;
433        let module = ExecutableModule::new(WildcardId, &block)?;
434
435        let mut env = Environment::new();
436        env.extend(Prelude::iter());
437        env.insert_wrapped_fn("destructure", |values: Vec<(bool, f32)>| {
438            values
439                .into_iter()
440                .map(|(flag, x)| if flag { x } else { 0.0 })
441                .sum::<f32>()
442        });
443
444        let err = module.with_env(&env)?.run().unwrap_err();
445        let err_message = err.source().kind().to_short_string();
446        assert!(err_message.contains("Cannot convert primitive value to bool"));
447        assert!(err_message.contains("location: arg0[1].0"));
448        Ok(())
449    }
450
451    #[test]
452    fn function_with_context() -> anyhow::Result<()> {
453        #[allow(clippy::needless_pass_by_value)] // required for wrapping to work
454        fn call(
455            ctx: &mut CallContext<'_, f32>,
456            func: Function<f32>,
457            value: f32,
458        ) -> EvalResult<f32> {
459            let args = vec![ctx.apply_call_location(Value::Prim(value))];
460            func.evaluate(args, ctx)
461        }
462
463        let program = "(|x| { x + 1 }).call(1)";
464        let block = Untyped::<F32Grammar>::parse_statements(program)?;
465        let module = ExecutableModule::new(WildcardId, &block)?;
466
467        let mut env = Environment::new();
468        env.insert_wrapped_fn("call", call);
469        assert_eq!(module.with_env(&env)?.run()?, Value::Prim(2.0));
470        Ok(())
471    }
472}