1use core::{cmp::Ordering, fmt};
15
16use once_cell::unsync::OnceCell;
17
18#[cfg(feature = "std")]
19pub use self::std::Dbg;
20pub use self::{
21 array::{All, Any, Array, Filter, Fold, Len, Map, Merge, Push},
22 assertions::{Assert, AssertClose, AssertEq, AssertFails},
23 flow::{If, While},
24 wrapper::{
25 wrap, Binary, ErrorOutput, FnWrapper, FromValueError, FromValueErrorKind,
26 FromValueErrorLocation, IntoEvalResult, Quaternary, Ternary, TryFromValue, Unary,
27 },
28};
29use crate::{
30 alloc::{vec, Vec},
31 error::AuxErrorInfo,
32 CallContext, Error, ErrorKind, EvalResult, Function, NativeFn, OpaqueRef, SpannedValue, Value,
33};
34
35mod array;
36mod assertions;
37mod flow;
38#[cfg(feature = "std")]
39mod std;
40mod wrapper;
41
42fn extract_primitive<T, A>(
43 ctx: &CallContext<'_, A>,
44 value: SpannedValue<T>,
45 error_msg: &str,
46) -> Result<T, Error> {
47 match value.extra {
48 Value::Prim(value) => Ok(value),
49 _ => Err(ctx
50 .call_site_error(ErrorKind::native(error_msg))
51 .with_location(&value, AuxErrorInfo::InvalidArg)),
52 }
53}
54
55fn extract_array<T, A>(
56 ctx: &CallContext<'_, A>,
57 value: SpannedValue<T>,
58 error_msg: &str,
59) -> Result<Vec<Value<T>>, Error> {
60 if let Value::Tuple(array) = value.extra {
61 Ok(array.into())
62 } else {
63 let err = ErrorKind::native(error_msg);
64 Err(ctx
65 .call_site_error(err)
66 .with_location(&value, AuxErrorInfo::InvalidArg))
67 }
68}
69
70fn extract_fn<T, A>(
71 ctx: &CallContext<'_, A>,
72 value: SpannedValue<T>,
73 error_msg: &str,
74) -> Result<Function<T>, Error> {
75 if let Value::Function(function) = value.extra {
76 Ok(function)
77 } else {
78 let err = ErrorKind::native(error_msg);
79 Err(ctx
80 .call_site_error(err)
81 .with_location(&value, AuxErrorInfo::InvalidArg))
82 }
83}
84
85#[derive(Debug, Clone, Copy)]
142#[non_exhaustive]
143pub enum Compare {
144 Raw,
147 Min,
149 Max,
151}
152
153impl Compare {
154 fn extract_primitives<T>(
155 mut args: Vec<SpannedValue<T>>,
156 ctx: &mut CallContext<'_, T>,
157 ) -> Result<(T, T), Error> {
158 ctx.check_args_count(&args, 2)?;
159 let y = args.pop().unwrap();
160 let x = args.pop().unwrap();
161 let x = extract_primitive(ctx, x, COMPARE_ERROR_MSG)?;
162 let y = extract_primitive(ctx, y, COMPARE_ERROR_MSG)?;
163 Ok((x, y))
164 }
165}
166
167const COMPARE_ERROR_MSG: &str = "Compare requires 2 primitive arguments";
168
169impl<T> NativeFn<T> for Compare {
170 fn evaluate(&self, args: Vec<SpannedValue<T>>, ctx: &mut CallContext<'_, T>) -> EvalResult<T> {
171 let (x, y) = Self::extract_primitives(args, ctx)?;
172 let maybe_ordering = ctx.arithmetic().partial_cmp(&x, &y);
173
174 if let Self::Raw = self {
175 Ok(maybe_ordering.map_or_else(Value::void, Value::opaque_ref))
176 } else {
177 let ordering =
178 maybe_ordering.ok_or_else(|| ctx.call_site_error(ErrorKind::CannotCompare))?;
179 let value = match (ordering, self) {
180 (Ordering::Equal, _)
181 | (Ordering::Less, Self::Min)
182 | (Ordering::Greater, Self::Max) => x,
183 _ => y,
184 };
185 Ok(Value::Prim(value))
186 }
187 }
188}
189
190#[derive(Debug, Clone, Copy, Default)]
222pub struct Defer;
223
224impl<T: Clone + 'static> NativeFn<T> for Defer {
225 fn evaluate(
226 &self,
227 mut args: Vec<SpannedValue<T>>,
228 ctx: &mut CallContext<'_, T>,
229 ) -> EvalResult<T> {
230 const ARG_ERROR: &str = "Argument must be a function";
231
232 ctx.check_args_count(&args, 1)?;
233 let function = extract_fn(ctx, args.pop().unwrap(), ARG_ERROR)?;
234 let cell = OpaqueRef::with_identity_eq(ValueCell::<T>::default());
235 let spanned_cell = ctx.apply_call_location(Value::Ref(cell.clone()));
236 let return_value = function.evaluate(vec![spanned_cell], ctx)?;
237
238 let cell = cell.downcast_ref::<ValueCell<T>>().unwrap();
239 cell.set(return_value.clone());
241 Ok(return_value)
242 }
243}
244
245#[derive(Debug)]
246pub(crate) struct ValueCell<T> {
247 inner: OnceCell<Value<T>>,
248}
249
250impl<T> Default for ValueCell<T> {
251 fn default() -> Self {
252 Self {
253 inner: OnceCell::new(),
254 }
255 }
256}
257
258impl<T: 'static + fmt::Debug> From<ValueCell<T>> for Value<T> {
259 fn from(cell: ValueCell<T>) -> Self {
260 Self::Ref(OpaqueRef::with_identity_eq(cell))
261 }
262}
263
264impl<T> ValueCell<T> {
265 pub fn get(&self) -> Option<&Value<T>> {
267 self.inner.get()
268 }
269
270 fn set(&self, value: Value<T>) {
271 self.inner
272 .set(value)
273 .map_err(drop)
274 .expect("Repeated `ValueCell` assignment");
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
281 use assert_matches::assert_matches;
282
283 use super::*;
284 use crate::{
285 env::Environment,
286 exec::{ExecutableModule, WildcardId},
287 };
288
289 #[test]
290 fn if_basics() -> anyhow::Result<()> {
291 let block = "
292 x = 1.0;
293 if(x < 2, x + 5, 3 - x)
294 ";
295 let block = Untyped::<F32Grammar>::parse_statements(block)?;
296 let module = ExecutableModule::new(WildcardId, &block)?;
297 let mut env = Environment::new();
298 env.insert_native_fn("if", If);
299 assert_eq!(module.with_env(&env)?.run()?, Value::Prim(6.0));
300 Ok(())
301 }
302
303 #[test]
304 fn if_with_closures() -> anyhow::Result<()> {
305 let block = "
306 x = 4.5;
307 if(x < 2, || x + 5, || 3 - x)()
308 ";
309 let block = Untyped::<F32Grammar>::parse_statements(block)?;
310 let module = ExecutableModule::new(WildcardId, &block)?;
311 let mut env = Environment::new();
312 env.insert_native_fn("if", If);
313 assert_eq!(module.with_env(&env)?.run()?, Value::Prim(-1.5));
314 Ok(())
315 }
316
317 #[test]
318 fn cmp_sugar() -> anyhow::Result<()> {
319 let program = "x = 1.0; x > 0 && x <= 3";
320 let block = Untyped::<F32Grammar>::parse_statements(program)?;
321 let module = ExecutableModule::new(WildcardId, &block)?;
322 assert_eq!(
323 module.with_env(&Environment::new())?.run()?,
324 Value::Bool(true)
325 );
326
327 let bogus_program = "x = 1.0; x > (1, 2)";
328 let bogus_block = Untyped::<F32Grammar>::parse_statements(bogus_program)?;
329 let bogus_module = ExecutableModule::new(WildcardId, &bogus_block)?;
330
331 let err = bogus_module
332 .with_env(&Environment::new())?
333 .run()
334 .unwrap_err();
335 let err = err.source();
336 assert_matches!(err.kind(), ErrorKind::CannotCompare);
337 assert_eq!(err.location().in_module().span(bogus_program), "(1, 2)");
338 Ok(())
339 }
340
341 #[test]
342 fn while_basic() -> anyhow::Result<()> {
343 let program = "
344 // Finds the greatest power of 2 lesser or equal to the value.
345 discrete_log2 = |x| {
346 while(0, |i| 2^i <= x, |i| i + 1) - 1
347 };
348
349 (discrete_log2(1), discrete_log2(2),
350 discrete_log2(4), discrete_log2(6.5), discrete_log2(1000))
351 ";
352 let block = Untyped::<F32Grammar>::parse_statements(program)?;
353
354 let module = ExecutableModule::new(WildcardId, &block)?;
355 let mut env = Environment::new();
356 env.insert_native_fn("while", While)
357 .insert_native_fn("if", If);
358
359 assert_eq!(
360 module.with_env(&env)?.run()?,
361 Value::from(vec![
362 Value::Prim(0.0),
363 Value::Prim(1.0),
364 Value::Prim(2.0),
365 Value::Prim(2.0),
366 Value::Prim(9.0),
367 ])
368 );
369 Ok(())
370 }
371
372 #[test]
373 fn max_value_with_fold() -> anyhow::Result<()> {
374 let program = "
375 max_value = |...xs| {
376 fold(xs, -Inf, |acc, x| if(x > acc, x, acc))
377 };
378 max_value(1, -2, 7, 2, 5) == 7 && max_value(3, -5, 9) == 9
379 ";
380 let block = Untyped::<F32Grammar>::parse_statements(program)?;
381
382 let module = ExecutableModule::new(WildcardId, &block)?;
383 let mut env = Environment::new();
384 env.insert("Inf", Value::Prim(f32::INFINITY))
385 .insert_native_fn("fold", Fold)
386 .insert_native_fn("if", If);
387
388 assert_eq!(module.with_env(&env)?.run()?, Value::Bool(true));
389 Ok(())
390 }
391
392 #[test]
393 fn reverse_list_with_fold() -> anyhow::Result<()> {
394 const SAMPLES: &[(&[f32], &[f32])] = &[
395 (&[1.0, 2.0, 3.0], &[3.0, 2.0, 1.0]),
396 (&[], &[]),
397 (&[1.0], &[1.0]),
398 ];
399
400 let program = "
401 reverse = |xs| {
402 fold(xs, (), |acc, x| merge((x,), acc))
403 };
404 xs = (-4, 3, 0, 1);
405 reverse(xs) == (1, 0, 3, -4)
406 ";
407 let block = Untyped::<F32Grammar>::parse_statements(program)?;
408 let module = ExecutableModule::new(WildcardId, &block)?;
409
410 let mut env = Environment::new();
411 env.insert_native_fn("merge", Merge)
412 .insert_native_fn("fold", Fold);
413
414 assert_eq!(module.with_mutable_env(&mut env)?.run()?, Value::Bool(true));
415
416 let test_block = Untyped::<F32Grammar>::parse_statements("reverse(xs)")?;
417 let test_module = ExecutableModule::new("test", &test_block)?;
418
419 for &(input, expected) in SAMPLES {
420 let input = input.iter().copied().map(Value::Prim).collect();
421 let expected = expected.iter().copied().map(Value::Prim).collect();
422 env.insert("xs", Value::Tuple(input));
423 assert_eq!(test_module.with_env(&env)?.run()?, Value::Tuple(expected));
424 }
425 Ok(())
426 }
427
428 #[test]
429 fn error_with_min_function_args() -> anyhow::Result<()> {
430 let program = "5 - min(1, (2, 3))";
431 let block = Untyped::<F32Grammar>::parse_statements(program)?;
432 let module = ExecutableModule::new(WildcardId, &block)?;
433 let mut env = Environment::new();
434 env.insert_native_fn("min", Compare::Min);
435
436 let err = module.with_env(&env)?.run().unwrap_err();
437 let err = err.source();
438 assert_eq!(err.location().in_module().span(program), "min(1, (2, 3))");
439 assert_matches!(
440 err.kind(),
441 ErrorKind::NativeCall(ref msg) if msg.contains("requires 2 primitive arguments")
442 );
443 Ok(())
444 }
445
446 #[test]
447 fn error_with_min_function_incomparable_args() -> anyhow::Result<()> {
448 let program = "5 - min(1, NAN)";
449 let block = Untyped::<F32Grammar>::parse_statements(program)?;
450 let module = ExecutableModule::new(WildcardId, &block)?;
451 let mut env = Environment::new();
452 env.insert("NAN", Value::Prim(f32::NAN))
453 .insert_native_fn("min", Compare::Min);
454
455 let err = module.with_env(&env)?.run().unwrap_err();
456 let err = err.source();
457 assert_eq!(err.location().in_module().span(program), "min(1, NAN)");
458 assert_matches!(err.kind(), ErrorKind::CannotCompare);
459 Ok(())
460 }
461}