test_casing/decorators/
traces.rs1use std::{env, fmt};
4
5use tracing::{level_filters::LevelFilter, Dispatch, Event, Subscriber};
6use tracing_subscriber::{
7 field::RecordFields,
8 fmt::{format, format::Writer, FmtContext, FormatEvent, FormatFields, TestWriter},
9 registry::LookupSpan,
10 EnvFilter, FmtSubscriber,
11};
12
13use crate::decorators::{DecorateTest, TestFn};
14
15#[derive(Debug)]
16enum Either<L, R> {
17 Left(L),
18 Right(R),
19}
20
21impl<'w, L, R> FormatFields<'w> for Either<L, R>
22where
23 L: FormatFields<'w>,
24 R: FormatFields<'w>,
25{
26 fn format_fields<F: RecordFields>(&self, writer: Writer<'w>, fields: F) -> fmt::Result {
27 match self {
28 Self::Left(formatter) => formatter.format_fields(writer, fields),
29 Self::Right(formatter) => formatter.format_fields(writer, fields),
30 }
31 }
32}
33
34impl<S, N, L, R> FormatEvent<S, N> for Either<L, R>
35where
36 S: Subscriber + for<'a> LookupSpan<'a>,
37 N: for<'a> FormatFields<'a> + 'static,
38 L: FormatEvent<S, N>,
39 R: FormatEvent<S, N>,
40{
41 fn format_event(
42 &self,
43 ctx: &FmtContext<'_, S, N>,
44 writer: Writer<'_>,
45 event: &Event<'_>,
46 ) -> fmt::Result {
47 match self {
48 Self::Left(formatter) => formatter.format_event(ctx, writer, event),
49 Self::Right(formatter) => formatter.format_event(ctx, writer, event),
50 }
51 }
52}
53
54type TestSubscriber = FmtSubscriber<
55 Either<format::Pretty, format::DefaultFields>,
56 Either<format::Format<format::Pretty>, format::Format>,
57 EnvFilter,
58 TestWriter,
59>;
60
61#[cfg_attr(docsrs, doc(cfg(feature = "tracing")))]
81#[derive(Debug, Clone, Copy)]
82pub struct Trace {
83 directives: Option<&'static str>,
84 pretty: bool,
85 global: bool,
86}
87
88impl Trace {
89 pub const fn new(directives: &'static str) -> Self {
92 Self {
93 directives: Some(directives),
94 pretty: false,
95 global: false,
96 }
97 }
98
99 #[must_use]
101 pub const fn pretty(mut self) -> Self {
102 self.pretty = true;
103 self
104 }
105
106 #[must_use]
109 pub const fn global(mut self) -> Self {
110 self.global = true;
111 self
112 }
113
114 pub fn create_subscriber(self) -> impl Subscriber + for<'a> LookupSpan<'a> {
117 self.create_subscriber_inner()
118 }
119
120 fn create_subscriber_inner(self) -> TestSubscriber {
121 let env = env::var("RUST_LOG").ok();
122 let env = env.as_deref().or(self.directives).unwrap_or_default();
123 let env_filter = EnvFilter::builder()
124 .with_default_directive(LevelFilter::INFO.into())
125 .parse_lossy(env);
126 FmtSubscriber::builder()
127 .with_test_writer()
128 .with_env_filter(env_filter)
129 .fmt_fields(if self.pretty {
130 Either::Left(format::Pretty::default())
131 } else {
132 Either::Right(format::DefaultFields::default())
133 })
134 .map_event_format(|fmt| {
135 if self.pretty {
136 Either::Left(fmt.pretty())
137 } else {
138 Either::Right(fmt)
139 }
140 })
141 .finish()
142 }
143}
144
145impl<R> DecorateTest<R> for Trace {
146 fn decorate_and_test<F: TestFn<R>>(&'static self, test_fn: F) -> R {
147 let subscriber = self.create_subscriber_inner();
148 let _guard = if self.global {
149 if tracing::subscriber::set_global_default(subscriber).is_err() {
150 let is_test_subscriber =
151 tracing::dispatcher::get_default(Dispatch::is::<TestSubscriber>);
152 if !is_test_subscriber {
153 tracing::warn!("could not set up global tracing subscriber");
154 }
155 }
156 None
157 } else {
158 Some(tracing::subscriber::set_default(subscriber))
159 };
160 test_fn()
161 }
162}