mirror of
https://github.com/rosenpass/rosenpass.git
synced 2025-12-12 07:40:30 -08:00
Fix signal handling in rp and rosenpass (#685)
This commit is contained in:
79
Cargo.lock
generated
79
Cargo.lock
generated
@@ -2070,6 +2070,7 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
"serial_test",
|
"serial_test",
|
||||||
"signal-hook",
|
"signal-hook",
|
||||||
|
"signal-hook-mio",
|
||||||
"stacker",
|
"stacker",
|
||||||
"static_assertions",
|
"static_assertions",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
@@ -2152,6 +2153,38 @@ dependencies = [
|
|||||||
"rosenpass-util",
|
"rosenpass-util",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rosenpass-rp"
|
||||||
|
version = "0.2.1"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"base64ct",
|
||||||
|
"ctrlc-async",
|
||||||
|
"env_logger",
|
||||||
|
"futures",
|
||||||
|
"futures-util",
|
||||||
|
"genetlink",
|
||||||
|
"libc",
|
||||||
|
"log",
|
||||||
|
"netlink-packet-core",
|
||||||
|
"netlink-packet-generic",
|
||||||
|
"netlink-packet-wireguard",
|
||||||
|
"rosenpass",
|
||||||
|
"rosenpass-cipher-traits",
|
||||||
|
"rosenpass-ciphers",
|
||||||
|
"rosenpass-secret-memory",
|
||||||
|
"rosenpass-util",
|
||||||
|
"rosenpass-wireguard-broker",
|
||||||
|
"rtnetlink",
|
||||||
|
"serde",
|
||||||
|
"stacker",
|
||||||
|
"tempfile",
|
||||||
|
"tokio",
|
||||||
|
"toml",
|
||||||
|
"x25519-dalek",
|
||||||
|
"zeroize",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rosenpass-secret-memory"
|
name = "rosenpass-secret-memory"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
@@ -2184,11 +2217,13 @@ dependencies = [
|
|||||||
"anyhow",
|
"anyhow",
|
||||||
"base64ct",
|
"base64ct",
|
||||||
"libcrux-test-utils",
|
"libcrux-test-utils",
|
||||||
|
"log",
|
||||||
"mio",
|
"mio",
|
||||||
"rustix",
|
"rustix",
|
||||||
"static_assertions",
|
"static_assertions",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
"thiserror 1.0.69",
|
"thiserror 1.0.69",
|
||||||
|
"tokio",
|
||||||
"typenum",
|
"typenum",
|
||||||
"uds",
|
"uds",
|
||||||
"zerocopy 0.7.35",
|
"zerocopy 0.7.35",
|
||||||
@@ -2219,35 +2254,6 @@ dependencies = [
|
|||||||
"zerocopy 0.7.35",
|
"zerocopy 0.7.35",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "rp"
|
|
||||||
version = "0.2.1"
|
|
||||||
dependencies = [
|
|
||||||
"anyhow",
|
|
||||||
"base64ct",
|
|
||||||
"ctrlc-async",
|
|
||||||
"futures",
|
|
||||||
"futures-util",
|
|
||||||
"genetlink",
|
|
||||||
"netlink-packet-core",
|
|
||||||
"netlink-packet-generic",
|
|
||||||
"netlink-packet-wireguard",
|
|
||||||
"rosenpass",
|
|
||||||
"rosenpass-cipher-traits",
|
|
||||||
"rosenpass-ciphers",
|
|
||||||
"rosenpass-secret-memory",
|
|
||||||
"rosenpass-util",
|
|
||||||
"rosenpass-wireguard-broker",
|
|
||||||
"rtnetlink",
|
|
||||||
"serde",
|
|
||||||
"stacker",
|
|
||||||
"tempfile",
|
|
||||||
"tokio",
|
|
||||||
"toml",
|
|
||||||
"x25519-dalek",
|
|
||||||
"zeroize",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rtnetlink"
|
name = "rtnetlink"
|
||||||
version = "0.14.1"
|
version = "0.14.1"
|
||||||
@@ -2432,14 +2438,25 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "signal-hook"
|
name = "signal-hook"
|
||||||
version = "0.3.17"
|
version = "0.3.18"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801"
|
checksum = "d881a16cf4426aa584979d30bd82cb33429027e42122b169753d6ef1085ed6e2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
"signal-hook-registry",
|
"signal-hook-registry",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "signal-hook-mio"
|
||||||
|
version = "0.2.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"mio",
|
||||||
|
"signal-hook",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "signal-hook-registry"
|
name = "signal-hook-registry"
|
||||||
version = "1.4.2"
|
version = "1.4.2"
|
||||||
|
|||||||
@@ -54,6 +54,8 @@ serde = { version = "1.0.217", features = ["derive"] }
|
|||||||
arbitrary = { version = "1.4.1", features = ["derive"] }
|
arbitrary = { version = "1.4.1", features = ["derive"] }
|
||||||
anyhow = { version = "1.0.95", features = ["backtrace", "std"] }
|
anyhow = { version = "1.0.95", features = ["backtrace", "std"] }
|
||||||
mio = { version = "1.0.3", features = ["net", "os-poll"] }
|
mio = { version = "1.0.3", features = ["net", "os-poll"] }
|
||||||
|
signal-hook-mio = { version = "0.2.4", features = ["support-v1_0"] }
|
||||||
|
signal-hook = "0.3.17"
|
||||||
oqs-sys = { version = "0.9.1", default-features = false, features = [
|
oqs-sys = { version = "0.9.1", default-features = false, features = [
|
||||||
'classic_mceliece',
|
'classic_mceliece',
|
||||||
'kyber',
|
'kyber',
|
||||||
@@ -79,7 +81,6 @@ hex = { version = "0.4.3" }
|
|||||||
heck = { version = "0.5.0" }
|
heck = { version = "0.5.0" }
|
||||||
libc = { version = "0.2" }
|
libc = { version = "0.2" }
|
||||||
uds = { git = "https://github.com/rosenpass/uds" }
|
uds = { git = "https://github.com/rosenpass/uds" }
|
||||||
signal-hook = "0.3.17"
|
|
||||||
lazy_static = "1.5"
|
lazy_static = "1.5"
|
||||||
|
|
||||||
#Dev dependencies
|
#Dev dependencies
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ pub struct InferKeyedHash<Static, const KEY_LEN: usize, const HASH_LEN: usize>
|
|||||||
where
|
where
|
||||||
Static: KeyedHash<KEY_LEN, HASH_LEN>,
|
Static: KeyedHash<KEY_LEN, HASH_LEN>,
|
||||||
{
|
{
|
||||||
pub _phantom_keyed_hasher: PhantomData<*const Static>,
|
pub _phantom_keyed_hasher: PhantomData<Static>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<Static, const KEY_LEN: usize, const HASH_LEN: usize> InferKeyedHash<Static, KEY_LEN, HASH_LEN>
|
impl<Static, const KEY_LEN: usize, const HASH_LEN: usize> InferKeyedHash<Static, KEY_LEN, HASH_LEN>
|
||||||
|
|||||||
@@ -64,6 +64,8 @@ clap = { workspace = true }
|
|||||||
clap_complete = { workspace = true }
|
clap_complete = { workspace = true }
|
||||||
clap_mangen = { workspace = true }
|
clap_mangen = { workspace = true }
|
||||||
mio = { workspace = true }
|
mio = { workspace = true }
|
||||||
|
signal-hook = { workspace = true }
|
||||||
|
signal-hook-mio = { workspace = true }
|
||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
zerocopy = { workspace = true }
|
zerocopy = { workspace = true }
|
||||||
home = { workspace = true }
|
home = { workspace = true }
|
||||||
@@ -76,7 +78,6 @@ heck = { workspace = true, optional = true }
|
|||||||
command-fds = { workspace = true, optional = true }
|
command-fds = { workspace = true, optional = true }
|
||||||
rustix = { workspace = true, optional = true }
|
rustix = { workspace = true, optional = true }
|
||||||
uds = { workspace = true, optional = true, features = ["mio_1xx"] }
|
uds = { workspace = true, optional = true, features = ["mio_1xx"] }
|
||||||
signal-hook = { workspace = true, optional = true }
|
|
||||||
libcrux-test-utils = { workspace = true, optional = true }
|
libcrux-test-utils = { workspace = true, optional = true }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
@@ -109,7 +110,6 @@ experiment_api = [
|
|||||||
"rosenpass-util/experiment_file_descriptor_passing",
|
"rosenpass-util/experiment_file_descriptor_passing",
|
||||||
"rosenpass-wireguard-broker/experiment_api",
|
"rosenpass-wireguard-broker/experiment_api",
|
||||||
]
|
]
|
||||||
internal_signal_handling_for_coverage_reports = ["signal-hook"]
|
|
||||||
internal_testing = []
|
internal_testing = []
|
||||||
internal_bin_gen_ipc_msg_types = ["hex", "heck"]
|
internal_bin_gen_ipc_msg_types = ["hex", "heck"]
|
||||||
trace_bench = ["rosenpass-util/trace_bench", "dep:libcrux-test-utils"]
|
trace_bench = ["rosenpass-util/trace_bench", "dep:libcrux-test-utils"]
|
||||||
|
|||||||
@@ -7,17 +7,20 @@ use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSoc
|
|||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use std::{cell::Cell, fmt::Debug, io, path::PathBuf, slice};
|
use std::{cell::Cell, fmt::Debug, io, path::PathBuf, slice};
|
||||||
|
|
||||||
|
use mio::{Interest, Token};
|
||||||
|
use signal_hook_mio::v1_0 as signal_hook_mio;
|
||||||
|
|
||||||
use anyhow::{bail, Context, Result};
|
use anyhow::{bail, Context, Result};
|
||||||
use derive_builder::Builder;
|
use derive_builder::Builder;
|
||||||
use log::{error, info, warn};
|
use log::{error, info, warn};
|
||||||
use mio::{Interest, Token};
|
|
||||||
use zerocopy::AsBytes;
|
use zerocopy::AsBytes;
|
||||||
|
|
||||||
use rosenpass_util::attempt;
|
use rosenpass_util::attempt;
|
||||||
|
use rosenpass_util::fmt::debug::NullDebug;
|
||||||
use rosenpass_util::functional::{run, ApplyExt};
|
use rosenpass_util::functional::{run, ApplyExt};
|
||||||
use rosenpass_util::io::{IoResultKindHintExt, SubstituteForIoErrorKindExt};
|
use rosenpass_util::io::{IoResultKindHintExt, SubstituteForIoErrorKindExt};
|
||||||
use rosenpass_util::{
|
use rosenpass_util::{
|
||||||
b64::B64Display, build::ConstructionSite, file::StoreValueB64, option::SomeExt, result::OkExt,
|
b64::B64Display, build::ConstructionSite, file::StoreValueB64, result::OkExt,
|
||||||
};
|
};
|
||||||
|
|
||||||
use rosenpass_secret_memory::{Public, Secret};
|
use rosenpass_secret_memory::{Public, Secret};
|
||||||
@@ -129,7 +132,7 @@ pub struct BrokerStore {
|
|||||||
/// The collection of WireGuard brokers. See [Self].
|
/// The collection of WireGuard brokers. See [Self].
|
||||||
pub store: HashMap<
|
pub store: HashMap<
|
||||||
Public<BROKER_ID_BYTES>,
|
Public<BROKER_ID_BYTES>,
|
||||||
Box<dyn WireguardBrokerMio<Error = anyhow::Error, MioError = anyhow::Error>>,
|
Box<dyn WireguardBrokerMio<Error = anyhow::Error, MioError = anyhow::Error> + Send>,
|
||||||
>,
|
>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,12 +149,12 @@ pub struct BrokerPeer {
|
|||||||
///
|
///
|
||||||
/// This is woefully overengineered and there is very little reason why the broker
|
/// This is woefully overengineered and there is very little reason why the broker
|
||||||
/// configuration should not live in the particular WireGuard broker.
|
/// configuration should not live in the particular WireGuard broker.
|
||||||
peer_cfg: Box<dyn WireguardBrokerCfg>,
|
peer_cfg: Box<dyn WireguardBrokerCfg + Send>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BrokerPeer {
|
impl BrokerPeer {
|
||||||
/// Create a broker peer
|
/// Create a broker peer
|
||||||
pub fn new(ptr: BrokerStorePtr, peer_cfg: Box<dyn WireguardBrokerCfg>) -> Self {
|
pub fn new(ptr: BrokerStorePtr, peer_cfg: Box<dyn WireguardBrokerCfg + Send>) -> Self {
|
||||||
Self { ptr, peer_cfg }
|
Self { ptr, peer_cfg }
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -286,12 +289,20 @@ pub enum AppServerIoSource {
|
|||||||
Socket(usize),
|
Socket(usize),
|
||||||
/// IO source refers to a PSK broker in [AppServer::brokers]
|
/// IO source refers to a PSK broker in [AppServer::brokers]
|
||||||
PskBroker(Public<BROKER_ID_BYTES>),
|
PskBroker(Public<BROKER_ID_BYTES>),
|
||||||
|
/// IO source refers to our signal handlers
|
||||||
|
SignalHandler,
|
||||||
/// IO source refers to some IO sources used in the API;
|
/// IO source refers to some IO sources used in the API;
|
||||||
/// see [AppServer::api_manager]
|
/// see [AppServer::api_manager]
|
||||||
#[cfg(feature = "experiment_api")]
|
#[cfg(feature = "experiment_api")]
|
||||||
MioManager(crate::api::mio::MioManagerIoSource),
|
MioManager(crate::api::mio::MioManagerIoSource),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub enum AppServerTryRecvResult {
|
||||||
|
None,
|
||||||
|
Terminate,
|
||||||
|
NetworkMessage(usize, Endpoint),
|
||||||
|
}
|
||||||
|
|
||||||
/// Number of epoll(7) events Rosenpass can receive at a time
|
/// Number of epoll(7) events Rosenpass can receive at a time
|
||||||
const EVENT_CAPACITY: usize = 20;
|
const EVENT_CAPACITY: usize = 20;
|
||||||
|
|
||||||
@@ -332,6 +343,8 @@ pub struct AppServer {
|
|||||||
/// MIO associates IO sources with numeric tokens. This struct takes care of generating these
|
/// MIO associates IO sources with numeric tokens. This struct takes care of generating these
|
||||||
/// tokens
|
/// tokens
|
||||||
pub mio_token_dispenser: MioTokenDispenser,
|
pub mio_token_dispenser: MioTokenDispenser,
|
||||||
|
/// Mio-based handler for signals
|
||||||
|
pub signal_handler: NullDebug<signal_hook_mio::Signals>,
|
||||||
/// Helpers handling communication with WireGuard; these take a generated key and forward it to
|
/// Helpers handling communication with WireGuard; these take a generated key and forward it to
|
||||||
/// WireGuard
|
/// WireGuard
|
||||||
pub brokers: BrokerStore,
|
pub brokers: BrokerStore,
|
||||||
@@ -357,16 +370,6 @@ pub struct AppServer {
|
|||||||
/// Used by integration tests to force [Self] into DoS condition
|
/// Used by integration tests to force [Self] into DoS condition
|
||||||
/// and to terminate the AppServer after the test is complete
|
/// and to terminate the AppServer after the test is complete
|
||||||
pub test_helpers: Option<AppServerTest>,
|
pub test_helpers: Option<AppServerTest>,
|
||||||
/// Helper for integration tests running rosenpass as a subprocess
|
|
||||||
/// to terminate properly upon receiving an appropriate system signal.
|
|
||||||
///
|
|
||||||
/// This is primarily needed for coverage testing, since llvm-cov does not
|
|
||||||
/// write coverage reports to disk when a process is stopped by the default
|
|
||||||
/// signal handler.
|
|
||||||
///
|
|
||||||
/// See <https://github.com/rosenpass/rosenpass/issues/385>
|
|
||||||
#[cfg(feature = "internal_signal_handling_for_coverage_reports")]
|
|
||||||
pub term_signal: terminate::TerminateRequested,
|
|
||||||
#[cfg(feature = "experiment_api")]
|
#[cfg(feature = "experiment_api")]
|
||||||
/// The Rosenpass unix socket API handler; this is an experimental
|
/// The Rosenpass unix socket API handler; this is an experimental
|
||||||
/// feature that can be used to embed Rosenpass in external applications
|
/// feature that can be used to embed Rosenpass in external applications
|
||||||
@@ -456,6 +459,8 @@ impl AppPeerPtr {
|
|||||||
/// Instructs [AppServer::event_loop_without_error_handling] on how to proceed.
|
/// Instructs [AppServer::event_loop_without_error_handling] on how to proceed.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum AppPollResult {
|
pub enum AppPollResult {
|
||||||
|
/// Received request to terminate the application
|
||||||
|
Terminate,
|
||||||
/// Erase the key for a given peer. Corresponds to [crate::protocol::PollResult::DeleteKey]
|
/// Erase the key for a given peer. Corresponds to [crate::protocol::PollResult::DeleteKey]
|
||||||
DeleteKey(AppPeerPtr),
|
DeleteKey(AppPeerPtr),
|
||||||
/// Send an initiation to the given peer. Corresponds to [crate::protocol::PollResult::SendInitiation]
|
/// Send an initiation to the given peer. Corresponds to [crate::protocol::PollResult::SendInitiation]
|
||||||
@@ -802,10 +807,27 @@ impl AppServer {
|
|||||||
verbosity: Verbosity,
|
verbosity: Verbosity,
|
||||||
test_helpers: Option<AppServerTest>,
|
test_helpers: Option<AppServerTest>,
|
||||||
) -> anyhow::Result<Self> {
|
) -> anyhow::Result<Self> {
|
||||||
// setup mio
|
// Setup Mio itself
|
||||||
let mio_poll = mio::Poll::new()?;
|
let mio_poll = mio::Poll::new()?;
|
||||||
let events = mio::Events::with_capacity(EVENT_CAPACITY);
|
let events = mio::Events::with_capacity(EVENT_CAPACITY);
|
||||||
|
|
||||||
|
// And helpers to map mio tokens to internal event types
|
||||||
let mut mio_token_dispenser = MioTokenDispenser::default();
|
let mut mio_token_dispenser = MioTokenDispenser::default();
|
||||||
|
let mut io_source_index = HashMap::new();
|
||||||
|
|
||||||
|
// Setup signal handling
|
||||||
|
let signal_handler = attempt!({
|
||||||
|
let mut handler =
|
||||||
|
signal_hook_mio::Signals::new(signal_hook::consts::TERM_SIGNALS.iter())?;
|
||||||
|
let mio_token = mio_token_dispenser.dispense();
|
||||||
|
mio_poll
|
||||||
|
.registry()
|
||||||
|
.register(&mut handler, mio_token, Interest::READABLE)?;
|
||||||
|
let prev = io_source_index.insert(mio_token, AppServerIoSource::SignalHandler);
|
||||||
|
assert!(prev.is_none());
|
||||||
|
Ok(NullDebug(handler))
|
||||||
|
})
|
||||||
|
.context("Failed to set up signal (user triggered program termination) handler")?;
|
||||||
|
|
||||||
// bind each SocketAddr to a socket
|
// bind each SocketAddr to a socket
|
||||||
let maybe_sockets: Result<Vec<_>, _> =
|
let maybe_sockets: Result<Vec<_>, _> =
|
||||||
@@ -879,7 +901,6 @@ impl AppServer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// register all sockets to mio
|
// register all sockets to mio
|
||||||
let mut io_source_index = HashMap::new();
|
|
||||||
for (idx, socket) in sockets.iter_mut().enumerate() {
|
for (idx, socket) in sockets.iter_mut().enumerate() {
|
||||||
let mio_token = mio_token_dispenser.dispense();
|
let mio_token = mio_token_dispenser.dispense();
|
||||||
mio_poll
|
mio_poll
|
||||||
@@ -895,8 +916,6 @@ impl AppServer {
|
|||||||
};
|
};
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
#[cfg(feature = "internal_signal_handling_for_coverage_reports")]
|
|
||||||
term_signal: terminate::TerminateRequested::new()?,
|
|
||||||
crypto_site,
|
crypto_site,
|
||||||
peers: Vec::new(),
|
peers: Vec::new(),
|
||||||
verbosity,
|
verbosity,
|
||||||
@@ -907,6 +926,7 @@ impl AppServer {
|
|||||||
io_source_index,
|
io_source_index,
|
||||||
mio_poll,
|
mio_poll,
|
||||||
mio_token_dispenser,
|
mio_token_dispenser,
|
||||||
|
signal_handler,
|
||||||
brokers: BrokerStore::default(),
|
brokers: BrokerStore::default(),
|
||||||
all_sockets_drained: false,
|
all_sockets_drained: false,
|
||||||
under_load: DoSOperation::Normal,
|
under_load: DoSOperation::Normal,
|
||||||
@@ -977,7 +997,7 @@ impl AppServer {
|
|||||||
/// Register a new WireGuard PSK broker
|
/// Register a new WireGuard PSK broker
|
||||||
pub fn register_broker(
|
pub fn register_broker(
|
||||||
&mut self,
|
&mut self,
|
||||||
broker: Box<dyn WireguardBrokerMio<Error = anyhow::Error, MioError = anyhow::Error>>,
|
broker: Box<dyn WireguardBrokerMio<Error = anyhow::Error, MioError = anyhow::Error> + Send>,
|
||||||
) -> Result<BrokerStorePtr> {
|
) -> Result<BrokerStorePtr> {
|
||||||
let ptr = Public::from_slice((self.brokers.store.len() as u64).as_bytes());
|
let ptr = Public::from_slice((self.brokers.store.len() as u64).as_bytes());
|
||||||
if self.brokers.store.insert(ptr, broker).is_some() {
|
if self.brokers.store.insert(ptr, broker).is_some() {
|
||||||
@@ -1049,7 +1069,7 @@ impl AppServer {
|
|||||||
Ok(AppPeerPtr(pn))
|
Ok(AppPeerPtr(pn))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Main IO handler; this generally does not terminate
|
/// Main IO handler; this generally does not terminate other than through unix signals
|
||||||
///
|
///
|
||||||
/// # Examples
|
/// # Examples
|
||||||
///
|
///
|
||||||
@@ -1066,23 +1086,6 @@ impl AppServer {
|
|||||||
Err(e) => e,
|
Err(e) => e,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[cfg(feature = "internal_signal_handling_for_coverage_reports")]
|
|
||||||
{
|
|
||||||
let terminated_by_signal = err
|
|
||||||
.downcast_ref::<std::io::Error>()
|
|
||||||
.filter(|e| e.kind() == std::io::ErrorKind::Interrupted)
|
|
||||||
.filter(|_| self.term_signal.value())
|
|
||||||
.is_some();
|
|
||||||
if terminated_by_signal {
|
|
||||||
log::warn!(
|
|
||||||
"\
|
|
||||||
Terminated by signal; this signal handler is correct during coverage testing \
|
|
||||||
but should be otherwise disabled"
|
|
||||||
);
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This should not happen…
|
// This should not happen…
|
||||||
failure_cnt = if msgs_processed > 0 {
|
failure_cnt = if msgs_processed > 0 {
|
||||||
0
|
0
|
||||||
@@ -1135,6 +1138,7 @@ impl AppServer {
|
|||||||
use AppPollResult::*;
|
use AppPollResult::*;
|
||||||
use KeyOutputReason::*;
|
use KeyOutputReason::*;
|
||||||
|
|
||||||
|
// TODO: We should read from this using a mio channel
|
||||||
if let Some(AppServerTest {
|
if let Some(AppServerTest {
|
||||||
termination_handler: Some(terminate),
|
termination_handler: Some(terminate),
|
||||||
..
|
..
|
||||||
@@ -1158,6 +1162,8 @@ impl AppServer {
|
|||||||
|
|
||||||
#[allow(clippy::redundant_closure_call)]
|
#[allow(clippy::redundant_closure_call)]
|
||||||
match (have_crypto, poll_result) {
|
match (have_crypto, poll_result) {
|
||||||
|
(_, Terminate) => return Ok(()),
|
||||||
|
|
||||||
(CryptoSrv::Missing, SendInitiation(_)) => {}
|
(CryptoSrv::Missing, SendInitiation(_)) => {}
|
||||||
(CryptoSrv::Avail, SendInitiation(peer)) => tx_maybe_with!(peer, || self
|
(CryptoSrv::Avail, SendInitiation(peer)) => tx_maybe_with!(peer, || self
|
||||||
.crypto_server_mut()?
|
.crypto_server_mut()?
|
||||||
@@ -1305,6 +1311,7 @@ impl AppServer {
|
|||||||
pub fn poll(&mut self, rx_buf: &mut [u8]) -> anyhow::Result<AppPollResult> {
|
pub fn poll(&mut self, rx_buf: &mut [u8]) -> anyhow::Result<AppPollResult> {
|
||||||
use crate::protocol::PollResult as C;
|
use crate::protocol::PollResult as C;
|
||||||
use AppPollResult as A;
|
use AppPollResult as A;
|
||||||
|
use AppServerTryRecvResult as R;
|
||||||
let res = loop {
|
let res = loop {
|
||||||
// Call CryptoServer's poll (if available)
|
// Call CryptoServer's poll (if available)
|
||||||
let crypto_poll = self
|
let crypto_poll = self
|
||||||
@@ -1325,8 +1332,10 @@ impl AppServer {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Perform IO (look for a message)
|
// Perform IO (look for a message)
|
||||||
if let Some((len, addr)) = self.try_recv(rx_buf, io_poll_timeout)? {
|
match self.try_recv(rx_buf, io_poll_timeout)? {
|
||||||
break A::ReceivedMessage(len, addr);
|
R::None => {}
|
||||||
|
R::Terminate => break A::Terminate,
|
||||||
|
R::NetworkMessage(len, addr) => break A::ReceivedMessage(len, addr),
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -1344,12 +1353,12 @@ impl AppServer {
|
|||||||
&mut self,
|
&mut self,
|
||||||
buf: &mut [u8],
|
buf: &mut [u8],
|
||||||
timeout: Timing,
|
timeout: Timing,
|
||||||
) -> anyhow::Result<Option<(usize, Endpoint)>> {
|
) -> anyhow::Result<AppServerTryRecvResult> {
|
||||||
let timeout = Duration::from_secs_f64(timeout);
|
let timeout = Duration::from_secs_f64(timeout);
|
||||||
|
|
||||||
// if there is no time to wait on IO, well, then, lets not waste any time!
|
// if there is no time to wait on IO, well, then, lets not waste any time!
|
||||||
if timeout.is_zero() {
|
if timeout.is_zero() {
|
||||||
return Ok(None);
|
return Ok(AppServerTryRecvResult::None);
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE when using mio::Poll, there are some particularities (taken from
|
// NOTE when using mio::Poll, there are some particularities (taken from
|
||||||
@@ -1459,12 +1468,19 @@ impl AppServer {
|
|||||||
// blocking poll, we go through all available IO sources to see if we missed anything.
|
// blocking poll, we go through all available IO sources to see if we missed anything.
|
||||||
{
|
{
|
||||||
while let Some(ev) = self.short_poll_queue.pop_front() {
|
while let Some(ev) = self.short_poll_queue.pop_front() {
|
||||||
if let Some(v) = self.try_recv_from_mio_token(buf, ev.token())? {
|
match self.try_recv_from_mio_token(buf, ev.token())? {
|
||||||
return Ok(Some(v));
|
AppServerTryRecvResult::None => continue,
|
||||||
|
res => return Ok(res),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Drain operating system signals
|
||||||
|
match self.try_recv_from_signal_handler()? {
|
||||||
|
AppServerTryRecvResult::None => {} // Nop
|
||||||
|
res => return Ok(res),
|
||||||
|
}
|
||||||
|
|
||||||
// drain all sockets
|
// drain all sockets
|
||||||
let mut would_block_count = 0;
|
let mut would_block_count = 0;
|
||||||
for sock_no in 0..self.sockets.len() {
|
for sock_no in 0..self.sockets.len() {
|
||||||
@@ -1472,11 +1488,11 @@ impl AppServer {
|
|||||||
.try_recv_from_listen_socket(buf, sock_no)
|
.try_recv_from_listen_socket(buf, sock_no)
|
||||||
.io_err_kind_hint()
|
.io_err_kind_hint()
|
||||||
{
|
{
|
||||||
Ok(None) => continue,
|
Ok(AppServerTryRecvResult::None) => continue,
|
||||||
Ok(Some(v)) => {
|
Ok(res) => {
|
||||||
// at least one socket was not drained...
|
// at least one socket was not drained...
|
||||||
self.all_sockets_drained = false;
|
self.all_sockets_drained = false;
|
||||||
return Ok(Some(v));
|
return Ok(res);
|
||||||
}
|
}
|
||||||
Err((_, ErrorKind::WouldBlock)) => {
|
Err((_, ErrorKind::WouldBlock)) => {
|
||||||
would_block_count += 1;
|
would_block_count += 1;
|
||||||
@@ -1504,12 +1520,24 @@ impl AppServer {
|
|||||||
|
|
||||||
self.performed_long_poll = true;
|
self.performed_long_poll = true;
|
||||||
|
|
||||||
Ok(None)
|
Ok(AppServerTryRecvResult::None)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Internal helper for [Self::try_recv]
|
/// Internal helper for [Self::try_recv]
|
||||||
fn perform_mio_poll_and_register_events(&mut self, timeout: Duration) -> io::Result<()> {
|
fn perform_mio_poll_and_register_events(&mut self, timeout: Duration) -> io::Result<()> {
|
||||||
self.mio_poll.poll(&mut self.events, Some(timeout))?;
|
loop {
|
||||||
|
use std::io::ErrorKind as IOE;
|
||||||
|
match self
|
||||||
|
.mio_poll
|
||||||
|
.poll(&mut self.events, Some(timeout))
|
||||||
|
.io_err_kind_hint()
|
||||||
|
{
|
||||||
|
Ok(()) => break,
|
||||||
|
Err((_, IOE::Interrupted)) => continue,
|
||||||
|
Err((err, _)) => return Err(err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Fill the short poll buffer with the acquired events
|
// Fill the short poll buffer with the acquired events
|
||||||
self.events
|
self.events
|
||||||
.iter()
|
.iter()
|
||||||
@@ -1523,12 +1551,12 @@ impl AppServer {
|
|||||||
&mut self,
|
&mut self,
|
||||||
buf: &mut [u8],
|
buf: &mut [u8],
|
||||||
token: mio::Token,
|
token: mio::Token,
|
||||||
) -> anyhow::Result<Option<(usize, Endpoint)>> {
|
) -> anyhow::Result<AppServerTryRecvResult> {
|
||||||
let io_source = match self.io_source_index.get(&token) {
|
let io_source = match self.io_source_index.get(&token) {
|
||||||
Some(io_source) => *io_source,
|
Some(io_source) => *io_source,
|
||||||
None => {
|
None => {
|
||||||
log::warn!("No IO source assiociated with mio token ({token:?}). Polling using mio tokens directly is an experimental feature and IO handler should recover when all available io sources are polled. This is a developer error. Please report it.");
|
log::warn!("No IO source assiociated with mio token ({token:?}). Polling using mio tokens directly is an experimental feature and IO handler should recover when all available io sources are polled. This is a developer error. Please report it.");
|
||||||
return Ok(None);
|
return Ok(AppServerTryRecvResult::None);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -1540,11 +1568,13 @@ impl AppServer {
|
|||||||
&mut self,
|
&mut self,
|
||||||
buf: &mut [u8],
|
buf: &mut [u8],
|
||||||
io_source: AppServerIoSource,
|
io_source: AppServerIoSource,
|
||||||
) -> anyhow::Result<Option<(usize, Endpoint)>> {
|
) -> anyhow::Result<AppServerTryRecvResult> {
|
||||||
match io_source {
|
match io_source {
|
||||||
|
AppServerIoSource::SignalHandler => self.try_recv_from_signal_handler()?.ok(),
|
||||||
|
|
||||||
AppServerIoSource::Socket(idx) => self
|
AppServerIoSource::Socket(idx) => self
|
||||||
.try_recv_from_listen_socket(buf, idx)
|
.try_recv_from_listen_socket(buf, idx)
|
||||||
.substitute_for_ioerr_wouldblock(None)?
|
.substitute_for_ioerr_wouldblock(AppServerTryRecvResult::None)?
|
||||||
.ok(),
|
.ok(),
|
||||||
|
|
||||||
AppServerIoSource::PskBroker(key) => self
|
AppServerIoSource::PskBroker(key) => self
|
||||||
@@ -1553,7 +1583,7 @@ impl AppServer {
|
|||||||
.get_mut(&key)
|
.get_mut(&key)
|
||||||
.with_context(|| format!("No PSK broker under key {key:?}"))?
|
.with_context(|| format!("No PSK broker under key {key:?}"))?
|
||||||
.process_poll()
|
.process_poll()
|
||||||
.map(|_| None),
|
.map(|_| AppServerTryRecvResult::None),
|
||||||
|
|
||||||
#[cfg(feature = "experiment_api")]
|
#[cfg(feature = "experiment_api")]
|
||||||
AppServerIoSource::MioManager(mmio_src) => {
|
AppServerIoSource::MioManager(mmio_src) => {
|
||||||
@@ -1561,17 +1591,28 @@ impl AppServer {
|
|||||||
|
|
||||||
MioManagerFocus(self)
|
MioManagerFocus(self)
|
||||||
.poll_particular(mmio_src)
|
.poll_particular(mmio_src)
|
||||||
.map(|_| None)
|
.map(|_| AppServerTryRecvResult::None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Internal helper for [Self::try_recv]
|
||||||
|
fn try_recv_from_signal_handler(&mut self) -> io::Result<AppServerTryRecvResult> {
|
||||||
|
#[allow(clippy::never_loop)]
|
||||||
|
for signal in self.signal_handler.pending() {
|
||||||
|
log::debug!("Received operating system signal no {signal}.");
|
||||||
|
log::info!("Received termination request; exiting.");
|
||||||
|
return Ok(AppServerTryRecvResult::Terminate);
|
||||||
|
}
|
||||||
|
Ok(AppServerTryRecvResult::None)
|
||||||
|
}
|
||||||
|
|
||||||
/// Internal helper for [Self::try_recv]
|
/// Internal helper for [Self::try_recv]
|
||||||
fn try_recv_from_listen_socket(
|
fn try_recv_from_listen_socket(
|
||||||
&mut self,
|
&mut self,
|
||||||
buf: &mut [u8],
|
buf: &mut [u8],
|
||||||
idx: usize,
|
idx: usize,
|
||||||
) -> io::Result<Option<(usize, Endpoint)>> {
|
) -> io::Result<AppServerTryRecvResult> {
|
||||||
use std::io::ErrorKind as K;
|
use std::io::ErrorKind as K;
|
||||||
let (n, addr) = loop {
|
let (n, addr) = loop {
|
||||||
match self.sockets[idx].recv_from(buf).io_err_kind_hint() {
|
match self.sockets[idx].recv_from(buf).io_err_kind_hint() {
|
||||||
@@ -1583,8 +1624,7 @@ impl AppServer {
|
|||||||
SocketPtr(idx)
|
SocketPtr(idx)
|
||||||
.apply(|sp| SocketBoundEndpoint::new(sp, addr))
|
.apply(|sp| SocketBoundEndpoint::new(sp, addr))
|
||||||
.apply(Endpoint::SocketBoundAddress)
|
.apply(Endpoint::SocketBoundAddress)
|
||||||
.apply(|ep| (n, ep))
|
.apply(|ep| AppServerTryRecvResult::NetworkMessage(n, ep))
|
||||||
.some()
|
|
||||||
.ok()
|
.ok()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1636,48 +1676,3 @@ impl crate::api::mio::MioManagerContext for MioManagerFocus<'_> {
|
|||||||
self.0
|
self.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// These signal handlers are used exclusively used during coverage testing
|
|
||||||
/// to ensure that the llvm-cov can produce reports during integration tests
|
|
||||||
/// with multiple processes where subprocesses are terminated via kill(2).
|
|
||||||
///
|
|
||||||
/// llvm-cov does not support producing coverage reports when the process exits
|
|
||||||
/// through a signal, so this is necessary.
|
|
||||||
///
|
|
||||||
/// The functionality of exiting gracefully upon reception of a terminating signal
|
|
||||||
/// is desired for the production variant of Rosenpass, but we should make sure
|
|
||||||
/// to use a higher quality implementation; in particular, we should use signalfd(2).
|
|
||||||
///
|
|
||||||
#[cfg(feature = "internal_signal_handling_for_coverage_reports")]
|
|
||||||
mod terminate {
|
|
||||||
use signal_hook::flag::register as sig_register;
|
|
||||||
use std::sync::{
|
|
||||||
atomic::{AtomicBool, Ordering},
|
|
||||||
Arc,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Automatically register a signal handler for common termination signals;
|
|
||||||
/// whether one of these signals was issued can be polled using [Self::value].
|
|
||||||
///
|
|
||||||
/// The signal handler is not removed when this struct goes out of scope.
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct TerminateRequested {
|
|
||||||
value: Arc<AtomicBool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TerminateRequested {
|
|
||||||
/// Register signal handlers watching for common termination signals
|
|
||||||
pub fn new() -> anyhow::Result<Self> {
|
|
||||||
let value = Arc::new(AtomicBool::new(false));
|
|
||||||
for sig in signal_hook::consts::TERM_SIGNALS.iter().copied() {
|
|
||||||
sig_register(sig, Arc::clone(&value))?;
|
|
||||||
}
|
|
||||||
Ok(Self { value })
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check whether a termination signal has been set
|
|
||||||
pub fn value(&self) -> bool {
|
|
||||||
self.value.load(Ordering::Relaxed)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -490,7 +490,7 @@ impl CliArgs {
|
|||||||
cfg_peer.key_out,
|
cfg_peer.key_out,
|
||||||
broker_peer,
|
broker_peer,
|
||||||
cfg_peer.endpoint.clone(),
|
cfg_peer.endpoint.clone(),
|
||||||
cfg_peer.protocol_version.into(),
|
cfg_peer.protocol_version,
|
||||||
cfg_peer.osk_domain_separator.try_into()?,
|
cfg_peer.osk_domain_separator.try_into()?,
|
||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
@@ -515,7 +515,7 @@ impl CliArgs {
|
|||||||
fn create_broker(
|
fn create_broker(
|
||||||
broker_interface: Option<BrokerInterface>,
|
broker_interface: Option<BrokerInterface>,
|
||||||
) -> Result<
|
) -> Result<
|
||||||
Box<dyn WireguardBrokerMio<MioError = anyhow::Error, Error = anyhow::Error>>,
|
Box<dyn WireguardBrokerMio<MioError = anyhow::Error, Error = anyhow::Error> + Send>,
|
||||||
anyhow::Error,
|
anyhow::Error,
|
||||||
> {
|
> {
|
||||||
if let Some(interface) = broker_interface {
|
if let Some(interface) = broker_interface {
|
||||||
|
|||||||
@@ -200,7 +200,7 @@ impl RosenpassPeerOskDomainSeparator {
|
|||||||
pub fn org_and_label(&self) -> anyhow::Result<Option<(&String, &Vec<String>)>> {
|
pub fn org_and_label(&self) -> anyhow::Result<Option<(&String, &Vec<String>)>> {
|
||||||
match (&self.osk_organization, &self.osk_label) {
|
match (&self.osk_organization, &self.osk_label) {
|
||||||
(None, None) => Ok(None),
|
(None, None) => Ok(None),
|
||||||
(Some(org), Some(label)) => Ok(Some((&org, &label))),
|
(Some(org), Some(label)) => Ok(Some((org, label))),
|
||||||
(Some(_), None) => bail!("Specified osk_organization but not osk_label in config file. You need to specify both, or none."),
|
(Some(_), None) => bail!("Specified osk_organization but not osk_label in config file. You need to specify both, or none."),
|
||||||
(None, Some(_)) => bail!("Specified osk_label but not osk_organization in config file. You need to specify both, or none."),
|
(None, Some(_)) => bail!("Specified osk_label but not osk_organization in config file. You need to specify both, or none."),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1943,7 +1943,7 @@ impl CryptoServer {
|
|||||||
&mut self,
|
&mut self,
|
||||||
rx_buf: &[u8],
|
rx_buf: &[u8],
|
||||||
tx_buf: &mut [u8],
|
tx_buf: &mut [u8],
|
||||||
host_identification: &H,
|
_host_identification: &H,
|
||||||
) -> Result<HandleMsgResult> {
|
) -> Result<HandleMsgResult> {
|
||||||
self.handle_msg(rx_buf, tx_buf)
|
self.handle_msg(rx_buf, tx_buf)
|
||||||
}
|
}
|
||||||
@@ -3231,7 +3231,7 @@ impl HandshakeState {
|
|||||||
|
|
||||||
let k = bk.get(srv).value.secret();
|
let k = bk.get(srv).value.secret();
|
||||||
let pt = biscuit.as_bytes();
|
let pt = biscuit.as_bytes();
|
||||||
XAead.encrypt_with_nonce_in_ctxt(biscuit_ct, k, &*n, &ad, pt)?;
|
XAead.encrypt_with_nonce_in_ctxt(biscuit_ct, k, &n, &ad, pt)?;
|
||||||
|
|
||||||
self.mix(biscuit_ct)
|
self.mix(biscuit_ct)
|
||||||
}
|
}
|
||||||
@@ -3421,7 +3421,7 @@ impl CryptoServer {
|
|||||||
|
|
||||||
// IHI3
|
// IHI3
|
||||||
protocol_section!("IHI3", {
|
protocol_section!("IHI3", {
|
||||||
EphemeralKem.keygen(hs.eski.secret_mut(), &mut *hs.epki)?;
|
EphemeralKem.keygen(hs.eski.secret_mut(), &mut hs.epki)?;
|
||||||
ih.epki.copy_from_slice(&hs.epki.value);
|
ih.epki.copy_from_slice(&hs.epki.value);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ fn api_integration_api_setup(protocol_version: ProtocolVersion) -> anyhow::Resul
|
|||||||
peer: format!("{}", peer_b_wg_peer_id.fmt_b64::<8129>()),
|
peer: format!("{}", peer_b_wg_peer_id.fmt_b64::<8129>()),
|
||||||
extra_params: vec![],
|
extra_params: vec![],
|
||||||
}),
|
}),
|
||||||
protocol_version: protocol_version.clone(),
|
protocol_version: protocol_version,
|
||||||
osk_domain_separator: Default::default(),
|
osk_domain_separator: Default::default(),
|
||||||
}],
|
}],
|
||||||
};
|
};
|
||||||
@@ -127,7 +127,7 @@ fn api_integration_api_setup(protocol_version: ProtocolVersion) -> anyhow::Resul
|
|||||||
endpoint: Some(peer_a_endpoint.to_owned()),
|
endpoint: Some(peer_a_endpoint.to_owned()),
|
||||||
pre_shared_key: None,
|
pre_shared_key: None,
|
||||||
wg: None,
|
wg: None,
|
||||||
protocol_version: protocol_version.clone(),
|
protocol_version: protocol_version,
|
||||||
osk_domain_separator: Default::default(),
|
osk_domain_separator: Default::default(),
|
||||||
}],
|
}],
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ fn api_integration_test(protocol_version: ProtocolVersion) -> anyhow::Result<()>
|
|||||||
endpoint: None,
|
endpoint: None,
|
||||||
pre_shared_key: None,
|
pre_shared_key: None,
|
||||||
wg: None,
|
wg: None,
|
||||||
protocol_version: protocol_version.clone(),
|
protocol_version: protocol_version,
|
||||||
osk_domain_separator: Default::default(),
|
osk_domain_separator: Default::default(),
|
||||||
}],
|
}],
|
||||||
};
|
};
|
||||||
@@ -104,7 +104,7 @@ fn api_integration_test(protocol_version: ProtocolVersion) -> anyhow::Result<()>
|
|||||||
endpoint: Some(peer_a_endpoint.to_owned()),
|
endpoint: Some(peer_a_endpoint.to_owned()),
|
||||||
pre_shared_key: None,
|
pre_shared_key: None,
|
||||||
wg: None,
|
wg: None,
|
||||||
protocol_version: protocol_version.clone(),
|
protocol_version: protocol_version,
|
||||||
osk_domain_separator: Default::default(),
|
osk_domain_separator: Default::default(),
|
||||||
}],
|
}],
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ fn check_example_config() {
|
|||||||
|
|
||||||
let tmp_dir = tempdir().unwrap();
|
let tmp_dir = tempdir().unwrap();
|
||||||
let config_path = tmp_dir.path().join("config.toml");
|
let config_path = tmp_dir.path().join("config.toml");
|
||||||
let mut config_file = File::create(config_path.to_owned()).unwrap();
|
let mut config_file = File::create(&config_path).unwrap();
|
||||||
|
|
||||||
config_file
|
config_file
|
||||||
.write_all(
|
.write_all(
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "rp"
|
name = "rosenpass-rp"
|
||||||
version = "0.2.1"
|
version = "0.2.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
@@ -8,7 +8,9 @@ homepage = "https://rosenpass.eu/"
|
|||||||
repository = "https://github.com/rosenpass/rosenpass"
|
repository = "https://github.com/rosenpass/rosenpass"
|
||||||
rust-version = "1.77.0"
|
rust-version = "1.77.0"
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
[[bin]]
|
||||||
|
name = "rp"
|
||||||
|
path = "src/main.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
@@ -17,12 +19,15 @@ serde = { workspace = true }
|
|||||||
toml = { workspace = true }
|
toml = { workspace = true }
|
||||||
x25519-dalek = { workspace = true, features = ["static_secrets"] }
|
x25519-dalek = { workspace = true, features = ["static_secrets"] }
|
||||||
zeroize = { workspace = true }
|
zeroize = { workspace = true }
|
||||||
|
libc = { workspace = true }
|
||||||
|
log = { workspace = true }
|
||||||
|
env_logger = { workspace = true }
|
||||||
|
|
||||||
rosenpass = { workspace = true }
|
rosenpass = { workspace = true }
|
||||||
rosenpass-ciphers = { workspace = true }
|
rosenpass-ciphers = { workspace = true }
|
||||||
rosenpass-cipher-traits = { workspace = true }
|
rosenpass-cipher-traits = { workspace = true }
|
||||||
rosenpass-secret-memory = { workspace = true }
|
rosenpass-secret-memory = { workspace = true }
|
||||||
rosenpass-util = { workspace = true }
|
rosenpass-util = { workspace = true, features = ["tokio"] }
|
||||||
rosenpass-wireguard-broker = { workspace = true }
|
rosenpass-wireguard-broker = { workspace = true }
|
||||||
|
|
||||||
tokio = { workspace = true }
|
tokio = { workspace = true }
|
||||||
|
|||||||
@@ -1,16 +1,63 @@
|
|||||||
use std::{
|
use std::any::type_name;
|
||||||
future::Future, net::SocketAddr, ops::DerefMut, path::PathBuf, pin::Pin, process::Command,
|
use std::{borrow::Borrow, net::SocketAddr, path::PathBuf};
|
||||||
sync::Arc,
|
|
||||||
};
|
|
||||||
|
|
||||||
use anyhow::{Error, Result};
|
use tokio::process::Command;
|
||||||
|
|
||||||
|
use anyhow::{bail, ensure, Context, Result};
|
||||||
|
use futures_util::TryStreamExt as _;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
use rosenpass::config::ProtocolVersion;
|
use rosenpass::config::ProtocolVersion;
|
||||||
|
use rosenpass::{
|
||||||
|
app_server::{AppServer, BrokerPeer},
|
||||||
|
config::Verbosity,
|
||||||
|
protocol::{
|
||||||
|
basic_types::{SPk, SSk, SymKey},
|
||||||
|
osk_domain_separator::OskDomainSeparator,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use rosenpass_secret_memory::Secret;
|
||||||
|
use rosenpass_util::file::{LoadValue as _, LoadValueB64};
|
||||||
|
use rosenpass_util::functional::{ApplyExt, MutatingExt};
|
||||||
|
use rosenpass_util::result::OkExt;
|
||||||
|
use rosenpass_util::tokio::janitor::{spawn_cleanup_job, try_spawn_daemon};
|
||||||
|
use rosenpass_wireguard_broker::brokers::native_unix::{
|
||||||
|
NativeUnixBroker, NativeUnixBrokerConfigBaseBuilder,
|
||||||
|
};
|
||||||
|
use tokio::task::spawn_blocking;
|
||||||
|
|
||||||
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
|
|
||||||
use crate::key::WG_B64_LEN;
|
use crate::key::WG_B64_LEN;
|
||||||
|
|
||||||
|
/// Extra-special measure to structure imports from the various
|
||||||
|
/// netlink related crates used in [super]
|
||||||
|
mod netlink {
|
||||||
|
/// Re-exports from [::netlink_packet_core]
|
||||||
|
pub mod core {
|
||||||
|
pub use ::netlink_packet_core::{NetlinkMessage, NLM_F_ACK, NLM_F_REQUEST};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Re-exports from [::rtnetlink]
|
||||||
|
pub mod rtnl {
|
||||||
|
pub use ::rtnetlink::Error;
|
||||||
|
pub use ::rtnetlink::Handle;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Re-exports from [::genetlink] and [::netlink_packet_generic]
|
||||||
|
pub mod genl {
|
||||||
|
pub use ::genetlink::GenetlinkHandle as Handle;
|
||||||
|
pub use ::netlink_packet_generic::GenlMessage as Message;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Re-exports from [::netlink_packet_wireguard]
|
||||||
|
pub mod wg {
|
||||||
|
pub use ::netlink_packet_wireguard::constants::WG_KEY_LEN as KEY_LEN;
|
||||||
|
pub use ::netlink_packet_wireguard::nlas::WgDeviceAttrs as DeviceAttrs;
|
||||||
|
pub use ::netlink_packet_wireguard::{Wireguard, WireguardCmd};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type WgSecretKey = Secret<{ netlink::wg::KEY_LEN }>;
|
||||||
|
|
||||||
/// Used to define a peer for the rosenpass connection that consists of
|
/// Used to define a peer for the rosenpass connection that consists of
|
||||||
/// a directory for storing public keys and optionally an IP address and port of the endpoint,
|
/// a directory for storing public keys and optionally an IP address and port of the endpoint,
|
||||||
/// for how long the connection should be kept alive and a list of allowed IPs for the peer.
|
/// for how long the connection should be kept alive and a list of allowed IPs for the peer.
|
||||||
@@ -43,286 +90,401 @@ pub struct ExchangeOptions {
|
|||||||
pub dev: Option<String>,
|
pub dev: Option<String>,
|
||||||
/// The IP-address rosenpass should run under.
|
/// The IP-address rosenpass should run under.
|
||||||
pub ip: Option<String>,
|
pub ip: Option<String>,
|
||||||
/// The IP-address and port that the rosenpass [AppServer](rosenpass::app_server::AppServer)
|
/// The IP-address and port that the rosenpass [AppServer]
|
||||||
/// should use.
|
/// should use.
|
||||||
pub listen: Option<SocketAddr>,
|
pub listen: Option<SocketAddr>,
|
||||||
/// Other peers a connection should be initialized to
|
/// Other peers a connection should be initialized to
|
||||||
pub peers: Vec<ExchangePeer>,
|
pub peers: Vec<ExchangePeer>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
|
/// Manage the lifetime of WireGuard devices uses for rp
|
||||||
pub async fn exchange(_: ExchangeOptions) -> Result<()> {
|
#[derive(Debug, Default)]
|
||||||
use anyhow::anyhow;
|
struct WireGuardDeviceImpl {
|
||||||
|
// TODO: Can we merge these two somehow?
|
||||||
Err(anyhow!(
|
rtnl_netlink_handle_cache: Option<netlink::rtnl::Handle>,
|
||||||
"Your system {} is not yet supported. We are happy to receive patches to address this :)",
|
genl_netlink_handle_cache: Option<netlink::genl::Handle>,
|
||||||
std::env::consts::OS
|
/// Handle and name of the device
|
||||||
))
|
device: Option<(u32, String)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
|
impl WireGuardDeviceImpl {
|
||||||
mod netlink {
|
fn take(&mut self) -> WireGuardDeviceImpl {
|
||||||
use anyhow::Result;
|
Self::default().mutating(|nu| std::mem::swap(self, nu))
|
||||||
use futures_util::{StreamExt as _, TryStreamExt as _};
|
}
|
||||||
use genetlink::GenetlinkHandle;
|
|
||||||
use netlink_packet_core::{NLM_F_ACK, NLM_F_REQUEST};
|
|
||||||
use netlink_packet_wireguard::nlas::WgDeviceAttrs;
|
|
||||||
use rtnetlink::Handle;
|
|
||||||
|
|
||||||
|
async fn open(&mut self, device_name: String) -> anyhow::Result<()> {
|
||||||
|
let mut rtnl_link = self.rtnl_netlink_handle()?.link();
|
||||||
|
let device_name_ref = &device_name;
|
||||||
|
|
||||||
|
// Make sure that there is no device called `device_name` before we start
|
||||||
|
rtnl_link
|
||||||
|
.get()
|
||||||
|
.match_name(device_name.to_owned())
|
||||||
|
.execute()
|
||||||
|
// Count the number of occurences
|
||||||
|
.try_fold(0, |acc, _val| async move {
|
||||||
|
Ok(acc + 1)
|
||||||
|
}).await
|
||||||
|
// Extract the error's raw system error code
|
||||||
|
.map_err(|e| {
|
||||||
|
use netlink::rtnl::Error as E;
|
||||||
|
match e {
|
||||||
|
E::NetlinkError(msg) => {
|
||||||
|
let raw_code = -msg.raw_code();
|
||||||
|
(E::NetlinkError(msg), Some(raw_code))
|
||||||
|
},
|
||||||
|
_ => (e, None),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.apply(|r| {
|
||||||
|
match r {
|
||||||
|
// No such device, which is exactly what we are expecting
|
||||||
|
Ok(0) | Err((_, Some(libc::ENODEV))) => Ok(()),
|
||||||
|
// Device already exists
|
||||||
|
Ok(_) => bail!("\
|
||||||
|
Trying to create a network device for Rosenpass under the name \"{device_name}\", \
|
||||||
|
but at least one device under the name aready exists."),
|
||||||
|
// Other error
|
||||||
|
Err((e, _)) => bail!(e),
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Add the link, equivalent to `ip link add <link_name> type wireguard`.
|
||||||
|
rtnl_link
|
||||||
|
.add()
|
||||||
|
.wireguard(device_name.to_owned())
|
||||||
|
.execute()
|
||||||
|
.await?;
|
||||||
|
log::info!("Created network device!");
|
||||||
|
|
||||||
|
// Retrieve a handle for the newly created device
|
||||||
|
let device_handle = rtnl_link
|
||||||
|
.get()
|
||||||
|
.match_name(device_name.to_owned())
|
||||||
|
.execute()
|
||||||
|
.err_into::<anyhow::Error>()
|
||||||
|
.try_fold(Option::None, |acc, val| async move {
|
||||||
|
ensure!(acc.is_none(), "\
|
||||||
|
Created a network device for Rosenpass under the name \"{device_name_ref}\", \
|
||||||
|
but upon trying to determine the handle for the device using named-based lookup, we received multiple handles. \
|
||||||
|
We checked beforehand whether the device already exists. \
|
||||||
|
This should not happen. Unsure how to proceed. Terminating.");
|
||||||
|
Ok(Some(val))
|
||||||
|
}).await?
|
||||||
|
.with_context(|| format!("\
|
||||||
|
Created a network device for Rosenpass under the name \"{device_name}\", \
|
||||||
|
but upon trying to determine the handle for the device using named-based lookup, we received no handle. \
|
||||||
|
This should not happen. Unsure how to proceed. Terminating."))?
|
||||||
|
.apply(|msg| msg.header.index);
|
||||||
|
|
||||||
|
// Now we can actually start to mark the device as initialized.
|
||||||
|
// Note that if the handle retrieval above does not work, the destructor
|
||||||
|
// will not run and the device will not be erased.
|
||||||
|
// This is, for now, the desired behavior as we need the handle to erase
|
||||||
|
// the device anyway.
|
||||||
|
self.device = Some((device_handle, device_name));
|
||||||
|
|
||||||
|
// Activate the link, equivalent to `ip link set dev <DEV> up`.
|
||||||
|
rtnl_link.set(device_handle).up().execute().await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn close(mut self) {
|
||||||
|
// Check if the device is properly initialized and retrieve the device info
|
||||||
|
let (device_handle, device_name) = match self.device.take() {
|
||||||
|
Some(val) => val,
|
||||||
|
// Nothing to do, not yet properly initialized
|
||||||
|
None => return,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Erase the network device; the rest of the function is just error handling
|
||||||
|
let res = async move {
|
||||||
|
self.rtnl_netlink_handle()?
|
||||||
|
.link()
|
||||||
|
.del(device_handle)
|
||||||
|
.execute()
|
||||||
|
.await?;
|
||||||
|
log::debug!("Erased network interface!");
|
||||||
|
anyhow::Ok(())
|
||||||
|
}
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Here we test if the error needs printing at all
|
||||||
|
let res = 'do_print: {
|
||||||
|
// Short-circuit if the deletion was successful
|
||||||
|
let err = match res {
|
||||||
|
Ok(()) => break 'do_print Ok(()),
|
||||||
|
Err(err) => err,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Extract the rtnetlink error, so we can inspect it
|
||||||
|
let err = match err.downcast::<netlink::rtnl::Error>() {
|
||||||
|
Ok(rtnl_err) => rtnl_err,
|
||||||
|
Err(other_err) => break 'do_print Err(other_err),
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: This is a bit brittle, as the rtnetlink error enum looks like
|
||||||
|
// E::NetlinkError is a sort of "unknown error" case. If they explicitly
|
||||||
|
// add support for the "no such device" errors or other ones we check for in
|
||||||
|
// this block, then this code may no longer filter these errors
|
||||||
|
// Extract the raw netlink error code
|
||||||
|
use netlink::rtnl::Error as E;
|
||||||
|
let error_code = match err {
|
||||||
|
E::NetlinkError(ref msg) => -msg.raw_code(),
|
||||||
|
err => break 'do_print Err(err.into()),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check whether its just the "no such device" error
|
||||||
|
#[allow(clippy::single_match)]
|
||||||
|
match error_code {
|
||||||
|
libc::ENODEV => break 'do_print Ok(()),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, we just print the error
|
||||||
|
Err(err.into())
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Err(err) = res {
|
||||||
|
log::warn!("Could not remove network device `{device_name}`: {err:?}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn add_ip_address(&self, addr: &str) -> anyhow::Result<()> {
|
||||||
|
// TODO: Migrate to using netlink
|
||||||
|
Command::new("ip")
|
||||||
|
.args(["address", "add", addr, "dev", self.name()?])
|
||||||
|
.status()
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_open(&self) -> bool {
|
||||||
|
self.device.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn maybe_name(&self) -> Option<&str> {
|
||||||
|
self.device.as_ref().map(|slot| slot.1.borrow())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the raw handle for this device
|
||||||
|
pub fn maybe_raw_handle(&self) -> Option<u32> {
|
||||||
|
self.device.as_ref().map(|slot| slot.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn name(&self) -> anyhow::Result<&str> {
|
||||||
|
self.maybe_name()
|
||||||
|
.with_context(|| format!("{} has not been initialized!", type_name::<Self>()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the raw handle for this device
|
||||||
|
pub fn raw_handle(&self) -> anyhow::Result<u32> {
|
||||||
|
self.maybe_raw_handle()
|
||||||
|
.with_context(|| format!("{} has not been initialized!", type_name::<Self>()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn set_private_key_and_listen_addr(
|
||||||
|
&mut self,
|
||||||
|
wgsk: &WgSecretKey,
|
||||||
|
listen_port: Option<u16>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
use netlink as nl;
|
||||||
|
|
||||||
|
// The attributes to set
|
||||||
|
// TODO: This exposes the secret key; we should probably run this in a separate process
|
||||||
|
// or on a separate stack and have zeroizing allocator globally.
|
||||||
|
let mut attrs = vec![
|
||||||
|
nl::wg::DeviceAttrs::IfIndex(self.raw_handle()?),
|
||||||
|
nl::wg::DeviceAttrs::PrivateKey(*wgsk.secret()),
|
||||||
|
];
|
||||||
|
|
||||||
|
// Optional listen port for WireGuard
|
||||||
|
if let Some(port) = listen_port {
|
||||||
|
attrs.push(nl::wg::DeviceAttrs::ListenPort(port));
|
||||||
|
}
|
||||||
|
|
||||||
|
// The netlink request we are trying to send
|
||||||
|
let req = nl::wg::Wireguard {
|
||||||
|
cmd: nl::wg::WireguardCmd::SetDevice,
|
||||||
|
nlas: attrs,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Boilerplate; wrap the request into more structures
|
||||||
|
let req = req
|
||||||
|
.apply(nl::genl::Message::from_payload)
|
||||||
|
.apply(nl::core::NetlinkMessage::from)
|
||||||
|
.mutating(|req| {
|
||||||
|
req.header.flags = nl::core::NLM_F_REQUEST | nl::core::NLM_F_ACK;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Send the request
|
||||||
|
self.genl_netlink_handle()?
|
||||||
|
.request(req)
|
||||||
|
.await?
|
||||||
|
// Collect all errors (let try_fold do all the work)
|
||||||
|
.try_fold((), |_, _| async move { Ok(()) })
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn take_rtnl_netlink_handle(&mut self) -> Result<netlink::rtnl::Handle> {
|
||||||
|
if let Some(handle) = self.rtnl_netlink_handle_cache.take() {
|
||||||
|
Ok(handle)
|
||||||
|
} else {
|
||||||
|
let (connection, handle, _) = rtnetlink::new_connection()?;
|
||||||
|
|
||||||
|
// Making sure that the connection has a chance to terminate before the
|
||||||
|
// application exits
|
||||||
|
try_spawn_daemon(async move {
|
||||||
|
connection.await;
|
||||||
|
Ok(())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(handle)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rtnl_netlink_handle(&mut self) -> Result<&mut netlink::rtnl::Handle> {
|
||||||
|
let netlink_handle = self.take_rtnl_netlink_handle()?;
|
||||||
|
self.rtnl_netlink_handle_cache.insert(netlink_handle).ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn take_genl_netlink_handle(&mut self) -> Result<netlink::genl::Handle> {
|
||||||
|
if let Some(handle) = self.genl_netlink_handle_cache.take() {
|
||||||
|
Ok(handle)
|
||||||
|
} else {
|
||||||
|
let (connection, handle, _) = genetlink::new_connection()?;
|
||||||
|
|
||||||
|
// Making sure that the connection has a chance to terminate before the
|
||||||
|
// application exits
|
||||||
|
try_spawn_daemon(async move {
|
||||||
|
connection.await;
|
||||||
|
Ok(())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(handle)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn genl_netlink_handle(&mut self) -> Result<&mut netlink::genl::Handle> {
|
||||||
|
let netlink_handle = self.take_genl_netlink_handle()?;
|
||||||
|
self.genl_netlink_handle_cache.insert(netlink_handle).ok()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct WireGuardDevice {
|
||||||
|
_impl: WireGuardDeviceImpl,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WireGuardDevice {
|
||||||
/// Creates a netlink named `link_name` and changes the state to up. It returns the index
|
/// Creates a netlink named `link_name` and changes the state to up. It returns the index
|
||||||
/// of the interface in the list of interfaces as the result or an error if any of the
|
/// of the interface in the list of interfaces as the result or an error if any of the
|
||||||
/// operations of creating the link or changing its state to up fails.
|
/// operations of creating the link or changing its state to up fails.
|
||||||
pub async fn link_create_and_up(rtnetlink: &Handle, link_name: String) -> Result<u32> {
|
pub async fn create_device(device_name: String) -> Result<Self> {
|
||||||
// Add the link, equivalent to `ip link add <link_name> type wireguard`.
|
let mut _impl = WireGuardDeviceImpl::default();
|
||||||
rtnetlink
|
_impl.open(device_name).await?;
|
||||||
.link()
|
assert!(_impl.is_open()); // Sanity check
|
||||||
.add()
|
Ok(WireGuardDevice { _impl })
|
||||||
.wireguard(link_name.clone())
|
}
|
||||||
.execute()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Retrieve the link to be able to up it, equivalent to `ip link show` and then
|
pub fn name(&self) -> &str {
|
||||||
// using the link shown that is identified by `link_name`.
|
self._impl.name().unwrap()
|
||||||
let link = rtnetlink
|
}
|
||||||
.link()
|
|
||||||
.get()
|
/// Return the raw handle for this device
|
||||||
.match_name(link_name.clone())
|
#[allow(dead_code)]
|
||||||
.execute()
|
pub fn raw_handle(&self) -> u32 {
|
||||||
.into_stream()
|
self._impl.raw_handle().unwrap()
|
||||||
.into_future()
|
}
|
||||||
|
|
||||||
|
pub async fn add_ip_address(&self, addr: &str) -> anyhow::Result<()> {
|
||||||
|
self._impl.add_ip_address(addr).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn set_private_key_and_listen_addr(
|
||||||
|
&mut self,
|
||||||
|
wgsk: &WgSecretKey,
|
||||||
|
listen_port: Option<u16>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
self._impl
|
||||||
|
.set_private_key_and_listen_addr(wgsk, listen_port)
|
||||||
.await
|
.await
|
||||||
.0
|
|
||||||
.unwrap()?;
|
|
||||||
|
|
||||||
// Up the link, equivalent to `ip link set dev <DEV> up`.
|
|
||||||
rtnetlink
|
|
||||||
.link()
|
|
||||||
.set(link.header.index)
|
|
||||||
.up()
|
|
||||||
.execute()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(link.header.index)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Deletes a link using rtnetlink. The link is specified using its index in the list of links.
|
|
||||||
pub async fn link_cleanup(rtnetlink: &Handle, index: u32) -> Result<()> {
|
|
||||||
rtnetlink.link().del(index).execute().await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Deletes a link using rtnetlink. The link is specified using its index in the list of links.
|
|
||||||
/// In contrast to [link_cleanup], this function create a new socket connection to netlink and
|
|
||||||
/// *ignores errors* that occur during deletion.
|
|
||||||
pub async fn link_cleanup_standalone(index: u32) -> Result<()> {
|
|
||||||
let (connection, rtnetlink, _) = rtnetlink::new_connection()?;
|
|
||||||
tokio::spawn(connection);
|
|
||||||
|
|
||||||
// We don't care if this fails, as the device may already have been auto-cleaned up.
|
|
||||||
let _ = rtnetlink.link().del(index).execute().await;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This replicates the functionality of the `wg set` command line tool.
|
|
||||||
///
|
|
||||||
/// It sets the specified WireGuard attributes of the indexed device by
|
|
||||||
/// communicating with WireGuard's generic netlink interface, like the
|
|
||||||
/// `wg` tool does.
|
|
||||||
pub async fn wg_set(
|
|
||||||
genetlink: &mut GenetlinkHandle,
|
|
||||||
index: u32,
|
|
||||||
mut attr: Vec<WgDeviceAttrs>,
|
|
||||||
) -> Result<()> {
|
|
||||||
use futures_util::StreamExt as _;
|
|
||||||
use netlink_packet_core::{NetlinkMessage, NetlinkPayload};
|
|
||||||
use netlink_packet_generic::GenlMessage;
|
|
||||||
use netlink_packet_wireguard::{Wireguard, WireguardCmd};
|
|
||||||
|
|
||||||
// Scope our `set` command to only the device of the specified index.
|
|
||||||
attr.insert(0, WgDeviceAttrs::IfIndex(index));
|
|
||||||
|
|
||||||
// Construct the WireGuard-specific netlink packet
|
|
||||||
let wgc = Wireguard {
|
|
||||||
cmd: WireguardCmd::SetDevice,
|
|
||||||
nlas: attr,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Construct final message.
|
|
||||||
let genl = GenlMessage::from_payload(wgc);
|
|
||||||
let mut nlmsg = NetlinkMessage::from(genl);
|
|
||||||
nlmsg.header.flags = NLM_F_REQUEST | NLM_F_ACK;
|
|
||||||
|
|
||||||
// Send and wait for the ACK or error.
|
|
||||||
let (res, _) = genetlink.request(nlmsg).await?.into_future().await;
|
|
||||||
if let Some(res) = res {
|
|
||||||
let res = res?;
|
|
||||||
if let NetlinkPayload::Error(err) = res.payload {
|
|
||||||
return Err(err.to_io().into());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A wrapper for a list of cleanup handlers that can be used in an asynchronous context
|
impl Drop for WireGuardDevice {
|
||||||
/// to clean up after the usage of rosenpass or if the `rp` binary is interrupted with ctrl+c
|
fn drop(&mut self) {
|
||||||
/// or a `SIGINT` signal in general.
|
let _impl = self._impl.take();
|
||||||
#[derive(Clone)]
|
spawn_cleanup_job(async move {
|
||||||
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
|
_impl.close().await;
|
||||||
struct CleanupHandlers(
|
Ok(())
|
||||||
Arc<::futures::lock::Mutex<Vec<Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>>>>,
|
});
|
||||||
);
|
|
||||||
|
|
||||||
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
|
|
||||||
impl CleanupHandlers {
|
|
||||||
/// Creates a new list of [CleanupHandlers].
|
|
||||||
fn new() -> Self {
|
|
||||||
CleanupHandlers(Arc::new(::futures::lock::Mutex::new(vec![])))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Enqueues a new cleanup handler in the form of a [Future].
|
|
||||||
async fn enqueue(&self, handler: Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>) {
|
|
||||||
self.0.lock().await.push(Box::pin(handler))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Runs all cleanup handlers. Following the documentation of [futures::future::try_join_all]:
|
|
||||||
/// If any cleanup handler returns an error then all other cleanup handlers will be canceled and
|
|
||||||
/// an error will be returned immediately. If all cleanup handlers complete successfully,
|
|
||||||
/// however, then the returned future will succeed with a Vec of all the successful results.
|
|
||||||
async fn run(self) -> Result<Vec<()>, Error> {
|
|
||||||
futures::future::try_join_all(self.0.lock().await.deref_mut()).await
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sets up the rosenpass link and wireguard and configures both with the configuration specified by
|
/// Sets up the rosenpass link and wireguard and configures both with the configuration specified by
|
||||||
/// `options`.
|
/// `options`.
|
||||||
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
|
|
||||||
pub async fn exchange(options: ExchangeOptions) -> Result<()> {
|
pub async fn exchange(options: ExchangeOptions) -> Result<()> {
|
||||||
use std::fs;
|
// Load the server parameter files
|
||||||
|
let wgsk = options.private_keys_dir.join("wgsk");
|
||||||
|
let rpsk = options.private_keys_dir.join("pqsk");
|
||||||
|
let rppk = options.private_keys_dir.join("pqpk");
|
||||||
|
let (wgsk, rpsk, rppk) = spawn_blocking(move || {
|
||||||
|
let wgsk = WgSecretKey::load_b64::<WG_B64_LEN, _>(wgsk)?;
|
||||||
|
let rpsk = SSk::load(rpsk)?;
|
||||||
|
let wgpk = SPk::load(rppk)?;
|
||||||
|
anyhow::Ok((wgsk, rpsk, wgpk))
|
||||||
|
})
|
||||||
|
.await??;
|
||||||
|
|
||||||
use anyhow::anyhow;
|
// Setup the WireGuard device
|
||||||
use netlink_packet_wireguard::{constants::WG_KEY_LEN, nlas::WgDeviceAttrs};
|
let device = options.dev.as_deref().unwrap_or("rosenpass0");
|
||||||
use rosenpass::{
|
let mut device = WireGuardDevice::create_device(device.to_owned()).await?;
|
||||||
app_server::{AppServer, BrokerPeer},
|
|
||||||
config::Verbosity,
|
|
||||||
protocol::{
|
|
||||||
basic_types::{SPk, SSk, SymKey},
|
|
||||||
osk_domain_separator::OskDomainSeparator,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
use rosenpass_secret_memory::Secret;
|
|
||||||
use rosenpass_util::file::{LoadValue as _, LoadValueB64};
|
|
||||||
use rosenpass_wireguard_broker::brokers::native_unix::{
|
|
||||||
NativeUnixBroker, NativeUnixBrokerConfigBaseBuilder, NativeUnixBrokerConfigBaseBuilderError,
|
|
||||||
};
|
|
||||||
|
|
||||||
let (connection, rtnetlink, _) = rtnetlink::new_connection()?;
|
// Assign WG secret key & port
|
||||||
tokio::spawn(connection);
|
device
|
||||||
|
.set_private_key_and_listen_addr(&wgsk, options.listen.map(|ip| ip.port() + 1))
|
||||||
|
.await?;
|
||||||
|
std::mem::drop(wgsk);
|
||||||
|
|
||||||
let link_name = options.dev.clone().unwrap_or("rosenpass0".to_string());
|
// Assign the public IP address for the interface
|
||||||
let link_index = netlink::link_create_and_up(&rtnetlink, link_name.clone()).await?;
|
if let Some(ref ip) = options.ip {
|
||||||
|
device.add_ip_address(ip).await?;
|
||||||
// Set up a list of (initiallc empty) cleanup handlers that are to be run if
|
|
||||||
// ctrl-c is hit or generally a `SIGINT` signal is received and always in the end.
|
|
||||||
let cleanup_handlers = CleanupHandlers::new();
|
|
||||||
let final_cleanup_handlers = (&cleanup_handlers).clone();
|
|
||||||
|
|
||||||
cleanup_handlers
|
|
||||||
.enqueue(Box::pin(async move {
|
|
||||||
netlink::link_cleanup_standalone(link_index).await
|
|
||||||
}))
|
|
||||||
.await;
|
|
||||||
|
|
||||||
ctrlc_async::set_async_handler(async move {
|
|
||||||
final_cleanup_handlers
|
|
||||||
.run()
|
|
||||||
.await
|
|
||||||
.expect("Failed to clean up");
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// Run `ip address add <ip> dev <dev>` and enqueue `ip address del <ip> dev <dev>` as a cleanup.
|
|
||||||
if let Some(ip) = options.ip {
|
|
||||||
let dev = options.dev.clone().unwrap_or("rosenpass0".to_string());
|
|
||||||
Command::new("ip")
|
|
||||||
.arg("address")
|
|
||||||
.arg("add")
|
|
||||||
.arg(ip.clone())
|
|
||||||
.arg("dev")
|
|
||||||
.arg(dev.clone())
|
|
||||||
.status()
|
|
||||||
.expect("failed to configure ip");
|
|
||||||
cleanup_handlers
|
|
||||||
.enqueue(Box::pin(async move {
|
|
||||||
Command::new("ip")
|
|
||||||
.arg("address")
|
|
||||||
.arg("del")
|
|
||||||
.arg(ip)
|
|
||||||
.arg("dev")
|
|
||||||
.arg(dev)
|
|
||||||
.status()
|
|
||||||
.expect("failed to remove ip");
|
|
||||||
Ok(())
|
|
||||||
}))
|
|
||||||
.await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deploy the classic wireguard private key.
|
|
||||||
let (connection, mut genetlink, _) = genetlink::new_connection()?;
|
|
||||||
tokio::spawn(connection);
|
|
||||||
|
|
||||||
let wgsk_path = options.private_keys_dir.join("wgsk");
|
|
||||||
|
|
||||||
let wgsk = Secret::<WG_KEY_LEN>::load_b64::<WG_B64_LEN, _>(wgsk_path)?;
|
|
||||||
|
|
||||||
let mut attr: Vec<WgDeviceAttrs> = Vec::with_capacity(2);
|
|
||||||
attr.push(WgDeviceAttrs::PrivateKey(*wgsk.secret()));
|
|
||||||
|
|
||||||
if let Some(listen) = options.listen {
|
|
||||||
if listen.port() == u16::MAX {
|
|
||||||
return Err(anyhow!("You may not use {} as the listen port.", u16::MAX));
|
|
||||||
}
|
|
||||||
|
|
||||||
attr.push(WgDeviceAttrs::ListenPort(listen.port() + 1));
|
|
||||||
}
|
|
||||||
|
|
||||||
netlink::wg_set(&mut genetlink, link_index, attr).await?;
|
|
||||||
|
|
||||||
// set up the rosenpass AppServer
|
|
||||||
let pqsk = options.private_keys_dir.join("pqsk");
|
|
||||||
let pqpk = options.private_keys_dir.join("pqpk");
|
|
||||||
|
|
||||||
let sk = SSk::load(&pqsk)?;
|
|
||||||
let pk = SPk::load(&pqpk)?;
|
|
||||||
|
|
||||||
let mut srv = Box::new(AppServer::new(
|
let mut srv = Box::new(AppServer::new(
|
||||||
Some((sk, pk)),
|
Some((rpsk, rppk)),
|
||||||
if let Some(listen) = options.listen {
|
Vec::from_iter(options.listen),
|
||||||
vec![listen]
|
match options.verbose {
|
||||||
} else {
|
true => Verbosity::Verbose,
|
||||||
Vec::with_capacity(0)
|
false => Verbosity::Quiet,
|
||||||
},
|
|
||||||
if options.verbose {
|
|
||||||
Verbosity::Verbose
|
|
||||||
} else {
|
|
||||||
Verbosity::Quiet
|
|
||||||
},
|
},
|
||||||
None,
|
None,
|
||||||
)?);
|
)?);
|
||||||
|
|
||||||
let broker_store_ptr = srv.register_broker(Box::new(NativeUnixBroker::new()))?;
|
let broker_store_ptr = srv.register_broker(Box::new(NativeUnixBroker::new()))?;
|
||||||
|
|
||||||
fn cfg_err_map(e: NativeUnixBrokerConfigBaseBuilderError) -> anyhow::Error {
|
|
||||||
anyhow::Error::msg(format!("NativeUnixBrokerConfigBaseBuilderError: {:?}", e))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Configure everything per peer.
|
// Configure everything per peer.
|
||||||
for peer in options.peers {
|
for peer in options.peers {
|
||||||
let wgpk = peer.public_keys_dir.join("wgpk");
|
// TODO: Some of this is sync but should be async
|
||||||
|
let wgpk = peer
|
||||||
|
.public_keys_dir
|
||||||
|
.join("wgpk")
|
||||||
|
.apply(tokio::fs::read_to_string)
|
||||||
|
.await?;
|
||||||
let pqpk = peer.public_keys_dir.join("pqpk");
|
let pqpk = peer.public_keys_dir.join("pqpk");
|
||||||
let psk = peer.public_keys_dir.join("psk");
|
let psk = peer.public_keys_dir.join("psk");
|
||||||
|
let (pqpk, psk) = spawn_blocking(move || {
|
||||||
|
let pqpk = SPk::load(pqpk)?;
|
||||||
|
let psk = psk
|
||||||
|
.exists()
|
||||||
|
.then(|| SymKey::load_b64::<WG_B64_LEN, _>(psk))
|
||||||
|
.transpose()?;
|
||||||
|
anyhow::Ok((pqpk, psk))
|
||||||
|
})
|
||||||
|
.await??;
|
||||||
|
|
||||||
let mut extra_params: Vec<String> = Vec::with_capacity(6);
|
let mut extra_params: Vec<String> = Vec::with_capacity(6);
|
||||||
if let Some(endpoint) = peer.endpoint {
|
if let Some(endpoint) = peer.endpoint {
|
||||||
@@ -342,11 +504,11 @@ pub async fn exchange(options: ExchangeOptions) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let peer_cfg = NativeUnixBrokerConfigBaseBuilder::default()
|
let peer_cfg = NativeUnixBrokerConfigBaseBuilder::default()
|
||||||
.peer_id_b64(&fs::read_to_string(wgpk)?)?
|
.peer_id_b64(&wgpk)?
|
||||||
.interface(link_name.clone())
|
.interface(device.name().to_owned())
|
||||||
.extra_params_ser(&extra_params)?
|
.extra_params_ser(&extra_params)?
|
||||||
.build()
|
.build()
|
||||||
.map_err(cfg_err_map)?;
|
.with_context(|| format!("Could not configure broker to supply keys from Rosenpass to WireGuard for peer {wgpk}."))?;
|
||||||
|
|
||||||
let broker_peer = Some(BrokerPeer::new(
|
let broker_peer = Some(BrokerPeer::new(
|
||||||
broker_store_ptr.clone(),
|
broker_store_ptr.clone(),
|
||||||
@@ -354,13 +516,8 @@ pub async fn exchange(options: ExchangeOptions) -> Result<()> {
|
|||||||
));
|
));
|
||||||
|
|
||||||
srv.add_peer(
|
srv.add_peer(
|
||||||
if psk.exists() {
|
psk,
|
||||||
Some(SymKey::load_b64::<WG_B64_LEN, _>(psk))
|
pqpk,
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
.transpose()?,
|
|
||||||
SPk::load(&pqpk)?,
|
|
||||||
None,
|
None,
|
||||||
broker_peer,
|
broker_peer,
|
||||||
peer.endpoint.map(|x| x.to_string()),
|
peer.endpoint.map(|x| x.to_string()),
|
||||||
@@ -372,47 +529,13 @@ pub async fn exchange(options: ExchangeOptions) -> Result<()> {
|
|||||||
// the cleanup as `ip route del <allowed_ips>`.
|
// the cleanup as `ip route del <allowed_ips>`.
|
||||||
if let Some(allowed_ips) = peer.allowed_ips {
|
if let Some(allowed_ips) = peer.allowed_ips {
|
||||||
Command::new("ip")
|
Command::new("ip")
|
||||||
.arg("route")
|
.args(["route", "replace", &allowed_ips, "dev", device.name()])
|
||||||
.arg("replace")
|
|
||||||
.arg(allowed_ips.clone())
|
|
||||||
.arg("dev")
|
|
||||||
.arg(options.dev.clone().unwrap_or("rosenpass0".to_string()))
|
|
||||||
.status()
|
.status()
|
||||||
.expect("failed to configure route");
|
.await
|
||||||
cleanup_handlers
|
.with_context(|| format!("Could not configure routes for peer {wgpk}"))?;
|
||||||
.enqueue(Box::pin(async move {
|
|
||||||
Command::new("ip")
|
|
||||||
.arg("route")
|
|
||||||
.arg("del")
|
|
||||||
.arg(allowed_ips)
|
|
||||||
.status()
|
|
||||||
.expect("failed to remove ip");
|
|
||||||
Ok(())
|
|
||||||
}))
|
|
||||||
.await;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let out = srv.event_loop();
|
log::info!("Starting to perform rosenpass key exchanges!");
|
||||||
|
spawn_blocking(move || srv.event_loop()).await?
|
||||||
netlink::link_cleanup(&rtnetlink, link_index).await?;
|
|
||||||
|
|
||||||
match out {
|
|
||||||
Ok(_) => Ok(()),
|
|
||||||
Err(e) => {
|
|
||||||
// Check if the returned error is actually EINTR, in which case, the run actually
|
|
||||||
// succeeded.
|
|
||||||
let is_ok = if let Some(e) = e.root_cause().downcast_ref::<std::io::Error>() {
|
|
||||||
matches!(e.kind(), std::io::ErrorKind::Interrupted)
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
};
|
|
||||||
|
|
||||||
if is_ok {
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,21 +1,28 @@
|
|||||||
use std::{fs, process::exit};
|
use std::{fs, process::exit};
|
||||||
|
|
||||||
use cli::{Cli, Command};
|
use rosenpass_util::tokio::janitor::ensure_janitor;
|
||||||
use exchange::exchange;
|
|
||||||
use key::{genkey, pubkey};
|
|
||||||
use rosenpass_secret_memory::policy;
|
use rosenpass_secret_memory::policy;
|
||||||
|
|
||||||
|
use crate::cli::{Cli, Command};
|
||||||
|
use crate::exchange::exchange;
|
||||||
|
use crate::key::{genkey, pubkey};
|
||||||
|
|
||||||
mod cli;
|
mod cli;
|
||||||
mod exchange;
|
mod exchange;
|
||||||
mod key;
|
mod key;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() -> anyhow::Result<()> {
|
||||||
#[cfg(feature = "experiment_memfd_secret")]
|
#[cfg(feature = "experiment_memfd_secret")]
|
||||||
policy::secret_policy_try_use_memfd_secrets();
|
policy::secret_policy_try_use_memfd_secrets();
|
||||||
#[cfg(not(feature = "experiment_memfd_secret"))]
|
#[cfg(not(feature = "experiment_memfd_secret"))]
|
||||||
policy::secret_policy_use_only_malloc_secrets();
|
policy::secret_policy_use_only_malloc_secrets();
|
||||||
|
|
||||||
|
ensure_janitor(async move { main_impl().await }).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn main_impl() -> anyhow::Result<()> {
|
||||||
let cli = match Cli::parse(std::env::args().peekable()) {
|
let cli = match Cli::parse(std::env::args().peekable()) {
|
||||||
Ok(cli) => cli,
|
Ok(cli) => cli,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
@@ -24,9 +31,13 @@ async fn main() {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// init logging
|
||||||
|
// TODO: Taken from rosenpass; we should deduplicate the code.
|
||||||
|
env_logger::Builder::from_default_env().init(); // sets log level filter from environment (or defaults)
|
||||||
|
|
||||||
let command = cli.command.unwrap();
|
let command = cli.command.unwrap();
|
||||||
|
|
||||||
let res = match command {
|
match command {
|
||||||
Command::GenKey { private_keys_dir } => genkey(&private_keys_dir),
|
Command::GenKey { private_keys_dir } => genkey(&private_keys_dir),
|
||||||
Command::PubKey {
|
Command::PubKey {
|
||||||
private_keys_dir,
|
private_keys_dir,
|
||||||
@@ -47,13 +58,5 @@ async fn main() {
|
|||||||
println!("Usage: rp [verbose] genkey|pubkey|exchange [ARGS]...");
|
println!("Usage: rp [verbose] genkey|pubkey|exchange [ARGS]...");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
match res {
|
|
||||||
Ok(_) => {}
|
|
||||||
Err(err) => {
|
|
||||||
eprintln!("An error occurred: {}", err);
|
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -379,10 +379,7 @@ impl<const N: usize> StoreSecret for Secret<N> {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
use crate::{
|
use crate::{secret_policy_use_only_malloc_secrets, test_spawn_process_provided_policies};
|
||||||
secret_policy_try_use_memfd_secrets, secret_policy_use_only_malloc_secrets,
|
|
||||||
test_spawn_process_provided_policies,
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use std::{fs, os::unix::fs::PermissionsExt};
|
use std::{fs, os::unix::fs::PermissionsExt};
|
||||||
|
|||||||
@@ -630,7 +630,11 @@ version = "3.2.0"
|
|||||||
criteria = "safe-to-run"
|
criteria = "safe-to-run"
|
||||||
|
|
||||||
[[exemptions.signal-hook]]
|
[[exemptions.signal-hook]]
|
||||||
version = "0.3.17"
|
version = "0.3.18"
|
||||||
|
criteria = "safe-to-deploy"
|
||||||
|
|
||||||
|
[[exemptions.signal-hook-mio]]
|
||||||
|
version = "0.2.4"
|
||||||
criteria = "safe-to-deploy"
|
criteria = "safe-to-deploy"
|
||||||
|
|
||||||
[[exemptions.signal-hook-registry]]
|
[[exemptions.signal-hook-registry]]
|
||||||
|
|||||||
@@ -25,7 +25,15 @@ mio = { workspace = true }
|
|||||||
tempfile = { workspace = true }
|
tempfile = { workspace = true }
|
||||||
uds = { workspace = true, optional = true, features = ["mio_1xx"] }
|
uds = { workspace = true, optional = true, features = ["mio_1xx"] }
|
||||||
libcrux-test-utils = { workspace = true, optional = true }
|
libcrux-test-utils = { workspace = true, optional = true }
|
||||||
|
tokio = { workspace = true, optional = true, features = [
|
||||||
|
"macros",
|
||||||
|
"rt-multi-thread",
|
||||||
|
"sync",
|
||||||
|
"time",
|
||||||
|
] }
|
||||||
|
log = { workspace = true }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
experiment_file_descriptor_passing = ["uds"]
|
experiment_file_descriptor_passing = ["uds"]
|
||||||
trace_bench = ["dep:libcrux-test-utils"]
|
trace_bench = ["dep:libcrux-test-utils"]
|
||||||
|
tokio = ["dep:tokio"]
|
||||||
|
|||||||
82
util/src/fmt/debug.rs
Normal file
82
util/src/fmt/debug.rs
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
//! Helpers for string formatting with the debug formatter; extensions for [std::fmt::Debug]
|
||||||
|
|
||||||
|
use std::any::type_name;
|
||||||
|
use std::borrow::{Borrow, BorrowMut};
|
||||||
|
use std::ops::{Deref, DerefMut};
|
||||||
|
|
||||||
|
/// Debug formatter which just prints the type name;
|
||||||
|
/// used to wrap values which do not support the Debug
|
||||||
|
/// trait themselves
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// use rosenpass_util::fmt::debug::NullDebug;
|
||||||
|
///
|
||||||
|
/// // Does not implement debug
|
||||||
|
/// struct NoDebug;
|
||||||
|
///
|
||||||
|
/// #[derive(Debug)]
|
||||||
|
/// struct ShouldSupportDebug {
|
||||||
|
/// #[allow(dead_code)]
|
||||||
|
/// no_debug: NullDebug<NoDebug>,
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// let val = ShouldSupportDebug {
|
||||||
|
/// no_debug: NullDebug(NoDebug),
|
||||||
|
/// };
|
||||||
|
/// ```
|
||||||
|
pub struct NullDebug<T>(pub T);
|
||||||
|
|
||||||
|
impl<T> std::fmt::Debug for NullDebug<T> {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.write_str("NullDebug<")?;
|
||||||
|
f.write_str(type_name::<T>())?;
|
||||||
|
f.write_str(">")?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> From<T> for NullDebug<T> {
|
||||||
|
fn from(value: T) -> Self {
|
||||||
|
NullDebug(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Deref for NullDebug<T> {
|
||||||
|
type Target = T;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
self.0.borrow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> DerefMut for NullDebug<T> {
|
||||||
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||||
|
self.0.borrow_mut()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Borrow<T> for NullDebug<T> {
|
||||||
|
fn borrow(&self) -> &T {
|
||||||
|
self.deref()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> BorrowMut<T> for NullDebug<T> {
|
||||||
|
fn borrow_mut(&mut self) -> &mut T {
|
||||||
|
self.deref_mut()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> AsRef<T> for NullDebug<T> {
|
||||||
|
fn as_ref(&self) -> &T {
|
||||||
|
self.deref()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> AsMut<T> for NullDebug<T> {
|
||||||
|
fn as_mut(&mut self) -> &mut T {
|
||||||
|
self.deref_mut()
|
||||||
|
}
|
||||||
|
}
|
||||||
3
util/src/fmt/mod.rs
Normal file
3
util/src/fmt/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
//! Helpers for string formatting; extensions for [std::fmt]
|
||||||
|
|
||||||
|
pub mod debug;
|
||||||
@@ -14,6 +14,7 @@ pub mod controlflow;
|
|||||||
pub mod fd;
|
pub mod fd;
|
||||||
/// File system operations and handling.
|
/// File system operations and handling.
|
||||||
pub mod file;
|
pub mod file;
|
||||||
|
pub mod fmt;
|
||||||
/// Functional programming utilities.
|
/// Functional programming utilities.
|
||||||
pub mod functional;
|
pub mod functional;
|
||||||
/// Input/output operations.
|
/// Input/output operations.
|
||||||
@@ -30,6 +31,8 @@ pub mod option;
|
|||||||
pub mod result;
|
pub mod result;
|
||||||
/// Time and duration utilities.
|
/// Time and duration utilities.
|
||||||
pub mod time;
|
pub mod time;
|
||||||
|
#[cfg(feature = "tokio")]
|
||||||
|
pub mod tokio;
|
||||||
/// Trace benchmarking utilities
|
/// Trace benchmarking utilities
|
||||||
#[cfg(feature = "trace_bench")]
|
#[cfg(feature = "trace_bench")]
|
||||||
pub mod trace_bench;
|
pub mod trace_bench;
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ use crate::fd::{claim_fd_inplace, IntoStdioErr};
|
|||||||
/// &io_stream,
|
/// &io_stream,
|
||||||
/// &mut read_fd_buffer,
|
/// &mut read_fd_buffer,
|
||||||
/// );
|
/// );
|
||||||
////
|
///
|
||||||
/// // Simulated reads; the actual operations will depend on the protocol (implementation details)
|
/// // Simulated reads; the actual operations will depend on the protocol (implementation details)
|
||||||
/// let mut recv_buffer = Vec::<u8>::new();
|
/// let mut recv_buffer = Vec::<u8>::new();
|
||||||
/// let bytes_read = fd_passing_sock.read(&mut recv_buffer[..]).expect("error reading from socket");
|
/// let bytes_read = fd_passing_sock.read(&mut recv_buffer[..]).expect("error reading from socket");
|
||||||
|
|||||||
618
util/src/tokio/janitor.rs
Normal file
618
util/src/tokio/janitor.rs
Normal file
@@ -0,0 +1,618 @@
|
|||||||
|
//! Facilities to spawn tasks that will be reliably executed
|
||||||
|
//! before the current tokio context finishes.
|
||||||
|
//!
|
||||||
|
//! Asynchronous applications often need to manage multiple parallel tasks.
|
||||||
|
//! Tokio supports spawning these tasks with [tokio::task::spawn], but when the
|
||||||
|
//! tokio event loop exits, all lingering background tasks will aborted.
|
||||||
|
//!
|
||||||
|
//! Tokio supports managing multiple parallel tasks, all of which should exit successfully, through
|
||||||
|
//! [tokio::task::JoinSet]. This is a useful and very explicit API. To launch a background job,
|
||||||
|
//! user code needs to be aware of which JoinSet to use, so this can lead to a JoinSet needing to
|
||||||
|
//! be handed around in many parts of the application.
|
||||||
|
//!
|
||||||
|
//! This level of explicitness avoids bugs, but it can be cumbersome to use and it can introduce a
|
||||||
|
//! [function coloring](https://morestina.net/1686/rust-async-is-colored) issue;
|
||||||
|
//! creating a strong distinction between functions which have access
|
||||||
|
//! to a JoinSet (one color) and those that have not (the other color). Functions with the color
|
||||||
|
//! that has access to a JoinSet can call those functions that do not need access, but not the
|
||||||
|
//! other way around. This can make refactoring quite difficult: your refactor needs to use a
|
||||||
|
//! function that requires a JoinSet? Then have fun spending quite a bit of time recoloring
|
||||||
|
//! possibly many parts of your code base.
|
||||||
|
//!
|
||||||
|
//! This module solves this issue by essentially registering a central [JoinSet] through ambient
|
||||||
|
//! (semi-global), task-local variables. The mechanism to register this task-local JoinSet is
|
||||||
|
//! [tokio::task_local].
|
||||||
|
//!
|
||||||
|
//! # Error-handling
|
||||||
|
//!
|
||||||
|
//! The janitor accepts daemons/cleanup jobs which return an [anyhow::Error].
|
||||||
|
//! When any daemon returns an error, then the entire janitor will immediately exit with a failure
|
||||||
|
//! without awaiting the other registered tasks.
|
||||||
|
//!
|
||||||
|
//! The janitor can generally produce errors in three scenarios:
|
||||||
|
//!
|
||||||
|
//! - A daemon panics
|
||||||
|
//! - A daemon returns an error
|
||||||
|
//! - An internal error
|
||||||
|
//!
|
||||||
|
//! When [enter_janitor]/[ensure_janitor] is used to set up a janitor, these functions will always
|
||||||
|
//! panic in case of a janitor error. **This also means, that these functions panic if any daemon
|
||||||
|
//! returns an error**.
|
||||||
|
//!
|
||||||
|
//! You can explicitly handle janitor errors through [try_enter_janitor]/[try_ensure_janitor].
|
||||||
|
//!
|
||||||
|
//! # Examples
|
||||||
|
//!
|
||||||
|
#![doc = "```ignore"]
|
||||||
|
#![doc = include_str!("../../tests/janitor.rs")]
|
||||||
|
#![doc = "```"]
|
||||||
|
|
||||||
|
use std::any::type_name;
|
||||||
|
use std::future::Future;
|
||||||
|
|
||||||
|
use anyhow::{bail, Context};
|
||||||
|
|
||||||
|
use tokio::task::{AbortHandle, JoinError, JoinHandle, JoinSet};
|
||||||
|
use tokio::task_local;
|
||||||
|
|
||||||
|
use tokio::sync::mpsc::unbounded_channel as janitor_channel;
|
||||||
|
|
||||||
|
use crate::tokio::local_key::LocalKeyExt;
|
||||||
|
|
||||||
|
/// Type for the message queue from [JanitorClient]/[JanitorSupervisor] to [JanitorAgent]: Receiving side
|
||||||
|
type JanitorQueueRx = tokio::sync::mpsc::UnboundedReceiver<JanitorTicket>;
|
||||||
|
/// Type for the message queue from [JanitorClient]/[JanitorSupervisor] to [JanitorAgent]: Sending side
|
||||||
|
type JanitorQueueTx = tokio::sync::mpsc::UnboundedSender<JanitorTicket>;
|
||||||
|
/// Type for the message queue from [JanitorClient]/[JanitorSupervisor] to [JanitorAgent]: Sending side, Weak reference
|
||||||
|
type WeakJanitorQueueTx = tokio::sync::mpsc::WeakUnboundedSender<JanitorTicket>;
|
||||||
|
|
||||||
|
/// Type of the return value for jobs submitted to [spawn_daemon]/[spawn_cleanup_job]
|
||||||
|
type CleanupJobResult = anyhow::Result<()>;
|
||||||
|
/// Handle by which we internally refer to cleanup jobs submitted by [spawn_daemon]/[spawn_cleanup_job]
|
||||||
|
/// to the current [JanitorAgent]
|
||||||
|
type CleanupJob = JoinHandle<CleanupJobResult>;
|
||||||
|
|
||||||
|
task_local! {
|
||||||
|
/// Handle to the current [JanitorAgent]; this is where [ensure_janitor]/[enter_janitor]
|
||||||
|
/// register the newly created janitor
|
||||||
|
static CURRENT_JANITOR: JanitorClient;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Messages supported by [JanitorAgent]
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum JanitorTicket {
|
||||||
|
/// This message transmits a new cleanup job to the [JanitorAgent]
|
||||||
|
CleanupJob(CleanupJob),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Represents the background task which actually manages cleanup jobs.
|
||||||
|
///
|
||||||
|
/// This is what is started by [enter_janitor]/[ensure_janitor]
|
||||||
|
/// and what receives the messages sent by [JanitorSupervisor]/[JanitorClient]
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct JanitorAgent {
|
||||||
|
/// Background tasks currently registered with this agent.
|
||||||
|
///
|
||||||
|
/// This contains two types of tasks:
|
||||||
|
///
|
||||||
|
/// 1. Background jobs launched through [enter_janitor]/[ensure_janitor]
|
||||||
|
/// 2. A single task waiting for new [JanitorTicket]s being transmitted from a [JanitorSupervisor]/[JanitorClient]
|
||||||
|
tasks: JoinSet<AgentInternalEvent>,
|
||||||
|
/// Whether this [JanitorAgent] will ever receive new [JanitorTicket]s
|
||||||
|
///
|
||||||
|
/// Communication between [JanitorAgent] and [JanitorSupervisor]/[JanitorClient] uses a message
|
||||||
|
/// queue (see [JanitorQueueTx]/[JanitorQueueRx]/[WeakJanitorQueueTx]), but you may notice that
|
||||||
|
/// the Agent does not actually contain a field storing the message queue.
|
||||||
|
/// Instead, to appease the borrow checker, the message queue is moved into the internal
|
||||||
|
/// background task (see [Self::tasks]) that waits for new [JanitorTicket]s.
|
||||||
|
///
|
||||||
|
/// Since our state machine still needs to know, whether that queue is closed, we maintain this
|
||||||
|
/// flag.
|
||||||
|
///
|
||||||
|
/// See [AgentInternalEvent::TicketQueueClosed].
|
||||||
|
ticket_queue_closed: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// These are the return values (events) returned by [JanitorAgent] internal tasks (see
|
||||||
|
/// [JanitorAgent::tasks]).
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum AgentInternalEvent {
|
||||||
|
/// Notifies the [JanitorAgent] state machine that a cleanup job finished successfully
|
||||||
|
///
|
||||||
|
/// Sent by genuine background tasks registered through [enter_janitor]/[ensure_janitor].
|
||||||
|
CleanupJobSuccessful,
|
||||||
|
/// Notifies the [JanitorAgent] state machine that a cleanup job finished with a tokio
|
||||||
|
/// [JoinError].
|
||||||
|
///
|
||||||
|
/// Sent by genuine background tasks registered through [enter_janitor]/[ensure_janitor].
|
||||||
|
CleanupJobJoinError(JoinError),
|
||||||
|
/// Notifies the [JanitorAgent] state machine that a cleanup job returned an error.
|
||||||
|
///
|
||||||
|
/// Sent by genuine background tasks registered through [enter_janitor]/[ensure_janitor].
|
||||||
|
CleanupJobReturnedError(anyhow::Error),
|
||||||
|
/// Notifies the [JanitorAgent] state machine that a new cleanup job was received through the
|
||||||
|
/// ticket queue.
|
||||||
|
///
|
||||||
|
/// Sent by the background task managing the ticket queue.
|
||||||
|
ReceivedCleanupJob(JanitorQueueRx, CleanupJob),
|
||||||
|
/// Notifies the [JanitorAgent] state machine that a new cleanup job was received through the
|
||||||
|
/// ticket queue.
|
||||||
|
///
|
||||||
|
/// Sent by the background task managing the ticket queue.
|
||||||
|
///
|
||||||
|
/// See [JanitorAgent::ticket_queue_closed].
|
||||||
|
TicketQueueClosed,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl JanitorAgent {
|
||||||
|
/// Create a new agent. Start with [Self::start].
|
||||||
|
fn new() -> Self {
|
||||||
|
let tasks = JoinSet::new();
|
||||||
|
let ticket_queue_closed = false;
|
||||||
|
Self {
|
||||||
|
tasks,
|
||||||
|
ticket_queue_closed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Main entry point for the [JanitorAgent]. Launches the background task and returns a [JanitorSupervisor]
|
||||||
|
/// which can be used to send tickets to the agent and to wait for agent termination.
|
||||||
|
pub async fn start() -> JanitorSupervisor {
|
||||||
|
let (queue_tx, queue_rx) = janitor_channel();
|
||||||
|
let join_handle = tokio::spawn(async move { Self::new().event_loop(queue_rx).await });
|
||||||
|
JanitorSupervisor::new(join_handle, queue_tx)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Event loop, processing events from the ticket queue and from [Self::tasks]
|
||||||
|
async fn event_loop(&mut self, queue_rx: JanitorQueueRx) -> anyhow::Result<()> {
|
||||||
|
// Seed the internal task list with a single task to receive
|
||||||
|
self.spawn_internal_recv_ticket_task(queue_rx).await;
|
||||||
|
|
||||||
|
// Process all incoming events until handle_one_event indicates there are
|
||||||
|
// no more events to process
|
||||||
|
while self.handle_one_event().await?.is_some() {}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Process events from [Self::tasks] (and by proxy from the ticket queue)
|
||||||
|
///
|
||||||
|
/// This is the agent's main state machine.
|
||||||
|
async fn handle_one_event(&mut self) -> anyhow::Result<Option<()>> {
|
||||||
|
use AgentInternalEvent as E;
|
||||||
|
match (self.tasks.join_next().await, self.ticket_queue_closed) {
|
||||||
|
// Normal, successful operation
|
||||||
|
|
||||||
|
// CleanupJob exited successfully, no action neccesary
|
||||||
|
(Some(Ok(E::CleanupJobSuccessful)), _) => Ok(Some(())),
|
||||||
|
|
||||||
|
// New cleanup job scheduled, add to task list and wait for another task
|
||||||
|
(Some(Ok(E::ReceivedCleanupJob(queue_rx, job))), _) => {
|
||||||
|
self.spawn_internal_recv_ticket_task(queue_rx).await;
|
||||||
|
self.spawn_internal_cleanup_task(job).await;
|
||||||
|
Ok(Some(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ticket queue is closed; now we are just waiting for the remaining cleanup jobs
|
||||||
|
// to terminate
|
||||||
|
(Some(Ok(E::TicketQueueClosed)), _) => {
|
||||||
|
self.ticket_queue_closed = true;
|
||||||
|
Ok(Some(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// No more tasks in the task manager and the ticket queue is already closed.
|
||||||
|
// This just means we are done and can finally terminate the janitor agent
|
||||||
|
(Option::None, true) => Ok(None),
|
||||||
|
|
||||||
|
// Error handling
|
||||||
|
|
||||||
|
// User callback errors
|
||||||
|
|
||||||
|
// Some cleanup job returned an error as a result
|
||||||
|
(Some(Ok(E::CleanupJobReturnedError(err))), _) => Err(err).with_context(|| {
|
||||||
|
format!("Error in cleanup job handled by {}", type_name::<Self>())
|
||||||
|
}),
|
||||||
|
|
||||||
|
// JoinError produced by the user task: The user task was cancelled.
|
||||||
|
(Some(Ok(E::CleanupJobJoinError(err))), _) if err.is_cancelled() => Err(err).with_context(|| {
|
||||||
|
format!(
|
||||||
|
"Error in cleanup job handled by {me}; the cleanup task was cancelled.
|
||||||
|
This should not happend and likely indicates a developer error in {me}.",
|
||||||
|
me = type_name::<Self>()
|
||||||
|
)
|
||||||
|
}),
|
||||||
|
|
||||||
|
// JoinError produced by the user task: The user task panicked
|
||||||
|
(Some(Ok(E::CleanupJobJoinError(err))), _) => Err(err).with_context(|| {
|
||||||
|
format!(
|
||||||
|
"Error in cleanup job handled by {}; looks like the cleanup task panicked.",
|
||||||
|
type_name::<Self>()
|
||||||
|
)
|
||||||
|
}),
|
||||||
|
|
||||||
|
// Internal errors: Internal task error
|
||||||
|
|
||||||
|
// JoinError produced by JoinSet::join_next(): The internal task was cancelled
|
||||||
|
(Some(Err(err)), _) if err.is_cancelled() => Err(err).with_context(|| {
|
||||||
|
format!(
|
||||||
|
"Internal error in {me}; internal async task was cancelled. \
|
||||||
|
This is probably a developer error in {me}.",
|
||||||
|
me = type_name::<Self>()
|
||||||
|
)
|
||||||
|
}),
|
||||||
|
|
||||||
|
// JoinError produced by JoinSet::join_next(): The internal task panicked
|
||||||
|
(Some(Err(err)), _) => Err(err).with_context(|| {
|
||||||
|
format!(
|
||||||
|
"Internal error in {me}; internal async task panicked. \
|
||||||
|
This is probably a developer error in {me}.",
|
||||||
|
me = type_name::<Self>()
|
||||||
|
)
|
||||||
|
}),
|
||||||
|
|
||||||
|
|
||||||
|
// Internal errors: State machine failure
|
||||||
|
|
||||||
|
// No tasks left, but ticket queue was not drained
|
||||||
|
(Option::None, false) => bail!("Internal error in {me}::handle_one_event(); \
|
||||||
|
there are no more internal tasks active, but the ticket queue was not drained. \
|
||||||
|
The {me}::handle_one_event() code is deliberately designed to never leave the internal task set empty; \
|
||||||
|
instead, there should always be one task to receive new cleanup jobs from the task queue unless the task \
|
||||||
|
queue has been closed. \
|
||||||
|
This is probably a developer error.",
|
||||||
|
me = type_name::<Self>())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Used by [Self::event_loop] and [Self::handle_one_event] to start the internal
|
||||||
|
/// task waiting for tickets on the ticket queue.
|
||||||
|
async fn spawn_internal_recv_ticket_task(
|
||||||
|
&mut self,
|
||||||
|
mut queue_rx: JanitorQueueRx,
|
||||||
|
) -> AbortHandle {
|
||||||
|
self.tasks.spawn(async {
|
||||||
|
use AgentInternalEvent as E;
|
||||||
|
use JanitorTicket as T;
|
||||||
|
|
||||||
|
let ticket = queue_rx.recv().await;
|
||||||
|
match ticket {
|
||||||
|
Some(T::CleanupJob(job)) => E::ReceivedCleanupJob(queue_rx, job),
|
||||||
|
Option::None => E::TicketQueueClosed,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Used by [Self::event_loop] and [Self::handle_one_event] to register
|
||||||
|
/// background deamons/cleanup jobs submitted via [JanitorTicket]
|
||||||
|
async fn spawn_internal_cleanup_task(&mut self, job: CleanupJob) -> AbortHandle {
|
||||||
|
self.tasks.spawn(async {
|
||||||
|
use AgentInternalEvent as E;
|
||||||
|
match job.await {
|
||||||
|
Ok(Ok(())) => E::CleanupJobSuccessful,
|
||||||
|
Ok(Err(e)) => E::CleanupJobReturnedError(e),
|
||||||
|
Err(e) => E::CleanupJobJoinError(e),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Client for [JanitorAgent]. Allows for [JanitorTicket]s (background jobs)
|
||||||
|
/// to be transmitted to the current [JanitorAgent].
|
||||||
|
///
|
||||||
|
/// This is stored in [CURRENT_JANITOR] as a task.-local variable.
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct JanitorClient {
|
||||||
|
/// Queue we can use to send messages to the current janitor
|
||||||
|
queue_tx: WeakJanitorQueueTx,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl JanitorClient {
|
||||||
|
/// Create a new client. Use through [JanitorSupervisor::get_client]
|
||||||
|
fn new(queue_tx: WeakJanitorQueueTx) -> Self {
|
||||||
|
Self { queue_tx }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Has the associated [JanitorAgent] shut down?
|
||||||
|
pub fn is_closed(&self) -> bool {
|
||||||
|
self.queue_tx
|
||||||
|
.upgrade()
|
||||||
|
.map(|channel| channel.is_closed())
|
||||||
|
.unwrap_or(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Spawn a new cleanup job/daemon with the associated [JanitorAgent].
|
||||||
|
///
|
||||||
|
/// Used internally by [spawn_daemon]/[spawn_cleanup_job].
|
||||||
|
pub fn spawn_cleanup_task<F>(&self, future: F) -> Result<(), TrySpawnCleanupJobError>
|
||||||
|
where
|
||||||
|
F: Future<Output = anyhow::Result<()>> + Send + 'static,
|
||||||
|
{
|
||||||
|
let background_task = tokio::spawn(future);
|
||||||
|
self.queue_tx
|
||||||
|
.upgrade()
|
||||||
|
.ok_or(TrySpawnCleanupJobError::ActiveJanitorTerminating)?
|
||||||
|
.send(JanitorTicket::CleanupJob(background_task))
|
||||||
|
.map_err(|_| TrySpawnCleanupJobError::ActiveJanitorTerminating)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Client for [JanitorAgent]. Allows waiting for [JanitorAgent] termination as well as creating
|
||||||
|
/// [JanitorClient]s, which in turn can be used to submit background daemons/termination jobs
|
||||||
|
/// to the agent.
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct JanitorSupervisor {
|
||||||
|
/// Represents the tokio task associated with the [JanitorAgent].
|
||||||
|
///
|
||||||
|
/// We use this to wait for [JanitorAgent] termination in [enter_janitor]/[ensure_janitor]
|
||||||
|
agent_join_handle: CleanupJob,
|
||||||
|
/// Queue we can use to send messages to the current janitor
|
||||||
|
queue_tx: JanitorQueueTx,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl JanitorSupervisor {
|
||||||
|
/// Create a new janitor supervisor. Use through [JanitorAgent::start]
|
||||||
|
pub fn new(agent_join_handle: CleanupJob, queue_tx: JanitorQueueTx) -> Self {
|
||||||
|
Self {
|
||||||
|
agent_join_handle,
|
||||||
|
queue_tx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a [JanitorClient] for submitting background daemons/cleanup jobs
|
||||||
|
pub fn get_client(&self) -> JanitorClient {
|
||||||
|
JanitorClient::new(self.queue_tx.clone().downgrade())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wait for [JanitorAgent] termination
|
||||||
|
pub async fn terminate_janitor(self) -> anyhow::Result<()> {
|
||||||
|
std::mem::drop(self.queue_tx);
|
||||||
|
self.agent_join_handle.await?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return value of [try_enter_janitor].
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct EnterJanitorResult<T, E> {
|
||||||
|
/// The result produced by the janitor itself.
|
||||||
|
///
|
||||||
|
/// This may contain an error if one of the background daemons/cleanup tasks returned an error,
|
||||||
|
/// panicked, or in case there is an internal error in the janitor.
|
||||||
|
pub janitor_result: anyhow::Result<()>,
|
||||||
|
/// Contains the result of the future passed to [try_enter_janitor].
|
||||||
|
pub callee_result: Result<T, E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, E> EnterJanitorResult<T, E> {
|
||||||
|
/// Create a new result from its components
|
||||||
|
pub fn new(janitor_result: anyhow::Result<()>, callee_result: Result<T, E>) -> Self {
|
||||||
|
Self {
|
||||||
|
janitor_result,
|
||||||
|
callee_result,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Turn this named type into a tuple
|
||||||
|
pub fn into_tuple(self) -> (anyhow::Result<()>, Result<T, E>) {
|
||||||
|
(self.janitor_result, self.callee_result)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Panic if [Self::janitor_result] contains an error; returning [Self::callee_result]
|
||||||
|
/// otherwise.
|
||||||
|
///
|
||||||
|
/// If this panics and both [Self::janitor_result] and [Self::callee_result] contain an error,
|
||||||
|
/// this will print both errors.
|
||||||
|
pub fn unwrap_janitor_result(self) -> Result<T, E>
|
||||||
|
where
|
||||||
|
E: std::fmt::Debug,
|
||||||
|
{
|
||||||
|
let me: EnsureJanitorResult<T, E> = self.into();
|
||||||
|
me.unwrap_janitor_result()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Panic if [Self::janitor_result] or [Self::callee_result] contain an error,
|
||||||
|
/// returning the Ok value of [Self::callee_result].
|
||||||
|
///
|
||||||
|
/// If this panics and both [Self::janitor_result] and [Self::callee_result] contain an error,
|
||||||
|
/// this will print both errors.
|
||||||
|
pub fn unwrap(self) -> T
|
||||||
|
where
|
||||||
|
E: std::fmt::Debug,
|
||||||
|
{
|
||||||
|
let me: EnsureJanitorResult<T, E> = self.into();
|
||||||
|
me.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return value of [try_ensure_janitor]. The only difference compared to [EnterJanitorResult]
|
||||||
|
/// is that [Self::janitor_result] contains None in case an ambient janitor had already existed.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct EnsureJanitorResult<T, E> {
|
||||||
|
/// See [EnterJanitorResult::janitor_result]
|
||||||
|
///
|
||||||
|
/// This is:
|
||||||
|
///
|
||||||
|
/// - `None` if a pre-existing ambient janitor was used
|
||||||
|
/// - `Some(Ok(()))` if a new janitor had to be created and it exited successfully
|
||||||
|
/// - `Some(Err(...))` if a new janitor had to be created and it exited with an error
|
||||||
|
pub janitor_result: Option<anyhow::Result<()>>,
|
||||||
|
/// See [EnterJanitorResult::callee]
|
||||||
|
pub callee_result: Result<T, E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, E> EnsureJanitorResult<T, E> {
|
||||||
|
/// See [EnterJanitorResult::new]
|
||||||
|
pub fn new(janitor_result: Option<anyhow::Result<()>>, callee_result: Result<T, E>) -> Self {
|
||||||
|
Self {
|
||||||
|
janitor_result,
|
||||||
|
callee_result,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets up a [EnsureJanitorResult] with [EnsureJanitorResult::janitor_result] = None.
|
||||||
|
pub fn from_callee_result(callee_result: Result<T, E>) -> Self {
|
||||||
|
Self::new(None, callee_result)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Turn this named type into a tuple
|
||||||
|
pub fn into_tuple(self) -> (Option<anyhow::Result<()>>, Result<T, E>) {
|
||||||
|
(self.janitor_result, self.callee_result)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// See [EnterJanitorResult::unwrap_janitor_result]
|
||||||
|
///
|
||||||
|
/// If [Self::janitor_result] is None, this won't panic.
|
||||||
|
pub fn unwrap_janitor_result(self) -> Result<T, E>
|
||||||
|
where
|
||||||
|
E: std::fmt::Debug,
|
||||||
|
{
|
||||||
|
match self.into_tuple() {
|
||||||
|
(Some(Ok(())) | None, res) => res,
|
||||||
|
(Some(Err(err)), Ok(_)) => panic!(
|
||||||
|
"Callee in enter_janitor()/ensure_janitor() was successful, \
|
||||||
|
but the janitor or some of its deamons failed: {err:?}"
|
||||||
|
),
|
||||||
|
(Some(Err(jerr)), Err(cerr)) => panic!(
|
||||||
|
"Both the calee and the janitor or \
|
||||||
|
some of its deamons falied in enter_janitor()/ensure_janitor():\n\
|
||||||
|
\n\
|
||||||
|
Janitor/Daemon error: {jerr:?}
|
||||||
|
\n\
|
||||||
|
Callee error: {cerr:?}"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// See [EnterJanitorResult::unwrap]
|
||||||
|
///
|
||||||
|
/// If [Self::janitor_result] is None, this is not considered a failure.
|
||||||
|
pub fn unwrap(self) -> T
|
||||||
|
where
|
||||||
|
E: std::fmt::Debug,
|
||||||
|
{
|
||||||
|
match self.unwrap_janitor_result() {
|
||||||
|
Ok(val) => val,
|
||||||
|
Err(err) => panic!(
|
||||||
|
"Janitor or and its deamons in in enter_janitor()/ensure_janitor() was successful, \
|
||||||
|
but the callee itself failed: {err:?}"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, E> From<EnterJanitorResult<T, E>> for EnsureJanitorResult<T, E> {
|
||||||
|
fn from(val: EnterJanitorResult<T, E>) -> Self {
|
||||||
|
EnsureJanitorResult::new(Some(val.janitor_result), val.callee_result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Non-panicking version of [enter_janitor].
|
||||||
|
pub async fn try_enter_janitor<T, E, F>(future: F) -> EnterJanitorResult<T, E>
|
||||||
|
where
|
||||||
|
T: 'static,
|
||||||
|
F: Future<Output = Result<T, E>> + 'static,
|
||||||
|
{
|
||||||
|
let janitor_handle = JanitorAgent::start().await;
|
||||||
|
let callee_result = CURRENT_JANITOR
|
||||||
|
.scope(janitor_handle.get_client(), future)
|
||||||
|
.await;
|
||||||
|
let janitor_result = janitor_handle.terminate_janitor().await;
|
||||||
|
EnterJanitorResult::new(janitor_result, callee_result)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Non-panicking version of [ensure_janitor]
|
||||||
|
pub async fn try_ensure_janitor<T, E, F>(future: F) -> EnsureJanitorResult<T, E>
|
||||||
|
where
|
||||||
|
T: 'static,
|
||||||
|
F: Future<Output = Result<T, E>> + 'static,
|
||||||
|
{
|
||||||
|
match CURRENT_JANITOR.is_set() {
|
||||||
|
true => EnsureJanitorResult::from_callee_result(future.await),
|
||||||
|
false => try_enter_janitor(future).await.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register a janitor that can be used to register background daemons/cleanup jobs **only within
|
||||||
|
/// the future passed to this**.
|
||||||
|
///
|
||||||
|
/// The function will wait for both the given future and all background jobs registered with the
|
||||||
|
/// janitor to terminate.
|
||||||
|
///
|
||||||
|
/// For a version that does not panick, see [try_enter_janitor].
|
||||||
|
pub async fn enter_janitor<T, E, F>(future: F) -> Result<T, E>
|
||||||
|
where
|
||||||
|
T: 'static,
|
||||||
|
E: std::fmt::Debug,
|
||||||
|
F: Future<Output = Result<T, E>> + 'static,
|
||||||
|
{
|
||||||
|
try_enter_janitor(future).await.unwrap_janitor_result()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Variant of [enter_janitor] that will first check if a janitor already exists.
|
||||||
|
/// A new janitor is only set up, if no janitor has been previously registered.
|
||||||
|
pub async fn ensure_janitor<T, E, F>(future: F) -> Result<T, E>
|
||||||
|
where
|
||||||
|
T: 'static,
|
||||||
|
E: std::fmt::Debug,
|
||||||
|
F: Future<Output = Result<T, E>> + 'static,
|
||||||
|
{
|
||||||
|
try_ensure_janitor(future).await.unwrap_janitor_result()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Error returned by [try_spawn_cleanup_job]
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum TrySpawnCleanupJobError {
|
||||||
|
/// No active janitor exists
|
||||||
|
#[error("No janitor registered. Did the developer forget to call enter_janitor(…) or ensure_janitor(…)?")]
|
||||||
|
NoActiveJanitor,
|
||||||
|
/// The currently active janitor is in the process of terminating
|
||||||
|
#[error("There is a registered janitor, but it is currently in the process of terminating and won't accept new tasks.")]
|
||||||
|
ActiveJanitorTerminating,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check whether a janitor has been set up with [enter_janitor]/[ensure_janitor]
|
||||||
|
pub fn has_active_janitor() -> bool {
|
||||||
|
CURRENT_JANITOR
|
||||||
|
.try_with(|client| client.is_closed())
|
||||||
|
.unwrap_or(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Non-panicking variant of [spawn_cleanup_job].
|
||||||
|
///
|
||||||
|
/// This function is available under two names; see [spawn_cleanup_job] for details about this:
|
||||||
|
///
|
||||||
|
/// 1. [try_spawn_cleanup_job]
|
||||||
|
/// 2. [try_spawn_daemon]
|
||||||
|
pub fn try_spawn_cleanup_job<F>(future: F) -> Result<(), TrySpawnCleanupJobError>
|
||||||
|
where
|
||||||
|
F: Future<Output = anyhow::Result<()>> + Send + 'static,
|
||||||
|
{
|
||||||
|
CURRENT_JANITOR
|
||||||
|
.try_with(|client| client.spawn_cleanup_task(future))
|
||||||
|
.map_err(|_| TrySpawnCleanupJobError::NoActiveJanitor)??;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register a cleanup job or a daemon with the current janitor registered through
|
||||||
|
/// [enter_janitor]/[ensure_janitor]:
|
||||||
|
///
|
||||||
|
/// This function is available under two names:
|
||||||
|
///
|
||||||
|
/// 1. [spawn_cleanup_job]
|
||||||
|
/// 2. [spawn_daemon]
|
||||||
|
///
|
||||||
|
/// The first name should be used in destructors and to spawn cleanup actions which immediately
|
||||||
|
/// begin their task.
|
||||||
|
///
|
||||||
|
/// The second name should be used for any other tasks; e.g. when the janitor setup is used to
|
||||||
|
/// manage multiple parallel jobs, all of which must be waited for.
|
||||||
|
pub fn spawn_cleanup_job<F>(future: F)
|
||||||
|
where
|
||||||
|
F: Future<Output = anyhow::Result<()>> + Send + 'static,
|
||||||
|
{
|
||||||
|
if let Err(e) = try_spawn_cleanup_job(future) {
|
||||||
|
panic!("Could not spawn cleanup job/daemon: {e:?}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub use spawn_cleanup_job as spawn_daemon;
|
||||||
|
pub use try_spawn_cleanup_job as try_spawn_daemon;
|
||||||
13
util/src/tokio/local_key.rs
Normal file
13
util/src/tokio/local_key.rs
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
//! Helpers for [tokio::task::LocalKey]
|
||||||
|
|
||||||
|
/// Extension trait for [tokio::task::LocalKey]
|
||||||
|
pub trait LocalKeyExt {
|
||||||
|
/// Check whether a tokio LocalKey is set
|
||||||
|
fn is_set(&'static self) -> bool;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: 'static> LocalKeyExt for tokio::task::LocalKey<T> {
|
||||||
|
fn is_set(&'static self) -> bool {
|
||||||
|
self.try_with(|_| ()).is_ok()
|
||||||
|
}
|
||||||
|
}
|
||||||
4
util/src/tokio/mod.rs
Normal file
4
util/src/tokio/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
//! Tokio-related utilities
|
||||||
|
|
||||||
|
pub mod janitor;
|
||||||
|
pub mod local_key;
|
||||||
85
util/tests/janitor.rs
Normal file
85
util/tests/janitor.rs
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
#![cfg(feature = "tokio")]
|
||||||
|
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use tokio::time::sleep;
|
||||||
|
|
||||||
|
use rosenpass_util::tokio::janitor::{enter_janitor, spawn_cleanup_job, try_spawn_daemon};
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn janitor_demo() -> anyhow::Result<()> {
|
||||||
|
let count = Arc::new(AtomicUsize::new(0));
|
||||||
|
|
||||||
|
// Make sure the program has access to an ambient janitor
|
||||||
|
{
|
||||||
|
let count = count.clone();
|
||||||
|
enter_janitor(async move {
|
||||||
|
let _drop_guard = AsyncDropDemo::new(count.clone()).await;
|
||||||
|
|
||||||
|
// Start a background job
|
||||||
|
{
|
||||||
|
let count = count.clone();
|
||||||
|
try_spawn_daemon(async move {
|
||||||
|
for _ in 0..17 {
|
||||||
|
count.fetch_add(1, Ordering::Relaxed);
|
||||||
|
sleep(Duration::from_micros(200)).await;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start another
|
||||||
|
{
|
||||||
|
let count = count.clone();
|
||||||
|
try_spawn_daemon(async move {
|
||||||
|
for _ in 0..6 {
|
||||||
|
count.fetch_add(100, Ordering::Relaxed);
|
||||||
|
sleep(Duration::from_micros(800)).await;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note how this function just starts a couple background jobs, but exits immediately
|
||||||
|
|
||||||
|
anyhow::Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// At this point, all background jobs have finished, now we can check the result of all our
|
||||||
|
// additions
|
||||||
|
assert_eq!(count.load(Ordering::Acquire), 41617);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Demo of how janitor can be used to implement async destructors
|
||||||
|
struct AsyncDropDemo {
|
||||||
|
count: Arc<AtomicUsize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncDropDemo {
|
||||||
|
async fn new(count: Arc<AtomicUsize>) -> Self {
|
||||||
|
count.fetch_add(1000, Ordering::Relaxed);
|
||||||
|
sleep(Duration::from_micros(50)).await;
|
||||||
|
AsyncDropDemo { count }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for AsyncDropDemo {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let count = self.count.clone();
|
||||||
|
// This necessarily uses the panicking variant;
|
||||||
|
// we use spawn_cleanup_job because this makes more semantic sense in this context
|
||||||
|
spawn_cleanup_job(async move {
|
||||||
|
for _ in 0..4 {
|
||||||
|
count.fetch_add(10000, Ordering::Relaxed);
|
||||||
|
sleep(Duration::from_micros(800)).await;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user