use std::{
collections::HashMap,
error::Error as StdError,
ffi::{OsStr, OsString},
io,
path::{Path, PathBuf},
};
use portable_pty::{native_pty_system, Child, CommandBuilder, PtyPair, PtySize};
use crate::{
traits::{ConfigureCommand, ShellProcess, SpawnShell, SpawnedShell},
utils::is_recoverable_kill_error,
};
fn into_io_error(err: Box<dyn StdError + Send + Sync>) -> io::Error {
err.downcast::<io::Error>()
.map_or_else(|err| io::Error::new(io::ErrorKind::Other, err), |err| *err)
}
#[cfg_attr(docsrs, doc(cfg(feature = "portable-pty")))]
#[derive(Debug, Clone)]
pub struct PtyCommand {
args: Vec<OsString>,
env: HashMap<OsString, OsString>,
current_dir: Option<PathBuf>,
pty_size: PtySize,
}
#[cfg(unix)]
impl Default for PtyCommand {
fn default() -> Self {
Self::new("sh")
}
}
#[cfg(windows)]
impl Default for PtyCommand {
fn default() -> Self {
let mut cmd = Self::new("cmd");
cmd.arg("/Q").arg("/K").arg("echo off && chcp 65001");
cmd
}
}
impl PtyCommand {
pub fn new(command: impl Into<OsString>) -> Self {
Self {
args: vec![command.into()],
env: HashMap::new(),
current_dir: None,
pty_size: PtySize {
rows: 19,
cols: 80,
pixel_width: 0,
pixel_height: 0,
},
}
}
pub fn with_size(&mut self, rows: u16, cols: u16) -> &mut Self {
self.pty_size.rows = rows;
self.pty_size.cols = cols;
self
}
pub fn arg(&mut self, arg: impl Into<OsString>) -> &mut Self {
self.args.push(arg.into());
self
}
fn to_command_builder(&self) -> CommandBuilder {
let mut builder = CommandBuilder::from_argv(self.args.clone());
for (name, value) in &self.env {
builder.env(name, value);
}
if let Some(current_dir) = &self.current_dir {
builder.cwd(current_dir);
}
builder
}
}
impl ConfigureCommand for PtyCommand {
fn current_dir(&mut self, dir: &Path) {
self.current_dir = Some(dir.to_owned());
}
fn env(&mut self, name: &str, value: &OsStr) {
self.env
.insert(OsStr::new(name).to_owned(), value.to_owned());
}
}
impl SpawnShell for PtyCommand {
type ShellProcess = PtyShell;
type Reader = Box<dyn io::Read + Send>;
type Writer = Box<dyn io::Write + Send>;
#[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", err))]
fn spawn_shell(&mut self) -> io::Result<SpawnedShell<Self>> {
let pty_system = native_pty_system();
let PtyPair { master, slave } = pty_system
.openpty(self.pty_size)
.map_err(|err| into_io_error(err.into()))?;
#[cfg(feature = "tracing")]
tracing::debug!("created PTY pair");
let child = slave
.spawn_command(self.to_command_builder())
.map_err(|err| into_io_error(err.into()))?;
#[cfg(feature = "tracing")]
tracing::debug!("spawned command into PTY");
let reader = master
.try_clone_reader()
.map_err(|err| into_io_error(err.into()))?;
let writer = master
.take_writer()
.map_err(|err| into_io_error(err.into()))?;
Ok(SpawnedShell {
shell: PtyShell { child },
reader,
writer,
})
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "portable-pty")))]
#[derive(Debug)]
pub struct PtyShell {
child: Box<dyn Child + Send + Sync>,
}
impl ShellProcess for PtyShell {
#[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", err))]
fn check_is_alive(&mut self) -> io::Result<()> {
if let Some(exit_status) = self.child.try_wait()? {
let status_str = if exit_status.success() {
"zero"
} else {
"non-zero"
};
let message =
format!("Shell process has prematurely exited with {status_str} exit status");
Err(io::Error::new(io::ErrorKind::BrokenPipe, message))
} else {
Ok(())
}
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", err))]
fn terminate(mut self) -> io::Result<()> {
if self.child.try_wait()?.is_none() {
self.child.kill().or_else(|err| {
if is_recoverable_kill_error(&err) {
Ok(())
} else {
Err(err)
}
})?;
}
Ok(())
}
fn is_echoing(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use std::{
io::{Read, Write},
thread,
time::Duration,
};
use super::*;
use crate::{ShellOptions, Transcript, UserInput};
#[test]
fn pty_trait_implementation() -> anyhow::Result<()> {
let mut pty_command = PtyCommand::default();
let mut spawned = pty_command.spawn_shell()?;
thread::sleep(Duration::from_millis(100));
spawned.shell.check_is_alive()?;
writeln!(spawned.writer, "echo Hello")?;
thread::sleep(Duration::from_millis(100));
spawned.shell.check_is_alive()?;
drop(spawned.writer); thread::sleep(Duration::from_millis(100));
spawned.shell.terminate()?;
let mut buffer = String::new();
spawned.reader.read_to_string(&mut buffer)?;
assert!(buffer.contains("Hello"), "Unexpected buffer: {buffer:?}");
Ok(())
}
#[test]
fn creating_transcript_with_pty() -> anyhow::Result<()> {
let mut options = ShellOptions::new(PtyCommand::default());
let inputs = vec![
UserInput::command("echo hello"),
UserInput::command("echo foo && echo bar >&2"),
];
let transcript = Transcript::from_inputs(&mut options, inputs)?;
assert_eq!(transcript.interactions().len(), 2);
{
let interaction = &transcript.interactions()[0];
assert_eq!(interaction.input().text, "echo hello");
let output = interaction.output().as_ref();
assert_eq!(output.trim(), "hello");
}
let interaction = &transcript.interactions()[1];
assert_eq!(interaction.input().text, "echo foo && echo bar >&2");
let output = interaction.output().as_ref();
assert_eq!(
output.split_whitespace().collect::<Vec<_>>(),
["foo", "bar"]
);
Ok(())
}
#[cfg(unix)]
#[test]
fn pty_transcript_with_multiline_input() -> anyhow::Result<()> {
let mut options = ShellOptions::new(PtyCommand::default());
let inputs = vec![UserInput::command("echo \\\nhello")];
let transcript = Transcript::from_inputs(&mut options, inputs)?;
assert_eq!(transcript.interactions().len(), 1);
let interaction = &transcript.interactions()[0];
let output = interaction.output().as_ref();
assert_eq!(output.trim(), "hello");
Ok(())
}
}