1use std::{
4 convert::Infallible,
5 env, error,
6 ffi::OsStr,
7 fmt, io,
8 path::{Path, PathBuf},
9 process::Command,
10 time::Duration,
11};
12
13use styled_str::StyledStr;
14
15mod standard;
16mod transcript_impl;
17
18pub use self::standard::StdShell;
19use crate::{
20 ExitStatus,
21 traits::{ConfigureCommand, Echoing, SpawnShell, SpawnedShell},
22};
23
24type StatusCheckerFn = dyn Fn(StyledStr<'_>) -> Option<ExitStatus>;
25
26pub(crate) struct StatusCheck {
27 command: String,
28 response_checker: Box<StatusCheckerFn>,
29}
30
31impl fmt::Debug for StatusCheck {
32 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
33 formatter
34 .debug_struct("StatusCheck")
35 .field("command", &self.command)
36 .finish_non_exhaustive()
37 }
38}
39
40impl StatusCheck {
41 pub(crate) fn command(&self) -> &str {
42 &self.command
43 }
44
45 pub(crate) fn check(&self, response: StyledStr<'_>) -> Option<ExitStatus> {
46 (self.response_checker)(response)
47 }
48}
49
50pub struct ShellOptions<Cmd = Command> {
59 command: Cmd,
60 path_additions: Vec<PathBuf>,
61 io_timeout: Duration,
62 init_timeout: Duration,
63 init_commands: Vec<String>,
64 line_decoder: Box<dyn FnMut(Vec<u8>) -> io::Result<String>>,
65 status_check: Option<StatusCheck>,
66}
67
68impl<Cmd: fmt::Debug> fmt::Debug for ShellOptions<Cmd> {
69 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
70 formatter
71 .debug_struct("ShellOptions")
72 .field("command", &self.command)
73 .field("path_additions", &self.path_additions)
74 .field("io_timeout", &self.io_timeout)
75 .field("init_timeout", &self.init_timeout)
76 .field("init_commands", &self.init_commands)
77 .field("status_check", &self.status_check)
78 .finish_non_exhaustive()
79 }
80}
81
82#[cfg(any(unix, windows))]
83impl Default for ShellOptions {
84 fn default() -> Self {
85 Self::new(Self::default_shell())
86 }
87}
88
89impl<Cmd: ConfigureCommand> From<Cmd> for ShellOptions<Cmd> {
90 fn from(command: Cmd) -> Self {
91 Self::new(command)
92 }
93}
94
95impl<Cmd: ConfigureCommand> ShellOptions<Cmd> {
96 #[cfg(unix)]
97 fn default_shell() -> Command {
98 Command::new("sh")
99 }
100
101 #[cfg(windows)]
102 fn default_shell() -> Command {
103 let mut command = Command::new("cmd");
104 command.arg("/Q").arg("/K").arg("echo off && chcp 65001");
106 command
107 }
108
109 pub fn new(command: Cmd) -> Self {
111 Self {
112 command,
113 path_additions: vec![],
114 io_timeout: Duration::from_millis(500),
115 init_timeout: Duration::from_millis(1_500),
116 init_commands: vec![],
117 line_decoder: Box::new(|line| {
118 String::from_utf8(line)
119 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err.utf8_error()))
120 }),
121 status_check: None,
122 }
123 }
124
125 #[must_use]
127 pub fn echoing(self, is_echoing: bool) -> ShellOptions<Echoing<Cmd>> {
128 ShellOptions {
129 command: Echoing::new(self.command, is_echoing),
130 path_additions: self.path_additions,
131 io_timeout: self.io_timeout,
132 init_timeout: self.init_timeout,
133 init_commands: self.init_commands,
134 line_decoder: self.line_decoder,
135 status_check: self.status_check,
136 }
137 }
138
139 #[must_use]
141 pub fn with_current_dir(mut self, current_dir: impl AsRef<Path>) -> Self {
142 self.command.current_dir(current_dir.as_ref());
143 self
144 }
145
146 #[must_use]
155 pub fn with_io_timeout(mut self, io_timeout: Duration) -> Self {
156 self.io_timeout = io_timeout;
157 self
158 }
159
160 #[must_use]
166 pub fn with_init_timeout(mut self, init_timeout: Duration) -> Self {
167 self.init_timeout = init_timeout;
168 self
169 }
170
171 #[must_use]
174 pub fn with_init_command(mut self, command: impl Into<String>) -> Self {
175 self.init_commands.push(command.into());
176 self
177 }
178
179 #[must_use]
181 pub fn with_env(mut self, name: impl AsRef<str>, value: impl AsRef<OsStr>) -> Self {
182 self.command.env(name.as_ref(), value.as_ref());
183 self
184 }
185
186 #[must_use]
192 pub fn with_line_decoder<E, F>(mut self, mut mapper: F) -> Self
193 where
194 E: Into<Box<dyn error::Error + Send + Sync>>,
195 F: FnMut(Vec<u8>) -> Result<String, E> + 'static,
196 {
197 self.line_decoder = Box::new(move |line| {
198 mapper(line).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))
199 });
200 self
201 }
202
203 #[must_use]
206 pub fn with_lossy_utf8_decoder(self) -> Self {
207 self.with_line_decoder::<Infallible, _>(|line| {
208 Ok(String::from_utf8_lossy(&line).into_owned())
209 })
210 }
211
212 #[must_use]
237 pub fn with_status_check<F>(mut self, command: impl Into<String>, checker: F) -> Self
238 where
239 F: Fn(StyledStr<'_>) -> Option<ExitStatus> + 'static,
240 {
241 let command = command.into();
242 assert!(
243 command.bytes().all(|ch| ch != b'\n' && ch != b'\r'),
244 "`command` contains a newline character ('\\n' or '\\r')"
245 );
246
247 self.status_check = Some(StatusCheck {
248 command,
249 response_checker: Box::new(checker),
250 });
251 self
252 }
253
254 fn target_path() -> PathBuf {
258 let mut path = env::current_exe().expect("Cannot obtain path to the executing file");
259 path.pop();
260 if path.ends_with("deps") {
261 path.pop();
262 }
263 path
264 }
265
266 #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", ret))]
267 fn legacy_cargo_path(binary_name: &str) -> Option<PathBuf> {
268 let target_path = Self::target_path();
269 let binary_path = target_path.join(format!("{binary_name}{}", env::consts::EXE_SUFFIX));
270 let exists = binary_path.try_exists();
271
272 #[cfg(feature = "tracing")]
273 tracing::debug!(?binary_path, ?exists, "checked binary path");
274 exists.ok()?.then_some(binary_path)
275 }
276
277 fn panic_on_missing_cargo_path(binary_name: &str) -> ! {
278 let binaries: Vec<_> = env::vars_os()
279 .filter_map(|(name, _)| {
280 let name = name.into_string().ok()?;
281 Some(name.strip_prefix("CARGO_BIN_EXE_")?.to_owned())
282 })
283 .collect();
284 if binaries.is_empty() {
285 panic!(
286 "`CARGO_BIN_EXE_{binary_name}` env variable is unset, and {binary_name} is not in the default cargo target dir.\n\
287 help: If this is run in a unit test, move it to an integration test to gain access to `CARGO_BIN_EXE_` vars (requires Rust 1.94+)"
288 );
289 } else {
290 panic!(
291 "`{binary_name}` does not look like a valid cargo binary in the workspace.\n\
292 help: Available binaries: {binaries:?}"
293 );
294 }
295 }
296
297 #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip(self)))]
306 #[must_use]
307 #[allow(clippy::missing_panics_doc)] pub fn with_cargo_path_for(mut self, binary_name: &str) -> Self {
309 let env_var_name = format!("CARGO_BIN_EXE_{binary_name}");
310 let binary_path = env::var_os(&env_var_name).map(PathBuf::from);
311
312 #[cfg(feature = "tracing")]
313 tracing::debug!(?binary_path, "got Rust 1.94+ path to binary");
314
315 let binary_path = binary_path
316 .or_else(|| Self::legacy_cargo_path(binary_name))
317 .unwrap_or_else(|| Self::panic_on_missing_cargo_path(binary_name));
318
319 #[cfg(feature = "tracing")]
320 tracing::debug!(?binary_path, "got path to binary");
321
322 let parent_path = binary_path
323 .parent()
324 .expect("invalid binary path")
325 .to_owned();
326 if !self.path_additions.contains(&parent_path) {
328 self.path_additions.push(parent_path);
329 }
330
331 self
332 }
333
334 #[must_use]
338 pub fn with_additional_path(mut self, path: impl Into<PathBuf>) -> Self {
339 let path = path.into();
340 self.path_additions.push(path);
341 self
342 }
343}
344
345impl<Cmd: SpawnShell> ShellOptions<Cmd> {
346 #[cfg_attr(
347 feature = "tracing",
348 tracing::instrument(
349 level = "debug",
350 skip(self),
351 err,
352 fields(self.path_additions = ?self.path_additions)
353 )
354 )]
355 fn spawn_shell(&mut self) -> io::Result<SpawnedShell<Cmd>> {
356 #[cfg(unix)]
357 const PATH_SEPARATOR: &str = ":";
358 #[cfg(windows)]
359 const PATH_SEPARATOR: &str = ";";
360
361 if !self.path_additions.is_empty() {
362 let mut path_var = env::var_os("PATH").unwrap_or_default();
363 if !path_var.is_empty() {
364 path_var.push(PATH_SEPARATOR);
365 }
366 for (i, addition) in self.path_additions.iter().enumerate() {
367 path_var.push(addition);
368 if i + 1 < self.path_additions.len() {
369 path_var.push(PATH_SEPARATOR);
370 }
371 }
372 self.command.env("PATH", &path_var);
373 }
374 self.command.spawn_shell()
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use crate::{Transcript, UserInput};
382
383 #[cfg(any(unix, windows))]
384 #[test]
385 fn creating_transcript_basics() -> anyhow::Result<()> {
386 let inputs = vec![
387 UserInput::command("echo hello"),
388 UserInput::command("echo foo && echo bar >&2"),
389 ];
390 let transcript = Transcript::from_inputs(&mut ShellOptions::default(), inputs)?;
391
392 assert_eq!(transcript.interactions().len(), 2);
393
394 {
395 let interaction = &transcript.interactions()[0];
396 assert_eq!(interaction.input().as_ref(), "echo hello");
397 assert_eq!(interaction.output().text().trim(), "hello");
398 }
399
400 let interaction = &transcript.interactions()[1];
401 assert_eq!(interaction.input().as_ref(), "echo foo && echo bar >&2");
402 let output = interaction.output();
403 assert_eq!(
404 output.text().split_whitespace().collect::<Vec<_>>(),
405 ["foo", "bar"]
406 );
407 Ok(())
408 }
409
410 #[cfg(unix)]
411 #[test]
412 fn transcript_with_multiline_input() -> anyhow::Result<()> {
413 let mut options = ShellOptions::default();
414 let inputs = vec![UserInput::command("echo \\\nhello")];
415 let transcript = Transcript::from_inputs(&mut options, inputs)?;
416
417 assert_eq!(transcript.interactions().len(), 1);
418 let interaction = &transcript.interactions()[0];
419 let output = interaction.output();
420 assert_eq!(output.text().trim(), "hello");
421 Ok(())
422 }
423}