diff --git a/Cargo.lock b/Cargo.lock index 44e4884..47562ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2068,6 +2068,7 @@ dependencies = [ "rosenpass-wireguard-broker", "rustix", "serde", + "serde_json", "serial_test", "signal-hook", "signal-hook-mio", @@ -2376,9 +2377,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.139" +version = "1.0.140" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6" +checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" dependencies = [ "itoa", "memchr", diff --git a/Cargo.toml b/Cargo.toml index 8e9b19c..61deb60 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,17 +2,17 @@ resolver = "2" members = [ - "rosenpass", - "cipher-traits", - "ciphers", - "util", - "constant-time", - "oqs", - "to", - "fuzz", - "secret-memory", - "rp", - "wireguard-broker", + "rosenpass", + "cipher-traits", + "ciphers", + "util", + "constant-time", + "oqs", + "to", + "fuzz", + "secret-memory", + "rp", + "wireguard-broker", ] default-members = ["rosenpass", "rp", "wireguard-broker"] @@ -42,7 +42,7 @@ toml = "0.7.8" static_assertions = "1.1.0" allocator-api2 = "0.2.14" memsec = { git = "https://github.com/rosenpass/memsec.git", rev = "aceb9baee8aec6844125bd6612f92e9a281373df", features = [ - "alloc_ext", + "alloc_ext", ] } rand = "0.8.5" typenum = "1.17.0" @@ -57,14 +57,14 @@ mio = { version = "1.0.3", features = ["net", "os-poll"] } signal-hook-mio = { version = "0.2.4", features = ["support-v1_0"] } signal-hook = "0.3.17" oqs-sys = { version = "0.9.1", default-features = false, features = [ - 'classic_mceliece', - 'kyber', + 'classic_mceliece', + 'kyber', ] } blake2 = "0.10.6" sha3 = "0.10.8" chacha20poly1305 = { version = "0.10.1", default-features = false, features = [ - "std", - "heapless", + "std", + "heapless", ] } zerocopy = { version = "0.7.35", features = ["derive"] } home = "=0.5.9" # 5.11 requires rustc 1.81 @@ -92,6 +92,7 @@ test_bin = "0.4.0" criterion = "0.5.1" allocator-api2-tests = "0.2.15" procspawn = { version = "1.0.1", features = ["test-support"] } +serde_json = { version = "1.0.140" } #Broker dependencies (might need cleanup or changes) wireguard-uapi = { version = "3.0.0", features = ["xplatform"] } diff --git a/rosenpass/Cargo.toml b/rosenpass/Cargo.toml index 9ce2686..a0e639d 100644 --- a/rosenpass/Cargo.toml +++ b/rosenpass/Cargo.toml @@ -30,9 +30,9 @@ required-features = ["experiment_api", "internal_testing"] [[test]] name = "gen-ipc-msg-types" required-features = [ - "experiment_api", - "internal_testing", - "internal_bin_gen_ipc_msg_types", + "experiment_api", + "internal_testing", + "internal_bin_gen_ipc_msg_types", ] [[bench]] @@ -91,24 +91,24 @@ serial_test = { workspace = true } procspawn = { workspace = true } tempfile = { workspace = true } rustix = { workspace = true } +serde_json = { workspace = true } [features] -#default = ["experiment_libcrux_all"] experiment_cookie_dos_mitigation = [] experiment_memfd_secret = ["rosenpass-wireguard-broker/experiment_memfd_secret"] experiment_libcrux_all = ["rosenpass-ciphers/experiment_libcrux_all"] experiment_libcrux_blake2 = ["rosenpass-ciphers/experiment_libcrux_blake2"] experiment_libcrux_chachapoly = [ - "rosenpass-ciphers/experiment_libcrux_chachapoly", + "rosenpass-ciphers/experiment_libcrux_chachapoly", ] experiment_libcrux_kyber = ["rosenpass-ciphers/experiment_libcrux_kyber"] experiment_api = [ - "hex-literal", - "uds", - "command-fds", - "rustix", - "rosenpass-util/experiment_file_descriptor_passing", - "rosenpass-wireguard-broker/experiment_api", + "hex-literal", + "uds", + "command-fds", + "rustix", + "rosenpass-util/experiment_file_descriptor_passing", + "rosenpass-wireguard-broker/experiment_api", ] internal_testing = [] internal_bin_gen_ipc_msg_types = ["hex", "heck"] diff --git a/rosenpass/benches/trace_handshake.rs b/rosenpass/benches/trace_handshake.rs index 8f55ed5..0966046 100644 --- a/rosenpass/benches/trace_handshake.rs +++ b/rosenpass/benches/trace_handshake.rs @@ -1,6 +1,9 @@ -use std::io::{self, Write}; -use std::time::{Duration, Instant}; -use std::{collections::HashMap, hint::black_box, ops::DerefMut}; +use std::{ + collections::HashMap, + hint::black_box, + ops::DerefMut, + time::{Duration, Instant}, +}; use anyhow::Result; @@ -9,11 +12,12 @@ use libcrux_test_utils::tracing::{EventType, Trace as _}; use rosenpass_cipher_traits::primitives::Kem; use rosenpass_ciphers::StaticKem; use rosenpass_secret_memory::secret_policy_try_use_memfd_secrets; -use rosenpass_util::trace_bench::RpEventType; +use rosenpass_util::trace_bench::RpEvent; use rosenpass::protocol::basic_types::{MsgBuf, SPk, SSk, SymKey}; use rosenpass::protocol::osk_domain_separator::OskDomainSeparator; use rosenpass::protocol::{CryptoServer, HandleMsgResult, PeerPtr, ProtocolVersion}; +use serde::ser::SerializeStruct; const ITERATIONS: usize = 100; @@ -124,15 +128,30 @@ fn main() { (v02, &v03_with_marker[1..]) }; - // Perform statistical analysis on both trace sections and write results as JSON - write_json_arrays( - &mut std::io::stdout(), // Write to standard output - vec![ - ("V02", statistical_analysis(trace_v02.to_vec())), - ("V03", statistical_analysis(trace_v03.to_vec())), - ], - ) - .expect("error writing json data"); + // Perform statistical analysis on both trace sections + let analysis_v02 = statistical_analysis(trace_v02); + let analysis_v03 = statistical_analysis(trace_v03); + + // Transform analysis results to JSON-encodable data type + let stats_v02 = analysis_v02 + .iter() + .map(|(label, agg_stat)| JsonAggregateStat { + protocol_version: "V02", + label, + agg_stat, + }); + let stats_v03 = analysis_v03 + .iter() + .map(|(label, agg_stat)| JsonAggregateStat { + protocol_version: "V03", + label, + agg_stat: &agg_stat, + }); + + // Write results as JSON + let stats_all: Vec<_> = stats_v02.chain(stats_v03).collect(); + let stats_json = serde_json::to_string_pretty(&stats_all).expect("error encoding to json"); + println!("{stats_json}"); } /// Performs a simple statistical analysis: @@ -140,7 +159,7 @@ fn main() { /// - extracts durations of spamns /// - filters out empty bins /// - calculates aggregate statistics (mean, std dev) -fn statistical_analysis(trace: Vec) -> Vec<(&'static str, AggregateStat)> { +fn statistical_analysis(trace: &[RpEvent]) -> Vec<(&'static str, AggregateStat)> { bin_events(trace) .into_iter() .map(|(label, spans)| (label, extract_span_durations(label, spans.as_slice()))) @@ -149,44 +168,6 @@ fn statistical_analysis(trace: Vec) -> Vec<(&'static str, Aggregate .collect() } -/// Takes an iterator of ("protocol_version", iterator_of_stats) pairs and writes them -/// as a single flat JSON array to the provided writer. -/// -/// # Arguments -/// * `w` - The writer to output JSON to (e.g., stdout, file). -/// * `item_groups` - An iterator producing tuples `(version, stats): (&'static str, II)`. -/// Here `II` is itself an iterator producing `(label, agg_stat): (&'static str, AggregateStat)`, -/// where the label is the label of the span, e.g. "IHI2". -/// -/// # Type Parameters -/// * `W` - A type that implements `std::io::Write`. -/// * `II` - An iterator type yielding (`&'static str`, `AggregateStat`). -fn write_json_arrays)>>( - w: &mut W, - item_groups: impl IntoIterator, -) -> io::Result<()> { - // Flatten the groups into a single iterator of (protocol_version, label, stats) - let iter = item_groups.into_iter().flat_map(|(version, items)| { - items - .into_iter() - .map(move |(label, agg_stat)| (version, label, agg_stat)) - }); - let mut delim = ""; // Start with no delimiter - - // Start the JSON array - write!(w, "[")?; - - // Write the flattened statistics as JSON objects, separated by commas. - for (version, label, agg_stat) in iter { - write!(w, "{delim}")?; // Write delimiter (empty for first item, "," for subsequent) - agg_stat.write_json_ns(label, version, w)?; // Write the JSON object for the stat entry - delim = ","; // Set delimiter for the next iteration - } - - // End the JSON array - write!(w, "]") -} - /// Used to group benchmark results in visualizations enum RunTimeGroup { /// For particularly long operations. @@ -239,13 +220,13 @@ enum StatEntry { /// Takes a flat list of events and organizes them into a HashMap where keys /// are event labels and values are vectors of events with that label. -fn bin_events(events: Vec) -> HashMap<&'static str, Vec> { +fn bin_events(events: &[RpEvent]) -> HashMap<&'static str, Vec> { let mut spans = HashMap::<_, Vec<_>>::new(); for event in events { // Get the vector for the event's label, or create a new one let spans_for_label = spans.entry(event.label).or_default(); // Add the event to the vector - spans_for_label.push(event); + spans_for_label.push(event.clone()); } spans } @@ -253,7 +234,7 @@ fn bin_events(events: Vec) -> HashMap<&'static str, Vec Vec { +fn extract_span_durations(label: &str, events: &[RpEvent]) -> Vec { let mut processing_list: Vec = vec![]; // List to track open spans and final durations for entry in events { @@ -313,6 +294,7 @@ fn extract_span_durations(label: &str, events: &[RpEventType]) -> Vec /// Stores the mean, standard deviation, relative standard deviation (sd/mean), /// and the number of samples used for calculation. #[derive(Debug)] +#[allow(dead_code)] struct AggregateStat { /// Average duration. mean_duration: T, @@ -362,32 +344,33 @@ impl AggregateStat { sample_size, } } +} - /// Writes the statistics as a JSON object to the provided writer. - /// Includes metadata like label, protocol_version, OS, architecture, and run time group. - /// - /// # Arguments - /// * `label` - The specific benchmark/span label. - /// * `protocol_version` - Version of the protocol that is benchmarked. - /// * `w` - The output writer (must implement `std::io::Write`). - fn write_json_ns( - &self, - label: &str, - protocol_version: &str, - w: &mut impl io::Write, - ) -> io::Result<()> { - // Format the JSON string using measured values and environment constants - writeln!( - w, - r#"{{"name":"{name}", "unit":"ns/iter", "value":"{value}", "range":"± {range}", "protocol version":"{protocol_version}", "sample size":"{sample_size}", "operating system":"{os}", "architecture":"{arch}", "run time":"{run_time}"}}"#, - name = label, // Benchmark name - value = self.mean_duration.as_nanos(), // Mean duration in nanoseconds - range = self.sd_duration.as_nanos(), // Standard deviation in nanoseconds - sample_size = self.sample_size, // Number of samples - os = std::env::consts::OS, // Operating system - arch = std::env::consts::ARCH, // CPU architecture - run_time = run_time_group(label), // Run time group category (long, medium, etc.) - protocol_version = protocol_version // Overall protocol_version (e.g., protocol version) - ) +struct JsonAggregateStat<'a, T> { + agg_stat: &'a AggregateStat, + label: &'a str, + protocol_version: &'a str, +} + +impl<'a> serde::Serialize for JsonAggregateStat<'a, Duration> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut stat = serializer.serialize_struct("AggregateStat", 9)?; + stat.serialize_field("name", self.label)?; + stat.serialize_field("unit", "ns/iter")?; + stat.serialize_field("value", &self.agg_stat.mean_duration.as_nanos().to_string())?; + stat.serialize_field( + "range", + &format!("± {}", self.agg_stat.sd_duration.as_nanos()), + )?; + stat.serialize_field("protocol version", self.protocol_version)?; + stat.serialize_field("sample size", &self.agg_stat.sample_size)?; + stat.serialize_field("operating system", std::env::consts::OS)?; + stat.serialize_field("architecture", std::env::consts::ARCH)?; + stat.serialize_field("run time", &run_time_group(self.label).to_string())?; + + stat.end() } } diff --git a/util/src/trace_bench.rs b/util/src/trace_bench.rs index 367efa6..3ede377 100644 --- a/util/src/trace_bench.rs +++ b/util/src/trace_bench.rs @@ -10,7 +10,7 @@ static TRACE: OnceLock = OnceLock::new(); pub type RpTrace = tracing::MutexTrace<&'static str, Instant>; /// The trace event type used to trace Rosenpass for performance measurement. -pub type RpEventType = tracing::TraceEvent<&'static str, Instant>; +pub type RpEvent = tracing::TraceEvent<&'static str, Instant>; // Re-export to make functionality available and callers don't need to also directly depend on // [`libcrux_test_utils`].