mirror of
https://github.com/rosenpass/rosenpass.git
synced 2025-12-12 07:40:30 -08:00
fix: PSK broker integration did not work
This commit resolves multiple issues with the PSK broker integration. - The manual testing procedure never actually utilized the brokers due to the use of the outfile option, this led to issues with the broker being hidden. - The manual testing procedure omitted checking whether a PSK was actually sent to WireGuard entirely. This was fixed by writing an entirely new manual integration testing shell-script that can serve as a blueprint for future integration tests. - Many parts of the PSK broker code did not report (log) errors accurately; added error logging - BrokerServer set message.payload.return_code to the msg_type value, this led to crashes - The PSK broker commands all omitted to set the memfd policy, this led to immediate crashes once secrets where actually allocated - The MioBrokerClient IO state machine was broken and the design was too obtuse to debug. The state machine returned the length prefix as a message instead of actually interpreting it as a state machine. Seems the code was integrated but never actually tested. This was fixed by rewriting the entire state machine code using the new LengthPrefixEncoder/Decoder facilities. A write-buffer that was not being flushed is now handled by flushing the buffer in blocking-io mode.
This commit is contained in:
@@ -26,6 +26,8 @@ env_logger = { workspace = true }
|
||||
log = { workspace = true }
|
||||
derive_builder = {workspace = true}
|
||||
postcard = {workspace = true}
|
||||
rustix = { worspace = true, optional = true }
|
||||
libc = { worspace = true, optional = true }
|
||||
|
||||
# Mio broker client
|
||||
mio = { workspace = true }
|
||||
@@ -36,7 +38,8 @@ rand = {workspace = true}
|
||||
procspawn = {workspace = true}
|
||||
|
||||
[features]
|
||||
experimental_broker_api = []
|
||||
experimental_broker_api = ["rustix", "libc"]
|
||||
experiment_memfd_secret = []
|
||||
|
||||
[[bin]]
|
||||
name = "rosenpass-wireguard-broker-privileged"
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::{SerializedBrokerConfig, WG_KEY_LEN, WG_PEER_LEN};
|
||||
use derive_builder::Builder;
|
||||
use rosenpass_secret_memory::{Public, Secret};
|
||||
|
||||
#[derive(Builder)]
|
||||
#[derive(Builder, Debug)]
|
||||
#[builder(pattern = "mutable")]
|
||||
//TODO: Use generics for iface, add additional params
|
||||
pub struct NetworkBrokerConfig<'a> {
|
||||
|
||||
@@ -36,6 +36,7 @@ impl<Err, Inner> BrokerServer<Err, Inner>
|
||||
where
|
||||
Inner: WireGuardBroker<Error = Err>,
|
||||
msgs::SetPskError: From<Err>,
|
||||
Err: std::fmt::Debug,
|
||||
{
|
||||
pub fn new(inner: Inner) -> Self {
|
||||
Self { inner }
|
||||
@@ -56,9 +57,9 @@ where
|
||||
.ok_or(BrokerServerError::InvalidMessage)?;
|
||||
let mut res = zerocopy::Ref::<&mut [u8], Envelope<SetPskResponse>>::new(res)
|
||||
.ok_or(BrokerServerError::InvalidMessage)?;
|
||||
|
||||
res.payload.return_code = msgs::MsgType::SetPsk as u8;
|
||||
res.msg_type = msgs::MsgType::SetPsk as u8;
|
||||
self.handle_set_psk(&req.payload, &mut res.payload)?;
|
||||
|
||||
Ok(res.bytes().len())
|
||||
}
|
||||
|
||||
@@ -83,6 +84,10 @@ where
|
||||
.build()
|
||||
.unwrap();
|
||||
let r: Result<(), Err> = self.inner.borrow_mut().set_psk(config.into());
|
||||
if let Err(e) = &r {
|
||||
eprintln!("Error setting PSK: {e:?}"); // TODO: Use rust log
|
||||
}
|
||||
|
||||
let r: msgs::SetPskResult = r.map_err(|e| e.into());
|
||||
let r: msgs::SetPskResponseReturnCode = r.into();
|
||||
res.return_code = r as u8;
|
||||
|
||||
@@ -27,6 +27,14 @@ pub mod linux {
|
||||
}
|
||||
|
||||
pub fn main() -> Result<(), BrokerAppError> {
|
||||
{
|
||||
use rosenpass_secret_memory as SM;
|
||||
#[cfg(feature = "experiment_memfd_secret")]
|
||||
SM::secret_policy_try_use_memfd_secrets();
|
||||
#[cfg(not(feature = "experiment_memfd_secret"))]
|
||||
SM::secret_policy_use_only_malloc_secrets();
|
||||
}
|
||||
|
||||
let mut broker = BrokerServer::new(wg::NetlinkWireGuardBroker::new()?);
|
||||
|
||||
let mut stdin = stdin().lock();
|
||||
|
||||
@@ -148,6 +148,14 @@ async fn listen_for_clients(queue: mpsc::Sender<BrokerRequest>, sock: UnixListen
|
||||
async fn on_accept(queue: mpsc::Sender<BrokerRequest>, mut stream: UnixStream) -> Result<()> {
|
||||
let mut req_buf = Vec::new();
|
||||
|
||||
{
|
||||
use rosenpass_secret_memory as SM;
|
||||
#[cfg(feature = "experiment_memfd_secret")]
|
||||
SM::secret_policy_try_use_memfd_secrets();
|
||||
#[cfg(not(feature = "experiment_memfd_secret"))]
|
||||
SM::secret_policy_use_only_malloc_secrets();
|
||||
}
|
||||
|
||||
loop {
|
||||
stream.readable().await?;
|
||||
|
||||
|
||||
@@ -1,59 +1,79 @@
|
||||
use anyhow::{bail, ensure};
|
||||
use anyhow::{bail, Context};
|
||||
use mio::Interest;
|
||||
use rosenpass_util::ord::max_usize;
|
||||
use std::collections::VecDeque;
|
||||
use std::io::{ErrorKind, Read, Write};
|
||||
|
||||
use crate::{SerializedBrokerConfig, WireGuardBroker, WireguardBrokerMio};
|
||||
use rosenpass_secret_memory::Secret;
|
||||
use rosenpass_to::{ops::copy_slice_least_src, To};
|
||||
use rosenpass_util::io::{IoResultKindHintExt, TryIoResultKindHintExt};
|
||||
use rosenpass_util::length_prefix_encoding::decoder::LengthPrefixDecoder;
|
||||
use rosenpass_util::length_prefix_encoding::encoder::LengthPrefixEncoder;
|
||||
use rustix::fd::AsFd;
|
||||
use std::borrow::{Borrow, BorrowMut};
|
||||
|
||||
use crate::api::client::{
|
||||
BrokerClient, BrokerClientIo, BrokerClientPollResponseError, BrokerClientSetPskError,
|
||||
};
|
||||
use crate::api::msgs::{self, RESPONSE_MSG_BUFFER_SIZE};
|
||||
use crate::{SerializedBrokerConfig, WireGuardBroker, WireguardBrokerMio};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MioBrokerClient {
|
||||
inner: BrokerClient<MioBrokerClientIo>,
|
||||
}
|
||||
|
||||
const LEN_SIZE: usize = 8;
|
||||
const RECV_BUF_SIZE: usize = max_usize(LEN_SIZE, RESPONSE_MSG_BUFFER_SIZE);
|
||||
#[derive(Debug)]
|
||||
struct SecretBuffer<const N: usize>(pub Secret<N>);
|
||||
|
||||
impl<const N: usize> SecretBuffer<N> {
|
||||
fn new() -> Self {
|
||||
Self(Secret::zero())
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> Borrow<[u8]> for SecretBuffer<N> {
|
||||
fn borrow(&self) -> &[u8] {
|
||||
self.0.secret()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> BorrowMut<[u8]> for SecretBuffer<N> {
|
||||
fn borrow_mut(&mut self) -> &mut [u8] {
|
||||
self.0.secret_mut()
|
||||
}
|
||||
}
|
||||
|
||||
type ReadBuffer = LengthPrefixDecoder<SecretBuffer<4096>>;
|
||||
type WriteBuffer = LengthPrefixEncoder<SecretBuffer<4096>>;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MioBrokerClientIo {
|
||||
socket: mio::net::UnixStream,
|
||||
send_buf: VecDeque<u8>,
|
||||
recv_state: RxState,
|
||||
expected_state: RxState,
|
||||
recv_buf: [u8; RECV_BUF_SIZE],
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
enum RxState {
|
||||
//Recieving size with buffer offset
|
||||
RxSize(usize),
|
||||
RxBuffer(usize),
|
||||
read_buffer: ReadBuffer,
|
||||
write_buffer: WriteBuffer,
|
||||
}
|
||||
|
||||
impl MioBrokerClient {
|
||||
pub fn new(socket: mio::net::UnixStream) -> Self {
|
||||
let read_buffer = LengthPrefixDecoder::new(SecretBuffer::new());
|
||||
let write_buffer = LengthPrefixEncoder::from_buffer(SecretBuffer::new());
|
||||
let io = MioBrokerClientIo {
|
||||
socket,
|
||||
send_buf: VecDeque::new(),
|
||||
recv_state: RxState::RxSize(0),
|
||||
recv_buf: [0u8; RECV_BUF_SIZE],
|
||||
expected_state: RxState::RxSize(LEN_SIZE),
|
||||
read_buffer,
|
||||
write_buffer,
|
||||
};
|
||||
let inner = BrokerClient::new(io);
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
fn poll(&mut self) -> anyhow::Result<Option<msgs::SetPskResult>> {
|
||||
fn poll(&mut self) -> anyhow::Result<()> {
|
||||
self.inner.io_mut().flush()?;
|
||||
|
||||
// This sucks
|
||||
match self.inner.poll_response() {
|
||||
Ok(res) => Ok(res),
|
||||
let res = self.inner.poll_response();
|
||||
match res {
|
||||
Ok(None) => Ok(()),
|
||||
Ok(Some(Ok(()))) => Ok(()),
|
||||
Ok(Some(Err(e))) => {
|
||||
log::warn!("Error from PSK broker: {e:?}");
|
||||
Ok(())
|
||||
}
|
||||
Err(BrokerClientPollResponseError::IoError(e)) => Err(e),
|
||||
Err(BrokerClientPollResponseError::InvalidMessage) => bail!("Invalid message"),
|
||||
}
|
||||
@@ -108,154 +128,101 @@ impl BrokerClientIo for MioBrokerClientIo {
|
||||
type RecvError = anyhow::Error;
|
||||
|
||||
fn send_msg(&mut self, buf: &[u8]) -> Result<(), Self::SendError> {
|
||||
self.flush()?;
|
||||
self.send_or_buffer(&(buf.len() as u64).to_le_bytes())?;
|
||||
self.send_or_buffer(buf)?;
|
||||
// Clear write buffer (blocking write)
|
||||
self.flush_blocking()?;
|
||||
assert!(self.write_buffer.exhausted(), "flush_blocking() should have put the write buffer in exhausted state. Developer error!");
|
||||
|
||||
// Emplace new message in write buffer
|
||||
copy_slice_least_src(buf).to(self.write_buffer.buffer_bytes_mut());
|
||||
self.write_buffer
|
||||
.restart_write_with_new_message(buf.len())?;
|
||||
|
||||
// Give the write buffer a chance to clear
|
||||
self.flush()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn recv_msg(&mut self) -> Result<Option<&[u8]>, Self::RecvError> {
|
||||
use std::io::ErrorKind as K;
|
||||
loop {
|
||||
match (self.recv_state, self.expected_state) {
|
||||
//Stale Buffer state or recieved everything
|
||||
(RxState::RxSize(x), RxState::RxSize(y))
|
||||
| (RxState::RxBuffer(x), RxState::RxBuffer(y))
|
||||
if x == y =>
|
||||
{
|
||||
match self.recv_state {
|
||||
RxState::RxSize(s) => {
|
||||
let len: &[u8; LEN_SIZE] = self.recv_buf[0..s].try_into().unwrap();
|
||||
let len: usize = u64::from_le_bytes(*len) as usize;
|
||||
match self
|
||||
.read_buffer
|
||||
.read_from_stdio(&self.socket)
|
||||
.try_io_err_kind_hint()
|
||||
{
|
||||
Ok(_) => {} // Moved down in the loop
|
||||
Err((_, Some(K::WouldBlock))) => break Ok(None),
|
||||
Err((_, Some(K::Interrupted))) => continue,
|
||||
Err((e, _)) => break Err(e)?,
|
||||
}
|
||||
|
||||
ensure!(
|
||||
len <= msgs::RESPONSE_MSG_BUFFER_SIZE,
|
||||
"Oversized buffer ({len}) in psk buffer response."
|
||||
);
|
||||
|
||||
self.recv_state = RxState::RxBuffer(0);
|
||||
self.expected_state = RxState::RxBuffer(len);
|
||||
continue;
|
||||
}
|
||||
RxState::RxBuffer(s) => {
|
||||
self.recv_state = RxState::RxSize(0);
|
||||
self.expected_state = RxState::RxSize(LEN_SIZE);
|
||||
return Ok(Some(&self.recv_buf[0..s]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//Recieve if x < y
|
||||
(RxState::RxSize(x), RxState::RxSize(y))
|
||||
| (RxState::RxBuffer(x), RxState::RxBuffer(y))
|
||||
if x < y =>
|
||||
{
|
||||
let bytes = raw_recv(&self.socket, &mut self.recv_buf[x..y])?;
|
||||
|
||||
// If we've received nothing so far, and raw_recv came up empty,
|
||||
// then let the broker client know nothing came
|
||||
if self.recv_state == RxState::RxSize(0) && bytes == 0 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if x + bytes == y {
|
||||
return Ok(Some(&self.recv_buf[0..y]));
|
||||
}
|
||||
|
||||
// We didn't recieve everything so let's assume something went wrong
|
||||
self.recv_state = RxState::RxSize(0);
|
||||
self.expected_state = RxState::RxSize(LEN_SIZE);
|
||||
bail!("Invalid state");
|
||||
}
|
||||
_ => {
|
||||
//Reset states
|
||||
self.recv_state = RxState::RxSize(0);
|
||||
self.expected_state = RxState::RxSize(LEN_SIZE);
|
||||
bail!("Invalid state");
|
||||
}
|
||||
};
|
||||
// OK case moved here to appease borrow checker
|
||||
break Ok(self.read_buffer.message()?);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MioBrokerClientIo {
|
||||
fn flush(&mut self) -> anyhow::Result<()> {
|
||||
let (fst, snd) = self.send_buf.as_slices();
|
||||
fn flush_blocking(&mut self) -> anyhow::Result<()> {
|
||||
self.flush()?;
|
||||
if self.write_buffer.exhausted() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let (written, res) = match raw_send(&self.socket, fst) {
|
||||
Ok(w1) if w1 >= fst.len() => match raw_send(&self.socket, snd) {
|
||||
Ok(w2) => (w1 + w2, Ok(())),
|
||||
Err(e) => (w1, Err(e)),
|
||||
},
|
||||
Ok(w1) => (w1, Ok(())),
|
||||
Err(e) => (0, Err(e)),
|
||||
log::warn!("Could not flush PSK broker write buffer in non-blocking mode. Flushing in blocking mode!");
|
||||
use rustix::io::{fcntl_getfd, fcntl_setfd, FdFlags};
|
||||
|
||||
// Build O_NONBLOCK
|
||||
let o_nonblock = {
|
||||
let v = libc::O_NONBLOCK;
|
||||
let v = v.try_into().context(
|
||||
"Could not cast O_NONBLOCK (`{v}`) from libc int (i32?) to rustix int (u32?)",
|
||||
)?;
|
||||
FdFlags::from_bits(v).context(
|
||||
"Could not cast O_NONBLOCK (`{v}`) from rustix int to rustix::io::FdFlags",
|
||||
)?
|
||||
};
|
||||
|
||||
self.send_buf.drain(..written);
|
||||
// Determine previous and new file descriptor flags
|
||||
let flags_orig = fcntl_getfd(self.socket.as_fd())?;
|
||||
let mut flags_blocking = flags_orig;
|
||||
flags_blocking.insert(o_nonblock);
|
||||
|
||||
self.socket.try_io(|| (&self.socket).flush())?;
|
||||
// Set file descriptor flags
|
||||
fcntl_setfd(self.socket.as_fd(), flags_blocking)?;
|
||||
|
||||
res
|
||||
// Blocking write
|
||||
let res = loop {
|
||||
if self.write_buffer.exhausted() {
|
||||
break Ok(());
|
||||
}
|
||||
|
||||
match self.flush() {
|
||||
Ok(_) => {}
|
||||
Err(e) => break Err(e),
|
||||
}
|
||||
};
|
||||
|
||||
// Restore file descriptor flags
|
||||
fcntl_setfd(self.socket.as_fd(), flags_orig)?;
|
||||
|
||||
Ok(res?)
|
||||
}
|
||||
|
||||
fn send_or_buffer(&mut self, buf: &[u8]) -> anyhow::Result<()> {
|
||||
let mut off = 0;
|
||||
|
||||
if self.send_buf.is_empty() {
|
||||
off += raw_send(&self.socket, buf)?;
|
||||
fn flush(&mut self) -> std::io::Result<()> {
|
||||
use std::io::ErrorKind as K;
|
||||
loop {
|
||||
match self
|
||||
.write_buffer
|
||||
.write_to_stdio(&self.socket)
|
||||
.io_err_kind_hint()
|
||||
{
|
||||
Ok(_) => break Ok(()),
|
||||
Err((_, K::WouldBlock)) => break Ok(()),
|
||||
Err((_, K::Interrupted)) => continue,
|
||||
Err((e, _)) => return Err(e)?,
|
||||
}
|
||||
}
|
||||
|
||||
self.send_buf.extend(buf[off..].iter());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn raw_send(mut socket: &mio::net::UnixStream, data: &[u8]) -> anyhow::Result<usize> {
|
||||
let mut off = 0;
|
||||
|
||||
socket.try_io(|| {
|
||||
loop {
|
||||
if off == data.len() {
|
||||
return Ok(());
|
||||
}
|
||||
match socket.write(&data[off..]) {
|
||||
Ok(n) => {
|
||||
off += n;
|
||||
}
|
||||
Err(e) if e.kind() == ErrorKind::Interrupted => {
|
||||
// pass – retry
|
||||
}
|
||||
Err(e) if off > 0 || e.kind() == ErrorKind::WouldBlock => return Ok(()),
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(off)
|
||||
}
|
||||
|
||||
fn raw_recv(mut socket: &mio::net::UnixStream, out: &mut [u8]) -> anyhow::Result<usize> {
|
||||
let mut off = 0;
|
||||
|
||||
socket.try_io(|| {
|
||||
loop {
|
||||
if off == out.len() {
|
||||
return Ok(());
|
||||
}
|
||||
match socket.read(&mut out[off..]) {
|
||||
Ok(n) => {
|
||||
off += n;
|
||||
}
|
||||
Err(e) if e.kind() == ErrorKind::Interrupted => {
|
||||
// pass – retry
|
||||
}
|
||||
Err(e) if off > 0 || e.kind() == ErrorKind::WouldBlock => return Ok(()),
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(off)
|
||||
}
|
||||
|
||||
@@ -79,6 +79,7 @@ impl WireGuardBroker for NetlinkWireGuardBroker {
|
||||
fn set_psk(&mut self, config: SerializedBrokerConfig) -> Result<(), Self::Error> {
|
||||
let config: NetworkBrokerConfig = config
|
||||
.try_into()
|
||||
// TODO: I think this is the wrong error
|
||||
.map_err(|_e| SetPskError::NoSuchInterface)?;
|
||||
// Ensure that the peer exists by querying the device configuration
|
||||
// TODO: Use InvalidInterfaceError
|
||||
|
||||
Reference in New Issue
Block a user