diff --git a/src/spawner.rs b/src/spawner.rs index 5e0212b..cbf86f8 100644 --- a/src/spawner.rs +++ b/src/spawner.rs @@ -9,7 +9,7 @@ use std::collections::HashMap; use std::ffi::CString; use std::fs::File; use std::io::Read; -use std::os::unix::io::AsRawFd; +use std::os::unix::io::IntoRawFd; use std::path::PathBuf; use nix::unistd; @@ -23,6 +23,23 @@ pub struct Spawner<'a> { pub trailing: &'a Vec<&'a str>, } +enum TriggerData<'a> { + /// No data, for example a Startup trigger + None, + + /// A string sent across a pipe + Pipe(&'a str), +} + +impl<'a> TriggerData<'a> { + fn args(&mut self) -> Vec { + match self { + TriggerData::None => vec![], + TriggerData::Pipe(s) => vec![CString::new(s.to_string()).unwrap()], + } + } +} + impl<'a> Spawner<'a> { pub fn spawn(&mut self) -> Result<()> { for (name, entrypoint) in &self.spec.entrypoints { @@ -36,7 +53,9 @@ impl<'a> Spawner<'a> { builder.mount(binary, "/entrypoint"); let closure = || { - let args = self.prepare_args(name, &entrypoint.args, None); + let args = self + .prepare_args(name, &entrypoint.args, &mut TriggerData::None) + .unwrap(); if let Err(e) = unistd::execv(&CString::new("/entrypoint").unwrap(), &args) .map_err(|e| Error::Nix { @@ -92,7 +111,9 @@ impl<'a> Spawner<'a> { let closure = || { let pipe_trigger = std::str::from_utf8(&buf[0..read_bytes]).unwrap(); - let args = self.prepare_args_ref(name, &spec.args, Some(pipe_trigger)); + let args = self + .prepare_args_ref(name, &spec.args, &mut TriggerData::Pipe(pipe_trigger)) + .unwrap(); if let Err(e) = unistd::execv(&CString::new("/entrypoint").unwrap(), &args) .map_err(|e| Error::Nix { @@ -116,67 +137,67 @@ impl<'a> Spawner<'a> { &mut self, entrypoint: &str, args: &[Arg], - pipe_trigger: Option<&str>, - ) -> Vec { + trigger: &mut TriggerData, + ) -> Result> { let mut out = Vec::new(); for arg in args { - match arg { - Arg::BinaryName => out.push(CString::new(self.binary).unwrap()), - Arg::Entrypoint => out.push(CString::new(entrypoint).unwrap()), - - Arg::Pipe(p) => out.push(match p { - Pipe::Rx(s) => { - let pipe = self.pipes.get_mut(s).unwrap().take_read().unwrap(); - CString::new(pipe.as_raw_fd().to_string()).unwrap() - } - Pipe::Tx(s) => { - let pipe = self.pipes.get_mut(s).unwrap().take_write().unwrap(); - CString::new(pipe.as_raw_fd().to_string()).unwrap() - } - }), - - Arg::PipeTrigger => { - out.push(CString::new(pipe_trigger.as_ref().unwrap().to_string()).unwrap()) - } - - Arg::TcpListener { port: _port } => unimplemented!(), - - Arg::Trailing => { - out.extend(self.trailing.iter().map(|s| CString::new(*s).unwrap())) - } - } + out.extend(self.prepare_arg(entrypoint, arg, trigger)?); } - - out + Ok(out) } - fn prepare_args_ref( &self, entrypoint: &str, args: &[Arg], - pipe_trigger: Option<&str>, - ) -> Vec { + trigger: &mut TriggerData, + ) -> Result> { let mut out = Vec::new(); - for arg in args { - match arg { - Arg::BinaryName => out.push(CString::new(self.binary).unwrap()), - Arg::Entrypoint => out.push(CString::new(entrypoint).unwrap()), - - Arg::Pipe(_) => panic!("can't use pipes with an immutable reference"), - - Arg::PipeTrigger => { - out.push(CString::new(pipe_trigger.as_ref().unwrap().to_string()).unwrap()) - } - - Arg::TcpListener { port: _port } => unimplemented!(), - - Arg::Trailing => { - out.extend(self.trailing.iter().map(|s| CString::new(*s).unwrap())) - } - } + out.extend(self.prepare_arg_ref(entrypoint, arg, trigger)?); } + Ok(out) + } - out + fn prepare_arg( + &mut self, + entrypoint: &str, + arg: &Arg, + trigger: &mut TriggerData, + ) -> Result> { + match arg { + Arg::Pipe(p) => match p { + Pipe::Rx(s) => { + let pipe = self.pipes.get_mut(s).unwrap().take_read()?; + Ok(vec![CString::new(pipe.into_raw_fd().to_string()).unwrap()]) + } + Pipe::Tx(s) => { + let pipe = self.pipes.get_mut(s).unwrap().take_write()?; + Ok(vec![CString::new(pipe.into_raw_fd().to_string()).unwrap()]) + } + }, + a => self.prepare_arg_ref(entrypoint, a, trigger), + } + } + + fn prepare_arg_ref( + &self, + entrypoint: &str, + arg: &Arg, + trigger: &mut TriggerData, + ) -> Result> { + match arg { + Arg::BinaryName => Ok(vec![CString::new(self.binary).unwrap()]), + Arg::Entrypoint => Ok(vec![CString::new(entrypoint).unwrap()]), + + Arg::Pipe(p) => Err(Error::BadPipe(p.get_name().to_string())), + + Arg::Trigger => Ok(trigger.args()), + Arg::TcpListener { port: _port } => unimplemented!(), + Arg::Trailing => Ok(self + .trailing + .iter() + .map(|s| CString::new(*s).unwrap()) + .collect()), + } } } diff --git a/src/specification.rs b/src/specification.rs index ae9f95f..29c0b55 100644 --- a/src/specification.rs +++ b/src/specification.rs @@ -46,9 +46,9 @@ pub enum Arg { /// A chosen end of a named pipe Pipe(Pipe), - /// The value of a pipe trigger + /// A value specified by the trigger /// NOTE: Only valid if the trigger is of type Pipe(...) - PipeTrigger, + Trigger, /// A TCP Listener TcpListener { port: u16 }, @@ -69,6 +69,15 @@ pub enum Pipe { Tx(String), } +impl Pipe { + pub fn get_name(&self) -> &str { + match self { + Pipe::Rx(n) => n, + Pipe::Tx(n) => n, + } + } +} + #[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Debug)] pub enum Permission { Filesystem { @@ -145,7 +154,7 @@ impl Specification { // validate pipe trigger arguments make sense for entrypoint in self.entrypoints.values() { - if entrypoint.args.contains(&Arg::PipeTrigger) { + if entrypoint.args.contains(&Arg::Trigger) { match entrypoint.trigger { Trigger::Pipe(_) => {} _ => return Err(Error::BadTriggerArgument),