diff --git a/examples/tls/http.rs b/examples/tls/http.rs index 0688f37..82c6e52 100644 --- a/examples/tls/http.rs +++ b/examples/tls/http.rs @@ -1,9 +1,9 @@ use std::fs; use std::io::{Read, Write}; -use std::net::TcpStream; +use std::os::unix::net::UnixStream; use std::path::PathBuf; -pub(super) fn handler(mut stream: TcpStream) -> i32 { +pub(super) fn handler(mut stream: UnixStream) -> i32 { println!("entered http handler"); let mut buf = Vec::new(); diff --git a/examples/tls/main.rs b/examples/tls/main.rs index dda71b4..94b75f0 100644 --- a/examples/tls/main.rs +++ b/examples/tls/main.rs @@ -3,6 +3,7 @@ mod tls; use std::fs::File; use std::net::{TcpListener, TcpStream}; +use std::os::unix::net::UnixStream; fn main() { match std::env::args().next() { @@ -129,7 +130,7 @@ fn tls_handler_entrypoint() { )); } -fn http_handler(trigger_socket: &File, stream: TcpStream) { +fn http_handler(trigger_socket: &File, stream: UnixStream) { // imports use nix::sys::socket::{sendmsg, ControlMessage, MsgFlags}; use std::os::unix::io::AsRawFd; @@ -162,7 +163,7 @@ fn http_handler_entrypoint() { .expect("request stream required") .parse() .expect("request stream should be a file descriptor"); - let stream = unsafe { TcpStream::from_raw_fd(stream) }; + let stream = unsafe { UnixStream::from_raw_fd(stream) }; std::process::exit(http::handler(stream)); } diff --git a/examples/tls/tls.rs b/examples/tls/tls.rs index 74bf5aa..62dfd2c 100644 --- a/examples/tls/tls.rs +++ b/examples/tls/tls.rs @@ -1,12 +1,91 @@ use std::fs::File; +use std::io::{self, ErrorKind, Read, Write}; use std::net::TcpStream; +use std::os::unix::io::AsRawFd; +use std::os::unix::net::UnixStream; + +use nix::poll::{poll, PollFd, PollFlags}; + +const BUFFER_SIZE: usize = 4096; pub(crate) fn handler( http_trigger_socket: File, _cert: File, _key: File, - stream: TcpStream, + mut stream: TcpStream, ) -> i32 { - super::http_handler(&http_trigger_socket, stream); + let (mut socket, far_socket) = UnixStream::pair().unwrap(); + + super::http_handler(&http_trigger_socket, far_socket); + + stream.set_nonblocking(true).unwrap(); + socket.set_nonblocking(true).unwrap(); + + let mut to_poll = [ + PollFd::new(stream.as_raw_fd(), PollFlags::POLLIN), + PollFd::new(socket.as_raw_fd(), PollFlags::POLLIN), + ]; + + loop { + println!("starting polling"); + poll(&mut to_poll, -1).unwrap(); + + if let Some(events) = to_poll[0].revents() { + if events.contains(PollFlags::POLLIN) { + handle_encrypted_data(&mut stream, &mut socket).unwrap(); + } + } + + if let Some(events) = to_poll[1].revents() { + if events.contains(PollFlags::POLLIN) { + handle_new_data(&mut socket, &mut stream).unwrap(); + } + + if events.contains(PollFlags::POLLHUP) { + println!("response writer hung up, exiting"); + break; + } + } + } + exitcode::OK } + +fn handle_encrypted_data(stream: &mut impl Read, socket: &mut impl Write) -> io::Result<()> { + let mut buf = [0_u8; BUFFER_SIZE]; + + loop { + let read = non_blocking_read(stream, &mut buf)?; + if read == 0 { + return Ok(()); + } + + socket.write_all(&buf[0..read]).unwrap(); + } +} + +fn handle_new_data(socket: &mut impl Read, stream: &mut impl Write) -> io::Result<()> { + let mut buf = [0_u8; BUFFER_SIZE]; + + loop { + let read = non_blocking_read(socket, &mut buf)?; + if read == 0 { + return Ok(()); + } + + stream.write_all(&buf[0..read]).unwrap(); + } +} + +fn non_blocking_read(reader: &mut impl io::Read, buf: &mut [u8]) -> io::Result { + match reader.read(buf) { + Err(e) => { + if e.kind() == ErrorKind::WouldBlock { + Ok(0) + } else { + Err(e) + } + } + Ok(n) => Ok(n), + } +}