From 76ff535619f0ce6f8be67ecfb039f37e2a939c49 Mon Sep 17 00:00:00 2001 From: hdbg Date: Mon, 16 Feb 2026 22:00:24 +0100 Subject: [PATCH] refactor(server::tests): moved integration-like tests into `tests/` --- server/crates/arbiter-server/src/actors.rs | 4 +- .../arbiter-server/src/actors/bootstrap.rs | 1 - .../src/actors/keyholder/encryption/v1.rs | 1 - .../src/actors/keyholder/mod.rs | 542 +----------------- .../src/actors/user_agent/mod.rs | 5 +- .../src/actors/user_agent/state.rs | 8 +- .../src/actors/user_agent/tests.rs | 446 -------------- .../src/actors/user_agent/transport.rs | 4 +- .../crates/arbiter-server/src/context/tls.rs | 10 +- server/crates/arbiter-server/src/db.rs | 7 +- server/crates/arbiter-server/src/errors.rs | 4 +- .../crates/arbiter-server/tests/common/mod.rs | 43 ++ .../crates/arbiter-server/tests/keyholder.rs | 8 + .../tests/keyholder/concurrency.rs | 173 ++++++ .../tests/keyholder/lifecycle.rs | 134 +++++ .../arbiter-server/tests/keyholder/storage.rs | 161 ++++++ .../crates/arbiter-server/tests/user_agent.rs | 6 + .../arbiter-server/tests/user_agent/auth.rs | 178 ++++++ .../arbiter-server/tests/user_agent/unseal.rs | 229 ++++++++ 19 files changed, 950 insertions(+), 1014 deletions(-) delete mode 100644 server/crates/arbiter-server/src/actors/user_agent/tests.rs create mode 100644 server/crates/arbiter-server/tests/common/mod.rs create mode 100644 server/crates/arbiter-server/tests/keyholder.rs create mode 100644 server/crates/arbiter-server/tests/keyholder/concurrency.rs create mode 100644 server/crates/arbiter-server/tests/keyholder/lifecycle.rs create mode 100644 server/crates/arbiter-server/tests/keyholder/storage.rs create mode 100644 server/crates/arbiter-server/tests/user_agent.rs create mode 100644 server/crates/arbiter-server/tests/user_agent/auth.rs create mode 100644 server/crates/arbiter-server/tests/user_agent/unseal.rs diff --git a/server/crates/arbiter-server/src/actors.rs b/server/crates/arbiter-server/src/actors.rs index 2a6567c..80ca4dd 100644 --- a/server/crates/arbiter-server/src/actors.rs +++ b/server/crates/arbiter-server/src/actors.rs @@ -7,9 +7,9 @@ use crate::{ db, }; -pub(crate) mod bootstrap; +pub mod bootstrap; pub mod client; -pub(crate) mod keyholder; +pub mod keyholder; pub mod user_agent; #[derive(Error, Debug, Diagnostic)] diff --git a/server/crates/arbiter-server/src/actors/bootstrap.rs b/server/crates/arbiter-server/src/actors/bootstrap.rs index 063545c..7119084 100644 --- a/server/crates/arbiter-server/src/actors/bootstrap.rs +++ b/server/crates/arbiter-server/src/actors/bootstrap.rs @@ -92,7 +92,6 @@ impl Bootstrapper { } } -#[cfg(test)] #[messages] impl Bootstrapper { #[message] diff --git a/server/crates/arbiter-server/src/actors/keyholder/encryption/v1.rs b/server/crates/arbiter-server/src/actors/keyholder/encryption/v1.rs index f0d79df..fdc9727 100644 --- a/server/crates/arbiter-server/src/actors/keyholder/encryption/v1.rs +++ b/server/crates/arbiter-server/src/actors/keyholder/encryption/v1.rs @@ -124,7 +124,6 @@ impl KeyCell { let mut cipher = XChaCha20Poly1305::new(key_ref); let nonce = XNonce::from_slice(nonce.0.as_ref()); - let ciphertext = cipher.encrypt( nonce, Payload { diff --git a/server/crates/arbiter-server/src/actors/keyholder/mod.rs b/server/crates/arbiter-server/src/actors/keyholder/mod.rs index 87e9e52..aab15e3 100644 --- a/server/crates/arbiter-server/src/actors/keyholder/mod.rs +++ b/server/crates/arbiter-server/src/actors/keyholder/mod.rs @@ -8,18 +8,15 @@ use memsafe::MemSafe; use strum::{EnumDiscriminants, IntoDiscriminant}; use tracing::{error, info}; -use crate::{ - db::{ - self, - models::{self, RootKeyHistory}, - schema::{self}, - }, +use crate::db::{ + self, + models::{self, RootKeyHistory}, + schema::{self}, }; use encryption::v1::{self, KeyCell, Nonce}; pub mod encryption; - #[derive(Default, EnumDiscriminants)] #[strum_discriminants(derive(Reply), vis(pub))] enum State { @@ -347,13 +344,10 @@ impl KeyHolder { #[cfg(test)] mod tests { - use std::collections::{HashMap, HashSet}; - - use diesel::dsl::{insert_into, sql_query, update}; + use diesel::SelectableHelper; + use diesel::dsl::insert_into; use diesel_async::RunQueryDsl; - use kameo::actor::{ActorRef, Spawn as _}; use memsafe::MemSafe; - use tokio::task::JoinSet; use crate::db::{self, models::ArbiterSetting}; @@ -381,338 +375,6 @@ mod tests { actor } - async fn write_concurrently( - actor: ActorRef, - prefix: &'static str, - count: usize, - ) -> Vec<(i32, Vec)> { - let mut set = JoinSet::new(); - for i in 0..count { - let actor = actor.clone(); - set.spawn(async move { - let plaintext = format!("{prefix}-{i}").into_bytes(); - let id = { - actor - .ask(CreateNew { - plaintext: MemSafe::new(plaintext.clone()).unwrap(), - }) - .await - .unwrap() - }; - (id, plaintext) - }); - } - - let mut out = Vec::with_capacity(count); - while let Some(res) = set.join_next().await { - out.push(res.unwrap()); - } - out - } - - #[tokio::test] - #[test_log::test] - async fn test_bootstrap() { - let db = db::create_test_pool().await; - seed_settings(&db).await; - let mut actor = KeyHolder::new(db.clone()).await.unwrap(); - - assert!(matches!(actor.state, State::Unbootstrapped)); - - let seal_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap(); - actor.bootstrap(seal_key).await.unwrap(); - - assert!(matches!(actor.state, State::Unsealed { .. })); - - let mut conn = db.get().await.unwrap(); - let row: models::RootKeyHistory = schema::root_key_history::table - .select(models::RootKeyHistory::as_select()) - .first(&mut conn) - .await - .unwrap(); - - assert_eq!(row.schema_version, 1); - assert_eq!(row.tag, v1::ROOT_KEY_TAG); - assert!(!row.ciphertext.is_empty()); - assert!(!row.salt.is_empty()); - assert_eq!(row.data_encryption_nonce, v1::Nonce::default().to_vec()); - } - - #[tokio::test] - #[test_log::test] - async fn test_bootstrap_rejects_double() { - let db = db::create_test_pool().await; - let mut actor = bootstrapped_actor(&db).await; - - let seal_key2 = MemSafe::new(b"test-seal-key".to_vec()).unwrap(); - let err = actor.bootstrap(seal_key2).await.unwrap_err(); - assert!(matches!(err, Error::AlreadyBootstrapped)); - } - - #[tokio::test] - #[test_log::test] - async fn test_create_decrypt_roundtrip() { - let db = db::create_test_pool().await; - let mut actor = bootstrapped_actor(&db).await; - - let plaintext = b"hello arbiter"; - let aead_id = actor - .create_new(MemSafe::new(plaintext.to_vec()).unwrap()) - .await - .unwrap(); - - let mut decrypted = actor.decrypt(aead_id).await.unwrap(); - let decrypted = decrypted.read().unwrap(); - assert_eq!(*decrypted, plaintext); - } - - #[tokio::test] - #[test_log::test] - async fn test_create_new_before_bootstrap_fails() { - let db = db::create_test_pool().await; - seed_settings(&db).await; - let mut actor = KeyHolder::new(db).await.unwrap(); - - let err = actor - .create_new(MemSafe::new(b"data".to_vec()).unwrap()) - .await - .unwrap_err(); - assert!(matches!(err, Error::NotBootstrapped)); - } - - #[tokio::test] - #[test_log::test] - async fn test_decrypt_before_bootstrap_fails() { - let db = db::create_test_pool().await; - seed_settings(&db).await; - let mut actor = KeyHolder::new(db).await.unwrap(); - - let err = actor.decrypt(1).await.unwrap_err(); - assert!(matches!(err, Error::NotBootstrapped)); - } - - #[tokio::test] - #[test_log::test] - async fn test_decrypt_nonexistent_returns_not_found() { - let db = db::create_test_pool().await; - let mut actor = bootstrapped_actor(&db).await; - - let err = actor.decrypt(9999).await.unwrap_err(); - assert!(matches!(err, Error::NotFound)); - } - - #[tokio::test] - #[test_log::test] - async fn test_new_restores_sealed_state() { - let db = db::create_test_pool().await; - let actor = bootstrapped_actor(&db).await; - drop(actor); - - let actor2 = KeyHolder::new(db).await.unwrap(); - assert!(matches!(actor2.state, State::Sealed { .. })); - } - - #[tokio::test] - #[test_log::test] - async fn test_nonce_never_reused() { - let db = db::create_test_pool().await; - let mut actor = bootstrapped_actor(&db).await; - - let n = 5; - let mut ids = Vec::with_capacity(n); - for i in 0..n { - let id = actor - .create_new(MemSafe::new(format!("secret {i}").into_bytes()).unwrap()) - .await - .unwrap(); - ids.push(id); - } - - // read all stored nonces from DB - let mut conn = db.get().await.unwrap(); - let rows: Vec = schema::aead_encrypted::table - .select(models::AeadEncrypted::as_select()) - .load(&mut conn) - .await - .unwrap(); - - assert_eq!(rows.len(), n); - - let nonces: Vec<&Vec> = rows.iter().map(|r| &r.current_nonce).collect(); - let unique: HashSet<&Vec> = nonces.iter().copied().collect(); - assert_eq!(nonces.len(), unique.len(), "all nonces must be unique"); - - // verify nonces are sequential increments from 1 - for (i, row) in rows.iter().enumerate() { - let mut expected = v1::Nonce::default(); - for _ in 0..=i { - expected.increment(); - } - assert_eq!(row.current_nonce, expected.to_vec(), "nonce {i} mismatch"); - } - - // verify data_encryption_nonce on root_key_history tracks the latest nonce - let root_row: models::RootKeyHistory = schema::root_key_history::table - .select(models::RootKeyHistory::as_select()) - .first(&mut conn) - .await - .unwrap(); - let last_nonce = &rows.last().unwrap().current_nonce; - assert_eq!( - &root_row.data_encryption_nonce, last_nonce, - "root_key_history must track the latest nonce" - ); - } - - #[tokio::test] - #[test_log::test] - async fn test_unseal_correct_password() { - let db = db::create_test_pool().await; - let mut actor = bootstrapped_actor(&db).await; - - let plaintext = b"survive a restart"; - let aead_id = actor - .create_new(MemSafe::new(plaintext.to_vec()).unwrap()) - .await - .unwrap(); - drop(actor); - - let mut actor = KeyHolder::new(db.clone()).await.unwrap(); - assert!(matches!(actor.state, State::Sealed { .. })); - - let seal_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap(); - actor.try_unseal(seal_key).await.unwrap(); - assert!(matches!(actor.state, State::Unsealed { .. })); - - // previously encrypted data is still decryptable - let mut decrypted = actor.decrypt(aead_id).await.unwrap(); - assert_eq!(*decrypted.read().unwrap(), plaintext); - } - - #[tokio::test] - #[test_log::test] - async fn test_unseal_wrong_then_correct_password() { - let db = db::create_test_pool().await; - let mut actor = bootstrapped_actor(&db).await; - - let plaintext = b"important data"; - let aead_id = actor - .create_new(MemSafe::new(plaintext.to_vec()).unwrap()) - .await - .unwrap(); - drop(actor); - - let mut actor = KeyHolder::new(db.clone()).await.unwrap(); - assert!(matches!(actor.state, State::Sealed { .. })); - - // wrong password - let bad_key = MemSafe::new(b"wrong-password".to_vec()).unwrap(); - let err = actor.try_unseal(bad_key).await.unwrap_err(); - assert!(matches!(err, Error::InvalidKey)); - assert!( - matches!(actor.state, State::Sealed { .. }), - "state must remain Sealed after failed attempt" - ); - - // correct password - let good_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap(); - actor.try_unseal(good_key).await.unwrap(); - assert!(matches!(actor.state, State::Unsealed { .. })); - - let mut decrypted = actor.decrypt(aead_id).await.unwrap(); - assert_eq!(*decrypted.read().unwrap(), plaintext); - } - - #[tokio::test] - #[test_log::test] - async fn test_ciphertext_differs_across_entries() { - let db = db::create_test_pool().await; - let mut actor = bootstrapped_actor(&db).await; - - let plaintext = b"same content"; - let id1 = actor - .create_new(MemSafe::new(plaintext.to_vec()).unwrap()) - .await - .unwrap(); - let id2 = actor - .create_new(MemSafe::new(plaintext.to_vec()).unwrap()) - .await - .unwrap(); - - // different nonces => different ciphertext, even for identical plaintext - let mut conn = db.get().await.unwrap(); - let row1: models::AeadEncrypted = schema::aead_encrypted::table - .filter(schema::aead_encrypted::id.eq(id1)) - .select(models::AeadEncrypted::as_select()) - .first(&mut conn) - .await - .unwrap(); - let row2: models::AeadEncrypted = schema::aead_encrypted::table - .filter(schema::aead_encrypted::id.eq(id2)) - .select(models::AeadEncrypted::as_select()) - .first(&mut conn) - .await - .unwrap(); - - assert_ne!(row1.ciphertext, row2.ciphertext); - - // but both decrypt to the same plaintext - let mut d1 = actor.decrypt(id1).await.unwrap(); - let mut d2 = actor.decrypt(id2).await.unwrap(); - assert_eq!(*d1.read().unwrap(), plaintext); - assert_eq!(*d2.read().unwrap(), plaintext); - } - - #[tokio::test] - #[test_log::test] - async fn concurrent_create_new_no_duplicate_nonces_() { - let db = db::create_test_pool().await; - let actor = KeyHolder::spawn(bootstrapped_actor(&db).await); - - let writes = write_concurrently(actor, "nonce-unique", 32).await; - assert_eq!(writes.len(), 32); - - let mut conn = db.get().await.unwrap(); - let rows: Vec = schema::aead_encrypted::table - .select(models::AeadEncrypted::as_select()) - .load(&mut conn) - .await - .unwrap(); - assert_eq!(rows.len(), 32); - - let nonces: Vec<&Vec> = rows.iter().map(|r| &r.current_nonce).collect(); - let unique: HashSet<&Vec> = nonces.iter().copied().collect(); - assert_eq!(nonces.len(), unique.len(), "all nonces must be unique"); - } - - #[tokio::test] - #[test_log::test] - async fn concurrent_create_new_root_nonce_never_moves_backward() { - let db = db::create_test_pool().await; - let actor = KeyHolder::spawn(bootstrapped_actor(&db).await); - - write_concurrently(actor, "root-max", 24).await; - - let mut conn = db.get().await.unwrap(); - let rows: Vec = schema::aead_encrypted::table - .select(models::AeadEncrypted::as_select()) - .load(&mut conn) - .await - .unwrap(); - let max_nonce = rows - .iter() - .map(|r| r.current_nonce.clone()) - .max() - .expect("at least one row"); - - let root_row: models::RootKeyHistory = schema::root_key_history::table - .select(models::RootKeyHistory::as_select()) - .first(&mut conn) - .await - .unwrap(); - assert_eq!(root_row.data_encryption_nonce, max_nonce); - } - #[tokio::test] #[test_log::test] async fn nonce_monotonic_even_when_nonce_allocation_interleaves() { @@ -757,196 +419,4 @@ mod tests { "next write must advance nonce" ); } - - #[tokio::test] - #[test_log::test] - async fn insert_failure_does_not_create_partial_row() { - let db = db::create_test_pool().await; - let mut actor = bootstrapped_actor(&db).await; - let root_key_history_id = match actor.state { - State::Unsealed { - root_key_history_id, - .. - } => root_key_history_id, - _ => panic!("expected unsealed state"), - }; - - let mut conn = db.get().await.unwrap(); - let before_count: i64 = schema::aead_encrypted::table - .count() - .get_result(&mut conn) - .await - .unwrap(); - let before_root_nonce: Vec = schema::root_key_history::table - .filter(schema::root_key_history::id.eq(root_key_history_id)) - .select(schema::root_key_history::data_encryption_nonce) - .first(&mut conn) - .await - .unwrap(); - - sql_query( - "CREATE TRIGGER fail_aead_insert BEFORE INSERT ON aead_encrypted BEGIN SELECT RAISE(ABORT, 'forced test failure'); END;", - ) - .execute(&mut conn) - .await - .unwrap(); - drop(conn); - - let err = actor - .create_new(MemSafe::new(b"should fail".to_vec()).unwrap()) - .await - .unwrap_err(); - assert!(matches!(err, Error::DatabaseTransaction(_))); - - let mut conn = db.get().await.unwrap(); - sql_query("DROP TRIGGER fail_aead_insert;") - .execute(&mut conn) - .await - .unwrap(); - - let after_count: i64 = schema::aead_encrypted::table - .count() - .get_result(&mut conn) - .await - .unwrap(); - assert_eq!( - before_count, after_count, - "failed insert must not create row" - ); - - let after_root_nonce: Vec = schema::root_key_history::table - .filter(schema::root_key_history::id.eq(root_key_history_id)) - .select(schema::root_key_history::data_encryption_nonce) - .first(&mut conn) - .await - .unwrap(); - assert!( - after_root_nonce > before_root_nonce, - "current behavior allows nonce gap on failed insert" - ); - } - - #[tokio::test] - #[test_log::test] - async fn decrypt_roundtrip_after_high_concurrency() { - let db = db::create_test_pool().await; - let actor = KeyHolder::spawn(bootstrapped_actor(&db).await); - - let writes = write_concurrently(actor, "roundtrip", 40).await; - let expected: HashMap> = writes.into_iter().collect(); - - let mut decryptor = KeyHolder::new(db.clone()).await.unwrap(); - decryptor - .try_unseal(MemSafe::new(b"test-seal-key".to_vec()).unwrap()) - .await - .unwrap(); - - for (id, plaintext) in expected { - let mut decrypted = decryptor.decrypt(id).await.unwrap(); - assert_eq!(*decrypted.read().unwrap(), plaintext); - } - } - - // #[tokio::test] - // #[test_log::test] - // async fn swapping_ciphertext_and_nonce_between_rows_changes_logical_binding() { - // let db = db::create_test_pool().await; - // let mut actor = bootstrapped_actor(&db).await; - - // let plaintext1 = b"entry-one"; - // let plaintext2 = b"entry-two"; - // let id1 = actor - // .create_new(MemSafe::new(plaintext1.to_vec()).unwrap()) - // .await - // .unwrap(); - // let id2 = actor - // .create_new(MemSafe::new(plaintext2.to_vec()).unwrap()) - // .await - // .unwrap(); - - // let mut conn = db.get().await.unwrap(); - // let row1: models::AeadEncrypted = schema::aead_encrypted::table - // .filter(schema::aead_encrypted::id.eq(id1)) - // .select(models::AeadEncrypted::as_select()) - // .first(&mut conn) - // .await - // .unwrap(); - // let row2: models::AeadEncrypted = schema::aead_encrypted::table - // .filter(schema::aead_encrypted::id.eq(id2)) - // .select(models::AeadEncrypted::as_select()) - // .first(&mut conn) - // .await - // .unwrap(); - - // update(schema::aead_encrypted::table.filter(schema::aead_encrypted::id.eq(id1))) - // .set(( - // schema::aead_encrypted::ciphertext.eq(row2.ciphertext.clone()), - // schema::aead_encrypted::current_nonce.eq(row2.current_nonce.clone()), - // )) - // .execute(&mut conn) - // .await - // .unwrap(); - // update(schema::aead_encrypted::table.filter(schema::aead_encrypted::id.eq(id2))) - // .set(( - // schema::aead_encrypted::ciphertext.eq(row1.ciphertext.clone()), - // schema::aead_encrypted::current_nonce.eq(row1.current_nonce.clone()), - // )) - // .execute(&mut conn) - // .await - // .unwrap(); - - // let mut d1 = actor.decrypt(id1).await.unwrap(); - // let mut d2 = actor.decrypt(id2).await.unwrap(); - // assert_eq!(*d1.read().unwrap(), plaintext2); - // assert_eq!(*d2.read().unwrap(), plaintext1); - // } - #[tokio::test] - #[test_log::test] - async fn broken_db_nonce_format_fails_closed() { - // malformed root_key_history nonce must fail create_new - let db = db::create_test_pool().await; - let mut actor = bootstrapped_actor(&db).await; - let root_key_history_id = match actor.state { - State::Unsealed { - root_key_history_id, - .. - } => root_key_history_id, - _ => panic!("expected unsealed state"), - }; - - let mut conn = db.get().await.unwrap(); - update( - schema::root_key_history::table - .filter(schema::root_key_history::id.eq(root_key_history_id)), - ) - .set(schema::root_key_history::data_encryption_nonce.eq(vec![1, 2, 3])) - .execute(&mut conn) - .await - .unwrap(); - drop(conn); - - let err = actor - .create_new(MemSafe::new(b"must fail".to_vec()).unwrap()) - .await - .unwrap_err(); - assert!(matches!(err, Error::BrokenDatabase)); - - // malformed per-row nonce must fail decrypt - let db = db::create_test_pool().await; - let mut actor = bootstrapped_actor(&db).await; - let id = actor - .create_new(MemSafe::new(b"decrypt target".to_vec()).unwrap()) - .await - .unwrap(); - let mut conn = db.get().await.unwrap(); - update(schema::aead_encrypted::table.filter(schema::aead_encrypted::id.eq(id))) - .set(schema::aead_encrypted::current_nonce.eq(vec![7, 8])) - .execute(&mut conn) - .await - .unwrap(); - drop(conn); - - let err = actor.decrypt(id).await.unwrap_err(); - assert!(matches!(err, Error::BrokenDatabase)); - } } diff --git a/server/crates/arbiter-server/src/actors/user_agent/mod.rs b/server/crates/arbiter-server/src/actors/user_agent/mod.rs index 82f6722..700dd40 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/mod.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/mod.rs @@ -35,8 +35,6 @@ use crate::{ }; mod state; -#[cfg(test)] -mod tests; mod transport; pub(crate) use transport::handle_user_agent; @@ -63,8 +61,7 @@ impl UserAgentActor { } } - #[cfg(test)] - pub(crate) fn new_manual( + pub fn new_manual( db: db::DatabasePool, actors: GlobalActors, tx: Sender>, diff --git a/server/crates/arbiter-server/src/actors/user_agent/state.rs b/server/crates/arbiter-server/src/actors/user_agent/state.rs index 0ff5826..f0404b0 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/state.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/state.rs @@ -12,15 +12,11 @@ pub struct ChallengeContext { pub key: VerifyingKey, } - - pub struct UnsealContext { pub client_public_key: PublicKey, pub secret: Mutex>, } - - smlang::statemachine!( name: UserAgent, custom_error: false, @@ -46,10 +42,10 @@ impl UserAgentStateMachineContext for DummyContext { fn generate_temp_keypair(&mut self, event_data: UnsealContext) -> Result { Ok(event_data) } - + #[allow(missing_docs)] #[allow(clippy::unused_unit)] - fn move_challenge< >(&mut self,event_data:ChallengeContext) -> Result { + fn move_challenge(&mut self, event_data: ChallengeContext) -> Result { Ok(event_data) } } diff --git a/server/crates/arbiter-server/src/actors/user_agent/tests.rs b/server/crates/arbiter-server/src/actors/user_agent/tests.rs deleted file mode 100644 index 20f95e4..0000000 --- a/server/crates/arbiter-server/src/actors/user_agent/tests.rs +++ /dev/null @@ -1,446 +0,0 @@ -use arbiter_proto::proto::{ - UnsealEncryptedKey, UnsealResult, UnsealStart, UserAgentResponse, - auth::{self, AuthChallengeRequest, AuthOk}, - user_agent_response::Payload as UserAgentResponsePayload, -}; -use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit}; -use diesel::{ExpressionMethods as _, QueryDsl, insert_into}; -use diesel_async::RunQueryDsl; -use ed25519_dalek::Signer as _; -use kameo::actor::{ActorRef, Spawn}; -use memsafe::MemSafe; -use x25519_dalek::{EphemeralSecret, PublicKey}; - -use crate::{ - actors::{ - GlobalActors, - bootstrap::GetToken, - keyholder::{Bootstrap, Seal}, - user_agent::{ - HandleAuthChallengeRequest, HandleAuthChallengeSolution, HandleUnsealEncryptedKey, - HandleUnsealRequest, - }, - }, - db::{self, models::ArbiterSetting, schema}, -}; - -use super::UserAgentActor; - -async fn seed_settings(db: &db::DatabasePool) { - let mut conn = db.get().await.unwrap(); - insert_into(schema::arbiter_settings::table) - .values(&ArbiterSetting { - id: 1, - root_key_id: None, - cert_key: vec![], - cert: vec![], - }) - .execute(&mut conn) - .await - .unwrap(); -} - -/// Bootstrap keyholder with `seal_key`, and Seal it -/// then create and authenticate a user agent (reaching Idle state). -async fn setup_authenticated_user_agent( - seal_key: &[u8], -) -> (db::DatabasePool, ActorRef) { - let db = db::create_test_pool().await; - seed_settings(&db).await; - - let actors = GlobalActors::spawn(db.clone()).await.unwrap(); - actors - .key_holder - .ask(Bootstrap { - seal_key_raw: MemSafe::new(seal_key.to_vec()).unwrap(), - }) - .await - .unwrap(); - actors.key_holder.ask(Seal).await.unwrap(); - - let user_agent = UserAgentActor::new_manual( - db.clone(), - actors.clone(), - tokio::sync::mpsc::channel(1).0, - ); - let user_agent_ref = UserAgentActor::spawn(user_agent); - let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap(); - - let auth_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); - user_agent_ref - .ask(HandleAuthChallengeRequest { - req: AuthChallengeRequest { - pubkey: auth_key.verifying_key().to_bytes().to_vec(), - bootstrap_token: Some(token), - }, - }) - .await - .unwrap(); - - (db, user_agent_ref) -} - -/// Client side of the DH unseal exchange: -/// sends UnsealStart, derives shared secret, encrypts `key_to_send`. -async fn client_dh_encrypt( - user_agent_ref: &ActorRef, - key_to_send: &[u8], -) -> UnsealEncryptedKey { - let client_secret = EphemeralSecret::random(); - let client_public = PublicKey::from(&client_secret); - - let response = user_agent_ref - .ask(HandleUnsealRequest { - req: UnsealStart { - client_pubkey: client_public.as_bytes().to_vec(), - }, - }) - .await - .unwrap(); - - let server_pubkey = match response.payload.unwrap() { - UserAgentResponsePayload::UnsealStartResponse(resp) => resp.server_pubkey, - other => panic!("Expected UnsealStartResponse, got {other:?}"), - }; - let server_public = PublicKey::from( - <[u8; 32]>::try_from(server_pubkey.as_slice()).unwrap(), - ); - - let shared_secret = client_secret.diffie_hellman(&server_public); - let cipher = XChaCha20Poly1305::new(shared_secret.as_bytes().into()); - let nonce = XNonce::from([0u8; 24]); - let associated_data = b"unseal"; - let mut ciphertext = key_to_send.to_vec(); - cipher - .encrypt_in_place(&nonce, associated_data, &mut ciphertext) - .unwrap(); - - UnsealEncryptedKey { - nonce: nonce.to_vec(), - ciphertext, - associated_data: associated_data.to_vec(), - } -} - -#[tokio::test] -#[test_log::test] -pub async fn test_bootstrap_token_auth() { - let db = db::create_test_pool().await; - seed_settings(&db).await; - let actors = GlobalActors::spawn(db.clone()).await.unwrap(); - let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap(); - let user_agent = UserAgentActor::new_manual( - db.clone(), - actors.clone(), - tokio::sync::mpsc::channel(1).0, // dummy channel, we won't actually send responses in this test - ); - let user_agent_ref = UserAgentActor::spawn(user_agent); - - // simulate client sending auth request with bootstrap token - let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); - let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); - - let result = user_agent_ref - .ask(HandleAuthChallengeRequest { - req: AuthChallengeRequest { - pubkey: pubkey_bytes, - bootstrap_token: Some(token), - }, - }) - .await - .expect("Shouldn't fail to send message"); - - // auth succeeded - assert_eq!( - result, - UserAgentResponse { - payload: Some(UserAgentResponsePayload::AuthMessage( - arbiter_proto::proto::auth::ServerMessage { - payload: Some(arbiter_proto::proto::auth::server_message::Payload::AuthOk( - AuthOk {}, - )), - }, - )), - } - ); - - // key is succesfully recorded in database - let mut conn = db.get().await.unwrap(); - let stored_pubkey: Vec = schema::useragent_client::table - .select(schema::useragent_client::public_key) - .first::>(&mut conn) - .await - .unwrap(); - assert_eq!(stored_pubkey, new_key.verifying_key().to_bytes().to_vec()); -} - -#[tokio::test] -#[test_log::test] -pub async fn test_bootstrap_invalid_token_auth() { - let db = db::create_test_pool().await; - seed_settings(&db).await; - let actors = GlobalActors::spawn(db.clone()).await.unwrap(); - - let user_agent = UserAgentActor::new_manual( - db.clone(), - actors, - tokio::sync::mpsc::channel(1).0, // dummy channel, we won't actually send responses in this test - ); - let user_agent_ref = UserAgentActor::spawn(user_agent); - - // simulate client sending auth request with bootstrap token - let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); - let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); - - let result = user_agent_ref - .ask(HandleAuthChallengeRequest { - req: AuthChallengeRequest { - pubkey: pubkey_bytes, - bootstrap_token: Some("invalid_token".to_string()), - }, - }) - .await; - - match result { - Err(kameo::error::SendError::HandlerError(status)) => { - assert_eq!(status.code(), tonic::Code::InvalidArgument); - insta::assert_debug_snapshot!(status, @r#" - Status { - code: InvalidArgument, - message: "Invalid bootstrap token", - source: None, - } - "#); - } - Err(other) => { - panic!("Expected SendError::HandlerError, got {other:?}"); - } - Ok(_) => { - panic!("Expected error due to invalid bootstrap token, but got success"); - } - } -} - -#[tokio::test] -#[test_log::test] -pub async fn test_challenge_auth() { - let db = db::create_test_pool().await; - seed_settings(&db).await; - - let actors = GlobalActors::spawn(db.clone()).await.unwrap(); - let user_agent = UserAgentActor::new_manual( - db.clone(), - actors, - tokio::sync::mpsc::channel(1).0, // dummy channel, we won't actually send responses in this test - ); - let user_agent_ref = UserAgentActor::spawn(user_agent); - - // simulate client sending auth request with bootstrap token - let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); - let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); - - // insert pubkey into database to trigger challenge-response auth flow - { - let mut conn = db.get().await.unwrap(); - insert_into(schema::useragent_client::table) - .values(schema::useragent_client::public_key.eq(pubkey_bytes.clone())) - .execute(&mut conn) - .await - .unwrap(); - } - - let result = user_agent_ref - .ask(HandleAuthChallengeRequest { - req: AuthChallengeRequest { - pubkey: pubkey_bytes, - bootstrap_token: None, - }, - }) - .await - .expect("Shouldn't fail to send message"); - - // auth challenge succeeded - let UserAgentResponse { - payload: - Some(UserAgentResponsePayload::AuthMessage(arbiter_proto::proto::auth::ServerMessage { - payload: - Some(arbiter_proto::proto::auth::server_message::Payload::AuthChallenge(challenge)), - })), - } = result - else { - panic!("Expected auth challenge response, got {result:?}"); - }; - - let formatted_challenge = arbiter_proto::format_challenge(&challenge); - let signature = new_key.sign(&formatted_challenge); - let serialized_signature = signature.to_bytes().to_vec(); - - let result = user_agent_ref - .ask(HandleAuthChallengeSolution { - solution: auth::AuthChallengeSolution { - signature: serialized_signature, - }, - }) - .await - .expect("Shouldn't fail to send message"); - - // auth succeeded - assert_eq!( - result, - UserAgentResponse { - payload: Some(UserAgentResponsePayload::AuthMessage( - arbiter_proto::proto::auth::ServerMessage { - payload: Some(arbiter_proto::proto::auth::server_message::Payload::AuthOk( - AuthOk {}, - )), - }, - )), - } - ); -} - -#[tokio::test] -#[test_log::test] -pub async fn test_unseal_success() { - let seal_key = b"test-seal-key"; - let (_db, user_agent_ref) = setup_authenticated_user_agent(seal_key).await; - - let encrypted_key = client_dh_encrypt(&user_agent_ref, seal_key).await; - - let response = user_agent_ref - .ask(HandleUnsealEncryptedKey { req: encrypted_key }) - .await - .unwrap(); - - assert_eq!( - response.payload.unwrap(), - UserAgentResponsePayload::UnsealResult(UnsealResult::Success.into()), - ); -} - -#[tokio::test] -#[test_log::test] -pub async fn test_unseal_wrong_seal_key() { - let (_db, user_agent_ref) = setup_authenticated_user_agent(b"correct-key").await; - - // Encrypt a different key through the DH channel - let encrypted_key = client_dh_encrypt(&user_agent_ref, b"wrong-key").await; - - let response = user_agent_ref - .ask(HandleUnsealEncryptedKey { req: encrypted_key }) - .await - .unwrap(); - - assert_eq!( - response.payload.unwrap(), - UserAgentResponsePayload::UnsealResult(UnsealResult::InvalidKey.into()), - ); -} - -#[tokio::test] -#[test_log::test] -pub async fn test_unseal_corrupted_ciphertext() { - let (_db, user_agent_ref) = setup_authenticated_user_agent(b"test-key").await; - - // Do UnsealStart to reach WaitingForUnsealKey state - let client_secret = EphemeralSecret::random(); - let client_public = PublicKey::from(&client_secret); - - user_agent_ref - .ask(HandleUnsealRequest { - req: UnsealStart { - client_pubkey: client_public.as_bytes().to_vec(), - }, - }) - .await - .unwrap(); - - // Send garbage that wasn't encrypted with the DH shared secret - let response = user_agent_ref - .ask(HandleUnsealEncryptedKey { - req: UnsealEncryptedKey { - nonce: vec![0u8; 24], - ciphertext: vec![0u8; 32], - associated_data: vec![], - }, - }) - .await - .unwrap(); - - assert_eq!( - response.payload.unwrap(), - UserAgentResponsePayload::UnsealResult(UnsealResult::InvalidKey.into()), - ); -} - -#[tokio::test] -#[test_log::test] -pub async fn test_unseal_start_without_auth_fails() { - let db = db::create_test_pool().await; - seed_settings(&db).await; - - let actors = GlobalActors::spawn(db.clone()).await.unwrap(); - - let user_agent = UserAgentActor::new_manual( - db.clone(), - actors, - tokio::sync::mpsc::channel(1).0, - ); - let user_agent_ref = UserAgentActor::spawn(user_agent); - - // Try unseal from Init state (not authenticated) - let client_secret = EphemeralSecret::random(); - let client_public = PublicKey::from(&client_secret); - - let result = user_agent_ref - .ask(HandleUnsealRequest { - req: UnsealStart { - client_pubkey: client_public.as_bytes().to_vec(), - }, - }) - .await; - - match result { - Err(kameo::error::SendError::HandlerError(status)) => { - assert_eq!(status.code(), tonic::Code::Internal); - } - other => panic!("Expected state machine error, got {other:?}"), - } -} - -#[tokio::test] -#[test_log::test] -pub async fn test_unseal_retry_after_invalid_key() { - let seal_key = b"real-seal-key"; - let (_db, user_agent_ref) = setup_authenticated_user_agent(seal_key).await; - - // First attempt: wrong key -> InvalidKey, state goes back to Idle - { - let encrypted_key = client_dh_encrypt(&user_agent_ref, b"wrong-key").await; - - let response = user_agent_ref - .ask(HandleUnsealEncryptedKey { req: encrypted_key }) - .await - .unwrap(); - - assert_eq!( - response.payload.unwrap(), - UserAgentResponsePayload::UnsealResult(UnsealResult::InvalidKey.into()), - ); - } - - // Second attempt: correct key -> Success - { - let encrypted_key = client_dh_encrypt(&user_agent_ref, seal_key).await; - - let response = user_agent_ref - .ask(HandleUnsealEncryptedKey { req: encrypted_key }) - .await - .unwrap(); - - assert_eq!( - response.payload.unwrap(), - UserAgentResponsePayload::UnsealResult(UnsealResult::Success.into()), - ); - } -} diff --git a/server/crates/arbiter-server/src/actors/user_agent/transport.rs b/server/crates/arbiter-server/src/actors/user_agent/transport.rs index bf54094..c1ac84c 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/transport.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/transport.rs @@ -1,9 +1,7 @@ use super::UserAgentActor; use arbiter_proto::proto::{ UserAgentRequest, UserAgentResponse, - auth::{ - ClientMessage as ClientAuthMessage, client_message::Payload as ClientAuthPayload, - }, + auth::{ClientMessage as ClientAuthMessage, client_message::Payload as ClientAuthPayload}, user_agent_request::Payload as UserAgentRequestPayload, }; use futures::StreamExt; diff --git a/server/crates/arbiter-server/src/context/tls.rs b/server/crates/arbiter-server/src/context/tls.rs index 267c8b2..424a580 100644 --- a/server/crates/arbiter-server/src/context/tls.rs +++ b/server/crates/arbiter-server/src/context/tls.rs @@ -5,7 +5,6 @@ use rcgen::{Certificate, KeyPair}; use rustls::pki_types::CertificateDer; use thiserror::Error; - #[derive(Error, Debug, Diagnostic)] pub enum TlsInitError { #[error("Key generation error during TLS initialization: {0}")] @@ -41,8 +40,7 @@ impl TlsDataRaw { pub fn deserialize(&self) -> Result { let cert = CertificateDer::from_slice(&self.cert).into_owned(); - let key = - String::from_utf8(self.key.clone()).map_err(TlsInitError::KeyInvalidFormat)?; + let key = String::from_utf8(self.key.clone()).map_err(TlsInitError::KeyInvalidFormat)?; let keypair = KeyPair::from_pem(&key).map_err(TlsInitError::KeyDeserializationError)?; @@ -51,10 +49,8 @@ impl TlsDataRaw { } fn generate_cert(key: &KeyPair) -> Result { - let params = rcgen::CertificateParams::new(vec![ - "arbiter.local".to_string(), - "localhost".to_string(), - ])?; + let params = + rcgen::CertificateParams::new(vec!["arbiter.local".to_string(), "localhost".to_string()])?; params.self_signed(key) } diff --git a/server/crates/arbiter-server/src/db.rs b/server/crates/arbiter-server/src/db.rs index 5bb8e9e..a15398e 100644 --- a/server/crates/arbiter-server/src/db.rs +++ b/server/crates/arbiter-server/src/db.rs @@ -1,8 +1,4 @@ - -use diesel::{ - Connection as _, SqliteConnection, - connection::SimpleConnection as _, -}; +use diesel::{Connection as _, SqliteConnection, connection::SimpleConnection as _}; use diesel_async::{ AsyncConnection, SimpleAsyncConnection, pooled_connection::{AsyncDieselConnectionManager, ManagerConfig}, @@ -133,7 +129,6 @@ pub async fn create_pool(url: Option<&str>) -> Result DatabasePool { use rand::distr::{Alphanumeric, SampleString as _}; diff --git a/server/crates/arbiter-server/src/errors.rs b/server/crates/arbiter-server/src/errors.rs index 4115f9c..98dae76 100644 --- a/server/crates/arbiter-server/src/errors.rs +++ b/server/crates/arbiter-server/src/errors.rs @@ -7,7 +7,7 @@ pub trait GrpcStatusExt { impl GrpcStatusExt for Result { fn to_status(self) -> Result { - self.map_err(|e| { + self.map_err(|e| { error!(error = ?e, "Database error"); Status::internal("Database error") }) @@ -21,4 +21,4 @@ impl GrpcStatusExt for Result { Status::internal("Database pool error") }) } -} \ No newline at end of file +} diff --git a/server/crates/arbiter-server/tests/common/mod.rs b/server/crates/arbiter-server/tests/common/mod.rs new file mode 100644 index 0000000..7269edf --- /dev/null +++ b/server/crates/arbiter-server/tests/common/mod.rs @@ -0,0 +1,43 @@ +use arbiter_server::{ + actors::keyholder::KeyHolder, + db::{self, models::ArbiterSetting, schema}, +}; +use diesel::{QueryDsl, insert_into}; +use diesel_async::RunQueryDsl; +use memsafe::MemSafe; + +pub async fn seed_settings(pool: &db::DatabasePool) { + let mut conn = pool.get().await.unwrap(); + insert_into(schema::arbiter_settings::table) + .values(&ArbiterSetting { + id: 1, + root_key_id: None, + cert_key: vec![], + cert: vec![], + }) + .execute(&mut conn) + .await + .unwrap(); +} + +#[allow(dead_code)] +pub async fn bootstrapped_keyholder(db: &db::DatabasePool) -> KeyHolder { + seed_settings(db).await; + let mut actor = KeyHolder::new(db.clone()).await.unwrap(); + actor + .bootstrap(MemSafe::new(b"test-seal-key".to_vec()).unwrap()) + .await + .unwrap(); + actor +} + +#[allow(dead_code)] +pub async fn root_key_history_id(db: &db::DatabasePool) -> i32 { + let mut conn = db.get().await.unwrap(); + let id = schema::arbiter_settings::table + .select(schema::arbiter_settings::root_key_id) + .first::>(&mut conn) + .await + .unwrap(); + id.expect("root_key_id should be set after bootstrap") +} diff --git a/server/crates/arbiter-server/tests/keyholder.rs b/server/crates/arbiter-server/tests/keyholder.rs new file mode 100644 index 0000000..0fa5692 --- /dev/null +++ b/server/crates/arbiter-server/tests/keyholder.rs @@ -0,0 +1,8 @@ +mod common; + +#[path = "keyholder/concurrency.rs"] +mod concurrency; +#[path = "keyholder/lifecycle.rs"] +mod lifecycle; +#[path = "keyholder/storage.rs"] +mod storage; diff --git a/server/crates/arbiter-server/tests/keyholder/concurrency.rs b/server/crates/arbiter-server/tests/keyholder/concurrency.rs new file mode 100644 index 0000000..d34e315 --- /dev/null +++ b/server/crates/arbiter-server/tests/keyholder/concurrency.rs @@ -0,0 +1,173 @@ +use std::collections::{HashMap, HashSet}; + +use arbiter_server::{ + actors::keyholder::{CreateNew, Error, KeyHolder}, + db::{self, models, schema}, +}; +use diesel::{ExpressionMethods as _, QueryDsl, SelectableHelper, dsl::sql_query}; +use diesel_async::RunQueryDsl; +use kameo::actor::{ActorRef, Spawn as _}; +use memsafe::MemSafe; +use tokio::task::JoinSet; + +use crate::common; + +async fn write_concurrently( + actor: ActorRef, + prefix: &'static str, + count: usize, +) -> Vec<(i32, Vec)> { + let mut set = JoinSet::new(); + for i in 0..count { + let actor = actor.clone(); + set.spawn(async move { + let plaintext = format!("{prefix}-{i}").into_bytes(); + let id = actor + .ask(CreateNew { + plaintext: MemSafe::new(plaintext.clone()).unwrap(), + }) + .await + .unwrap(); + (id, plaintext) + }); + } + + let mut out = Vec::with_capacity(count); + while let Some(res) = set.join_next().await { + out.push(res.unwrap()); + } + out +} + +#[tokio::test] +#[test_log::test] +async fn concurrent_create_new_no_duplicate_nonces_() { + let db = db::create_test_pool().await; + let actor = KeyHolder::spawn(common::bootstrapped_keyholder(&db).await); + + let writes = write_concurrently(actor, "nonce-unique", 32).await; + assert_eq!(writes.len(), 32); + + let mut conn = db.get().await.unwrap(); + let rows: Vec = schema::aead_encrypted::table + .select(models::AeadEncrypted::as_select()) + .load(&mut conn) + .await + .unwrap(); + assert_eq!(rows.len(), 32); + + let nonces: Vec<&Vec> = rows.iter().map(|r| &r.current_nonce).collect(); + let unique: HashSet<&Vec> = nonces.iter().copied().collect(); + assert_eq!(nonces.len(), unique.len(), "all nonces must be unique"); +} + +#[tokio::test] +#[test_log::test] +async fn concurrent_create_new_root_nonce_never_moves_backward() { + let db = db::create_test_pool().await; + let actor = KeyHolder::spawn(common::bootstrapped_keyholder(&db).await); + + write_concurrently(actor, "root-max", 24).await; + + let mut conn = db.get().await.unwrap(); + let rows: Vec = schema::aead_encrypted::table + .select(models::AeadEncrypted::as_select()) + .load(&mut conn) + .await + .unwrap(); + let max_nonce = rows + .iter() + .map(|r| r.current_nonce.clone()) + .max() + .expect("at least one row"); + + let root_row: models::RootKeyHistory = schema::root_key_history::table + .select(models::RootKeyHistory::as_select()) + .first(&mut conn) + .await + .unwrap(); + assert_eq!(root_row.data_encryption_nonce, max_nonce); +} + +#[tokio::test] +#[test_log::test] +async fn insert_failure_does_not_create_partial_row() { + let db = db::create_test_pool().await; + let mut actor = common::bootstrapped_keyholder(&db).await; + let root_key_history_id = common::root_key_history_id(&db).await; + + let mut conn = db.get().await.unwrap(); + let before_count: i64 = schema::aead_encrypted::table + .count() + .get_result(&mut conn) + .await + .unwrap(); + let before_root_nonce: Vec = schema::root_key_history::table + .filter(schema::root_key_history::id.eq(root_key_history_id)) + .select(schema::root_key_history::data_encryption_nonce) + .first(&mut conn) + .await + .unwrap(); + + sql_query( + "CREATE TRIGGER fail_aead_insert BEFORE INSERT ON aead_encrypted BEGIN SELECT RAISE(ABORT, 'forced test failure'); END;", + ) + .execute(&mut conn) + .await + .unwrap(); + drop(conn); + + let err = actor + .create_new(MemSafe::new(b"should fail".to_vec()).unwrap()) + .await + .unwrap_err(); + assert!(matches!(err, Error::DatabaseTransaction(_))); + + let mut conn = db.get().await.unwrap(); + sql_query("DROP TRIGGER fail_aead_insert;") + .execute(&mut conn) + .await + .unwrap(); + + let after_count: i64 = schema::aead_encrypted::table + .count() + .get_result(&mut conn) + .await + .unwrap(); + assert_eq!( + before_count, after_count, + "failed insert must not create row" + ); + + let after_root_nonce: Vec = schema::root_key_history::table + .filter(schema::root_key_history::id.eq(root_key_history_id)) + .select(schema::root_key_history::data_encryption_nonce) + .first(&mut conn) + .await + .unwrap(); + assert!( + after_root_nonce > before_root_nonce, + "current behavior allows nonce gap on failed insert" + ); +} + +#[tokio::test] +#[test_log::test] +async fn decrypt_roundtrip_after_high_concurrency() { + let db = db::create_test_pool().await; + let actor = KeyHolder::spawn(common::bootstrapped_keyholder(&db).await); + + let writes = write_concurrently(actor, "roundtrip", 40).await; + let expected: HashMap> = writes.into_iter().collect(); + + let mut decryptor = KeyHolder::new(db.clone()).await.unwrap(); + decryptor + .try_unseal(MemSafe::new(b"test-seal-key".to_vec()).unwrap()) + .await + .unwrap(); + + for (id, plaintext) in expected { + let mut decrypted = decryptor.decrypt(id).await.unwrap(); + assert_eq!(*decrypted.read().unwrap(), plaintext); + } +} diff --git a/server/crates/arbiter-server/tests/keyholder/lifecycle.rs b/server/crates/arbiter-server/tests/keyholder/lifecycle.rs new file mode 100644 index 0000000..7c633fc --- /dev/null +++ b/server/crates/arbiter-server/tests/keyholder/lifecycle.rs @@ -0,0 +1,134 @@ +use arbiter_server::{ + actors::keyholder::{Error, KeyHolder}, + db::{self, models, schema}, +}; +use diesel::{QueryDsl, SelectableHelper}; +use diesel_async::RunQueryDsl; +use memsafe::MemSafe; + +use crate::common; + +#[tokio::test] +#[test_log::test] +async fn test_bootstrap() { + let db = db::create_test_pool().await; + common::seed_settings(&db).await; + let mut actor = KeyHolder::new(db.clone()).await.unwrap(); + + let seal_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap(); + actor.bootstrap(seal_key).await.unwrap(); + + let mut conn = db.get().await.unwrap(); + let row: models::RootKeyHistory = schema::root_key_history::table + .select(models::RootKeyHistory::as_select()) + .first(&mut conn) + .await + .unwrap(); + + assert_eq!(row.schema_version, 1); + assert_eq!( + row.tag, + arbiter_server::actors::keyholder::encryption::v1::ROOT_KEY_TAG + ); + assert!(!row.ciphertext.is_empty()); + assert!(!row.salt.is_empty()); + assert_eq!( + row.data_encryption_nonce, + arbiter_server::actors::keyholder::encryption::v1::Nonce::default().to_vec() + ); +} + +#[tokio::test] +#[test_log::test] +async fn test_bootstrap_rejects_double() { + let db = db::create_test_pool().await; + let mut actor = common::bootstrapped_keyholder(&db).await; + + let seal_key2 = MemSafe::new(b"test-seal-key".to_vec()).unwrap(); + let err = actor.bootstrap(seal_key2).await.unwrap_err(); + assert!(matches!(err, Error::AlreadyBootstrapped)); +} + +#[tokio::test] +#[test_log::test] +async fn test_create_new_before_bootstrap_fails() { + let db = db::create_test_pool().await; + common::seed_settings(&db).await; + let mut actor = KeyHolder::new(db).await.unwrap(); + + let err = actor + .create_new(MemSafe::new(b"data".to_vec()).unwrap()) + .await + .unwrap_err(); + assert!(matches!(err, Error::NotBootstrapped)); +} + +#[tokio::test] +#[test_log::test] +async fn test_decrypt_before_bootstrap_fails() { + let db = db::create_test_pool().await; + common::seed_settings(&db).await; + let mut actor = KeyHolder::new(db).await.unwrap(); + + let err = actor.decrypt(1).await.unwrap_err(); + assert!(matches!(err, Error::NotBootstrapped)); +} + +#[tokio::test] +#[test_log::test] +async fn test_new_restores_sealed_state() { + let db = db::create_test_pool().await; + let actor = common::bootstrapped_keyholder(&db).await; + drop(actor); + + let mut actor2 = KeyHolder::new(db).await.unwrap(); + let err = actor2.decrypt(1).await.unwrap_err(); + assert!(matches!(err, Error::NotBootstrapped)); +} + +#[tokio::test] +#[test_log::test] +async fn test_unseal_correct_password() { + let db = db::create_test_pool().await; + let mut actor = common::bootstrapped_keyholder(&db).await; + + let plaintext = b"survive a restart"; + let aead_id = actor + .create_new(MemSafe::new(plaintext.to_vec()).unwrap()) + .await + .unwrap(); + drop(actor); + + let mut actor = KeyHolder::new(db.clone()).await.unwrap(); + let seal_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap(); + actor.try_unseal(seal_key).await.unwrap(); + + let mut decrypted = actor.decrypt(aead_id).await.unwrap(); + assert_eq!(*decrypted.read().unwrap(), plaintext); +} + +#[tokio::test] +#[test_log::test] +async fn test_unseal_wrong_then_correct_password() { + let db = db::create_test_pool().await; + let mut actor = common::bootstrapped_keyholder(&db).await; + + let plaintext = b"important data"; + let aead_id = actor + .create_new(MemSafe::new(plaintext.to_vec()).unwrap()) + .await + .unwrap(); + drop(actor); + + let mut actor = KeyHolder::new(db.clone()).await.unwrap(); + + let bad_key = MemSafe::new(b"wrong-password".to_vec()).unwrap(); + let err = actor.try_unseal(bad_key).await.unwrap_err(); + assert!(matches!(err, Error::InvalidKey)); + + let good_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap(); + actor.try_unseal(good_key).await.unwrap(); + + let mut decrypted = actor.decrypt(aead_id).await.unwrap(); + assert_eq!(*decrypted.read().unwrap(), plaintext); +} diff --git a/server/crates/arbiter-server/tests/keyholder/storage.rs b/server/crates/arbiter-server/tests/keyholder/storage.rs new file mode 100644 index 0000000..e595339 --- /dev/null +++ b/server/crates/arbiter-server/tests/keyholder/storage.rs @@ -0,0 +1,161 @@ +use std::collections::HashSet; + +use arbiter_server::{ + actors::keyholder::{Error, encryption::v1}, + db::{self, models, schema}, +}; +use diesel::{ExpressionMethods as _, QueryDsl, SelectableHelper, dsl::update}; +use diesel_async::RunQueryDsl; +use memsafe::MemSafe; + +use crate::common; + +#[tokio::test] +#[test_log::test] +async fn test_create_decrypt_roundtrip() { + let db = db::create_test_pool().await; + let mut actor = common::bootstrapped_keyholder(&db).await; + + let plaintext = b"hello arbiter"; + let aead_id = actor + .create_new(MemSafe::new(plaintext.to_vec()).unwrap()) + .await + .unwrap(); + + let mut decrypted = actor.decrypt(aead_id).await.unwrap(); + assert_eq!(*decrypted.read().unwrap(), plaintext); +} + +#[tokio::test] +#[test_log::test] +async fn test_decrypt_nonexistent_returns_not_found() { + let db = db::create_test_pool().await; + let mut actor = common::bootstrapped_keyholder(&db).await; + + let err = actor.decrypt(9999).await.unwrap_err(); + assert!(matches!(err, Error::NotFound)); +} + +#[tokio::test] +#[test_log::test] +async fn test_ciphertext_differs_across_entries() { + let db = db::create_test_pool().await; + let mut actor = common::bootstrapped_keyholder(&db).await; + + let plaintext = b"same content"; + let id1 = actor + .create_new(MemSafe::new(plaintext.to_vec()).unwrap()) + .await + .unwrap(); + let id2 = actor + .create_new(MemSafe::new(plaintext.to_vec()).unwrap()) + .await + .unwrap(); + + let mut conn = db.get().await.unwrap(); + let row1: models::AeadEncrypted = schema::aead_encrypted::table + .filter(schema::aead_encrypted::id.eq(id1)) + .select(models::AeadEncrypted::as_select()) + .first(&mut conn) + .await + .unwrap(); + let row2: models::AeadEncrypted = schema::aead_encrypted::table + .filter(schema::aead_encrypted::id.eq(id2)) + .select(models::AeadEncrypted::as_select()) + .first(&mut conn) + .await + .unwrap(); + + assert_ne!(row1.ciphertext, row2.ciphertext); + + let mut d1 = actor.decrypt(id1).await.unwrap(); + let mut d2 = actor.decrypt(id2).await.unwrap(); + assert_eq!(*d1.read().unwrap(), plaintext); + assert_eq!(*d2.read().unwrap(), plaintext); +} + +#[tokio::test] +#[test_log::test] +async fn test_nonce_never_reused() { + let db = db::create_test_pool().await; + let mut actor = common::bootstrapped_keyholder(&db).await; + + let n = 5; + for i in 0..n { + actor + .create_new(MemSafe::new(format!("secret {i}").into_bytes()).unwrap()) + .await + .unwrap(); + } + + let mut conn = db.get().await.unwrap(); + let rows: Vec = schema::aead_encrypted::table + .select(models::AeadEncrypted::as_select()) + .load(&mut conn) + .await + .unwrap(); + + assert_eq!(rows.len(), n); + + let nonces: Vec<&Vec> = rows.iter().map(|r| &r.current_nonce).collect(); + let unique: HashSet<&Vec> = nonces.iter().copied().collect(); + assert_eq!(nonces.len(), unique.len(), "all nonces must be unique"); + + for (i, row) in rows.iter().enumerate() { + let mut expected = v1::Nonce::default(); + for _ in 0..=i { + expected.increment(); + } + assert_eq!(row.current_nonce, expected.to_vec(), "nonce {i} mismatch"); + } + + let root_row: models::RootKeyHistory = schema::root_key_history::table + .select(models::RootKeyHistory::as_select()) + .first(&mut conn) + .await + .unwrap(); + let last_nonce = &rows.last().unwrap().current_nonce; + assert_eq!(&root_row.data_encryption_nonce, last_nonce); +} + +#[tokio::test] +#[test_log::test] +async fn broken_db_nonce_format_fails_closed() { + let db = db::create_test_pool().await; + let mut actor = common::bootstrapped_keyholder(&db).await; + let root_key_history_id = common::root_key_history_id(&db).await; + + let mut conn = db.get().await.unwrap(); + update( + schema::root_key_history::table + .filter(schema::root_key_history::id.eq(root_key_history_id)), + ) + .set(schema::root_key_history::data_encryption_nonce.eq(vec![1, 2, 3])) + .execute(&mut conn) + .await + .unwrap(); + drop(conn); + + let err = actor + .create_new(MemSafe::new(b"must fail".to_vec()).unwrap()) + .await + .unwrap_err(); + assert!(matches!(err, Error::BrokenDatabase)); + + let db = db::create_test_pool().await; + let mut actor = common::bootstrapped_keyholder(&db).await; + let id = actor + .create_new(MemSafe::new(b"decrypt target".to_vec()).unwrap()) + .await + .unwrap(); + let mut conn = db.get().await.unwrap(); + update(schema::aead_encrypted::table.filter(schema::aead_encrypted::id.eq(id))) + .set(schema::aead_encrypted::current_nonce.eq(vec![7, 8])) + .execute(&mut conn) + .await + .unwrap(); + drop(conn); + + let err = actor.decrypt(id).await.unwrap_err(); + assert!(matches!(err, Error::BrokenDatabase)); +} diff --git a/server/crates/arbiter-server/tests/user_agent.rs b/server/crates/arbiter-server/tests/user_agent.rs new file mode 100644 index 0000000..dcd9789 --- /dev/null +++ b/server/crates/arbiter-server/tests/user_agent.rs @@ -0,0 +1,6 @@ +mod common; + +#[path = "user_agent/auth.rs"] +mod auth; +#[path = "user_agent/unseal.rs"] +mod unseal; diff --git a/server/crates/arbiter-server/tests/user_agent/auth.rs b/server/crates/arbiter-server/tests/user_agent/auth.rs new file mode 100644 index 0000000..98d921d --- /dev/null +++ b/server/crates/arbiter-server/tests/user_agent/auth.rs @@ -0,0 +1,178 @@ +use arbiter_proto::proto::{ + UserAgentResponse, + auth::{self, AuthChallengeRequest, AuthOk}, + user_agent_response::Payload as UserAgentResponsePayload, +}; +use arbiter_server::{ + actors::{ + GlobalActors, + bootstrap::GetToken, + user_agent::{HandleAuthChallengeRequest, HandleAuthChallengeSolution, UserAgentActor}, + }, + db::{self, schema}, +}; +use diesel::{ExpressionMethods as _, QueryDsl, insert_into}; +use diesel_async::RunQueryDsl; +use ed25519_dalek::Signer as _; +use kameo::actor::Spawn; + +#[tokio::test] +#[test_log::test] +pub async fn test_bootstrap_token_auth() { + let db =db::create_test_pool().await; + crate::common::seed_settings(&db).await; + + let actors = GlobalActors::spawn(db.clone()).await.unwrap(); + let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap(); + let user_agent = + UserAgentActor::new_manual(db.clone(), actors, tokio::sync::mpsc::channel(1).0); + let user_agent_ref = UserAgentActor::spawn(user_agent); + + let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); + let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); + + let result = user_agent_ref + .ask(HandleAuthChallengeRequest { + req: AuthChallengeRequest { + pubkey: pubkey_bytes, + bootstrap_token: Some(token), + }, + }) + .await + .expect("Shouldn't fail to send message"); + + assert_eq!( + result, + UserAgentResponse { + payload: Some(UserAgentResponsePayload::AuthMessage( + arbiter_proto::proto::auth::ServerMessage { + payload: Some(arbiter_proto::proto::auth::server_message::Payload::AuthOk( + AuthOk {}, + )), + }, + )), + } + ); + + let mut conn = db.get().await.unwrap(); + let stored_pubkey: Vec = schema::useragent_client::table + .select(schema::useragent_client::public_key) + .first::>(&mut conn) + .await + .unwrap(); + assert_eq!(stored_pubkey, new_key.verifying_key().to_bytes().to_vec()); +} + +#[tokio::test] +#[test_log::test] +pub async fn test_bootstrap_invalid_token_auth() { + let db = db::create_test_pool().await; + crate::common::seed_settings(&db).await; + + let actors = GlobalActors::spawn(db.clone()).await.unwrap(); + let user_agent = + UserAgentActor::new_manual(db.clone(), actors, tokio::sync::mpsc::channel(1).0); + let user_agent_ref = UserAgentActor::spawn(user_agent); + + let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); + let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); + + let result = user_agent_ref + .ask(HandleAuthChallengeRequest { + req: AuthChallengeRequest { + pubkey: pubkey_bytes, + bootstrap_token: Some("invalid_token".to_string()), + }, + }) + .await; + + match result { + Err(kameo::error::SendError::HandlerError(status)) => { + assert_eq!(status.code(), tonic::Code::InvalidArgument); + insta::assert_debug_snapshot!(status, @r#" + Status { + code: InvalidArgument, + message: "Invalid bootstrap token", + source: None, + } + "#); + } + Err(other) => { + panic!("Expected SendError::HandlerError, got {other:?}"); + } + Ok(_) => { + panic!("Expected error due to invalid bootstrap token, but got success"); + } + } +} + +#[tokio::test] +#[test_log::test] +pub async fn test_challenge_auth() { + let db = db::create_test_pool().await; + crate::common::seed_settings(&db).await; + + let actors = GlobalActors::spawn(db.clone()).await.unwrap(); + let user_agent = + UserAgentActor::new_manual(db.clone(), actors, tokio::sync::mpsc::channel(1).0); + let user_agent_ref = UserAgentActor::spawn(user_agent); + + let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); + let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); + + { + let mut conn = db.get().await.unwrap(); + insert_into(schema::useragent_client::table) + .values(schema::useragent_client::public_key.eq(pubkey_bytes.clone())) + .execute(&mut conn) + .await + .unwrap(); + } + + let result = user_agent_ref + .ask(HandleAuthChallengeRequest { + req: AuthChallengeRequest { + pubkey: pubkey_bytes, + bootstrap_token: None, + }, + }) + .await + .expect("Shouldn't fail to send message"); + + let UserAgentResponse { + payload: + Some(UserAgentResponsePayload::AuthMessage(arbiter_proto::proto::auth::ServerMessage { + payload: + Some(arbiter_proto::proto::auth::server_message::Payload::AuthChallenge(challenge)), + })), + } = result + else { + panic!("Expected auth challenge response, got {result:?}"); + }; + + let formatted_challenge = arbiter_proto::format_challenge(&challenge); + let signature = new_key.sign(&formatted_challenge); + let serialized_signature = signature.to_bytes().to_vec(); + + let result = user_agent_ref + .ask(HandleAuthChallengeSolution { + solution: auth::AuthChallengeSolution { + signature: serialized_signature, + }, + }) + .await + .expect("Shouldn't fail to send message"); + + assert_eq!( + result, + UserAgentResponse { + payload: Some(UserAgentResponsePayload::AuthMessage( + arbiter_proto::proto::auth::ServerMessage { + payload: Some(arbiter_proto::proto::auth::server_message::Payload::AuthOk( + AuthOk {}, + )), + }, + )), + } + ); +} diff --git a/server/crates/arbiter-server/tests/user_agent/unseal.rs b/server/crates/arbiter-server/tests/user_agent/unseal.rs new file mode 100644 index 0000000..7120935 --- /dev/null +++ b/server/crates/arbiter-server/tests/user_agent/unseal.rs @@ -0,0 +1,229 @@ +use arbiter_proto::proto::{ + UnsealEncryptedKey, UnsealResult, UnsealStart, auth::AuthChallengeRequest, + user_agent_response::Payload as UserAgentResponsePayload, +}; +use arbiter_server::{ + actors::{ + GlobalActors, + bootstrap::GetToken, + keyholder::{Bootstrap, Seal}, + user_agent::{ + HandleAuthChallengeRequest, HandleUnsealEncryptedKey, HandleUnsealRequest, + UserAgentActor, + }, + }, + db, +}; +use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit}; +use kameo::actor::{ActorRef, Spawn}; +use memsafe::MemSafe; +use x25519_dalek::{EphemeralSecret, PublicKey}; + +async fn setup_authenticated_user_agent( + seal_key: &[u8], +) -> (arbiter_server::db::DatabasePool, ActorRef) { + let db = db::create_test_pool().await; + crate::common::seed_settings(&db).await; + + let actors = GlobalActors::spawn(db.clone()).await.unwrap(); + actors + .key_holder + .ask(Bootstrap { + seal_key_raw: MemSafe::new(seal_key.to_vec()).unwrap(), + }) + .await + .unwrap(); + actors.key_holder.ask(Seal).await.unwrap(); + + let user_agent = + UserAgentActor::new_manual(db.clone(), actors.clone(), tokio::sync::mpsc::channel(1).0); + let user_agent_ref = UserAgentActor::spawn(user_agent); + + let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap(); + let auth_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); + user_agent_ref + .ask(HandleAuthChallengeRequest { + req: AuthChallengeRequest { + pubkey: auth_key.verifying_key().to_bytes().to_vec(), + bootstrap_token: Some(token), + }, + }) + .await + .unwrap(); + + (db, user_agent_ref) +} + +async fn client_dh_encrypt( + user_agent_ref: &ActorRef, + key_to_send: &[u8], +) -> UnsealEncryptedKey { + let client_secret = EphemeralSecret::random(); + let client_public = PublicKey::from(&client_secret); + + let response = user_agent_ref + .ask(HandleUnsealRequest { + req: UnsealStart { + client_pubkey: client_public.as_bytes().to_vec(), + }, + }) + .await + .unwrap(); + + let server_pubkey = match response.payload.unwrap() { + UserAgentResponsePayload::UnsealStartResponse(resp) => resp.server_pubkey, + other => panic!("Expected UnsealStartResponse, got {other:?}"), + }; + let server_public = PublicKey::from(<[u8; 32]>::try_from(server_pubkey.as_slice()).unwrap()); + + let shared_secret = client_secret.diffie_hellman(&server_public); + let cipher = XChaCha20Poly1305::new(shared_secret.as_bytes().into()); + let nonce = XNonce::from([0u8; 24]); + let associated_data = b"unseal"; + let mut ciphertext = key_to_send.to_vec(); + cipher + .encrypt_in_place(&nonce, associated_data, &mut ciphertext) + .unwrap(); + + UnsealEncryptedKey { + nonce: nonce.to_vec(), + ciphertext, + associated_data: associated_data.to_vec(), + } +} + +#[tokio::test] +#[test_log::test] +pub async fn test_unseal_success() { + let seal_key = b"test-seal-key"; + let (_db, user_agent_ref) = setup_authenticated_user_agent(seal_key).await; + + let encrypted_key = client_dh_encrypt(&user_agent_ref, seal_key).await; + + let response = user_agent_ref + .ask(HandleUnsealEncryptedKey { req: encrypted_key }) + .await + .unwrap(); + + assert_eq!( + response.payload.unwrap(), + UserAgentResponsePayload::UnsealResult(UnsealResult::Success.into()), + ); +} + +#[tokio::test] +#[test_log::test] +pub async fn test_unseal_wrong_seal_key() { + let (_db, user_agent_ref) = setup_authenticated_user_agent(b"correct-key").await; + + let encrypted_key = client_dh_encrypt(&user_agent_ref, b"wrong-key").await; + + let response = user_agent_ref + .ask(HandleUnsealEncryptedKey { req: encrypted_key }) + .await + .unwrap(); + + assert_eq!( + response.payload.unwrap(), + UserAgentResponsePayload::UnsealResult(UnsealResult::InvalidKey.into()), + ); +} + +#[tokio::test] +#[test_log::test] +pub async fn test_unseal_corrupted_ciphertext() { + let (_db, user_agent_ref) = setup_authenticated_user_agent(b"test-key").await; + + let client_secret = EphemeralSecret::random(); + let client_public = PublicKey::from(&client_secret); + + user_agent_ref + .ask(HandleUnsealRequest { + req: UnsealStart { + client_pubkey: client_public.as_bytes().to_vec(), + }, + }) + .await + .unwrap(); + + let response = user_agent_ref + .ask(HandleUnsealEncryptedKey { + req: UnsealEncryptedKey { + nonce: vec![0u8; 24], + ciphertext: vec![0u8; 32], + associated_data: vec![], + }, + }) + .await + .unwrap(); + + assert_eq!( + response.payload.unwrap(), + UserAgentResponsePayload::UnsealResult(UnsealResult::InvalidKey.into()), + ); +} + +#[tokio::test] +#[test_log::test] +pub async fn test_unseal_start_without_auth_fails() { + let db = db::create_test_pool().await; + crate::common::seed_settings(&db).await; + + let actors = GlobalActors::spawn(db.clone()).await.unwrap(); + let user_agent = + UserAgentActor::new_manual(db.clone(), actors, tokio::sync::mpsc::channel(1).0); + let user_agent_ref = UserAgentActor::spawn(user_agent); + + let client_secret = EphemeralSecret::random(); + let client_public = PublicKey::from(&client_secret); + + let result = user_agent_ref + .ask(HandleUnsealRequest { + req: UnsealStart { + client_pubkey: client_public.as_bytes().to_vec(), + }, + }) + .await; + + match result { + Err(kameo::error::SendError::HandlerError(status)) => { + assert_eq!(status.code(), tonic::Code::Internal); + } + other => panic!("Expected state machine error, got {other:?}"), + } +} + +#[tokio::test] +#[test_log::test] +pub async fn test_unseal_retry_after_invalid_key() { + let seal_key = b"real-seal-key"; + let (_db, user_agent_ref) = setup_authenticated_user_agent(seal_key).await; + + { + let encrypted_key = client_dh_encrypt(&user_agent_ref, b"wrong-key").await; + + let response = user_agent_ref + .ask(HandleUnsealEncryptedKey { req: encrypted_key }) + .await + .unwrap(); + + assert_eq!( + response.payload.unwrap(), + UserAgentResponsePayload::UnsealResult(UnsealResult::InvalidKey.into()), + ); + } + + { + let encrypted_key = client_dh_encrypt(&user_agent_ref, seal_key).await; + + let response = user_agent_ref + .ask(HandleUnsealEncryptedKey { req: encrypted_key }) + .await + .unwrap(); + + assert_eq!( + response.payload.unwrap(), + UserAgentResponsePayload::UnsealResult(UnsealResult::Success.into()), + ); + } +}