diff --git a/server/crates/arbiter-server/src/actors/keyholder.rs b/server/crates/arbiter-server/src/actors/keyholder.rs index 2b21327..eeded0d 100644 --- a/server/crates/arbiter-server/src/actors/keyholder.rs +++ b/server/crates/arbiter-server/src/actors/keyholder.rs @@ -68,13 +68,13 @@ pub enum Error { /// Provides API for encrypting and decrypting data using the vault root key. /// Abstraction over database to make sure nonces are never reused and encryption keys are never exposed in plaintext outside of this actor. #[derive(Actor)] -pub struct KeyHolderActor { +pub struct KeyHolder { db: db::DatabasePool, state: State, } #[messages] -impl KeyHolderActor { +impl KeyHolder { pub async fn new(db: db::DatabasePool) -> Result { let state = { let mut conn = db.get().await?; @@ -206,14 +206,16 @@ impl KeyHolderActor { return Err(Error::NotBootstrapped); }; - let mut conn = self.db.get().await?; - - let current_key = schema::root_key_history::table - .filter(schema::root_key_history::id.eq(*root_key_history_id)) - .select((schema::root_key_history::data_encryption_nonce)) - .select((RootKeyHistory::as_select())) - .first(&mut conn) - .await?; + // We don't want to hold connection while doing expensive KDF work + let current_key = { + let mut conn = self.db.get().await?; + schema::root_key_history::table + .filter(schema::root_key_history::id.eq(*root_key_history_id)) + .select((schema::root_key_history::data_encryption_nonce)) + .select((RootKeyHistory::as_select())) + .first(&mut conn) + .await? + }; let salt = ¤t_key.salt; let salt = v1::Salt::try_from(salt.as_slice()).map_err(|_| { @@ -257,14 +259,17 @@ impl KeyHolderActor { let State::Unsealed { root_key, .. } = &mut self.state else { return Err(Error::NotBootstrapped); }; - let mut conn = self.db.get().await?; - let row: models::AeadEncrypted = schema::aead_encrypted::table - .select(models::AeadEncrypted::as_select()) - .filter(schema::aead_encrypted::id.eq(aead_id)) - .first(&mut conn) - .await - .optional()? - .ok_or(Error::NotFound)?; + + let row: models::AeadEncrypted = { + let mut conn = self.db.get().await?; + schema::aead_encrypted::table + .select(models::AeadEncrypted::as_select()) + .filter(schema::aead_encrypted::id.eq(aead_id)) + .first(&mut conn) + .await + .optional()? + .ok_or(Error::NotFound)? + }; let nonce = v1::Nonce::try_from(row.current_nonce.as_slice()).map_err(|_| { error!( @@ -293,14 +298,13 @@ impl KeyHolderActor { // Borrow checker note: &mut borrow a few lines above is disjoint from this field let nonce = Self::get_new_nonce(&self.db, *root_key_history_id).await?; - let mut conn = self.db.get().await?; - let mut ciphertext_buffer = plaintext.write().unwrap(); let ciphertext_buffer: &mut Vec = ciphertext_buffer.as_mut(); root_key.encrypt_in_place(&nonce, v1::TAG, &mut *ciphertext_buffer)?; let ciphertext = std::mem::take(ciphertext_buffer); + let mut conn = self.db.get().await?; let aead_id: i32 = insert_into(schema::aead_encrypted::table) .values(&models::NewAeadEncrypted { ciphertext, @@ -319,11 +323,16 @@ impl KeyHolderActor { #[cfg(test)] mod tests { - use std::collections::HashSet; + use std::collections::{HashMap, HashSet}; + use std::sync::Arc; - use diesel::dsl::insert_into; + use diesel::dsl::{insert_into, sql_query, update}; use diesel_async::RunQueryDsl; + use futures::stream::TryUnfold; + use kameo::actor::{ActorRef, Spawn as _}; use memsafe::MemSafe; + use tokio::sync::Mutex; + use tokio::task::JoinSet; use crate::db::{self, models::ArbiterSetting}; @@ -343,20 +352,49 @@ mod tests { .unwrap(); } - async fn bootstrapped_actor(db: &db::DatabasePool) -> KeyHolderActor { + async fn bootstrapped_actor(db: &db::DatabasePool) -> KeyHolder { seed_settings(db).await; - let mut actor = KeyHolderActor::new(db.clone()).await.unwrap(); + let mut actor = KeyHolder::new(db.clone()).await.unwrap(); let seal_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap(); actor.bootstrap(seal_key).await.unwrap(); actor } + async fn write_concurrently( + actor: ActorRef, + prefix: &'static str, + count: usize, + ) -> Vec<(i32, Vec)> { + let mut set = JoinSet::new(); + for i in 0..count { + let actor = actor.clone(); + set.spawn(async move { + let plaintext = format!("{prefix}-{i}").into_bytes(); + let id = { + actor + .ask(CreateNew { + plaintext: MemSafe::new(plaintext.clone()).unwrap(), + }) + .await + .unwrap() + }; + (id, plaintext) + }); + } + + let mut out = Vec::with_capacity(count); + while let Some(res) = set.join_next().await { + out.push(res.unwrap()); + } + out + } + #[tokio::test] #[test_log::test] async fn test_bootstrap() { let db = db::create_test_pool().await; seed_settings(&db).await; - let mut actor = KeyHolderActor::new(db.clone()).await.unwrap(); + let mut actor = KeyHolder::new(db.clone()).await.unwrap(); assert!(matches!(actor.state, State::Unbootstrapped)); @@ -412,7 +450,7 @@ mod tests { async fn test_create_new_before_bootstrap_fails() { let db = db::create_test_pool().await; seed_settings(&db).await; - let mut actor = KeyHolderActor::new(db).await.unwrap(); + let mut actor = KeyHolder::new(db).await.unwrap(); let err = actor .create_new(MemSafe::new(b"data".to_vec()).unwrap()) @@ -426,7 +464,7 @@ mod tests { async fn test_decrypt_before_bootstrap_fails() { let db = db::create_test_pool().await; seed_settings(&db).await; - let mut actor = KeyHolderActor::new(db).await.unwrap(); + let mut actor = KeyHolder::new(db).await.unwrap(); let err = actor.decrypt(1).await.unwrap_err(); assert!(matches!(err, Error::NotBootstrapped)); @@ -449,7 +487,7 @@ mod tests { let actor = bootstrapped_actor(&db).await; drop(actor); - let actor2 = KeyHolderActor::new(db).await.unwrap(); + let actor2 = KeyHolder::new(db).await.unwrap(); assert!(matches!(actor2.state, State::Sealed { .. })); } @@ -518,7 +556,7 @@ mod tests { .unwrap(); drop(actor); - let mut actor = KeyHolderActor::new(db.clone()).await.unwrap(); + let mut actor = KeyHolder::new(db.clone()).await.unwrap(); assert!(matches!(actor.state, State::Sealed { .. })); let seal_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap(); @@ -543,7 +581,7 @@ mod tests { .unwrap(); drop(actor); - let mut actor = KeyHolderActor::new(db.clone()).await.unwrap(); + let mut actor = KeyHolder::new(db.clone()).await.unwrap(); assert!(matches!(actor.state, State::Sealed { .. })); // wrong password @@ -603,4 +641,291 @@ mod tests { assert_eq!(*d1.read().unwrap(), plaintext); assert_eq!(*d2.read().unwrap(), plaintext); } + + #[tokio::test] + #[test_log::test] + async fn concurrent_create_new_no_duplicate_nonces_() { + let db = db::create_test_pool().await; + let actor = KeyHolder::spawn(bootstrapped_actor(&db).await); + + let writes = write_concurrently(actor, "nonce-unique", 32).await; + assert_eq!(writes.len(), 32); + + let mut conn = db.get().await.unwrap(); + let rows: Vec = schema::aead_encrypted::table + .select(models::AeadEncrypted::as_select()) + .load(&mut conn) + .await + .unwrap(); + assert_eq!(rows.len(), 32); + + let nonces: Vec<&Vec> = rows.iter().map(|r| &r.current_nonce).collect(); + let unique: HashSet<&Vec> = nonces.iter().copied().collect(); + assert_eq!(nonces.len(), unique.len(), "all nonces must be unique"); + } + + #[tokio::test] + #[test_log::test] + async fn concurrent_create_new_root_nonce_never_moves_backward() { + let db = db::create_test_pool().await; + let actor = KeyHolder::spawn(bootstrapped_actor(&db).await); + + write_concurrently(actor, "root-max", 24).await; + + let mut conn = db.get().await.unwrap(); + let rows: Vec = schema::aead_encrypted::table + .select(models::AeadEncrypted::as_select()) + .load(&mut conn) + .await + .unwrap(); + let max_nonce = rows + .iter() + .map(|r| r.current_nonce.clone()) + .max() + .expect("at least one row"); + + let root_row: models::RootKeyHistory = schema::root_key_history::table + .select(models::RootKeyHistory::as_select()) + .first(&mut conn) + .await + .unwrap(); + assert_eq!(root_row.data_encryption_nonce, max_nonce); + } + + #[tokio::test] + #[test_log::test] + async fn nonce_monotonic_even_when_nonce_allocation_interleaves() { + let db = db::create_test_pool().await; + let mut actor = bootstrapped_actor(&db).await; + let root_key_history_id = match actor.state { + State::Unsealed { + root_key_history_id, + .. + } => root_key_history_id, + _ => panic!("expected unsealed state"), + }; + + let n1 = KeyHolder::get_new_nonce(&db, root_key_history_id) + .await + .unwrap(); + let n2 = KeyHolder::get_new_nonce(&db, root_key_history_id) + .await + .unwrap(); + assert!(n2.to_vec() > n1.to_vec(), "nonce must increase"); + + let mut conn = db.get().await.unwrap(); + let root_row: models::RootKeyHistory = schema::root_key_history::table + .select(models::RootKeyHistory::as_select()) + .first(&mut conn) + .await + .unwrap(); + assert_eq!(root_row.data_encryption_nonce, n2.to_vec()); + + let id = actor + .create_new(MemSafe::new(b"post-interleave".to_vec()).unwrap()) + .await + .unwrap(); + let row: models::AeadEncrypted = schema::aead_encrypted::table + .filter(schema::aead_encrypted::id.eq(id)) + .select(models::AeadEncrypted::as_select()) + .first(&mut conn) + .await + .unwrap(); + assert!( + row.current_nonce > n2.to_vec(), + "next write must advance nonce" + ); + } + + #[tokio::test] + #[test_log::test] + async fn insert_failure_does_not_create_partial_row() { + let db = db::create_test_pool().await; + let mut actor = bootstrapped_actor(&db).await; + let root_key_history_id = match actor.state { + State::Unsealed { + root_key_history_id, + .. + } => root_key_history_id, + _ => panic!("expected unsealed state"), + }; + + let mut conn = db.get().await.unwrap(); + let before_count: i64 = schema::aead_encrypted::table + .count() + .get_result(&mut conn) + .await + .unwrap(); + let before_root_nonce: Vec = schema::root_key_history::table + .filter(schema::root_key_history::id.eq(root_key_history_id)) + .select(schema::root_key_history::data_encryption_nonce) + .first(&mut conn) + .await + .unwrap(); + + sql_query( + "CREATE TRIGGER fail_aead_insert BEFORE INSERT ON aead_encrypted BEGIN SELECT RAISE(ABORT, 'forced test failure'); END;", + ) + .execute(&mut conn) + .await + .unwrap(); + drop(conn); + + let err = actor + .create_new(MemSafe::new(b"should fail".to_vec()).unwrap()) + .await + .unwrap_err(); + assert!(matches!(err, Error::DatabaseTransaction(_))); + + let mut conn = db.get().await.unwrap(); + sql_query("DROP TRIGGER fail_aead_insert;") + .execute(&mut conn) + .await + .unwrap(); + + let after_count: i64 = schema::aead_encrypted::table + .count() + .get_result(&mut conn) + .await + .unwrap(); + assert_eq!( + before_count, after_count, + "failed insert must not create row" + ); + + let after_root_nonce: Vec = schema::root_key_history::table + .filter(schema::root_key_history::id.eq(root_key_history_id)) + .select(schema::root_key_history::data_encryption_nonce) + .first(&mut conn) + .await + .unwrap(); + assert!( + after_root_nonce > before_root_nonce, + "current behavior allows nonce gap on failed insert" + ); + } + + #[tokio::test] + #[test_log::test] + async fn decrypt_roundtrip_after_high_concurrency() { + let db = db::create_test_pool().await; + let actor = KeyHolder::spawn(bootstrapped_actor(&db).await); + + let writes = write_concurrently(actor, "roundtrip", 40).await; + let expected: HashMap> = writes.into_iter().collect(); + + let mut decryptor = KeyHolder::new(db.clone()).await.unwrap(); + decryptor + .try_unseal(MemSafe::new(b"test-seal-key".to_vec()).unwrap()) + .await + .unwrap(); + + for (id, plaintext) in expected { + let mut decrypted = decryptor.decrypt(id).await.unwrap(); + assert_eq!(*decrypted.read().unwrap(), plaintext); + } + } + + #[tokio::test] + #[test_log::test] + async fn swapping_ciphertext_and_nonce_between_rows_changes_logical_binding() { + let db = db::create_test_pool().await; + let mut actor = bootstrapped_actor(&db).await; + + let plaintext1 = b"entry-one"; + let plaintext2 = b"entry-two"; + let id1 = actor + .create_new(MemSafe::new(plaintext1.to_vec()).unwrap()) + .await + .unwrap(); + let id2 = actor + .create_new(MemSafe::new(plaintext2.to_vec()).unwrap()) + .await + .unwrap(); + + let mut conn = db.get().await.unwrap(); + let row1: models::AeadEncrypted = schema::aead_encrypted::table + .filter(schema::aead_encrypted::id.eq(id1)) + .select(models::AeadEncrypted::as_select()) + .first(&mut conn) + .await + .unwrap(); + let row2: models::AeadEncrypted = schema::aead_encrypted::table + .filter(schema::aead_encrypted::id.eq(id2)) + .select(models::AeadEncrypted::as_select()) + .first(&mut conn) + .await + .unwrap(); + + update(schema::aead_encrypted::table.filter(schema::aead_encrypted::id.eq(id1))) + .set(( + schema::aead_encrypted::ciphertext.eq(row2.ciphertext.clone()), + schema::aead_encrypted::current_nonce.eq(row2.current_nonce.clone()), + )) + .execute(&mut conn) + .await + .unwrap(); + update(schema::aead_encrypted::table.filter(schema::aead_encrypted::id.eq(id2))) + .set(( + schema::aead_encrypted::ciphertext.eq(row1.ciphertext.clone()), + schema::aead_encrypted::current_nonce.eq(row1.current_nonce.clone()), + )) + .execute(&mut conn) + .await + .unwrap(); + + let mut d1 = actor.decrypt(id1).await.unwrap(); + let mut d2 = actor.decrypt(id2).await.unwrap(); + assert_eq!(*d1.read().unwrap(), plaintext2); + assert_eq!(*d2.read().unwrap(), plaintext1); + } + #[tokio::test] + #[test_log::test] + async fn broken_db_nonce_format_fails_closed() { + // malformed root_key_history nonce must fail create_new + let db = db::create_test_pool().await; + let mut actor = bootstrapped_actor(&db).await; + let root_key_history_id = match actor.state { + State::Unsealed { + root_key_history_id, + .. + } => root_key_history_id, + _ => panic!("expected unsealed state"), + }; + + let mut conn = db.get().await.unwrap(); + update( + schema::root_key_history::table + .filter(schema::root_key_history::id.eq(root_key_history_id)), + ) + .set(schema::root_key_history::data_encryption_nonce.eq(vec![1, 2, 3])) + .execute(&mut conn) + .await + .unwrap(); + drop(conn); + + let err = actor + .create_new(MemSafe::new(b"must fail".to_vec()).unwrap()) + .await + .unwrap_err(); + assert!(matches!(err, Error::BrokenDatabase)); + + // malformed per-row nonce must fail decrypt + let db = db::create_test_pool().await; + let mut actor = bootstrapped_actor(&db).await; + let id = actor + .create_new(MemSafe::new(b"decrypt target".to_vec()).unwrap()) + .await + .unwrap(); + let mut conn = db.get().await.unwrap(); + update(schema::aead_encrypted::table.filter(schema::aead_encrypted::id.eq(id))) + .set(schema::aead_encrypted::current_nonce.eq(vec![7, 8])) + .execute(&mut conn) + .await + .unwrap(); + drop(conn); + + let err = actor.decrypt(id).await.unwrap_err(); + assert!(matches!(err, Error::BrokenDatabase)); + } }