arithmetic_eval/fns/array.rs
1//! Functions on arrays.
2
3use core::cmp::Ordering;
4
5use num_traits::{FromPrimitive, One, Zero};
6
7use crate::{
8 alloc::{format, vec, Vec},
9 error::AuxErrorInfo,
10 fns::{extract_array, extract_fn, extract_primitive},
11 CallContext, ErrorKind, EvalResult, NativeFn, SpannedValue, Tuple, Value,
12};
13
14/// Function generating an array by mapping its indexes.
15///
16/// # Type
17///
18/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
19///
20/// ```text
21/// (Num, (Num) -> 'T) -> ['T]
22/// ```
23///
24/// # Examples
25///
26/// ```
27/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
28/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
29/// # fn main() -> anyhow::Result<()> {
30/// let program = "array(3, |i| 2 * i + 1) == (1, 3, 5)";
31/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
32/// let module = ExecutableModule::new("test_array", &program)?;
33///
34/// let mut env = Environment::new();
35/// env.insert_native_fn("array", fns::Array);
36/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
37/// # Ok(())
38/// # }
39/// ```
40#[derive(Debug, Clone, Copy, Default)]
41pub struct Array;
42
43impl<T> NativeFn<T> for Array
44where
45 T: 'static + Clone + Zero + One,
46{
47 fn evaluate<'a>(
48 &self,
49 mut args: Vec<SpannedValue<T>>,
50 ctx: &mut CallContext<'_, T>,
51 ) -> EvalResult<T> {
52 ctx.check_args_count(&args, 2)?;
53 let generation_fn = extract_fn(
54 ctx,
55 args.pop().unwrap(),
56 "`array` requires second arg to be a generation function",
57 )?;
58 let len = extract_primitive(
59 ctx,
60 args.pop().unwrap(),
61 "`array` requires first arg to be a number",
62 )?;
63
64 let mut index = T::zero();
65 let mut array = vec![];
66 loop {
67 let next_index = ctx
68 .arithmetic()
69 .add(index.clone(), T::one())
70 .map_err(|err| ctx.call_site_error(ErrorKind::Arithmetic(err)))?;
71
72 let cmp = ctx.arithmetic().partial_cmp(&next_index, &len);
73 if matches!(cmp, Some(Ordering::Less | Ordering::Equal)) {
74 let spanned = ctx.apply_call_location(Value::Prim(index));
75 array.push(generation_fn.evaluate(vec![spanned], ctx)?);
76 index = next_index;
77 } else {
78 break;
79 }
80 }
81 Ok(Value::Tuple(array.into()))
82 }
83}
84
85/// Function returning array / object length.
86///
87/// # Type
88///
89/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
90///
91/// ```text
92/// ([T]) -> Num
93/// ```
94///
95/// # Examples
96///
97/// ```
98/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
99/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
100/// # fn main() -> anyhow::Result<()> {
101/// let program = "len(()) == 0 && len((1, 2, 3)) == 3";
102/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
103/// let module = ExecutableModule::new("tes_len", &program)?;
104///
105/// let mut env = Environment::new();
106/// env.insert_native_fn("len", fns::Len);
107/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
108/// # Ok(())
109/// # }
110/// ```
111#[derive(Debug, Clone, Copy, Default)]
112pub struct Len;
113
114impl<T: FromPrimitive> NativeFn<T> for Len {
115 fn evaluate(
116 &self,
117 mut args: Vec<SpannedValue<T>>,
118 ctx: &mut CallContext<'_, T>,
119 ) -> EvalResult<T> {
120 ctx.check_args_count(&args, 1)?;
121 let arg = args.pop().unwrap();
122
123 let len = match arg.extra {
124 Value::Tuple(array) => array.len(),
125 Value::Object(object) => object.len(),
126 _ => {
127 let err = ErrorKind::native("`len` requires object or tuple arg");
128 return Err(ctx
129 .call_site_error(err)
130 .with_location(&arg, AuxErrorInfo::InvalidArg));
131 }
132 };
133 let len = T::from_usize(len).ok_or_else(|| {
134 let err = ErrorKind::native("Cannot convert length to number");
135 ctx.call_site_error(err)
136 })?;
137 Ok(Value::Prim(len))
138 }
139}
140
141/// Map function that evaluates the provided function on each item of the tuple.
142///
143/// # Type
144///
145/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
146///
147/// ```text
148/// (['T; N], ('T) -> 'U) -> ['U; N]
149/// ```
150///
151/// # Examples
152///
153/// ```
154/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
155/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
156/// # fn main() -> anyhow::Result<()> {
157/// let program = "
158/// xs = (1, -2, 3, -0.3);
159/// map(xs, |x| if(x > 0, x, 0)) == (1, 0, 3, 0)
160/// ";
161/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
162/// let module = ExecutableModule::new("test_map", &program)?;
163///
164/// let mut env = Environment::new();
165/// env.insert_native_fn("if", fns::If).insert_native_fn("map", fns::Map);
166/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
167/// # Ok(())
168/// # }
169/// ```
170#[derive(Debug, Clone, Copy, Default)]
171pub struct Map;
172
173impl<T: 'static + Clone> NativeFn<T> for Map {
174 fn evaluate(
175 &self,
176 mut args: Vec<SpannedValue<T>>,
177 ctx: &mut CallContext<'_, T>,
178 ) -> EvalResult<T> {
179 ctx.check_args_count(&args, 2)?;
180 let map_fn = extract_fn(
181 ctx,
182 args.pop().unwrap(),
183 "`map` requires second arg to be a mapping function",
184 )?;
185 let array = extract_array(
186 ctx,
187 args.pop().unwrap(),
188 "`map` requires first arg to be a tuple",
189 )?;
190
191 let mapped: Result<Tuple<_>, _> = array
192 .into_iter()
193 .map(|value| {
194 let spanned = ctx.apply_call_location(value);
195 map_fn.evaluate(vec![spanned], ctx)
196 })
197 .collect();
198 mapped.map(Value::Tuple)
199 }
200}
201
202/// Filter function that evaluates the provided function on each item of the tuple and retains
203/// only elements for which the function returned `true`.
204///
205/// # Type
206///
207/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
208///
209/// ```text
210/// (['T; N], ('T) -> Bool) -> ['T]
211/// ```
212///
213/// # Examples
214///
215/// ```
216/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
217/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
218/// # fn main() -> anyhow::Result<()> {
219/// let program = "
220/// xs = (1, -2, 3, -7, -0.3);
221/// filter(xs, |x| x > -1) == (1, 3, -0.3)
222/// ";
223/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
224/// let module = ExecutableModule::new("test_filter", &program)?;
225///
226/// let mut env = Environment::new();
227/// env.insert_native_fn("filter", fns::Filter);
228/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
229/// # Ok(())
230/// # }
231/// ```
232#[derive(Debug, Clone, Copy, Default)]
233pub struct Filter;
234
235impl<T: 'static + Clone> NativeFn<T> for Filter {
236 fn evaluate(
237 &self,
238 mut args: Vec<SpannedValue<T>>,
239 ctx: &mut CallContext<'_, T>,
240 ) -> EvalResult<T> {
241 ctx.check_args_count(&args, 2)?;
242 let filter_fn = extract_fn(
243 ctx,
244 args.pop().unwrap(),
245 "`filter` requires second arg to be a filter function",
246 )?;
247 let array = extract_array(
248 ctx,
249 args.pop().unwrap(),
250 "`filter` requires first arg to be a tuple",
251 )?;
252
253 let mut filtered = vec![];
254 for value in array {
255 let spanned = ctx.apply_call_location(value.clone());
256 match filter_fn.evaluate(vec![spanned], ctx)? {
257 Value::Bool(true) => filtered.push(value),
258 Value::Bool(false) => { /* do nothing */ }
259 _ => {
260 let err = ErrorKind::native(
261 "`filter` requires filtering function to return booleans",
262 );
263 return Err(ctx.call_site_error(err));
264 }
265 }
266 }
267 Ok(Value::Tuple(filtered.into()))
268 }
269}
270
271/// Reduce (aka fold) function that reduces the provided tuple to a single value.
272///
273/// # Type
274///
275/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
276///
277/// ```text
278/// (['T], 'Acc, ('Acc, 'T) -> 'Acc) -> 'Acc
279/// ```
280///
281/// # Examples
282///
283/// ```
284/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
285/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
286/// # fn main() -> anyhow::Result<()> {
287/// let program = "
288/// xs = (1, -2, 3, -7);
289/// fold(xs, 1, |acc, x| acc * x) == 42
290/// ";
291/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
292/// let module = ExecutableModule::new("test_fold", &program)?;
293///
294/// let mut env = Environment::new();
295/// env.insert_native_fn("fold", fns::Fold);
296/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
297/// # Ok(())
298/// # }
299/// ```
300#[derive(Debug, Clone, Copy, Default)]
301pub struct Fold;
302
303impl<T: 'static + Clone> NativeFn<T> for Fold {
304 fn evaluate(
305 &self,
306 mut args: Vec<SpannedValue<T>>,
307 ctx: &mut CallContext<'_, T>,
308 ) -> EvalResult<T> {
309 ctx.check_args_count(&args, 3)?;
310 let fold_fn = extract_fn(
311 ctx,
312 args.pop().unwrap(),
313 "`fold` requires third arg to be a folding function",
314 )?;
315 let acc = args.pop().unwrap().extra;
316 let array = extract_array(
317 ctx,
318 args.pop().unwrap(),
319 "`fold` requires first arg to be a tuple",
320 )?;
321
322 array.into_iter().try_fold(acc, |acc, value| {
323 let spanned_args = vec![ctx.apply_call_location(acc), ctx.apply_call_location(value)];
324 fold_fn.evaluate(spanned_args, ctx)
325 })
326 }
327}
328
329/// Function that appends a value onto a tuple.
330///
331/// # Type
332///
333/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
334///
335/// ```text
336/// (['T; N], 'T) -> ['T; N + 1]
337/// ```
338///
339/// # Examples
340///
341/// ```
342/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
343/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
344/// # fn main() -> anyhow::Result<()> {
345/// let program = "
346/// repeat = |x, times| {
347/// (_, acc) = while(
348/// (0, ()),
349/// |(i, _)| i < times,
350/// |(i, acc)| (i + 1, push(acc, x)),
351/// );
352/// acc
353/// };
354/// repeat(-2, 3) == (-2, -2, -2) &&
355/// repeat((7,), 4) == ((7,), (7,), (7,), (7,))
356/// ";
357/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
358/// let module = ExecutableModule::new("test_push", &program)?;
359///
360/// let mut env = Environment::new();
361/// env.insert_native_fn("while", fns::While)
362/// .insert_native_fn("push", fns::Push);
363/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
364/// # Ok(())
365/// # }
366/// ```
367#[derive(Debug, Clone, Copy, Default)]
368pub struct Push;
369
370impl<T> NativeFn<T> for Push {
371 fn evaluate(
372 &self,
373 mut args: Vec<SpannedValue<T>>,
374 ctx: &mut CallContext<'_, T>,
375 ) -> EvalResult<T> {
376 ctx.check_args_count(&args, 2)?;
377 let elem = args.pop().unwrap().extra;
378 let mut array = extract_array(
379 ctx,
380 args.pop().unwrap(),
381 "`push` requires first arg to be a tuple",
382 )?;
383
384 array.push(elem);
385 Ok(Value::Tuple(array.into()))
386 }
387}
388
389/// Function that merges two tuples.
390///
391/// # Type
392///
393/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
394///
395/// ```text
396/// (['T], ['T]) -> ['T]
397/// ```
398///
399/// # Examples
400///
401/// ```
402/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
403/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
404/// # fn main() -> anyhow::Result<()> {
405/// let program = "
406/// // Merges all arguments (which should be tuples) into a single tuple.
407/// super_merge = |...xs| fold(xs, (), merge);
408/// super_merge((1, 2), (3,), (), (4, 5, 6)) == (1, 2, 3, 4, 5, 6)
409/// ";
410/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
411/// let module = ExecutableModule::new("test_merge", &program)?;
412///
413/// let mut env = Environment::new();
414/// env.insert_native_fn("fold", fns::Fold)
415/// .insert_native_fn("merge", fns::Merge);
416/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
417/// # Ok(())
418/// # }
419/// ```
420#[derive(Debug, Clone, Copy, Default)]
421pub struct Merge;
422
423impl<T: Clone> NativeFn<T> for Merge {
424 fn evaluate(
425 &self,
426 mut args: Vec<SpannedValue<T>>,
427 ctx: &mut CallContext<'_, T>,
428 ) -> EvalResult<T> {
429 ctx.check_args_count(&args, 2)?;
430 let second = extract_array(
431 ctx,
432 args.pop().unwrap(),
433 "`merge` requires second arg to be a tuple",
434 )?;
435 let mut first = extract_array(
436 ctx,
437 args.pop().unwrap(),
438 "`merge` requires first arg to be a tuple",
439 )?;
440
441 first.extend_from_slice(&second);
442 Ok(Value::Tuple(first.into()))
443 }
444}
445
446/// Function that checks whether any of array items satisfy the provided predicate.
447///
448/// # Type
449///
450/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
451///
452/// ```text
453/// (['T], ('T) -> Bool) -> Bool
454/// ```
455///
456/// # Examples
457///
458/// ```
459/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
460/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
461/// # fn main() -> anyhow::Result<()> {
462/// let program = "
463/// assert(any((1, 3, -1), |x| x < 0));
464/// assert(!any((1, 2, 3), |x| x < 0));
465/// ";
466/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
467/// let module = ExecutableModule::new("test_any", &program)?;
468///
469/// let mut env = Environment::new();
470/// env.insert_native_fn("any", fns::Any)
471/// .insert_native_fn("assert", fns::Assert);
472/// module.with_env(&env)?.run()?;
473/// # Ok(())
474/// # }
475/// ```
476#[derive(Debug, Clone, Copy, Default)]
477pub struct Any;
478
479impl<T: Clone + 'static> NativeFn<T> for Any {
480 fn evaluate(
481 &self,
482 mut args: Vec<SpannedValue<T>>,
483 ctx: &mut CallContext<'_, T>,
484 ) -> EvalResult<T> {
485 ctx.check_args_count(&args, 2)?;
486 let predicate = extract_fn(
487 ctx,
488 args.pop().unwrap(),
489 "`any` requires second arg to be a predicate function",
490 )?;
491 let array = extract_array(
492 ctx,
493 args.pop().unwrap(),
494 "`any` requires first arg to be a tuple",
495 )?;
496
497 for value in array {
498 let spanned = ctx.apply_call_location(value);
499 let result = predicate.evaluate(vec![spanned], ctx)?;
500 match result {
501 Value::Bool(false) => { /* continue */ }
502 Value::Bool(true) => return Ok(Value::Bool(true)),
503 _ => {
504 let err = ErrorKind::native(format!(
505 "Incorrect return type of a predicate: expected Boolean, got {}",
506 result.value_type()
507 ));
508 ctx.call_site_error(err);
509 }
510 }
511 }
512 Ok(Value::Bool(false))
513 }
514}
515
516/// Function that checks whether all of array items satisfy the provided predicate.
517///
518/// # Type
519///
520/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
521///
522/// ```text
523/// (['T], ('T) -> Bool) -> Bool
524/// ```
525///
526/// # Examples
527///
528/// ```
529/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
530/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
531/// # fn main() -> anyhow::Result<()> {
532/// let program = "
533/// assert(all((1, 2, 3, 5), |x| x > 0));
534/// assert(!all((1, -2, 3), |x| x > 0));
535/// ";
536/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
537/// let module = ExecutableModule::new("test_all", &program)?;
538///
539/// let mut env = Environment::new();
540/// env.insert_native_fn("all", fns::All)
541/// .insert_native_fn("assert", fns::Assert);
542/// module.with_env(&env)?.run()?;
543/// # Ok(())
544/// # }
545/// ```
546#[derive(Debug, Clone, Copy, Default)]
547pub struct All;
548
549impl<T: Clone + 'static> NativeFn<T> for All {
550 fn evaluate(
551 &self,
552 mut args: Vec<SpannedValue<T>>,
553 ctx: &mut CallContext<'_, T>,
554 ) -> EvalResult<T> {
555 ctx.check_args_count(&args, 2)?;
556 let predicate = extract_fn(
557 ctx,
558 args.pop().unwrap(),
559 "`all` requires second arg to be a predicate function",
560 )?;
561 let array = extract_array(
562 ctx,
563 args.pop().unwrap(),
564 "`all` requires first arg to be a tuple",
565 )?;
566
567 for value in array {
568 let spanned = ctx.apply_call_location(value);
569 let result = predicate.evaluate(vec![spanned], ctx)?;
570 match result {
571 Value::Bool(false) => return Ok(Value::Bool(false)),
572 Value::Bool(true) => { /* continue */ }
573 _ => {
574 let err = ErrorKind::native(format!(
575 "Incorrect return type of a predicate: expected Boolean, got {}",
576 result.value_type()
577 ));
578 ctx.call_site_error(err);
579 }
580 }
581 }
582 Ok(Value::Bool(true))
583 }
584}
585
586#[cfg(test)]
587mod tests {
588 use arithmetic_parser::grammars::{F32Grammar, NumGrammar, NumLiteral, Parse, Untyped};
589 use assert_matches::assert_matches;
590
591 use super::*;
592 use crate::{
593 arith::{OrdArithmetic, StdArithmetic, WrappingArithmetic},
594 Environment, ExecutableModule,
595 };
596
597 fn test_len_function<T: NumLiteral, A>(arithmetic: A)
598 where
599 Len: NativeFn<T>,
600 A: OrdArithmetic<T> + 'static,
601 {
602 let code = "
603 len((1, 2, 3)) == 3 && len(()) == 0 &&
604 len(#{}) == 0 && len(#{ x: 1 }) == 1 && len(#{ x: 1, y: 2 }) == 2
605 ";
606 let block = Untyped::<NumGrammar<T>>::parse_statements(code).unwrap();
607 let module = ExecutableModule::new("len", &block).unwrap();
608 let mut env = Environment::with_arithmetic(arithmetic);
609 env.insert_native_fn("len", Len);
610
611 let output = module.with_env(&env).unwrap().run().unwrap();
612 assert_matches!(output, Value::Bool(true));
613 }
614
615 #[test]
616 fn len_function_in_floating_point_arithmetic() {
617 test_len_function::<f32, _>(StdArithmetic);
618 test_len_function::<f64, _>(StdArithmetic);
619 }
620
621 #[test]
622 fn len_function_in_int_arithmetic() {
623 test_len_function::<u8, _>(WrappingArithmetic);
624 test_len_function::<i8, _>(WrappingArithmetic);
625 test_len_function::<u64, _>(WrappingArithmetic);
626 test_len_function::<i64, _>(WrappingArithmetic);
627 }
628
629 #[test]
630 fn len_function_with_number_overflow() -> anyhow::Result<()> {
631 let code = "len(xs)";
632 let block = Untyped::<NumGrammar<i8>>::parse_statements(code)?;
633 let module = ExecutableModule::new("len", &block)?;
634
635 let mut env = Environment::with_arithmetic(WrappingArithmetic);
636 env.insert("xs", Value::from(vec![Value::Bool(true); 128]))
637 .insert_native_fn("len", Len);
638
639 let err = module.with_env(&env)?.run().unwrap_err();
640 assert_matches!(
641 err.source().kind(),
642 ErrorKind::NativeCall(msg) if msg.contains("length to number")
643 );
644 Ok(())
645 }
646
647 #[test]
648 fn array_function_in_floating_point_arithmetic() -> anyhow::Result<()> {
649 let code = "
650 array(0, |_| 1) == () && array(-1, |_| 1) == () &&
651 array(0.1, |_| 1) == () && array(0.999, |_| 1) == () &&
652 array(1, |_| 1) == (1,) && array(1.5, |_| 1) == (1,) &&
653 array(2, |_| 1) == (1, 1) && array(3, |i| i) == (0, 1, 2)
654 ";
655 let block = Untyped::<NumGrammar<f32>>::parse_statements(code)?;
656 let module = ExecutableModule::new("array", &block)?;
657
658 let mut env = Environment::new();
659 env.insert_native_fn("array", Array);
660
661 let output = module.with_env(&env)?.run()?;
662 assert_matches!(output, Value::Bool(true));
663 Ok(())
664 }
665
666 #[test]
667 fn array_function_in_unsigned_int_arithmetic() -> anyhow::Result<()> {
668 let code = "
669 array(0, |_| 1) == () && array(1, |_| 1) == (1,) && array(3, |i| i) == (0, 1, 2)
670 ";
671 let block = Untyped::<NumGrammar<u32>>::parse_statements(code)?;
672 let module = ExecutableModule::new("array", &block)?;
673
674 let mut env = Environment::with_arithmetic(WrappingArithmetic);
675 env.insert_native_fn("array", Array);
676
677 let output = module.with_env(&env)?.run()?;
678 assert_matches!(output, Value::Bool(true));
679 Ok(())
680 }
681
682 #[test]
683 fn all_and_any_are_short_circuit() -> anyhow::Result<()> {
684 let code = "
685 !all((1, 5 == 5), |x| x < 0) && any((-1, 1, 5 == 4), |x| x > 0)
686 ";
687 let block = Untyped::<F32Grammar>::parse_statements(code)?;
688 let module = ExecutableModule::new("array", &block)?;
689
690 let mut env = Environment::new();
691 env.insert_native_fn("all", All)
692 .insert_native_fn("any", Any);
693
694 let output = module.with_env(&env)?.run()?;
695 assert_matches!(output, Value::Bool(true));
696 Ok(())
697 }
698}