1use std::{
6 collections::HashMap,
7 error::Error as StdError,
8 ffi::{OsStr, OsString},
9 fmt, io,
10 path::{Path, PathBuf},
11};
12
13use portable_pty::{native_pty_system, Child, CommandBuilder, MasterPty, PtyPair, PtySize};
14
15use crate::{
16 traits::{ConfigureCommand, ShellProcess, SpawnShell, SpawnedShell},
17 utils::is_recoverable_kill_error,
18};
19
20fn into_io_error(err: Box<dyn StdError + Send + Sync>) -> io::Error {
21 err.downcast::<io::Error>()
22 .map_or_else(io::Error::other, |err| *err)
23}
24
25#[cfg_attr(docsrs, doc(cfg(feature = "portable-pty")))]
46#[derive(Debug, Clone)]
47pub struct PtyCommand {
48 args: Vec<OsString>,
49 env: HashMap<OsString, OsString>,
50 current_dir: Option<PathBuf>,
51 pty_size: PtySize,
52}
53
54#[cfg(unix)]
55impl Default for PtyCommand {
56 fn default() -> Self {
57 Self::new("sh")
58 }
59}
60
61#[cfg(windows)]
62impl Default for PtyCommand {
63 fn default() -> Self {
64 let mut cmd = Self::new("cmd");
65 cmd.arg("/Q").arg("/K").arg("echo off && chcp 65001");
66 cmd
67 }
68}
69
70impl PtyCommand {
71 pub fn new(command: impl Into<OsString>) -> Self {
75 Self {
76 args: vec![command.into()],
77 env: HashMap::new(),
78 current_dir: None,
79 pty_size: PtySize {
80 rows: 19,
81 cols: 80,
82 pixel_width: 0,
83 pixel_height: 0,
84 },
85 }
86 }
87
88 pub fn with_size(&mut self, rows: u16, cols: u16) -> &mut Self {
90 self.pty_size.rows = rows;
91 self.pty_size.cols = cols;
92 self
93 }
94
95 pub fn arg(&mut self, arg: impl Into<OsString>) -> &mut Self {
97 self.args.push(arg.into());
98 self
99 }
100
101 fn to_command_builder(&self) -> CommandBuilder {
102 let mut builder = CommandBuilder::from_argv(self.args.clone());
103 for (name, value) in &self.env {
104 builder.env(name, value);
105 }
106 if let Some(current_dir) = &self.current_dir {
107 builder.cwd(current_dir);
108 }
109 builder
110 }
111}
112
113impl ConfigureCommand for PtyCommand {
114 fn current_dir(&mut self, dir: &Path) {
115 self.current_dir = Some(dir.to_owned());
116 }
117
118 fn env(&mut self, name: &str, value: &OsStr) {
119 self.env
120 .insert(OsStr::new(name).to_owned(), value.to_owned());
121 }
122}
123
124impl SpawnShell for PtyCommand {
125 type ShellProcess = PtyShell;
126 type Reader = Box<dyn io::Read + Send>;
127 type Writer = Box<dyn io::Write + Send>;
128
129 #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", err))]
130 fn spawn_shell(&mut self) -> io::Result<SpawnedShell<Self>> {
131 let pty_system = native_pty_system();
132 let PtyPair { master, slave } = pty_system
133 .openpty(self.pty_size)
134 .map_err(|err| into_io_error(err.into()))?;
135 #[cfg(feature = "tracing")]
136 tracing::debug!("created PTY pair");
137
138 let child = slave
139 .spawn_command(self.to_command_builder())
140 .map_err(|err| into_io_error(err.into()))?;
141 #[cfg(feature = "tracing")]
142 tracing::debug!("spawned command into PTY");
143
144 let reader = master
145 .try_clone_reader()
146 .map_err(|err| into_io_error(err.into()))?;
147 let writer = master
148 .take_writer()
149 .map_err(|err| into_io_error(err.into()))?;
150 Ok(SpawnedShell {
151 shell: PtyShell {
152 child,
153 _master: master,
154 },
155 reader,
156 writer,
157 })
158 }
159}
160
161#[cfg_attr(docsrs, doc(cfg(feature = "portable-pty")))]
163pub struct PtyShell {
164 child: Box<dyn Child + Send + Sync>,
165 _master: Box<dyn MasterPty + Send>,
166}
167
168impl fmt::Debug for PtyShell {
169 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
170 formatter
171 .debug_struct("PtyShell")
172 .field("child", &self.child)
173 .finish_non_exhaustive()
174 }
175}
176
177impl ShellProcess for PtyShell {
178 #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", err))]
179 fn check_is_alive(&mut self) -> io::Result<()> {
180 if let Some(exit_status) = self.child.try_wait()? {
181 #[cfg(feature = "tracing")]
182 tracing::error!(?exit_status, "shell exited prematurely");
183
184 let status_str = if exit_status.success() {
185 "zero"
186 } else {
187 "non-zero"
188 };
189 let message =
190 format!("Shell process has prematurely exited with {status_str} exit status");
191 Err(io::Error::new(io::ErrorKind::BrokenPipe, message))
192 } else {
193 Ok(())
194 }
195 }
196
197 #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", err))]
198 fn terminate(mut self) -> io::Result<()> {
199 if self.child.try_wait()?.is_none() {
200 self.child.kill().or_else(|err| {
201 if is_recoverable_kill_error(&err) {
202 Ok(())
204 } else {
205 Err(err)
206 }
207 })?;
208 }
209 Ok(())
210 }
211
212 fn is_echoing(&self) -> bool {
213 true
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use std::{
220 io::{Read, Write},
221 thread,
222 time::Duration,
223 };
224
225 use test_casing::{
226 decorate,
227 decorators::{Retry, RetryErrors},
228 };
229
230 use super::*;
231 use crate::{ShellOptions, Transcript, UserInput};
232
233 #[test]
234 fn pty_trait_implementation() -> anyhow::Result<()> {
235 let mut pty_command = PtyCommand::default();
236 let mut spawned = pty_command.spawn_shell()?;
237
238 thread::sleep(Duration::from_millis(100));
239 spawned.shell.check_is_alive()?;
240
241 writeln!(spawned.writer, "echo Hello")?;
242 thread::sleep(Duration::from_millis(100));
243 spawned.shell.check_is_alive()?;
244
245 drop(spawned.writer); thread::sleep(Duration::from_millis(100));
247
248 spawned.shell.terminate()?;
249 let mut buffer = String::new();
250 spawned.reader.read_to_string(&mut buffer)?;
251
252 assert!(buffer.contains("Hello"), "Unexpected buffer: {buffer:?}");
253 Ok(())
254 }
255
256 #[test]
257 fn creating_transcript_with_pty() -> anyhow::Result<()> {
258 let mut options = ShellOptions::new(PtyCommand::default());
259 let inputs = vec![
260 UserInput::command("echo hello"),
261 UserInput::command("echo foo && echo bar >&2"),
262 ];
263 let transcript = Transcript::from_inputs(&mut options, inputs)?;
264
265 assert_eq!(transcript.interactions().len(), 2);
266
267 {
268 let interaction = &transcript.interactions()[0];
269 assert_eq!(interaction.input().text, "echo hello");
270 let output = interaction.output().as_ref();
271 assert_eq!(output.trim(), "hello");
272 }
273
274 let interaction = &transcript.interactions()[1];
275 assert_eq!(interaction.input().text, "echo foo && echo bar >&2");
276 let output = interaction.output().as_ref();
277 assert_eq!(
278 output.split_whitespace().collect::<Vec<_>>(),
279 ["foo", "bar"]
280 );
281 Ok(())
282 }
283
284 const RETRIES: RetryErrors<anyhow::Error> =
285 Retry::times(3).on_error(|err| err.to_string().contains("Unexpected PTY output"));
286
287 #[decorate(RETRIES)] #[test]
289 fn pty_transcript_with_multiline_input() -> anyhow::Result<()> {
290 let mut options = ShellOptions::new(PtyCommand::default());
291 let inputs = vec![UserInput::command("(echo Hello\necho world)")];
292 let transcript = Transcript::from_inputs(&mut options, inputs)?;
293
294 assert_eq!(transcript.interactions().len(), 1);
295 let interaction = &transcript.interactions()[0];
296 let output = interaction.output().as_ref();
297 anyhow::ensure!(
298 output.trim() == "Hello\nworld",
299 "Unexpected PTY output: {output}"
300 );
301 Ok(())
302 }
303}