1use core::{cmp::Ordering, fmt};
4
5use super::extract_fn;
6use crate::{
7 alloc::Vec,
8 error::{AuxErrorInfo, Error},
9 CallContext, ErrorKind, EvalResult, NativeFn, SpannedValue, Value,
10};
11
12#[derive(Debug, Clone, Copy, Default)]
52pub struct Assert;
53
54impl<T> NativeFn<T> for Assert {
55 fn evaluate<'a>(
56 &self,
57 args: Vec<SpannedValue<T>>,
58 ctx: &mut CallContext<'_, T>,
59 ) -> EvalResult<T> {
60 ctx.check_args_count(&args, 1)?;
61 match args[0].extra {
62 Value::Bool(true) => Ok(Value::void()),
63
64 Value::Bool(false) => {
65 let err = ErrorKind::native("Assertion failed");
66 Err(ctx.call_site_error(err))
67 }
68
69 _ => {
70 let err = ErrorKind::native("`assert` requires a single boolean argument");
71 Err(ctx
72 .call_site_error(err)
73 .with_location(&args[0], AuxErrorInfo::InvalidArg))
74 }
75 }
76 }
77}
78
79fn create_error_with_values<T: fmt::Display>(
80 err: ErrorKind,
81 args: &[SpannedValue<T>],
82 ctx: &CallContext<'_, T>,
83) -> Error {
84 ctx.call_site_error(err)
85 .with_location(&args[0], AuxErrorInfo::arg_value(&args[0].extra))
86 .with_location(&args[1], AuxErrorInfo::arg_value(&args[1].extra))
87}
88
89#[derive(Debug, Clone, Copy, Default)]
129pub struct AssertEq;
130
131impl<T: fmt::Display> NativeFn<T> for AssertEq {
132 fn evaluate(&self, args: Vec<SpannedValue<T>>, ctx: &mut CallContext<'_, T>) -> EvalResult<T> {
133 ctx.check_args_count(&args, 2)?;
134
135 let is_equal = args[0]
136 .extra
137 .eq_by_arithmetic(&args[1].extra, ctx.arithmetic());
138
139 if is_equal {
140 Ok(Value::void())
141 } else {
142 let err = ErrorKind::native("Equality assertion failed");
143 Err(create_error_with_values(err, &args, ctx))
144 }
145 }
146}
147
148#[derive(Debug, Clone, Copy)]
189pub struct AssertClose<T> {
190 tolerance: T,
191}
192
193impl<T> AssertClose<T> {
194 pub const fn new(tolerance: T) -> Self {
197 Self { tolerance }
198 }
199
200 fn extract_primitive_ref<'r>(
201 ctx: &mut CallContext<'_, T>,
202 value: &'r SpannedValue<T>,
203 ) -> Result<&'r T, Error> {
204 const ARG_ERROR: &str = "Function arguments must be primitive numbers";
205
206 match &value.extra {
207 Value::Prim(value) => Ok(value),
208 _ => Err(ctx
209 .call_site_error(ErrorKind::native(ARG_ERROR))
210 .with_location(value, AuxErrorInfo::InvalidArg)),
211 }
212 }
213}
214
215impl<T: Clone + fmt::Display> NativeFn<T> for AssertClose<T> {
216 fn evaluate(&self, args: Vec<SpannedValue<T>>, ctx: &mut CallContext<'_, T>) -> EvalResult<T> {
217 ctx.check_args_count(&args, 2)?;
218 let rhs = Self::extract_primitive_ref(ctx, &args[0])?;
219 let lhs = Self::extract_primitive_ref(ctx, &args[1])?;
220
221 let arith = ctx.arithmetic();
222 let diff = match arith.partial_cmp(lhs, rhs) {
223 Some(Ordering::Less | Ordering::Equal) => arith.sub(rhs.clone(), lhs.clone()),
224 Some(Ordering::Greater) => arith.sub(lhs.clone(), rhs.clone()),
225 None => {
226 let err = ErrorKind::native("Values are not comparable");
227 return Err(create_error_with_values(err, &args, ctx));
228 }
229 };
230 let diff = diff.map_err(|err| ctx.call_site_error(ErrorKind::Arithmetic(err)))?;
231
232 match arith.partial_cmp(&diff, &self.tolerance) {
233 Some(Ordering::Less | Ordering::Equal) => Ok(Value::void()),
234 Some(Ordering::Greater) => {
235 let err = ErrorKind::native("Values are not close");
236 Err(create_error_with_values(err, &args, ctx))
237 }
238 None => {
239 let err = ErrorKind::native("Error comparing value difference to tolerance");
240 Err(ctx.call_site_error(err))
241 }
242 }
243 }
244}
245
246#[derive(Clone, Copy)]
316pub struct AssertFails {
317 error_matcher: fn(&Error) -> bool,
318}
319
320impl fmt::Debug for AssertFails {
321 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
322 formatter.debug_tuple("AssertFails").finish()
323 }
324}
325
326impl Default for AssertFails {
327 fn default() -> Self {
328 Self {
329 error_matcher: |_| true,
330 }
331 }
332}
333
334impl AssertFails {
335 pub fn new(error_matcher: fn(&Error) -> bool) -> Self {
338 Self { error_matcher }
339 }
340}
341
342impl<T: 'static + Clone> NativeFn<T> for AssertFails {
343 fn evaluate(
344 &self,
345 mut args: Vec<SpannedValue<T>>,
346 ctx: &mut CallContext<'_, T>,
347 ) -> EvalResult<T> {
348 const ARG_ERROR: &str = "Single argument must be a function";
349
350 ctx.check_args_count(&args, 1)?;
351 let closure = extract_fn(ctx, args.pop().unwrap(), ARG_ERROR)?;
352 match closure.evaluate(Vec::new(), ctx) {
353 Ok(_) => {
354 let err = ErrorKind::native("Function did not fail");
355 Err(ctx.call_site_error(err))
356 }
357 Err(err) => {
358 if (self.error_matcher)(&err) {
359 Ok(Value::void())
360 } else {
361 Err(err)
363 }
364 }
365 }
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use arithmetic_parser::{Location, LvalueLen};
372 use assert_matches::assert_matches;
373
374 use super::*;
375 use crate::{arith::CheckedArithmetic, exec::WildcardId, Environment, Object};
376
377 fn span_value<T>(value: Value<T>) -> SpannedValue<T> {
378 Location::from_str("", ..).copy_with_extra(value)
379 }
380
381 #[test]
382 fn assert_basics() {
383 let env = Environment::with_arithmetic(<CheckedArithmetic>::new());
384 let mut ctx = CallContext::<u32>::mock(WildcardId, Location::from_str("", ..), &env);
385
386 let err = Assert.evaluate(vec![], &mut ctx).unwrap_err();
387 assert_matches!(err.kind(), ErrorKind::ArgsLenMismatch { .. });
388
389 let invalid_arg = span_value(Value::Prim(1));
390 let err = Assert.evaluate(vec![invalid_arg], &mut ctx).unwrap_err();
391 assert_matches!(
392 err.kind(),
393 ErrorKind::NativeCall(s) if s.contains("requires a single boolean argument")
394 );
395
396 let false_arg = span_value(Value::Bool(false));
397 let err = Assert.evaluate(vec![false_arg], &mut ctx).unwrap_err();
398 assert_matches!(
399 err.kind(),
400 ErrorKind::NativeCall(s) if s.contains("Assertion failed")
401 );
402
403 let true_arg = span_value(Value::Bool(true));
404 let return_value = Assert.evaluate(vec![true_arg.clone()], &mut ctx).unwrap();
405 assert!(return_value.is_void(), "{return_value:?}");
406
407 let err = Assert
408 .evaluate(vec![true_arg.clone(), true_arg], &mut ctx)
409 .unwrap_err();
410 assert_matches!(err.kind(), ErrorKind::ArgsLenMismatch { .. });
411 }
412
413 #[test]
414 fn assert_eq_basics() {
415 let env = Environment::with_arithmetic(<CheckedArithmetic>::new());
416 let mut ctx = CallContext::<u32>::mock(WildcardId, Location::from_str("", ..), &env);
417
418 let err = AssertEq.evaluate(vec![], &mut ctx).unwrap_err();
419 assert_matches!(err.kind(), ErrorKind::ArgsLenMismatch { .. });
420
421 let x = span_value(Value::Prim(1));
422 let y = span_value(Value::Prim(2));
423 let err = AssertEq.evaluate(vec![x.clone(), y], &mut ctx).unwrap_err();
424 assert_matches!(
425 err.kind(),
426 ErrorKind::NativeCall(s) if s.contains("assertion failed")
427 );
428
429 let return_value = AssertEq.evaluate(vec![x.clone(), x], &mut ctx).unwrap();
430 assert!(return_value.is_void(), "{return_value:?}");
431 }
432
433 #[test]
434 fn assert_close_basics() {
435 let assert_close = AssertClose::new(1e-3);
436 let env = Environment::new();
437 let mut ctx = CallContext::<f32>::mock(WildcardId, Location::from_str("", ..), &env);
438
439 let err = assert_close.evaluate(vec![], &mut ctx).unwrap_err();
440 assert_matches!(err.kind(), ErrorKind::ArgsLenMismatch { .. });
441
442 let one_arg = span_value(Value::Prim(1.0));
443 let invalid_args = [
444 Value::Bool(true),
445 vec![Value::Prim(1.0)].into(),
446 Object::just("test", Value::Prim(1.0)).into(),
447 ];
448 for invalid_arg in invalid_args {
449 let err = assert_close
450 .evaluate(vec![one_arg.clone(), span_value(invalid_arg)], &mut ctx)
451 .unwrap_err();
452 assert_matches!(
453 err.kind(),
454 ErrorKind::NativeCall(s) if s.contains("must be primitive numbers")
455 );
456 }
457
458 let distant_values = &[(0.0, 1.0), (1.0, 1.01), (0.0, f32::INFINITY)];
459 for &(x, y) in distant_values {
460 let x = span_value(Value::Prim(x));
461 let y = span_value(Value::Prim(y));
462 let err = assert_close.evaluate(vec![x, y], &mut ctx).unwrap_err();
463 assert_matches!(
464 err.kind(),
465 ErrorKind::NativeCall(s) if s.contains("Values are not close")
466 );
467 }
468
469 let non_comparable_values = &[(0.0, f32::NAN), (f32::NAN, 1.0), (f32::NAN, f32::NAN)];
470 for &(x, y) in non_comparable_values {
471 let x = span_value(Value::Prim(x));
472 let y = span_value(Value::Prim(y));
473 let err = assert_close.evaluate(vec![x, y], &mut ctx).unwrap_err();
474 assert_matches!(
475 err.kind(),
476 ErrorKind::NativeCall(s) if s.contains("Values are not comparable")
477 );
478 }
479
480 let close_values = &[(1.0, 0.9999), (0.9999, 1.0), (1.0, 1.0)];
481 for &(x, y) in close_values {
482 let x = span_value(Value::Prim(x));
483 let y = span_value(Value::Prim(y));
484 let return_value = assert_close.evaluate(vec![x, y], &mut ctx).unwrap();
485 assert!(return_value.is_void(), "{return_value:?}");
486 }
487 }
488
489 #[test]
490 fn assert_fails_basics() {
491 let assert_fails = AssertFails::default();
492 let env = Environment::new();
493 let mut ctx = CallContext::<f32>::mock(WildcardId, Location::from_str("", ..), &env);
494
495 let err = assert_fails.evaluate(vec![], &mut ctx).unwrap_err();
496 assert_matches!(err.kind(), ErrorKind::ArgsLenMismatch { .. });
497
498 let invalid_arg = span_value(Value::Prim(1.0));
499 let err = assert_fails
500 .evaluate(vec![invalid_arg], &mut ctx)
501 .unwrap_err();
502 assert_matches!(
503 err.kind(),
504 ErrorKind::NativeCall(s) if s.contains("must be a function")
505 );
506
507 let successful_fn = span_value(Value::wrapped_fn(|| true));
508 let err = assert_fails
509 .evaluate(vec![successful_fn], &mut ctx)
510 .unwrap_err();
511 assert_matches!(
512 err.kind(),
513 ErrorKind::NativeCall(s) if s.contains("Function did not fail")
514 );
515
516 let failing_fn = Value::wrapped_fn(|| Err::<f32, _>("oops".to_owned()));
517 let return_value = assert_fails
518 .evaluate(vec![span_value(failing_fn)], &mut ctx)
519 .unwrap();
520 assert!(return_value.is_void(), "{return_value:?}");
521 }
522
523 #[test]
524 fn assert_fails_with_custom_matcher() {
525 let assert_fails = AssertFails::new(
526 |err| matches!(err.kind(), ErrorKind::NativeCall(msg) if msg == "oops"),
527 );
528 let env = Environment::new();
529 let mut ctx = CallContext::<f32>::mock(WildcardId, Location::from_str("", ..), &env);
530
531 let wrong_fn = Value::wrapped_fn(f32::abs);
532 let err = assert_fails
533 .evaluate(vec![span_value(wrong_fn)], &mut ctx)
534 .unwrap_err();
535 assert_matches!(
536 err.kind(),
537 ErrorKind::ArgsLenMismatch { def, call: 0 } if *def == LvalueLen::Exact(1)
538 );
539
540 let failing_fn = Value::wrapped_fn(|| Err::<f32, _>("oops".to_owned()));
541 let return_value = assert_fails
542 .evaluate(vec![span_value(failing_fn)], &mut ctx)
543 .unwrap();
544 assert!(return_value.is_void(), "{return_value:?}");
545 }
546}