fix: PSK broker integration did not work

This commit resolves multiple issues with the PSK broker integration.

- The manual testing procedure never actually utilized the brokers
  due to the use of the outfile option, this led to issues with the
  broker being hidden.
- The manual testing procedure omitted checking whether a PSK was
  actually sent to WireGuard entirely. This was fixed by writing an
  entirely new manual integration testing shell-script that can serve
  as a blueprint for future integration tests.
- Many parts of the PSK broker code did not report (log) errors
  accurately; added error logging
- BrokerServer set message.payload.return_code to the msg_type value,
  this led to crashes
- The PSK broker commands all omitted to set the memfd policy, this led
  to immediate crashes once secrets where actually allocated
- The MioBrokerClient IO state machine was broken and the design was
  too obtuse to debug. The state machine returned the length prefix as
  a message instead of actually interpreting it as a state machine.
  Seems the code was integrated but never actually tested. This was
  fixed by rewriting the entire state machine code using the new
  LengthPrefixEncoder/Decoder facilities. A write-buffer that was not
  being flushed is now handled by flushing the buffer in blocking-io
  mode.
This commit is contained in:
Karolin Varner
2024-08-15 22:14:48 +02:00
parent fd0f35b279
commit 258efe408c
15 changed files with 398 additions and 210 deletions

View File

@@ -26,6 +26,8 @@ env_logger = { workspace = true }
log = { workspace = true }
derive_builder = {workspace = true}
postcard = {workspace = true}
rustix = { worspace = true, optional = true }
libc = { worspace = true, optional = true }
# Mio broker client
mio = { workspace = true }
@@ -36,7 +38,8 @@ rand = {workspace = true}
procspawn = {workspace = true}
[features]
experimental_broker_api = []
experimental_broker_api = ["rustix", "libc"]
experiment_memfd_secret = []
[[bin]]
name = "rosenpass-wireguard-broker-privileged"

View File

@@ -2,7 +2,7 @@ use crate::{SerializedBrokerConfig, WG_KEY_LEN, WG_PEER_LEN};
use derive_builder::Builder;
use rosenpass_secret_memory::{Public, Secret};
#[derive(Builder)]
#[derive(Builder, Debug)]
#[builder(pattern = "mutable")]
//TODO: Use generics for iface, add additional params
pub struct NetworkBrokerConfig<'a> {

View File

@@ -36,6 +36,7 @@ impl<Err, Inner> BrokerServer<Err, Inner>
where
Inner: WireGuardBroker<Error = Err>,
msgs::SetPskError: From<Err>,
Err: std::fmt::Debug,
{
pub fn new(inner: Inner) -> Self {
Self { inner }
@@ -56,9 +57,9 @@ where
.ok_or(BrokerServerError::InvalidMessage)?;
let mut res = zerocopy::Ref::<&mut [u8], Envelope<SetPskResponse>>::new(res)
.ok_or(BrokerServerError::InvalidMessage)?;
res.payload.return_code = msgs::MsgType::SetPsk as u8;
res.msg_type = msgs::MsgType::SetPsk as u8;
self.handle_set_psk(&req.payload, &mut res.payload)?;
Ok(res.bytes().len())
}
@@ -83,6 +84,10 @@ where
.build()
.unwrap();
let r: Result<(), Err> = self.inner.borrow_mut().set_psk(config.into());
if let Err(e) = &r {
eprintln!("Error setting PSK: {e:?}"); // TODO: Use rust log
}
let r: msgs::SetPskResult = r.map_err(|e| e.into());
let r: msgs::SetPskResponseReturnCode = r.into();
res.return_code = r as u8;

View File

@@ -27,6 +27,14 @@ pub mod linux {
}
pub fn main() -> Result<(), BrokerAppError> {
{
use rosenpass_secret_memory as SM;
#[cfg(feature = "experiment_memfd_secret")]
SM::secret_policy_try_use_memfd_secrets();
#[cfg(not(feature = "experiment_memfd_secret"))]
SM::secret_policy_use_only_malloc_secrets();
}
let mut broker = BrokerServer::new(wg::NetlinkWireGuardBroker::new()?);
let mut stdin = stdin().lock();

View File

@@ -148,6 +148,14 @@ async fn listen_for_clients(queue: mpsc::Sender<BrokerRequest>, sock: UnixListen
async fn on_accept(queue: mpsc::Sender<BrokerRequest>, mut stream: UnixStream) -> Result<()> {
let mut req_buf = Vec::new();
{
use rosenpass_secret_memory as SM;
#[cfg(feature = "experiment_memfd_secret")]
SM::secret_policy_try_use_memfd_secrets();
#[cfg(not(feature = "experiment_memfd_secret"))]
SM::secret_policy_use_only_malloc_secrets();
}
loop {
stream.readable().await?;

View File

@@ -1,59 +1,79 @@
use anyhow::{bail, ensure};
use anyhow::{bail, Context};
use mio::Interest;
use rosenpass_util::ord::max_usize;
use std::collections::VecDeque;
use std::io::{ErrorKind, Read, Write};
use crate::{SerializedBrokerConfig, WireGuardBroker, WireguardBrokerMio};
use rosenpass_secret_memory::Secret;
use rosenpass_to::{ops::copy_slice_least_src, To};
use rosenpass_util::io::{IoResultKindHintExt, TryIoResultKindHintExt};
use rosenpass_util::length_prefix_encoding::decoder::LengthPrefixDecoder;
use rosenpass_util::length_prefix_encoding::encoder::LengthPrefixEncoder;
use rustix::fd::AsFd;
use std::borrow::{Borrow, BorrowMut};
use crate::api::client::{
BrokerClient, BrokerClientIo, BrokerClientPollResponseError, BrokerClientSetPskError,
};
use crate::api::msgs::{self, RESPONSE_MSG_BUFFER_SIZE};
use crate::{SerializedBrokerConfig, WireGuardBroker, WireguardBrokerMio};
#[derive(Debug)]
pub struct MioBrokerClient {
inner: BrokerClient<MioBrokerClientIo>,
}
const LEN_SIZE: usize = 8;
const RECV_BUF_SIZE: usize = max_usize(LEN_SIZE, RESPONSE_MSG_BUFFER_SIZE);
#[derive(Debug)]
struct SecretBuffer<const N: usize>(pub Secret<N>);
impl<const N: usize> SecretBuffer<N> {
fn new() -> Self {
Self(Secret::zero())
}
}
impl<const N: usize> Borrow<[u8]> for SecretBuffer<N> {
fn borrow(&self) -> &[u8] {
self.0.secret()
}
}
impl<const N: usize> BorrowMut<[u8]> for SecretBuffer<N> {
fn borrow_mut(&mut self) -> &mut [u8] {
self.0.secret_mut()
}
}
type ReadBuffer = LengthPrefixDecoder<SecretBuffer<4096>>;
type WriteBuffer = LengthPrefixEncoder<SecretBuffer<4096>>;
#[derive(Debug)]
struct MioBrokerClientIo {
socket: mio::net::UnixStream,
send_buf: VecDeque<u8>,
recv_state: RxState,
expected_state: RxState,
recv_buf: [u8; RECV_BUF_SIZE],
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum RxState {
//Recieving size with buffer offset
RxSize(usize),
RxBuffer(usize),
read_buffer: ReadBuffer,
write_buffer: WriteBuffer,
}
impl MioBrokerClient {
pub fn new(socket: mio::net::UnixStream) -> Self {
let read_buffer = LengthPrefixDecoder::new(SecretBuffer::new());
let write_buffer = LengthPrefixEncoder::from_buffer(SecretBuffer::new());
let io = MioBrokerClientIo {
socket,
send_buf: VecDeque::new(),
recv_state: RxState::RxSize(0),
recv_buf: [0u8; RECV_BUF_SIZE],
expected_state: RxState::RxSize(LEN_SIZE),
read_buffer,
write_buffer,
};
let inner = BrokerClient::new(io);
Self { inner }
}
fn poll(&mut self) -> anyhow::Result<Option<msgs::SetPskResult>> {
fn poll(&mut self) -> anyhow::Result<()> {
self.inner.io_mut().flush()?;
// This sucks
match self.inner.poll_response() {
Ok(res) => Ok(res),
let res = self.inner.poll_response();
match res {
Ok(None) => Ok(()),
Ok(Some(Ok(()))) => Ok(()),
Ok(Some(Err(e))) => {
log::warn!("Error from PSK broker: {e:?}");
Ok(())
}
Err(BrokerClientPollResponseError::IoError(e)) => Err(e),
Err(BrokerClientPollResponseError::InvalidMessage) => bail!("Invalid message"),
}
@@ -108,154 +128,101 @@ impl BrokerClientIo for MioBrokerClientIo {
type RecvError = anyhow::Error;
fn send_msg(&mut self, buf: &[u8]) -> Result<(), Self::SendError> {
self.flush()?;
self.send_or_buffer(&(buf.len() as u64).to_le_bytes())?;
self.send_or_buffer(buf)?;
// Clear write buffer (blocking write)
self.flush_blocking()?;
assert!(self.write_buffer.exhausted(), "flush_blocking() should have put the write buffer in exhausted state. Developer error!");
// Emplace new message in write buffer
copy_slice_least_src(buf).to(self.write_buffer.buffer_bytes_mut());
self.write_buffer
.restart_write_with_new_message(buf.len())?;
// Give the write buffer a chance to clear
self.flush()?;
Ok(())
}
fn recv_msg(&mut self) -> Result<Option<&[u8]>, Self::RecvError> {
use std::io::ErrorKind as K;
loop {
match (self.recv_state, self.expected_state) {
//Stale Buffer state or recieved everything
(RxState::RxSize(x), RxState::RxSize(y))
| (RxState::RxBuffer(x), RxState::RxBuffer(y))
if x == y =>
{
match self.recv_state {
RxState::RxSize(s) => {
let len: &[u8; LEN_SIZE] = self.recv_buf[0..s].try_into().unwrap();
let len: usize = u64::from_le_bytes(*len) as usize;
match self
.read_buffer
.read_from_stdio(&self.socket)
.try_io_err_kind_hint()
{
Ok(_) => {} // Moved down in the loop
Err((_, Some(K::WouldBlock))) => break Ok(None),
Err((_, Some(K::Interrupted))) => continue,
Err((e, _)) => break Err(e)?,
}
ensure!(
len <= msgs::RESPONSE_MSG_BUFFER_SIZE,
"Oversized buffer ({len}) in psk buffer response."
);
self.recv_state = RxState::RxBuffer(0);
self.expected_state = RxState::RxBuffer(len);
continue;
}
RxState::RxBuffer(s) => {
self.recv_state = RxState::RxSize(0);
self.expected_state = RxState::RxSize(LEN_SIZE);
return Ok(Some(&self.recv_buf[0..s]));
}
}
}
//Recieve if x < y
(RxState::RxSize(x), RxState::RxSize(y))
| (RxState::RxBuffer(x), RxState::RxBuffer(y))
if x < y =>
{
let bytes = raw_recv(&self.socket, &mut self.recv_buf[x..y])?;
// If we've received nothing so far, and raw_recv came up empty,
// then let the broker client know nothing came
if self.recv_state == RxState::RxSize(0) && bytes == 0 {
return Ok(None);
}
if x + bytes == y {
return Ok(Some(&self.recv_buf[0..y]));
}
// We didn't recieve everything so let's assume something went wrong
self.recv_state = RxState::RxSize(0);
self.expected_state = RxState::RxSize(LEN_SIZE);
bail!("Invalid state");
}
_ => {
//Reset states
self.recv_state = RxState::RxSize(0);
self.expected_state = RxState::RxSize(LEN_SIZE);
bail!("Invalid state");
}
};
// OK case moved here to appease borrow checker
break Ok(self.read_buffer.message()?);
}
}
}
impl MioBrokerClientIo {
fn flush(&mut self) -> anyhow::Result<()> {
let (fst, snd) = self.send_buf.as_slices();
fn flush_blocking(&mut self) -> anyhow::Result<()> {
self.flush()?;
if self.write_buffer.exhausted() {
return Ok(());
}
let (written, res) = match raw_send(&self.socket, fst) {
Ok(w1) if w1 >= fst.len() => match raw_send(&self.socket, snd) {
Ok(w2) => (w1 + w2, Ok(())),
Err(e) => (w1, Err(e)),
},
Ok(w1) => (w1, Ok(())),
Err(e) => (0, Err(e)),
log::warn!("Could not flush PSK broker write buffer in non-blocking mode. Flushing in blocking mode!");
use rustix::io::{fcntl_getfd, fcntl_setfd, FdFlags};
// Build O_NONBLOCK
let o_nonblock = {
let v = libc::O_NONBLOCK;
let v = v.try_into().context(
"Could not cast O_NONBLOCK (`{v}`) from libc int (i32?) to rustix int (u32?)",
)?;
FdFlags::from_bits(v).context(
"Could not cast O_NONBLOCK (`{v}`) from rustix int to rustix::io::FdFlags",
)?
};
self.send_buf.drain(..written);
// Determine previous and new file descriptor flags
let flags_orig = fcntl_getfd(self.socket.as_fd())?;
let mut flags_blocking = flags_orig;
flags_blocking.insert(o_nonblock);
self.socket.try_io(|| (&self.socket).flush())?;
// Set file descriptor flags
fcntl_setfd(self.socket.as_fd(), flags_blocking)?;
res
// Blocking write
let res = loop {
if self.write_buffer.exhausted() {
break Ok(());
}
match self.flush() {
Ok(_) => {}
Err(e) => break Err(e),
}
};
// Restore file descriptor flags
fcntl_setfd(self.socket.as_fd(), flags_orig)?;
Ok(res?)
}
fn send_or_buffer(&mut self, buf: &[u8]) -> anyhow::Result<()> {
let mut off = 0;
if self.send_buf.is_empty() {
off += raw_send(&self.socket, buf)?;
fn flush(&mut self) -> std::io::Result<()> {
use std::io::ErrorKind as K;
loop {
match self
.write_buffer
.write_to_stdio(&self.socket)
.io_err_kind_hint()
{
Ok(_) => break Ok(()),
Err((_, K::WouldBlock)) => break Ok(()),
Err((_, K::Interrupted)) => continue,
Err((e, _)) => return Err(e)?,
}
}
self.send_buf.extend(buf[off..].iter());
Ok(())
}
}
fn raw_send(mut socket: &mio::net::UnixStream, data: &[u8]) -> anyhow::Result<usize> {
let mut off = 0;
socket.try_io(|| {
loop {
if off == data.len() {
return Ok(());
}
match socket.write(&data[off..]) {
Ok(n) => {
off += n;
}
Err(e) if e.kind() == ErrorKind::Interrupted => {
// pass retry
}
Err(e) if off > 0 || e.kind() == ErrorKind::WouldBlock => return Ok(()),
Err(e) => return Err(e),
}
}
})?;
Ok(off)
}
fn raw_recv(mut socket: &mio::net::UnixStream, out: &mut [u8]) -> anyhow::Result<usize> {
let mut off = 0;
socket.try_io(|| {
loop {
if off == out.len() {
return Ok(());
}
match socket.read(&mut out[off..]) {
Ok(n) => {
off += n;
}
Err(e) if e.kind() == ErrorKind::Interrupted => {
// pass retry
}
Err(e) if off > 0 || e.kind() == ErrorKind::WouldBlock => return Ok(()),
Err(e) => return Err(e),
}
}
})?;
Ok(off)
}

View File

@@ -79,6 +79,7 @@ impl WireGuardBroker for NetlinkWireGuardBroker {
fn set_psk(&mut self, config: SerializedBrokerConfig) -> Result<(), Self::Error> {
let config: NetworkBrokerConfig = config
.try_into()
// TODO: I think this is the wrong error
.map_err(|_e| SetPskError::NoSuchInterface)?;
// Ensure that the peer exists by querying the device configuration
// TODO: Use InvalidInterfaceError