1use std::{
18 any::Any,
19 fmt, panic,
20 sync::{
21 mpsc::{self, RecvTimeoutError},
22 Mutex, PoisonError,
23 },
24 thread,
25 time::Duration,
26};
27
28pub trait TestFn<R>: Fn() -> R + panic::UnwindSafe + Send + Sync + Copy + 'static {}
32
33impl<R, F> TestFn<R> for F where F: Fn() -> R + panic::UnwindSafe + Send + Sync + Copy + 'static {}
34
35pub trait DecorateTest<R>: panic::RefUnwindSafe + Send + Sync + 'static {
79 fn decorate_and_test<F: TestFn<R>>(&'static self, test_fn: F) -> R;
81}
82
83impl<R, T: DecorateTest<R>> DecorateTest<R> for &'static T {
84 fn decorate_and_test<F: TestFn<R>>(&'static self, test_fn: F) -> R {
85 (**self).decorate_and_test(test_fn)
86 }
87}
88
89#[doc(hidden)] pub trait DecorateTestFn<R>: panic::RefUnwindSafe + Send + Sync + 'static {
92 fn decorate_and_test_fn(&'static self, test_fn: fn() -> R) -> R;
93}
94
95impl<R: 'static, T: DecorateTest<R>> DecorateTestFn<R> for T {
96 fn decorate_and_test_fn(&'static self, test_fn: fn() -> R) -> R {
97 self.decorate_and_test(test_fn)
98 }
99}
100
101#[derive(Debug, Clone, Copy)]
116pub struct Timeout(pub Duration);
117
118impl Timeout {
119 pub const fn secs(secs: u64) -> Self {
121 Self(Duration::from_secs(secs))
122 }
123
124 pub const fn millis(millis: u64) -> Self {
126 Self(Duration::from_millis(millis))
127 }
128}
129
130impl<R: Send + 'static> DecorateTest<R> for Timeout {
131 #[allow(clippy::similar_names)]
132 fn decorate_and_test<F: TestFn<R>>(&self, test_fn: F) -> R {
133 let (output_sx, output_rx) = mpsc::channel();
134 let handle = thread::spawn(move || {
135 output_sx.send(test_fn()).ok();
136 });
137 match output_rx.recv_timeout(self.0) {
138 Ok(output) => {
139 handle.join().unwrap();
140 output
143 }
144 Err(RecvTimeoutError::Timeout) => {
145 panic!("Timeout {:?} expired for the test", self.0);
146 }
147 Err(RecvTimeoutError::Disconnected) => {
148 let panic_object = handle.join().unwrap_err();
149 panic::resume_unwind(panic_object)
150 }
151 }
152 }
153}
154
155#[derive(Debug)]
173pub struct Retry {
174 times: usize,
175 delay: Duration,
176}
177
178impl Retry {
179 pub const fn times(times: usize) -> Self {
181 Self {
182 times,
183 delay: Duration::ZERO,
184 }
185 }
186
187 #[must_use]
189 pub const fn with_delay(self, delay: Duration) -> Self {
190 Self { delay, ..self }
191 }
192
193 pub const fn on_error<E>(self, matcher: fn(&E) -> bool) -> RetryErrors<E> {
195 RetryErrors {
196 inner: self,
197 matcher,
198 }
199 }
200
201 fn handle_panic(&self, attempt: usize, panic_object: Box<dyn Any + Send>) {
202 if attempt < self.times {
203 let panic_str = extract_panic_str(&panic_object).unwrap_or("");
204 let punctuation = if panic_str.is_empty() { "" } else { ": " };
205 println!("Test attempt #{attempt} panicked{punctuation}{panic_str}");
206 } else {
207 panic::resume_unwind(panic_object);
208 }
209 }
210
211 fn run_with_retries<E: fmt::Display>(
212 &self,
213 test_fn: impl TestFn<Result<(), E>>,
214 should_retry: fn(&E) -> bool,
215 ) -> Result<(), E> {
216 for attempt in 0..=self.times {
217 println!("Test attempt #{attempt}");
218 match panic::catch_unwind(test_fn) {
219 Ok(Ok(())) => return Ok(()),
220 Ok(Err(err)) => {
221 if attempt < self.times && should_retry(&err) {
222 println!("Test attempt #{attempt} errored: {err}");
223 } else {
224 return Err(err);
225 }
226 }
227 Err(panic_object) => {
228 self.handle_panic(attempt, panic_object);
229 }
230 }
231 if self.delay > Duration::ZERO {
232 thread::sleep(self.delay);
233 }
234 }
235 Ok(())
236 }
237}
238
239impl DecorateTest<()> for Retry {
240 fn decorate_and_test<F: TestFn<()>>(&self, test_fn: F) {
241 for attempt in 0..=self.times {
242 println!("Test attempt #{attempt}");
243 match panic::catch_unwind(test_fn) {
244 Ok(()) => break,
245 Err(panic_object) => {
246 self.handle_panic(attempt, panic_object);
247 }
248 }
249 if self.delay > Duration::ZERO {
250 thread::sleep(self.delay);
251 }
252 }
253 }
254}
255
256impl<E: fmt::Display> DecorateTest<Result<(), E>> for Retry {
257 fn decorate_and_test<F>(&self, test_fn: F) -> Result<(), E>
258 where
259 F: TestFn<Result<(), E>>,
260 {
261 self.run_with_retries(test_fn, |_| true)
262 }
263}
264
265fn extract_panic_str(panic_object: &(dyn Any + Send)) -> Option<&str> {
266 if let Some(panic_str) = panic_object.downcast_ref::<&'static str>() {
267 Some(panic_str)
268 } else if let Some(panic_string) = panic_object.downcast_ref::<String>() {
269 Some(panic_string.as_str())
270 } else {
271 None
272 }
273}
274
275pub struct RetryErrors<E> {
297 inner: Retry,
298 matcher: fn(&E) -> bool,
299}
300
301impl<E> fmt::Debug for RetryErrors<E> {
302 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
303 formatter
304 .debug_struct("RetryErrors")
305 .field("inner", &self.inner)
306 .finish_non_exhaustive()
307 }
308}
309
310impl<E: fmt::Display + 'static> DecorateTest<Result<(), E>> for RetryErrors<E> {
311 fn decorate_and_test<F>(&self, test_fn: F) -> Result<(), E>
312 where
313 F: TestFn<Result<(), E>>,
314 {
315 self.inner.run_with_retries(test_fn, self.matcher)
316 }
317}
318
319#[derive(Debug, Default)]
346pub struct Sequence {
347 failed: Mutex<bool>,
348 abort_on_failure: bool,
349}
350
351impl Sequence {
352 pub const fn new() -> Self {
354 Self {
355 failed: Mutex::new(false),
356 abort_on_failure: false,
357 }
358 }
359
360 #[must_use]
362 pub const fn abort_on_failure(mut self) -> Self {
363 self.abort_on_failure = true;
364 self
365 }
366
367 fn decorate_inner<R, F: TestFn<R>>(
368 &self,
369 test_fn: F,
370 ok_value: R,
371 match_failure: fn(&R) -> bool,
372 ) -> R {
373 let mut guard = self.failed.lock().unwrap_or_else(PoisonError::into_inner);
374 if *guard && self.abort_on_failure {
375 println!("Skipping test because a previous test in the same sequence has failed");
376 return ok_value;
377 }
378
379 let output = panic::catch_unwind(test_fn);
380 *guard = output.as_ref().map_or(true, match_failure);
381 drop(guard);
382 output.unwrap_or_else(|panic_object| {
383 panic::resume_unwind(panic_object);
384 })
385 }
386}
387
388impl DecorateTest<()> for Sequence {
389 fn decorate_and_test<F: TestFn<()>>(&self, test_fn: F) {
390 self.decorate_inner(test_fn, (), |()| false);
391 }
392}
393
394impl<E: 'static> DecorateTest<Result<(), E>> for Sequence {
395 fn decorate_and_test<F>(&self, test_fn: F) -> Result<(), E>
396 where
397 F: TestFn<Result<(), E>>,
398 {
399 self.decorate_inner(test_fn, Ok(()), Result::is_err)
400 }
401}
402
403macro_rules! impl_decorate_test_for_tuple {
404 ($($field:ident : $ty:ident),* => $last_field:ident : $last_ty:ident) => {
405 impl<R, $($ty,)* $last_ty> DecorateTest<R> for ($($ty,)* $last_ty,)
406 where
407 $($ty: DecorateTest<R>,)*
408 $last_ty: DecorateTest<R>,
409 {
410 fn decorate_and_test<Fn: TestFn<R>>(&'static self, test_fn: Fn) -> R {
411 let ($($field,)* $last_field,) = self;
412 $(
413 let test_fn = move || $field.decorate_and_test(test_fn);
414 )*
415 $last_field.decorate_and_test(test_fn)
416 }
417 }
418 };
419}
420
421impl_decorate_test_for_tuple!(=> a: A);
422impl_decorate_test_for_tuple!(a: A => b: B);
423impl_decorate_test_for_tuple!(a: A, b: B => c: C);
424impl_decorate_test_for_tuple!(a: A, b: B, c: C => d: D);
425impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D => e: E);
426impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D, e: E => f: F);
427impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D, e: E, f: F => g: G);
428impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D, e: E, f: F, g: G => h: H);
429
430#[cfg(test)]
431mod tests {
432 use std::{
433 io,
434 sync::{
435 atomic::{AtomicU32, Ordering},
436 Mutex,
437 },
438 time::Instant,
439 };
440
441 use super::*;
442
443 #[test]
444 #[should_panic(expected = "Timeout 100ms expired")]
445 fn timeouts() {
446 const TIMEOUT: Timeout = Timeout(Duration::from_millis(100));
447
448 let test_fn: fn() = || thread::sleep(Duration::from_secs(1));
449 TIMEOUT.decorate_and_test(test_fn);
450 }
451
452 #[test]
453 fn retrying_with_delay() {
454 const RETRY: Retry = Retry::times(1).with_delay(Duration::from_millis(100));
455
456 fn test_fn() -> Result<(), &'static str> {
457 static TEST_START: Mutex<Option<Instant>> = Mutex::new(None);
458
459 let mut test_start = TEST_START.lock().unwrap();
460 if let Some(test_start) = *test_start {
461 assert!(test_start.elapsed() > RETRY.delay);
462 Ok(())
463 } else {
464 *test_start = Some(Instant::now());
465 Err("come again?")
466 }
467 }
468
469 RETRY.decorate_and_test(test_fn).unwrap();
470 }
471
472 const RETRY: RetryErrors<io::Error> =
473 Retry::times(2).on_error(|err| matches!(err.kind(), io::ErrorKind::AddrInUse));
474
475 #[test]
476 fn retrying_on_error() {
477 static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
478
479 fn test_fn() -> io::Result<()> {
480 if TEST_COUNTER.fetch_add(1, Ordering::Relaxed) == 2 {
481 Ok(())
482 } else {
483 Err(io::Error::new(
484 io::ErrorKind::AddrInUse,
485 "please retry later",
486 ))
487 }
488 }
489
490 let test_fn: fn() -> _ = test_fn;
491 RETRY.decorate_and_test(test_fn).unwrap();
492 assert_eq!(TEST_COUNTER.load(Ordering::Relaxed), 3);
493
494 let err = RETRY.decorate_and_test(test_fn).unwrap_err();
495 assert!(err.to_string().contains("please retry later"));
496 assert_eq!(TEST_COUNTER.load(Ordering::Relaxed), 6);
497 }
498
499 #[test]
500 fn retrying_on_error_failure() {
501 static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
502
503 fn test_fn() -> io::Result<()> {
504 if TEST_COUNTER.fetch_add(1, Ordering::Relaxed) == 0 {
505 Err(io::Error::new(io::ErrorKind::BrokenPipe, "oops"))
506 } else {
507 Ok(())
508 }
509 }
510
511 let err = RETRY.decorate_and_test(test_fn).unwrap_err();
512 assert!(err.to_string().contains("oops"));
513 assert_eq!(TEST_COUNTER.load(Ordering::Relaxed), 1);
514 }
515
516 #[test]
517 fn sequential_tests() {
518 static SEQUENCE: Sequence = Sequence::new();
519 static ENTRY_COUNTER: AtomicU32 = AtomicU32::new(0);
520
521 let first_test: fn() = || {
522 let counter = ENTRY_COUNTER.fetch_add(1, Ordering::Relaxed);
523 assert_eq!(counter, 0);
524 thread::sleep(Duration::from_millis(10));
525 ENTRY_COUNTER.store(0, Ordering::Relaxed);
526 panic!("oops");
527 };
528 let second_test = || {
529 let counter = ENTRY_COUNTER.fetch_add(1, Ordering::Relaxed);
530 assert_eq!(counter, 0);
531 thread::sleep(Duration::from_millis(20));
532 ENTRY_COUNTER.store(0, Ordering::Relaxed);
533 Ok::<_, io::Error>(())
534 };
535
536 let first_test_handle = thread::spawn(move || SEQUENCE.decorate_and_test(first_test));
537 SEQUENCE.decorate_and_test(second_test).unwrap();
538 first_test_handle.join().unwrap_err();
539 }
540
541 #[test]
542 fn sequential_tests_with_abort() {
543 static SEQUENCE: Sequence = Sequence::new().abort_on_failure();
544
545 let failing_test =
546 || Err::<(), _>(io::Error::new(io::ErrorKind::AddrInUse, "please try later"));
547 let second_test: fn() = || unreachable!("Second test should not be called!");
548
549 SEQUENCE.decorate_and_test(failing_test).unwrap_err();
550 SEQUENCE.decorate_and_test(second_test);
551 }
552
553 macro_rules! define_test_fn {
556 () => {
557 fn test_fn() -> Result<(), &'static str> {
558 static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
559 match TEST_COUNTER.fetch_add(1, Ordering::Relaxed) {
560 0 => {
561 thread::sleep(Duration::from_secs(1));
562 Ok(())
563 }
564 1 => Err("oops"),
565 2 => Ok(()),
566 _ => unreachable!(),
567 }
568 }
569 };
570 }
571
572 #[test]
573 fn composing_decorators() {
574 define_test_fn!();
575
576 const DECORATORS: (Timeout, Retry) = (Timeout(Duration::from_millis(100)), Retry::times(2));
577
578 DECORATORS.decorate_and_test(test_fn).unwrap();
579 }
580
581 #[test]
582 fn making_decorator_into_trait_object() {
583 define_test_fn!();
584
585 static DECORATORS: &dyn DecorateTestFn<Result<(), &'static str>> =
586 &(Timeout(Duration::from_millis(100)), Retry::times(2));
587
588 DECORATORS.decorate_and_test_fn(test_fn).unwrap();
589 }
590
591 #[test]
592 fn making_sequence_into_trait_object() {
593 static SEQUENCE: Sequence = Sequence::new();
594 static DECORATORS: &dyn DecorateTestFn<()> = &(&SEQUENCE,);
595
596 DECORATORS.decorate_and_test_fn(|| {});
597 }
598}