Secret memory with memfd_secret (#321)

Implements:
- An additional allocator to use memfd_secret(2) and guard pages using mmap(2), implemented in quininer/memsec#16
- An allocator that abstracts away underlying allocators, and uses specified allocator set by rosenpass_secret_memory::policy functions (or a function that sets rosenpass_secret_memory::alloc::ALLOC_INIT
- Updates to tests- integration, fuzz, bench: some tests use procspawn to spawn multiple processes with different allocator policies
This commit is contained in:
Prabhpreet Dua
2024-06-10 13:12:44 +05:30
committed by GitHub
parent b46fca99cb
commit 526c930119
29 changed files with 1010 additions and 307 deletions

View File

@@ -176,8 +176,12 @@ jobs:
cargo fuzz run fuzz_handle_msg -- -max_total_time=5 cargo fuzz run fuzz_handle_msg -- -max_total_time=5
ulimit -s 8192000 && RUST_MIN_STACK=33554432000 && cargo fuzz run fuzz_kyber_encaps -- -max_total_time=5 ulimit -s 8192000 && RUST_MIN_STACK=33554432000 && cargo fuzz run fuzz_kyber_encaps -- -max_total_time=5
cargo fuzz run fuzz_mceliece_encaps -- -max_total_time=5 cargo fuzz run fuzz_mceliece_encaps -- -max_total_time=5
cargo fuzz run fuzz_box_secret_alloc -- -max_total_time=5 cargo fuzz run fuzz_box_secret_alloc_malloc -- -max_total_time=5
cargo fuzz run fuzz_vec_secret_alloc -- -max_total_time=5 cargo fuzz run fuzz_box_secret_alloc_memfdsec -- -max_total_time=5
cargo fuzz run fuzz_box_secret_alloc_memfdsec_mallocfb -- -max_total_time=5
cargo fuzz run fuzz_vec_secret_alloc_malloc -- -max_total_time=5
cargo fuzz run fuzz_vec_secret_alloc_memfdsec -- -max_total_time=5
cargo fuzz run fuzz_vec_secret_alloc_memfdsec_mallocfb -- -max_total_time=5
codecov: codecov:
runs-on: ubuntu-latest runs-on: ubuntu-latest

476
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -12,15 +12,11 @@ members = [
"fuzz", "fuzz",
"secret-memory", "secret-memory",
"rp", "rp",
"wireguard-broker"
]
default-members = [
"rosenpass",
"rp",
"wireguard-broker", "wireguard-broker",
] ]
default-members = ["rosenpass", "rp", "wireguard-broker"]
[workspace.metadata.release] [workspace.metadata.release]
# ensure that adding `--package` as argument to `cargo release` still creates version tags in the form of `vx.y.z` # ensure that adding `--package` as argument to `cargo release` still creates version tags in the form of `vx.y.z`
tag-prefix = "" tag-prefix = ""
@@ -45,7 +41,7 @@ env_logger = "0.10.2"
toml = "0.7.8" toml = "0.7.8"
static_assertions = "1.1.0" static_assertions = "1.1.0"
allocator-api2 = "0.2.14" allocator-api2 = "0.2.14"
memsec = "0.6.3" memsec = { version="0.7.0", features = [ "alloc_ext", ] }
rand = "0.8.5" rand = "0.8.5"
typenum = "1.17.0" typenum = "1.17.0"
log = { version = "0.4.21" } log = { version = "0.4.21" }
@@ -54,9 +50,15 @@ serde = { version = "1.0.203", features = ["derive"] }
arbitrary = { version = "1.3.2", features = ["derive"] } arbitrary = { version = "1.3.2", features = ["derive"] }
anyhow = { version = "1.0.86", features = ["backtrace", "std"] } anyhow = { version = "1.0.86", features = ["backtrace", "std"] }
mio = { version = "0.8.11", features = ["net", "os-poll"] } mio = { version = "0.8.11", features = ["net", "os-poll"] }
oqs-sys = { version = "0.9.1", default-features = false, features = ['classic_mceliece', 'kyber'] } oqs-sys = { version = "0.9.1", default-features = false, features = [
'classic_mceliece',
'kyber',
] }
blake2 = "0.10.6" blake2 = "0.10.6"
chacha20poly1305 = { version = "0.10.1", default-features = false, features = [ "std", "heapless" ] } chacha20poly1305 = { version = "0.10.1", default-features = false, features = [
"std",
"heapless",
] }
zerocopy = { version = "0.7.34", features = ["derive"] } zerocopy = { version = "0.7.34", features = ["derive"] }
home = "0.5.9" home = "0.5.9"
derive_builder = "0.20.0" derive_builder = "0.20.0"
@@ -65,14 +67,16 @@ postcard= {version = "1.0.8", features = ["alloc"]}
#Dev dependencies #Dev dependencies
serial_test = "3.1.1" serial_test = "3.1.1"
tempfile="3" tempfile = "3"
stacker = "0.1.15" stacker = "0.1.15"
libfuzzer-sys = "0.4" libfuzzer-sys = "0.4"
test_bin = "0.4.0" test_bin = "0.4.0"
criterion = "0.4.0" criterion = "0.4.0"
allocator-api2-tests = "0.2.15" allocator-api2-tests = "0.2.15"
procspawn = {version = "1.0.0", features= ["test-support"]}
#Broker dependencies (might need cleanup or changes) #Broker dependencies (might need cleanup or changes)
wireguard-uapi = "3.0.0" wireguard-uapi = "3.0.0"
command-fds = "0.2.3" command-fds = "0.2.3"
rustix = { version = "0.38.27", features = ["net"] } rustix = { version = "0.38.27", features = ["net"] }

View File

@@ -48,13 +48,37 @@ test = false
doc = false doc = false
[[bin]] [[bin]]
name = "fuzz_box_secret_alloc" name = "fuzz_box_secret_alloc_malloc"
path = "fuzz_targets/box_secret_alloc.rs" path = "fuzz_targets/box_secret_alloc_malloc.rs"
test = false test = false
doc = false doc = false
[[bin]] [[bin]]
name = "fuzz_vec_secret_alloc" name = "fuzz_vec_secret_alloc_malloc"
path = "fuzz_targets/vec_secret_alloc.rs" path = "fuzz_targets/vec_secret_alloc_malloc.rs"
test = false test = false
doc = false doc = false
[[bin]]
name = "fuzz_box_secret_alloc_memfdsec"
path = "fuzz_targets/box_secret_alloc_memfdsec.rs"
test = false
doc = false
[[bin]]
name = "fuzz_vec_secret_alloc_memfdsec"
path = "fuzz_targets/vec_secret_alloc_memfdsec.rs"
test = false
doc = false
[[bin]]
name = "fuzz_box_secret_alloc_memfdsec_mallocfb"
path = "fuzz_targets/box_secret_alloc_memfdsec_mallocfb.rs"
test = false
doc = false
[[bin]]
name = "fuzz_vec_secret_alloc_memfdsec_mallocfb"
path = "fuzz_targets/vec_secret_alloc_memfdsec_mallocfb.rs"
test = false
doc = false

View File

@@ -0,0 +1,12 @@
#![no_main]
use libfuzzer_sys::fuzz_target;
use rosenpass_secret_memory::alloc::secret_box;
use rosenpass_secret_memory::policy::*;
use std::sync::Once;
static ONCE: Once = Once::new();
fuzz_target!(|data: &[u8]| {
ONCE.call_once(secret_policy_use_only_malloc_secrets);
let _ = secret_box(data);
});

View File

@@ -0,0 +1,13 @@
#![no_main]
use libfuzzer_sys::fuzz_target;
use rosenpass_secret_memory::alloc::secret_box;
use rosenpass_secret_memory::policy::*;
use std::sync::Once;
static ONCE: Once = Once::new();
fuzz_target!(|data: &[u8]| {
ONCE.call_once(secret_policy_use_only_memfd_secrets);
let _ = secret_box(data);
});

View File

@@ -2,7 +2,12 @@
use libfuzzer_sys::fuzz_target; use libfuzzer_sys::fuzz_target;
use rosenpass_secret_memory::alloc::secret_box; use rosenpass_secret_memory::alloc::secret_box;
use rosenpass_secret_memory::policy::*;
use std::sync::Once;
static ONCE: Once = Once::new();
fuzz_target!(|data: &[u8]| { fuzz_target!(|data: &[u8]| {
ONCE.call_once(secret_policy_try_use_memfd_secrets);
let _ = secret_box(data); let _ = secret_box(data);
}); });

View File

@@ -6,9 +6,13 @@ use libfuzzer_sys::fuzz_target;
use rosenpass::protocol::CryptoServer; use rosenpass::protocol::CryptoServer;
use rosenpass_cipher_traits::Kem; use rosenpass_cipher_traits::Kem;
use rosenpass_ciphers::kem::StaticKem; use rosenpass_ciphers::kem::StaticKem;
use rosenpass_secret_memory::policy::*;
use rosenpass_secret_memory::Secret; use rosenpass_secret_memory::Secret;
use std::sync::Once;
static ONCE: Once = Once::new();
fuzz_target!(|rx_buf: &[u8]| { fuzz_target!(|rx_buf: &[u8]| {
ONCE.call_once(secret_policy_use_only_malloc_secrets);
let sk = Secret::from_slice(&[0; StaticKem::SK_LEN]); let sk = Secret::from_slice(&[0; StaticKem::SK_LEN]);
let pk = Secret::from_slice(&[0; StaticKem::PK_LEN]); let pk = Secret::from_slice(&[0; StaticKem::PK_LEN]);

View File

@@ -0,0 +1,15 @@
#![no_main]
use std::sync::Once;
use libfuzzer_sys::fuzz_target;
use rosenpass_secret_memory::alloc::secret_vec;
use rosenpass_secret_memory::policy::*;
static ONCE: Once = Once::new();
fuzz_target!(|data: &[u8]| {
ONCE.call_once(secret_policy_use_only_malloc_secrets);
let mut vec = secret_vec();
vec.extend_from_slice(data);
});

View File

@@ -0,0 +1,15 @@
#![no_main]
use std::sync::Once;
use libfuzzer_sys::fuzz_target;
use rosenpass_secret_memory::alloc::secret_vec;
use rosenpass_secret_memory::policy::*;
static ONCE: Once = Once::new();
fuzz_target!(|data: &[u8]| {
ONCE.call_once(secret_policy_use_only_memfd_secrets);
let mut vec = secret_vec();
vec.extend_from_slice(data);
});

View File

@@ -1,9 +1,15 @@
#![no_main] #![no_main]
use std::sync::Once;
use libfuzzer_sys::fuzz_target; use libfuzzer_sys::fuzz_target;
use rosenpass_secret_memory::alloc::secret_vec; use rosenpass_secret_memory::alloc::secret_vec;
use rosenpass_secret_memory::policy::*;
static ONCE: Once = Once::new();
fuzz_target!(|data: &[u8]| { fuzz_target!(|data: &[u8]| {
ONCE.call_once(secret_policy_try_use_memfd_secrets);
let mut vec = secret_vec(); let mut vec = secret_vec();
vec.extend_from_slice(data); vec.extend_from_slice(data);
}); });

View File

@@ -49,6 +49,7 @@ criterion = { workspace = true }
test_bin = { workspace = true } test_bin = { workspace = true }
stacker = { workspace = true } stacker = { workspace = true }
serial_test = {workspace = true} serial_test = {workspace = true}
procspawn = {workspace = true}
[features] [features]
enable_broker_api = ["rosenpass-wireguard-broker/enable_broker_api"] enable_broker_api = ["rosenpass-wireguard-broker/enable_broker_api"]

View File

@@ -5,6 +5,7 @@ use rosenpass_cipher_traits::Kem;
use rosenpass_ciphers::kem::StaticKem; use rosenpass_ciphers::kem::StaticKem;
use criterion::{black_box, criterion_group, criterion_main, Criterion}; use criterion::{black_box, criterion_group, criterion_main, Criterion};
use rosenpass_secret_memory::secret_policy_try_use_memfd_secrets;
fn handle( fn handle(
tx: &mut CryptoServer, tx: &mut CryptoServer,
@@ -56,6 +57,7 @@ fn make_server_pair() -> Result<(CryptoServer, CryptoServer)> {
} }
fn criterion_benchmark(c: &mut Criterion) { fn criterion_benchmark(c: &mut Criterion) {
secret_policy_try_use_memfd_secrets();
let (mut a, mut b) = make_server_pair().unwrap(); let (mut a, mut b) = make_server_pair().unwrap();
c.bench_function("cca_secret_alloc", |bench| { c.bench_function("cca_secret_alloc", |bench| {
bench.iter(|| { bench.iter(|| {

View File

@@ -3,6 +3,7 @@ use clap::{Parser, Subcommand};
use rosenpass_cipher_traits::Kem; use rosenpass_cipher_traits::Kem;
use rosenpass_ciphers::kem::StaticKem; use rosenpass_ciphers::kem::StaticKem;
use rosenpass_secret_memory::file::StoreSecret; use rosenpass_secret_memory::file::StoreSecret;
use rosenpass_secret_memory::secret_policy_try_use_memfd_secrets;
use rosenpass_util::file::{LoadValue, LoadValueB64}; use rosenpass_util::file::{LoadValue, LoadValueB64};
use rosenpass_wireguard_broker::brokers::native_unix::{ use rosenpass_wireguard_broker::brokers::native_unix::{
NativeUnixBroker, NativeUnixBrokerConfigBaseBuilder, NativeUnixBrokerConfigBaseBuilderError, NativeUnixBroker, NativeUnixBrokerConfigBaseBuilder, NativeUnixBrokerConfigBaseBuilderError,
@@ -154,6 +155,9 @@ impl CliCommand {
/// ## TODO /// ## TODO
/// - This method consumes the [`CliCommand`] value. It might be wise to use a reference... /// - This method consumes the [`CliCommand`] value. It might be wise to use a reference...
pub fn run(self, test_helpers: Option<AppServerTest>) -> anyhow::Result<()> { pub fn run(self, test_helpers: Option<AppServerTest>) -> anyhow::Result<()> {
//Specify secret policy
secret_policy_try_use_memfd_secrets();
use CliCommand::*; use CliCommand::*;
match self { match self {
Man => { Man => {

View File

@@ -19,12 +19,16 @@
//! [CryptoServer]. //! [CryptoServer].
//! //!
//! ``` //! ```
//! use rosenpass_secret_memory::policy::*;
//! use rosenpass_cipher_traits::Kem; //! use rosenpass_cipher_traits::Kem;
//! use rosenpass_ciphers::kem::StaticKem; //! use rosenpass_ciphers::kem::StaticKem;
//! use rosenpass::{ //! use rosenpass::{
//! protocol::{SSk, SPk, MsgBuf, PeerPtr, CryptoServer, SymKey}, //! protocol::{SSk, SPk, MsgBuf, PeerPtr, CryptoServer, SymKey},
//! }; //! };
//! # fn main() -> anyhow::Result<()> { //! # fn main() -> anyhow::Result<()> {
//! // Set security policy for storing secrets
//!
//! secret_policy_try_use_memfd_secrets();
//! //!
//! // initialize secret and public key for peer a ... //! // initialize secret and public key for peer a ...
//! let (mut peer_a_sk, mut peer_a_pk) = (SSk::zero(), SPk::zero()); //! let (mut peer_a_sk, mut peer_a_pk) = (SSk::zero(), SPk::zero());
@@ -2145,6 +2149,7 @@ mod test {
use std::{net::SocketAddrV4, thread::sleep, time::Duration}; use std::{net::SocketAddrV4, thread::sleep, time::Duration};
use super::*; use super::*;
use serial_test::serial;
struct VecHostIdentifier(Vec<u8>); struct VecHostIdentifier(Vec<u8>);
@@ -2166,7 +2171,21 @@ mod test {
} }
} }
fn setup_logging() {
use std::io::Write;
let mut log_builder = env_logger::Builder::from_default_env(); // sets log level filter from environment (or defaults)
log_builder.filter_level(log::LevelFilter::Info);
log_builder.format_timestamp_nanos();
log_builder.format(|buf, record| {
let ts_format = buf.timestamp_nanos().to_string();
writeln!(buf, "{}: {}", &ts_format[14..], record.args())
});
let _ = log_builder.try_init();
}
#[test] #[test]
#[serial]
/// Ensure that the protocol implementation can deal with truncated /// Ensure that the protocol implementation can deal with truncated
/// messages and with overlong messages. /// messages and with overlong messages.
/// ///
@@ -2182,6 +2201,8 @@ mod test {
/// Through all this, the handshake should still successfully terminate; /// Through all this, the handshake should still successfully terminate;
/// i.e. an exchanged key must be produced in both servers. /// i.e. an exchanged key must be produced in both servers.
fn handles_incorrect_size_messages() { fn handles_incorrect_size_messages() {
setup_logging();
rosenpass_secret_memory::secret_policy_try_use_memfd_secrets();
stacker::grow(8 * 1024 * 1024, || { stacker::grow(8 * 1024 * 1024, || {
const OVERSIZED_MESSAGE: usize = ((MAX_MESSAGE_LEN as f32) * 1.2) as usize; const OVERSIZED_MESSAGE: usize = ((MAX_MESSAGE_LEN as f32) * 1.2) as usize;
type MsgBufPlus = Public<OVERSIZED_MESSAGE>; type MsgBufPlus = Public<OVERSIZED_MESSAGE>;
@@ -2252,7 +2273,10 @@ mod test {
} }
#[test] #[test]
#[serial]
fn test_regular_exchange() { fn test_regular_exchange() {
setup_logging();
rosenpass_secret_memory::secret_policy_try_use_memfd_secrets();
stacker::grow(8 * 1024 * 1024, || { stacker::grow(8 * 1024 * 1024, || {
type MsgBufPlus = Public<MAX_MESSAGE_LEN>; type MsgBufPlus = Public<MAX_MESSAGE_LEN>;
let (mut a, mut b) = make_server_pair().unwrap(); let (mut a, mut b) = make_server_pair().unwrap();
@@ -2296,7 +2320,7 @@ mod test {
//B handles InitConf, sends EmptyData //B handles InitConf, sends EmptyData
let HandleMsgResult { let HandleMsgResult {
resp, resp: _,
exchanged_with, exchanged_with,
} = b } = b
.handle_msg(&a_to_b_buf.as_slice()[..init_conf_len], &mut *b_to_a_buf) .handle_msg(&a_to_b_buf.as_slice()[..init_conf_len], &mut *b_to_a_buf)
@@ -2310,7 +2334,10 @@ mod test {
} }
#[test] #[test]
#[serial]
fn test_regular_init_conf_retransmit() { fn test_regular_init_conf_retransmit() {
setup_logging();
rosenpass_secret_memory::secret_policy_try_use_memfd_secrets();
stacker::grow(8 * 1024 * 1024, || { stacker::grow(8 * 1024 * 1024, || {
type MsgBufPlus = Public<MAX_MESSAGE_LEN>; type MsgBufPlus = Public<MAX_MESSAGE_LEN>;
let (mut a, mut b) = make_server_pair().unwrap(); let (mut a, mut b) = make_server_pair().unwrap();
@@ -2355,7 +2382,7 @@ mod test {
//B handles InitConf, sends EmptyData //B handles InitConf, sends EmptyData
let HandleMsgResult { let HandleMsgResult {
resp, resp: _,
exchanged_with, exchanged_with,
} = b } = b
.handle_msg(&a_to_b_buf.as_slice()[..init_conf_len], &mut *b_to_a_buf) .handle_msg(&a_to_b_buf.as_slice()[..init_conf_len], &mut *b_to_a_buf)
@@ -2368,7 +2395,7 @@ mod test {
//B handles InitConf again, sends EmptyData //B handles InitConf again, sends EmptyData
let HandleMsgResult { let HandleMsgResult {
resp, resp: _,
exchanged_with, exchanged_with,
} = b } = b
.handle_msg(&a_to_b_buf.as_slice()[..init_conf_len], &mut *b_to_a_buf) .handle_msg(&a_to_b_buf.as_slice()[..init_conf_len], &mut *b_to_a_buf)
@@ -2382,7 +2409,10 @@ mod test {
} }
#[test] #[test]
#[serial]
fn cookie_reply_mechanism_responder_under_load() { fn cookie_reply_mechanism_responder_under_load() {
setup_logging();
rosenpass_secret_memory::secret_policy_try_use_memfd_secrets();
stacker::grow(8 * 1024 * 1024, || { stacker::grow(8 * 1024 * 1024, || {
type MsgBufPlus = Public<MAX_MESSAGE_LEN>; type MsgBufPlus = Public<MAX_MESSAGE_LEN>;
let (mut a, mut b) = make_server_pair().unwrap(); let (mut a, mut b) = make_server_pair().unwrap();
@@ -2476,7 +2506,10 @@ mod test {
} }
#[test] #[test]
#[serial]
fn cookie_reply_mechanism_initiator_bails_on_message_under_load() { fn cookie_reply_mechanism_initiator_bails_on_message_under_load() {
setup_logging();
rosenpass_secret_memory::secret_policy_try_use_memfd_secrets();
stacker::grow(8 * 1024 * 1024, || { stacker::grow(8 * 1024 * 1024, || {
type MsgBufPlus = Public<MAX_MESSAGE_LEN>; type MsgBufPlus = Public<MAX_MESSAGE_LEN>;
let (mut a, mut b) = make_server_pair().unwrap(); let (mut a, mut b) = make_server_pair().unwrap();

View File

@@ -6,7 +6,7 @@ use std::{
time::Duration, time::Duration,
}; };
use clap::{builder::Str, Parser}; use clap::Parser;
use rosenpass::{app_server::AppServerTestBuilder, cli::CliArgs}; use rosenpass::{app_server::AppServerTestBuilder, cli::CliArgs};
use rosenpass_secret_memory::{Public, Secret}; use rosenpass_secret_memory::{Public, Secret};
use rosenpass_wireguard_broker::{WireguardBrokerMio, WG_KEY_LEN, WG_PEER_LEN}; use rosenpass_wireguard_broker::{WireguardBrokerMio, WG_KEY_LEN, WG_PEER_LEN};
@@ -275,6 +275,7 @@ fn check_exchange_under_dos() {
fs::remove_dir_all(&tmpdir).unwrap(); fs::remove_dir_all(&tmpdir).unwrap();
} }
#[allow(dead_code)]
#[derive(Debug, Default)] #[derive(Debug, Default)]
struct MockBrokerInner { struct MockBrokerInner {
psk: Option<Secret<WG_KEY_LEN>>, psk: Option<Secret<WG_KEY_LEN>>,

View File

@@ -102,6 +102,7 @@ mod tests {
use std::fs; use std::fs;
use rosenpass::protocol::{SPk, SSk}; use rosenpass::protocol::{SPk, SSk};
use rosenpass_secret_memory::secret_policy_try_use_memfd_secrets;
use rosenpass_secret_memory::Secret; use rosenpass_secret_memory::Secret;
use rosenpass_util::file::LoadValue; use rosenpass_util::file::LoadValue;
use rosenpass_util::file::LoadValueB64; use rosenpass_util::file::LoadValueB64;
@@ -110,7 +111,8 @@ mod tests {
use crate::key::{genkey, pubkey, WG_B64_LEN}; use crate::key::{genkey, pubkey, WG_B64_LEN};
#[test] #[test]
fn it_works() { fn test_key_loopback() {
secret_policy_try_use_memfd_secrets();
let private_keys_dir = tempdir().unwrap(); let private_keys_dir = tempdir().unwrap();
fs::remove_dir(private_keys_dir.path()).unwrap(); fs::remove_dir(private_keys_dir.path()).unwrap();

View File

@@ -3,6 +3,7 @@ use std::process::exit;
use cli::{Cli, Command}; use cli::{Cli, Command};
use exchange::exchange; use exchange::exchange;
use key::{genkey, pubkey}; use key::{genkey, pubkey};
use rosenpass_secret_memory::policy;
mod cli; mod cli;
mod exchange; mod exchange;
@@ -10,6 +11,8 @@ mod key;
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
policy::secret_policy_try_use_memfd_secrets();
let cli = match Cli::parse(std::env::args().peekable()) { let cli = match Cli::parse(std::env::args().peekable()) {
Ok(cli) => cli, Ok(cli) => cli,
Err(err) => { Err(err) => {

View File

@@ -23,3 +23,4 @@ log = { workspace = true }
allocator-api2-tests = { workspace = true } allocator-api2-tests = { workspace = true }
tempfile = {workspace = true} tempfile = {workspace = true}
base64ct = {workspace = true} base64ct = {workspace = true}
procspawn = {workspace = true}

View File

@@ -4,37 +4,41 @@ use std::ptr::NonNull;
use allocator_api2::alloc::{AllocError, Allocator, Layout}; use allocator_api2::alloc::{AllocError, Allocator, Layout};
#[derive(Copy, Clone, Default)] #[derive(Copy, Clone, Default)]
struct MemsecAllocatorContents; struct MallocAllocatorContents;
/// Memory allocation using using the memsec crate /// Memory allocation using using the memsec crate
#[derive(Copy, Clone, Default)] #[derive(Copy, Clone, Default)]
pub struct MemsecAllocator { pub struct MallocAllocator {
_dummy_private_data: MemsecAllocatorContents, _dummy_private_data: MallocAllocatorContents,
} }
/// A box backed by the memsec allocator /// A box backed by the memsec allocator
pub type MemsecBox<T> = allocator_api2::boxed::Box<T, MemsecAllocator>; pub type MallocBox<T> = allocator_api2::boxed::Box<T, MallocAllocator>;
/// A vector backed by the memsec allocator /// A vector backed by the memsec allocator
pub type MemsecVec<T> = allocator_api2::vec::Vec<T, MemsecAllocator>; pub type MallocVec<T> = allocator_api2::vec::Vec<T, MallocAllocator>;
pub fn memsec_box<T>(x: T) -> MemsecBox<T> { pub fn malloc_box_try<T>(x: T) -> Result<MallocBox<T>, AllocError> {
MemsecBox::<T>::new_in(x, MemsecAllocator::new()) MallocBox::<T>::try_new_in(x, MallocAllocator::new())
} }
pub fn memsec_vec<T>() -> MemsecVec<T> { pub fn malloc_box<T>(x: T) -> MallocBox<T> {
MemsecVec::<T>::new_in(MemsecAllocator::new()) MallocBox::<T>::new_in(x, MallocAllocator::new())
} }
impl MemsecAllocator { pub fn malloc_vec<T>() -> MallocVec<T> {
MallocVec::<T>::new_in(MallocAllocator::new())
}
impl MallocAllocator {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
_dummy_private_data: MemsecAllocatorContents, _dummy_private_data: MallocAllocatorContents,
} }
} }
} }
unsafe impl Allocator for MemsecAllocator { unsafe impl Allocator for MallocAllocator {
fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> { fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
// Call memsec allocator // Call memsec allocator
let mem: Option<NonNull<[u8]>> = unsafe { memsec::malloc_sized(layout.size()) }; let mem: Option<NonNull<[u8]>> = unsafe { memsec::malloc_sized(layout.size()) };
@@ -48,8 +52,8 @@ unsafe impl Allocator for MemsecAllocator {
// Ensure the right alignment is used // Ensure the right alignment is used
let off = (mem.as_ptr() as *const u8).align_offset(layout.align()); let off = (mem.as_ptr() as *const u8).align_offset(layout.align());
if off != 0 { if off != 0 {
log::error!("Allocation {layout:?} was requested but memsec returned allocation \ log::error!("Allocation {layout:?} was requested but malloc-based memsec returned allocation \
with offset {off} from the requested alignment. Memsec always allocates values \ with offset {off} from the requested alignment. Malloc always allocates values \
at the end of a memory page for security reasons, custom alignments are not supported. \ at the end of a memory page for security reasons, custom alignments are not supported. \
You could try allocating an oversized value."); You could try allocating an oversized value.");
unsafe { memsec::free(mem) }; unsafe { memsec::free(mem) };
@@ -66,7 +70,7 @@ unsafe impl Allocator for MemsecAllocator {
} }
} }
impl fmt::Debug for MemsecAllocator { impl fmt::Debug for MallocAllocator {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.write_str("<memsec based Rust allocator>") fmt.write_str("<memsec based Rust allocator>")
} }
@@ -78,21 +82,21 @@ mod test {
use super::*; use super::*;
make_test! { test_sizes(MemsecAllocator::new()) } make_test! { test_sizes(MallocAllocator::new()) }
make_test! { test_vec(MemsecAllocator::new()) } make_test! { test_vec(MallocAllocator::new()) }
make_test! { test_many_boxes(MemsecAllocator::new()) } make_test! { test_many_boxes(MallocAllocator::new()) }
#[test] #[test]
fn memsec_allocation() { fn malloc_allocation() {
let alloc = MemsecAllocator::new(); let alloc = MallocAllocator::new();
memsec_allocation_impl::<0>(&alloc); malloc_allocation_impl::<0>(&alloc);
memsec_allocation_impl::<7>(&alloc); malloc_allocation_impl::<7>(&alloc);
memsec_allocation_impl::<8>(&alloc); malloc_allocation_impl::<8>(&alloc);
memsec_allocation_impl::<64>(&alloc); malloc_allocation_impl::<64>(&alloc);
memsec_allocation_impl::<999>(&alloc); malloc_allocation_impl::<999>(&alloc);
} }
fn memsec_allocation_impl<const N: usize>(alloc: &MemsecAllocator) { fn malloc_allocation_impl<const N: usize>(alloc: &MallocAllocator) {
let layout = Layout::new::<[u8; N]>(); let layout = Layout::new::<[u8; N]>();
let mem = alloc.allocate(layout).unwrap(); let mem = alloc.allocate(layout).unwrap();

View File

@@ -0,0 +1,112 @@
#![cfg(target_os = "linux")]
use std::fmt;
use std::ptr::NonNull;
use allocator_api2::alloc::{AllocError, Allocator, Layout};
#[derive(Copy, Clone, Default)]
struct MemfdSecAllocatorContents;
/// Memory allocation using using the memsec crate
#[derive(Copy, Clone, Default)]
pub struct MemfdSecAllocator {
_dummy_private_data: MemfdSecAllocatorContents,
}
/// A box backed by the memsec allocator
pub type MemfdSecBox<T> = allocator_api2::boxed::Box<T, MemfdSecAllocator>;
/// A vector backed by the memsec allocator
pub type MemfdSecVec<T> = allocator_api2::vec::Vec<T, MemfdSecAllocator>;
pub fn memfdsec_box_try<T>(x: T) -> Result<MemfdSecBox<T>, AllocError> {
MemfdSecBox::<T>::try_new_in(x, MemfdSecAllocator::new())
}
pub fn memfdsec_box<T>(x: T) -> MemfdSecBox<T> {
MemfdSecBox::<T>::new_in(x, MemfdSecAllocator::new())
}
pub fn memfdsec_vec<T>() -> MemfdSecVec<T> {
MemfdSecVec::<T>::new_in(MemfdSecAllocator::new())
}
impl MemfdSecAllocator {
pub fn new() -> Self {
Self {
_dummy_private_data: MemfdSecAllocatorContents,
}
}
}
unsafe impl Allocator for MemfdSecAllocator {
fn allocate(&self, layout: Layout) -> Result<NonNull<[u8]>, AllocError> {
// Call memsec allocator
let mem: Option<NonNull<[u8]>> = unsafe { memsec::memfd_secret_sized(layout.size()) };
// Unwrap the option
let Some(mem) = mem else {
log::error!("Allocation {layout:?} was requested but memfd-based memsec returned a null pointer");
return Err(AllocError);
};
// Ensure the right alignment is used
let off = (mem.as_ptr() as *const u8).align_offset(layout.align());
if off != 0 {
log::error!("Allocation {layout:?} was requested but memfd-based memsec returned allocation \
with offset {off} from the requested alignment. Memfd always allocates values \
at the end of a memory page for security reasons, custom alignments are not supported. \
You could try allocating an oversized value.");
unsafe { memsec::free_memfd_secret(mem) };
return Err(AllocError);
};
Ok(mem)
}
unsafe fn deallocate(&self, ptr: NonNull<u8>, _layout: Layout) {
unsafe {
memsec::free_memfd_secret(ptr);
}
}
}
impl fmt::Debug for MemfdSecAllocator {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.write_str("<memsec based Rust allocator>")
}
}
#[cfg(test)]
mod test {
use allocator_api2_tests::make_test;
use super::*;
make_test! { test_sizes(MemfdSecAllocator::new()) }
make_test! { test_vec(MemfdSecAllocator::new()) }
make_test! { test_many_boxes(MemfdSecAllocator::new()) }
#[test]
fn memfdsec_allocation() {
let alloc = MemfdSecAllocator::new();
memfdsec_allocation_impl::<0>(&alloc);
memfdsec_allocation_impl::<7>(&alloc);
memfdsec_allocation_impl::<8>(&alloc);
memfdsec_allocation_impl::<64>(&alloc);
memfdsec_allocation_impl::<999>(&alloc);
}
fn memfdsec_allocation_impl<const N: usize>(alloc: &MemfdSecAllocator) {
let layout = Layout::new::<[u8; N]>();
let mem = alloc.allocate(layout).unwrap();
// https://libsodium.gitbook.io/doc/memory_management#guarded-heap-allocations
// promises us that allocated memory is initialized with the magic byte 0xDB
// and memsec promises to provide a reimplementation of the libsodium mechanism;
// it uses the magic value 0xD0 though
assert_eq!(unsafe { mem.as_ref() }, &[0xD0u8; N]);
let mem = NonNull::new(mem.as_ptr() as *mut u8).unwrap();
unsafe { alloc.deallocate(mem, layout) };
}
}

View File

@@ -0,0 +1,2 @@
pub mod malloc;
pub mod memfdsec;

View File

@@ -1,6 +1,86 @@
pub mod memsec; pub mod memsec;
pub use crate::alloc::memsec::{ use std::sync::OnceLock;
memsec_box as secret_box, memsec_vec as secret_vec, MemsecAllocator as SecretAllocator,
MemsecBox as SecretBox, MemsecVec as SecretVec, use allocator_api2::alloc::{AllocError, Allocator};
}; use memsec::malloc::MallocAllocator;
#[cfg(target_os = "linux")]
use memsec::memfdsec::MemfdSecAllocator;
static ALLOC_TYPE: OnceLock<SecretAllocType> = OnceLock::new();
/// Sets the secret allocation type to use.
/// Intended usage at startup before secret allocation
/// takes place
pub fn set_secret_alloc_type(alloc_type: SecretAllocType) {
ALLOC_TYPE.set(alloc_type).unwrap();
}
pub fn get_or_init_secret_alloc_type(alloc_type: SecretAllocType) -> SecretAllocType {
*ALLOC_TYPE.get_or_init(|| alloc_type)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SecretAllocType {
MemsecMalloc,
#[cfg(target_os = "linux")]
MemsecMemfdSec,
}
pub struct SecretAlloc {
alloc_type: SecretAllocType,
}
impl Default for SecretAlloc {
fn default() -> Self {
Self {
alloc_type: *ALLOC_TYPE.get().expect(
"Secret security policy not specified. \
Run the specifying policy function in \
rosenpass_secret_memory::policy or set a \
custom policy by initializing \
rosenpass_secret_memory::alloc::ALLOC_TYPE \
before using secrets",
),
}
}
}
unsafe impl Allocator for SecretAlloc {
fn allocate(
&self,
layout: std::alloc::Layout,
) -> Result<std::ptr::NonNull<[u8]>, allocator_api2::alloc::AllocError> {
match self.alloc_type {
SecretAllocType::MemsecMalloc => MallocAllocator::default().allocate(layout),
#[cfg(target_os = "linux")]
SecretAllocType::MemsecMemfdSec => MemfdSecAllocator::default().allocate(layout),
}
}
unsafe fn deallocate(&self, ptr: std::ptr::NonNull<u8>, layout: std::alloc::Layout) {
match self.alloc_type {
SecretAllocType::MemsecMalloc => MallocAllocator::default().deallocate(ptr, layout),
#[cfg(target_os = "linux")]
SecretAllocType::MemsecMemfdSec => MemfdSecAllocator::default().deallocate(ptr, layout),
}
}
}
pub type SecretBox<T> = allocator_api2::boxed::Box<T, SecretAlloc>;
/// A vector backed by the memsec allocator
pub type SecretVec<T> = allocator_api2::vec::Vec<T, SecretAlloc>;
pub fn secret_box_try<T>(x: T) -> Result<SecretBox<T>, AllocError> {
SecretBox::<T>::try_new_in(x, SecretAlloc::default())
}
pub fn secret_box<T>(x: T) -> SecretBox<T> {
SecretBox::<T>::new_in(x, SecretAlloc::default())
}
pub fn secret_vec<T>() -> SecretVec<T> {
SecretVec::<T>::new_in(SecretAlloc::default())
}

View File

@@ -9,3 +9,6 @@ pub use crate::public::Public;
mod secret; mod secret;
pub use crate::secret::Secret; pub use crate::secret::Secret;
pub mod policy;
pub use crate::policy::*;

View File

@@ -0,0 +1,82 @@
pub fn secret_policy_try_use_memfd_secrets() {
let alloc_type = {
#[cfg(target_os = "linux")]
{
if crate::alloc::memsec::memfdsec::memfdsec_box_try(0u8).is_ok() {
crate::alloc::SecretAllocType::MemsecMemfdSec
} else {
crate::alloc::SecretAllocType::MemsecMalloc
}
}
#[cfg(not(target_os = "linux"))]
{
crate::alloc::SecretAllocType::MemsecMalloc
}
};
assert_eq!(
alloc_type,
crate::alloc::get_or_init_secret_alloc_type(alloc_type)
);
log::info!("Secrets will be allocated using {:?}", alloc_type);
}
#[cfg(target_os = "linux")]
pub fn secret_policy_use_only_memfd_secrets() {
let alloc_type = crate::alloc::SecretAllocType::MemsecMemfdSec;
assert_eq!(
alloc_type,
crate::alloc::get_or_init_secret_alloc_type(alloc_type)
);
log::info!("Secrets will be allocated using {:?}", alloc_type);
}
pub fn secret_policy_use_only_malloc_secrets() {
let alloc_type = crate::alloc::SecretAllocType::MemsecMalloc;
assert_eq!(
alloc_type,
crate::alloc::get_or_init_secret_alloc_type(alloc_type)
);
log::info!("Secrets will be allocated using {:?}", alloc_type);
}
pub mod test {
#[macro_export]
macro_rules! test_spawn_process_with_policies {
($body:block, $($f: expr),*) => {
$(
let handle = procspawn::spawn((), |_| {
$f();
$body
});
handle.join().unwrap();
)*
};
}
#[macro_export]
macro_rules! test_spawn_process_provided_policies {
($body: block) => {
$crate::test_spawn_process_with_policies!(
$body,
$crate::policy::secret_policy_try_use_memfd_secrets,
$crate::secret_policy_use_only_malloc_secrets
);
#[cfg(target_os = "linux")]
{
$crate::test_spawn_process_with_policies!(
$body,
$crate::policy::secret_policy_use_only_memfd_secrets
);
}
};
}
}

View File

@@ -321,133 +321,147 @@ impl<const N: usize> StoreSecret for Secret<N> {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::test_spawn_process_provided_policies;
use super::*; use super::*;
use std::{fs, os::unix::fs::PermissionsExt}; use std::{fs, os::unix::fs::PermissionsExt};
use tempfile::tempdir; use tempfile::tempdir;
procspawn::enable_test_support!();
/// check that we can alloc using the magic pool /// check that we can alloc using the magic pool
#[test] #[test]
fn secret_memory_pool_take() { fn secret_memory_pool_take() {
const N: usize = 0x100; test_spawn_process_provided_policies!({
let mut pool = SecretMemoryPool::new(); const N: usize = 0x100;
let secret: ZeroizingSecretBox<[u8; N]> = pool.take(); let mut pool = SecretMemoryPool::new();
assert_eq!(secret.as_ref(), &[0; N]); let secret: ZeroizingSecretBox<[u8; N]> = pool.take();
assert_eq!(secret.as_ref(), &[0; N]);
});
} }
/// check that a secret lives, even if its [SecretMemoryPool] is deleted /// check that a secret lives, even if its [SecretMemoryPool] is deleted
#[test] #[test]
fn secret_memory_pool_drop() { fn secret_memory_pool_drop() {
const N: usize = 0x100; test_spawn_process_provided_policies!({
let mut pool = SecretMemoryPool::new(); const N: usize = 0x100;
let secret: ZeroizingSecretBox<[u8; N]> = pool.take(); let mut pool = SecretMemoryPool::new();
std::mem::drop(pool); let secret: ZeroizingSecretBox<[u8; N]> = pool.take();
assert_eq!(secret.as_ref(), &[0; N]); std::mem::drop(pool);
assert_eq!(secret.as_ref(), &[0; N]);
});
} }
/// check that a secret can be reborn, freshly initialized with zero /// check that a secret can be reborn, freshly initialized with zero
#[test] #[test]
fn secret_memory_pool_release() { fn secret_memory_pool_release() {
const N: usize = 1; test_spawn_process_provided_policies!({
let mut pool = SecretMemoryPool::new(); const N: usize = 1;
let mut secret: ZeroizingSecretBox<[u8; N]> = pool.take(); let mut pool = SecretMemoryPool::new();
let old_secret_ptr = secret.as_ref().as_ptr(); let mut secret: ZeroizingSecretBox<[u8; N]> = pool.take();
let old_secret_ptr = secret.as_ref().as_ptr();
secret.as_mut()[0] = 0x13; secret.as_mut()[0] = 0x13;
pool.release(secret); pool.release(secret);
// now check that we get the same ptr // now check that we get the same ptr
let new_secret: ZeroizingSecretBox<[u8; N]> = pool.take(); let new_secret: ZeroizingSecretBox<[u8; N]> = pool.take();
assert_eq!(old_secret_ptr, new_secret.as_ref().as_ptr()); assert_eq!(old_secret_ptr, new_secret.as_ref().as_ptr());
// and that the secret was zeroized // and that the secret was zeroized
assert_eq!(new_secret.as_ref(), &[0; N]); assert_eq!(new_secret.as_ref(), &[0; N]);
});
} }
/// test loading a secret from an example file, and then storing it again in a different file /// test loading a secret from an example file, and then storing it again in a different file
#[test] #[test]
fn test_secret_load_store() { fn test_secret_load_store() {
const N: usize = 100; test_spawn_process_provided_policies!({
const N: usize = 100;
// Generate original random bytes // Generate original random bytes
let original_bytes: [u8; N] = [rand::random(); N]; let original_bytes: [u8; N] = [rand::random(); N];
// Create a temporary directory // Create a temporary directory
let temp_dir = tempdir().unwrap(); let temp_dir = tempdir().unwrap();
// Store the original secret to an example file in the temporary directory // Store the original secret to an example file in the temporary directory
let example_file = temp_dir.path().join("example_file"); let example_file = temp_dir.path().join("example_file");
std::fs::write(example_file.clone(), &original_bytes).unwrap(); std::fs::write(example_file.clone(), &original_bytes).unwrap();
// Load the secret from the example file // Load the secret from the example file
let loaded_secret = Secret::load(&example_file).unwrap(); let loaded_secret = Secret::load(&example_file).unwrap();
// Check that the loaded secret matches the original bytes // Check that the loaded secret matches the original bytes
assert_eq!(loaded_secret.secret(), &original_bytes); assert_eq!(loaded_secret.secret(), &original_bytes);
// Store the loaded secret to a different file in the temporary directory // Store the loaded secret to a different file in the temporary directory
let new_file = temp_dir.path().join("new_file"); let new_file = temp_dir.path().join("new_file");
loaded_secret.store(&new_file).unwrap(); loaded_secret.store(&new_file).unwrap();
// Read the contents of the new file // Read the contents of the new file
let new_file_contents = fs::read(&new_file).unwrap(); let new_file_contents = fs::read(&new_file).unwrap();
// Read the contents of the original file // Read the contents of the original file
let original_file_contents = fs::read(&example_file).unwrap(); let original_file_contents = fs::read(&example_file).unwrap();
// Check that the contents of the new file match the original file // Check that the contents of the new file match the original file
assert_eq!(new_file_contents, original_file_contents); assert_eq!(new_file_contents, original_file_contents);
});
} }
/// test loading a base64 encoded secret from an example file, and then storing it again in a different file /// test loading a base64 encoded secret from an example file, and then storing it again in a different file
#[test] #[test]
fn test_secret_load_store_base64() { fn test_secret_load_store_base64() {
const N: usize = 100; test_spawn_process_provided_policies!({
// Generate original random bytes const N: usize = 100;
let original_bytes: [u8; N] = [rand::random(); N]; // Generate original random bytes
// Create a temporary directory let original_bytes: [u8; N] = [rand::random(); N];
let temp_dir = tempdir().unwrap(); // Create a temporary directory
let example_file = temp_dir.path().join("example_file"); let temp_dir = tempdir().unwrap();
let mut encoded_secret = [0u8; N * 2]; let example_file = temp_dir.path().join("example_file");
let encoded_secret = b64_encode(&original_bytes, &mut encoded_secret).unwrap(); let mut encoded_secret = [0u8; N * 2];
let encoded_secret = b64_encode(&original_bytes, &mut encoded_secret).unwrap();
std::fs::write(&example_file, encoded_secret).unwrap(); std::fs::write(&example_file, encoded_secret).unwrap();
// Load the secret from the example file // Load the secret from the example file
let loaded_secret = Secret::load_b64::<{ N * 2 }, _>(&example_file).unwrap(); let loaded_secret = Secret::load_b64::<{ N * 2 }, _>(&example_file).unwrap();
// Check that the loaded secret matches the original bytes // Check that the loaded secret matches the original bytes
assert_eq!(loaded_secret.secret(), &original_bytes); assert_eq!(loaded_secret.secret(), &original_bytes);
// Store the loaded secret to a different file in the temporary directory // Store the loaded secret to a different file in the temporary directory
let new_file = temp_dir.path().join("new_file"); let new_file = temp_dir.path().join("new_file");
loaded_secret.store_b64::<{ N * 2 }, _>(&new_file).unwrap(); loaded_secret.store_b64::<{ N * 2 }, _>(&new_file).unwrap();
// Read the contents of the new file // Read the contents of the new file
let new_file_contents = fs::read(&new_file).unwrap(); let new_file_contents = fs::read(&new_file).unwrap();
// Read the contents of the original file // Read the contents of the original file
let original_file_contents = fs::read(&example_file).unwrap(); let original_file_contents = fs::read(&example_file).unwrap();
// Check that the contents of the new file match the original file // Check that the contents of the new file match the original file
assert_eq!(new_file_contents, original_file_contents); assert_eq!(new_file_contents, original_file_contents);
//Check new file permissions are secret //Check new file permissions are secret
let metadata = fs::metadata(&new_file).unwrap(); let metadata = fs::metadata(&new_file).unwrap();
assert_eq!(metadata.permissions().mode() & 0o000777, 0o600); assert_eq!(metadata.permissions().mode() & 0o000777, 0o600);
// Store the loaded secret to a different file in the temporary directory for a second time // Store the loaded secret to a different file in the temporary directory for a second time
let new_file = temp_dir.path().join("new_file_writer"); let new_file = temp_dir.path().join("new_file_writer");
let new_file_writer = fopen_w(new_file.clone(), Visibility::Secret).unwrap(); let new_file_writer = fopen_w(new_file.clone(), Visibility::Secret).unwrap();
loaded_secret loaded_secret
.store_b64_writer::<{ N * 2 }, _>(&new_file_writer) .store_b64_writer::<{ N * 2 }, _>(&new_file_writer)
.unwrap(); .unwrap();
// Read the contents of the new file // Read the contents of the new file
let new_file_contents = fs::read(&new_file).unwrap(); let new_file_contents = fs::read(&new_file).unwrap();
// Read the contents of the original file // Read the contents of the original file
let original_file_contents = fs::read(&example_file).unwrap(); let original_file_contents = fs::read(&example_file).unwrap();
// Check that the contents of the new file match the original file // Check that the contents of the new file match the original file
assert_eq!(new_file_contents, original_file_contents); assert_eq!(new_file_contents, original_file_contents);
//Check new file permissions are secret //Check new file permissions are secret
let metadata = fs::metadata(&new_file).unwrap(); let metadata = fs::metadata(&new_file).unwrap();
assert_eq!(metadata.permissions().mode() & 0o000777, 0o600); assert_eq!(metadata.permissions().mode() & 0o000777, 0o600);
});
} }
} }

View File

@@ -33,6 +33,7 @@ rosenpass-util = { workspace = true }
[dev-dependencies] [dev-dependencies]
rand = {workspace = true} rand = {workspace = true}
procspawn = {workspace = true}
[features] [features]
enable_broker_api=[] enable_broker_api=[]

View File

@@ -77,7 +77,7 @@ impl WireGuardBroker for NetlinkWireGuardBroker {
fn set_psk(&mut self, config: SerializedBrokerConfig) -> Result<(), Self::Error> { fn set_psk(&mut self, config: SerializedBrokerConfig) -> Result<(), Self::Error> {
let config: NetworkBrokerConfig = config let config: NetworkBrokerConfig = config
.try_into() .try_into()
.map_err(|e| SetPskError::NoSuchInterface)?; .map_err(|_e| SetPskError::NoSuchInterface)?;
// Ensure that the peer exists by querying the device configuration // Ensure that the peer exists by querying the device configuration
// TODO: Use InvalidInterfaceError // TODO: Use InvalidInterfaceError

View File

@@ -53,82 +53,88 @@ mod integration_tests {
} }
} }
procspawn::enable_test_support!();
#[test] #[test]
fn test_psk_exchanges() { fn test_psk_exchanges() {
const TEST_RUNS: usize = 100; const TEST_RUNS: usize = 100;
let server_broker_inner = Arc::new(Mutex::new(MockServerBrokerInner::default())); use rosenpass_secret_memory::test_spawn_process_provided_policies;
// Create a mock BrokerServer
let server_broker = MockServerBroker::new(server_broker_inner.clone());
let mut server = BrokerServer::<SetPskError, MockServerBroker>::new(server_broker); test_spawn_process_provided_policies!({
let server_broker_inner = Arc::new(Mutex::new(MockServerBrokerInner::default()));
// Create a mock BrokerServer
let server_broker = MockServerBroker::new(server_broker_inner.clone());
let (client_socket, mut server_socket) = mio::net::UnixStream::pair().unwrap(); let mut server = BrokerServer::<SetPskError, MockServerBroker>::new(server_broker);
let (client_socket, mut server_socket) = mio::net::UnixStream::pair().unwrap();
// Spawn a new thread to connect to the unix socket
let handle = std::thread::spawn(move || {
for _ in 0..TEST_RUNS {
// Wait for 8 bytes of length to come in
let mut length_buffer = [0; 8];
while let Err(_err) = server_socket.read_exact(&mut length_buffer) {}
let length = u64::from_le_bytes(length_buffer) as usize;
// Read the amount of length bytes into a buffer
let mut data_buffer = [0; REQUEST_MSG_BUFFER_SIZE];
while let Err(_err) = server_socket.read_exact(&mut data_buffer[0..length]) {}
let mut response = [0; RESPONSE_MSG_BUFFER_SIZE];
server.handle_message(&data_buffer[0..length], &mut response)?;
}
Ok::<(), BrokerServerError>(())
});
// Create a MioBrokerClient and send a psk
let mut client = MioBrokerClient::new(client_socket);
// Spawn a new thread to connect to the unix socket
let handle = std::thread::spawn(move || {
for _ in 0..TEST_RUNS { for _ in 0..TEST_RUNS {
// Wait for 8 bytes of length to come in //Create psk of random 32 bytes
let mut length_buffer = [0; 8]; let psk = Secret::random();
let peer_id = Public::random();
let interface = "test";
let config = SerializedBrokerConfig {
psk: &psk,
peer_id: &peer_id,
interface: interface.as_bytes(),
additional_params: &[],
};
client.set_psk(config).unwrap();
while let Err(_err) = server_socket.read_exact(&mut length_buffer) {} //Sleep for a while to allow the server to process the message
std::thread::sleep(std::time::Duration::from_millis(
rand::thread_rng().gen_range(100..500),
));
let length = u64::from_le_bytes(length_buffer) as usize; let psk = psk.secret().to_owned();
// Read the amount of length bytes into a buffer loop {
let mut data_buffer = [0; REQUEST_MSG_BUFFER_SIZE]; let mut lock = server_broker_inner.try_lock();
while let Err(_err) = server_socket.read_exact(&mut data_buffer[0..length]) {}
let mut response = [0; RESPONSE_MSG_BUFFER_SIZE]; if let Ok(ref mut inner) = lock {
server.handle_message(&data_buffer[0..length], &mut response)?; // Check if the psk is received by the server
} let received_psk = &inner.psk;
Ok::<(), BrokerServerError>(()) assert_eq!(
}); received_psk.as_ref().map(|psk| psk.secret().to_owned()),
Some(psk)
);
// Create a MioBrokerClient and send a psk let recieved_peer_id = inner.peer_id;
let mut client = MioBrokerClient::new(client_socket); assert_eq!(recieved_peer_id, Some(peer_id));
for _ in 0..TEST_RUNS { let target_interface = &inner.interface;
//Create psk of random 32 bytes assert_eq!(target_interface.as_deref(), Some(interface));
let psk = Secret::random();
let peer_id = Public::random();
let interface = "test";
let config = SerializedBrokerConfig {
psk: &psk,
peer_id: &peer_id,
interface: interface.as_bytes(),
additional_params: &[],
};
client.set_psk(config).unwrap();
//Sleep for a while to allow the server to process the message break;
std::thread::sleep(std::time::Duration::from_millis( }
rand::thread_rng().gen_range(100..500),
));
let psk = psk.secret().to_owned();
loop {
let mut lock = server_broker_inner.try_lock();
if let Ok(ref mut inner) = lock {
// Check if the psk is received by the server
let received_psk = &inner.psk;
assert_eq!(
received_psk.as_ref().map(|psk| psk.secret().to_owned()),
Some(psk)
);
let recieved_peer_id = inner.peer_id;
assert_eq!(recieved_peer_id, Some(peer_id));
let target_interface = &inner.interface;
assert_eq!(target_interface.as_deref(), Some(interface));
break;
} }
} }
} handle.join().unwrap().unwrap();
handle.join().unwrap().unwrap(); });
} }
} }