arithmetic_eval/fns/
flow.rs

1//! Flow control functions.
2
3use crate::{
4    alloc::{vec, Vec},
5    error::AuxErrorInfo,
6    fns::extract_fn,
7    CallContext, ErrorKind, EvalResult, NativeFn, SpannedValue, Value,
8};
9
10/// `if` function that eagerly evaluates "if" / "else" terms.
11///
12/// # Type
13///
14/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
15///
16/// ```text
17/// (Bool, 'T, 'T) -> 'T
18/// ```
19///
20/// # Examples
21///
22/// ```
23/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
24/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
25/// # fn main() -> anyhow::Result<()> {
26/// let program = "x = 3; if(x == 2, -1, x + 1)";
27/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
28/// let module = ExecutableModule::new("test_if", &program)?;
29///
30/// let mut env = Environment::new();
31/// env.insert_native_fn("if", fns::If);
32/// assert_eq!(module.with_env(&env)?.run()?, Value::Prim(4.0));
33/// # Ok(())
34/// # }
35/// ```
36///
37/// You can also use the lazy evaluation by returning a function and evaluating it
38/// afterwards:
39///
40/// ```
41/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
42/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
43/// # fn main() -> anyhow::Result<()> {
44/// let program = "x = 3; if(x == 2, || -1, || x + 1)()";
45/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
46/// let module = ExecutableModule::new("test_if", &program)?;
47///
48/// let mut env = Environment::new();
49/// env.insert_native_fn("if", fns::If);
50/// assert_eq!(module.with_env(&env)?.run()?, Value::Prim(4.0));
51/// # Ok(())
52/// # }
53/// ```
54#[derive(Debug, Clone, Copy, Default)]
55pub struct If;
56
57impl<T> NativeFn<T> for If {
58    fn evaluate(
59        &self,
60        mut args: Vec<SpannedValue<T>>,
61        ctx: &mut CallContext<'_, T>,
62    ) -> EvalResult<T> {
63        ctx.check_args_count(&args, 3)?;
64        let else_val = args.pop().unwrap().extra;
65        let then_val = args.pop().unwrap().extra;
66
67        if let Value::Bool(condition) = &args[0].extra {
68            Ok(if *condition { then_val } else { else_val })
69        } else {
70            let err = ErrorKind::native("`if` requires first arg to be boolean");
71            Err(ctx
72                .call_site_error(err)
73                .with_location(&args[0], AuxErrorInfo::InvalidArg))
74        }
75    }
76}
77
78/// Loop function that evaluates the provided closure while a certain condition is true.
79/// Returns the loop state afterwards.
80///
81/// # Type
82///
83/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
84///
85/// ```text
86/// ('T, ('T) -> Bool, ('T) -> 'T) -> 'T
87/// ```
88///
89/// # Examples
90///
91/// ```
92/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
93/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
94/// # fn main() -> anyhow::Result<()> {
95/// let program = "
96///     factorial = |x| {
97///         (_, acc) = while(
98///             (x, 1),
99///             |(i, _)| i >= 1,
100///             |(i, acc)| (i - 1, acc * i),
101///         );
102///         acc
103///     };
104///     factorial(5) == 120 && factorial(10) == 3628800
105/// ";
106/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
107/// let module = ExecutableModule::new("test_while", &program)?;
108///
109/// let mut env = Environment::new();
110/// env.insert_native_fn("while", fns::While);
111/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
112/// # Ok(())
113/// # }
114/// ```
115#[derive(Debug, Clone, Copy, Default)]
116pub struct While;
117
118impl<T: 'static + Clone> NativeFn<T> for While {
119    fn evaluate(
120        &self,
121        mut args: Vec<SpannedValue<T>>,
122        ctx: &mut CallContext<'_, T>,
123    ) -> EvalResult<T> {
124        ctx.check_args_count(&args, 3)?;
125
126        let step_fn = extract_fn(
127            ctx,
128            args.pop().unwrap(),
129            "`while` requires third arg to be a step function",
130        )?;
131        let condition_fn = extract_fn(
132            ctx,
133            args.pop().unwrap(),
134            "`while` requires second arg to be a condition function",
135        )?;
136        let mut state = args.pop().unwrap();
137        let state_span = state.copy_with_extra(());
138
139        loop {
140            let condition_value = condition_fn.evaluate(vec![state.clone()], ctx)?;
141            match condition_value {
142                Value::Bool(true) => {
143                    let new_state = step_fn.evaluate(vec![state], ctx)?;
144                    state = state_span.copy_with_extra(new_state);
145                }
146                Value::Bool(false) => break Ok(state.extra),
147                _ => {
148                    let err =
149                        ErrorKind::native("`while` requires condition function to return booleans");
150                    return Err(ctx.call_site_error(err));
151                }
152            }
153        }
154    }
155}