term_transcript/shell/
transcript_impl.rs

1//! Shell-related `Transcript` methods.
2
3use std::{
4    io::{self, BufRead, BufReader, LineWriter, Read},
5    iter,
6    process::{Command, Stdio},
7    sync::mpsc,
8    thread,
9    time::Duration,
10};
11
12use super::ShellOptions;
13use crate::{
14    traits::{ShellProcess, SpawnShell, SpawnedShell},
15    Captured, Interaction, Transcript, UserInput,
16};
17
18#[derive(Debug)]
19struct Timeouts {
20    inner: iter::Chain<iter::Once<Duration>, iter::Repeat<Duration>>,
21}
22
23impl Timeouts {
24    fn new<Cmd: SpawnShell>(options: &ShellOptions<Cmd>) -> Self {
25        Self {
26            inner: iter::once(options.init_timeout + options.io_timeout)
27                .chain(iter::repeat(options.io_timeout)),
28        }
29    }
30
31    fn next(&mut self) -> Duration {
32        self.inner.next().unwrap() // safe by construction; the iterator is indefinite
33    }
34}
35
36impl Transcript {
37    #[cfg(not(windows))]
38    #[cfg_attr(
39        feature = "tracing",
40        tracing::instrument(level = "debug", skip(writer), err)
41    )]
42    fn write_line(writer: &mut impl io::Write, line: &str) -> io::Result<()> {
43        writeln!(writer, "{line}")
44    }
45
46    // Lines need to end with `\r\n` to be properly processed, at least when writing to a PTY.
47    #[cfg(windows)]
48    #[cfg_attr(
49        feature = "tracing",
50        tracing::instrument(level = "debug", skip(writer), err)
51    )]
52    fn write_line(writer: &mut impl io::Write, line: &str) -> io::Result<()> {
53        writeln!(writer, "{line}\r")
54    }
55
56    #[cfg_attr(
57        feature = "tracing",
58        tracing::instrument(level = "debug", skip(lines_recv), err)
59    )]
60    #[cfg_attr(not(feature = "tracing"), allow(unused_variables))]
61    // ^ The received `line` is used only for debug purposes
62    fn read_echo(
63        input_line: &str,
64        lines_recv: &mpsc::Receiver<Vec<u8>>,
65        io_timeout: Duration,
66    ) -> io::Result<()> {
67        if let Ok(line) = lines_recv.recv_timeout(io_timeout) {
68            #[cfg(feature = "tracing")]
69            tracing::debug!(line_utf8 = std::str::from_utf8(&line).ok(), "received line");
70            Ok(())
71        } else {
72            let err =
73                format!("could not read all input `{input_line}` back from an echoing terminal");
74            Err(io::Error::new(io::ErrorKind::BrokenPipe, err))
75        }
76    }
77
78    #[cfg_attr(
79        feature = "tracing",
80        tracing::instrument(level = "debug", skip_all, ret, err)
81    )]
82    fn read_output(
83        lines_recv: &mpsc::Receiver<Vec<u8>>,
84        mut timeouts: Timeouts,
85        line_decoder: &mut dyn FnMut(Vec<u8>) -> io::Result<String>,
86    ) -> io::Result<String> {
87        let mut output = String::new();
88
89        while let Ok(mut line) = lines_recv.recv_timeout(timeouts.next()) {
90            if line.last() == Some(&b'\r') {
91                // Normalize `\r\n` line ending to `\n`.
92                line.pop();
93            }
94            #[cfg(feature = "tracing")]
95            tracing::debug!(line_utf8 = std::str::from_utf8(&line).ok(), "received line");
96
97            let mapped_line = line_decoder(line)?;
98            #[cfg(feature = "tracing")]
99            tracing::debug!(?mapped_line, "mapped received line");
100            output.push_str(&mapped_line);
101            output.push('\n');
102        }
103
104        if output.ends_with('\n') {
105            output.truncate(output.len() - 1);
106        }
107        Ok(output)
108    }
109
110    /// Constructs a transcript from the sequence of given user `input`s.
111    ///
112    /// The inputs are executed in the shell specified in `options`. A single shell is shared
113    /// among all commands.
114    ///
115    /// # Errors
116    ///
117    /// - Returns an error if spawning the shell or any operations with it fail (such as reading
118    ///   stdout / stderr, or writing commands to stdin), or if the shell exits before all commands
119    ///   are executed.
120    #[cfg_attr(
121        feature = "tracing",
122        tracing::instrument(
123            skip_all,
124            err,
125            fields(
126                options.io_timeout = ?options.io_timeout,
127                options.init_timeout = ?options.init_timeout,
128                options.path_additions = ?options.path_additions,
129                options.init_commands = ?options.init_commands
130            )
131        )
132    )]
133    pub fn from_inputs<Cmd: SpawnShell>(
134        options: &mut ShellOptions<Cmd>,
135        inputs: impl IntoIterator<Item = UserInput>,
136    ) -> io::Result<Self> {
137        let SpawnedShell {
138            mut shell,
139            reader,
140            writer,
141        } = options.spawn_shell()?;
142
143        let stdout = BufReader::new(reader);
144        let (out_lines_send, out_lines_recv) = mpsc::channel();
145
146        // Propagate the dispatcher for the current thread to the spawned one. Mainly useful for integration tests
147        // that don't set the global dispatcher.
148        #[cfg(feature = "tracing")]
149        let dispatcher = tracing::dispatcher::get_default(Clone::clone);
150        let io_handle = thread::spawn(move || {
151            #[cfg(feature = "tracing")]
152            let _tracing_guard = tracing::dispatcher::set_default(&dispatcher);
153            #[cfg(feature = "tracing")]
154            let _entered = tracing::debug_span!("reader_thread").entered();
155
156            let mut lines = stdout.split(b'\n');
157            loop {
158                match lines.next() {
159                    Some(Ok(line)) => {
160                        #[cfg(feature = "tracing")]
161                        tracing::debug!(
162                            line_utf8 = std::str::from_utf8(&line).ok(),
163                            "received line"
164                        );
165
166                        if out_lines_send.send(line).is_err() {
167                            #[cfg(feature = "tracing")]
168                            tracing::debug!("receiver dropped, breaking reader loop");
169                            return;
170                        }
171                    }
172                    // `err` is only used in the log message
173                    #[cfg_attr(not(feature = "tracing"), allow(unused_variables))]
174                    Some(Err(err)) => {
175                        #[cfg(feature = "tracing")]
176                        tracing::warn!(?err, msg = %err, "error reading shell output");
177                        return;
178                    }
179                    None => {
180                        #[cfg(feature = "tracing")]
181                        tracing::debug!("input sender dropped");
182                        return;
183                    }
184                }
185            }
186        });
187
188        let mut stdin = LineWriter::new(writer);
189        Self::push_init_commands(options, &out_lines_recv, &mut shell, &mut stdin)?;
190
191        let mut transcript = Self::new();
192        for input in inputs {
193            let interaction =
194                Self::record_interaction(options, input, &out_lines_recv, &mut shell, &mut stdin)?;
195            transcript.interactions.push(interaction);
196        }
197
198        drop(stdin); // signals to shell that we're done
199
200        // Give a chance for the shell process to exit. This will reduce kill errors later.
201        thread::sleep(options.io_timeout / 4);
202
203        shell.terminate()?;
204        io_handle.join().ok(); // the I/O thread should not panic, so we ignore errors here
205        Ok(transcript)
206    }
207
208    #[cfg_attr(
209        feature = "tracing",
210        tracing::instrument(
211            level = "debug",
212            skip_all,
213            err,
214            fields(options.init_commands = ?options.init_commands)
215        )
216    )]
217    fn push_init_commands<Cmd: SpawnShell>(
218        options: &ShellOptions<Cmd>,
219        lines_recv: &mpsc::Receiver<Vec<u8>>,
220        shell: &mut Cmd::ShellProcess,
221        stdin: &mut impl io::Write,
222    ) -> io::Result<()> {
223        // Drain all output left after commands and let the shell get fully initialized.
224        let mut timeouts = Timeouts::new(options);
225        while lines_recv.recv_timeout(timeouts.next()).is_ok() {
226            // Intentionally empty.
227        }
228
229        // Push initialization commands.
230        for cmd in &options.init_commands {
231            Self::write_line(stdin, cmd)?;
232            if shell.is_echoing() {
233                Self::read_echo(cmd, lines_recv, options.io_timeout)?;
234            }
235
236            // Drain all other output as well.
237            let mut timeouts = Timeouts::new(options);
238            while lines_recv.recv_timeout(timeouts.next()).is_ok() {
239                // Intentionally empty.
240            }
241        }
242        Ok(())
243    }
244
245    #[cfg_attr(
246        feature = "tracing",
247        tracing::instrument(level = "debug", skip(options, lines_recv, shell, stdin), ret, err)
248    )]
249    fn record_interaction<Cmd: SpawnShell>(
250        options: &mut ShellOptions<Cmd>,
251        input: UserInput,
252        lines_recv: &mpsc::Receiver<Vec<u8>>,
253        shell: &mut Cmd::ShellProcess,
254        stdin: &mut impl io::Write,
255    ) -> io::Result<Interaction> {
256        // Check if the shell is still alive. It seems that older Rust versions allow
257        // to write to `stdin` even after the shell exits.
258        shell.check_is_alive()?;
259
260        let input_lines = input.text.split('\n');
261        for input_line in input_lines {
262            Self::write_line(stdin, input_line)?;
263            if shell.is_echoing() {
264                Self::read_echo(input_line, lines_recv, options.io_timeout)?;
265            }
266        }
267
268        let output = Self::read_output(
269            lines_recv,
270            Timeouts::new(options),
271            options.line_decoder.as_mut(),
272        )?;
273
274        let exit_status = if let Some(status_check) = &options.status_check {
275            let command = status_check.command();
276            Self::write_line(stdin, command)?;
277            if shell.is_echoing() {
278                Self::read_echo(command, lines_recv, options.io_timeout)?;
279            }
280            let response = Self::read_output(
281                lines_recv,
282                Timeouts::new(options),
283                options.line_decoder.as_mut(),
284            )?;
285            status_check.check(&Captured::from(response))
286        } else {
287            None
288        };
289
290        let mut interaction = Interaction::new(input, output);
291        interaction.exit_status = exit_status;
292        Ok(interaction)
293    }
294
295    /// Captures stdout / stderr of the provided `command` and adds it to [`Self::interactions()`].
296    ///
297    /// The `command` is spawned with the closed stdin. This method blocks until the command exits.
298    /// The method succeeds regardless of the exit status of the `command`.
299    ///
300    /// # Errors
301    ///
302    /// - Returns an error if spawning the `command` or any operations with it fail (such as reading
303    ///   stdout / stderr).
304    #[cfg_attr(
305        feature = "tracing",
306        tracing::instrument(skip(self, input), err, fields(input.text = %input.text))
307    )]
308    pub fn capture_output(
309        &mut self,
310        input: UserInput,
311        command: &mut Command,
312    ) -> io::Result<&mut Self> {
313        let (mut pipe_reader, pipe_writer) = os_pipe::pipe()?;
314        #[cfg(feature = "tracing")]
315        tracing::debug!("created OS pipe");
316
317        let mut child = command
318            .stdin(Stdio::null())
319            .stdout(pipe_writer.try_clone()?)
320            .stderr(pipe_writer)
321            .spawn()?;
322        #[cfg(feature = "tracing")]
323        tracing::debug!("created child");
324
325        // Drop pipe writers. This is necessary for the pipe reader to receive EOF.
326        command.stdout(Stdio::null()).stderr(Stdio::null());
327
328        let mut output = vec![];
329        pipe_reader.read_to_end(&mut output)?;
330        child.wait()?;
331
332        let mut output = String::from_utf8(output)
333            .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err.utf8_error()))?;
334        if output.ends_with('\n') {
335            output.truncate(output.len() - 1);
336        }
337        #[cfg(feature = "tracing")]
338        tracing::debug!(?output, "read command output");
339
340        self.interactions.push(Interaction::new(input, output));
341        Ok(self)
342    }
343}