arithmetic_eval/fns/flow.rs
1//! Flow control functions.
2
3use crate::{
4 alloc::{vec, Vec},
5 error::AuxErrorInfo,
6 fns::extract_fn,
7 CallContext, ErrorKind, EvalResult, NativeFn, SpannedValue, Value,
8};
9
10/// `if` function that eagerly evaluates "if" / "else" terms.
11///
12/// # Type
13///
14/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
15///
16/// ```text
17/// (Bool, 'T, 'T) -> 'T
18/// ```
19///
20/// # Examples
21///
22/// ```
23/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
24/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
25/// # fn main() -> anyhow::Result<()> {
26/// let program = "x = 3; if(x == 2, -1, x + 1)";
27/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
28/// let module = ExecutableModule::new("test_if", &program)?;
29///
30/// let mut env = Environment::new();
31/// env.insert_native_fn("if", fns::If);
32/// assert_eq!(module.with_env(&env)?.run()?, Value::Prim(4.0));
33/// # Ok(())
34/// # }
35/// ```
36///
37/// You can also use the lazy evaluation by returning a function and evaluating it
38/// afterwards:
39///
40/// ```
41/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
42/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
43/// # fn main() -> anyhow::Result<()> {
44/// let program = "x = 3; if(x == 2, || -1, || x + 1)()";
45/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
46/// let module = ExecutableModule::new("test_if", &program)?;
47///
48/// let mut env = Environment::new();
49/// env.insert_native_fn("if", fns::If);
50/// assert_eq!(module.with_env(&env)?.run()?, Value::Prim(4.0));
51/// # Ok(())
52/// # }
53/// ```
54#[derive(Debug, Clone, Copy, Default)]
55pub struct If;
56
57impl<T> NativeFn<T> for If {
58 fn evaluate(
59 &self,
60 mut args: Vec<SpannedValue<T>>,
61 ctx: &mut CallContext<'_, T>,
62 ) -> EvalResult<T> {
63 ctx.check_args_count(&args, 3)?;
64 let else_val = args.pop().unwrap().extra;
65 let then_val = args.pop().unwrap().extra;
66
67 if let Value::Bool(condition) = &args[0].extra {
68 Ok(if *condition { then_val } else { else_val })
69 } else {
70 let err = ErrorKind::native("`if` requires first arg to be boolean");
71 Err(ctx
72 .call_site_error(err)
73 .with_location(&args[0], AuxErrorInfo::InvalidArg))
74 }
75 }
76}
77
78/// Loop function that evaluates the provided closure while a certain condition is true.
79/// Returns the loop state afterwards.
80///
81/// # Type
82///
83/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
84///
85/// ```text
86/// ('T, ('T) -> Bool, ('T) -> 'T) -> 'T
87/// ```
88///
89/// # Examples
90///
91/// ```
92/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
93/// # use arithmetic_eval::{fns, Environment, ExecutableModule, Value};
94/// # fn main() -> anyhow::Result<()> {
95/// let program = "
96/// factorial = |x| {
97/// (_, acc) = while(
98/// (x, 1),
99/// |(i, _)| i >= 1,
100/// |(i, acc)| (i - 1, acc * i),
101/// );
102/// acc
103/// };
104/// factorial(5) == 120 && factorial(10) == 3628800
105/// ";
106/// let program = Untyped::<F32Grammar>::parse_statements(program)?;
107/// let module = ExecutableModule::new("test_while", &program)?;
108///
109/// let mut env = Environment::new();
110/// env.insert_native_fn("while", fns::While);
111/// assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
112/// # Ok(())
113/// # }
114/// ```
115#[derive(Debug, Clone, Copy, Default)]
116pub struct While;
117
118impl<T: 'static + Clone> NativeFn<T> for While {
119 fn evaluate(
120 &self,
121 mut args: Vec<SpannedValue<T>>,
122 ctx: &mut CallContext<'_, T>,
123 ) -> EvalResult<T> {
124 ctx.check_args_count(&args, 3)?;
125
126 let step_fn = extract_fn(
127 ctx,
128 args.pop().unwrap(),
129 "`while` requires third arg to be a step function",
130 )?;
131 let condition_fn = extract_fn(
132 ctx,
133 args.pop().unwrap(),
134 "`while` requires second arg to be a condition function",
135 )?;
136 let mut state = args.pop().unwrap();
137 let state_span = state.copy_with_extra(());
138
139 loop {
140 let condition_value = condition_fn.evaluate(vec![state.clone()], ctx)?;
141 match condition_value {
142 Value::Bool(true) => {
143 let new_state = step_fn.evaluate(vec![state], ctx)?;
144 state = state_span.copy_with_extra(new_state);
145 }
146 Value::Bool(false) => break Ok(state.extra),
147 _ => {
148 let err =
149 ErrorKind::native("`while` requires condition function to return booleans");
150 return Err(ctx.call_site_error(err));
151 }
152 }
153 }
154 }
155}