feat(API): Use mio::Token based polling

Avoid polling every single IO source to collect events,
poll those specific IO sources mio tells us about.
This commit is contained in:
Karolin Varner
2024-08-18 21:19:44 +02:00
parent 53e560191f
commit 77760d71df
13 changed files with 403 additions and 98 deletions

View File

@@ -74,6 +74,7 @@ tempfile = { workspace = true }
rustix = {workspace = true}
[features]
default = ["experiment_api"]
experiment_memfd_secret = ["rosenpass-wireguard-broker/experiment_memfd_secret"]
experiment_libcrux = ["rosenpass-ciphers/experiment_libcrux"]
experiment_api = ["hex-literal", "uds", "command-fds", "rosenpass-util/experiment_file_descriptor_passing", "rosenpass-wireguard-broker/experiment_api"]

View File

@@ -203,7 +203,7 @@ where
mio::net::UdpSocket::from_std(sock).ok()
});
let mut sock = match sock_res {
let sock = match sock_res {
Ok(sock) => sock,
Err(e) => {
log::debug!("Error processing AddListenSocket API request: {e:?}");
@@ -213,16 +213,7 @@ where
};
// Register socket
let reg_result = run(|| -> anyhow::Result<()> {
let srv = self.app_server_mut();
srv.mio_poll.registry().register(
&mut sock,
srv.mio_token_dispenser.dispense(),
mio::Interest::READABLE,
)?;
srv.sockets.push(sock);
Ok(())
});
let reg_result = self.app_server_mut().register_listen_socket(sock);
if let Err(internal_error) = reg_result {
log::warn!("Internal error processing AddListenSocket API request: {internal_error:?}");

View File

@@ -55,6 +55,7 @@ struct MioConnectionBuffers {
#[derive(Debug)]
pub struct MioConnection {
io: UnixStream,
mio_token: mio::Token,
invalid_read: bool,
buffers: Option<MioConnectionBuffers>,
api_handler: ApiHandler,
@@ -62,11 +63,11 @@ pub struct MioConnection {
impl MioConnection {
pub fn new(app_server: &mut AppServer, mut io: UnixStream) -> std::io::Result<Self> {
app_server.mio_poll.registry().register(
&mut io,
app_server.mio_token_dispenser.dispense(),
MIO_RW,
)?;
let mio_token = app_server.mio_token_dispenser.dispense();
app_server
.mio_poll
.registry()
.register(&mut io, mio_token, MIO_RW)?;
let invalid_read = false;
let read_buffer = LengthPrefixDecoder::new(SecretBuffer::new());
@@ -80,6 +81,7 @@ impl MioConnection {
let api_state = ApiHandler::new();
Ok(Self {
io,
mio_token,
invalid_read,
buffers,
api_handler: api_state,
@@ -99,6 +101,10 @@ impl MioConnection {
app_server.mio_poll.registry().deregister(&mut self.io)?;
Ok(())
}
pub fn mio_token(&self) -> mio::Token {
self.mio_token
}
}
pub trait MioConnectionContext {
@@ -250,6 +256,14 @@ pub trait MioConnectionContext {
};
}
}
fn mio_token(&self) -> mio::Token {
self.mio_connection().mio_token()
}
fn should_close(&self) -> bool {
self.mio_connection().shoud_close()
}
}
trait MioConnectionContextPrivate: MioConnectionContext {

View File

@@ -1,20 +1,25 @@
use std::{
borrow::{Borrow, BorrowMut},
io,
};
use std::{borrow::BorrowMut, io};
use mio::net::{UnixListener, UnixStream};
use rosenpass_util::{io::nonblocking_handle_io_errors, mio::interest::RW as MIO_RW};
use rosenpass_util::{
functional::ApplyExt, io::nonblocking_handle_io_errors, mio::interest::RW as MIO_RW,
};
use crate::app_server::AppServer;
use crate::app_server::{AppServer, AppServerIoSource};
use super::{MioConnection, MioConnectionContext};
#[derive(Default, Debug)]
pub struct MioManager {
listeners: Vec<UnixListener>,
connections: Vec<MioConnection>,
connections: Vec<Option<MioConnection>>,
}
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub enum MioManagerIoSource {
Listener(usize),
Connection(usize),
}
impl MioManager {
@@ -42,18 +47,49 @@ pub trait MioManagerContext {
fn add_listener(&mut self, mut listener: UnixListener) -> io::Result<()> {
let srv = self.app_server_mut();
srv.mio_poll.registry().register(
&mut listener,
srv.mio_token_dispenser.dispense(),
MIO_RW,
)?;
let mio_token = srv.mio_token_dispenser.dispense();
srv.mio_poll
.registry()
.register(&mut listener, mio_token, MIO_RW)?;
let io_source = self
.mio_manager()
.listeners
.len()
.apply(MioManagerIoSource::Listener)
.apply(AppServerIoSource::MioManager);
self.mio_manager_mut().listeners.push(listener);
self.app_server_mut()
.register_io_source(mio_token, io_source);
Ok(())
}
fn add_connection(&mut self, connection: UnixStream) -> io::Result<()> {
let connection = MioConnection::new(self.app_server_mut(), connection)?;
self.mio_manager_mut().connections.push(connection);
let mio_token = connection.mio_token();
let conns: &mut Vec<Option<MioConnection>> =
self.mio_manager_mut().connections.borrow_mut();
let idx = conns
.iter_mut()
.enumerate()
.find(|(_, slot)| slot.is_some())
.map(|(idx, _)| idx)
.unwrap_or(conns.len());
conns.insert(idx, Some(connection));
let io_source = idx
.apply(MioManagerIoSource::Listener)
.apply(AppServerIoSource::MioManager);
self.app_server_mut()
.register_io_source(mio_token, io_source);
Ok(())
}
fn poll_particular(&mut self, io_source: MioManagerIoSource) -> anyhow::Result<()> {
use MioManagerIoSource as S;
match io_source {
S::Listener(idx) => self.accept_from(idx)?,
S::Connection(idx) => self.poll_particular_connection(idx)?,
};
Ok(())
}
@@ -87,27 +123,38 @@ pub trait MioManagerContext {
}
fn poll_connections(&mut self) -> anyhow::Result<()> {
let mut idx = 0;
while idx < self.mio_manager().connections.len() {
if self.mio_manager().connections[idx].shoud_close() {
let conn = self.mio_manager_mut().connections.remove(idx);
if let Err(e) = conn.close(self.app_server_mut()) {
log::warn!("Error while closing API connection {e:?}");
};
continue;
}
MioConnectionFocus::new(self, idx).poll()?;
idx += 1;
for idx in 0..self.mio_manager().connections.len() {
self.poll_particular_connection(idx)?;
}
Ok(())
}
fn poll_particular_connection(&mut self, idx: usize) -> anyhow::Result<()> {
if self.mio_manager().connections[idx].is_none() {
return Ok(());
}
let mut conn = MioConnectionFocus::new(self, idx);
conn.poll()?;
if conn.should_close() {
let conn = self.mio_manager_mut().connections[idx].take().unwrap();
let mio_token = conn.mio_token();
if let Err(e) = conn.close(self.app_server_mut()) {
log::warn!("Error while closing API connection {e:?}");
};
self.app_server_mut().unregister_io_source(mio_token);
}
Ok(())
}
}
impl<T: ?Sized + MioManagerContext> MioConnectionContext for MioConnectionFocus<'_, T> {
fn mio_connection(&self) -> &MioConnection {
self.ctx.mio_manager().connections[self.conn_idx].borrow()
self.ctx.mio_manager().connections[self.conn_idx]
.as_ref()
.unwrap()
}
fn app_server(&self) -> &AppServer {
@@ -115,7 +162,9 @@ impl<T: ?Sized + MioManagerContext> MioConnectionContext for MioConnectionFocus<
}
fn mio_connection_mut(&mut self) -> &mut MioConnection {
self.ctx.mio_manager_mut().connections[self.conn_idx].borrow_mut()
self.ctx.mio_manager_mut().connections[self.conn_idx]
.as_mut()
.unwrap()
}
fn app_server_mut(&mut self) -> &mut AppServer {

View File

@@ -10,6 +10,12 @@ use rosenpass_secret_memory::Public;
use rosenpass_secret_memory::Secret;
use rosenpass_util::build::ConstructionSite;
use rosenpass_util::file::StoreValueB64;
use rosenpass_util::functional::run;
use rosenpass_util::functional::ApplyExt;
use rosenpass_util::io::IoResultKindHintExt;
use rosenpass_util::io::SubstituteForIoErrorKindExt;
use rosenpass_util::option::SomeExt;
use rosenpass_util::result::OkExt;
use rosenpass_wireguard_broker::WireguardBrokerMio;
use rosenpass_wireguard_broker::{WireguardBrokerCfg, WG_KEY_LEN};
use zerocopy::AsBytes;
@@ -17,7 +23,9 @@ use zerocopy::AsBytes;
use std::cell::Cell;
use std::collections::HashMap;
use std::collections::VecDeque;
use std::fmt::Debug;
use std::io;
use std::io::stdout;
use std::io::ErrorKind;
use std::io::Write;
@@ -143,6 +151,17 @@ pub struct AppServerTest {
pub termination_handler: Option<std::sync::mpsc::Receiver<()>>,
}
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub enum AppServerIoSource {
Socket(usize),
#[cfg(feature = "experiment_api")]
PskBroker(Public<BROKER_ID_BYTES>),
#[cfg(feature = "experiment_api")]
MioManager(crate::api::mio::MioManagerIoSource),
}
const EVENT_CAPACITY: usize = 20;
/// Holds the state of the application, namely the external IO
///
/// Responsible for file IO, network IO
@@ -152,6 +171,9 @@ pub struct AppServer {
pub crypto_site: ConstructionSite<BuildCryptoServer, CryptoServer>,
pub sockets: Vec<mio::net::UdpSocket>,
pub events: mio::Events,
pub short_poll_queue: VecDeque<mio::event::Event>,
pub performed_long_poll: bool,
pub io_source_index: HashMap<mio::Token, AppServerIoSource>,
pub mio_poll: mio::Poll,
pub mio_token_dispenser: MioTokenDispenser,
pub brokers: BrokerStore,
@@ -521,7 +543,7 @@ impl AppServer {
) -> anyhow::Result<Self> {
// setup mio
let mio_poll = mio::Poll::new()?;
let events = mio::Events::with_capacity(20);
let events = mio::Events::with_capacity(EVENT_CAPACITY);
let mut mio_token_dispenser = MioTokenDispenser::default();
// bind each SocketAddr to a socket
@@ -596,12 +618,14 @@ impl AppServer {
}
// register all sockets to mio
for socket in sockets.iter_mut() {
mio_poll.registry().register(
socket,
mio_token_dispenser.dispense(),
Interest::READABLE,
)?;
let mut io_source_index = HashMap::new();
for (idx, socket) in sockets.iter_mut().enumerate() {
let mio_token = mio_token_dispenser.dispense();
mio_poll
.registry()
.register(socket, mio_token, Interest::READABLE)?;
let prev = io_source_index.insert(mio_token, AppServerIoSource::Socket(idx));
assert!(prev.is_none());
}
let crypto_site = match keypair {
@@ -615,6 +639,9 @@ impl AppServer {
verbosity,
sockets,
events,
short_poll_queue: Default::default(),
performed_long_poll: false,
io_source_index,
mio_poll,
mio_token_dispenser,
brokers: BrokerStore::default(),
@@ -646,41 +673,57 @@ impl AppServer {
matches!(self.verbosity, Verbosity::Verbose)
}
pub fn register_listen_socket(&mut self, mut sock: mio::net::UdpSocket) -> anyhow::Result<()> {
let mio_token = self.mio_token_dispenser.dispense();
self.mio_poll
.registry()
.register(&mut sock, mio_token, mio::Interest::READABLE)?;
let io_source = self.sockets.len().apply(AppServerIoSource::Socket);
self.sockets.push(sock);
self.register_io_source(mio_token, io_source);
Ok(())
}
pub fn register_io_source(&mut self, token: mio::Token, io_source: AppServerIoSource) {
let prev = self.io_source_index.insert(token, io_source);
assert!(prev.is_none());
}
pub fn unregister_io_source(&mut self, token: mio::Token) {
let value = self.io_source_index.remove(&token);
assert!(value.is_some(), "Removed IO source that does not exist");
}
pub fn register_broker(
&mut self,
broker: Box<dyn WireguardBrokerMio<Error = anyhow::Error, MioError = anyhow::Error>>,
) -> Result<BrokerStorePtr> {
let ptr = Public::from_slice((self.brokers.store.len() as u64).as_bytes());
if self.brokers.store.insert(ptr, broker).is_some() {
bail!("Broker already registered");
}
let mio_token = self.mio_token_dispenser.dispense();
let io_source = ptr.apply(AppServerIoSource::PskBroker);
//Register broker
self.brokers
.store
.get_mut(&ptr)
.ok_or(anyhow::format_err!("Broker wasn't added to registry"))?
.register(
self.mio_poll.registry(),
self.mio_token_dispenser.dispense(),
)?;
.register(self.mio_poll.registry(), mio_token)?;
self.register_io_source(mio_token, io_source);
Ok(BrokerStorePtr(ptr))
}
pub fn unregister_broker(&mut self, ptr: BrokerStorePtr) -> Result<()> {
//Unregister broker
self.brokers
.store
.get_mut(&ptr.0)
.ok_or_else(|| anyhow::anyhow!("Broker not found"))?
.unregister(self.mio_poll.registry())?;
//Remove broker from store
self.brokers
let mut broker = self
.brokers
.store
.remove(&ptr.0)
.ok_or_else(|| anyhow::anyhow!("Broker not found"))?;
.context("Broker not found")?;
self.unregister_io_source(broker.mio_token().unwrap());
broker.unregister(self.mio_poll.registry())?;
Ok(())
}
@@ -998,22 +1041,33 @@ impl AppServer {
// readiness event seems to be good enough™ for now.
// only poll if we drained all sockets before
if self.all_sockets_drained {
//Non blocked polling
self.mio_poll
.poll(&mut self.events, Some(Duration::from_secs(0)))?;
if self.events.iter().peekable().peek().is_none() {
// if there are no events, then add to blocking poll count
self.blocking_polls_count += 1;
//Execute blocking poll
self.mio_poll.poll(&mut self.events, Some(timeout))?;
} else {
self.non_blocking_polls_count += 1;
run(|| -> anyhow::Result<()> {
if !self.all_sockets_drained || !self.short_poll_queue.is_empty() {
self.unpolled_count += 1;
return Ok(());
}
} else {
self.unpolled_count += 1;
}
self.perform_mio_poll_and_register_events(Duration::from_secs(0))?; // Non-blocking poll
if !self.short_poll_queue.is_empty() {
// Got some events in non-blocking mode
self.non_blocking_polls_count += 1;
return Ok(());
}
if !self.performed_long_poll {
// pass go perform a full long poll before we enter blocking poll mode
// to make sure our experimental short poll feature did not miss any events
// due to being buggy.
return Ok(());
}
// Perform and register blocking poll
self.blocking_polls_count += 1;
self.perform_mio_poll_and_register_events(timeout)?;
self.performed_long_poll = false;
Ok(())
})?;
if let Some(AppServerTest {
enable_dos_permanently: true,
@@ -1048,26 +1102,58 @@ impl AppServer {
}
}
// Focused polling i.e. actually using mio::Token is experimental for now.
// The reason for this is that we need to figure out how to integrate load detection
// and focused polling for one. Mio event-based polling also does not play nice with
// the current function signature and its reentrant design which is focused around receiving UDP socket packages
// for processing by the crypto protocol server.
// Besides that, there are also some parts of the code which intentionally block
// despite available data. This is the correct behavior; e.g. api::mio::Connection blocks
// further reads from its unix socket until the write buffer is flushed. In other words
// the connection handler makes sure that there is a buffer to put the response in while
// before reading further request.
// The potential problem with this behavior is that we end up ignoring instructions from
// epoll() to read from the particular sockets, so epoll will return information about that
// particular blocked file descriptor every call. We have only so many event slots and
// in theory, the event array could fill up entirely with intentionally blocked sockets.
// We need to figure out how to deal with this situation.
// Mio uses uses epoll in level-triggered mode, so we could handle taint-tracking for ignored
// sockets ourselves. The facilities are available in epoll and Mio, but we need to figure out how mio uses those
// facilities and how we can integrate them here.
// This will involve rewriting a lot of IO code and we should probably have integration
// tests before we approach that.
//
// This hybrid approach is not without merit though; the short poll implementation covers
// all our IO sources, so under contention, rosenpass should generally not hit the long
// poll mode below. We keep short polling and calling epoll() in non-blocking mode (timeout
// of zero) until we run out of IO events processed. Then, just before we would perform a
// blocking poll, we go through all available IO sources to see if we missed anything.
{
while let Some(ev) = self.short_poll_queue.pop_front() {
if let Some(v) = self.try_recv_from_mio_token(buf, ev.token())? {
return Ok(Some(v));
}
}
}
// drain all sockets
let mut would_block_count = 0;
for (sock_no, socket) in self.sockets.iter_mut().enumerate() {
match socket.recv_from(buf) {
Ok((n, addr)) => {
for sock_no in 0..self.sockets.len() {
match self
.try_recv_from_listen_socket(buf, sock_no)
.io_err_kind_hint()
{
Ok(None) => continue,
Ok(Some(v)) => {
// at least one socket was not drained...
self.all_sockets_drained = false;
return Ok(Some((
n,
Endpoint::SocketBoundAddress(SocketBoundEndpoint::new(
SocketPtr(sock_no),
addr,
)),
)));
return Ok(Some(v));
}
Err(e) if e.kind() == ErrorKind::WouldBlock => {
Err((_, ErrorKind::WouldBlock)) => {
would_block_count += 1;
}
// TODO if one socket continuously returns an error, then we never poll, thus we never wait for a timeout, thus we have a spin-lock
Err(e) => return Err(e.into()),
Err((e, _)) => return Err(e)?,
}
}
@@ -1087,9 +1173,87 @@ impl AppServer {
MioManagerFocus(self).poll()?;
}
self.performed_long_poll = true;
Ok(None)
}
fn perform_mio_poll_and_register_events(&mut self, timeout: Duration) -> io::Result<()> {
self.mio_poll.poll(&mut self.events, Some(timeout))?;
// Fill the short poll buffer with the acquired events
self.events
.iter()
.cloned()
.for_each(|v| self.short_poll_queue.push_back(v));
Ok(())
}
fn try_recv_from_mio_token(
&mut self,
buf: &mut [u8],
token: mio::Token,
) -> anyhow::Result<Option<(usize, Endpoint)>> {
let io_source = match self.io_source_index.get(&token) {
Some(io_source) => *io_source,
None => {
log::warn!("No IO source assiociated with mio token ({token:?}). Polling using mio tokens directly is an experimental feature and IO handler should recover when all available io sources are polled. This is a developer error. Please report it.");
return Ok(None);
}
};
self.try_recv_from_io_source(buf, io_source)
}
fn try_recv_from_io_source(
&mut self,
buf: &mut [u8],
io_source: AppServerIoSource,
) -> anyhow::Result<Option<(usize, Endpoint)>> {
use crate::api::mio::MioManagerContext;
match io_source {
AppServerIoSource::Socket(idx) => self
.try_recv_from_listen_socket(buf, idx)
.substitute_for_ioerr_wouldblock(None)?
.ok(),
#[cfg(feature = "experiment_api")]
AppServerIoSource::PskBroker(key) => self
.brokers
.store
.get_mut(&key)
.with_context(|| format!("No PSK broker under key {key:?}"))?
.process_poll()
.map(|_| None),
#[cfg(feature = "experiment_api")]
AppServerIoSource::MioManager(mmio_src) => MioManagerFocus(self)
.poll_particular(mmio_src)
.map(|_| None),
}
}
fn try_recv_from_listen_socket(
&mut self,
buf: &mut [u8],
idx: usize,
) -> io::Result<Option<(usize, Endpoint)>> {
use std::io::ErrorKind as K;
let (n, addr) = loop {
match self.sockets[idx].recv_from(buf).io_err_kind_hint() {
Ok(v) => break v,
Err((_, K::Interrupted)) => continue,
Err((e, _)) => return Err(e)?,
}
};
SocketPtr(idx)
.apply(|sp| SocketBoundEndpoint::new(sp, addr))
.apply(Endpoint::SocketBoundAddress)
.apply(|ep| (n, ep))
.some()
.ok()
}
#[cfg(feature = "experiment_api")]
pub fn add_api_connection(&mut self, connection: mio::net::UnixStream) -> std::io::Result<()> {
use crate::api::mio::MioManagerContext;

View File

@@ -141,6 +141,7 @@ fn api_integration_api_setup() -> anyhow::Result<()> {
peer_b.config_file_path.to_str().context("")?,
])
.stdin(Stdio::null())
.stderr(Stdio::null())
.stdout(Stdio::piped())
.spawn()?;

View File

@@ -293,6 +293,7 @@ struct MockBrokerInner {
#[derive(Debug, Default)]
struct MockBroker {
inner: Arc<Mutex<MockBrokerInner>>,
mio_token: Option<mio::Token>,
}
impl WireguardBrokerMio for MockBroker {
@@ -301,8 +302,9 @@ impl WireguardBrokerMio for MockBroker {
fn register(
&mut self,
_registry: &mio::Registry,
_token: mio::Token,
token: mio::Token,
) -> Result<(), Self::MioError> {
self.mio_token = Some(token);
Ok(())
}
@@ -311,8 +313,13 @@ impl WireguardBrokerMio for MockBroker {
}
fn unregister(&mut self, _registry: &mio::Registry) -> Result<(), Self::MioError> {
self.mio_token = None;
Ok(())
}
fn mio_token(&self) -> Option<mio::Token> {
self.mio_token
}
}
impl rosenpass_wireguard_broker::WireGuardBroker for MockBroker {

View File

@@ -52,6 +52,56 @@ impl<T, E: TryIoErrorKind> TryIoResultKindHintExt<T> for Result<T, E> {
}
}
pub trait SubstituteForIoErrorKindExt<T>: Sized {
type Error;
fn substitute_for_ioerr_kind_with<F: FnOnce() -> T>(
self,
kind: io::ErrorKind,
f: F,
) -> Result<T, Self::Error>;
fn substitute_for_ioerr_kind(self, kind: io::ErrorKind, v: T) -> Result<T, Self::Error> {
self.substitute_for_ioerr_kind_with(kind, || v)
}
fn substitute_for_ioerr_interrupted_with<F: FnOnce() -> T>(
self,
f: F,
) -> Result<T, Self::Error> {
self.substitute_for_ioerr_kind_with(io::ErrorKind::Interrupted, f)
}
fn substitute_for_ioerr_interrupted(self, v: T) -> Result<T, Self::Error> {
self.substitute_for_ioerr_interrupted_with(|| v)
}
fn substitute_for_ioerr_wouldblock_with<F: FnOnce() -> T>(
self,
f: F,
) -> Result<T, Self::Error> {
self.substitute_for_ioerr_kind_with(io::ErrorKind::WouldBlock, f)
}
fn substitute_for_ioerr_wouldblock(self, v: T) -> Result<T, Self::Error> {
self.substitute_for_ioerr_wouldblock_with(|| v)
}
}
impl<T, E: TryIoErrorKind> SubstituteForIoErrorKindExt<T> for Result<T, E> {
type Error = E;
fn substitute_for_ioerr_kind_with<F: FnOnce() -> T>(
self,
kind: io::ErrorKind,
f: F,
) -> Result<T, Self::Error> {
match self.try_io_err_kind_hint() {
Ok(v) => Ok(v),
Err((_, Some(k))) if k == kind => Ok(f()),
Err((e, _)) => Err(e),
}
}
}
/// Automatically handles `std::io::ErrorKind::Interrupted`.
///
/// - If there is no error (i.e. on `Ok(r)`), the function will return `Ok(Some(r))`

View File

@@ -10,6 +10,7 @@ pub mod io;
pub mod length_prefix_encoding;
pub mod mem;
pub mod mio;
pub mod option;
pub mod ord;
pub mod result;
pub mod time;

7
util/src/option.rs Normal file
View File

@@ -0,0 +1,7 @@
pub trait SomeExt: Sized {
fn some(self) -> Option<Self> {
Some(self)
}
}
impl<T> SomeExt for T {}

View File

@@ -16,6 +16,7 @@ use crate::{SerializedBrokerConfig, WireGuardBroker, WireguardBrokerMio};
#[derive(Debug)]
pub struct MioBrokerClient {
inner: BrokerClient<MioBrokerClientIo>,
mio_token: Option<mio::Token>,
}
#[derive(Debug)]
@@ -59,7 +60,10 @@ impl MioBrokerClient {
write_buffer,
};
let inner = BrokerClient::new(io);
Self { inner }
Self {
inner,
mio_token: None,
}
}
fn poll(&mut self) -> anyhow::Result<()> {
@@ -104,6 +108,7 @@ impl WireguardBrokerMio for MioBrokerClient {
registry: &mio::Registry,
token: mio::Token,
) -> Result<(), Self::MioError> {
self.mio_token = Some(token);
registry.register(
&mut self.inner.io_mut().socket,
token,
@@ -118,9 +123,14 @@ impl WireguardBrokerMio for MioBrokerClient {
}
fn unregister(&mut self, registry: &mio::Registry) -> Result<(), Self::MioError> {
self.mio_token = None;
registry.deregister(&mut self.inner.io_mut().socket)?;
Ok(())
}
fn mio_token(&self) -> Option<mio::Token> {
self.mio_token
}
}
impl BrokerClientIo for MioBrokerClientIo {

View File

@@ -16,7 +16,9 @@ const MAX_B64_KEY_SIZE: usize = WG_KEY_LEN * 5 / 3;
const MAX_B64_PEER_ID_SIZE: usize = WG_PEER_LEN * 5 / 3;
#[derive(Debug)]
pub struct NativeUnixBroker {}
pub struct NativeUnixBroker {
mio_token: Option<mio::Token>,
}
impl Default for NativeUnixBroker {
fn default() -> Self {
@@ -26,7 +28,7 @@ impl Default for NativeUnixBroker {
impl NativeUnixBroker {
pub fn new() -> Self {
Self {}
Self { mio_token: None }
}
}
@@ -88,8 +90,9 @@ impl WireguardBrokerMio for NativeUnixBroker {
fn register(
&mut self,
_registry: &mio::Registry,
_token: mio::Token,
token: mio::Token,
) -> Result<(), Self::MioError> {
self.mio_token = Some(token);
Ok(())
}
@@ -98,8 +101,13 @@ impl WireguardBrokerMio for NativeUnixBroker {
}
fn unregister(&mut self, _registry: &mio::Registry) -> Result<(), Self::MioError> {
self.mio_token = None;
Ok(())
}
fn mio_token(&self) -> Option<mio::Token> {
self.mio_token
}
}
#[derive(Debug, Builder)]

View File

@@ -28,6 +28,8 @@ pub trait WireguardBrokerMio: WireGuardBroker {
registry: &mio::Registry,
token: mio::Token,
) -> Result<(), Self::MioError>;
fn mio_token(&self) -> Option<mio::Token>;
/// Run after a mio::poll operation
fn process_poll(&mut self) -> Result<(), Self::MioError>;