test_casing/decorators/
mod.rs

1//! Test decorator trait and implementations.
2//!
3//! # Overview
4//!
5//! A [test decorator](DecorateTest) takes a [tested function](TestFn) and calls it zero or more times,
6//! perhaps with additional logic spliced between calls. Examples of decorators include [retries](Retry),
7//! [`Timeout`]s and test [`Sequence`]s.
8//!
9//! Decorators are composable: `DecorateTest` is automatically implemented for a tuple with
10//! 2..=8 elements where each element implements `DecorateTest`. The decorators in a tuple
11//! are applied in the order of their appearance in the tuple.
12//!
13//! # Examples
14//!
15//! See [`decorate`](crate::decorate) macro docs for the examples of usage.
16
17use std::{
18    any, fmt, panic,
19    sync::{
20        mpsc::{self, RecvTimeoutError},
21        Mutex, PoisonError,
22    },
23    thread,
24    time::Duration,
25};
26
27#[cfg(feature = "tracing")]
28pub use self::traces::Trace;
29
30#[cfg(feature = "tracing")]
31mod traces;
32
33/// Tested function or closure.
34///
35/// This trait is automatically implemented for all functions without arguments.
36pub trait TestFn<R>: Fn() -> R + panic::UnwindSafe + Send + Sync + Copy + 'static {}
37
38impl<R, F> TestFn<R> for F where F: Fn() -> R + panic::UnwindSafe + Send + Sync + Copy + 'static {}
39
40/// Test decorator.
41///
42/// See [module docs](index.html#overview) for the extended description.
43///
44/// # Examples
45///
46/// The following decorator implements a `#[should_panic]` analogue for errors.
47///
48/// ```no_run
49/// use test_casing::decorators::{DecorateTest, TestFn};
50///
51/// #[derive(Debug, Clone, Copy)]
52/// pub struct ShouldError(pub &'static str);
53///
54/// impl<E: ToString> DecorateTest<Result<(), E>> for ShouldError {
55///     fn decorate_and_test<F: TestFn<Result<(), E>>>(
56///         &self,
57///         test_fn: F,
58///     ) -> Result<(), E> {
59///         let Err(err) = test_fn() else {
60///             panic!("Expected test to error, but it completed successfully");
61///         };
62///         let err = err.to_string();
63///         if err.contains(self.0) {
64///             Ok(())
65///         } else {
66///             panic!(
67///                 "Expected error message to contain `{}`, but it was: {err}",
68///                 self.0
69///             );
70///         }
71///     }
72/// }
73///
74/// // Usage:
75/// # use test_casing::decorate;
76/// # use std::error::Error;
77/// #[test]
78/// #[decorate(ShouldError("oops"))]
79/// fn test_with_an_error() -> Result<(), Box<dyn Error>> {
80///     Err("oops, this test failed".into())
81/// }
82/// ```
83pub trait DecorateTest<R>: panic::RefUnwindSafe + Send + Sync + 'static + fmt::Debug {
84    /// Decorates the provided test function and runs the test.
85    fn decorate_and_test<F: TestFn<R>>(&'static self, test_fn: F) -> R;
86}
87
88impl<R, T: DecorateTest<R>> DecorateTest<R> for &'static T {
89    fn decorate_and_test<F: TestFn<R>>(&'static self, test_fn: F) -> R {
90        (**self).decorate_and_test(test_fn)
91    }
92}
93
94/// Object-safe version of [`DecorateTest`].
95#[doc(hidden)] // used in the `decorate` proc macro; logically private
96pub trait DecorateTestFn<R>: panic::RefUnwindSafe + Send + Sync + 'static + fmt::Debug {
97    fn decorate_and_test_fn(&'static self, test_fn: fn() -> R) -> R;
98}
99
100impl<R: 'static, T: DecorateTest<R>> DecorateTestFn<R> for T {
101    fn decorate_and_test_fn(&'static self, test_fn: fn() -> R) -> R {
102        self.decorate_and_test(move || {
103            #[cfg(feature = "tracing")]
104            tracing::info!(decorators = ?self, "running decorated test");
105            test_fn()
106        })
107    }
108}
109
110/// [Test decorator](DecorateTest) that fails a wrapped test if it doesn't complete
111/// in the specified [`Duration`].
112///
113/// # Examples
114///
115/// ```no_run
116/// use test_casing::{decorate, decorators::Timeout};
117///
118/// #[test]
119/// #[decorate(Timeout::secs(5))]
120/// fn test_with_timeout() {
121///     // test logic
122/// }
123/// ```
124#[derive(Debug, Clone, Copy)]
125pub struct Timeout(pub Duration);
126
127impl Timeout {
128    /// Defines a timeout with the specified number of seconds.
129    pub const fn secs(secs: u64) -> Self {
130        Self(Duration::from_secs(secs))
131    }
132
133    /// Defines a timeout with the specified number of milliseconds.
134    pub const fn millis(millis: u64) -> Self {
135        Self(Duration::from_millis(millis))
136    }
137}
138
139impl<R: Send + 'static> DecorateTest<R> for Timeout {
140    #[allow(clippy::similar_names)]
141    fn decorate_and_test<F: TestFn<R>>(&self, test_fn: F) -> R {
142        let (output_sx, output_rx) = mpsc::channel();
143        let handle = thread::spawn(move || {
144            output_sx.send(test_fn()).ok();
145        });
146        match output_rx.recv_timeout(self.0) {
147            Ok(output) => {
148                handle.join().unwrap();
149                // ^ `unwrap()` is safe; the thread didn't panic before `send`ing the output,
150                // and there's nowhere to panic after that.
151                output
152            }
153            Err(RecvTimeoutError::Timeout) => {
154                panic!("Timeout {:?} expired for the test", self.0);
155            }
156            Err(RecvTimeoutError::Disconnected) => {
157                let panic_object = handle.join().unwrap_err();
158                panic::resume_unwind(panic_object)
159            }
160        }
161    }
162}
163
164/// [Test decorator](DecorateTest) that retries a wrapped test the specified number of times,
165/// potentially with a delay between retries.
166///
167/// # Examples
168///
169/// ```no_run
170/// use test_casing::{decorate, decorators::Retry};
171/// use std::time::Duration;
172///
173/// const RETRY_DELAY: Duration = Duration::from_millis(200);
174///
175/// #[test]
176/// #[decorate(Retry::times(3).with_delay(RETRY_DELAY))]
177/// fn test_with_retries() {
178///     // test logic
179/// }
180/// ```
181#[derive(Debug)]
182pub struct Retry {
183    times: usize,
184    delay: Duration,
185}
186
187impl Retry {
188    /// Specified the number of retries. The delay between retries is zero.
189    pub const fn times(times: usize) -> Self {
190        Self {
191            times,
192            delay: Duration::ZERO,
193        }
194    }
195
196    /// Specifies the delay between retries.
197    #[must_use]
198    pub const fn with_delay(self, delay: Duration) -> Self {
199        Self { delay, ..self }
200    }
201
202    /// Converts this retry specification to only retry specific errors.
203    pub const fn on_error<E>(self, matcher: fn(&E) -> bool) -> RetryErrors<E> {
204        RetryErrors {
205            inner: self,
206            matcher,
207        }
208    }
209
210    fn handle_panic(&self, attempt: usize, panic_object: Box<dyn any::Any + Send>) {
211        if attempt < self.times {
212            let panic_str = extract_panic_str(panic_object.as_ref()).unwrap_or("");
213
214            #[cfg(not(feature = "tracing"))]
215            println!(
216                "Test attempt #{attempt} panicked{punctuation}{panic_str}",
217                punctuation = if panic_str.is_empty() { "" } else { ": " }
218            );
219            #[cfg(feature = "tracing")]
220            tracing::warn!(panic = panic_str, "test attempt panicked");
221        } else {
222            panic::resume_unwind(panic_object);
223        }
224    }
225
226    fn run_with_retries<E: fmt::Display>(
227        &self,
228        test_fn: impl TestFn<Result<(), E>>,
229        should_retry: fn(&E) -> bool,
230    ) -> Result<(), E> {
231        for attempt in 0..=self.times {
232            #[cfg(not(feature = "tracing"))]
233            println!("Test attempt #{attempt}");
234            #[cfg(feature = "tracing")]
235            let _span_guard = tracing::info_span!("test_attempt", attempt).entered();
236
237            match panic::catch_unwind(test_fn) {
238                Ok(Ok(())) => return Ok(()),
239                Ok(Err(err)) => {
240                    if attempt < self.times && should_retry(&err) {
241                        #[cfg(not(feature = "tracing"))]
242                        println!("Test attempt #{attempt} errored: {err}");
243                        #[cfg(feature = "tracing")]
244                        tracing::warn!(%err, "test attempt errored");
245                    } else {
246                        return Err(err);
247                    }
248                }
249                Err(panic_object) => {
250                    self.handle_panic(attempt, panic_object);
251                }
252            }
253            if self.delay > Duration::ZERO {
254                thread::sleep(self.delay);
255            }
256        }
257        Ok(())
258    }
259}
260
261impl DecorateTest<()> for Retry {
262    fn decorate_and_test<F: TestFn<()>>(&self, test_fn: F) {
263        for attempt in 0..=self.times {
264            #[cfg(not(feature = "tracing"))]
265            println!("Test attempt #{attempt}");
266            #[cfg(feature = "tracing")]
267            let _span_guard = tracing::info_span!("test_attempt", attempt).entered();
268
269            match panic::catch_unwind(test_fn) {
270                Ok(()) => break,
271                Err(panic_object) => {
272                    self.handle_panic(attempt, panic_object);
273                }
274            }
275            if self.delay > Duration::ZERO {
276                thread::sleep(self.delay);
277            }
278        }
279    }
280}
281
282impl<E: fmt::Display> DecorateTest<Result<(), E>> for Retry {
283    fn decorate_and_test<F>(&self, test_fn: F) -> Result<(), E>
284    where
285        F: TestFn<Result<(), E>>,
286    {
287        self.run_with_retries(test_fn, |_| true)
288    }
289}
290
291fn extract_panic_str(panic_object: &(dyn any::Any + Send)) -> Option<&str> {
292    if let Some(panic_str) = panic_object.downcast_ref::<&'static str>() {
293        Some(panic_str)
294    } else if let Some(panic_string) = panic_object.downcast_ref::<String>() {
295        Some(panic_string.as_str())
296    } else {
297        None
298    }
299}
300
301/// [Test decorator](DecorateTest) that retries a wrapped test a certain number of times
302/// only if an error matches the specified predicate.
303///
304/// Constructed using [`Retry::on_error()`].
305///
306/// # Examples
307///
308/// ```no_run
309/// use test_casing::{decorate, decorators::{Retry, RetryErrors}};
310/// use std::error::Error;
311///
312/// const RETRY: RetryErrors<Box<dyn Error>> = Retry::times(3)
313///     .on_error(|err| err.to_string().contains("retry please"));
314///
315/// #[test]
316/// #[decorate(RETRY)]
317/// fn test_with_retries() -> Result<(), Box<dyn Error>> {
318///     // test logic
319/// #    Ok(())
320/// }
321/// ```
322pub struct RetryErrors<E> {
323    inner: Retry,
324    matcher: fn(&E) -> bool,
325}
326
327impl<E> fmt::Debug for RetryErrors<E> {
328    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
329        formatter
330            .debug_struct("RetryErrors")
331            .field("inner", &self.inner)
332            .finish_non_exhaustive()
333    }
334}
335
336impl<E: fmt::Display + 'static> DecorateTest<Result<(), E>> for RetryErrors<E> {
337    fn decorate_and_test<F>(&self, test_fn: F) -> Result<(), E>
338    where
339        F: TestFn<Result<(), E>>,
340    {
341        self.inner.run_with_retries(test_fn, self.matcher)
342    }
343}
344
345/// [Test decorator](DecorateTest) that makes runs of decorated tests sequential. The sequence
346/// can optionally be aborted if a test in it fails.
347///
348/// The run ordering of tests in the sequence is not deterministic. This is because depending
349/// on the command-line args that the test was launched with, not all tests in the sequence may run
350/// at all.
351///
352/// # Examples
353///
354/// ```no_run
355/// use test_casing::{decorate, decorators::{Sequence, Timeout}};
356///
357/// static SEQUENCE: Sequence = Sequence::new().abort_on_failure();
358///
359/// #[test]
360/// #[decorate(&SEQUENCE)]
361/// fn sequential_test() {
362///     // test logic
363/// }
364///
365/// #[test]
366/// #[decorate(Timeout::secs(1), &SEQUENCE)]
367/// fn other_sequential_test() {
368///     // test logic
369/// }
370/// ```
371#[derive(Debug, Default)]
372pub struct Sequence {
373    failed: Mutex<bool>,
374    abort_on_failure: bool,
375}
376
377impl Sequence {
378    /// Creates a new test sequence.
379    pub const fn new() -> Self {
380        Self {
381            failed: Mutex::new(false),
382            abort_on_failure: false,
383        }
384    }
385
386    /// Makes the decorated tests abort immediately if one test from the sequence fails.
387    #[must_use]
388    pub const fn abort_on_failure(mut self) -> Self {
389        self.abort_on_failure = true;
390        self
391    }
392
393    fn decorate_inner<R, F: TestFn<R>>(
394        &self,
395        test_fn: F,
396        ok_value: R,
397        match_failure: fn(&R) -> bool,
398    ) -> R {
399        let mut guard = self.failed.lock().unwrap_or_else(PoisonError::into_inner);
400        if *guard && self.abort_on_failure {
401            #[cfg(not(feature = "tracing"))]
402            println!("Skipping test because a previous test in the same sequence has failed");
403            #[cfg(feature = "tracing")]
404            tracing::info!("skipping test because a previous test in the same sequence has failed");
405
406            return ok_value;
407        }
408
409        let output = panic::catch_unwind(test_fn);
410        *guard = output.as_ref().map_or(true, match_failure);
411        drop(guard);
412        output.unwrap_or_else(|panic_object| {
413            panic::resume_unwind(panic_object);
414        })
415    }
416}
417
418impl DecorateTest<()> for Sequence {
419    fn decorate_and_test<F: TestFn<()>>(&self, test_fn: F) {
420        self.decorate_inner(test_fn, (), |()| false);
421    }
422}
423
424impl<E: 'static> DecorateTest<Result<(), E>> for Sequence {
425    fn decorate_and_test<F>(&self, test_fn: F) -> Result<(), E>
426    where
427        F: TestFn<Result<(), E>>,
428    {
429        self.decorate_inner(test_fn, Ok(()), Result::is_err)
430    }
431}
432
433macro_rules! impl_decorate_test_for_tuple {
434    ($($field:ident : $ty:ident),* => $last_field:ident : $last_ty:ident) => {
435        impl<R, $($ty,)* $last_ty> DecorateTest<R> for ($($ty,)* $last_ty,)
436        where
437            $($ty: DecorateTest<R>,)*
438            $last_ty: DecorateTest<R>,
439        {
440            fn decorate_and_test<Fn: TestFn<R>>(&'static self, test_fn: Fn) -> R {
441                let ($($field,)* $last_field,) = self;
442                $(
443                let test_fn = move || $field.decorate_and_test(test_fn);
444                )*
445                $last_field.decorate_and_test(test_fn)
446            }
447        }
448    };
449}
450
451impl_decorate_test_for_tuple!(=> a: A);
452impl_decorate_test_for_tuple!(a: A => b: B);
453impl_decorate_test_for_tuple!(a: A, b: B => c: C);
454impl_decorate_test_for_tuple!(a: A, b: B, c: C => d: D);
455impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D => e: E);
456impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D, e: E => f: F);
457impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D, e: E, f: F => g: G);
458impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D, e: E, f: F, g: G => h: H);
459
460#[cfg(test)]
461mod tests {
462    use std::{
463        io,
464        sync::{
465            atomic::{AtomicU32, Ordering},
466            Mutex,
467        },
468        time::Instant,
469    };
470
471    use super::*;
472
473    #[test]
474    #[should_panic(expected = "Timeout 100ms expired")]
475    fn timeouts() {
476        const TIMEOUT: Timeout = Timeout(Duration::from_millis(100));
477
478        let test_fn: fn() = || thread::sleep(Duration::from_secs(1));
479        TIMEOUT.decorate_and_test(test_fn);
480    }
481
482    #[test]
483    fn retrying_with_delay() {
484        const RETRY: Retry = Retry::times(1).with_delay(Duration::from_millis(100));
485
486        fn test_fn() -> Result<(), &'static str> {
487            static TEST_START: Mutex<Option<Instant>> = Mutex::new(None);
488
489            let mut test_start = TEST_START.lock().unwrap();
490            if let Some(test_start) = *test_start {
491                assert!(test_start.elapsed() > RETRY.delay);
492                Ok(())
493            } else {
494                *test_start = Some(Instant::now());
495                Err("come again?")
496            }
497        }
498
499        RETRY.decorate_and_test(test_fn).unwrap();
500    }
501
502    const RETRY: RetryErrors<io::Error> =
503        Retry::times(2).on_error(|err| matches!(err.kind(), io::ErrorKind::AddrInUse));
504
505    #[test]
506    fn retrying_on_error() {
507        static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
508
509        fn test_fn() -> io::Result<()> {
510            if TEST_COUNTER.fetch_add(1, Ordering::Relaxed) == 2 {
511                Ok(())
512            } else {
513                Err(io::Error::new(
514                    io::ErrorKind::AddrInUse,
515                    "please retry later",
516                ))
517            }
518        }
519
520        let test_fn: fn() -> _ = test_fn;
521        RETRY.decorate_and_test(test_fn).unwrap();
522        assert_eq!(TEST_COUNTER.load(Ordering::Relaxed), 3);
523
524        let err = RETRY.decorate_and_test(test_fn).unwrap_err();
525        assert!(err.to_string().contains("please retry later"));
526        assert_eq!(TEST_COUNTER.load(Ordering::Relaxed), 6);
527    }
528
529    #[test]
530    fn retrying_on_error_failure() {
531        static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
532
533        fn test_fn() -> io::Result<()> {
534            if TEST_COUNTER.fetch_add(1, Ordering::Relaxed) == 0 {
535                Err(io::Error::new(io::ErrorKind::BrokenPipe, "oops"))
536            } else {
537                Ok(())
538            }
539        }
540
541        let err = RETRY.decorate_and_test(test_fn).unwrap_err();
542        assert!(err.to_string().contains("oops"));
543        assert_eq!(TEST_COUNTER.load(Ordering::Relaxed), 1);
544    }
545
546    #[test]
547    fn sequential_tests() {
548        static SEQUENCE: Sequence = Sequence::new();
549        static ENTRY_COUNTER: AtomicU32 = AtomicU32::new(0);
550
551        let first_test: fn() = || {
552            let counter = ENTRY_COUNTER.fetch_add(1, Ordering::Relaxed);
553            assert_eq!(counter, 0);
554            thread::sleep(Duration::from_millis(10));
555            ENTRY_COUNTER.store(0, Ordering::Relaxed);
556            panic!("oops");
557        };
558        let second_test = || {
559            let counter = ENTRY_COUNTER.fetch_add(1, Ordering::Relaxed);
560            assert_eq!(counter, 0);
561            thread::sleep(Duration::from_millis(20));
562            ENTRY_COUNTER.store(0, Ordering::Relaxed);
563            Ok::<_, io::Error>(())
564        };
565
566        let first_test_handle = thread::spawn(move || SEQUENCE.decorate_and_test(first_test));
567        SEQUENCE.decorate_and_test(second_test).unwrap();
568        first_test_handle.join().unwrap_err();
569    }
570
571    #[test]
572    fn sequential_tests_with_abort() {
573        static SEQUENCE: Sequence = Sequence::new().abort_on_failure();
574
575        let failing_test =
576            || Err::<(), _>(io::Error::new(io::ErrorKind::AddrInUse, "please try later"));
577        let second_test: fn() = || unreachable!("Second test should not be called!");
578
579        SEQUENCE.decorate_and_test(failing_test).unwrap_err();
580        SEQUENCE.decorate_and_test(second_test);
581    }
582
583    // We need independent test counters for different tests, hence defining a function
584    // via a macro.
585    macro_rules! define_test_fn {
586        () => {
587            fn test_fn() -> Result<(), &'static str> {
588                static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
589                match TEST_COUNTER.fetch_add(1, Ordering::Relaxed) {
590                    0 => {
591                        thread::sleep(Duration::from_secs(1));
592                        Ok(())
593                    }
594                    1 => Err("oops"),
595                    2 => Ok(()),
596                    _ => unreachable!(),
597                }
598            }
599        };
600    }
601
602    #[test]
603    fn composing_decorators() {
604        define_test_fn!();
605
606        const DECORATORS: (Timeout, Retry) = (Timeout(Duration::from_millis(100)), Retry::times(2));
607
608        DECORATORS.decorate_and_test(test_fn).unwrap();
609    }
610
611    #[test]
612    fn making_decorator_into_trait_object() {
613        define_test_fn!();
614
615        static DECORATORS: &dyn DecorateTestFn<Result<(), &'static str>> =
616            &(Timeout(Duration::from_millis(100)), Retry::times(2));
617
618        DECORATORS.decorate_and_test_fn(test_fn).unwrap();
619    }
620
621    #[test]
622    fn making_sequence_into_trait_object() {
623        static SEQUENCE: Sequence = Sequence::new();
624        static DECORATORS: &dyn DecorateTestFn<()> = &(&SEQUENCE,);
625
626        DECORATORS.decorate_and_test_fn(|| {});
627    }
628
629    #[test]
630    fn extracting_panic() {
631        let value = 0;
632        let panic_obj = panic::catch_unwind(|| {
633            assert!(value > 1, "what");
634        })
635        .unwrap_err();
636
637        assert_eq!(extract_panic_str(panic_obj.as_ref()), Some("what"));
638
639        let panic_obj = panic::catch_unwind(|| {
640            assert!(value > 1, "what: {value}");
641        })
642        .unwrap_err();
643
644        assert_eq!(extract_panic_str(panic_obj.as_ref()), Some("what: 0"));
645    }
646}