arithmetic_eval/fns/
array.rs

1//! Functions on arrays.
2
3use core::cmp::Ordering;
4
5use num_traits::{FromPrimitive, One, Zero};
6
7use crate::{
8    alloc::{format, vec, Vec},
9    error::AuxErrorInfo,
10    fns::{extract_array, extract_fn, extract_primitive},
11    CallContext, ErrorKind, EvalResult, NativeFn, SpannedValue, Tuple, Value,
12};
13
14/// Function generating an array by mapping its indexes.
15///
16/// # Type
17///
18/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
19///
20/// ```text
21/// (Num, (Num) -> 'T) -> ['T]
22/// ```
23///
24/// # Examples
25///
26/// ```
27/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
28/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
29/// # fn main() -> anyhow::Result<()> {
30/// let program = "array(3, |i| 2 * i + 1) == (1, 3, 5)";
31/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
32/// let module = ExecutableModule::new("test_array", &program)?;
33///
34/// let mut env = Environment::new();
35/// env.insert_native_fn("array", fns::Array);
36/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
37/// # Ok(())
38/// # }
39/// ```
40#[derive(Debug, Clone, Copy, Default)]
41pub struct Array;
42
43impl<T> NativeFn<T> for Array
44where
45    T: 'static + Clone + Zero + One,
46{
47    fn evaluate<'a>(
48        &self,
49        mut args: Vec<SpannedValue<T>>,
50        ctx: &mut CallContext<'_, T>,
51    ) -> EvalResult<T> {
52        ctx.check_args_count(&args, 2)?;
53        let generation_fn = extract_fn(
54            ctx,
55            args.pop().unwrap(),
56            "`array` requires second arg to be a generation function",
57        )?;
58        let len = extract_primitive(
59            ctx,
60            args.pop().unwrap(),
61            "`array` requires first arg to be a number",
62        )?;
63
64        let mut index = T::zero();
65        let mut array = vec![];
66        loop {
67            let next_index = ctx
68                .arithmetic()
69                .add(index.clone(), T::one())
70                .map_err(|err| ctx.call_site_error(ErrorKind::Arithmetic(err)))?;
71
72            let cmp = ctx.arithmetic().partial_cmp(&next_index, &len);
73            if matches!(cmp, Some(Ordering::Less | Ordering::Equal)) {
74                let spanned = ctx.apply_call_location(Value::Prim(index));
75                array.push(generation_fn.evaluate(vec![spanned], ctx)?);
76                index = next_index;
77            } else {
78                break;
79            }
80        }
81        Ok(Value::Tuple(array.into()))
82    }
83}
84
85/// Function returning array / object length.
86///
87/// # Type
88///
89/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
90///
91/// ```text
92/// ([T]) -> Num
93/// ```
94///
95/// # Examples
96///
97/// ```
98/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
99/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
100/// # fn main() -> anyhow::Result<()> {
101/// let program = "len(()) == 0 && len((1, 2, 3)) == 3";
102/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
103/// let module = ExecutableModule::new("tes_len", &program)?;
104///
105/// let mut env = Environment::new();
106/// env.insert_native_fn("len", fns::Len);
107/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
108/// # Ok(())
109/// # }
110/// ```
111#[derive(Debug, Clone, Copy, Default)]
112pub struct Len;
113
114impl<T: FromPrimitive> NativeFn<T> for Len {
115    fn evaluate(
116        &self,
117        mut args: Vec<SpannedValue<T>>,
118        ctx: &mut CallContext<'_, T>,
119    ) -> EvalResult<T> {
120        ctx.check_args_count(&args, 1)?;
121        let arg = args.pop().unwrap();
122
123        let len = match arg.extra {
124            Value::Tuple(array) => array.len(),
125            Value::Object(object) => object.len(),
126            _ => {
127                let err = ErrorKind::native("`len` requires object or tuple arg");
128                return Err(ctx
129                    .call_site_error(err)
130                    .with_location(&arg, AuxErrorInfo::InvalidArg));
131            }
132        };
133        let len = T::from_usize(len).ok_or_else(|| {
134            let err = ErrorKind::native("Cannot convert length to number");
135            ctx.call_site_error(err)
136        })?;
137        Ok(Value::Prim(len))
138    }
139}
140
141/// Map function that evaluates the provided function on each item of the tuple.
142///
143/// # Type
144///
145/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
146///
147/// ```text
148/// (['T; N], ('T) -> 'U) -> ['U; N]
149/// ```
150///
151/// # Examples
152///
153/// ```
154/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
155/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
156/// # fn main() -> anyhow::Result<()> {
157/// let program = "
158///     xs = (1, -2, 3, -0.3);
159///     map(xs, |x| if(x > 0, x, 0)) == (1, 0, 3, 0)
160/// ";
161/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
162/// let module = ExecutableModule::new("test_map", &program)?;
163///
164/// let mut env = Environment::new();
165/// env.insert_native_fn("if", fns::If).insert_native_fn("map", fns::Map);
166/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
167/// # Ok(())
168/// # }
169/// ```
170#[derive(Debug, Clone, Copy, Default)]
171pub struct Map;
172
173impl<T: 'static + Clone> NativeFn<T> for Map {
174    fn evaluate(
175        &self,
176        mut args: Vec<SpannedValue<T>>,
177        ctx: &mut CallContext<'_, T>,
178    ) -> EvalResult<T> {
179        ctx.check_args_count(&args, 2)?;
180        let map_fn = extract_fn(
181            ctx,
182            args.pop().unwrap(),
183            "`map` requires second arg to be a mapping function",
184        )?;
185        let array = extract_array(
186            ctx,
187            args.pop().unwrap(),
188            "`map` requires first arg to be a tuple",
189        )?;
190
191        let mapped: Result<Tuple<_>, _> = array
192            .into_iter()
193            .map(|value| {
194                let spanned = ctx.apply_call_location(value);
195                map_fn.evaluate(vec![spanned], ctx)
196            })
197            .collect();
198        mapped.map(Value::Tuple)
199    }
200}
201
202/// Filter function that evaluates the provided function on each item of the tuple and retains
203/// only elements for which the function returned `true`.
204///
205/// # Type
206///
207/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
208///
209/// ```text
210/// (['T; N], ('T) -> Bool) -> ['T]
211/// ```
212///
213/// # Examples
214///
215/// ```
216/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
217/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
218/// # fn main() -> anyhow::Result<()> {
219/// let program = "
220///     xs = (1, -2, 3, -7, -0.3);
221///     filter(xs, |x| x > -1) == (1, 3, -0.3)
222/// ";
223/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
224/// let module = ExecutableModule::new("test_filter", &program)?;
225///
226/// let mut env = Environment::new();
227/// env.insert_native_fn("filter", fns::Filter);
228/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
229/// # Ok(())
230/// # }
231/// ```
232#[derive(Debug, Clone, Copy, Default)]
233pub struct Filter;
234
235impl<T: 'static + Clone> NativeFn<T> for Filter {
236    fn evaluate(
237        &self,
238        mut args: Vec<SpannedValue<T>>,
239        ctx: &mut CallContext<'_, T>,
240    ) -> EvalResult<T> {
241        ctx.check_args_count(&args, 2)?;
242        let filter_fn = extract_fn(
243            ctx,
244            args.pop().unwrap(),
245            "`filter` requires second arg to be a filter function",
246        )?;
247        let array = extract_array(
248            ctx,
249            args.pop().unwrap(),
250            "`filter` requires first arg to be a tuple",
251        )?;
252
253        let mut filtered = vec![];
254        for value in array {
255            let spanned = ctx.apply_call_location(value.clone());
256            match filter_fn.evaluate(vec![spanned], ctx)? {
257                Value::Bool(true) => filtered.push(value),
258                Value::Bool(false) => { /* do nothing */ }
259                _ => {
260                    let err = ErrorKind::native(
261                        "`filter` requires filtering function to return booleans",
262                    );
263                    return Err(ctx.call_site_error(err));
264                }
265            }
266        }
267        Ok(Value::Tuple(filtered.into()))
268    }
269}
270
271/// Reduce (aka fold) function that reduces the provided tuple to a single value.
272///
273/// # Type
274///
275/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
276///
277/// ```text
278/// (['T], 'Acc, ('Acc, 'T) -> 'Acc) -> 'Acc
279/// ```
280///
281/// # Examples
282///
283/// ```
284/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
285/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
286/// # fn main() -> anyhow::Result<()> {
287/// let program = "
288///     xs = (1, -2, 3, -7);
289///     fold(xs, 1, |acc, x| acc * x) == 42
290/// ";
291/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
292/// let module = ExecutableModule::new("test_fold", &program)?;
293///
294/// let mut env = Environment::new();
295/// env.insert_native_fn("fold", fns::Fold);
296/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
297/// # Ok(())
298/// # }
299/// ```
300#[derive(Debug, Clone, Copy, Default)]
301pub struct Fold;
302
303impl<T: 'static + Clone> NativeFn<T> for Fold {
304    fn evaluate(
305        &self,
306        mut args: Vec<SpannedValue<T>>,
307        ctx: &mut CallContext<'_, T>,
308    ) -> EvalResult<T> {
309        ctx.check_args_count(&args, 3)?;
310        let fold_fn = extract_fn(
311            ctx,
312            args.pop().unwrap(),
313            "`fold` requires third arg to be a folding function",
314        )?;
315        let acc = args.pop().unwrap().extra;
316        let array = extract_array(
317            ctx,
318            args.pop().unwrap(),
319            "`fold` requires first arg to be a tuple",
320        )?;
321
322        array.into_iter().try_fold(acc, |acc, value| {
323            let spanned_args = vec![ctx.apply_call_location(acc), ctx.apply_call_location(value)];
324            fold_fn.evaluate(spanned_args, ctx)
325        })
326    }
327}
328
329/// Function that appends a value onto a tuple.
330///
331/// # Type
332///
333/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
334///
335/// ```text
336/// (['T; N], 'T) -> ['T; N + 1]
337/// ```
338///
339/// # Examples
340///
341/// ```
342/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
343/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
344/// # fn main() -> anyhow::Result<()> {
345/// let program = "
346///     repeat = |x, times| {
347///         (_, acc) = while(
348///             (0, ()),
349///             |(i, _)| i < times,
350///             |(i, acc)| (i + 1, push(acc, x)),
351///         );
352///         acc
353///     };
354///     repeat(-2, 3) == (-2, -2, -2) &&
355///         repeat((7,), 4) == ((7,), (7,), (7,), (7,))
356/// ";
357/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
358/// let module = ExecutableModule::new("test_push", &program)?;
359///
360/// let mut env = Environment::new();
361/// env.insert_native_fn("while", fns::While)
362///     .insert_native_fn("push", fns::Push);
363/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
364/// # Ok(())
365/// # }
366/// ```
367#[derive(Debug, Clone, Copy, Default)]
368pub struct Push;
369
370impl<T> NativeFn<T> for Push {
371    fn evaluate(
372        &self,
373        mut args: Vec<SpannedValue<T>>,
374        ctx: &mut CallContext<'_, T>,
375    ) -> EvalResult<T> {
376        ctx.check_args_count(&args, 2)?;
377        let elem = args.pop().unwrap().extra;
378        let mut array = extract_array(
379            ctx,
380            args.pop().unwrap(),
381            "`push` requires first arg to be a tuple",
382        )?;
383
384        array.push(elem);
385        Ok(Value::Tuple(array.into()))
386    }
387}
388
389/// Function that merges two tuples.
390///
391/// # Type
392///
393/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
394///
395/// ```text
396/// (['T], ['T]) -> ['T]
397/// ```
398///
399/// # Examples
400///
401/// ```
402/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
403/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
404/// # fn main() -> anyhow::Result<()> {
405/// let program = "
406///     // Merges all arguments (which should be tuples) into a single tuple.
407///     super_merge = |...xs| fold(xs, (), merge);
408///     super_merge((1, 2), (3,), (), (4, 5, 6)) == (1, 2, 3, 4, 5, 6)
409/// ";
410/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
411/// let module = ExecutableModule::new("test_merge", &program)?;
412///
413/// let mut env = Environment::new();
414/// env.insert_native_fn("fold", fns::Fold)
415///     .insert_native_fn("merge", fns::Merge);
416/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
417/// # Ok(())
418/// # }
419/// ```
420#[derive(Debug, Clone, Copy, Default)]
421pub struct Merge;
422
423impl<T: Clone> NativeFn<T> for Merge {
424    fn evaluate(
425        &self,
426        mut args: Vec<SpannedValue<T>>,
427        ctx: &mut CallContext<'_, T>,
428    ) -> EvalResult<T> {
429        ctx.check_args_count(&args, 2)?;
430        let second = extract_array(
431            ctx,
432            args.pop().unwrap(),
433            "`merge` requires second arg to be a tuple",
434        )?;
435        let mut first = extract_array(
436            ctx,
437            args.pop().unwrap(),
438            "`merge` requires first arg to be a tuple",
439        )?;
440
441        first.extend_from_slice(&second);
442        Ok(Value::Tuple(first.into()))
443    }
444}
445
446/// Function that checks whether any of array items satisfy the provided predicate.
447///
448/// # Type
449///
450/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
451///
452/// ```text
453/// (['T], ('T) -> Bool) -> Bool
454/// ```
455///
456/// # Examples
457///
458/// ```
459/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
460/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
461/// # fn main() -> anyhow::Result<()> {
462/// let program = "
463///     assert(any((1, 3, -1), |x| x < 0));
464///     assert(!any((1, 2, 3), |x| x < 0));
465/// ";
466/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
467/// let module = ExecutableModule::new("test_any", &program)?;
468///
469/// let mut env = Environment::new();
470/// env.insert_native_fn("any", fns::Any)
471///     .insert_native_fn("assert", fns::Assert);
472/// module.with_env(&env)?.run()?;
473/// # Ok(())
474/// # }
475/// ```
476#[derive(Debug, Clone, Copy, Default)]
477pub struct Any;
478
479impl<T: Clone + 'static> NativeFn<T> for Any {
480    fn evaluate(
481        &self,
482        mut args: Vec<SpannedValue<T>>,
483        ctx: &mut CallContext<'_, T>,
484    ) -> EvalResult<T> {
485        ctx.check_args_count(&args, 2)?;
486        let predicate = extract_fn(
487            ctx,
488            args.pop().unwrap(),
489            "`any` requires second arg to be a predicate function",
490        )?;
491        let array = extract_array(
492            ctx,
493            args.pop().unwrap(),
494            "`any` requires first arg to be a tuple",
495        )?;
496
497        for value in array {
498            let spanned = ctx.apply_call_location(value);
499            let result = predicate.evaluate(vec![spanned], ctx)?;
500            match result {
501                Value::Bool(false) => { /* continue */ }
502                Value::Bool(true) => return Ok(Value::Bool(true)),
503                _ => {
504                    let err = ErrorKind::native(format!(
505                        "Incorrect return type of a predicate: expected Boolean, got {}",
506                        result.value_type()
507                    ));
508                    ctx.call_site_error(err);
509                }
510            }
511        }
512        Ok(Value::Bool(false))
513    }
514}
515
516/// Function that checks whether all of array items satisfy the provided predicate.
517///
518/// # Type
519///
520/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
521///
522/// ```text
523/// (['T], ('T) -> Bool) -> Bool
524/// ```
525///
526/// # Examples
527///
528/// ```
529/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
530/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
531/// # fn main() -> anyhow::Result<()> {
532/// let program = "
533///     assert(all((1, 2, 3, 5), |x| x > 0));
534///     assert(!all((1, -2, 3), |x| x > 0));
535/// ";
536/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
537/// let module = ExecutableModule::new("test_all", &program)?;
538///
539/// let mut env = Environment::new();
540/// env.insert_native_fn("all", fns::All)
541///     .insert_native_fn("assert", fns::Assert);
542/// module.with_env(&env)?.run()?;
543/// # Ok(())
544/// # }
545/// ```
546#[derive(Debug, Clone, Copy, Default)]
547pub struct All;
548
549impl<T: Clone + 'static> NativeFn<T> for All {
550    fn evaluate(
551        &self,
552        mut args: Vec<SpannedValue<T>>,
553        ctx: &mut CallContext<'_, T>,
554    ) -> EvalResult<T> {
555        ctx.check_args_count(&args, 2)?;
556        let predicate = extract_fn(
557            ctx,
558            args.pop().unwrap(),
559            "`all` requires second arg to be a predicate function",
560        )?;
561        let array = extract_array(
562            ctx,
563            args.pop().unwrap(),
564            "`all` requires first arg to be a tuple",
565        )?;
566
567        for value in array {
568            let spanned = ctx.apply_call_location(value);
569            let result = predicate.evaluate(vec![spanned], ctx)?;
570            match result {
571                Value::Bool(false) => return Ok(Value::Bool(false)),
572                Value::Bool(true) => { /* continue */ }
573                _ => {
574                    let err = ErrorKind::native(format!(
575                        "Incorrect return type of a predicate: expected Boolean, got {}",
576                        result.value_type()
577                    ));
578                    ctx.call_site_error(err);
579                }
580            }
581        }
582        Ok(Value::Bool(true))
583    }
584}
585
586#[cfg(test)]
587mod tests {
588    use arithmetic_parser::grammars::{F32Grammar, NumGrammar, NumLiteral, Parse, Untyped};
589    use assert_matches::assert_matches;
590
591    use super::*;
592    use crate::{
593        arith::{OrdArithmetic, StdArithmetic, WrappingArithmetic},
594        Environment, ExecutableModule,
595    };
596
597    fn test_len_function<T: NumLiteral, A>(arithmetic: A)
598    where
599        Len: NativeFn<T>,
600        A: OrdArithmetic<T> + 'static,
601    {
602        let code = "
603            len((1, 2, 3)) == 3 && len(()) == 0 &&
604            len(#{}) == 0 && len(#{ x: 1 }) == 1 && len(#{ x: 1, y: 2 }) == 2
605        ";
606        let block = Untyped::<NumGrammar<T>>::parse_statements(code).unwrap();
607        let module = ExecutableModule::new("len", &block).unwrap();
608        let mut env = Environment::with_arithmetic(arithmetic);
609        env.insert_native_fn("len", Len);
610
611        let output = module.with_env(&env).unwrap().run().unwrap();
612        assert_matches!(output, Value::Bool(true));
613    }
614
615    #[test]
616    fn len_function_in_floating_point_arithmetic() {
617        test_len_function::<f32, _>(StdArithmetic);
618        test_len_function::<f64, _>(StdArithmetic);
619    }
620
621    #[test]
622    fn len_function_in_int_arithmetic() {
623        test_len_function::<u8, _>(WrappingArithmetic);
624        test_len_function::<i8, _>(WrappingArithmetic);
625        test_len_function::<u64, _>(WrappingArithmetic);
626        test_len_function::<i64, _>(WrappingArithmetic);
627    }
628
629    #[test]
630    fn len_function_with_number_overflow() -> anyhow::Result<()> {
631        let code = "len(xs)";
632        let block = Untyped::<NumGrammar<i8>>::parse_statements(code)?;
633        let module = ExecutableModule::new("len", &block)?;
634
635        let mut env = Environment::with_arithmetic(WrappingArithmetic);
636        env.insert("xs", Value::from(vec![Value::Bool(true); 128]))
637            .insert_native_fn("len", Len);
638
639        let err = module.with_env(&env)?.run().unwrap_err();
640        assert_matches!(
641            err.source().kind(),
642            ErrorKind::NativeCall(msg) if msg.contains("length to number")
643        );
644        Ok(())
645    }
646
647    #[test]
648    fn array_function_in_floating_point_arithmetic() -> anyhow::Result<()> {
649        let code = "
650            array(0, |_| 1) == () && array(-1, |_| 1) == () &&
651            array(0.1, |_| 1) == () && array(0.999, |_| 1) == () &&
652            array(1, |_| 1) == (1,) && array(1.5, |_| 1) == (1,) &&
653            array(2, |_| 1) == (1, 1) && array(3, |i| i) == (0, 1, 2)
654        ";
655        let block = Untyped::<NumGrammar<f32>>::parse_statements(code)?;
656        let module = ExecutableModule::new("array", &block)?;
657
658        let mut env = Environment::new();
659        env.insert_native_fn("array", Array);
660
661        let output = module.with_env(&env)?.run()?;
662        assert_matches!(output, Value::Bool(true));
663        Ok(())
664    }
665
666    #[test]
667    fn array_function_in_unsigned_int_arithmetic() -> anyhow::Result<()> {
668        let code = "
669            array(0, |_| 1) == () && array(1, |_| 1) == (1,) && array(3, |i| i) == (0, 1, 2)
670        ";
671        let block = Untyped::<NumGrammar<u32>>::parse_statements(code)?;
672        let module = ExecutableModule::new("array", &block)?;
673
674        let mut env = Environment::with_arithmetic(WrappingArithmetic);
675        env.insert_native_fn("array", Array);
676
677        let output = module.with_env(&env)?.run()?;
678        assert_matches!(output, Value::Bool(true));
679        Ok(())
680    }
681
682    #[test]
683    fn all_and_any_are_short_circuit() -> anyhow::Result<()> {
684        let code = "
685            !all((1, 5 == 5), |x| x < 0) && any((-1, 1, 5 == 4), |x| x > 0)
686        ";
687        let block = Untyped::<F32Grammar>::parse_statements(code)?;
688        let module = ExecutableModule::new("array", &block)?;
689
690        let mut env = Environment::new();
691        env.insert_native_fn("all", All)
692            .insert_native_fn("any", Any);
693
694        let output = module.with_env(&env)?.run()?;
695        assert_matches!(output, Value::Bool(true));
696        Ok(())
697    }
698}