From 01f8f096c6bdb0eb5442cadce43d0faa32610937 Mon Sep 17 00:00:00 2001 From: Jake Hillion Date: Tue, 10 May 2022 23:34:27 +0100 Subject: [PATCH] allowing multiple processes to share socket --- src/error.rs | 2 +- src/lib.rs | 24 ++++++++++++++---------- src/spawner/args.rs | 28 ++++++++++++++++++---------- src/spawner/mod.rs | 4 ++-- src/specification.rs | 34 +++++++++++++++++++++++++--------- 5 files changed, 60 insertions(+), 32 deletions(-) diff --git a/src/error.rs b/src/error.rs index 700ecd5..f8db402 100644 --- a/src/error.rs +++ b/src/error.rs @@ -18,7 +18,7 @@ pub enum Error { #[error("bad pipe specification: a pipe must have exactly one reader and one writer: {0}")] BadPipe(String), - #[error("bad socket specification: a socket must have exactly one reader and one writer: {0}")] + #[error("bad socket specification: a socket must have exactly one reader and one or more writers: {0}")] BadFileSocket(String), #[error("bad specification type: only .json files are supported")] diff --git a/src/lib.rs b/src/lib.rs index 049b24c..5162885 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,7 @@ use specification::Specification; use std::collections::HashMap; use std::fs::File; -use std::os::unix::io::FromRawFd; +use std::os::unix::io::{AsRawFd, FromRawFd}; use std::path::Path; use nix::fcntl::OFlag; @@ -135,10 +135,11 @@ impl PipePair { src: e, })?; - // safe to create files given the successful return of pipe(2) Ok(PipePair { name: name.to_string(), + // SAFETY: valid new fd as pipe2(2) returned successfully read: Some(unsafe { File::from_raw_fd(read) }), + // SAFETY: valid new fd as pipe2(2) returned successfully write: Some(unsafe { File::from_raw_fd(write) }), }) } @@ -160,7 +161,7 @@ pub struct SocketPair { name: String, read: Option, - write: Option, + write: File, } impl SocketPair { @@ -176,23 +177,26 @@ impl SocketPair { src: e, })?; - // safe to create files given the successful return of socketpair(2) Ok(SocketPair { name: name.to_string(), + // SAFETY: valid new fd as socketpair(2) returned successfully read: Some(unsafe { File::from_raw_fd(read) }), - write: Some(unsafe { File::from_raw_fd(write) }), + // SAFETY: valid new fd as socketpair(2) returned successfully + write: unsafe { File::from_raw_fd(write) }, }) } fn take_read(&mut self) -> Result { self.read .take() - .ok_or_else(|| Error::BadPipe(self.name.to_string())) + .ok_or_else(|| Error::BadFileSocket(self.name.to_string())) } - fn take_write(&mut self) -> Result { - self.write - .take() - .ok_or_else(|| Error::BadPipe(self.name.to_string())) + fn write(&self) -> Result { + let dup_fd = nix::unistd::dup(self.write.as_raw_fd()) + .map_err(|e| Error::Nix { msg: "dup", src: e })?; + + // SAFETY: valid new fd as dup(2) returned successfully + Ok(unsafe { File::from_raw_fd(dup_fd) }) } } diff --git a/src/spawner/args.rs b/src/spawner/args.rs index e3aba55..be8b918 100644 --- a/src/spawner/args.rs +++ b/src/spawner/args.rs @@ -37,11 +37,15 @@ impl PreparedArgs { * for things like network sockets. update the builder * with newly passed fds. */ - pub fn prepare_ambient(builder: &mut VoidBuilder, args: &[Arg]) -> Result { + pub fn prepare_ambient( + spawner: &Spawner, + builder: &mut VoidBuilder, + args: &[Arg], + ) -> Result { let mut v = Vec::with_capacity(args.len()); for arg in args { - v.push(PreparedArg::prepare_ambient(builder, arg)?); + v.push(PreparedArg::prepare_ambient(spawner, builder, arg)?); } Ok(PreparedArgs(v)) @@ -113,24 +117,28 @@ impl PreparedArg { PreparedArg::Pipe(pipe) } - Arg::FileSocket(s) => { - let socket = match s { - FileSocket::Rx(s) => spawner.sockets.get_mut(s).unwrap().take_read(), - FileSocket::Tx(s) => spawner.sockets.get_mut(s).unwrap().take_write(), - }?; + Arg::FileSocket(FileSocket::Rx(s)) => { + let socket = spawner.sockets.get_mut(s).unwrap().take_read()?; builder.keep_fd(&socket); PreparedArg::FileSocket(socket) } - arg => Self::prepare_ambient(builder, arg)?, + arg => Self::prepare_ambient(spawner, builder, arg)?, }) } - fn prepare_ambient(builder: &mut VoidBuilder, arg: &Arg) -> Result { + fn prepare_ambient(spawner: &Spawner, builder: &mut VoidBuilder, arg: &Arg) -> Result { Ok(match arg { Arg::Pipe(p) => return Err(Error::BadPipe(p.get_name().to_string())), - Arg::FileSocket(s) => return Err(Error::BadFileSocket(s.get_name().to_string())), + Arg::FileSocket(FileSocket::Rx(s)) => return Err(Error::BadFileSocket(s.to_string())), + + Arg::FileSocket(FileSocket::Tx(s)) => { + let socket = spawner.sockets.get(s).unwrap().write()?; + + builder.keep_fd(&socket); + PreparedArg::FileSocket(socket) + } Arg::File(path) => { let fd = File::open(path)?; diff --git a/src/spawner/mod.rs b/src/spawner/mod.rs index cdc0f96..3c81942 100644 --- a/src/spawner/mod.rs +++ b/src/spawner/mod.rs @@ -169,7 +169,7 @@ impl<'a> Spawner<'a> { self.prepare_env(&mut builder, &spec.environment); - let args = PreparedArgs::prepare_ambient(&mut builder, &spec.args)?; + let args = PreparedArgs::prepare_ambient(self, &mut builder, &spec.args)?; let closure = || { @@ -234,7 +234,7 @@ impl<'a> Spawner<'a> { self.prepare_env(&mut builder, &spec.environment); - let args = PreparedArgs::prepare_ambient(&mut builder, &spec.args)?; + let args = PreparedArgs::prepare_ambient(self, &mut builder, &spec.args)?; let closure = || { if self.debug { diff --git a/src/specification.rs b/src/specification.rs index 18d9de7..5bd936b 100644 --- a/src/specification.rs +++ b/src/specification.rs @@ -99,15 +99,6 @@ pub enum FileSocket { Tx(String), } -impl FileSocket { - pub fn get_name(&self) -> &str { - match self { - FileSocket::Rx(n) => n, - FileSocket::Tx(n) => n, - } - } -} - #[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Debug)] pub enum Environment { Filesystem { @@ -206,6 +197,31 @@ impl Specification { return Err(Error::BadPipe(pipe.to_string())); } + // validate sockets match + let (read, write) = self.sockets(); + let mut read_set = HashSet::with_capacity(read.len()); + + for socket in read { + if !read_set.insert(socket) { + return Err(Error::BadFileSocket(socket.to_string())); + } + } + + let mut write_set = HashSet::with_capacity(write.len()); + for socket in write { + write_set.insert(socket); + } + + for socket in &read_set { + if !write_set.contains(socket) { + return Err(Error::BadFileSocket(socket.to_string())); + } + } + + if let Some(socket) = (&write_set - &read_set).into_iter().next() { + return Err(Error::BadFileSocket(socket.to_string())); + } + // validate trigger arguments make sense for entrypoint in self.entrypoints.values() { if entrypoint.args.contains(&Arg::Trigger) { -- 2.47.0