arithmetic_eval/fns/
assertions.rs

1//! Assertion functions.
2
3use core::{cmp::Ordering, fmt};
4
5use super::extract_fn;
6use crate::{
7    alloc::Vec,
8    error::{AuxErrorInfo, Error},
9    CallContext, ErrorKind, EvalResult, NativeFn, SpannedValue, Value,
10};
11
12/// Assertion function.
13///
14/// # Type
15///
16/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
17///
18/// ```text
19/// (Bool) -> ()
20/// ```
21///
22/// # Examples
23///
24/// ```
25/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
26/// # use arithmetic_eval::{fns, Environment, ErrorKind, ExecutableModule};
27/// # use assert_matches::assert_matches;
28/// # fn main() -> anyhow::Result<()> {
29/// let program = "
30///     assert(1 + 2 != 5); // this assertion is fine
31///     assert(3^2 > 10); // this one will fail
32/// ";
33/// let module = Untyped::<F32Grammar>::parse_statements(program)?;
34/// let module = ExecutableModule::new("test_assert", &module)?;
35///
36/// let mut env = Environment::new();
37/// env.insert_native_fn("assert", fns::Assert);
38///
39/// let err = module.with_env(&env)?.run().unwrap_err();
40/// assert_eq!(
41///     err.source().location().in_module().span(&program),
42///     "assert(3^2 > 10)"
43/// );
44/// assert_matches!(
45///     err.source().kind(),
46///     ErrorKind::NativeCall(msg) if msg == "Assertion failed"
47/// );
48/// # Ok(())
49/// # }
50/// ```
51#[derive(Debug, Clone, Copy, Default)]
52pub struct Assert;
53
54impl<T> NativeFn<T> for Assert {
55    fn evaluate<'a>(
56        &self,
57        args: Vec<SpannedValue<T>>,
58        ctx: &mut CallContext<'_, T>,
59    ) -> EvalResult<T> {
60        ctx.check_args_count(&args, 1)?;
61        match args[0].extra {
62            Value::Bool(true) => Ok(Value::void()),
63
64            Value::Bool(false) => {
65                let err = ErrorKind::native("Assertion failed");
66                Err(ctx.call_site_error(err))
67            }
68
69            _ => {
70                let err = ErrorKind::native("`assert` requires a single boolean argument");
71                Err(ctx
72                    .call_site_error(err)
73                    .with_location(&args[0], AuxErrorInfo::InvalidArg))
74            }
75        }
76    }
77}
78
79fn create_error_with_values<T: fmt::Display>(
80    err: ErrorKind,
81    args: &[SpannedValue<T>],
82    ctx: &CallContext<'_, T>,
83) -> Error {
84    ctx.call_site_error(err)
85        .with_location(&args[0], AuxErrorInfo::arg_value(&args[0].extra))
86        .with_location(&args[1], AuxErrorInfo::arg_value(&args[1].extra))
87}
88
89/// Equality assertion function.
90///
91/// # Type
92///
93/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
94///
95/// ```text
96/// ('T, 'T) -> ()
97/// ```
98///
99/// # Examples
100///
101/// ```
102/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
103/// # use arithmetic_eval::{fns, Environment, ErrorKind, ExecutableModule};
104/// # use assert_matches::assert_matches;
105/// # fn main() -> anyhow::Result<()> {
106/// let program = "
107///     assert_eq(1 + 2, 3); // this assertion is fine
108///     assert_eq(3^2, 10); // this one will fail
109/// ";
110/// let module = Untyped::<F32Grammar>::parse_statements(program)?;
111/// let module = ExecutableModule::new("test_assert", &module)?;
112///
113/// let mut env = Environment::new();
114/// env.insert_native_fn("assert_eq", fns::AssertEq);
115///
116/// let err = module.with_env(&env)?.run().unwrap_err();
117/// assert_eq!(
118///     err.source().location().in_module().span(program),
119///     "assert_eq(3^2, 10)"
120/// );
121/// assert_matches!(
122///     err.source().kind(),
123///     ErrorKind::NativeCall(msg) if msg == "Equality assertion failed"
124/// );
125/// # Ok(())
126/// # }
127/// ```
128#[derive(Debug, Clone, Copy, Default)]
129pub struct AssertEq;
130
131impl<T: fmt::Display> NativeFn<T> for AssertEq {
132    fn evaluate(&self, args: Vec<SpannedValue<T>>, ctx: &mut CallContext<'_, T>) -> EvalResult<T> {
133        ctx.check_args_count(&args, 2)?;
134
135        let is_equal = args[0]
136            .extra
137            .eq_by_arithmetic(&args[1].extra, ctx.arithmetic());
138
139        if is_equal {
140            Ok(Value::void())
141        } else {
142            let err = ErrorKind::native("Equality assertion failed");
143            Err(create_error_with_values(err, &args, ctx))
144        }
145    }
146}
147
148/// Assertion that two values are close to each other.
149///
150/// Unlike [`AssertEq`], the arguments must be primitive. The function is parameterized by
151/// the tolerance threshold.
152///
153/// # Type
154///
155/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
156///
157/// ```text
158/// (Num, Num) -> ()
159/// ```
160///
161/// # Examples
162///
163/// ```
164/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
165/// # use arithmetic_eval::{fns, Environment, ExecutableModule};
166/// # use assert_matches::assert_matches;
167/// # fn main() -> anyhow::Result<()> {
168/// let program = "
169///     assert_close(sqrt(9), 3); // this assertion is fine
170///     assert_close(sqrt(10), 3); // this one should fail
171/// ";
172/// let module = Untyped::<F32Grammar>::parse_statements(program)?;
173/// let module = ExecutableModule::new("test_assert", &module)?;
174///
175/// let mut env = Environment::new();
176/// env.insert_native_fn("assert_close", fns::AssertClose::new(1e-4))
177///     .insert_wrapped_fn("sqrt", f32::sqrt);
178///
179/// let err = module.with_env(&env)?.run().unwrap_err();
180/// assert_eq!(
181///     err.source().location().in_module().span(program),
182///     "assert_close(sqrt(10), 3)"
183/// );
184/// # Ok(())
185/// # }
186/// ```
187// TODO: support structured values?
188#[derive(Debug, Clone, Copy)]
189pub struct AssertClose<T> {
190    tolerance: T,
191}
192
193impl<T> AssertClose<T> {
194    /// Creates a function with the specified tolerance threshold. No checks are performed
195    /// on the threshold (e.g., that it is positive).
196    pub const fn new(tolerance: T) -> Self {
197        Self { tolerance }
198    }
199
200    fn extract_primitive_ref<'r>(
201        ctx: &mut CallContext<'_, T>,
202        value: &'r SpannedValue<T>,
203    ) -> Result<&'r T, Error> {
204        const ARG_ERROR: &str = "Function arguments must be primitive numbers";
205
206        match &value.extra {
207            Value::Prim(value) => Ok(value),
208            _ => Err(ctx
209                .call_site_error(ErrorKind::native(ARG_ERROR))
210                .with_location(value, AuxErrorInfo::InvalidArg)),
211        }
212    }
213}
214
215impl<T: Clone + fmt::Display> NativeFn<T> for AssertClose<T> {
216    fn evaluate(&self, args: Vec<SpannedValue<T>>, ctx: &mut CallContext<'_, T>) -> EvalResult<T> {
217        ctx.check_args_count(&args, 2)?;
218        let rhs = Self::extract_primitive_ref(ctx, &args[0])?;
219        let lhs = Self::extract_primitive_ref(ctx, &args[1])?;
220
221        let arith = ctx.arithmetic();
222        let diff = match arith.partial_cmp(lhs, rhs) {
223            Some(Ordering::Less | Ordering::Equal) => arith.sub(rhs.clone(), lhs.clone()),
224            Some(Ordering::Greater) => arith.sub(lhs.clone(), rhs.clone()),
225            None => {
226                let err = ErrorKind::native("Values are not comparable");
227                return Err(create_error_with_values(err, &args, ctx));
228            }
229        };
230        let diff = diff.map_err(|err| ctx.call_site_error(ErrorKind::Arithmetic(err)))?;
231
232        match arith.partial_cmp(&diff, &self.tolerance) {
233            Some(Ordering::Less | Ordering::Equal) => Ok(Value::void()),
234            Some(Ordering::Greater) => {
235                let err = ErrorKind::native("Values are not close");
236                Err(create_error_with_values(err, &args, ctx))
237            }
238            None => {
239                let err = ErrorKind::native("Error comparing value difference to tolerance");
240                Err(ctx.call_site_error(err))
241            }
242        }
243    }
244}
245
246/// Assertion that the provided function raises an error. Errors can optionally be matched
247/// against a predicate.
248///
249/// If an error is raised, but does not match the predicate, it is bubbled up.
250///
251/// # Type
252///
253/// (using [`arithmetic-typing`](https://docs.rs/arithmetic-typing/) notation)
254///
255/// ```text
256/// (() -> 'T) -> ()
257/// ```
258///
259/// # Examples
260///
261/// ```
262/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
263/// # use arithmetic_eval::{fns, Environment, ExecutableModule};
264/// # use assert_matches::assert_matches;
265/// # fn main() -> anyhow::Result<()> {
266/// let program = "
267///     obj = #{ x: 3 };
268///     assert_fails(|| obj.x + obj.y); // pass: `obj.y` is not defined
269///     assert_fails(|| obj.x); // fail: function executes successfully
270/// ";
271/// let module = Untyped::<F32Grammar>::parse_statements(program)?;
272/// let module = ExecutableModule::new("test_assert", &module)?;
273///
274/// let mut env = Environment::new();
275/// env.insert_native_fn("assert_fails", fns::AssertFails::default());
276///
277/// let err = module.with_env(&env)?.run().unwrap_err();
278/// assert_eq!(
279///     err.source().location().in_module().span(program),
280///     "assert_fails(|| obj.x)"
281/// );
282/// # Ok(())
283/// # }
284/// ```
285///
286/// Custom error matching:
287///
288/// ```
289/// # use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
290/// # use arithmetic_eval::{ErrorKind, fns, Environment, ExecutableModule};
291/// # use assert_matches::assert_matches;
292/// # fn main() -> anyhow::Result<()> {
293/// let assert_fails = fns::AssertFails::new(|err| {
294///     matches!(err.kind(), ErrorKind::NativeCall(_))
295/// });
296///
297/// let program = "
298///     assert_fails(|| assert_fails(1)); // pass: native error
299///     assert_fails(assert_fails); // fail: arg len mismatch
300/// ";
301/// let module = Untyped::<F32Grammar>::parse_statements(program)?;
302/// let module = ExecutableModule::new("test_assert", &module)?;
303///
304/// let mut env = Environment::new();
305/// env.insert_native_fn("assert_fails", assert_fails);
306///
307/// let err = module.with_env(&env)?.run().unwrap_err();
308/// assert_eq!(
309///     err.source().location().in_module().span(program),
310///     "assert_fails(assert_fails)"
311/// );
312/// # Ok(())
313/// # }
314/// ```
315#[derive(Clone, Copy)]
316pub struct AssertFails {
317    error_matcher: fn(&Error) -> bool,
318}
319
320impl fmt::Debug for AssertFails {
321    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
322        formatter.debug_tuple("AssertFails").finish()
323    }
324}
325
326impl Default for AssertFails {
327    fn default() -> Self {
328        Self {
329            error_matcher: |_| true,
330        }
331    }
332}
333
334impl AssertFails {
335    /// Creates an assertion function with a custom error matcher. If the error does not match,
336    /// the assertion will fail, and the error will bubble up.
337    pub fn new(error_matcher: fn(&Error) -> bool) -> Self {
338        Self { error_matcher }
339    }
340}
341
342impl<T: 'static + Clone> NativeFn<T> for AssertFails {
343    fn evaluate(
344        &self,
345        mut args: Vec<SpannedValue<T>>,
346        ctx: &mut CallContext<'_, T>,
347    ) -> EvalResult<T> {
348        const ARG_ERROR: &str = "Single argument must be a function";
349
350        ctx.check_args_count(&args, 1)?;
351        let closure = extract_fn(ctx, args.pop().unwrap(), ARG_ERROR)?;
352        match closure.evaluate(Vec::new(), ctx) {
353            Ok(_) => {
354                let err = ErrorKind::native("Function did not fail");
355                Err(ctx.call_site_error(err))
356            }
357            Err(err) => {
358                if (self.error_matcher)(&err) {
359                    Ok(Value::void())
360                } else {
361                    // Pass the error through.
362                    Err(err)
363                }
364            }
365        }
366    }
367}
368
369#[cfg(test)]
370mod tests {
371    use arithmetic_parser::{Location, LvalueLen};
372    use assert_matches::assert_matches;
373
374    use super::*;
375    use crate::{arith::CheckedArithmetic, exec::WildcardId, Environment, Object};
376
377    fn span_value<T>(value: Value<T>) -> SpannedValue<T> {
378        Location::from_str("", ..).copy_with_extra(value)
379    }
380
381    #[test]
382    fn assert_basics() {
383        let env = Environment::with_arithmetic(<CheckedArithmetic>::new());
384        let mut ctx = CallContext::<u32>::mock(WildcardId, Location::from_str("", ..), &env);
385
386        let err = Assert.evaluate(vec![], &mut ctx).unwrap_err();
387        assert_matches!(err.kind(), ErrorKind::ArgsLenMismatch { .. });
388
389        let invalid_arg = span_value(Value::Prim(1));
390        let err = Assert.evaluate(vec![invalid_arg], &mut ctx).unwrap_err();
391        assert_matches!(
392            err.kind(),
393            ErrorKind::NativeCall(s) if s.contains("requires a single boolean argument")
394        );
395
396        let false_arg = span_value(Value::Bool(false));
397        let err = Assert.evaluate(vec![false_arg], &mut ctx).unwrap_err();
398        assert_matches!(
399            err.kind(),
400            ErrorKind::NativeCall(s) if s.contains("Assertion failed")
401        );
402
403        let true_arg = span_value(Value::Bool(true));
404        let return_value = Assert.evaluate(vec![true_arg.clone()], &mut ctx).unwrap();
405        assert!(return_value.is_void(), "{return_value:?}");
406
407        let err = Assert
408            .evaluate(vec![true_arg.clone(), true_arg], &mut ctx)
409            .unwrap_err();
410        assert_matches!(err.kind(), ErrorKind::ArgsLenMismatch { .. });
411    }
412
413    #[test]
414    fn assert_eq_basics() {
415        let env = Environment::with_arithmetic(<CheckedArithmetic>::new());
416        let mut ctx = CallContext::<u32>::mock(WildcardId, Location::from_str("", ..), &env);
417
418        let err = AssertEq.evaluate(vec![], &mut ctx).unwrap_err();
419        assert_matches!(err.kind(), ErrorKind::ArgsLenMismatch { .. });
420
421        let x = span_value(Value::Prim(1));
422        let y = span_value(Value::Prim(2));
423        let err = AssertEq.evaluate(vec![x.clone(), y], &mut ctx).unwrap_err();
424        assert_matches!(
425            err.kind(),
426            ErrorKind::NativeCall(s) if s.contains("assertion failed")
427        );
428
429        let return_value = AssertEq.evaluate(vec![x.clone(), x], &mut ctx).unwrap();
430        assert!(return_value.is_void(), "{return_value:?}");
431    }
432
433    #[test]
434    fn assert_close_basics() {
435        let assert_close = AssertClose::new(1e-3);
436        let env = Environment::new();
437        let mut ctx = CallContext::<f32>::mock(WildcardId, Location::from_str("", ..), &env);
438
439        let err = assert_close.evaluate(vec![], &mut ctx).unwrap_err();
440        assert_matches!(err.kind(), ErrorKind::ArgsLenMismatch { .. });
441
442        let one_arg = span_value(Value::Prim(1.0));
443        let invalid_args = [
444            Value::Bool(true),
445            vec![Value::Prim(1.0)].into(),
446            Object::just("test", Value::Prim(1.0)).into(),
447        ];
448        for invalid_arg in invalid_args {
449            let err = assert_close
450                .evaluate(vec![one_arg.clone(), span_value(invalid_arg)], &mut ctx)
451                .unwrap_err();
452            assert_matches!(
453                err.kind(),
454                ErrorKind::NativeCall(s) if s.contains("must be primitive numbers")
455            );
456        }
457
458        let distant_values = &[(0.0, 1.0), (1.0, 1.01), (0.0, f32::INFINITY)];
459        for &(x, y) in distant_values {
460            let x = span_value(Value::Prim(x));
461            let y = span_value(Value::Prim(y));
462            let err = assert_close.evaluate(vec![x, y], &mut ctx).unwrap_err();
463            assert_matches!(
464                err.kind(),
465                ErrorKind::NativeCall(s) if s.contains("Values are not close")
466            );
467        }
468
469        let non_comparable_values = &[(0.0, f32::NAN), (f32::NAN, 1.0), (f32::NAN, f32::NAN)];
470        for &(x, y) in non_comparable_values {
471            let x = span_value(Value::Prim(x));
472            let y = span_value(Value::Prim(y));
473            let err = assert_close.evaluate(vec![x, y], &mut ctx).unwrap_err();
474            assert_matches!(
475                err.kind(),
476                ErrorKind::NativeCall(s) if s.contains("Values are not comparable")
477            );
478        }
479
480        let close_values = &[(1.0, 0.9999), (0.9999, 1.0), (1.0, 1.0)];
481        for &(x, y) in close_values {
482            let x = span_value(Value::Prim(x));
483            let y = span_value(Value::Prim(y));
484            let return_value = assert_close.evaluate(vec![x, y], &mut ctx).unwrap();
485            assert!(return_value.is_void(), "{return_value:?}");
486        }
487    }
488
489    #[test]
490    fn assert_fails_basics() {
491        let assert_fails = AssertFails::default();
492        let env = Environment::new();
493        let mut ctx = CallContext::<f32>::mock(WildcardId, Location::from_str("", ..), &env);
494
495        let err = assert_fails.evaluate(vec![], &mut ctx).unwrap_err();
496        assert_matches!(err.kind(), ErrorKind::ArgsLenMismatch { .. });
497
498        let invalid_arg = span_value(Value::Prim(1.0));
499        let err = assert_fails
500            .evaluate(vec![invalid_arg], &mut ctx)
501            .unwrap_err();
502        assert_matches!(
503            err.kind(),
504            ErrorKind::NativeCall(s) if s.contains("must be a function")
505        );
506
507        let successful_fn = span_value(Value::wrapped_fn(|| true));
508        let err = assert_fails
509            .evaluate(vec![successful_fn], &mut ctx)
510            .unwrap_err();
511        assert_matches!(
512            err.kind(),
513            ErrorKind::NativeCall(s) if s.contains("Function did not fail")
514        );
515
516        let failing_fn = Value::wrapped_fn(|| Err::<f32, _>("oops".to_owned()));
517        let return_value = assert_fails
518            .evaluate(vec![span_value(failing_fn)], &mut ctx)
519            .unwrap();
520        assert!(return_value.is_void(), "{return_value:?}");
521    }
522
523    #[test]
524    fn assert_fails_with_custom_matcher() {
525        let assert_fails = AssertFails::new(
526            |err| matches!(err.kind(), ErrorKind::NativeCall(msg) if msg == "oops"),
527        );
528        let env = Environment::new();
529        let mut ctx = CallContext::<f32>::mock(WildcardId, Location::from_str("", ..), &env);
530
531        let wrong_fn = Value::wrapped_fn(f32::abs);
532        let err = assert_fails
533            .evaluate(vec![span_value(wrong_fn)], &mut ctx)
534            .unwrap_err();
535        assert_matches!(
536            err.kind(),
537            ErrorKind::ArgsLenMismatch { def, call: 0 } if *def == LvalueLen::Exact(1)
538        );
539
540        let failing_fn = Value::wrapped_fn(|| Err::<f32, _>("oops".to_owned()));
541        let return_value = assert_fails
542            .evaluate(vec![span_value(failing_fn)], &mut ctx)
543            .unwrap();
544        assert!(return_value.is_void(), "{return_value:?}");
545    }
546}