1use std::{
4 io::{self, BufRead, BufReader, LineWriter, Read},
5 iter,
6 process::{Command, Stdio},
7 str,
8 sync::mpsc,
9 thread,
10 time::Duration,
11};
12
13use styled_str::{AnsiError, StyledString};
14
15use super::ShellOptions;
16use crate::{
17 Interaction, Transcript, UserInput,
18 traits::{ShellProcess, SpawnShell, SpawnedShell},
19};
20
21#[derive(Debug)]
22struct Timeouts {
23 inner: iter::Chain<iter::Once<Duration>, iter::Repeat<Duration>>,
24}
25
26impl Timeouts {
27 fn new<Cmd: SpawnShell>(options: &ShellOptions<Cmd>) -> Self {
28 Self {
29 inner: iter::once(options.init_timeout + options.io_timeout)
30 .chain(iter::repeat(options.io_timeout)),
31 }
32 }
33
34 fn next(&mut self) -> Duration {
35 self.inner.next().unwrap() }
37}
38
39fn map_ansi_error(err: AnsiError) -> io::Error {
40 io::Error::new(io::ErrorKind::InvalidData, err)
41}
42
43impl Transcript {
44 #[cfg(not(windows))]
45 #[cfg_attr(
46 feature = "tracing",
47 tracing::instrument(level = "debug", skip(writer), err)
48 )]
49 fn write_line(writer: &mut impl io::Write, line: &str) -> io::Result<()> {
50 writeln!(writer, "{line}")
51 }
52
53 #[cfg(windows)]
55 #[cfg_attr(
56 feature = "tracing",
57 tracing::instrument(level = "debug", skip(writer), err)
58 )]
59 fn write_line(writer: &mut impl io::Write, line: &str) -> io::Result<()> {
60 writeln!(writer, "{line}\r")
61 }
62
63 #[cfg_attr(
64 feature = "tracing",
65 tracing::instrument(level = "debug", skip(lines_recv), err)
66 )]
67 #[cfg_attr(not(feature = "tracing"), allow(unused_variables))]
68 fn read_echo(
70 input_line: &str,
71 lines_recv: &mpsc::Receiver<Vec<u8>>,
72 io_timeout: Duration,
73 ) -> io::Result<()> {
74 if let Ok(line) = lines_recv.recv_timeout(io_timeout) {
75 #[cfg(feature = "tracing")]
76 tracing::debug!(line_utf8 = std::str::from_utf8(&line).ok(), "received line");
77 Ok(())
78 } else {
79 let err =
80 format!("could not read all input `{input_line}` back from an echoing terminal");
81 Err(io::Error::new(io::ErrorKind::BrokenPipe, err))
82 }
83 }
84
85 #[cfg_attr(
86 feature = "tracing",
87 tracing::instrument(level = "debug", skip_all, ret, err)
88 )]
89 fn read_output(
90 lines_recv: &mpsc::Receiver<Vec<u8>>,
91 mut timeouts: Timeouts,
92 line_decoder: &mut dyn FnMut(Vec<u8>) -> io::Result<String>,
93 ) -> io::Result<String> {
94 let mut output = String::new();
95
96 while let Ok(mut line) = lines_recv.recv_timeout(timeouts.next()) {
97 if line.last() == Some(&b'\r') {
98 line.pop();
100 }
101 #[cfg(feature = "tracing")]
102 tracing::debug!(line_utf8 = std::str::from_utf8(&line).ok(), "received line");
103
104 let mapped_line = line_decoder(line)?;
105 #[cfg(feature = "tracing")]
106 tracing::debug!(?mapped_line, "mapped received line");
107 output.push_str(&mapped_line);
108 output.push('\n');
109 }
110
111 Ok(output)
112 }
113
114 #[cfg_attr(
125 feature = "tracing",
126 tracing::instrument(
127 skip_all,
128 err,
129 fields(
130 options.io_timeout = ?options.io_timeout,
131 options.init_timeout = ?options.init_timeout,
132 options.path_additions = ?options.path_additions,
133 options.init_commands = ?options.init_commands
134 )
135 )
136 )]
137 pub fn from_inputs<Cmd: SpawnShell>(
138 options: &mut ShellOptions<Cmd>,
139 inputs: impl IntoIterator<Item = UserInput>,
140 ) -> io::Result<Self> {
141 let SpawnedShell {
142 mut shell,
143 reader,
144 writer,
145 } = options.spawn_shell()?;
146
147 let stdout = BufReader::new(reader);
148 let (out_lines_send, out_lines_recv) = mpsc::channel();
149
150 #[cfg(feature = "tracing")]
153 let dispatcher = tracing::dispatcher::get_default(Clone::clone);
154 let io_handle = thread::spawn(move || {
155 #[cfg(feature = "tracing")]
156 let _tracing_guard = tracing::dispatcher::set_default(&dispatcher);
157 #[cfg(feature = "tracing")]
158 let _entered = tracing::debug_span!("reader_thread").entered();
159
160 let mut lines = stdout.split(b'\n');
161 loop {
162 match lines.next() {
163 Some(Ok(line)) => {
164 #[cfg(feature = "tracing")]
165 tracing::debug!(
166 line_utf8 = std::str::from_utf8(&line).ok(),
167 "received line"
168 );
169
170 if out_lines_send.send(line).is_err() {
171 #[cfg(feature = "tracing")]
172 tracing::debug!("receiver dropped, breaking reader loop");
173 return;
174 }
175 }
176 #[cfg_attr(not(feature = "tracing"), allow(unused_variables))]
178 Some(Err(err)) => {
179 #[cfg(feature = "tracing")]
180 tracing::warn!(?err, msg = %err, "error reading shell output");
181 return;
182 }
183 None => {
184 #[cfg(feature = "tracing")]
185 tracing::debug!("input sender dropped");
186 return;
187 }
188 }
189 }
190 });
191
192 let mut stdin = LineWriter::new(writer);
193 Self::push_init_commands(options, &out_lines_recv, &mut shell, &mut stdin)?;
194
195 let mut transcript = Self::new();
196 for input in inputs {
197 let interaction =
198 Self::record_interaction(options, input, &out_lines_recv, &mut shell, &mut stdin)?;
199 transcript.add_existing_interaction(interaction);
200 }
201
202 drop(stdin); thread::sleep(options.io_timeout / 4);
206
207 shell.terminate()?;
208 io_handle.join().ok(); Ok(transcript)
210 }
211
212 #[cfg_attr(
213 feature = "tracing",
214 tracing::instrument(
215 level = "debug",
216 skip_all,
217 err,
218 fields(options.init_commands = ?options.init_commands)
219 )
220 )]
221 fn push_init_commands<Cmd: SpawnShell>(
222 options: &ShellOptions<Cmd>,
223 lines_recv: &mpsc::Receiver<Vec<u8>>,
224 shell: &mut Cmd::ShellProcess,
225 stdin: &mut impl io::Write,
226 ) -> io::Result<()> {
227 let mut timeouts = Timeouts::new(options);
229 while lines_recv.recv_timeout(timeouts.next()).is_ok() {
230 }
232
233 for cmd in &options.init_commands {
235 Self::write_line(stdin, cmd)?;
236 if shell.is_echoing() {
237 Self::read_echo(cmd, lines_recv, options.io_timeout)?;
238 }
239
240 let mut timeouts = Timeouts::new(options);
242 while lines_recv.recv_timeout(timeouts.next()).is_ok() {
243 }
245 }
246 Ok(())
247 }
248
249 #[cfg_attr(
250 feature = "tracing",
251 tracing::instrument(level = "debug", skip(options, lines_recv, shell, stdin), ret, err)
252 )]
253 fn record_interaction<Cmd: SpawnShell>(
254 options: &mut ShellOptions<Cmd>,
255 input: UserInput,
256 lines_recv: &mpsc::Receiver<Vec<u8>>,
257 shell: &mut Cmd::ShellProcess,
258 stdin: &mut impl io::Write,
259 ) -> io::Result<Interaction> {
260 shell.check_is_alive()?;
263
264 let input_lines = input.as_ref().split('\n');
265 for input_line in input_lines {
266 Self::write_line(stdin, input_line)?;
267 if shell.is_echoing() {
268 Self::read_echo(input_line, lines_recv, options.io_timeout)?;
269 }
270 }
271
272 let output = Self::read_output(
273 lines_recv,
274 Timeouts::new(options),
275 options.line_decoder.as_mut(),
276 )?;
277 let output = StyledString::from_ansi(&output).map_err(map_ansi_error)?;
278
279 let exit_status = if let Some(status_check) = &options.status_check {
280 let command = status_check.command();
281 Self::write_line(stdin, command)?;
282 if shell.is_echoing() {
283 Self::read_echo(command, lines_recv, options.io_timeout)?;
284 }
285 let response = Self::read_output(
286 lines_recv,
287 Timeouts::new(options),
288 options.line_decoder.as_mut(),
289 )?;
290 let response = StyledString::from_ansi(&response).map_err(map_ansi_error)?;
291
292 status_check.check(response.as_str())
293 } else {
294 None
295 };
296
297 let mut interaction = Interaction::new(input, output);
298 interaction.set_exit_status(exit_status);
299 Ok(interaction)
300 }
301
302 #[cfg_attr(
312 feature = "tracing",
313 tracing::instrument(skip(self, input), err, fields(input.text = input.as_ref()))
314 )]
315 pub fn capture_output(
316 &mut self,
317 input: UserInput,
318 command: &mut Command,
319 ) -> io::Result<&mut Self> {
320 let (mut pipe_reader, pipe_writer) = os_pipe::pipe()?;
321 #[cfg(feature = "tracing")]
322 tracing::debug!("created OS pipe");
323
324 let mut child = command
325 .stdin(Stdio::null())
326 .stdout(pipe_writer.try_clone()?)
327 .stderr(pipe_writer)
328 .spawn()?;
329 #[cfg(feature = "tracing")]
330 tracing::debug!("created child");
331
332 command.stdout(Stdio::null()).stderr(Stdio::null());
334
335 let mut output = vec![];
336 pipe_reader.read_to_end(&mut output)?;
337 child.wait()?;
338
339 let output = str::from_utf8(&output)
340 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
341 #[cfg(feature = "tracing")]
342 tracing::debug!(?output, "read command output");
343
344 let output = StyledString::from_ansi(output).map_err(map_ansi_error)?;
345 let interaction = Interaction::new(input, output);
346 self.add_existing_interaction(interaction);
347 Ok(self)
348 }
349}