refactor(keyholder): rename KeyHolderActor to KeyHolder and optimize db connection lifetime

This commit is contained in:
hdbg
2026-02-16 18:25:17 +01:00
parent c82339d764
commit e4038d9188

View File

@@ -68,13 +68,13 @@ pub enum Error {
/// Provides API for encrypting and decrypting data using the vault root key. /// Provides API for encrypting and decrypting data using the vault root key.
/// Abstraction over database to make sure nonces are never reused and encryption keys are never exposed in plaintext outside of this actor. /// Abstraction over database to make sure nonces are never reused and encryption keys are never exposed in plaintext outside of this actor.
#[derive(Actor)] #[derive(Actor)]
pub struct KeyHolderActor { pub struct KeyHolder {
db: db::DatabasePool, db: db::DatabasePool,
state: State, state: State,
} }
#[messages] #[messages]
impl KeyHolderActor { impl KeyHolder {
pub async fn new(db: db::DatabasePool) -> Result<Self, Error> { pub async fn new(db: db::DatabasePool) -> Result<Self, Error> {
let state = { let state = {
let mut conn = db.get().await?; let mut conn = db.get().await?;
@@ -206,14 +206,16 @@ impl KeyHolderActor {
return Err(Error::NotBootstrapped); return Err(Error::NotBootstrapped);
}; };
let mut conn = self.db.get().await?; // We don't want to hold connection while doing expensive KDF work
let current_key = {
let current_key = schema::root_key_history::table let mut conn = self.db.get().await?;
.filter(schema::root_key_history::id.eq(*root_key_history_id)) schema::root_key_history::table
.select((schema::root_key_history::data_encryption_nonce)) .filter(schema::root_key_history::id.eq(*root_key_history_id))
.select((RootKeyHistory::as_select())) .select((schema::root_key_history::data_encryption_nonce))
.first(&mut conn) .select((RootKeyHistory::as_select()))
.await?; .first(&mut conn)
.await?
};
let salt = &current_key.salt; let salt = &current_key.salt;
let salt = v1::Salt::try_from(salt.as_slice()).map_err(|_| { let salt = v1::Salt::try_from(salt.as_slice()).map_err(|_| {
@@ -257,14 +259,17 @@ impl KeyHolderActor {
let State::Unsealed { root_key, .. } = &mut self.state else { let State::Unsealed { root_key, .. } = &mut self.state else {
return Err(Error::NotBootstrapped); return Err(Error::NotBootstrapped);
}; };
let mut conn = self.db.get().await?;
let row: models::AeadEncrypted = schema::aead_encrypted::table let row: models::AeadEncrypted = {
.select(models::AeadEncrypted::as_select()) let mut conn = self.db.get().await?;
.filter(schema::aead_encrypted::id.eq(aead_id)) schema::aead_encrypted::table
.first(&mut conn) .select(models::AeadEncrypted::as_select())
.await .filter(schema::aead_encrypted::id.eq(aead_id))
.optional()? .first(&mut conn)
.ok_or(Error::NotFound)?; .await
.optional()?
.ok_or(Error::NotFound)?
};
let nonce = v1::Nonce::try_from(row.current_nonce.as_slice()).map_err(|_| { let nonce = v1::Nonce::try_from(row.current_nonce.as_slice()).map_err(|_| {
error!( error!(
@@ -293,14 +298,13 @@ impl KeyHolderActor {
// Borrow checker note: &mut borrow a few lines above is disjoint from this field // Borrow checker note: &mut borrow a few lines above is disjoint from this field
let nonce = Self::get_new_nonce(&self.db, *root_key_history_id).await?; let nonce = Self::get_new_nonce(&self.db, *root_key_history_id).await?;
let mut conn = self.db.get().await?;
let mut ciphertext_buffer = plaintext.write().unwrap(); let mut ciphertext_buffer = plaintext.write().unwrap();
let ciphertext_buffer: &mut Vec<u8> = ciphertext_buffer.as_mut(); let ciphertext_buffer: &mut Vec<u8> = ciphertext_buffer.as_mut();
root_key.encrypt_in_place(&nonce, v1::TAG, &mut *ciphertext_buffer)?; root_key.encrypt_in_place(&nonce, v1::TAG, &mut *ciphertext_buffer)?;
let ciphertext = std::mem::take(ciphertext_buffer); let ciphertext = std::mem::take(ciphertext_buffer);
let mut conn = self.db.get().await?;
let aead_id: i32 = insert_into(schema::aead_encrypted::table) let aead_id: i32 = insert_into(schema::aead_encrypted::table)
.values(&models::NewAeadEncrypted { .values(&models::NewAeadEncrypted {
ciphertext, ciphertext,
@@ -319,11 +323,16 @@ impl KeyHolderActor {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::collections::HashSet; use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use diesel::dsl::insert_into; use diesel::dsl::{insert_into, sql_query, update};
use diesel_async::RunQueryDsl; use diesel_async::RunQueryDsl;
use futures::stream::TryUnfold;
use kameo::actor::{ActorRef, Spawn as _};
use memsafe::MemSafe; use memsafe::MemSafe;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use crate::db::{self, models::ArbiterSetting}; use crate::db::{self, models::ArbiterSetting};
@@ -343,20 +352,49 @@ mod tests {
.unwrap(); .unwrap();
} }
async fn bootstrapped_actor(db: &db::DatabasePool) -> KeyHolderActor { async fn bootstrapped_actor(db: &db::DatabasePool) -> KeyHolder {
seed_settings(db).await; seed_settings(db).await;
let mut actor = KeyHolderActor::new(db.clone()).await.unwrap(); let mut actor = KeyHolder::new(db.clone()).await.unwrap();
let seal_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap(); let seal_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap();
actor.bootstrap(seal_key).await.unwrap(); actor.bootstrap(seal_key).await.unwrap();
actor actor
} }
async fn write_concurrently(
actor: ActorRef<KeyHolder>,
prefix: &'static str,
count: usize,
) -> Vec<(i32, Vec<u8>)> {
let mut set = JoinSet::new();
for i in 0..count {
let actor = actor.clone();
set.spawn(async move {
let plaintext = format!("{prefix}-{i}").into_bytes();
let id = {
actor
.ask(CreateNew {
plaintext: MemSafe::new(plaintext.clone()).unwrap(),
})
.await
.unwrap()
};
(id, plaintext)
});
}
let mut out = Vec::with_capacity(count);
while let Some(res) = set.join_next().await {
out.push(res.unwrap());
}
out
}
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
async fn test_bootstrap() { async fn test_bootstrap() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
seed_settings(&db).await; seed_settings(&db).await;
let mut actor = KeyHolderActor::new(db.clone()).await.unwrap(); let mut actor = KeyHolder::new(db.clone()).await.unwrap();
assert!(matches!(actor.state, State::Unbootstrapped)); assert!(matches!(actor.state, State::Unbootstrapped));
@@ -412,7 +450,7 @@ mod tests {
async fn test_create_new_before_bootstrap_fails() { async fn test_create_new_before_bootstrap_fails() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
seed_settings(&db).await; seed_settings(&db).await;
let mut actor = KeyHolderActor::new(db).await.unwrap(); let mut actor = KeyHolder::new(db).await.unwrap();
let err = actor let err = actor
.create_new(MemSafe::new(b"data".to_vec()).unwrap()) .create_new(MemSafe::new(b"data".to_vec()).unwrap())
@@ -426,7 +464,7 @@ mod tests {
async fn test_decrypt_before_bootstrap_fails() { async fn test_decrypt_before_bootstrap_fails() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
seed_settings(&db).await; seed_settings(&db).await;
let mut actor = KeyHolderActor::new(db).await.unwrap(); let mut actor = KeyHolder::new(db).await.unwrap();
let err = actor.decrypt(1).await.unwrap_err(); let err = actor.decrypt(1).await.unwrap_err();
assert!(matches!(err, Error::NotBootstrapped)); assert!(matches!(err, Error::NotBootstrapped));
@@ -449,7 +487,7 @@ mod tests {
let actor = bootstrapped_actor(&db).await; let actor = bootstrapped_actor(&db).await;
drop(actor); drop(actor);
let actor2 = KeyHolderActor::new(db).await.unwrap(); let actor2 = KeyHolder::new(db).await.unwrap();
assert!(matches!(actor2.state, State::Sealed { .. })); assert!(matches!(actor2.state, State::Sealed { .. }));
} }
@@ -518,7 +556,7 @@ mod tests {
.unwrap(); .unwrap();
drop(actor); drop(actor);
let mut actor = KeyHolderActor::new(db.clone()).await.unwrap(); let mut actor = KeyHolder::new(db.clone()).await.unwrap();
assert!(matches!(actor.state, State::Sealed { .. })); assert!(matches!(actor.state, State::Sealed { .. }));
let seal_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap(); let seal_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap();
@@ -543,7 +581,7 @@ mod tests {
.unwrap(); .unwrap();
drop(actor); drop(actor);
let mut actor = KeyHolderActor::new(db.clone()).await.unwrap(); let mut actor = KeyHolder::new(db.clone()).await.unwrap();
assert!(matches!(actor.state, State::Sealed { .. })); assert!(matches!(actor.state, State::Sealed { .. }));
// wrong password // wrong password
@@ -603,4 +641,291 @@ mod tests {
assert_eq!(*d1.read().unwrap(), plaintext); assert_eq!(*d1.read().unwrap(), plaintext);
assert_eq!(*d2.read().unwrap(), plaintext); assert_eq!(*d2.read().unwrap(), plaintext);
} }
#[tokio::test]
#[test_log::test]
async fn concurrent_create_new_no_duplicate_nonces_() {
let db = db::create_test_pool().await;
let actor = KeyHolder::spawn(bootstrapped_actor(&db).await);
let writes = write_concurrently(actor, "nonce-unique", 32).await;
assert_eq!(writes.len(), 32);
let mut conn = db.get().await.unwrap();
let rows: Vec<models::AeadEncrypted> = schema::aead_encrypted::table
.select(models::AeadEncrypted::as_select())
.load(&mut conn)
.await
.unwrap();
assert_eq!(rows.len(), 32);
let nonces: Vec<&Vec<u8>> = rows.iter().map(|r| &r.current_nonce).collect();
let unique: HashSet<&Vec<u8>> = nonces.iter().copied().collect();
assert_eq!(nonces.len(), unique.len(), "all nonces must be unique");
}
#[tokio::test]
#[test_log::test]
async fn concurrent_create_new_root_nonce_never_moves_backward() {
let db = db::create_test_pool().await;
let actor = KeyHolder::spawn(bootstrapped_actor(&db).await);
write_concurrently(actor, "root-max", 24).await;
let mut conn = db.get().await.unwrap();
let rows: Vec<models::AeadEncrypted> = schema::aead_encrypted::table
.select(models::AeadEncrypted::as_select())
.load(&mut conn)
.await
.unwrap();
let max_nonce = rows
.iter()
.map(|r| r.current_nonce.clone())
.max()
.expect("at least one row");
let root_row: models::RootKeyHistory = schema::root_key_history::table
.select(models::RootKeyHistory::as_select())
.first(&mut conn)
.await
.unwrap();
assert_eq!(root_row.data_encryption_nonce, max_nonce);
}
#[tokio::test]
#[test_log::test]
async fn nonce_monotonic_even_when_nonce_allocation_interleaves() {
let db = db::create_test_pool().await;
let mut actor = bootstrapped_actor(&db).await;
let root_key_history_id = match actor.state {
State::Unsealed {
root_key_history_id,
..
} => root_key_history_id,
_ => panic!("expected unsealed state"),
};
let n1 = KeyHolder::get_new_nonce(&db, root_key_history_id)
.await
.unwrap();
let n2 = KeyHolder::get_new_nonce(&db, root_key_history_id)
.await
.unwrap();
assert!(n2.to_vec() > n1.to_vec(), "nonce must increase");
let mut conn = db.get().await.unwrap();
let root_row: models::RootKeyHistory = schema::root_key_history::table
.select(models::RootKeyHistory::as_select())
.first(&mut conn)
.await
.unwrap();
assert_eq!(root_row.data_encryption_nonce, n2.to_vec());
let id = actor
.create_new(MemSafe::new(b"post-interleave".to_vec()).unwrap())
.await
.unwrap();
let row: models::AeadEncrypted = schema::aead_encrypted::table
.filter(schema::aead_encrypted::id.eq(id))
.select(models::AeadEncrypted::as_select())
.first(&mut conn)
.await
.unwrap();
assert!(
row.current_nonce > n2.to_vec(),
"next write must advance nonce"
);
}
#[tokio::test]
#[test_log::test]
async fn insert_failure_does_not_create_partial_row() {
let db = db::create_test_pool().await;
let mut actor = bootstrapped_actor(&db).await;
let root_key_history_id = match actor.state {
State::Unsealed {
root_key_history_id,
..
} => root_key_history_id,
_ => panic!("expected unsealed state"),
};
let mut conn = db.get().await.unwrap();
let before_count: i64 = schema::aead_encrypted::table
.count()
.get_result(&mut conn)
.await
.unwrap();
let before_root_nonce: Vec<u8> = schema::root_key_history::table
.filter(schema::root_key_history::id.eq(root_key_history_id))
.select(schema::root_key_history::data_encryption_nonce)
.first(&mut conn)
.await
.unwrap();
sql_query(
"CREATE TRIGGER fail_aead_insert BEFORE INSERT ON aead_encrypted BEGIN SELECT RAISE(ABORT, 'forced test failure'); END;",
)
.execute(&mut conn)
.await
.unwrap();
drop(conn);
let err = actor
.create_new(MemSafe::new(b"should fail".to_vec()).unwrap())
.await
.unwrap_err();
assert!(matches!(err, Error::DatabaseTransaction(_)));
let mut conn = db.get().await.unwrap();
sql_query("DROP TRIGGER fail_aead_insert;")
.execute(&mut conn)
.await
.unwrap();
let after_count: i64 = schema::aead_encrypted::table
.count()
.get_result(&mut conn)
.await
.unwrap();
assert_eq!(
before_count, after_count,
"failed insert must not create row"
);
let after_root_nonce: Vec<u8> = schema::root_key_history::table
.filter(schema::root_key_history::id.eq(root_key_history_id))
.select(schema::root_key_history::data_encryption_nonce)
.first(&mut conn)
.await
.unwrap();
assert!(
after_root_nonce > before_root_nonce,
"current behavior allows nonce gap on failed insert"
);
}
#[tokio::test]
#[test_log::test]
async fn decrypt_roundtrip_after_high_concurrency() {
let db = db::create_test_pool().await;
let actor = KeyHolder::spawn(bootstrapped_actor(&db).await);
let writes = write_concurrently(actor, "roundtrip", 40).await;
let expected: HashMap<i32, Vec<u8>> = writes.into_iter().collect();
let mut decryptor = KeyHolder::new(db.clone()).await.unwrap();
decryptor
.try_unseal(MemSafe::new(b"test-seal-key".to_vec()).unwrap())
.await
.unwrap();
for (id, plaintext) in expected {
let mut decrypted = decryptor.decrypt(id).await.unwrap();
assert_eq!(*decrypted.read().unwrap(), plaintext);
}
}
#[tokio::test]
#[test_log::test]
async fn swapping_ciphertext_and_nonce_between_rows_changes_logical_binding() {
let db = db::create_test_pool().await;
let mut actor = bootstrapped_actor(&db).await;
let plaintext1 = b"entry-one";
let plaintext2 = b"entry-two";
let id1 = actor
.create_new(MemSafe::new(plaintext1.to_vec()).unwrap())
.await
.unwrap();
let id2 = actor
.create_new(MemSafe::new(plaintext2.to_vec()).unwrap())
.await
.unwrap();
let mut conn = db.get().await.unwrap();
let row1: models::AeadEncrypted = schema::aead_encrypted::table
.filter(schema::aead_encrypted::id.eq(id1))
.select(models::AeadEncrypted::as_select())
.first(&mut conn)
.await
.unwrap();
let row2: models::AeadEncrypted = schema::aead_encrypted::table
.filter(schema::aead_encrypted::id.eq(id2))
.select(models::AeadEncrypted::as_select())
.first(&mut conn)
.await
.unwrap();
update(schema::aead_encrypted::table.filter(schema::aead_encrypted::id.eq(id1)))
.set((
schema::aead_encrypted::ciphertext.eq(row2.ciphertext.clone()),
schema::aead_encrypted::current_nonce.eq(row2.current_nonce.clone()),
))
.execute(&mut conn)
.await
.unwrap();
update(schema::aead_encrypted::table.filter(schema::aead_encrypted::id.eq(id2)))
.set((
schema::aead_encrypted::ciphertext.eq(row1.ciphertext.clone()),
schema::aead_encrypted::current_nonce.eq(row1.current_nonce.clone()),
))
.execute(&mut conn)
.await
.unwrap();
let mut d1 = actor.decrypt(id1).await.unwrap();
let mut d2 = actor.decrypt(id2).await.unwrap();
assert_eq!(*d1.read().unwrap(), plaintext2);
assert_eq!(*d2.read().unwrap(), plaintext1);
}
#[tokio::test]
#[test_log::test]
async fn broken_db_nonce_format_fails_closed() {
// malformed root_key_history nonce must fail create_new
let db = db::create_test_pool().await;
let mut actor = bootstrapped_actor(&db).await;
let root_key_history_id = match actor.state {
State::Unsealed {
root_key_history_id,
..
} => root_key_history_id,
_ => panic!("expected unsealed state"),
};
let mut conn = db.get().await.unwrap();
update(
schema::root_key_history::table
.filter(schema::root_key_history::id.eq(root_key_history_id)),
)
.set(schema::root_key_history::data_encryption_nonce.eq(vec![1, 2, 3]))
.execute(&mut conn)
.await
.unwrap();
drop(conn);
let err = actor
.create_new(MemSafe::new(b"must fail".to_vec()).unwrap())
.await
.unwrap_err();
assert!(matches!(err, Error::BrokenDatabase));
// malformed per-row nonce must fail decrypt
let db = db::create_test_pool().await;
let mut actor = bootstrapped_actor(&db).await;
let id = actor
.create_new(MemSafe::new(b"decrypt target".to_vec()).unwrap())
.await
.unwrap();
let mut conn = db.get().await.unwrap();
update(schema::aead_encrypted::table.filter(schema::aead_encrypted::id.eq(id)))
.set(schema::aead_encrypted::current_nonce.eq(vec![7, 8]))
.execute(&mut conn)
.await
.unwrap();
drop(conn);
let err = actor.decrypt(id).await.unwrap_err();
assert!(matches!(err, Error::BrokenDatabase));
}
} }