1use std::{
18 any, fmt, panic,
19 sync::{
20 mpsc::{self, RecvTimeoutError},
21 Mutex, PoisonError,
22 },
23 thread,
24 time::Duration,
25};
26
27#[cfg(feature = "tracing")]
28pub use self::traces::Trace;
29
30#[cfg(feature = "tracing")]
31mod traces;
32
33pub trait TestFn<R>: Fn() -> R + panic::UnwindSafe + Send + Sync + Copy + 'static {}
37
38impl<R, F> TestFn<R> for F where F: Fn() -> R + panic::UnwindSafe + Send + Sync + Copy + 'static {}
39
40pub trait DecorateTest<R>: panic::RefUnwindSafe + Send + Sync + 'static + fmt::Debug {
84 fn decorate_and_test<F: TestFn<R>>(&'static self, test_fn: F) -> R;
86}
87
88impl<R, T: DecorateTest<R>> DecorateTest<R> for &'static T {
89 fn decorate_and_test<F: TestFn<R>>(&'static self, test_fn: F) -> R {
90 (**self).decorate_and_test(test_fn)
91 }
92}
93
94#[doc(hidden)] pub trait DecorateTestFn<R>: panic::RefUnwindSafe + Send + Sync + 'static + fmt::Debug {
97 fn decorate_and_test_fn(&'static self, test_fn: fn() -> R) -> R;
98}
99
100impl<R: 'static, T: DecorateTest<R>> DecorateTestFn<R> for T {
101 fn decorate_and_test_fn(&'static self, test_fn: fn() -> R) -> R {
102 self.decorate_and_test(move || {
103 #[cfg(feature = "tracing")]
104 tracing::info!(decorators = ?self, "running decorated test");
105 test_fn()
106 })
107 }
108}
109
110#[derive(Debug, Clone, Copy)]
125pub struct Timeout(pub Duration);
126
127impl Timeout {
128 pub const fn secs(secs: u64) -> Self {
130 Self(Duration::from_secs(secs))
131 }
132
133 pub const fn millis(millis: u64) -> Self {
135 Self(Duration::from_millis(millis))
136 }
137}
138
139impl<R: Send + 'static> DecorateTest<R> for Timeout {
140 #[allow(clippy::similar_names)]
141 fn decorate_and_test<F: TestFn<R>>(&self, test_fn: F) -> R {
142 let (output_sx, output_rx) = mpsc::channel();
143 let handle = thread::spawn(move || {
144 output_sx.send(test_fn()).ok();
145 });
146 match output_rx.recv_timeout(self.0) {
147 Ok(output) => {
148 handle.join().unwrap();
149 output
152 }
153 Err(RecvTimeoutError::Timeout) => {
154 panic!("Timeout {:?} expired for the test", self.0);
155 }
156 Err(RecvTimeoutError::Disconnected) => {
157 let panic_object = handle.join().unwrap_err();
158 panic::resume_unwind(panic_object)
159 }
160 }
161 }
162}
163
164#[derive(Debug)]
182pub struct Retry {
183 times: usize,
184 delay: Duration,
185}
186
187impl Retry {
188 pub const fn times(times: usize) -> Self {
190 Self {
191 times,
192 delay: Duration::ZERO,
193 }
194 }
195
196 #[must_use]
198 pub const fn with_delay(self, delay: Duration) -> Self {
199 Self { delay, ..self }
200 }
201
202 pub const fn on_error<E>(self, matcher: fn(&E) -> bool) -> RetryErrors<E> {
204 RetryErrors {
205 inner: self,
206 matcher,
207 }
208 }
209
210 fn handle_panic(&self, attempt: usize, panic_object: Box<dyn any::Any + Send>) {
211 if attempt < self.times {
212 let panic_str = extract_panic_str(panic_object.as_ref()).unwrap_or("");
213
214 #[cfg(not(feature = "tracing"))]
215 println!(
216 "Test attempt #{attempt} panicked{punctuation}{panic_str}",
217 punctuation = if panic_str.is_empty() { "" } else { ": " }
218 );
219 #[cfg(feature = "tracing")]
220 tracing::warn!(panic = panic_str, "test attempt panicked");
221 } else {
222 panic::resume_unwind(panic_object);
223 }
224 }
225
226 fn run_with_retries<E: fmt::Display>(
227 &self,
228 test_fn: impl TestFn<Result<(), E>>,
229 should_retry: fn(&E) -> bool,
230 ) -> Result<(), E> {
231 for attempt in 0..=self.times {
232 #[cfg(not(feature = "tracing"))]
233 println!("Test attempt #{attempt}");
234 #[cfg(feature = "tracing")]
235 let _span_guard = tracing::info_span!("test_attempt", attempt).entered();
236
237 match panic::catch_unwind(test_fn) {
238 Ok(Ok(())) => return Ok(()),
239 Ok(Err(err)) => {
240 if attempt < self.times && should_retry(&err) {
241 #[cfg(not(feature = "tracing"))]
242 println!("Test attempt #{attempt} errored: {err}");
243 #[cfg(feature = "tracing")]
244 tracing::warn!(%err, "test attempt errored");
245 } else {
246 return Err(err);
247 }
248 }
249 Err(panic_object) => {
250 self.handle_panic(attempt, panic_object);
251 }
252 }
253 if self.delay > Duration::ZERO {
254 thread::sleep(self.delay);
255 }
256 }
257 Ok(())
258 }
259}
260
261impl DecorateTest<()> for Retry {
262 fn decorate_and_test<F: TestFn<()>>(&self, test_fn: F) {
263 for attempt in 0..=self.times {
264 #[cfg(not(feature = "tracing"))]
265 println!("Test attempt #{attempt}");
266 #[cfg(feature = "tracing")]
267 let _span_guard = tracing::info_span!("test_attempt", attempt).entered();
268
269 match panic::catch_unwind(test_fn) {
270 Ok(()) => break,
271 Err(panic_object) => {
272 self.handle_panic(attempt, panic_object);
273 }
274 }
275 if self.delay > Duration::ZERO {
276 thread::sleep(self.delay);
277 }
278 }
279 }
280}
281
282impl<E: fmt::Display> DecorateTest<Result<(), E>> for Retry {
283 fn decorate_and_test<F>(&self, test_fn: F) -> Result<(), E>
284 where
285 F: TestFn<Result<(), E>>,
286 {
287 self.run_with_retries(test_fn, |_| true)
288 }
289}
290
291fn extract_panic_str(panic_object: &(dyn any::Any + Send)) -> Option<&str> {
292 if let Some(panic_str) = panic_object.downcast_ref::<&'static str>() {
293 Some(panic_str)
294 } else if let Some(panic_string) = panic_object.downcast_ref::<String>() {
295 Some(panic_string.as_str())
296 } else {
297 None
298 }
299}
300
301pub struct RetryErrors<E> {
323 inner: Retry,
324 matcher: fn(&E) -> bool,
325}
326
327impl<E> fmt::Debug for RetryErrors<E> {
328 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
329 formatter
330 .debug_struct("RetryErrors")
331 .field("inner", &self.inner)
332 .finish_non_exhaustive()
333 }
334}
335
336impl<E: fmt::Display + 'static> DecorateTest<Result<(), E>> for RetryErrors<E> {
337 fn decorate_and_test<F>(&self, test_fn: F) -> Result<(), E>
338 where
339 F: TestFn<Result<(), E>>,
340 {
341 self.inner.run_with_retries(test_fn, self.matcher)
342 }
343}
344
345#[derive(Debug, Default)]
372pub struct Sequence {
373 failed: Mutex<bool>,
374 abort_on_failure: bool,
375}
376
377impl Sequence {
378 pub const fn new() -> Self {
380 Self {
381 failed: Mutex::new(false),
382 abort_on_failure: false,
383 }
384 }
385
386 #[must_use]
388 pub const fn abort_on_failure(mut self) -> Self {
389 self.abort_on_failure = true;
390 self
391 }
392
393 fn decorate_inner<R, F: TestFn<R>>(
394 &self,
395 test_fn: F,
396 ok_value: R,
397 match_failure: fn(&R) -> bool,
398 ) -> R {
399 let mut guard = self.failed.lock().unwrap_or_else(PoisonError::into_inner);
400 if *guard && self.abort_on_failure {
401 #[cfg(not(feature = "tracing"))]
402 println!("Skipping test because a previous test in the same sequence has failed");
403 #[cfg(feature = "tracing")]
404 tracing::info!("skipping test because a previous test in the same sequence has failed");
405
406 return ok_value;
407 }
408
409 let output = panic::catch_unwind(test_fn);
410 *guard = output.as_ref().map_or(true, match_failure);
411 drop(guard);
412 output.unwrap_or_else(|panic_object| {
413 panic::resume_unwind(panic_object);
414 })
415 }
416}
417
418impl DecorateTest<()> for Sequence {
419 fn decorate_and_test<F: TestFn<()>>(&self, test_fn: F) {
420 self.decorate_inner(test_fn, (), |()| false);
421 }
422}
423
424impl<E: 'static> DecorateTest<Result<(), E>> for Sequence {
425 fn decorate_and_test<F>(&self, test_fn: F) -> Result<(), E>
426 where
427 F: TestFn<Result<(), E>>,
428 {
429 self.decorate_inner(test_fn, Ok(()), Result::is_err)
430 }
431}
432
433macro_rules! impl_decorate_test_for_tuple {
434 ($($field:ident : $ty:ident),* => $last_field:ident : $last_ty:ident) => {
435 impl<R, $($ty,)* $last_ty> DecorateTest<R> for ($($ty,)* $last_ty,)
436 where
437 $($ty: DecorateTest<R>,)*
438 $last_ty: DecorateTest<R>,
439 {
440 fn decorate_and_test<Fn: TestFn<R>>(&'static self, test_fn: Fn) -> R {
441 let ($($field,)* $last_field,) = self;
442 $(
443 let test_fn = move || $field.decorate_and_test(test_fn);
444 )*
445 $last_field.decorate_and_test(test_fn)
446 }
447 }
448 };
449}
450
451impl_decorate_test_for_tuple!(=> a: A);
452impl_decorate_test_for_tuple!(a: A => b: B);
453impl_decorate_test_for_tuple!(a: A, b: B => c: C);
454impl_decorate_test_for_tuple!(a: A, b: B, c: C => d: D);
455impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D => e: E);
456impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D, e: E => f: F);
457impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D, e: E, f: F => g: G);
458impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D, e: E, f: F, g: G => h: H);
459
460#[cfg(test)]
461mod tests {
462 use std::{
463 io,
464 sync::{
465 atomic::{AtomicU32, Ordering},
466 Mutex,
467 },
468 time::Instant,
469 };
470
471 use super::*;
472
473 #[test]
474 #[should_panic(expected = "Timeout 100ms expired")]
475 fn timeouts() {
476 const TIMEOUT: Timeout = Timeout(Duration::from_millis(100));
477
478 let test_fn: fn() = || thread::sleep(Duration::from_secs(1));
479 TIMEOUT.decorate_and_test(test_fn);
480 }
481
482 #[test]
483 fn retrying_with_delay() {
484 const RETRY: Retry = Retry::times(1).with_delay(Duration::from_millis(100));
485
486 fn test_fn() -> Result<(), &'static str> {
487 static TEST_START: Mutex<Option<Instant>> = Mutex::new(None);
488
489 let mut test_start = TEST_START.lock().unwrap();
490 if let Some(test_start) = *test_start {
491 assert!(test_start.elapsed() > RETRY.delay);
492 Ok(())
493 } else {
494 *test_start = Some(Instant::now());
495 Err("come again?")
496 }
497 }
498
499 RETRY.decorate_and_test(test_fn).unwrap();
500 }
501
502 const RETRY: RetryErrors<io::Error> =
503 Retry::times(2).on_error(|err| matches!(err.kind(), io::ErrorKind::AddrInUse));
504
505 #[test]
506 fn retrying_on_error() {
507 static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
508
509 fn test_fn() -> io::Result<()> {
510 if TEST_COUNTER.fetch_add(1, Ordering::Relaxed) == 2 {
511 Ok(())
512 } else {
513 Err(io::Error::new(
514 io::ErrorKind::AddrInUse,
515 "please retry later",
516 ))
517 }
518 }
519
520 let test_fn: fn() -> _ = test_fn;
521 RETRY.decorate_and_test(test_fn).unwrap();
522 assert_eq!(TEST_COUNTER.load(Ordering::Relaxed), 3);
523
524 let err = RETRY.decorate_and_test(test_fn).unwrap_err();
525 assert!(err.to_string().contains("please retry later"));
526 assert_eq!(TEST_COUNTER.load(Ordering::Relaxed), 6);
527 }
528
529 #[test]
530 fn retrying_on_error_failure() {
531 static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
532
533 fn test_fn() -> io::Result<()> {
534 if TEST_COUNTER.fetch_add(1, Ordering::Relaxed) == 0 {
535 Err(io::Error::new(io::ErrorKind::BrokenPipe, "oops"))
536 } else {
537 Ok(())
538 }
539 }
540
541 let err = RETRY.decorate_and_test(test_fn).unwrap_err();
542 assert!(err.to_string().contains("oops"));
543 assert_eq!(TEST_COUNTER.load(Ordering::Relaxed), 1);
544 }
545
546 #[test]
547 fn sequential_tests() {
548 static SEQUENCE: Sequence = Sequence::new();
549 static ENTRY_COUNTER: AtomicU32 = AtomicU32::new(0);
550
551 let first_test: fn() = || {
552 let counter = ENTRY_COUNTER.fetch_add(1, Ordering::Relaxed);
553 assert_eq!(counter, 0);
554 thread::sleep(Duration::from_millis(10));
555 ENTRY_COUNTER.store(0, Ordering::Relaxed);
556 panic!("oops");
557 };
558 let second_test = || {
559 let counter = ENTRY_COUNTER.fetch_add(1, Ordering::Relaxed);
560 assert_eq!(counter, 0);
561 thread::sleep(Duration::from_millis(20));
562 ENTRY_COUNTER.store(0, Ordering::Relaxed);
563 Ok::<_, io::Error>(())
564 };
565
566 let first_test_handle = thread::spawn(move || SEQUENCE.decorate_and_test(first_test));
567 SEQUENCE.decorate_and_test(second_test).unwrap();
568 first_test_handle.join().unwrap_err();
569 }
570
571 #[test]
572 fn sequential_tests_with_abort() {
573 static SEQUENCE: Sequence = Sequence::new().abort_on_failure();
574
575 let failing_test =
576 || Err::<(), _>(io::Error::new(io::ErrorKind::AddrInUse, "please try later"));
577 let second_test: fn() = || unreachable!("Second test should not be called!");
578
579 SEQUENCE.decorate_and_test(failing_test).unwrap_err();
580 SEQUENCE.decorate_and_test(second_test);
581 }
582
583 macro_rules! define_test_fn {
586 () => {
587 fn test_fn() -> Result<(), &'static str> {
588 static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
589 match TEST_COUNTER.fetch_add(1, Ordering::Relaxed) {
590 0 => {
591 thread::sleep(Duration::from_secs(1));
592 Ok(())
593 }
594 1 => Err("oops"),
595 2 => Ok(()),
596 _ => unreachable!(),
597 }
598 }
599 };
600 }
601
602 #[test]
603 fn composing_decorators() {
604 define_test_fn!();
605
606 const DECORATORS: (Timeout, Retry) = (Timeout(Duration::from_millis(100)), Retry::times(2));
607
608 DECORATORS.decorate_and_test(test_fn).unwrap();
609 }
610
611 #[test]
612 fn making_decorator_into_trait_object() {
613 define_test_fn!();
614
615 static DECORATORS: &dyn DecorateTestFn<Result<(), &'static str>> =
616 &(Timeout(Duration::from_millis(100)), Retry::times(2));
617
618 DECORATORS.decorate_and_test_fn(test_fn).unwrap();
619 }
620
621 #[test]
622 fn making_sequence_into_trait_object() {
623 static SEQUENCE: Sequence = Sequence::new();
624 static DECORATORS: &dyn DecorateTestFn<()> = &(&SEQUENCE,);
625
626 DECORATORS.decorate_and_test_fn(|| {});
627 }
628
629 #[test]
630 fn extracting_panic() {
631 let value = 0;
632 let panic_obj = panic::catch_unwind(|| {
633 assert!(value > 1, "what");
634 })
635 .unwrap_err();
636
637 assert_eq!(extract_panic_str(panic_obj.as_ref()), Some("what"));
638
639 let panic_obj = panic::catch_unwind(|| {
640 assert!(value > 1, "what: {value}");
641 })
642 .unwrap_err();
643
644 assert_eq!(extract_panic_str(panic_obj.as_ref()), Some("what: 0"));
645 }
646}