1use std::{
4 io::{self, BufRead, BufReader, LineWriter, Read},
5 iter,
6 process::{Command, Stdio},
7 sync::mpsc,
8 thread,
9 time::Duration,
10};
11
12use super::ShellOptions;
13use crate::{
14 traits::{ShellProcess, SpawnShell, SpawnedShell},
15 Captured, Interaction, Transcript, UserInput,
16};
17
18#[derive(Debug)]
19struct Timeouts {
20 inner: iter::Chain<iter::Once<Duration>, iter::Repeat<Duration>>,
21}
22
23impl Timeouts {
24 fn new<Cmd: SpawnShell>(options: &ShellOptions<Cmd>) -> Self {
25 Self {
26 inner: iter::once(options.init_timeout + options.io_timeout)
27 .chain(iter::repeat(options.io_timeout)),
28 }
29 }
30
31 fn next(&mut self) -> Duration {
32 self.inner.next().unwrap() }
34}
35
36impl Transcript {
37 #[cfg(not(windows))]
38 #[cfg_attr(
39 feature = "tracing",
40 tracing::instrument(level = "debug", skip(writer), err)
41 )]
42 fn write_line(writer: &mut impl io::Write, line: &str) -> io::Result<()> {
43 writeln!(writer, "{line}")
44 }
45
46 #[cfg(windows)]
48 #[cfg_attr(
49 feature = "tracing",
50 tracing::instrument(level = "debug", skip(writer), err)
51 )]
52 fn write_line(writer: &mut impl io::Write, line: &str) -> io::Result<()> {
53 writeln!(writer, "{line}\r")
54 }
55
56 #[cfg_attr(
57 feature = "tracing",
58 tracing::instrument(level = "debug", skip(lines_recv), err)
59 )]
60 #[cfg_attr(not(feature = "tracing"), allow(unused_variables))]
61 fn read_echo(
63 input_line: &str,
64 lines_recv: &mpsc::Receiver<Vec<u8>>,
65 io_timeout: Duration,
66 ) -> io::Result<()> {
67 if let Ok(line) = lines_recv.recv_timeout(io_timeout) {
68 #[cfg(feature = "tracing")]
69 tracing::debug!(line_utf8 = std::str::from_utf8(&line).ok(), "received line");
70 Ok(())
71 } else {
72 let err =
73 format!("could not read all input `{input_line}` back from an echoing terminal");
74 Err(io::Error::new(io::ErrorKind::BrokenPipe, err))
75 }
76 }
77
78 #[cfg_attr(
79 feature = "tracing",
80 tracing::instrument(level = "debug", skip_all, ret, err)
81 )]
82 fn read_output(
83 lines_recv: &mpsc::Receiver<Vec<u8>>,
84 mut timeouts: Timeouts,
85 line_decoder: &mut dyn FnMut(Vec<u8>) -> io::Result<String>,
86 ) -> io::Result<String> {
87 let mut output = String::new();
88
89 while let Ok(mut line) = lines_recv.recv_timeout(timeouts.next()) {
90 if line.last() == Some(&b'\r') {
91 line.pop();
93 }
94 #[cfg(feature = "tracing")]
95 tracing::debug!(line_utf8 = std::str::from_utf8(&line).ok(), "received line");
96
97 let mapped_line = line_decoder(line)?;
98 #[cfg(feature = "tracing")]
99 tracing::debug!(?mapped_line, "mapped received line");
100 output.push_str(&mapped_line);
101 output.push('\n');
102 }
103
104 if output.ends_with('\n') {
105 output.truncate(output.len() - 1);
106 }
107 Ok(output)
108 }
109
110 #[cfg_attr(
121 feature = "tracing",
122 tracing::instrument(
123 skip_all,
124 err,
125 fields(
126 options.io_timeout = ?options.io_timeout,
127 options.init_timeout = ?options.init_timeout,
128 options.path_additions = ?options.path_additions,
129 options.init_commands = ?options.init_commands
130 )
131 )
132 )]
133 pub fn from_inputs<Cmd: SpawnShell>(
134 options: &mut ShellOptions<Cmd>,
135 inputs: impl IntoIterator<Item = UserInput>,
136 ) -> io::Result<Self> {
137 let SpawnedShell {
138 mut shell,
139 reader,
140 writer,
141 } = options.spawn_shell()?;
142
143 let stdout = BufReader::new(reader);
144 let (out_lines_send, out_lines_recv) = mpsc::channel();
145
146 #[cfg(feature = "tracing")]
149 let dispatcher = tracing::dispatcher::get_default(Clone::clone);
150 let io_handle = thread::spawn(move || {
151 #[cfg(feature = "tracing")]
152 let _tracing_guard = tracing::dispatcher::set_default(&dispatcher);
153 #[cfg(feature = "tracing")]
154 let _entered = tracing::debug_span!("reader_thread").entered();
155
156 let mut lines = stdout.split(b'\n');
157 loop {
158 match lines.next() {
159 Some(Ok(line)) => {
160 #[cfg(feature = "tracing")]
161 tracing::debug!(
162 line_utf8 = std::str::from_utf8(&line).ok(),
163 "received line"
164 );
165
166 if out_lines_send.send(line).is_err() {
167 #[cfg(feature = "tracing")]
168 tracing::debug!("receiver dropped, breaking reader loop");
169 return;
170 }
171 }
172 #[cfg_attr(not(feature = "tracing"), allow(unused_variables))]
174 Some(Err(err)) => {
175 #[cfg(feature = "tracing")]
176 tracing::warn!(?err, msg = %err, "error reading shell output");
177 return;
178 }
179 None => {
180 #[cfg(feature = "tracing")]
181 tracing::debug!("input sender dropped");
182 return;
183 }
184 }
185 }
186 });
187
188 let mut stdin = LineWriter::new(writer);
189 Self::push_init_commands(options, &out_lines_recv, &mut shell, &mut stdin)?;
190
191 let mut transcript = Self::new();
192 for input in inputs {
193 let interaction =
194 Self::record_interaction(options, input, &out_lines_recv, &mut shell, &mut stdin)?;
195 transcript.interactions.push(interaction);
196 }
197
198 drop(stdin); thread::sleep(options.io_timeout / 4);
202
203 shell.terminate()?;
204 io_handle.join().ok(); Ok(transcript)
206 }
207
208 #[cfg_attr(
209 feature = "tracing",
210 tracing::instrument(
211 level = "debug",
212 skip_all,
213 err,
214 fields(options.init_commands = ?options.init_commands)
215 )
216 )]
217 fn push_init_commands<Cmd: SpawnShell>(
218 options: &ShellOptions<Cmd>,
219 lines_recv: &mpsc::Receiver<Vec<u8>>,
220 shell: &mut Cmd::ShellProcess,
221 stdin: &mut impl io::Write,
222 ) -> io::Result<()> {
223 let mut timeouts = Timeouts::new(options);
225 while lines_recv.recv_timeout(timeouts.next()).is_ok() {
226 }
228
229 for cmd in &options.init_commands {
231 Self::write_line(stdin, cmd)?;
232 if shell.is_echoing() {
233 Self::read_echo(cmd, lines_recv, options.io_timeout)?;
234 }
235
236 let mut timeouts = Timeouts::new(options);
238 while lines_recv.recv_timeout(timeouts.next()).is_ok() {
239 }
241 }
242 Ok(())
243 }
244
245 #[cfg_attr(
246 feature = "tracing",
247 tracing::instrument(level = "debug", skip(options, lines_recv, shell, stdin), ret, err)
248 )]
249 fn record_interaction<Cmd: SpawnShell>(
250 options: &mut ShellOptions<Cmd>,
251 input: UserInput,
252 lines_recv: &mpsc::Receiver<Vec<u8>>,
253 shell: &mut Cmd::ShellProcess,
254 stdin: &mut impl io::Write,
255 ) -> io::Result<Interaction> {
256 shell.check_is_alive()?;
259
260 let input_lines = input.text.split('\n');
261 for input_line in input_lines {
262 Self::write_line(stdin, input_line)?;
263 if shell.is_echoing() {
264 Self::read_echo(input_line, lines_recv, options.io_timeout)?;
265 }
266 }
267
268 let output = Self::read_output(
269 lines_recv,
270 Timeouts::new(options),
271 options.line_decoder.as_mut(),
272 )?;
273
274 let exit_status = if let Some(status_check) = &options.status_check {
275 let command = status_check.command();
276 Self::write_line(stdin, command)?;
277 if shell.is_echoing() {
278 Self::read_echo(command, lines_recv, options.io_timeout)?;
279 }
280 let response = Self::read_output(
281 lines_recv,
282 Timeouts::new(options),
283 options.line_decoder.as_mut(),
284 )?;
285 status_check.check(&Captured::from(response))
286 } else {
287 None
288 };
289
290 let mut interaction = Interaction::new(input, output);
291 interaction.exit_status = exit_status;
292 Ok(interaction)
293 }
294
295 #[cfg_attr(
305 feature = "tracing",
306 tracing::instrument(skip(self, input), err, fields(input.text = %input.text))
307 )]
308 pub fn capture_output(
309 &mut self,
310 input: UserInput,
311 command: &mut Command,
312 ) -> io::Result<&mut Self> {
313 let (mut pipe_reader, pipe_writer) = os_pipe::pipe()?;
314 #[cfg(feature = "tracing")]
315 tracing::debug!("created OS pipe");
316
317 let mut child = command
318 .stdin(Stdio::null())
319 .stdout(pipe_writer.try_clone()?)
320 .stderr(pipe_writer)
321 .spawn()?;
322 #[cfg(feature = "tracing")]
323 tracing::debug!("created child");
324
325 command.stdout(Stdio::null()).stderr(Stdio::null());
327
328 let mut output = vec![];
329 pipe_reader.read_to_end(&mut output)?;
330 child.wait()?;
331
332 let mut output = String::from_utf8(output)
333 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err.utf8_error()))?;
334 if output.ends_with('\n') {
335 output.truncate(output.len() - 1);
336 }
337 #[cfg(feature = "tracing")]
338 tracing::debug!(?output, "read command output");
339
340 self.interactions.push(Interaction::new(input, output));
341 Ok(self)
342 }
343}