forward tls handler

This commit is contained in:
Jake Hillion 2022-05-23 18:26:55 +01:00
parent a2cbe9dc54
commit 20c2f8d5f7
3 changed files with 86 additions and 6 deletions

View File

@ -1,9 +1,9 @@
use std::fs;
use std::io::{Read, Write};
use std::net::TcpStream;
use std::os::unix::net::UnixStream;
use std::path::PathBuf;
pub(super) fn handler(mut stream: TcpStream) -> i32 {
pub(super) fn handler(mut stream: UnixStream) -> i32 {
println!("entered http handler");
let mut buf = Vec::new();

View File

@ -3,6 +3,7 @@ mod tls;
use std::fs::File;
use std::net::{TcpListener, TcpStream};
use std::os::unix::net::UnixStream;
fn main() {
match std::env::args().next() {
@ -129,7 +130,7 @@ fn tls_handler_entrypoint() {
));
}
fn http_handler(trigger_socket: &File, stream: TcpStream) {
fn http_handler(trigger_socket: &File, stream: UnixStream) {
// imports
use nix::sys::socket::{sendmsg, ControlMessage, MsgFlags};
use std::os::unix::io::AsRawFd;
@ -162,7 +163,7 @@ fn http_handler_entrypoint() {
.expect("request stream required")
.parse()
.expect("request stream should be a file descriptor");
let stream = unsafe { TcpStream::from_raw_fd(stream) };
let stream = unsafe { UnixStream::from_raw_fd(stream) };
std::process::exit(http::handler(stream));
}

View File

@ -1,12 +1,91 @@
use std::fs::File;
use std::io::{self, ErrorKind, Read, Write};
use std::net::TcpStream;
use std::os::unix::io::AsRawFd;
use std::os::unix::net::UnixStream;
use nix::poll::{poll, PollFd, PollFlags};
const BUFFER_SIZE: usize = 4096;
pub(crate) fn handler(
http_trigger_socket: File,
_cert: File,
_key: File,
stream: TcpStream,
mut stream: TcpStream,
) -> i32 {
super::http_handler(&http_trigger_socket, stream);
let (mut socket, far_socket) = UnixStream::pair().unwrap();
super::http_handler(&http_trigger_socket, far_socket);
stream.set_nonblocking(true).unwrap();
socket.set_nonblocking(true).unwrap();
let mut to_poll = [
PollFd::new(stream.as_raw_fd(), PollFlags::POLLIN),
PollFd::new(socket.as_raw_fd(), PollFlags::POLLIN),
];
loop {
println!("starting polling");
poll(&mut to_poll, -1).unwrap();
if let Some(events) = to_poll[0].revents() {
if events.contains(PollFlags::POLLIN) {
handle_encrypted_data(&mut stream, &mut socket).unwrap();
}
}
if let Some(events) = to_poll[1].revents() {
if events.contains(PollFlags::POLLIN) {
handle_new_data(&mut socket, &mut stream).unwrap();
}
if events.contains(PollFlags::POLLHUP) {
println!("response writer hung up, exiting");
break;
}
}
}
exitcode::OK
}
fn handle_encrypted_data(stream: &mut impl Read, socket: &mut impl Write) -> io::Result<()> {
let mut buf = [0_u8; BUFFER_SIZE];
loop {
let read = non_blocking_read(stream, &mut buf)?;
if read == 0 {
return Ok(());
}
socket.write_all(&buf[0..read]).unwrap();
}
}
fn handle_new_data(socket: &mut impl Read, stream: &mut impl Write) -> io::Result<()> {
let mut buf = [0_u8; BUFFER_SIZE];
loop {
let read = non_blocking_read(socket, &mut buf)?;
if read == 0 {
return Ok(());
}
stream.write_all(&buf[0..read]).unwrap();
}
}
fn non_blocking_read(reader: &mut impl io::Read, buf: &mut [u8]) -> io::Result<usize> {
match reader.read(buf) {
Err(e) => {
if e.kind() == ErrorKind::WouldBlock {
Ok(0)
} else {
Err(e)
}
}
Ok(n) => Ok(n),
}
}