diff --git a/Cargo.lock b/Cargo.lock index 7efbdde..44e4884 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2070,6 +2070,7 @@ dependencies = [ "serde", "serial_test", "signal-hook", + "signal-hook-mio", "stacker", "static_assertions", "tempfile", @@ -2152,6 +2153,38 @@ dependencies = [ "rosenpass-util", ] +[[package]] +name = "rosenpass-rp" +version = "0.2.1" +dependencies = [ + "anyhow", + "base64ct", + "ctrlc-async", + "env_logger", + "futures", + "futures-util", + "genetlink", + "libc", + "log", + "netlink-packet-core", + "netlink-packet-generic", + "netlink-packet-wireguard", + "rosenpass", + "rosenpass-cipher-traits", + "rosenpass-ciphers", + "rosenpass-secret-memory", + "rosenpass-util", + "rosenpass-wireguard-broker", + "rtnetlink", + "serde", + "stacker", + "tempfile", + "tokio", + "toml", + "x25519-dalek", + "zeroize", +] + [[package]] name = "rosenpass-secret-memory" version = "0.1.0" @@ -2184,11 +2217,13 @@ dependencies = [ "anyhow", "base64ct", "libcrux-test-utils", + "log", "mio", "rustix", "static_assertions", "tempfile", "thiserror 1.0.69", + "tokio", "typenum", "uds", "zerocopy 0.7.35", @@ -2219,35 +2254,6 @@ dependencies = [ "zerocopy 0.7.35", ] -[[package]] -name = "rp" -version = "0.2.1" -dependencies = [ - "anyhow", - "base64ct", - "ctrlc-async", - "futures", - "futures-util", - "genetlink", - "netlink-packet-core", - "netlink-packet-generic", - "netlink-packet-wireguard", - "rosenpass", - "rosenpass-cipher-traits", - "rosenpass-ciphers", - "rosenpass-secret-memory", - "rosenpass-util", - "rosenpass-wireguard-broker", - "rtnetlink", - "serde", - "stacker", - "tempfile", - "tokio", - "toml", - "x25519-dalek", - "zeroize", -] - [[package]] name = "rtnetlink" version = "0.14.1" @@ -2432,14 +2438,25 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "signal-hook" -version = "0.3.17" +version = "0.3.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801" +checksum = "d881a16cf4426aa584979d30bd82cb33429027e42122b169753d6ef1085ed6e2" dependencies = [ "libc", "signal-hook-registry", ] +[[package]] +name = "signal-hook-mio" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd" +dependencies = [ + "libc", + "mio", + "signal-hook", +] + [[package]] name = "signal-hook-registry" version = "1.4.2" diff --git a/Cargo.toml b/Cargo.toml index 9b94447..8e9b19c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,6 +54,8 @@ serde = { version = "1.0.217", features = ["derive"] } arbitrary = { version = "1.4.1", features = ["derive"] } anyhow = { version = "1.0.95", features = ["backtrace", "std"] } mio = { version = "1.0.3", features = ["net", "os-poll"] } +signal-hook-mio = { version = "0.2.4", features = ["support-v1_0"] } +signal-hook = "0.3.17" oqs-sys = { version = "0.9.1", default-features = false, features = [ 'classic_mceliece', 'kyber', @@ -79,7 +81,6 @@ hex = { version = "0.4.3" } heck = { version = "0.5.0" } libc = { version = "0.2" } uds = { git = "https://github.com/rosenpass/uds" } -signal-hook = "0.3.17" lazy_static = "1.5" #Dev dependencies diff --git a/cipher-traits/src/primitives/keyed_hash.rs b/cipher-traits/src/primitives/keyed_hash.rs index 93ecaf1..426cc78 100644 --- a/cipher-traits/src/primitives/keyed_hash.rs +++ b/cipher-traits/src/primitives/keyed_hash.rs @@ -40,7 +40,7 @@ pub struct InferKeyedHash where Static: KeyedHash, { - pub _phantom_keyed_hasher: PhantomData<*const Static>, + pub _phantom_keyed_hasher: PhantomData, } impl InferKeyedHash diff --git a/rosenpass/Cargo.toml b/rosenpass/Cargo.toml index c23d1f3..9ce2686 100644 --- a/rosenpass/Cargo.toml +++ b/rosenpass/Cargo.toml @@ -64,6 +64,8 @@ clap = { workspace = true } clap_complete = { workspace = true } clap_mangen = { workspace = true } mio = { workspace = true } +signal-hook = { workspace = true } +signal-hook-mio = { workspace = true } rand = { workspace = true } zerocopy = { workspace = true } home = { workspace = true } @@ -76,7 +78,6 @@ heck = { workspace = true, optional = true } command-fds = { workspace = true, optional = true } rustix = { workspace = true, optional = true } uds = { workspace = true, optional = true, features = ["mio_1xx"] } -signal-hook = { workspace = true, optional = true } libcrux-test-utils = { workspace = true, optional = true } [build-dependencies] @@ -109,7 +110,6 @@ experiment_api = [ "rosenpass-util/experiment_file_descriptor_passing", "rosenpass-wireguard-broker/experiment_api", ] -internal_signal_handling_for_coverage_reports = ["signal-hook"] internal_testing = [] internal_bin_gen_ipc_msg_types = ["hex", "heck"] trace_bench = ["rosenpass-util/trace_bench", "dep:libcrux-test-utils"] diff --git a/rosenpass/src/app_server.rs b/rosenpass/src/app_server.rs index cd406f9..eeff163 100644 --- a/rosenpass/src/app_server.rs +++ b/rosenpass/src/app_server.rs @@ -7,17 +7,20 @@ use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSoc use std::time::{Duration, Instant}; use std::{cell::Cell, fmt::Debug, io, path::PathBuf, slice}; +use mio::{Interest, Token}; +use signal_hook_mio::v1_0 as signal_hook_mio; + use anyhow::{bail, Context, Result}; use derive_builder::Builder; use log::{error, info, warn}; -use mio::{Interest, Token}; use zerocopy::AsBytes; use rosenpass_util::attempt; +use rosenpass_util::fmt::debug::NullDebug; use rosenpass_util::functional::{run, ApplyExt}; use rosenpass_util::io::{IoResultKindHintExt, SubstituteForIoErrorKindExt}; use rosenpass_util::{ - b64::B64Display, build::ConstructionSite, file::StoreValueB64, option::SomeExt, result::OkExt, + b64::B64Display, build::ConstructionSite, file::StoreValueB64, result::OkExt, }; use rosenpass_secret_memory::{Public, Secret}; @@ -129,7 +132,7 @@ pub struct BrokerStore { /// The collection of WireGuard brokers. See [Self]. pub store: HashMap< Public, - Box>, + Box + Send>, >, } @@ -146,12 +149,12 @@ pub struct BrokerPeer { /// /// This is woefully overengineered and there is very little reason why the broker /// configuration should not live in the particular WireGuard broker. - peer_cfg: Box, + peer_cfg: Box, } impl BrokerPeer { /// Create a broker peer - pub fn new(ptr: BrokerStorePtr, peer_cfg: Box) -> Self { + pub fn new(ptr: BrokerStorePtr, peer_cfg: Box) -> Self { Self { ptr, peer_cfg } } @@ -286,12 +289,20 @@ pub enum AppServerIoSource { Socket(usize), /// IO source refers to a PSK broker in [AppServer::brokers] PskBroker(Public), + /// IO source refers to our signal handlers + SignalHandler, /// IO source refers to some IO sources used in the API; /// see [AppServer::api_manager] #[cfg(feature = "experiment_api")] MioManager(crate::api::mio::MioManagerIoSource), } +pub enum AppServerTryRecvResult { + None, + Terminate, + NetworkMessage(usize, Endpoint), +} + /// Number of epoll(7) events Rosenpass can receive at a time const EVENT_CAPACITY: usize = 20; @@ -332,6 +343,8 @@ pub struct AppServer { /// MIO associates IO sources with numeric tokens. This struct takes care of generating these /// tokens pub mio_token_dispenser: MioTokenDispenser, + /// Mio-based handler for signals + pub signal_handler: NullDebug, /// Helpers handling communication with WireGuard; these take a generated key and forward it to /// WireGuard pub brokers: BrokerStore, @@ -357,16 +370,6 @@ pub struct AppServer { /// Used by integration tests to force [Self] into DoS condition /// and to terminate the AppServer after the test is complete pub test_helpers: Option, - /// Helper for integration tests running rosenpass as a subprocess - /// to terminate properly upon receiving an appropriate system signal. - /// - /// This is primarily needed for coverage testing, since llvm-cov does not - /// write coverage reports to disk when a process is stopped by the default - /// signal handler. - /// - /// See - #[cfg(feature = "internal_signal_handling_for_coverage_reports")] - pub term_signal: terminate::TerminateRequested, #[cfg(feature = "experiment_api")] /// The Rosenpass unix socket API handler; this is an experimental /// feature that can be used to embed Rosenpass in external applications @@ -456,6 +459,8 @@ impl AppPeerPtr { /// Instructs [AppServer::event_loop_without_error_handling] on how to proceed. #[derive(Debug)] pub enum AppPollResult { + /// Received request to terminate the application + Terminate, /// Erase the key for a given peer. Corresponds to [crate::protocol::PollResult::DeleteKey] DeleteKey(AppPeerPtr), /// Send an initiation to the given peer. Corresponds to [crate::protocol::PollResult::SendInitiation] @@ -802,10 +807,27 @@ impl AppServer { verbosity: Verbosity, test_helpers: Option, ) -> anyhow::Result { - // setup mio + // Setup Mio itself let mio_poll = mio::Poll::new()?; let events = mio::Events::with_capacity(EVENT_CAPACITY); + + // And helpers to map mio tokens to internal event types let mut mio_token_dispenser = MioTokenDispenser::default(); + let mut io_source_index = HashMap::new(); + + // Setup signal handling + let signal_handler = attempt!({ + let mut handler = + signal_hook_mio::Signals::new(signal_hook::consts::TERM_SIGNALS.iter())?; + let mio_token = mio_token_dispenser.dispense(); + mio_poll + .registry() + .register(&mut handler, mio_token, Interest::READABLE)?; + let prev = io_source_index.insert(mio_token, AppServerIoSource::SignalHandler); + assert!(prev.is_none()); + Ok(NullDebug(handler)) + }) + .context("Failed to set up signal (user triggered program termination) handler")?; // bind each SocketAddr to a socket let maybe_sockets: Result, _> = @@ -879,7 +901,6 @@ impl AppServer { } // register all sockets to mio - let mut io_source_index = HashMap::new(); for (idx, socket) in sockets.iter_mut().enumerate() { let mio_token = mio_token_dispenser.dispense(); mio_poll @@ -895,8 +916,6 @@ impl AppServer { }; Ok(Self { - #[cfg(feature = "internal_signal_handling_for_coverage_reports")] - term_signal: terminate::TerminateRequested::new()?, crypto_site, peers: Vec::new(), verbosity, @@ -907,6 +926,7 @@ impl AppServer { io_source_index, mio_poll, mio_token_dispenser, + signal_handler, brokers: BrokerStore::default(), all_sockets_drained: false, under_load: DoSOperation::Normal, @@ -977,7 +997,7 @@ impl AppServer { /// Register a new WireGuard PSK broker pub fn register_broker( &mut self, - broker: Box>, + broker: Box + Send>, ) -> Result { let ptr = Public::from_slice((self.brokers.store.len() as u64).as_bytes()); if self.brokers.store.insert(ptr, broker).is_some() { @@ -1049,7 +1069,7 @@ impl AppServer { Ok(AppPeerPtr(pn)) } - /// Main IO handler; this generally does not terminate + /// Main IO handler; this generally does not terminate other than through unix signals /// /// # Examples /// @@ -1066,23 +1086,6 @@ impl AppServer { Err(e) => e, }; - #[cfg(feature = "internal_signal_handling_for_coverage_reports")] - { - let terminated_by_signal = err - .downcast_ref::() - .filter(|e| e.kind() == std::io::ErrorKind::Interrupted) - .filter(|_| self.term_signal.value()) - .is_some(); - if terminated_by_signal { - log::warn!( - "\ - Terminated by signal; this signal handler is correct during coverage testing \ - but should be otherwise disabled" - ); - return Ok(()); - } - } - // This should not happen… failure_cnt = if msgs_processed > 0 { 0 @@ -1135,6 +1138,7 @@ impl AppServer { use AppPollResult::*; use KeyOutputReason::*; + // TODO: We should read from this using a mio channel if let Some(AppServerTest { termination_handler: Some(terminate), .. @@ -1158,6 +1162,8 @@ impl AppServer { #[allow(clippy::redundant_closure_call)] match (have_crypto, poll_result) { + (_, Terminate) => return Ok(()), + (CryptoSrv::Missing, SendInitiation(_)) => {} (CryptoSrv::Avail, SendInitiation(peer)) => tx_maybe_with!(peer, || self .crypto_server_mut()? @@ -1305,6 +1311,7 @@ impl AppServer { pub fn poll(&mut self, rx_buf: &mut [u8]) -> anyhow::Result { use crate::protocol::PollResult as C; use AppPollResult as A; + use AppServerTryRecvResult as R; let res = loop { // Call CryptoServer's poll (if available) let crypto_poll = self @@ -1325,8 +1332,10 @@ impl AppServer { }; // Perform IO (look for a message) - if let Some((len, addr)) = self.try_recv(rx_buf, io_poll_timeout)? { - break A::ReceivedMessage(len, addr); + match self.try_recv(rx_buf, io_poll_timeout)? { + R::None => {} + R::Terminate => break A::Terminate, + R::NetworkMessage(len, addr) => break A::ReceivedMessage(len, addr), } }; @@ -1344,12 +1353,12 @@ impl AppServer { &mut self, buf: &mut [u8], timeout: Timing, - ) -> anyhow::Result> { + ) -> anyhow::Result { let timeout = Duration::from_secs_f64(timeout); // if there is no time to wait on IO, well, then, lets not waste any time! if timeout.is_zero() { - return Ok(None); + return Ok(AppServerTryRecvResult::None); } // NOTE when using mio::Poll, there are some particularities (taken from @@ -1459,12 +1468,19 @@ impl AppServer { // blocking poll, we go through all available IO sources to see if we missed anything. { while let Some(ev) = self.short_poll_queue.pop_front() { - if let Some(v) = self.try_recv_from_mio_token(buf, ev.token())? { - return Ok(Some(v)); + match self.try_recv_from_mio_token(buf, ev.token())? { + AppServerTryRecvResult::None => continue, + res => return Ok(res), } } } + // Drain operating system signals + match self.try_recv_from_signal_handler()? { + AppServerTryRecvResult::None => {} // Nop + res => return Ok(res), + } + // drain all sockets let mut would_block_count = 0; for sock_no in 0..self.sockets.len() { @@ -1472,11 +1488,11 @@ impl AppServer { .try_recv_from_listen_socket(buf, sock_no) .io_err_kind_hint() { - Ok(None) => continue, - Ok(Some(v)) => { + Ok(AppServerTryRecvResult::None) => continue, + Ok(res) => { // at least one socket was not drained... self.all_sockets_drained = false; - return Ok(Some(v)); + return Ok(res); } Err((_, ErrorKind::WouldBlock)) => { would_block_count += 1; @@ -1504,12 +1520,24 @@ impl AppServer { self.performed_long_poll = true; - Ok(None) + Ok(AppServerTryRecvResult::None) } /// Internal helper for [Self::try_recv] fn perform_mio_poll_and_register_events(&mut self, timeout: Duration) -> io::Result<()> { - self.mio_poll.poll(&mut self.events, Some(timeout))?; + loop { + use std::io::ErrorKind as IOE; + match self + .mio_poll + .poll(&mut self.events, Some(timeout)) + .io_err_kind_hint() + { + Ok(()) => break, + Err((_, IOE::Interrupted)) => continue, + Err((err, _)) => return Err(err), + } + } + // Fill the short poll buffer with the acquired events self.events .iter() @@ -1523,12 +1551,12 @@ impl AppServer { &mut self, buf: &mut [u8], token: mio::Token, - ) -> anyhow::Result> { + ) -> anyhow::Result { let io_source = match self.io_source_index.get(&token) { Some(io_source) => *io_source, None => { log::warn!("No IO source assiociated with mio token ({token:?}). Polling using mio tokens directly is an experimental feature and IO handler should recover when all available io sources are polled. This is a developer error. Please report it."); - return Ok(None); + return Ok(AppServerTryRecvResult::None); } }; @@ -1540,11 +1568,13 @@ impl AppServer { &mut self, buf: &mut [u8], io_source: AppServerIoSource, - ) -> anyhow::Result> { + ) -> anyhow::Result { match io_source { + AppServerIoSource::SignalHandler => self.try_recv_from_signal_handler()?.ok(), + AppServerIoSource::Socket(idx) => self .try_recv_from_listen_socket(buf, idx) - .substitute_for_ioerr_wouldblock(None)? + .substitute_for_ioerr_wouldblock(AppServerTryRecvResult::None)? .ok(), AppServerIoSource::PskBroker(key) => self @@ -1553,7 +1583,7 @@ impl AppServer { .get_mut(&key) .with_context(|| format!("No PSK broker under key {key:?}"))? .process_poll() - .map(|_| None), + .map(|_| AppServerTryRecvResult::None), #[cfg(feature = "experiment_api")] AppServerIoSource::MioManager(mmio_src) => { @@ -1561,17 +1591,28 @@ impl AppServer { MioManagerFocus(self) .poll_particular(mmio_src) - .map(|_| None) + .map(|_| AppServerTryRecvResult::None) } } } + /// Internal helper for [Self::try_recv] + fn try_recv_from_signal_handler(&mut self) -> io::Result { + #[allow(clippy::never_loop)] + for signal in self.signal_handler.pending() { + log::debug!("Received operating system signal no {signal}."); + log::info!("Received termination request; exiting."); + return Ok(AppServerTryRecvResult::Terminate); + } + Ok(AppServerTryRecvResult::None) + } + /// Internal helper for [Self::try_recv] fn try_recv_from_listen_socket( &mut self, buf: &mut [u8], idx: usize, - ) -> io::Result> { + ) -> io::Result { use std::io::ErrorKind as K; let (n, addr) = loop { match self.sockets[idx].recv_from(buf).io_err_kind_hint() { @@ -1583,8 +1624,7 @@ impl AppServer { SocketPtr(idx) .apply(|sp| SocketBoundEndpoint::new(sp, addr)) .apply(Endpoint::SocketBoundAddress) - .apply(|ep| (n, ep)) - .some() + .apply(|ep| AppServerTryRecvResult::NetworkMessage(n, ep)) .ok() } @@ -1636,48 +1676,3 @@ impl crate::api::mio::MioManagerContext for MioManagerFocus<'_> { self.0 } } - -/// These signal handlers are used exclusively used during coverage testing -/// to ensure that the llvm-cov can produce reports during integration tests -/// with multiple processes where subprocesses are terminated via kill(2). -/// -/// llvm-cov does not support producing coverage reports when the process exits -/// through a signal, so this is necessary. -/// -/// The functionality of exiting gracefully upon reception of a terminating signal -/// is desired for the production variant of Rosenpass, but we should make sure -/// to use a higher quality implementation; in particular, we should use signalfd(2). -/// -#[cfg(feature = "internal_signal_handling_for_coverage_reports")] -mod terminate { - use signal_hook::flag::register as sig_register; - use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }; - - /// Automatically register a signal handler for common termination signals; - /// whether one of these signals was issued can be polled using [Self::value]. - /// - /// The signal handler is not removed when this struct goes out of scope. - #[derive(Debug)] - pub struct TerminateRequested { - value: Arc, - } - - impl TerminateRequested { - /// Register signal handlers watching for common termination signals - pub fn new() -> anyhow::Result { - let value = Arc::new(AtomicBool::new(false)); - for sig in signal_hook::consts::TERM_SIGNALS.iter().copied() { - sig_register(sig, Arc::clone(&value))?; - } - Ok(Self { value }) - } - - /// Check whether a termination signal has been set - pub fn value(&self) -> bool { - self.value.load(Ordering::Relaxed) - } - } -} diff --git a/rosenpass/src/cli.rs b/rosenpass/src/cli.rs index 186f617..f242a92 100644 --- a/rosenpass/src/cli.rs +++ b/rosenpass/src/cli.rs @@ -490,7 +490,7 @@ impl CliArgs { cfg_peer.key_out, broker_peer, cfg_peer.endpoint.clone(), - cfg_peer.protocol_version.into(), + cfg_peer.protocol_version, cfg_peer.osk_domain_separator.try_into()?, )?; } @@ -515,7 +515,7 @@ impl CliArgs { fn create_broker( broker_interface: Option, ) -> Result< - Box>, + Box + Send>, anyhow::Error, > { if let Some(interface) = broker_interface { diff --git a/rosenpass/src/config.rs b/rosenpass/src/config.rs index a4f494b..31a4b5c 100644 --- a/rosenpass/src/config.rs +++ b/rosenpass/src/config.rs @@ -200,7 +200,7 @@ impl RosenpassPeerOskDomainSeparator { pub fn org_and_label(&self) -> anyhow::Result)>> { match (&self.osk_organization, &self.osk_label) { (None, None) => Ok(None), - (Some(org), Some(label)) => Ok(Some((&org, &label))), + (Some(org), Some(label)) => Ok(Some((org, label))), (Some(_), None) => bail!("Specified osk_organization but not osk_label in config file. You need to specify both, or none."), (None, Some(_)) => bail!("Specified osk_label but not osk_organization in config file. You need to specify both, or none."), } diff --git a/rosenpass/src/protocol/protocol.rs b/rosenpass/src/protocol/protocol.rs index 778a1be..568ef08 100644 --- a/rosenpass/src/protocol/protocol.rs +++ b/rosenpass/src/protocol/protocol.rs @@ -1943,7 +1943,7 @@ impl CryptoServer { &mut self, rx_buf: &[u8], tx_buf: &mut [u8], - host_identification: &H, + _host_identification: &H, ) -> Result { self.handle_msg(rx_buf, tx_buf) } @@ -3231,7 +3231,7 @@ impl HandshakeState { let k = bk.get(srv).value.secret(); let pt = biscuit.as_bytes(); - XAead.encrypt_with_nonce_in_ctxt(biscuit_ct, k, &*n, &ad, pt)?; + XAead.encrypt_with_nonce_in_ctxt(biscuit_ct, k, &n, &ad, pt)?; self.mix(biscuit_ct) } @@ -3421,7 +3421,7 @@ impl CryptoServer { // IHI3 protocol_section!("IHI3", { - EphemeralKem.keygen(hs.eski.secret_mut(), &mut *hs.epki)?; + EphemeralKem.keygen(hs.eski.secret_mut(), &mut hs.epki)?; ih.epki.copy_from_slice(&hs.epki.value); }); diff --git a/rosenpass/tests/api-integration-tests-api-setup.rs b/rosenpass/tests/api-integration-tests-api-setup.rs index ee9341a..54ae051 100644 --- a/rosenpass/tests/api-integration-tests-api-setup.rs +++ b/rosenpass/tests/api-integration-tests-api-setup.rs @@ -105,7 +105,7 @@ fn api_integration_api_setup(protocol_version: ProtocolVersion) -> anyhow::Resul peer: format!("{}", peer_b_wg_peer_id.fmt_b64::<8129>()), extra_params: vec![], }), - protocol_version: protocol_version.clone(), + protocol_version: protocol_version, osk_domain_separator: Default::default(), }], }; @@ -127,7 +127,7 @@ fn api_integration_api_setup(protocol_version: ProtocolVersion) -> anyhow::Resul endpoint: Some(peer_a_endpoint.to_owned()), pre_shared_key: None, wg: None, - protocol_version: protocol_version.clone(), + protocol_version: protocol_version, osk_domain_separator: Default::default(), }], }; diff --git a/rosenpass/tests/api-integration-tests.rs b/rosenpass/tests/api-integration-tests.rs index 18380c2..ef54d28 100644 --- a/rosenpass/tests/api-integration-tests.rs +++ b/rosenpass/tests/api-integration-tests.rs @@ -82,7 +82,7 @@ fn api_integration_test(protocol_version: ProtocolVersion) -> anyhow::Result<()> endpoint: None, pre_shared_key: None, wg: None, - protocol_version: protocol_version.clone(), + protocol_version: protocol_version, osk_domain_separator: Default::default(), }], }; @@ -104,7 +104,7 @@ fn api_integration_test(protocol_version: ProtocolVersion) -> anyhow::Result<()> endpoint: Some(peer_a_endpoint.to_owned()), pre_shared_key: None, wg: None, - protocol_version: protocol_version.clone(), + protocol_version: protocol_version, osk_domain_separator: Default::default(), }], }; diff --git a/rosenpass/tests/integration_test.rs b/rosenpass/tests/integration_test.rs index d5b5f22..6d16ac9 100644 --- a/rosenpass/tests/integration_test.rs +++ b/rosenpass/tests/integration_test.rs @@ -144,7 +144,7 @@ fn check_example_config() { let tmp_dir = tempdir().unwrap(); let config_path = tmp_dir.path().join("config.toml"); - let mut config_file = File::create(config_path.to_owned()).unwrap(); + let mut config_file = File::create(&config_path).unwrap(); config_file .write_all( diff --git a/rp/Cargo.toml b/rp/Cargo.toml index 5b3091f..9eb308c 100644 --- a/rp/Cargo.toml +++ b/rp/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "rp" +name = "rosenpass-rp" version = "0.2.1" edition = "2021" license = "MIT OR Apache-2.0" @@ -8,7 +8,9 @@ homepage = "https://rosenpass.eu/" repository = "https://github.com/rosenpass/rosenpass" rust-version = "1.77.0" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[[bin]] +name = "rp" +path = "src/main.rs" [dependencies] anyhow = { workspace = true } @@ -17,12 +19,15 @@ serde = { workspace = true } toml = { workspace = true } x25519-dalek = { workspace = true, features = ["static_secrets"] } zeroize = { workspace = true } +libc = { workspace = true } +log = { workspace = true } +env_logger = { workspace = true } rosenpass = { workspace = true } rosenpass-ciphers = { workspace = true } rosenpass-cipher-traits = { workspace = true } rosenpass-secret-memory = { workspace = true } -rosenpass-util = { workspace = true } +rosenpass-util = { workspace = true, features = ["tokio"] } rosenpass-wireguard-broker = { workspace = true } tokio = { workspace = true } diff --git a/rp/src/exchange.rs b/rp/src/exchange.rs index d3f729c..fa5a4b2 100644 --- a/rp/src/exchange.rs +++ b/rp/src/exchange.rs @@ -1,16 +1,63 @@ -use std::{ - future::Future, net::SocketAddr, ops::DerefMut, path::PathBuf, pin::Pin, process::Command, - sync::Arc, -}; +use std::any::type_name; +use std::{borrow::Borrow, net::SocketAddr, path::PathBuf}; -use anyhow::{Error, Result}; +use tokio::process::Command; + +use anyhow::{bail, ensure, Context, Result}; +use futures_util::TryStreamExt as _; use serde::Deserialize; use rosenpass::config::ProtocolVersion; +use rosenpass::{ + app_server::{AppServer, BrokerPeer}, + config::Verbosity, + protocol::{ + basic_types::{SPk, SSk, SymKey}, + osk_domain_separator::OskDomainSeparator, + }, +}; +use rosenpass_secret_memory::Secret; +use rosenpass_util::file::{LoadValue as _, LoadValueB64}; +use rosenpass_util::functional::{ApplyExt, MutatingExt}; +use rosenpass_util::result::OkExt; +use rosenpass_util::tokio::janitor::{spawn_cleanup_job, try_spawn_daemon}; +use rosenpass_wireguard_broker::brokers::native_unix::{ + NativeUnixBroker, NativeUnixBrokerConfigBaseBuilder, +}; +use tokio::task::spawn_blocking; -#[cfg(any(target_os = "linux", target_os = "freebsd"))] use crate::key::WG_B64_LEN; +/// Extra-special measure to structure imports from the various +/// netlink related crates used in [super] +mod netlink { + /// Re-exports from [::netlink_packet_core] + pub mod core { + pub use ::netlink_packet_core::{NetlinkMessage, NLM_F_ACK, NLM_F_REQUEST}; + } + + /// Re-exports from [::rtnetlink] + pub mod rtnl { + pub use ::rtnetlink::Error; + pub use ::rtnetlink::Handle; + } + + /// Re-exports from [::genetlink] and [::netlink_packet_generic] + pub mod genl { + pub use ::genetlink::GenetlinkHandle as Handle; + pub use ::netlink_packet_generic::GenlMessage as Message; + } + + /// Re-exports from [::netlink_packet_wireguard] + pub mod wg { + pub use ::netlink_packet_wireguard::constants::WG_KEY_LEN as KEY_LEN; + pub use ::netlink_packet_wireguard::nlas::WgDeviceAttrs as DeviceAttrs; + pub use ::netlink_packet_wireguard::{Wireguard, WireguardCmd}; + } +} + +type WgSecretKey = Secret<{ netlink::wg::KEY_LEN }>; + /// Used to define a peer for the rosenpass connection that consists of /// a directory for storing public keys and optionally an IP address and port of the endpoint, /// for how long the connection should be kept alive and a list of allowed IPs for the peer. @@ -43,286 +90,401 @@ pub struct ExchangeOptions { pub dev: Option, /// The IP-address rosenpass should run under. pub ip: Option, - /// The IP-address and port that the rosenpass [AppServer](rosenpass::app_server::AppServer) + /// The IP-address and port that the rosenpass [AppServer] /// should use. pub listen: Option, /// Other peers a connection should be initialized to pub peers: Vec, } -#[cfg(not(any(target_os = "linux", target_os = "freebsd")))] -pub async fn exchange(_: ExchangeOptions) -> Result<()> { - use anyhow::anyhow; - - Err(anyhow!( - "Your system {} is not yet supported. We are happy to receive patches to address this :)", - std::env::consts::OS - )) +/// Manage the lifetime of WireGuard devices uses for rp +#[derive(Debug, Default)] +struct WireGuardDeviceImpl { + // TODO: Can we merge these two somehow? + rtnl_netlink_handle_cache: Option, + genl_netlink_handle_cache: Option, + /// Handle and name of the device + device: Option<(u32, String)>, } -#[cfg(any(target_os = "linux", target_os = "freebsd"))] -mod netlink { - use anyhow::Result; - use futures_util::{StreamExt as _, TryStreamExt as _}; - use genetlink::GenetlinkHandle; - use netlink_packet_core::{NLM_F_ACK, NLM_F_REQUEST}; - use netlink_packet_wireguard::nlas::WgDeviceAttrs; - use rtnetlink::Handle; +impl WireGuardDeviceImpl { + fn take(&mut self) -> WireGuardDeviceImpl { + Self::default().mutating(|nu| std::mem::swap(self, nu)) + } + async fn open(&mut self, device_name: String) -> anyhow::Result<()> { + let mut rtnl_link = self.rtnl_netlink_handle()?.link(); + let device_name_ref = &device_name; + + // Make sure that there is no device called `device_name` before we start + rtnl_link + .get() + .match_name(device_name.to_owned()) + .execute() + // Count the number of occurences + .try_fold(0, |acc, _val| async move { + Ok(acc + 1) + }).await + // Extract the error's raw system error code + .map_err(|e| { + use netlink::rtnl::Error as E; + match e { + E::NetlinkError(msg) => { + let raw_code = -msg.raw_code(); + (E::NetlinkError(msg), Some(raw_code)) + }, + _ => (e, None), + } + }) + .apply(|r| { + match r { + // No such device, which is exactly what we are expecting + Ok(0) | Err((_, Some(libc::ENODEV))) => Ok(()), + // Device already exists + Ok(_) => bail!("\ + Trying to create a network device for Rosenpass under the name \"{device_name}\", \ + but at least one device under the name aready exists."), + // Other error + Err((e, _)) => bail!(e), + } + })?; + + // Add the link, equivalent to `ip link add type wireguard`. + rtnl_link + .add() + .wireguard(device_name.to_owned()) + .execute() + .await?; + log::info!("Created network device!"); + + // Retrieve a handle for the newly created device + let device_handle = rtnl_link + .get() + .match_name(device_name.to_owned()) + .execute() + .err_into::() + .try_fold(Option::None, |acc, val| async move { + ensure!(acc.is_none(), "\ + Created a network device for Rosenpass under the name \"{device_name_ref}\", \ + but upon trying to determine the handle for the device using named-based lookup, we received multiple handles. \ + We checked beforehand whether the device already exists. \ + This should not happen. Unsure how to proceed. Terminating."); + Ok(Some(val)) + }).await? + .with_context(|| format!("\ + Created a network device for Rosenpass under the name \"{device_name}\", \ + but upon trying to determine the handle for the device using named-based lookup, we received no handle. \ + This should not happen. Unsure how to proceed. Terminating."))? + .apply(|msg| msg.header.index); + + // Now we can actually start to mark the device as initialized. + // Note that if the handle retrieval above does not work, the destructor + // will not run and the device will not be erased. + // This is, for now, the desired behavior as we need the handle to erase + // the device anyway. + self.device = Some((device_handle, device_name)); + + // Activate the link, equivalent to `ip link set dev up`. + rtnl_link.set(device_handle).up().execute().await?; + + Ok(()) + } + + async fn close(mut self) { + // Check if the device is properly initialized and retrieve the device info + let (device_handle, device_name) = match self.device.take() { + Some(val) => val, + // Nothing to do, not yet properly initialized + None => return, + }; + + // Erase the network device; the rest of the function is just error handling + let res = async move { + self.rtnl_netlink_handle()? + .link() + .del(device_handle) + .execute() + .await?; + log::debug!("Erased network interface!"); + anyhow::Ok(()) + } + .await; + + // Here we test if the error needs printing at all + let res = 'do_print: { + // Short-circuit if the deletion was successful + let err = match res { + Ok(()) => break 'do_print Ok(()), + Err(err) => err, + }; + + // Extract the rtnetlink error, so we can inspect it + let err = match err.downcast::() { + Ok(rtnl_err) => rtnl_err, + Err(other_err) => break 'do_print Err(other_err), + }; + + // TODO: This is a bit brittle, as the rtnetlink error enum looks like + // E::NetlinkError is a sort of "unknown error" case. If they explicitly + // add support for the "no such device" errors or other ones we check for in + // this block, then this code may no longer filter these errors + // Extract the raw netlink error code + use netlink::rtnl::Error as E; + let error_code = match err { + E::NetlinkError(ref msg) => -msg.raw_code(), + err => break 'do_print Err(err.into()), + }; + + // Check whether its just the "no such device" error + #[allow(clippy::single_match)] + match error_code { + libc::ENODEV => break 'do_print Ok(()), + _ => {} + } + + // Otherwise, we just print the error + Err(err.into()) + }; + + if let Err(err) = res { + log::warn!("Could not remove network device `{device_name}`: {err:?}"); + } + } + + pub async fn add_ip_address(&self, addr: &str) -> anyhow::Result<()> { + // TODO: Migrate to using netlink + Command::new("ip") + .args(["address", "add", addr, "dev", self.name()?]) + .status() + .await?; + Ok(()) + } + + pub fn is_open(&self) -> bool { + self.device.is_some() + } + + pub fn maybe_name(&self) -> Option<&str> { + self.device.as_ref().map(|slot| slot.1.borrow()) + } + + /// Return the raw handle for this device + pub fn maybe_raw_handle(&self) -> Option { + self.device.as_ref().map(|slot| slot.0) + } + + pub fn name(&self) -> anyhow::Result<&str> { + self.maybe_name() + .with_context(|| format!("{} has not been initialized!", type_name::())) + } + + /// Return the raw handle for this device + pub fn raw_handle(&self) -> anyhow::Result { + self.maybe_raw_handle() + .with_context(|| format!("{} has not been initialized!", type_name::())) + } + + pub async fn set_private_key_and_listen_addr( + &mut self, + wgsk: &WgSecretKey, + listen_port: Option, + ) -> anyhow::Result<()> { + use netlink as nl; + + // The attributes to set + // TODO: This exposes the secret key; we should probably run this in a separate process + // or on a separate stack and have zeroizing allocator globally. + let mut attrs = vec![ + nl::wg::DeviceAttrs::IfIndex(self.raw_handle()?), + nl::wg::DeviceAttrs::PrivateKey(*wgsk.secret()), + ]; + + // Optional listen port for WireGuard + if let Some(port) = listen_port { + attrs.push(nl::wg::DeviceAttrs::ListenPort(port)); + } + + // The netlink request we are trying to send + let req = nl::wg::Wireguard { + cmd: nl::wg::WireguardCmd::SetDevice, + nlas: attrs, + }; + + // Boilerplate; wrap the request into more structures + let req = req + .apply(nl::genl::Message::from_payload) + .apply(nl::core::NetlinkMessage::from) + .mutating(|req| { + req.header.flags = nl::core::NLM_F_REQUEST | nl::core::NLM_F_ACK; + }); + + // Send the request + self.genl_netlink_handle()? + .request(req) + .await? + // Collect all errors (let try_fold do all the work) + .try_fold((), |_, _| async move { Ok(()) }) + .await?; + + Ok(()) + } + + fn take_rtnl_netlink_handle(&mut self) -> Result { + if let Some(handle) = self.rtnl_netlink_handle_cache.take() { + Ok(handle) + } else { + let (connection, handle, _) = rtnetlink::new_connection()?; + + // Making sure that the connection has a chance to terminate before the + // application exits + try_spawn_daemon(async move { + connection.await; + Ok(()) + })?; + + Ok(handle) + } + } + + fn rtnl_netlink_handle(&mut self) -> Result<&mut netlink::rtnl::Handle> { + let netlink_handle = self.take_rtnl_netlink_handle()?; + self.rtnl_netlink_handle_cache.insert(netlink_handle).ok() + } + + fn take_genl_netlink_handle(&mut self) -> Result { + if let Some(handle) = self.genl_netlink_handle_cache.take() { + Ok(handle) + } else { + let (connection, handle, _) = genetlink::new_connection()?; + + // Making sure that the connection has a chance to terminate before the + // application exits + try_spawn_daemon(async move { + connection.await; + Ok(()) + })?; + + Ok(handle) + } + } + + fn genl_netlink_handle(&mut self) -> Result<&mut netlink::genl::Handle> { + let netlink_handle = self.take_genl_netlink_handle()?; + self.genl_netlink_handle_cache.insert(netlink_handle).ok() + } +} + +struct WireGuardDevice { + _impl: WireGuardDeviceImpl, +} + +impl WireGuardDevice { /// Creates a netlink named `link_name` and changes the state to up. It returns the index /// of the interface in the list of interfaces as the result or an error if any of the /// operations of creating the link or changing its state to up fails. - pub async fn link_create_and_up(rtnetlink: &Handle, link_name: String) -> Result { - // Add the link, equivalent to `ip link add type wireguard`. - rtnetlink - .link() - .add() - .wireguard(link_name.clone()) - .execute() - .await?; + pub async fn create_device(device_name: String) -> Result { + let mut _impl = WireGuardDeviceImpl::default(); + _impl.open(device_name).await?; + assert!(_impl.is_open()); // Sanity check + Ok(WireGuardDevice { _impl }) + } - // Retrieve the link to be able to up it, equivalent to `ip link show` and then - // using the link shown that is identified by `link_name`. - let link = rtnetlink - .link() - .get() - .match_name(link_name.clone()) - .execute() - .into_stream() - .into_future() + pub fn name(&self) -> &str { + self._impl.name().unwrap() + } + + /// Return the raw handle for this device + #[allow(dead_code)] + pub fn raw_handle(&self) -> u32 { + self._impl.raw_handle().unwrap() + } + + pub async fn add_ip_address(&self, addr: &str) -> anyhow::Result<()> { + self._impl.add_ip_address(addr).await + } + + pub async fn set_private_key_and_listen_addr( + &mut self, + wgsk: &WgSecretKey, + listen_port: Option, + ) -> anyhow::Result<()> { + self._impl + .set_private_key_and_listen_addr(wgsk, listen_port) .await - .0 - .unwrap()?; - - // Up the link, equivalent to `ip link set dev up`. - rtnetlink - .link() - .set(link.header.index) - .up() - .execute() - .await?; - - Ok(link.header.index) - } - - /// Deletes a link using rtnetlink. The link is specified using its index in the list of links. - pub async fn link_cleanup(rtnetlink: &Handle, index: u32) -> Result<()> { - rtnetlink.link().del(index).execute().await?; - - Ok(()) - } - - /// Deletes a link using rtnetlink. The link is specified using its index in the list of links. - /// In contrast to [link_cleanup], this function create a new socket connection to netlink and - /// *ignores errors* that occur during deletion. - pub async fn link_cleanup_standalone(index: u32) -> Result<()> { - let (connection, rtnetlink, _) = rtnetlink::new_connection()?; - tokio::spawn(connection); - - // We don't care if this fails, as the device may already have been auto-cleaned up. - let _ = rtnetlink.link().del(index).execute().await; - - Ok(()) - } - - /// This replicates the functionality of the `wg set` command line tool. - /// - /// It sets the specified WireGuard attributes of the indexed device by - /// communicating with WireGuard's generic netlink interface, like the - /// `wg` tool does. - pub async fn wg_set( - genetlink: &mut GenetlinkHandle, - index: u32, - mut attr: Vec, - ) -> Result<()> { - use futures_util::StreamExt as _; - use netlink_packet_core::{NetlinkMessage, NetlinkPayload}; - use netlink_packet_generic::GenlMessage; - use netlink_packet_wireguard::{Wireguard, WireguardCmd}; - - // Scope our `set` command to only the device of the specified index. - attr.insert(0, WgDeviceAttrs::IfIndex(index)); - - // Construct the WireGuard-specific netlink packet - let wgc = Wireguard { - cmd: WireguardCmd::SetDevice, - nlas: attr, - }; - - // Construct final message. - let genl = GenlMessage::from_payload(wgc); - let mut nlmsg = NetlinkMessage::from(genl); - nlmsg.header.flags = NLM_F_REQUEST | NLM_F_ACK; - - // Send and wait for the ACK or error. - let (res, _) = genetlink.request(nlmsg).await?.into_future().await; - if let Some(res) = res { - let res = res?; - if let NetlinkPayload::Error(err) = res.payload { - return Err(err.to_io().into()); - } - } - - Ok(()) } } -/// A wrapper for a list of cleanup handlers that can be used in an asynchronous context -/// to clean up after the usage of rosenpass or if the `rp` binary is interrupted with ctrl+c -/// or a `SIGINT` signal in general. -#[derive(Clone)] -#[cfg(any(target_os = "linux", target_os = "freebsd"))] -struct CleanupHandlers( - Arc<::futures::lock::Mutex> + Send>>>>>, -); - -#[cfg(any(target_os = "linux", target_os = "freebsd"))] -impl CleanupHandlers { - /// Creates a new list of [CleanupHandlers]. - fn new() -> Self { - CleanupHandlers(Arc::new(::futures::lock::Mutex::new(vec![]))) - } - - /// Enqueues a new cleanup handler in the form of a [Future]. - async fn enqueue(&self, handler: Pin> + Send>>) { - self.0.lock().await.push(Box::pin(handler)) - } - - /// Runs all cleanup handlers. Following the documentation of [futures::future::try_join_all]: - /// If any cleanup handler returns an error then all other cleanup handlers will be canceled and - /// an error will be returned immediately. If all cleanup handlers complete successfully, - /// however, then the returned future will succeed with a Vec of all the successful results. - async fn run(self) -> Result, Error> { - futures::future::try_join_all(self.0.lock().await.deref_mut()).await +impl Drop for WireGuardDevice { + fn drop(&mut self) { + let _impl = self._impl.take(); + spawn_cleanup_job(async move { + _impl.close().await; + Ok(()) + }); } } /// Sets up the rosenpass link and wireguard and configures both with the configuration specified by /// `options`. -#[cfg(any(target_os = "linux", target_os = "freebsd"))] pub async fn exchange(options: ExchangeOptions) -> Result<()> { - use std::fs; + // Load the server parameter files + let wgsk = options.private_keys_dir.join("wgsk"); + let rpsk = options.private_keys_dir.join("pqsk"); + let rppk = options.private_keys_dir.join("pqpk"); + let (wgsk, rpsk, rppk) = spawn_blocking(move || { + let wgsk = WgSecretKey::load_b64::(wgsk)?; + let rpsk = SSk::load(rpsk)?; + let wgpk = SPk::load(rppk)?; + anyhow::Ok((wgsk, rpsk, wgpk)) + }) + .await??; - use anyhow::anyhow; - use netlink_packet_wireguard::{constants::WG_KEY_LEN, nlas::WgDeviceAttrs}; - use rosenpass::{ - app_server::{AppServer, BrokerPeer}, - config::Verbosity, - protocol::{ - basic_types::{SPk, SSk, SymKey}, - osk_domain_separator::OskDomainSeparator, - }, - }; - use rosenpass_secret_memory::Secret; - use rosenpass_util::file::{LoadValue as _, LoadValueB64}; - use rosenpass_wireguard_broker::brokers::native_unix::{ - NativeUnixBroker, NativeUnixBrokerConfigBaseBuilder, NativeUnixBrokerConfigBaseBuilderError, - }; + // Setup the WireGuard device + let device = options.dev.as_deref().unwrap_or("rosenpass0"); + let mut device = WireGuardDevice::create_device(device.to_owned()).await?; - let (connection, rtnetlink, _) = rtnetlink::new_connection()?; - tokio::spawn(connection); + // Assign WG secret key & port + device + .set_private_key_and_listen_addr(&wgsk, options.listen.map(|ip| ip.port() + 1)) + .await?; + std::mem::drop(wgsk); - let link_name = options.dev.clone().unwrap_or("rosenpass0".to_string()); - let link_index = netlink::link_create_and_up(&rtnetlink, link_name.clone()).await?; - - // Set up a list of (initiallc empty) cleanup handlers that are to be run if - // ctrl-c is hit or generally a `SIGINT` signal is received and always in the end. - let cleanup_handlers = CleanupHandlers::new(); - let final_cleanup_handlers = (&cleanup_handlers).clone(); - - cleanup_handlers - .enqueue(Box::pin(async move { - netlink::link_cleanup_standalone(link_index).await - })) - .await; - - ctrlc_async::set_async_handler(async move { - final_cleanup_handlers - .run() - .await - .expect("Failed to clean up"); - })?; - - // Run `ip address add dev ` and enqueue `ip address del dev ` as a cleanup. - if let Some(ip) = options.ip { - let dev = options.dev.clone().unwrap_or("rosenpass0".to_string()); - Command::new("ip") - .arg("address") - .arg("add") - .arg(ip.clone()) - .arg("dev") - .arg(dev.clone()) - .status() - .expect("failed to configure ip"); - cleanup_handlers - .enqueue(Box::pin(async move { - Command::new("ip") - .arg("address") - .arg("del") - .arg(ip) - .arg("dev") - .arg(dev) - .status() - .expect("failed to remove ip"); - Ok(()) - })) - .await; + // Assign the public IP address for the interface + if let Some(ref ip) = options.ip { + device.add_ip_address(ip).await?; } - // Deploy the classic wireguard private key. - let (connection, mut genetlink, _) = genetlink::new_connection()?; - tokio::spawn(connection); - - let wgsk_path = options.private_keys_dir.join("wgsk"); - - let wgsk = Secret::::load_b64::(wgsk_path)?; - - let mut attr: Vec = Vec::with_capacity(2); - attr.push(WgDeviceAttrs::PrivateKey(*wgsk.secret())); - - if let Some(listen) = options.listen { - if listen.port() == u16::MAX { - return Err(anyhow!("You may not use {} as the listen port.", u16::MAX)); - } - - attr.push(WgDeviceAttrs::ListenPort(listen.port() + 1)); - } - - netlink::wg_set(&mut genetlink, link_index, attr).await?; - - // set up the rosenpass AppServer - let pqsk = options.private_keys_dir.join("pqsk"); - let pqpk = options.private_keys_dir.join("pqpk"); - - let sk = SSk::load(&pqsk)?; - let pk = SPk::load(&pqpk)?; - let mut srv = Box::new(AppServer::new( - Some((sk, pk)), - if let Some(listen) = options.listen { - vec![listen] - } else { - Vec::with_capacity(0) - }, - if options.verbose { - Verbosity::Verbose - } else { - Verbosity::Quiet + Some((rpsk, rppk)), + Vec::from_iter(options.listen), + match options.verbose { + true => Verbosity::Verbose, + false => Verbosity::Quiet, }, None, )?); let broker_store_ptr = srv.register_broker(Box::new(NativeUnixBroker::new()))?; - fn cfg_err_map(e: NativeUnixBrokerConfigBaseBuilderError) -> anyhow::Error { - anyhow::Error::msg(format!("NativeUnixBrokerConfigBaseBuilderError: {:?}", e)) - } - // Configure everything per peer. for peer in options.peers { - let wgpk = peer.public_keys_dir.join("wgpk"); + // TODO: Some of this is sync but should be async + let wgpk = peer + .public_keys_dir + .join("wgpk") + .apply(tokio::fs::read_to_string) + .await?; let pqpk = peer.public_keys_dir.join("pqpk"); let psk = peer.public_keys_dir.join("psk"); + let (pqpk, psk) = spawn_blocking(move || { + let pqpk = SPk::load(pqpk)?; + let psk = psk + .exists() + .then(|| SymKey::load_b64::(psk)) + .transpose()?; + anyhow::Ok((pqpk, psk)) + }) + .await??; let mut extra_params: Vec = Vec::with_capacity(6); if let Some(endpoint) = peer.endpoint { @@ -342,11 +504,11 @@ pub async fn exchange(options: ExchangeOptions) -> Result<()> { } let peer_cfg = NativeUnixBrokerConfigBaseBuilder::default() - .peer_id_b64(&fs::read_to_string(wgpk)?)? - .interface(link_name.clone()) + .peer_id_b64(&wgpk)? + .interface(device.name().to_owned()) .extra_params_ser(&extra_params)? .build() - .map_err(cfg_err_map)?; + .with_context(|| format!("Could not configure broker to supply keys from Rosenpass to WireGuard for peer {wgpk}."))?; let broker_peer = Some(BrokerPeer::new( broker_store_ptr.clone(), @@ -354,13 +516,8 @@ pub async fn exchange(options: ExchangeOptions) -> Result<()> { )); srv.add_peer( - if psk.exists() { - Some(SymKey::load_b64::(psk)) - } else { - None - } - .transpose()?, - SPk::load(&pqpk)?, + psk, + pqpk, None, broker_peer, peer.endpoint.map(|x| x.to_string()), @@ -372,47 +529,13 @@ pub async fn exchange(options: ExchangeOptions) -> Result<()> { // the cleanup as `ip route del `. if let Some(allowed_ips) = peer.allowed_ips { Command::new("ip") - .arg("route") - .arg("replace") - .arg(allowed_ips.clone()) - .arg("dev") - .arg(options.dev.clone().unwrap_or("rosenpass0".to_string())) + .args(["route", "replace", &allowed_ips, "dev", device.name()]) .status() - .expect("failed to configure route"); - cleanup_handlers - .enqueue(Box::pin(async move { - Command::new("ip") - .arg("route") - .arg("del") - .arg(allowed_ips) - .status() - .expect("failed to remove ip"); - Ok(()) - })) - .await; + .await + .with_context(|| format!("Could not configure routes for peer {wgpk}"))?; } } - let out = srv.event_loop(); - - netlink::link_cleanup(&rtnetlink, link_index).await?; - - match out { - Ok(_) => Ok(()), - Err(e) => { - // Check if the returned error is actually EINTR, in which case, the run actually - // succeeded. - let is_ok = if let Some(e) = e.root_cause().downcast_ref::() { - matches!(e.kind(), std::io::ErrorKind::Interrupted) - } else { - false - }; - - if is_ok { - Ok(()) - } else { - Err(e) - } - } - } + log::info!("Starting to perform rosenpass key exchanges!"); + spawn_blocking(move || srv.event_loop()).await? } diff --git a/rp/src/main.rs b/rp/src/main.rs index 1b9f696..40c6783 100644 --- a/rp/src/main.rs +++ b/rp/src/main.rs @@ -1,21 +1,28 @@ use std::{fs, process::exit}; -use cli::{Cli, Command}; -use exchange::exchange; -use key::{genkey, pubkey}; +use rosenpass_util::tokio::janitor::ensure_janitor; + use rosenpass_secret_memory::policy; +use crate::cli::{Cli, Command}; +use crate::exchange::exchange; +use crate::key::{genkey, pubkey}; + mod cli; mod exchange; mod key; #[tokio::main] -async fn main() { +async fn main() -> anyhow::Result<()> { #[cfg(feature = "experiment_memfd_secret")] policy::secret_policy_try_use_memfd_secrets(); #[cfg(not(feature = "experiment_memfd_secret"))] policy::secret_policy_use_only_malloc_secrets(); + ensure_janitor(async move { main_impl().await }).await +} + +async fn main_impl() -> anyhow::Result<()> { let cli = match Cli::parse(std::env::args().peekable()) { Ok(cli) => cli, Err(err) => { @@ -24,9 +31,13 @@ async fn main() { } }; + // init logging + // TODO: Taken from rosenpass; we should deduplicate the code. + env_logger::Builder::from_default_env().init(); // sets log level filter from environment (or defaults) + let command = cli.command.unwrap(); - let res = match command { + match command { Command::GenKey { private_keys_dir } => genkey(&private_keys_dir), Command::PubKey { private_keys_dir, @@ -47,13 +58,5 @@ async fn main() { println!("Usage: rp [verbose] genkey|pubkey|exchange [ARGS]..."); Ok(()) } - }; - - match res { - Ok(_) => {} - Err(err) => { - eprintln!("An error occurred: {}", err); - exit(1); - } } } diff --git a/secret-memory/src/secret.rs b/secret-memory/src/secret.rs index 90282ef..ec235ec 100644 --- a/secret-memory/src/secret.rs +++ b/secret-memory/src/secret.rs @@ -379,10 +379,7 @@ impl StoreSecret for Secret { #[cfg(test)] mod test { - use crate::{ - secret_policy_try_use_memfd_secrets, secret_policy_use_only_malloc_secrets, - test_spawn_process_provided_policies, - }; + use crate::{secret_policy_use_only_malloc_secrets, test_spawn_process_provided_policies}; use super::*; use std::{fs, os::unix::fs::PermissionsExt}; diff --git a/supply-chain/config.toml b/supply-chain/config.toml index a261f9a..616f97d 100644 --- a/supply-chain/config.toml +++ b/supply-chain/config.toml @@ -630,7 +630,11 @@ version = "3.2.0" criteria = "safe-to-run" [[exemptions.signal-hook]] -version = "0.3.17" +version = "0.3.18" +criteria = "safe-to-deploy" + +[[exemptions.signal-hook-mio]] +version = "0.2.4" criteria = "safe-to-deploy" [[exemptions.signal-hook-registry]] diff --git a/util/Cargo.toml b/util/Cargo.toml index bdd1ddf..df09eb2 100644 --- a/util/Cargo.toml +++ b/util/Cargo.toml @@ -25,7 +25,15 @@ mio = { workspace = true } tempfile = { workspace = true } uds = { workspace = true, optional = true, features = ["mio_1xx"] } libcrux-test-utils = { workspace = true, optional = true } +tokio = { workspace = true, optional = true, features = [ + "macros", + "rt-multi-thread", + "sync", + "time", +] } +log = { workspace = true } [features] experiment_file_descriptor_passing = ["uds"] trace_bench = ["dep:libcrux-test-utils"] +tokio = ["dep:tokio"] diff --git a/util/src/fmt/debug.rs b/util/src/fmt/debug.rs new file mode 100644 index 0000000..839b159 --- /dev/null +++ b/util/src/fmt/debug.rs @@ -0,0 +1,82 @@ +//! Helpers for string formatting with the debug formatter; extensions for [std::fmt::Debug] + +use std::any::type_name; +use std::borrow::{Borrow, BorrowMut}; +use std::ops::{Deref, DerefMut}; + +/// Debug formatter which just prints the type name; +/// used to wrap values which do not support the Debug +/// trait themselves +/// +/// # Examples +/// +/// ```rust +/// use rosenpass_util::fmt::debug::NullDebug; +/// +/// // Does not implement debug +/// struct NoDebug; +/// +/// #[derive(Debug)] +/// struct ShouldSupportDebug { +/// #[allow(dead_code)] +/// no_debug: NullDebug, +/// } +/// +/// let val = ShouldSupportDebug { +/// no_debug: NullDebug(NoDebug), +/// }; +/// ``` +pub struct NullDebug(pub T); + +impl std::fmt::Debug for NullDebug { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("NullDebug<")?; + f.write_str(type_name::())?; + f.write_str(">")?; + Ok(()) + } +} + +impl From for NullDebug { + fn from(value: T) -> Self { + NullDebug(value) + } +} + +impl Deref for NullDebug { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.0.borrow() + } +} + +impl DerefMut for NullDebug { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.borrow_mut() + } +} + +impl Borrow for NullDebug { + fn borrow(&self) -> &T { + self.deref() + } +} + +impl BorrowMut for NullDebug { + fn borrow_mut(&mut self) -> &mut T { + self.deref_mut() + } +} + +impl AsRef for NullDebug { + fn as_ref(&self) -> &T { + self.deref() + } +} + +impl AsMut for NullDebug { + fn as_mut(&mut self) -> &mut T { + self.deref_mut() + } +} diff --git a/util/src/fmt/mod.rs b/util/src/fmt/mod.rs new file mode 100644 index 0000000..0d35b06 --- /dev/null +++ b/util/src/fmt/mod.rs @@ -0,0 +1,3 @@ +//! Helpers for string formatting; extensions for [std::fmt] + +pub mod debug; diff --git a/util/src/lib.rs b/util/src/lib.rs index 7949a3b..69d7698 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -14,6 +14,7 @@ pub mod controlflow; pub mod fd; /// File system operations and handling. pub mod file; +pub mod fmt; /// Functional programming utilities. pub mod functional; /// Input/output operations. @@ -30,6 +31,8 @@ pub mod option; pub mod result; /// Time and duration utilities. pub mod time; +#[cfg(feature = "tokio")] +pub mod tokio; /// Trace benchmarking utilities #[cfg(feature = "trace_bench")] pub mod trace_bench; diff --git a/util/src/mio/uds_recv_fd.rs b/util/src/mio/uds_recv_fd.rs index 8cba2a3..31e0f9c 100644 --- a/util/src/mio/uds_recv_fd.rs +++ b/util/src/mio/uds_recv_fd.rs @@ -39,7 +39,7 @@ use crate::fd::{claim_fd_inplace, IntoStdioErr}; /// &io_stream, /// &mut read_fd_buffer, /// ); -//// +/// /// // Simulated reads; the actual operations will depend on the protocol (implementation details) /// let mut recv_buffer = Vec::::new(); /// let bytes_read = fd_passing_sock.read(&mut recv_buffer[..]).expect("error reading from socket"); diff --git a/util/src/tokio/janitor.rs b/util/src/tokio/janitor.rs new file mode 100644 index 0000000..91624e7 --- /dev/null +++ b/util/src/tokio/janitor.rs @@ -0,0 +1,618 @@ +//! Facilities to spawn tasks that will be reliably executed +//! before the current tokio context finishes. +//! +//! Asynchronous applications often need to manage multiple parallel tasks. +//! Tokio supports spawning these tasks with [tokio::task::spawn], but when the +//! tokio event loop exits, all lingering background tasks will aborted. +//! +//! Tokio supports managing multiple parallel tasks, all of which should exit successfully, through +//! [tokio::task::JoinSet]. This is a useful and very explicit API. To launch a background job, +//! user code needs to be aware of which JoinSet to use, so this can lead to a JoinSet needing to +//! be handed around in many parts of the application. +//! +//! This level of explicitness avoids bugs, but it can be cumbersome to use and it can introduce a +//! [function coloring](https://morestina.net/1686/rust-async-is-colored) issue; +//! creating a strong distinction between functions which have access +//! to a JoinSet (one color) and those that have not (the other color). Functions with the color +//! that has access to a JoinSet can call those functions that do not need access, but not the +//! other way around. This can make refactoring quite difficult: your refactor needs to use a +//! function that requires a JoinSet? Then have fun spending quite a bit of time recoloring +//! possibly many parts of your code base. +//! +//! This module solves this issue by essentially registering a central [JoinSet] through ambient +//! (semi-global), task-local variables. The mechanism to register this task-local JoinSet is +//! [tokio::task_local]. +//! +//! # Error-handling +//! +//! The janitor accepts daemons/cleanup jobs which return an [anyhow::Error]. +//! When any daemon returns an error, then the entire janitor will immediately exit with a failure +//! without awaiting the other registered tasks. +//! +//! The janitor can generally produce errors in three scenarios: +//! +//! - A daemon panics +//! - A daemon returns an error +//! - An internal error +//! +//! When [enter_janitor]/[ensure_janitor] is used to set up a janitor, these functions will always +//! panic in case of a janitor error. **This also means, that these functions panic if any daemon +//! returns an error**. +//! +//! You can explicitly handle janitor errors through [try_enter_janitor]/[try_ensure_janitor]. +//! +//! # Examples +//! +#![doc = "```ignore"] +#![doc = include_str!("../../tests/janitor.rs")] +#![doc = "```"] + +use std::any::type_name; +use std::future::Future; + +use anyhow::{bail, Context}; + +use tokio::task::{AbortHandle, JoinError, JoinHandle, JoinSet}; +use tokio::task_local; + +use tokio::sync::mpsc::unbounded_channel as janitor_channel; + +use crate::tokio::local_key::LocalKeyExt; + +/// Type for the message queue from [JanitorClient]/[JanitorSupervisor] to [JanitorAgent]: Receiving side +type JanitorQueueRx = tokio::sync::mpsc::UnboundedReceiver; +/// Type for the message queue from [JanitorClient]/[JanitorSupervisor] to [JanitorAgent]: Sending side +type JanitorQueueTx = tokio::sync::mpsc::UnboundedSender; +/// Type for the message queue from [JanitorClient]/[JanitorSupervisor] to [JanitorAgent]: Sending side, Weak reference +type WeakJanitorQueueTx = tokio::sync::mpsc::WeakUnboundedSender; + +/// Type of the return value for jobs submitted to [spawn_daemon]/[spawn_cleanup_job] +type CleanupJobResult = anyhow::Result<()>; +/// Handle by which we internally refer to cleanup jobs submitted by [spawn_daemon]/[spawn_cleanup_job] +/// to the current [JanitorAgent] +type CleanupJob = JoinHandle; + +task_local! { + /// Handle to the current [JanitorAgent]; this is where [ensure_janitor]/[enter_janitor] + /// register the newly created janitor + static CURRENT_JANITOR: JanitorClient; +} + +/// Messages supported by [JanitorAgent] +#[derive(Debug)] +enum JanitorTicket { + /// This message transmits a new cleanup job to the [JanitorAgent] + CleanupJob(CleanupJob), +} + +/// Represents the background task which actually manages cleanup jobs. +/// +/// This is what is started by [enter_janitor]/[ensure_janitor] +/// and what receives the messages sent by [JanitorSupervisor]/[JanitorClient] +#[derive(Debug)] +struct JanitorAgent { + /// Background tasks currently registered with this agent. + /// + /// This contains two types of tasks: + /// + /// 1. Background jobs launched through [enter_janitor]/[ensure_janitor] + /// 2. A single task waiting for new [JanitorTicket]s being transmitted from a [JanitorSupervisor]/[JanitorClient] + tasks: JoinSet, + /// Whether this [JanitorAgent] will ever receive new [JanitorTicket]s + /// + /// Communication between [JanitorAgent] and [JanitorSupervisor]/[JanitorClient] uses a message + /// queue (see [JanitorQueueTx]/[JanitorQueueRx]/[WeakJanitorQueueTx]), but you may notice that + /// the Agent does not actually contain a field storing the message queue. + /// Instead, to appease the borrow checker, the message queue is moved into the internal + /// background task (see [Self::tasks]) that waits for new [JanitorTicket]s. + /// + /// Since our state machine still needs to know, whether that queue is closed, we maintain this + /// flag. + /// + /// See [AgentInternalEvent::TicketQueueClosed]. + ticket_queue_closed: bool, +} + +/// These are the return values (events) returned by [JanitorAgent] internal tasks (see +/// [JanitorAgent::tasks]). +#[derive(Debug)] +enum AgentInternalEvent { + /// Notifies the [JanitorAgent] state machine that a cleanup job finished successfully + /// + /// Sent by genuine background tasks registered through [enter_janitor]/[ensure_janitor]. + CleanupJobSuccessful, + /// Notifies the [JanitorAgent] state machine that a cleanup job finished with a tokio + /// [JoinError]. + /// + /// Sent by genuine background tasks registered through [enter_janitor]/[ensure_janitor]. + CleanupJobJoinError(JoinError), + /// Notifies the [JanitorAgent] state machine that a cleanup job returned an error. + /// + /// Sent by genuine background tasks registered through [enter_janitor]/[ensure_janitor]. + CleanupJobReturnedError(anyhow::Error), + /// Notifies the [JanitorAgent] state machine that a new cleanup job was received through the + /// ticket queue. + /// + /// Sent by the background task managing the ticket queue. + ReceivedCleanupJob(JanitorQueueRx, CleanupJob), + /// Notifies the [JanitorAgent] state machine that a new cleanup job was received through the + /// ticket queue. + /// + /// Sent by the background task managing the ticket queue. + /// + /// See [JanitorAgent::ticket_queue_closed]. + TicketQueueClosed, +} + +impl JanitorAgent { + /// Create a new agent. Start with [Self::start]. + fn new() -> Self { + let tasks = JoinSet::new(); + let ticket_queue_closed = false; + Self { + tasks, + ticket_queue_closed, + } + } + + /// Main entry point for the [JanitorAgent]. Launches the background task and returns a [JanitorSupervisor] + /// which can be used to send tickets to the agent and to wait for agent termination. + pub async fn start() -> JanitorSupervisor { + let (queue_tx, queue_rx) = janitor_channel(); + let join_handle = tokio::spawn(async move { Self::new().event_loop(queue_rx).await }); + JanitorSupervisor::new(join_handle, queue_tx) + } + + /// Event loop, processing events from the ticket queue and from [Self::tasks] + async fn event_loop(&mut self, queue_rx: JanitorQueueRx) -> anyhow::Result<()> { + // Seed the internal task list with a single task to receive + self.spawn_internal_recv_ticket_task(queue_rx).await; + + // Process all incoming events until handle_one_event indicates there are + // no more events to process + while self.handle_one_event().await?.is_some() {} + + Ok(()) + } + + /// Process events from [Self::tasks] (and by proxy from the ticket queue) + /// + /// This is the agent's main state machine. + async fn handle_one_event(&mut self) -> anyhow::Result> { + use AgentInternalEvent as E; + match (self.tasks.join_next().await, self.ticket_queue_closed) { + // Normal, successful operation + + // CleanupJob exited successfully, no action neccesary + (Some(Ok(E::CleanupJobSuccessful)), _) => Ok(Some(())), + + // New cleanup job scheduled, add to task list and wait for another task + (Some(Ok(E::ReceivedCleanupJob(queue_rx, job))), _) => { + self.spawn_internal_recv_ticket_task(queue_rx).await; + self.spawn_internal_cleanup_task(job).await; + Ok(Some(())) + } + + // Ticket queue is closed; now we are just waiting for the remaining cleanup jobs + // to terminate + (Some(Ok(E::TicketQueueClosed)), _) => { + self.ticket_queue_closed = true; + Ok(Some(())) + } + + // No more tasks in the task manager and the ticket queue is already closed. + // This just means we are done and can finally terminate the janitor agent + (Option::None, true) => Ok(None), + + // Error handling + + // User callback errors + + // Some cleanup job returned an error as a result + (Some(Ok(E::CleanupJobReturnedError(err))), _) => Err(err).with_context(|| { + format!("Error in cleanup job handled by {}", type_name::()) + }), + + // JoinError produced by the user task: The user task was cancelled. + (Some(Ok(E::CleanupJobJoinError(err))), _) if err.is_cancelled() => Err(err).with_context(|| { + format!( + "Error in cleanup job handled by {me}; the cleanup task was cancelled. + This should not happend and likely indicates a developer error in {me}.", + me = type_name::() + ) + }), + + // JoinError produced by the user task: The user task panicked + (Some(Ok(E::CleanupJobJoinError(err))), _) => Err(err).with_context(|| { + format!( + "Error in cleanup job handled by {}; looks like the cleanup task panicked.", + type_name::() + ) + }), + + // Internal errors: Internal task error + + // JoinError produced by JoinSet::join_next(): The internal task was cancelled + (Some(Err(err)), _) if err.is_cancelled() => Err(err).with_context(|| { + format!( + "Internal error in {me}; internal async task was cancelled. \ + This is probably a developer error in {me}.", + me = type_name::() + ) + }), + + // JoinError produced by JoinSet::join_next(): The internal task panicked + (Some(Err(err)), _) => Err(err).with_context(|| { + format!( + "Internal error in {me}; internal async task panicked. \ + This is probably a developer error in {me}.", + me = type_name::() + ) + }), + + + // Internal errors: State machine failure + + // No tasks left, but ticket queue was not drained + (Option::None, false) => bail!("Internal error in {me}::handle_one_event(); \ + there are no more internal tasks active, but the ticket queue was not drained. \ + The {me}::handle_one_event() code is deliberately designed to never leave the internal task set empty; \ + instead, there should always be one task to receive new cleanup jobs from the task queue unless the task \ + queue has been closed. \ + This is probably a developer error.", + me = type_name::()) + } + } + + /// Used by [Self::event_loop] and [Self::handle_one_event] to start the internal + /// task waiting for tickets on the ticket queue. + async fn spawn_internal_recv_ticket_task( + &mut self, + mut queue_rx: JanitorQueueRx, + ) -> AbortHandle { + self.tasks.spawn(async { + use AgentInternalEvent as E; + use JanitorTicket as T; + + let ticket = queue_rx.recv().await; + match ticket { + Some(T::CleanupJob(job)) => E::ReceivedCleanupJob(queue_rx, job), + Option::None => E::TicketQueueClosed, + } + }) + } + + /// Used by [Self::event_loop] and [Self::handle_one_event] to register + /// background deamons/cleanup jobs submitted via [JanitorTicket] + async fn spawn_internal_cleanup_task(&mut self, job: CleanupJob) -> AbortHandle { + self.tasks.spawn(async { + use AgentInternalEvent as E; + match job.await { + Ok(Ok(())) => E::CleanupJobSuccessful, + Ok(Err(e)) => E::CleanupJobReturnedError(e), + Err(e) => E::CleanupJobJoinError(e), + } + }) + } +} + +/// Client for [JanitorAgent]. Allows for [JanitorTicket]s (background jobs) +/// to be transmitted to the current [JanitorAgent]. +/// +/// This is stored in [CURRENT_JANITOR] as a task.-local variable. +#[derive(Debug)] +struct JanitorClient { + /// Queue we can use to send messages to the current janitor + queue_tx: WeakJanitorQueueTx, +} + +impl JanitorClient { + /// Create a new client. Use through [JanitorSupervisor::get_client] + fn new(queue_tx: WeakJanitorQueueTx) -> Self { + Self { queue_tx } + } + + /// Has the associated [JanitorAgent] shut down? + pub fn is_closed(&self) -> bool { + self.queue_tx + .upgrade() + .map(|channel| channel.is_closed()) + .unwrap_or(false) + } + + /// Spawn a new cleanup job/daemon with the associated [JanitorAgent]. + /// + /// Used internally by [spawn_daemon]/[spawn_cleanup_job]. + pub fn spawn_cleanup_task(&self, future: F) -> Result<(), TrySpawnCleanupJobError> + where + F: Future> + Send + 'static, + { + let background_task = tokio::spawn(future); + self.queue_tx + .upgrade() + .ok_or(TrySpawnCleanupJobError::ActiveJanitorTerminating)? + .send(JanitorTicket::CleanupJob(background_task)) + .map_err(|_| TrySpawnCleanupJobError::ActiveJanitorTerminating) + } +} + +/// Client for [JanitorAgent]. Allows waiting for [JanitorAgent] termination as well as creating +/// [JanitorClient]s, which in turn can be used to submit background daemons/termination jobs +/// to the agent. +#[derive(Debug)] +struct JanitorSupervisor { + /// Represents the tokio task associated with the [JanitorAgent]. + /// + /// We use this to wait for [JanitorAgent] termination in [enter_janitor]/[ensure_janitor] + agent_join_handle: CleanupJob, + /// Queue we can use to send messages to the current janitor + queue_tx: JanitorQueueTx, +} + +impl JanitorSupervisor { + /// Create a new janitor supervisor. Use through [JanitorAgent::start] + pub fn new(agent_join_handle: CleanupJob, queue_tx: JanitorQueueTx) -> Self { + Self { + agent_join_handle, + queue_tx, + } + } + + /// Create a [JanitorClient] for submitting background daemons/cleanup jobs + pub fn get_client(&self) -> JanitorClient { + JanitorClient::new(self.queue_tx.clone().downgrade()) + } + + /// Wait for [JanitorAgent] termination + pub async fn terminate_janitor(self) -> anyhow::Result<()> { + std::mem::drop(self.queue_tx); + self.agent_join_handle.await? + } +} + +/// Return value of [try_enter_janitor]. +#[derive(Debug)] +pub struct EnterJanitorResult { + /// The result produced by the janitor itself. + /// + /// This may contain an error if one of the background daemons/cleanup tasks returned an error, + /// panicked, or in case there is an internal error in the janitor. + pub janitor_result: anyhow::Result<()>, + /// Contains the result of the future passed to [try_enter_janitor]. + pub callee_result: Result, +} + +impl EnterJanitorResult { + /// Create a new result from its components + pub fn new(janitor_result: anyhow::Result<()>, callee_result: Result) -> Self { + Self { + janitor_result, + callee_result, + } + } + + /// Turn this named type into a tuple + pub fn into_tuple(self) -> (anyhow::Result<()>, Result) { + (self.janitor_result, self.callee_result) + } + + /// Panic if [Self::janitor_result] contains an error; returning [Self::callee_result] + /// otherwise. + /// + /// If this panics and both [Self::janitor_result] and [Self::callee_result] contain an error, + /// this will print both errors. + pub fn unwrap_janitor_result(self) -> Result + where + E: std::fmt::Debug, + { + let me: EnsureJanitorResult = self.into(); + me.unwrap_janitor_result() + } + + /// Panic if [Self::janitor_result] or [Self::callee_result] contain an error, + /// returning the Ok value of [Self::callee_result]. + /// + /// If this panics and both [Self::janitor_result] and [Self::callee_result] contain an error, + /// this will print both errors. + pub fn unwrap(self) -> T + where + E: std::fmt::Debug, + { + let me: EnsureJanitorResult = self.into(); + me.unwrap() + } +} + +/// Return value of [try_ensure_janitor]. The only difference compared to [EnterJanitorResult] +/// is that [Self::janitor_result] contains None in case an ambient janitor had already existed. +#[derive(Debug)] +pub struct EnsureJanitorResult { + /// See [EnterJanitorResult::janitor_result] + /// + /// This is: + /// + /// - `None` if a pre-existing ambient janitor was used + /// - `Some(Ok(()))` if a new janitor had to be created and it exited successfully + /// - `Some(Err(...))` if a new janitor had to be created and it exited with an error + pub janitor_result: Option>, + /// See [EnterJanitorResult::callee] + pub callee_result: Result, +} + +impl EnsureJanitorResult { + /// See [EnterJanitorResult::new] + pub fn new(janitor_result: Option>, callee_result: Result) -> Self { + Self { + janitor_result, + callee_result, + } + } + + /// Sets up a [EnsureJanitorResult] with [EnsureJanitorResult::janitor_result] = None. + pub fn from_callee_result(callee_result: Result) -> Self { + Self::new(None, callee_result) + } + + /// Turn this named type into a tuple + pub fn into_tuple(self) -> (Option>, Result) { + (self.janitor_result, self.callee_result) + } + + /// See [EnterJanitorResult::unwrap_janitor_result] + /// + /// If [Self::janitor_result] is None, this won't panic. + pub fn unwrap_janitor_result(self) -> Result + where + E: std::fmt::Debug, + { + match self.into_tuple() { + (Some(Ok(())) | None, res) => res, + (Some(Err(err)), Ok(_)) => panic!( + "Callee in enter_janitor()/ensure_janitor() was successful, \ + but the janitor or some of its deamons failed: {err:?}" + ), + (Some(Err(jerr)), Err(cerr)) => panic!( + "Both the calee and the janitor or \ + some of its deamons falied in enter_janitor()/ensure_janitor():\n\ + \n\ + Janitor/Daemon error: {jerr:?} + \n\ + Callee error: {cerr:?}" + ), + } + } + + /// See [EnterJanitorResult::unwrap] + /// + /// If [Self::janitor_result] is None, this is not considered a failure. + pub fn unwrap(self) -> T + where + E: std::fmt::Debug, + { + match self.unwrap_janitor_result() { + Ok(val) => val, + Err(err) => panic!( + "Janitor or and its deamons in in enter_janitor()/ensure_janitor() was successful, \ + but the callee itself failed: {err:?}" + ), + } + } +} + +impl From> for EnsureJanitorResult { + fn from(val: EnterJanitorResult) -> Self { + EnsureJanitorResult::new(Some(val.janitor_result), val.callee_result) + } +} + +/// Non-panicking version of [enter_janitor]. +pub async fn try_enter_janitor(future: F) -> EnterJanitorResult +where + T: 'static, + F: Future> + 'static, +{ + let janitor_handle = JanitorAgent::start().await; + let callee_result = CURRENT_JANITOR + .scope(janitor_handle.get_client(), future) + .await; + let janitor_result = janitor_handle.terminate_janitor().await; + EnterJanitorResult::new(janitor_result, callee_result) +} + +/// Non-panicking version of [ensure_janitor] +pub async fn try_ensure_janitor(future: F) -> EnsureJanitorResult +where + T: 'static, + F: Future> + 'static, +{ + match CURRENT_JANITOR.is_set() { + true => EnsureJanitorResult::from_callee_result(future.await), + false => try_enter_janitor(future).await.into(), + } +} + +/// Register a janitor that can be used to register background daemons/cleanup jobs **only within +/// the future passed to this**. +/// +/// The function will wait for both the given future and all background jobs registered with the +/// janitor to terminate. +/// +/// For a version that does not panick, see [try_enter_janitor]. +pub async fn enter_janitor(future: F) -> Result +where + T: 'static, + E: std::fmt::Debug, + F: Future> + 'static, +{ + try_enter_janitor(future).await.unwrap_janitor_result() +} + +/// Variant of [enter_janitor] that will first check if a janitor already exists. +/// A new janitor is only set up, if no janitor has been previously registered. +pub async fn ensure_janitor(future: F) -> Result +where + T: 'static, + E: std::fmt::Debug, + F: Future> + 'static, +{ + try_ensure_janitor(future).await.unwrap_janitor_result() +} + +/// Error returned by [try_spawn_cleanup_job] +#[derive(thiserror::Error, Debug)] +pub enum TrySpawnCleanupJobError { + /// No active janitor exists + #[error("No janitor registered. Did the developer forget to call enter_janitor(…) or ensure_janitor(…)?")] + NoActiveJanitor, + /// The currently active janitor is in the process of terminating + #[error("There is a registered janitor, but it is currently in the process of terminating and won't accept new tasks.")] + ActiveJanitorTerminating, +} + +/// Check whether a janitor has been set up with [enter_janitor]/[ensure_janitor] +pub fn has_active_janitor() -> bool { + CURRENT_JANITOR + .try_with(|client| client.is_closed()) + .unwrap_or(false) +} + +/// Non-panicking variant of [spawn_cleanup_job]. +/// +/// This function is available under two names; see [spawn_cleanup_job] for details about this: +/// +/// 1. [try_spawn_cleanup_job] +/// 2. [try_spawn_daemon] +pub fn try_spawn_cleanup_job(future: F) -> Result<(), TrySpawnCleanupJobError> +where + F: Future> + Send + 'static, +{ + CURRENT_JANITOR + .try_with(|client| client.spawn_cleanup_task(future)) + .map_err(|_| TrySpawnCleanupJobError::NoActiveJanitor)??; + Ok(()) +} + +/// Register a cleanup job or a daemon with the current janitor registered through +/// [enter_janitor]/[ensure_janitor]: +/// +/// This function is available under two names: +/// +/// 1. [spawn_cleanup_job] +/// 2. [spawn_daemon] +/// +/// The first name should be used in destructors and to spawn cleanup actions which immediately +/// begin their task. +/// +/// The second name should be used for any other tasks; e.g. when the janitor setup is used to +/// manage multiple parallel jobs, all of which must be waited for. +pub fn spawn_cleanup_job(future: F) +where + F: Future> + Send + 'static, +{ + if let Err(e) = try_spawn_cleanup_job(future) { + panic!("Could not spawn cleanup job/daemon: {e:?}"); + } +} + +pub use spawn_cleanup_job as spawn_daemon; +pub use try_spawn_cleanup_job as try_spawn_daemon; diff --git a/util/src/tokio/local_key.rs b/util/src/tokio/local_key.rs new file mode 100644 index 0000000..e1c11ae --- /dev/null +++ b/util/src/tokio/local_key.rs @@ -0,0 +1,13 @@ +//! Helpers for [tokio::task::LocalKey] + +/// Extension trait for [tokio::task::LocalKey] +pub trait LocalKeyExt { + /// Check whether a tokio LocalKey is set + fn is_set(&'static self) -> bool; +} + +impl LocalKeyExt for tokio::task::LocalKey { + fn is_set(&'static self) -> bool { + self.try_with(|_| ()).is_ok() + } +} diff --git a/util/src/tokio/mod.rs b/util/src/tokio/mod.rs new file mode 100644 index 0000000..d257293 --- /dev/null +++ b/util/src/tokio/mod.rs @@ -0,0 +1,4 @@ +//! Tokio-related utilities + +pub mod janitor; +pub mod local_key; diff --git a/util/tests/janitor.rs b/util/tests/janitor.rs new file mode 100644 index 0000000..0661203 --- /dev/null +++ b/util/tests/janitor.rs @@ -0,0 +1,85 @@ +#![cfg(feature = "tokio")] + +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +use tokio::time::sleep; + +use rosenpass_util::tokio::janitor::{enter_janitor, spawn_cleanup_job, try_spawn_daemon}; + +#[tokio::test] +async fn janitor_demo() -> anyhow::Result<()> { + let count = Arc::new(AtomicUsize::new(0)); + + // Make sure the program has access to an ambient janitor + { + let count = count.clone(); + enter_janitor(async move { + let _drop_guard = AsyncDropDemo::new(count.clone()).await; + + // Start a background job + { + let count = count.clone(); + try_spawn_daemon(async move { + for _ in 0..17 { + count.fetch_add(1, Ordering::Relaxed); + sleep(Duration::from_micros(200)).await; + } + Ok(()) + })?; + } + + // Start another + { + let count = count.clone(); + try_spawn_daemon(async move { + for _ in 0..6 { + count.fetch_add(100, Ordering::Relaxed); + sleep(Duration::from_micros(800)).await; + } + Ok(()) + })?; + } + + // Note how this function just starts a couple background jobs, but exits immediately + + anyhow::Ok(()) + }) + } + .await; + + // At this point, all background jobs have finished, now we can check the result of all our + // additions + assert_eq!(count.load(Ordering::Acquire), 41617); + + Ok(()) +} + +/// Demo of how janitor can be used to implement async destructors +struct AsyncDropDemo { + count: Arc, +} + +impl AsyncDropDemo { + async fn new(count: Arc) -> Self { + count.fetch_add(1000, Ordering::Relaxed); + sleep(Duration::from_micros(50)).await; + AsyncDropDemo { count } + } +} + +impl Drop for AsyncDropDemo { + fn drop(&mut self) { + let count = self.count.clone(); + // This necessarily uses the panicking variant; + // we use spawn_cleanup_job because this makes more semantic sense in this context + spawn_cleanup_job(async move { + for _ in 0..4 { + count.fetch_add(10000, Ordering::Relaxed); + sleep(Duration::from_micros(800)).await; + } + Ok(()) + }) + } +}