test_casing/
decorators.rs

1//! Test decorator trait and implementations.
2//!
3//! # Overview
4//!
5//! A [test decorator](DecorateTest) takes a [tested function](TestFn) and calls it zero or more times,
6//! perhaps with additional logic spliced between calls. Examples of decorators include [retries](Retry),
7//! [`Timeout`]s and test [`Sequence`]s.
8//!
9//! Decorators are composable: `DecorateTest` is automatically implemented for a tuple with
10//! 2..=8 elements where each element implements `DecorateTest`. The decorators in a tuple
11//! are applied in the order of their appearance in the tuple.
12//!
13//! # Examples
14//!
15//! See [`decorate`](crate::decorate) macro docs for the examples of usage.
16
17use std::{
18    any::Any,
19    fmt, panic,
20    sync::{
21        mpsc::{self, RecvTimeoutError},
22        Mutex, PoisonError,
23    },
24    thread,
25    time::Duration,
26};
27
28/// Tested function or closure.
29///
30/// This trait is automatically implemented for all functions without arguments.
31pub trait TestFn<R>: Fn() -> R + panic::UnwindSafe + Send + Sync + Copy + 'static {}
32
33impl<R, F> TestFn<R> for F where F: Fn() -> R + panic::UnwindSafe + Send + Sync + Copy + 'static {}
34
35/// Test decorator.
36///
37/// See [module docs](index.html#overview) for the extended description.
38///
39/// # Examples
40///
41/// The following decorator implements a `#[should_panic]` analogue for errors.
42///
43/// ```no_run
44/// use test_casing::decorators::{DecorateTest, TestFn};
45///
46/// #[derive(Debug, Clone, Copy)]
47/// pub struct ShouldError(pub &'static str);
48///
49/// impl<E: ToString> DecorateTest<Result<(), E>> for ShouldError {
50///     fn decorate_and_test<F: TestFn<Result<(), E>>>(
51///         &self,
52///         test_fn: F,
53///     ) -> Result<(), E> {
54///         let Err(err) = test_fn() else {
55///             panic!("Expected test to error, but it completed successfully");
56///         };
57///         let err = err.to_string();
58///         if err.contains(self.0) {
59///             Ok(())
60///         } else {
61///             panic!(
62///                 "Expected error message to contain `{}`, but it was: {err}",
63///                 self.0
64///             );
65///         }
66///     }
67/// }
68///
69/// // Usage:
70/// # use test_casing::decorate;
71/// # use std::error::Error;
72/// #[test]
73/// #[decorate(ShouldError("oops"))]
74/// fn test_with_an_error() -> Result<(), Box<dyn Error>> {
75///     Err("oops, this test failed".into())
76/// }
77/// ```
78pub trait DecorateTest<R>: panic::RefUnwindSafe + Send + Sync + 'static {
79    /// Decorates the provided test function and runs the test.
80    fn decorate_and_test<F: TestFn<R>>(&'static self, test_fn: F) -> R;
81}
82
83impl<R, T: DecorateTest<R>> DecorateTest<R> for &'static T {
84    fn decorate_and_test<F: TestFn<R>>(&'static self, test_fn: F) -> R {
85        (**self).decorate_and_test(test_fn)
86    }
87}
88
89/// Object-safe version of [`DecorateTest`].
90#[doc(hidden)] // used in the `decorate` proc macro; logically private
91pub trait DecorateTestFn<R>: panic::RefUnwindSafe + Send + Sync + 'static {
92    fn decorate_and_test_fn(&'static self, test_fn: fn() -> R) -> R;
93}
94
95impl<R: 'static, T: DecorateTest<R>> DecorateTestFn<R> for T {
96    fn decorate_and_test_fn(&'static self, test_fn: fn() -> R) -> R {
97        self.decorate_and_test(test_fn)
98    }
99}
100
101/// [Test decorator](DecorateTest) that fails a wrapped test if it doesn't complete
102/// in the specified [`Duration`].
103///
104/// # Examples
105///
106/// ```no_run
107/// use test_casing::{decorate, decorators::Timeout};
108///
109/// #[test]
110/// #[decorate(Timeout::secs(5))]
111/// fn test_with_timeout() {
112///     // test logic
113/// }
114/// ```
115#[derive(Debug, Clone, Copy)]
116pub struct Timeout(pub Duration);
117
118impl Timeout {
119    /// Defines a timeout with the specified number of seconds.
120    pub const fn secs(secs: u64) -> Self {
121        Self(Duration::from_secs(secs))
122    }
123
124    /// Defines a timeout with the specified number of milliseconds.
125    pub const fn millis(millis: u64) -> Self {
126        Self(Duration::from_millis(millis))
127    }
128}
129
130impl<R: Send + 'static> DecorateTest<R> for Timeout {
131    #[allow(clippy::similar_names)]
132    fn decorate_and_test<F: TestFn<R>>(&self, test_fn: F) -> R {
133        let (output_sx, output_rx) = mpsc::channel();
134        let handle = thread::spawn(move || {
135            output_sx.send(test_fn()).ok();
136        });
137        match output_rx.recv_timeout(self.0) {
138            Ok(output) => {
139                handle.join().unwrap();
140                // ^ `unwrap()` is safe; the thread didn't panic before `send`ing the output,
141                // and there's nowhere to panic after that.
142                output
143            }
144            Err(RecvTimeoutError::Timeout) => {
145                panic!("Timeout {:?} expired for the test", self.0);
146            }
147            Err(RecvTimeoutError::Disconnected) => {
148                let panic_object = handle.join().unwrap_err();
149                panic::resume_unwind(panic_object)
150            }
151        }
152    }
153}
154
155/// [Test decorator](DecorateTest) that retries a wrapped test the specified number of times,
156/// potentially with a delay between retries.
157///
158/// # Examples
159///
160/// ```no_run
161/// use test_casing::{decorate, decorators::Retry};
162/// use std::time::Duration;
163///
164/// const RETRY_DELAY: Duration = Duration::from_millis(200);
165///
166/// #[test]
167/// #[decorate(Retry::times(3).with_delay(RETRY_DELAY))]
168/// fn test_with_retries() {
169///     // test logic
170/// }
171/// ```
172#[derive(Debug)]
173pub struct Retry {
174    times: usize,
175    delay: Duration,
176}
177
178impl Retry {
179    /// Specified the number of retries. The delay between retries is zero.
180    pub const fn times(times: usize) -> Self {
181        Self {
182            times,
183            delay: Duration::ZERO,
184        }
185    }
186
187    /// Specifies the delay between retries.
188    #[must_use]
189    pub const fn with_delay(self, delay: Duration) -> Self {
190        Self { delay, ..self }
191    }
192
193    /// Converts this retry specification to only retry specific errors.
194    pub const fn on_error<E>(self, matcher: fn(&E) -> bool) -> RetryErrors<E> {
195        RetryErrors {
196            inner: self,
197            matcher,
198        }
199    }
200
201    fn handle_panic(&self, attempt: usize, panic_object: Box<dyn Any + Send>) {
202        if attempt < self.times {
203            let panic_str = extract_panic_str(&panic_object).unwrap_or("");
204            let punctuation = if panic_str.is_empty() { "" } else { ": " };
205            println!("Test attempt #{attempt} panicked{punctuation}{panic_str}");
206        } else {
207            panic::resume_unwind(panic_object);
208        }
209    }
210
211    fn run_with_retries<E: fmt::Display>(
212        &self,
213        test_fn: impl TestFn<Result<(), E>>,
214        should_retry: fn(&E) -> bool,
215    ) -> Result<(), E> {
216        for attempt in 0..=self.times {
217            println!("Test attempt #{attempt}");
218            match panic::catch_unwind(test_fn) {
219                Ok(Ok(())) => return Ok(()),
220                Ok(Err(err)) => {
221                    if attempt < self.times && should_retry(&err) {
222                        println!("Test attempt #{attempt} errored: {err}");
223                    } else {
224                        return Err(err);
225                    }
226                }
227                Err(panic_object) => {
228                    self.handle_panic(attempt, panic_object);
229                }
230            }
231            if self.delay > Duration::ZERO {
232                thread::sleep(self.delay);
233            }
234        }
235        Ok(())
236    }
237}
238
239impl DecorateTest<()> for Retry {
240    fn decorate_and_test<F: TestFn<()>>(&self, test_fn: F) {
241        for attempt in 0..=self.times {
242            println!("Test attempt #{attempt}");
243            match panic::catch_unwind(test_fn) {
244                Ok(()) => break,
245                Err(panic_object) => {
246                    self.handle_panic(attempt, panic_object);
247                }
248            }
249            if self.delay > Duration::ZERO {
250                thread::sleep(self.delay);
251            }
252        }
253    }
254}
255
256impl<E: fmt::Display> DecorateTest<Result<(), E>> for Retry {
257    fn decorate_and_test<F>(&self, test_fn: F) -> Result<(), E>
258    where
259        F: TestFn<Result<(), E>>,
260    {
261        self.run_with_retries(test_fn, |_| true)
262    }
263}
264
265fn extract_panic_str(panic_object: &(dyn Any + Send)) -> Option<&str> {
266    if let Some(panic_str) = panic_object.downcast_ref::<&'static str>() {
267        Some(panic_str)
268    } else if let Some(panic_string) = panic_object.downcast_ref::<String>() {
269        Some(panic_string.as_str())
270    } else {
271        None
272    }
273}
274
275/// [Test decorator](DecorateTest) that retries a wrapped test a certain number of times
276/// only if an error matches the specified predicate.
277///
278/// Constructed using [`Retry::on_error()`].
279///
280/// # Examples
281///
282/// ```no_run
283/// use test_casing::{decorate, decorators::{Retry, RetryErrors}};
284/// use std::error::Error;
285///
286/// const RETRY: RetryErrors<Box<dyn Error>> = Retry::times(3)
287///     .on_error(|err| err.to_string().contains("retry please"));
288///
289/// #[test]
290/// #[decorate(RETRY)]
291/// fn test_with_retries() -> Result<(), Box<dyn Error>> {
292///     // test logic
293/// #    Ok(())
294/// }
295/// ```
296pub struct RetryErrors<E> {
297    inner: Retry,
298    matcher: fn(&E) -> bool,
299}
300
301impl<E> fmt::Debug for RetryErrors<E> {
302    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
303        formatter
304            .debug_struct("RetryErrors")
305            .field("inner", &self.inner)
306            .finish_non_exhaustive()
307    }
308}
309
310impl<E: fmt::Display + 'static> DecorateTest<Result<(), E>> for RetryErrors<E> {
311    fn decorate_and_test<F>(&self, test_fn: F) -> Result<(), E>
312    where
313        F: TestFn<Result<(), E>>,
314    {
315        self.inner.run_with_retries(test_fn, self.matcher)
316    }
317}
318
319/// [Test decorator](DecorateTest) that makes runs of decorated tests sequential. The sequence
320/// can optionally be aborted if a test in it fails.
321///
322/// The run ordering of tests in the sequence is not deterministic. This is because depending
323/// on the command-line args that the test was launched with, not all tests in the sequence may run
324/// at all.
325///
326/// # Examples
327///
328/// ```no_run
329/// use test_casing::{decorate, decorators::{Sequence, Timeout}};
330///
331/// static SEQUENCE: Sequence = Sequence::new().abort_on_failure();
332///
333/// #[test]
334/// #[decorate(&SEQUENCE)]
335/// fn sequential_test() {
336///     // test logic
337/// }
338///
339/// #[test]
340/// #[decorate(Timeout::secs(1), &SEQUENCE)]
341/// fn other_sequential_test() {
342///     // test logic
343/// }
344/// ```
345#[derive(Debug, Default)]
346pub struct Sequence {
347    failed: Mutex<bool>,
348    abort_on_failure: bool,
349}
350
351impl Sequence {
352    /// Creates a new test sequence.
353    pub const fn new() -> Self {
354        Self {
355            failed: Mutex::new(false),
356            abort_on_failure: false,
357        }
358    }
359
360    /// Makes the decorated tests abort immediately if one test from the sequence fails.
361    #[must_use]
362    pub const fn abort_on_failure(mut self) -> Self {
363        self.abort_on_failure = true;
364        self
365    }
366
367    fn decorate_inner<R, F: TestFn<R>>(
368        &self,
369        test_fn: F,
370        ok_value: R,
371        match_failure: fn(&R) -> bool,
372    ) -> R {
373        let mut guard = self.failed.lock().unwrap_or_else(PoisonError::into_inner);
374        if *guard && self.abort_on_failure {
375            println!("Skipping test because a previous test in the same sequence has failed");
376            return ok_value;
377        }
378
379        let output = panic::catch_unwind(test_fn);
380        *guard = output.as_ref().map_or(true, match_failure);
381        drop(guard);
382        output.unwrap_or_else(|panic_object| {
383            panic::resume_unwind(panic_object);
384        })
385    }
386}
387
388impl DecorateTest<()> for Sequence {
389    fn decorate_and_test<F: TestFn<()>>(&self, test_fn: F) {
390        self.decorate_inner(test_fn, (), |()| false);
391    }
392}
393
394impl<E: 'static> DecorateTest<Result<(), E>> for Sequence {
395    fn decorate_and_test<F>(&self, test_fn: F) -> Result<(), E>
396    where
397        F: TestFn<Result<(), E>>,
398    {
399        self.decorate_inner(test_fn, Ok(()), Result::is_err)
400    }
401}
402
403macro_rules! impl_decorate_test_for_tuple {
404    ($($field:ident : $ty:ident),* => $last_field:ident : $last_ty:ident) => {
405        impl<R, $($ty,)* $last_ty> DecorateTest<R> for ($($ty,)* $last_ty,)
406        where
407            $($ty: DecorateTest<R>,)*
408            $last_ty: DecorateTest<R>,
409        {
410            fn decorate_and_test<Fn: TestFn<R>>(&'static self, test_fn: Fn) -> R {
411                let ($($field,)* $last_field,) = self;
412                $(
413                let test_fn = move || $field.decorate_and_test(test_fn);
414                )*
415                $last_field.decorate_and_test(test_fn)
416            }
417        }
418    };
419}
420
421impl_decorate_test_for_tuple!(=> a: A);
422impl_decorate_test_for_tuple!(a: A => b: B);
423impl_decorate_test_for_tuple!(a: A, b: B => c: C);
424impl_decorate_test_for_tuple!(a: A, b: B, c: C => d: D);
425impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D => e: E);
426impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D, e: E => f: F);
427impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D, e: E, f: F => g: G);
428impl_decorate_test_for_tuple!(a: A, b: B, c: C, d: D, e: E, f: F, g: G => h: H);
429
430#[cfg(test)]
431mod tests {
432    use std::{
433        io,
434        sync::{
435            atomic::{AtomicU32, Ordering},
436            Mutex,
437        },
438        time::Instant,
439    };
440
441    use super::*;
442
443    #[test]
444    #[should_panic(expected = "Timeout 100ms expired")]
445    fn timeouts() {
446        const TIMEOUT: Timeout = Timeout(Duration::from_millis(100));
447
448        let test_fn: fn() = || thread::sleep(Duration::from_secs(1));
449        TIMEOUT.decorate_and_test(test_fn);
450    }
451
452    #[test]
453    fn retrying_with_delay() {
454        const RETRY: Retry = Retry::times(1).with_delay(Duration::from_millis(100));
455
456        fn test_fn() -> Result<(), &'static str> {
457            static TEST_START: Mutex<Option<Instant>> = Mutex::new(None);
458
459            let mut test_start = TEST_START.lock().unwrap();
460            if let Some(test_start) = *test_start {
461                assert!(test_start.elapsed() > RETRY.delay);
462                Ok(())
463            } else {
464                *test_start = Some(Instant::now());
465                Err("come again?")
466            }
467        }
468
469        RETRY.decorate_and_test(test_fn).unwrap();
470    }
471
472    const RETRY: RetryErrors<io::Error> =
473        Retry::times(2).on_error(|err| matches!(err.kind(), io::ErrorKind::AddrInUse));
474
475    #[test]
476    fn retrying_on_error() {
477        static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
478
479        fn test_fn() -> io::Result<()> {
480            if TEST_COUNTER.fetch_add(1, Ordering::Relaxed) == 2 {
481                Ok(())
482            } else {
483                Err(io::Error::new(
484                    io::ErrorKind::AddrInUse,
485                    "please retry later",
486                ))
487            }
488        }
489
490        let test_fn: fn() -> _ = test_fn;
491        RETRY.decorate_and_test(test_fn).unwrap();
492        assert_eq!(TEST_COUNTER.load(Ordering::Relaxed), 3);
493
494        let err = RETRY.decorate_and_test(test_fn).unwrap_err();
495        assert!(err.to_string().contains("please retry later"));
496        assert_eq!(TEST_COUNTER.load(Ordering::Relaxed), 6);
497    }
498
499    #[test]
500    fn retrying_on_error_failure() {
501        static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
502
503        fn test_fn() -> io::Result<()> {
504            if TEST_COUNTER.fetch_add(1, Ordering::Relaxed) == 0 {
505                Err(io::Error::new(io::ErrorKind::BrokenPipe, "oops"))
506            } else {
507                Ok(())
508            }
509        }
510
511        let err = RETRY.decorate_and_test(test_fn).unwrap_err();
512        assert!(err.to_string().contains("oops"));
513        assert_eq!(TEST_COUNTER.load(Ordering::Relaxed), 1);
514    }
515
516    #[test]
517    fn sequential_tests() {
518        static SEQUENCE: Sequence = Sequence::new();
519        static ENTRY_COUNTER: AtomicU32 = AtomicU32::new(0);
520
521        let first_test: fn() = || {
522            let counter = ENTRY_COUNTER.fetch_add(1, Ordering::Relaxed);
523            assert_eq!(counter, 0);
524            thread::sleep(Duration::from_millis(10));
525            ENTRY_COUNTER.store(0, Ordering::Relaxed);
526            panic!("oops");
527        };
528        let second_test = || {
529            let counter = ENTRY_COUNTER.fetch_add(1, Ordering::Relaxed);
530            assert_eq!(counter, 0);
531            thread::sleep(Duration::from_millis(20));
532            ENTRY_COUNTER.store(0, Ordering::Relaxed);
533            Ok::<_, io::Error>(())
534        };
535
536        let first_test_handle = thread::spawn(move || SEQUENCE.decorate_and_test(first_test));
537        SEQUENCE.decorate_and_test(second_test).unwrap();
538        first_test_handle.join().unwrap_err();
539    }
540
541    #[test]
542    fn sequential_tests_with_abort() {
543        static SEQUENCE: Sequence = Sequence::new().abort_on_failure();
544
545        let failing_test =
546            || Err::<(), _>(io::Error::new(io::ErrorKind::AddrInUse, "please try later"));
547        let second_test: fn() = || unreachable!("Second test should not be called!");
548
549        SEQUENCE.decorate_and_test(failing_test).unwrap_err();
550        SEQUENCE.decorate_and_test(second_test);
551    }
552
553    // We need independent test counters for different tests, hence defining a function
554    // via a macro.
555    macro_rules! define_test_fn {
556        () => {
557            fn test_fn() -> Result<(), &'static str> {
558                static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
559                match TEST_COUNTER.fetch_add(1, Ordering::Relaxed) {
560                    0 => {
561                        thread::sleep(Duration::from_secs(1));
562                        Ok(())
563                    }
564                    1 => Err("oops"),
565                    2 => Ok(()),
566                    _ => unreachable!(),
567                }
568            }
569        };
570    }
571
572    #[test]
573    fn composing_decorators() {
574        define_test_fn!();
575
576        const DECORATORS: (Timeout, Retry) = (Timeout(Duration::from_millis(100)), Retry::times(2));
577
578        DECORATORS.decorate_and_test(test_fn).unwrap();
579    }
580
581    #[test]
582    fn making_decorator_into_trait_object() {
583        define_test_fn!();
584
585        static DECORATORS: &dyn DecorateTestFn<Result<(), &'static str>> =
586            &(Timeout(Duration::from_millis(100)), Retry::times(2));
587
588        DECORATORS.decorate_and_test_fn(test_fn).unwrap();
589    }
590
591    #[test]
592    fn making_sequence_into_trait_object() {
593        static SEQUENCE: Sequence = Sequence::new();
594        static DECORATORS: &dyn DecorateTestFn<()> = &(&SEQUENCE,);
595
596        DECORATORS.decorate_and_test_fn(|| {});
597    }
598}