7 Commits

Author SHA1 Message Date
hdbg
4236f2c36d refactor(server): reogranized actors, context, and db modules into <dir>/mod.rs structure
Some checks failed
ci/woodpecker/push/server-lint Pipeline was successful
ci/woodpecker/push/server-vet Pipeline failed
ci/woodpecker/push/server-audit Pipeline was successful
ci/woodpecker/push/server-test Pipeline was successful
2026-02-16 22:29:48 +01:00
hdbg
76ff535619 refactor(server::tests): moved integration-like tests into tests/ 2026-02-16 22:27:59 +01:00
hdbg
b3566c8af6 refactor(server): separated global actors into their own handle 2026-02-16 21:58:14 +01:00
hdbg
bdb9f01757 refactor(server): actors reorganization & linter fixes 2026-02-16 21:43:59 +01:00
hdbg
0805e7a846 feat(keyholder): add seal method and unseal integration tests 2026-02-16 21:38:29 +01:00
hdbg
eb9cbc88e9 feat(server::user-agent): Unseal implemented 2026-02-16 21:17:06 +01:00
hdbg
dd716da4cd test(keyholder): remove unused imports from test modules 2026-02-16 21:15:13 +01:00
25 changed files with 1379 additions and 1147 deletions

View File

@@ -1,25 +0,0 @@
syntax = "proto3";
package arbiter.unseal;
import "google/protobuf/empty.proto";
message UnsealStart {
bytes client_pubkey = 1;
}
message UnsealStartResponse {
bytes server_pubkey = 1;
}
message UnsealEncryptedKey {
bytes nonce = 1;
bytes ciphertext = 2;
bytes associated_data = 3;
}
enum UnsealResult {
UNSEAL_RESULT_UNSPECIFIED = 0;
UNSEAL_RESULT_SUCCESS = 1;
UNSEAL_RESULT_INVALID_KEY = 2;
UNSEAL_RESULT_UNBOOTSTRAPPED = 3;
}

View File

@@ -3,19 +3,49 @@ syntax = "proto3";
package arbiter; package arbiter;
import "auth.proto"; import "auth.proto";
import "unseal.proto"; import "google/protobuf/empty.proto";
message UnsealStart {
bytes client_pubkey = 1;
}
message UnsealStartResponse {
bytes server_pubkey = 1;
}
message UnsealEncryptedKey {
bytes nonce = 1;
bytes ciphertext = 2;
bytes associated_data = 3;
}
enum UnsealResult {
UNSEAL_RESULT_UNSPECIFIED = 0;
UNSEAL_RESULT_SUCCESS = 1;
UNSEAL_RESULT_INVALID_KEY = 2;
UNSEAL_RESULT_UNBOOTSTRAPPED = 3;
}
enum VaultState {
VAULT_STATE_UNSPECIFIED = 0;
VAULT_STATE_UNBOOTSTRAPPED = 1;
VAULT_STATE_SEALED = 2;
VAULT_STATE_UNSEALED = 3;
VAULT_STATE_ERROR = 4;
}
message UserAgentRequest { message UserAgentRequest {
oneof payload { oneof payload {
arbiter.auth.ClientMessage auth_message = 1; arbiter.auth.ClientMessage auth_message = 1;
arbiter.unseal.UnsealStart unseal_start = 2; UnsealStart unseal_start = 2;
arbiter.unseal.UnsealEncryptedKey unseal_encrypted_key = 3; UnsealEncryptedKey unseal_encrypted_key = 3;
google.protobuf.Empty query_vault_state = 4;
} }
} }
message UserAgentResponse { message UserAgentResponse {
oneof payload { oneof payload {
arbiter.auth.ServerMessage auth_message = 1; arbiter.auth.ServerMessage auth_message = 1;
arbiter.unseal.UnsealStartResponse unseal_start_response = 2; UnsealStartResponse unseal_start_response = 2;
arbiter.unseal.UnsealResult unseal_result = 3; UnsealResult unseal_result = 3;
VaultState vault_state = 4;
} }
} }

View File

@@ -6,17 +6,14 @@ pub mod proto {
pub mod auth { pub mod auth {
tonic::include_proto!("arbiter.auth"); tonic::include_proto!("arbiter.auth");
} }
pub mod unseal {
tonic::include_proto!("arbiter.unseal");
}
} }
pub mod transport; pub mod transport;
pub static BOOTSTRAP_TOKEN_PATH: &'static str = "bootstrap_token"; pub static BOOTSTRAP_TOKEN_PATH: &str = "bootstrap_token";
pub fn home_path() -> Result<std::path::PathBuf, std::io::Error> { pub fn home_path() -> Result<std::path::PathBuf, std::io::Error> {
static ARBITER_HOME: &'static str = ".arbiter"; static ARBITER_HOME: &str = ".arbiter";
let home_dir = std::env::home_dir().ok_or(std::io::Error::new( let home_dir = std::env::home_dir().ok_or(std::io::Error::new(
std::io::ErrorKind::PermissionDenied, std::io::ErrorKind::PermissionDenied,
"can not get home directory", "can not get home directory",

View File

@@ -1,4 +0,0 @@
pub mod user_agent;
pub mod client;
pub(crate) mod bootstrap;
pub(crate) mod keyholder;

View File

@@ -28,7 +28,7 @@ pub async fn generate_token() -> Result<String, std::io::Error> {
} }
#[derive(Error, Debug, Diagnostic)] #[derive(Error, Debug, Diagnostic)]
pub enum BootstrapError { pub enum Error {
#[error("Database error: {0}")] #[error("Database error: {0}")]
#[diagnostic(code(arbiter_server::bootstrap::database))] #[diagnostic(code(arbiter_server::bootstrap::database))]
Database(#[from] db::PoolError), Database(#[from] db::PoolError),
@@ -48,7 +48,7 @@ pub struct Bootstrapper {
} }
impl Bootstrapper { impl Bootstrapper {
pub async fn new(db: &DatabasePool) -> Result<Self, BootstrapError> { pub async fn new(db: &DatabasePool) -> Result<Self, Error> {
let mut conn = db.get().await?; let mut conn = db.get().await?;
let row_count: i64 = schema::useragent_client::table let row_count: i64 = schema::useragent_client::table
@@ -69,11 +69,6 @@ impl Bootstrapper {
Ok(Self { token }) Ok(Self { token })
} }
#[cfg(test)]
pub fn get_token(&self) -> Option<String> {
self.token.clone()
}
} }
#[messages] #[messages]
@@ -96,3 +91,11 @@ impl Bootstrapper {
} }
} }
} }
#[messages]
impl Bootstrapper {
#[message]
pub fn get_token(&self) -> Option<String> {
self.token.clone()
}
}

View File

@@ -1,939 +0,0 @@
use diesel::{
ExpressionMethods as _, OptionalExtension, QueryDsl, SelectableHelper,
dsl::{insert_into, update},
};
use diesel_async::{AsyncConnection, RunQueryDsl};
use kameo::{Actor, Reply, messages};
use memsafe::MemSafe;
use strum::{EnumDiscriminants, IntoDiscriminant};
use tracing::{error, info};
use crate::{
actors::keyholder::v1::{KeyCell, Nonce},
db::{
self,
models::{self, RootKeyHistory},
schema::{self},
},
};
pub mod v1;
#[derive(Default, EnumDiscriminants)]
#[strum_discriminants(derive(Reply), vis(pub))]
enum State {
#[default]
Unbootstrapped,
Sealed {
root_key_history_id: i32,
},
Unsealed {
root_key_history_id: i32,
root_key: KeyCell,
},
}
#[derive(Debug, thiserror::Error, miette::Diagnostic)]
pub enum Error {
#[error("Keyholder is already bootstrapped")]
#[diagnostic(code(arbiter::keyholder::already_bootstrapped))]
AlreadyBootstrapped,
#[error("Keyholder is not bootstrapped")]
#[diagnostic(code(arbiter::keyholder::not_bootstrapped))]
NotBootstrapped,
#[error("Invalid key provided")]
#[diagnostic(code(arbiter::keyholder::invalid_key))]
InvalidKey,
#[error("Requested aead entry not found")]
#[diagnostic(code(arbiter::keyholder::aead_not_found))]
NotFound,
#[error("Encryption error: {0}")]
#[diagnostic(code(arbiter::keyholder::encryption_error))]
Encryption(#[from] chacha20poly1305::aead::Error),
#[error("Database error: {0}")]
#[diagnostic(code(arbiter::keyholder::database_error))]
DatabaseConnection(#[from] db::PoolError),
#[error("Database transaction error: {0}")]
#[diagnostic(code(arbiter::keyholder::database_transaction_error))]
DatabaseTransaction(#[from] diesel::result::Error),
#[error("Broken database")]
#[diagnostic(code(arbiter::keyholder::broken_database))]
BrokenDatabase,
}
/// Manages vault root key and tracks current state of the vault (bootstrapped/unbootstrapped, sealed/unsealed).
/// Provides API for encrypting and decrypting data using the vault root key.
/// Abstraction over database to make sure nonces are never reused and encryption keys are never exposed in plaintext outside of this actor.
#[derive(Actor)]
pub struct KeyHolder {
db: db::DatabasePool,
state: State,
}
#[messages]
impl KeyHolder {
pub async fn new(db: db::DatabasePool) -> Result<Self, Error> {
let state = {
let mut conn = db.get().await?;
let (root_key_history,) = schema::arbiter_settings::table
.left_join(schema::root_key_history::table)
.select((Option::<RootKeyHistory>::as_select(),))
.get_result::<(Option<RootKeyHistory>,)>(&mut conn)
.await?;
match root_key_history {
Some(root_key_history) => State::Sealed {
root_key_history_id: root_key_history.id,
},
None => State::Unbootstrapped,
}
};
Ok(Self { db, state })
}
// Exclusive transaction to avoid race condtions if multiple keyholders write
// additional layer of protection against nonce-reuse
async fn get_new_nonce(pool: &db::DatabasePool, root_key_id: i32) -> Result<Nonce, Error> {
let mut conn = pool.get().await?;
let nonce = conn
.exclusive_transaction(|conn| {
Box::pin(async move {
let current_nonce: Vec<u8> = schema::root_key_history::table
.filter(schema::root_key_history::id.eq(root_key_id))
.select(schema::root_key_history::data_encryption_nonce)
.first(conn)
.await?;
let mut nonce =
v1::Nonce::try_from(current_nonce.as_slice()).map_err(|_| {
error!(
"Broken database: invalid nonce for root key history id={}",
root_key_id
);
Error::BrokenDatabase
})?;
nonce.increment();
update(schema::root_key_history::table)
.filter(schema::root_key_history::id.eq(root_key_id))
.set(schema::root_key_history::data_encryption_nonce.eq(nonce.to_vec()))
.execute(conn)
.await?;
Result::<_, Error>::Ok(nonce)
})
})
.await?;
Ok(nonce)
}
#[message]
pub async fn bootstrap(&mut self, seal_key_raw: MemSafe<Vec<u8>>) -> Result<(), Error> {
if !matches!(self.state, State::Unbootstrapped) {
return Err(Error::AlreadyBootstrapped);
}
let salt = v1::generate_salt();
let mut seal_key = v1::derive_seal_key(seal_key_raw, &salt);
let mut root_key = KeyCell::new_secure_random();
// Zero nonces are fine because they are one-time
let root_key_nonce = v1::Nonce::default();
let data_encryption_nonce = v1::Nonce::default();
let root_key_ciphertext: Vec<u8> = {
let root_key_reader = root_key.0.read().unwrap();
let root_key_reader = root_key_reader.as_slice();
seal_key
.encrypt(&root_key_nonce, v1::ROOT_KEY_TAG, root_key_reader)
.map_err(|err| {
error!(?err, "Fatal bootstrap error");
Error::Encryption(err)
})?
};
let mut conn = self.db.get().await?;
let data_encryption_nonce_bytes = data_encryption_nonce.to_vec();
let root_key_history_id = conn
.transaction(|conn| {
Box::pin(async move {
let root_key_history_id: i32 = insert_into(schema::root_key_history::table)
.values(&models::NewRootKeyHistory {
ciphertext: root_key_ciphertext,
tag: v1::ROOT_KEY_TAG.to_vec(),
root_key_encryption_nonce: root_key_nonce.to_vec(),
data_encryption_nonce: data_encryption_nonce_bytes,
schema_version: 1,
salt: salt.to_vec(),
})
.returning(schema::root_key_history::id)
.get_result(conn)
.await?;
update(schema::arbiter_settings::table)
.set(schema::arbiter_settings::root_key_id.eq(root_key_history_id))
.execute(conn)
.await?;
Result::<_, diesel::result::Error>::Ok(root_key_history_id)
})
})
.await?;
self.state = State::Unsealed {
root_key,
root_key_history_id,
};
info!("Keyholder bootstrapped successfully");
Ok(())
}
#[message]
pub async fn try_unseal(&mut self, seal_key_raw: MemSafe<Vec<u8>>) -> Result<(), Error> {
let State::Sealed {
root_key_history_id,
} = &self.state
else {
return Err(Error::NotBootstrapped);
};
// We don't want to hold connection while doing expensive KDF work
let current_key = {
let mut conn = self.db.get().await?;
schema::root_key_history::table
.filter(schema::root_key_history::id.eq(*root_key_history_id))
.select(schema::root_key_history::data_encryption_nonce )
.select(RootKeyHistory::as_select() )
.first(&mut conn)
.await?
};
let salt = &current_key.salt;
let salt = v1::Salt::try_from(salt.as_slice()).map_err(|_| {
error!("Broken database: invalid salt for root key");
Error::BrokenDatabase
})?;
let mut seal_key = v1::derive_seal_key(seal_key_raw, &salt);
let mut root_key = MemSafe::new(current_key.ciphertext.clone()).unwrap();
let nonce = v1::Nonce::try_from(current_key.root_key_encryption_nonce.as_slice()).map_err(
|_| {
error!("Broken database: invalid nonce for root key");
Error::BrokenDatabase
},
)?;
seal_key
.decrypt_in_place(&nonce, v1::ROOT_KEY_TAG, &mut root_key)
.map_err(|err| {
error!(?err, "Failed to unseal root key: invalid seal key");
Error::InvalidKey
})?;
self.state = State::Unsealed {
root_key_history_id: current_key.id,
root_key: v1::KeyCell::try_from(root_key).map_err(|err| {
error!(?err, "Broken database: invalid encryption key size");
Error::BrokenDatabase
})?,
};
info!("Keyholder unsealed successfully");
Ok(())
}
// Decrypts the `aead_encrypted` entry with the given ID and returns the plaintext
#[message]
pub async fn decrypt(&mut self, aead_id: i32) -> Result<MemSafe<Vec<u8>>, Error> {
let State::Unsealed { root_key, .. } = &mut self.state else {
return Err(Error::NotBootstrapped);
};
let row: models::AeadEncrypted = {
let mut conn = self.db.get().await?;
schema::aead_encrypted::table
.select(models::AeadEncrypted::as_select())
.filter(schema::aead_encrypted::id.eq(aead_id))
.first(&mut conn)
.await
.optional()?
.ok_or(Error::NotFound)?
};
let nonce = v1::Nonce::try_from(row.current_nonce.as_slice()).map_err(|_| {
error!(
"Broken database: invalid nonce for aead_encrypted id={}",
aead_id
);
Error::BrokenDatabase
})?;
let mut output = MemSafe::new(row.ciphertext).unwrap();
root_key.decrypt_in_place(&nonce, v1::TAG, &mut output)?;
Ok(output)
}
// Creates new `aead_encrypted` entry in the database and returns it's ID
#[message]
pub async fn create_new(&mut self, mut plaintext: MemSafe<Vec<u8>>) -> Result<i32, Error> {
let State::Unsealed {
root_key,
root_key_history_id,
} = &mut self.state
else {
return Err(Error::NotBootstrapped);
};
// Order matters here - `get_new_nonce` acquires connection, so we need to call it before next acquire
// Borrow checker note: &mut borrow a few lines above is disjoint from this field
let nonce = Self::get_new_nonce(&self.db, *root_key_history_id).await?;
let mut ciphertext_buffer = plaintext.write().unwrap();
let ciphertext_buffer: &mut Vec<u8> = ciphertext_buffer.as_mut();
root_key.encrypt_in_place(&nonce, v1::TAG, &mut *ciphertext_buffer)?;
let ciphertext = std::mem::take(ciphertext_buffer);
let mut conn = self.db.get().await?;
let aead_id: i32 = insert_into(schema::aead_encrypted::table)
.values(&models::NewAeadEncrypted {
ciphertext,
tag: v1::TAG.to_vec(),
current_nonce: nonce.to_vec(),
schema_version: 1,
associated_root_key_id: *root_key_history_id,
created_at: chrono::Utc::now().timestamp() as i32,
})
.returning(schema::aead_encrypted::id)
.get_result(&mut conn)
.await?;
Ok(aead_id)
}
#[message]
pub fn get_state(&self) -> StateDiscriminants {
self.state.discriminant()
}
}
#[cfg(test)]
mod tests {
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use diesel::dsl::{insert_into, sql_query, update};
use diesel_async::RunQueryDsl;
use futures::stream::TryUnfold;
use kameo::actor::{ActorRef, Spawn as _};
use memsafe::MemSafe;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use crate::db::{self, models::ArbiterSetting};
use super::*;
async fn seed_settings(pool: &db::DatabasePool) {
let mut conn = pool.get().await.unwrap();
insert_into(schema::arbiter_settings::table)
.values(&ArbiterSetting {
id: 1,
root_key_id: None,
cert_key: vec![],
cert: vec![],
})
.execute(&mut conn)
.await
.unwrap();
}
async fn bootstrapped_actor(db: &db::DatabasePool) -> KeyHolder {
seed_settings(db).await;
let mut actor = KeyHolder::new(db.clone()).await.unwrap();
let seal_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap();
actor.bootstrap(seal_key).await.unwrap();
actor
}
async fn write_concurrently(
actor: ActorRef<KeyHolder>,
prefix: &'static str,
count: usize,
) -> Vec<(i32, Vec<u8>)> {
let mut set = JoinSet::new();
for i in 0..count {
let actor = actor.clone();
set.spawn(async move {
let plaintext = format!("{prefix}-{i}").into_bytes();
let id = {
actor
.ask(CreateNew {
plaintext: MemSafe::new(plaintext.clone()).unwrap(),
})
.await
.unwrap()
};
(id, plaintext)
});
}
let mut out = Vec::with_capacity(count);
while let Some(res) = set.join_next().await {
out.push(res.unwrap());
}
out
}
#[tokio::test]
#[test_log::test]
async fn test_bootstrap() {
let db = db::create_test_pool().await;
seed_settings(&db).await;
let mut actor = KeyHolder::new(db.clone()).await.unwrap();
assert!(matches!(actor.state, State::Unbootstrapped));
let seal_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap();
actor.bootstrap(seal_key).await.unwrap();
assert!(matches!(actor.state, State::Unsealed { .. }));
let mut conn = db.get().await.unwrap();
let row: models::RootKeyHistory = schema::root_key_history::table
.select(models::RootKeyHistory::as_select())
.first(&mut conn)
.await
.unwrap();
assert_eq!(row.schema_version, 1);
assert_eq!(row.tag, v1::ROOT_KEY_TAG);
assert!(!row.ciphertext.is_empty());
assert!(!row.salt.is_empty());
assert_eq!(row.data_encryption_nonce, v1::Nonce::default().to_vec());
}
#[tokio::test]
#[test_log::test]
async fn test_bootstrap_rejects_double() {
let db = db::create_test_pool().await;
let mut actor = bootstrapped_actor(&db).await;
let seal_key2 = MemSafe::new(b"test-seal-key".to_vec()).unwrap();
let err = actor.bootstrap(seal_key2).await.unwrap_err();
assert!(matches!(err, Error::AlreadyBootstrapped));
}
#[tokio::test]
#[test_log::test]
async fn test_create_decrypt_roundtrip() {
let db = db::create_test_pool().await;
let mut actor = bootstrapped_actor(&db).await;
let plaintext = b"hello arbiter";
let aead_id = actor
.create_new(MemSafe::new(plaintext.to_vec()).unwrap())
.await
.unwrap();
let mut decrypted = actor.decrypt(aead_id).await.unwrap();
let decrypted = decrypted.read().unwrap();
assert_eq!(*decrypted, plaintext);
}
#[tokio::test]
#[test_log::test]
async fn test_create_new_before_bootstrap_fails() {
let db = db::create_test_pool().await;
seed_settings(&db).await;
let mut actor = KeyHolder::new(db).await.unwrap();
let err = actor
.create_new(MemSafe::new(b"data".to_vec()).unwrap())
.await
.unwrap_err();
assert!(matches!(err, Error::NotBootstrapped));
}
#[tokio::test]
#[test_log::test]
async fn test_decrypt_before_bootstrap_fails() {
let db = db::create_test_pool().await;
seed_settings(&db).await;
let mut actor = KeyHolder::new(db).await.unwrap();
let err = actor.decrypt(1).await.unwrap_err();
assert!(matches!(err, Error::NotBootstrapped));
}
#[tokio::test]
#[test_log::test]
async fn test_decrypt_nonexistent_returns_not_found() {
let db = db::create_test_pool().await;
let mut actor = bootstrapped_actor(&db).await;
let err = actor.decrypt(9999).await.unwrap_err();
assert!(matches!(err, Error::NotFound));
}
#[tokio::test]
#[test_log::test]
async fn test_new_restores_sealed_state() {
let db = db::create_test_pool().await;
let actor = bootstrapped_actor(&db).await;
drop(actor);
let actor2 = KeyHolder::new(db).await.unwrap();
assert!(matches!(actor2.state, State::Sealed { .. }));
}
#[tokio::test]
#[test_log::test]
async fn test_nonce_never_reused() {
let db = db::create_test_pool().await;
let mut actor = bootstrapped_actor(&db).await;
let n = 5;
let mut ids = Vec::with_capacity(n);
for i in 0..n {
let id = actor
.create_new(MemSafe::new(format!("secret {i}").into_bytes()).unwrap())
.await
.unwrap();
ids.push(id);
}
// read all stored nonces from DB
let mut conn = db.get().await.unwrap();
let rows: Vec<models::AeadEncrypted> = schema::aead_encrypted::table
.select(models::AeadEncrypted::as_select())
.load(&mut conn)
.await
.unwrap();
assert_eq!(rows.len(), n);
let nonces: Vec<&Vec<u8>> = rows.iter().map(|r| &r.current_nonce).collect();
let unique: HashSet<&Vec<u8>> = nonces.iter().copied().collect();
assert_eq!(nonces.len(), unique.len(), "all nonces must be unique");
// verify nonces are sequential increments from 1
for (i, row) in rows.iter().enumerate() {
let mut expected = v1::Nonce::default();
for _ in 0..=i {
expected.increment();
}
assert_eq!(row.current_nonce, expected.to_vec(), "nonce {i} mismatch");
}
// verify data_encryption_nonce on root_key_history tracks the latest nonce
let root_row: models::RootKeyHistory = schema::root_key_history::table
.select(models::RootKeyHistory::as_select())
.first(&mut conn)
.await
.unwrap();
let last_nonce = &rows.last().unwrap().current_nonce;
assert_eq!(
&root_row.data_encryption_nonce, last_nonce,
"root_key_history must track the latest nonce"
);
}
#[tokio::test]
#[test_log::test]
async fn test_unseal_correct_password() {
let db = db::create_test_pool().await;
let mut actor = bootstrapped_actor(&db).await;
let plaintext = b"survive a restart";
let aead_id = actor
.create_new(MemSafe::new(plaintext.to_vec()).unwrap())
.await
.unwrap();
drop(actor);
let mut actor = KeyHolder::new(db.clone()).await.unwrap();
assert!(matches!(actor.state, State::Sealed { .. }));
let seal_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap();
actor.try_unseal(seal_key).await.unwrap();
assert!(matches!(actor.state, State::Unsealed { .. }));
// previously encrypted data is still decryptable
let mut decrypted = actor.decrypt(aead_id).await.unwrap();
assert_eq!(*decrypted.read().unwrap(), plaintext);
}
#[tokio::test]
#[test_log::test]
async fn test_unseal_wrong_then_correct_password() {
let db = db::create_test_pool().await;
let mut actor = bootstrapped_actor(&db).await;
let plaintext = b"important data";
let aead_id = actor
.create_new(MemSafe::new(plaintext.to_vec()).unwrap())
.await
.unwrap();
drop(actor);
let mut actor = KeyHolder::new(db.clone()).await.unwrap();
assert!(matches!(actor.state, State::Sealed { .. }));
// wrong password
let bad_key = MemSafe::new(b"wrong-password".to_vec()).unwrap();
let err = actor.try_unseal(bad_key).await.unwrap_err();
assert!(matches!(err, Error::InvalidKey));
assert!(
matches!(actor.state, State::Sealed { .. }),
"state must remain Sealed after failed attempt"
);
// correct password
let good_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap();
actor.try_unseal(good_key).await.unwrap();
assert!(matches!(actor.state, State::Unsealed { .. }));
let mut decrypted = actor.decrypt(aead_id).await.unwrap();
assert_eq!(*decrypted.read().unwrap(), plaintext);
}
#[tokio::test]
#[test_log::test]
async fn test_ciphertext_differs_across_entries() {
let db = db::create_test_pool().await;
let mut actor = bootstrapped_actor(&db).await;
let plaintext = b"same content";
let id1 = actor
.create_new(MemSafe::new(plaintext.to_vec()).unwrap())
.await
.unwrap();
let id2 = actor
.create_new(MemSafe::new(plaintext.to_vec()).unwrap())
.await
.unwrap();
// different nonces => different ciphertext, even for identical plaintext
let mut conn = db.get().await.unwrap();
let row1: models::AeadEncrypted = schema::aead_encrypted::table
.filter(schema::aead_encrypted::id.eq(id1))
.select(models::AeadEncrypted::as_select())
.first(&mut conn)
.await
.unwrap();
let row2: models::AeadEncrypted = schema::aead_encrypted::table
.filter(schema::aead_encrypted::id.eq(id2))
.select(models::AeadEncrypted::as_select())
.first(&mut conn)
.await
.unwrap();
assert_ne!(row1.ciphertext, row2.ciphertext);
// but both decrypt to the same plaintext
let mut d1 = actor.decrypt(id1).await.unwrap();
let mut d2 = actor.decrypt(id2).await.unwrap();
assert_eq!(*d1.read().unwrap(), plaintext);
assert_eq!(*d2.read().unwrap(), plaintext);
}
#[tokio::test]
#[test_log::test]
async fn concurrent_create_new_no_duplicate_nonces_() {
let db = db::create_test_pool().await;
let actor = KeyHolder::spawn(bootstrapped_actor(&db).await);
let writes = write_concurrently(actor, "nonce-unique", 32).await;
assert_eq!(writes.len(), 32);
let mut conn = db.get().await.unwrap();
let rows: Vec<models::AeadEncrypted> = schema::aead_encrypted::table
.select(models::AeadEncrypted::as_select())
.load(&mut conn)
.await
.unwrap();
assert_eq!(rows.len(), 32);
let nonces: Vec<&Vec<u8>> = rows.iter().map(|r| &r.current_nonce).collect();
let unique: HashSet<&Vec<u8>> = nonces.iter().copied().collect();
assert_eq!(nonces.len(), unique.len(), "all nonces must be unique");
}
#[tokio::test]
#[test_log::test]
async fn concurrent_create_new_root_nonce_never_moves_backward() {
let db = db::create_test_pool().await;
let actor = KeyHolder::spawn(bootstrapped_actor(&db).await);
write_concurrently(actor, "root-max", 24).await;
let mut conn = db.get().await.unwrap();
let rows: Vec<models::AeadEncrypted> = schema::aead_encrypted::table
.select(models::AeadEncrypted::as_select())
.load(&mut conn)
.await
.unwrap();
let max_nonce = rows
.iter()
.map(|r| r.current_nonce.clone())
.max()
.expect("at least one row");
let root_row: models::RootKeyHistory = schema::root_key_history::table
.select(models::RootKeyHistory::as_select())
.first(&mut conn)
.await
.unwrap();
assert_eq!(root_row.data_encryption_nonce, max_nonce);
}
#[tokio::test]
#[test_log::test]
async fn nonce_monotonic_even_when_nonce_allocation_interleaves() {
let db = db::create_test_pool().await;
let mut actor = bootstrapped_actor(&db).await;
let root_key_history_id = match actor.state {
State::Unsealed {
root_key_history_id,
..
} => root_key_history_id,
_ => panic!("expected unsealed state"),
};
let n1 = KeyHolder::get_new_nonce(&db, root_key_history_id)
.await
.unwrap();
let n2 = KeyHolder::get_new_nonce(&db, root_key_history_id)
.await
.unwrap();
assert!(n2.to_vec() > n1.to_vec(), "nonce must increase");
let mut conn = db.get().await.unwrap();
let root_row: models::RootKeyHistory = schema::root_key_history::table
.select(models::RootKeyHistory::as_select())
.first(&mut conn)
.await
.unwrap();
assert_eq!(root_row.data_encryption_nonce, n2.to_vec());
let id = actor
.create_new(MemSafe::new(b"post-interleave".to_vec()).unwrap())
.await
.unwrap();
let row: models::AeadEncrypted = schema::aead_encrypted::table
.filter(schema::aead_encrypted::id.eq(id))
.select(models::AeadEncrypted::as_select())
.first(&mut conn)
.await
.unwrap();
assert!(
row.current_nonce > n2.to_vec(),
"next write must advance nonce"
);
}
#[tokio::test]
#[test_log::test]
async fn insert_failure_does_not_create_partial_row() {
let db = db::create_test_pool().await;
let mut actor = bootstrapped_actor(&db).await;
let root_key_history_id = match actor.state {
State::Unsealed {
root_key_history_id,
..
} => root_key_history_id,
_ => panic!("expected unsealed state"),
};
let mut conn = db.get().await.unwrap();
let before_count: i64 = schema::aead_encrypted::table
.count()
.get_result(&mut conn)
.await
.unwrap();
let before_root_nonce: Vec<u8> = schema::root_key_history::table
.filter(schema::root_key_history::id.eq(root_key_history_id))
.select(schema::root_key_history::data_encryption_nonce)
.first(&mut conn)
.await
.unwrap();
sql_query(
"CREATE TRIGGER fail_aead_insert BEFORE INSERT ON aead_encrypted BEGIN SELECT RAISE(ABORT, 'forced test failure'); END;",
)
.execute(&mut conn)
.await
.unwrap();
drop(conn);
let err = actor
.create_new(MemSafe::new(b"should fail".to_vec()).unwrap())
.await
.unwrap_err();
assert!(matches!(err, Error::DatabaseTransaction(_)));
let mut conn = db.get().await.unwrap();
sql_query("DROP TRIGGER fail_aead_insert;")
.execute(&mut conn)
.await
.unwrap();
let after_count: i64 = schema::aead_encrypted::table
.count()
.get_result(&mut conn)
.await
.unwrap();
assert_eq!(
before_count, after_count,
"failed insert must not create row"
);
let after_root_nonce: Vec<u8> = schema::root_key_history::table
.filter(schema::root_key_history::id.eq(root_key_history_id))
.select(schema::root_key_history::data_encryption_nonce)
.first(&mut conn)
.await
.unwrap();
assert!(
after_root_nonce > before_root_nonce,
"current behavior allows nonce gap on failed insert"
);
}
#[tokio::test]
#[test_log::test]
async fn decrypt_roundtrip_after_high_concurrency() {
let db = db::create_test_pool().await;
let actor = KeyHolder::spawn(bootstrapped_actor(&db).await);
let writes = write_concurrently(actor, "roundtrip", 40).await;
let expected: HashMap<i32, Vec<u8>> = writes.into_iter().collect();
let mut decryptor = KeyHolder::new(db.clone()).await.unwrap();
decryptor
.try_unseal(MemSafe::new(b"test-seal-key".to_vec()).unwrap())
.await
.unwrap();
for (id, plaintext) in expected {
let mut decrypted = decryptor.decrypt(id).await.unwrap();
assert_eq!(*decrypted.read().unwrap(), plaintext);
}
}
// #[tokio::test]
// #[test_log::test]
// async fn swapping_ciphertext_and_nonce_between_rows_changes_logical_binding() {
// let db = db::create_test_pool().await;
// let mut actor = bootstrapped_actor(&db).await;
// let plaintext1 = b"entry-one";
// let plaintext2 = b"entry-two";
// let id1 = actor
// .create_new(MemSafe::new(plaintext1.to_vec()).unwrap())
// .await
// .unwrap();
// let id2 = actor
// .create_new(MemSafe::new(plaintext2.to_vec()).unwrap())
// .await
// .unwrap();
// let mut conn = db.get().await.unwrap();
// let row1: models::AeadEncrypted = schema::aead_encrypted::table
// .filter(schema::aead_encrypted::id.eq(id1))
// .select(models::AeadEncrypted::as_select())
// .first(&mut conn)
// .await
// .unwrap();
// let row2: models::AeadEncrypted = schema::aead_encrypted::table
// .filter(schema::aead_encrypted::id.eq(id2))
// .select(models::AeadEncrypted::as_select())
// .first(&mut conn)
// .await
// .unwrap();
// update(schema::aead_encrypted::table.filter(schema::aead_encrypted::id.eq(id1)))
// .set((
// schema::aead_encrypted::ciphertext.eq(row2.ciphertext.clone()),
// schema::aead_encrypted::current_nonce.eq(row2.current_nonce.clone()),
// ))
// .execute(&mut conn)
// .await
// .unwrap();
// update(schema::aead_encrypted::table.filter(schema::aead_encrypted::id.eq(id2)))
// .set((
// schema::aead_encrypted::ciphertext.eq(row1.ciphertext.clone()),
// schema::aead_encrypted::current_nonce.eq(row1.current_nonce.clone()),
// ))
// .execute(&mut conn)
// .await
// .unwrap();
// let mut d1 = actor.decrypt(id1).await.unwrap();
// let mut d2 = actor.decrypt(id2).await.unwrap();
// assert_eq!(*d1.read().unwrap(), plaintext2);
// assert_eq!(*d2.read().unwrap(), plaintext1);
// }
#[tokio::test]
#[test_log::test]
async fn broken_db_nonce_format_fails_closed() {
// malformed root_key_history nonce must fail create_new
let db = db::create_test_pool().await;
let mut actor = bootstrapped_actor(&db).await;
let root_key_history_id = match actor.state {
State::Unsealed {
root_key_history_id,
..
} => root_key_history_id,
_ => panic!("expected unsealed state"),
};
let mut conn = db.get().await.unwrap();
update(
schema::root_key_history::table
.filter(schema::root_key_history::id.eq(root_key_history_id)),
)
.set(schema::root_key_history::data_encryption_nonce.eq(vec![1, 2, 3]))
.execute(&mut conn)
.await
.unwrap();
drop(conn);
let err = actor
.create_new(MemSafe::new(b"must fail".to_vec()).unwrap())
.await
.unwrap_err();
assert!(matches!(err, Error::BrokenDatabase));
// malformed per-row nonce must fail decrypt
let db = db::create_test_pool().await;
let mut actor = bootstrapped_actor(&db).await;
let id = actor
.create_new(MemSafe::new(b"decrypt target".to_vec()).unwrap())
.await
.unwrap();
let mut conn = db.get().await.unwrap();
update(schema::aead_encrypted::table.filter(schema::aead_encrypted::id.eq(id)))
.set(schema::aead_encrypted::current_nonce.eq(vec![7, 8]))
.execute(&mut conn)
.await
.unwrap();
drop(conn);
let err = actor.decrypt(id).await.unwrap_err();
assert!(matches!(err, Error::BrokenDatabase));
}
}

View File

@@ -0,0 +1 @@
pub mod v1;

View File

@@ -42,12 +42,12 @@ impl<'a> TryFrom<&'a [u8]> for Nonce {
return Err(()); return Err(());
} }
let mut nonce = [0u8; NONCE_LENGTH]; let mut nonce = [0u8; NONCE_LENGTH];
nonce.copy_from_slice(&value); nonce.copy_from_slice(value);
Ok(Self(nonce)) Ok(Self(nonce))
} }
} }
pub struct KeyCell(pub(super) MemSafe<Key>); pub struct KeyCell(pub MemSafe<Key>);
impl From<MemSafe<Key>> for KeyCell { impl From<MemSafe<Key>> for KeyCell {
fn from(value: MemSafe<Key>) -> Self { fn from(value: MemSafe<Key>) -> Self {
Self(value) Self(value)
@@ -85,10 +85,6 @@ impl KeyCell {
key.into() key.into()
} }
pub fn into_inner(self) -> MemSafe<Key> {
self.0
}
pub fn encrypt_in_place( pub fn encrypt_in_place(
&mut self, &mut self,
nonce: &Nonce, nonce: &Nonce,
@@ -128,9 +124,8 @@ impl KeyCell {
let mut cipher = XChaCha20Poly1305::new(key_ref); let mut cipher = XChaCha20Poly1305::new(key_ref);
let nonce = XNonce::from_slice(nonce.0.as_ref()); let nonce = XNonce::from_slice(nonce.0.as_ref());
let ciphertext = cipher.encrypt( let ciphertext = cipher.encrypt(
&nonce, nonce,
Payload { Payload {
msg: plaintext.as_ref(), msg: plaintext.as_ref(),
aad: associated_data, aad: associated_data,
@@ -142,7 +137,7 @@ impl KeyCell {
pub type Salt = [u8; ArgonSalt::RECOMMENDED_LENGTH]; pub type Salt = [u8; ArgonSalt::RECOMMENDED_LENGTH];
pub(super) fn generate_salt() -> Salt { pub fn generate_salt() -> Salt {
let mut salt = Salt::default(); let mut salt = Salt::default();
let mut rng = StdRng::try_from_rng(&mut SysRng).unwrap(); let mut rng = StdRng::try_from_rng(&mut SysRng).unwrap();
rng.fill_bytes(&mut salt); rng.fill_bytes(&mut salt);
@@ -151,7 +146,7 @@ pub(super) fn generate_salt() -> Salt {
/// User password might be of different length, have not enough entropy, etc... /// User password might be of different length, have not enough entropy, etc...
/// Derive a fixed-length key from the password using Argon2id, which is designed for password hashing and key derivation. /// Derive a fixed-length key from the password using Argon2id, which is designed for password hashing and key derivation.
pub(super) fn derive_seal_key(mut password: MemSafe<Vec<u8>>, salt: &Salt) -> KeyCell { pub fn derive_seal_key(mut password: MemSafe<Vec<u8>>, salt: &Salt) -> KeyCell {
let params = argon2::Params::new(262_144, 3, 4, None).unwrap(); let params = argon2::Params::new(262_144, 3, 4, None).unwrap();
let hasher = Argon2::new(Algorithm::Argon2id, argon2::Version::V0x13, params); let hasher = Argon2::new(Algorithm::Argon2id, argon2::Version::V0x13, params);
let mut key = MemSafe::new(Key::default()).unwrap(); let mut key = MemSafe::new(Key::default()).unwrap();

View File

@@ -0,0 +1,422 @@
use diesel::{
ExpressionMethods as _, OptionalExtension, QueryDsl, SelectableHelper,
dsl::{insert_into, update},
};
use diesel_async::{AsyncConnection, RunQueryDsl};
use kameo::{Actor, Reply, messages};
use memsafe::MemSafe;
use strum::{EnumDiscriminants, IntoDiscriminant};
use tracing::{error, info};
use crate::db::{
self,
models::{self, RootKeyHistory},
schema::{self},
};
use encryption::v1::{self, KeyCell, Nonce};
pub mod encryption;
#[derive(Default, EnumDiscriminants)]
#[strum_discriminants(derive(Reply), vis(pub))]
enum State {
#[default]
Unbootstrapped,
Sealed {
root_key_history_id: i32,
},
Unsealed {
root_key_history_id: i32,
root_key: KeyCell,
},
}
#[derive(Debug, thiserror::Error, miette::Diagnostic)]
pub enum Error {
#[error("Keyholder is already bootstrapped")]
#[diagnostic(code(arbiter::keyholder::already_bootstrapped))]
AlreadyBootstrapped,
#[error("Keyholder is not bootstrapped")]
#[diagnostic(code(arbiter::keyholder::not_bootstrapped))]
NotBootstrapped,
#[error("Invalid key provided")]
#[diagnostic(code(arbiter::keyholder::invalid_key))]
InvalidKey,
#[error("Requested aead entry not found")]
#[diagnostic(code(arbiter::keyholder::aead_not_found))]
NotFound,
#[error("Encryption error: {0}")]
#[diagnostic(code(arbiter::keyholder::encryption_error))]
Encryption(#[from] chacha20poly1305::aead::Error),
#[error("Database error: {0}")]
#[diagnostic(code(arbiter::keyholder::database_error))]
DatabaseConnection(#[from] db::PoolError),
#[error("Database transaction error: {0}")]
#[diagnostic(code(arbiter::keyholder::database_transaction_error))]
DatabaseTransaction(#[from] diesel::result::Error),
#[error("Broken database")]
#[diagnostic(code(arbiter::keyholder::broken_database))]
BrokenDatabase,
}
/// Manages vault root key and tracks current state of the vault (bootstrapped/unbootstrapped, sealed/unsealed).
/// Provides API for encrypting and decrypting data using the vault root key.
/// Abstraction over database to make sure nonces are never reused and encryption keys are never exposed in plaintext outside of this actor.
#[derive(Actor)]
pub struct KeyHolder {
db: db::DatabasePool,
state: State,
}
#[messages]
impl KeyHolder {
pub async fn new(db: db::DatabasePool) -> Result<Self, Error> {
let state = {
let mut conn = db.get().await?;
let (root_key_history,) = schema::arbiter_settings::table
.left_join(schema::root_key_history::table)
.select((Option::<RootKeyHistory>::as_select(),))
.get_result::<(Option<RootKeyHistory>,)>(&mut conn)
.await?;
match root_key_history {
Some(root_key_history) => State::Sealed {
root_key_history_id: root_key_history.id,
},
None => State::Unbootstrapped,
}
};
Ok(Self { db, state })
}
// Exclusive transaction to avoid race condtions if multiple keyholders write
// additional layer of protection against nonce-reuse
async fn get_new_nonce(pool: &db::DatabasePool, root_key_id: i32) -> Result<Nonce, Error> {
let mut conn = pool.get().await?;
let nonce = conn
.exclusive_transaction(|conn| {
Box::pin(async move {
let current_nonce: Vec<u8> = schema::root_key_history::table
.filter(schema::root_key_history::id.eq(root_key_id))
.select(schema::root_key_history::data_encryption_nonce)
.first(conn)
.await?;
let mut nonce =
v1::Nonce::try_from(current_nonce.as_slice()).map_err(|_| {
error!(
"Broken database: invalid nonce for root key history id={}",
root_key_id
);
Error::BrokenDatabase
})?;
nonce.increment();
update(schema::root_key_history::table)
.filter(schema::root_key_history::id.eq(root_key_id))
.set(schema::root_key_history::data_encryption_nonce.eq(nonce.to_vec()))
.execute(conn)
.await?;
Result::<_, Error>::Ok(nonce)
})
})
.await?;
Ok(nonce)
}
#[message]
pub async fn bootstrap(&mut self, seal_key_raw: MemSafe<Vec<u8>>) -> Result<(), Error> {
if !matches!(self.state, State::Unbootstrapped) {
return Err(Error::AlreadyBootstrapped);
}
let salt = v1::generate_salt();
let mut seal_key = v1::derive_seal_key(seal_key_raw, &salt);
let mut root_key = KeyCell::new_secure_random();
// Zero nonces are fine because they are one-time
let root_key_nonce = v1::Nonce::default();
let data_encryption_nonce = v1::Nonce::default();
let root_key_ciphertext: Vec<u8> = {
let root_key_reader = root_key.0.read().unwrap();
let root_key_reader = root_key_reader.as_slice();
seal_key
.encrypt(&root_key_nonce, v1::ROOT_KEY_TAG, root_key_reader)
.map_err(|err| {
error!(?err, "Fatal bootstrap error");
Error::Encryption(err)
})?
};
let mut conn = self.db.get().await?;
let data_encryption_nonce_bytes = data_encryption_nonce.to_vec();
let root_key_history_id = conn
.transaction(|conn| {
Box::pin(async move {
let root_key_history_id: i32 = insert_into(schema::root_key_history::table)
.values(&models::NewRootKeyHistory {
ciphertext: root_key_ciphertext,
tag: v1::ROOT_KEY_TAG.to_vec(),
root_key_encryption_nonce: root_key_nonce.to_vec(),
data_encryption_nonce: data_encryption_nonce_bytes,
schema_version: 1,
salt: salt.to_vec(),
})
.returning(schema::root_key_history::id)
.get_result(conn)
.await?;
update(schema::arbiter_settings::table)
.set(schema::arbiter_settings::root_key_id.eq(root_key_history_id))
.execute(conn)
.await?;
Result::<_, diesel::result::Error>::Ok(root_key_history_id)
})
})
.await?;
self.state = State::Unsealed {
root_key,
root_key_history_id,
};
info!("Keyholder bootstrapped successfully");
Ok(())
}
#[message]
pub async fn try_unseal(&mut self, seal_key_raw: MemSafe<Vec<u8>>) -> Result<(), Error> {
let State::Sealed {
root_key_history_id,
} = &self.state
else {
return Err(Error::NotBootstrapped);
};
// We don't want to hold connection while doing expensive KDF work
let current_key = {
let mut conn = self.db.get().await?;
schema::root_key_history::table
.filter(schema::root_key_history::id.eq(*root_key_history_id))
.select(schema::root_key_history::data_encryption_nonce)
.select(RootKeyHistory::as_select())
.first(&mut conn)
.await?
};
let salt = &current_key.salt;
let salt = v1::Salt::try_from(salt.as_slice()).map_err(|_| {
error!("Broken database: invalid salt for root key");
Error::BrokenDatabase
})?;
let mut seal_key = v1::derive_seal_key(seal_key_raw, &salt);
let mut root_key = MemSafe::new(current_key.ciphertext.clone()).unwrap();
let nonce = v1::Nonce::try_from(current_key.root_key_encryption_nonce.as_slice()).map_err(
|_| {
error!("Broken database: invalid nonce for root key");
Error::BrokenDatabase
},
)?;
seal_key
.decrypt_in_place(&nonce, v1::ROOT_KEY_TAG, &mut root_key)
.map_err(|err| {
error!(?err, "Failed to unseal root key: invalid seal key");
Error::InvalidKey
})?;
self.state = State::Unsealed {
root_key_history_id: current_key.id,
root_key: v1::KeyCell::try_from(root_key).map_err(|err| {
error!(?err, "Broken database: invalid encryption key size");
Error::BrokenDatabase
})?,
};
info!("Keyholder unsealed successfully");
Ok(())
}
// Decrypts the `aead_encrypted` entry with the given ID and returns the plaintext
#[message]
pub async fn decrypt(&mut self, aead_id: i32) -> Result<MemSafe<Vec<u8>>, Error> {
let State::Unsealed { root_key, .. } = &mut self.state else {
return Err(Error::NotBootstrapped);
};
let row: models::AeadEncrypted = {
let mut conn = self.db.get().await?;
schema::aead_encrypted::table
.select(models::AeadEncrypted::as_select())
.filter(schema::aead_encrypted::id.eq(aead_id))
.first(&mut conn)
.await
.optional()?
.ok_or(Error::NotFound)?
};
let nonce = v1::Nonce::try_from(row.current_nonce.as_slice()).map_err(|_| {
error!(
"Broken database: invalid nonce for aead_encrypted id={}",
aead_id
);
Error::BrokenDatabase
})?;
let mut output = MemSafe::new(row.ciphertext).unwrap();
root_key.decrypt_in_place(&nonce, v1::TAG, &mut output)?;
Ok(output)
}
// Creates new `aead_encrypted` entry in the database and returns it's ID
#[message]
pub async fn create_new(&mut self, mut plaintext: MemSafe<Vec<u8>>) -> Result<i32, Error> {
let State::Unsealed {
root_key,
root_key_history_id,
} = &mut self.state
else {
return Err(Error::NotBootstrapped);
};
// Order matters here - `get_new_nonce` acquires connection, so we need to call it before next acquire
// Borrow checker note: &mut borrow a few lines above is disjoint from this field
let nonce = Self::get_new_nonce(&self.db, *root_key_history_id).await?;
let mut ciphertext_buffer = plaintext.write().unwrap();
let ciphertext_buffer: &mut Vec<u8> = ciphertext_buffer.as_mut();
root_key.encrypt_in_place(&nonce, v1::TAG, &mut *ciphertext_buffer)?;
let ciphertext = std::mem::take(ciphertext_buffer);
let mut conn = self.db.get().await?;
let aead_id: i32 = insert_into(schema::aead_encrypted::table)
.values(&models::NewAeadEncrypted {
ciphertext,
tag: v1::TAG.to_vec(),
current_nonce: nonce.to_vec(),
schema_version: 1,
associated_root_key_id: *root_key_history_id,
created_at: chrono::Utc::now().timestamp() as i32,
})
.returning(schema::aead_encrypted::id)
.get_result(&mut conn)
.await?;
Ok(aead_id)
}
#[message]
pub fn get_state(&self) -> StateDiscriminants {
self.state.discriminant()
}
#[message]
pub fn seal(&mut self) -> Result<(), Error> {
let State::Unsealed {
root_key_history_id,
..
} = &self.state
else {
return Err(Error::NotBootstrapped);
};
self.state = State::Sealed {
root_key_history_id: *root_key_history_id,
};
Ok(())
}
}
#[cfg(test)]
mod tests {
use diesel::SelectableHelper;
use diesel::dsl::insert_into;
use diesel_async::RunQueryDsl;
use memsafe::MemSafe;
use crate::db::{self, models::ArbiterSetting};
use super::*;
async fn seed_settings(pool: &db::DatabasePool) {
let mut conn = pool.get().await.unwrap();
insert_into(schema::arbiter_settings::table)
.values(&ArbiterSetting {
id: 1,
root_key_id: None,
cert_key: vec![],
cert: vec![],
})
.execute(&mut conn)
.await
.unwrap();
}
async fn bootstrapped_actor(db: &db::DatabasePool) -> KeyHolder {
seed_settings(db).await;
let mut actor = KeyHolder::new(db.clone()).await.unwrap();
let seal_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap();
actor.bootstrap(seal_key).await.unwrap();
actor
}
#[tokio::test]
#[test_log::test]
async fn nonce_monotonic_even_when_nonce_allocation_interleaves() {
let db = db::create_test_pool().await;
let mut actor = bootstrapped_actor(&db).await;
let root_key_history_id = match actor.state {
State::Unsealed {
root_key_history_id,
..
} => root_key_history_id,
_ => panic!("expected unsealed state"),
};
let n1 = KeyHolder::get_new_nonce(&db, root_key_history_id)
.await
.unwrap();
let n2 = KeyHolder::get_new_nonce(&db, root_key_history_id)
.await
.unwrap();
assert!(n2.to_vec() > n1.to_vec(), "nonce must increase");
let mut conn = db.get().await.unwrap();
let root_row: models::RootKeyHistory = schema::root_key_history::table
.select(models::RootKeyHistory::as_select())
.first(&mut conn)
.await
.unwrap();
assert_eq!(root_row.data_encryption_nonce, n2.to_vec());
let id = actor
.create_new(MemSafe::new(b"post-interleave".to_vec()).unwrap())
.await
.unwrap();
let row: models::AeadEncrypted = schema::aead_encrypted::table
.filter(schema::aead_encrypted::id.eq(id))
.select(models::AeadEncrypted::as_select())
.first(&mut conn)
.await
.unwrap();
assert!(
row.current_nonce > n2.to_vec(),
"next write must advance nonce"
);
}
}

View File

@@ -0,0 +1,40 @@
use kameo::actor::{ActorRef, Spawn};
use miette::Diagnostic;
use thiserror::Error;
use crate::{
actors::{bootstrap::Bootstrapper, keyholder::KeyHolder},
db,
};
pub mod bootstrap;
pub mod client;
pub mod keyholder;
pub mod user_agent;
#[derive(Error, Debug, Diagnostic)]
pub enum SpawnError {
#[error("Failed to spawn Bootstrapper actor")]
#[diagnostic(code(SpawnError::Bootstrapper))]
Bootstrapper(#[from] bootstrap::Error),
#[error("Failed to spawn KeyHolder actor")]
#[diagnostic(code(SpawnError::KeyHolder))]
KeyHolder(#[from] keyholder::Error),
}
/// Long-lived actors that are shared across all connections and handle global state and operations
#[derive(Clone)]
pub struct GlobalActors {
pub key_holder: ActorRef<KeyHolder>,
pub bootstrapper: ActorRef<Bootstrapper>,
}
impl GlobalActors {
pub async fn spawn(db: db::DatabasePool) -> Result<Self, SpawnError> {
Ok(Self {
bootstrapper: Bootstrapper::spawn(Bootstrapper::new(&db).await?),
key_holder: KeyHolder::spawn(KeyHolder::new(db.clone()).await?),
})
}
}

View File

@@ -1,25 +1,18 @@
use std::{ use std::{ops::DerefMut, sync::Mutex};
ops::DerefMut,
sync::Mutex,
};
use arbiter_proto::proto::{ use arbiter_proto::proto::{
UserAgentResponse, UnsealEncryptedKey, UnsealResult, UnsealStart, UnsealStartResponse, UserAgentResponse,
auth::{ auth::{
self, AuthChallengeRequest, AuthOk, ServerMessage as AuthServerMessage, self, AuthChallengeRequest, AuthOk, ServerMessage as AuthServerMessage,
server_message::Payload as ServerAuthPayload, server_message::Payload as ServerAuthPayload,
}, },
unseal::{UnsealEncryptedKey, UnsealResult, UnsealStart, UnsealStartResponse},
user_agent_response::Payload as UserAgentResponsePayload, user_agent_response::Payload as UserAgentResponsePayload,
}; };
use chacha20poly1305::{ use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit};
AeadInPlace, XChaCha20Poly1305, XNonce,
aead::KeyInit,
};
use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, dsl::update}; use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, dsl::update};
use diesel_async::{AsyncConnection, RunQueryDsl}; use diesel_async::RunQueryDsl;
use ed25519_dalek::VerifyingKey; use ed25519_dalek::VerifyingKey;
use kameo::{Actor, actor::ActorRef, messages}; use kameo::{Actor, error::SendError, messages};
use memsafe::MemSafe; use memsafe::MemSafe;
use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::Sender;
use tonic::Status; use tonic::Status;
@@ -29,10 +22,12 @@ use x25519_dalek::{EphemeralSecret, PublicKey};
use crate::{ use crate::{
ServerContext, ServerContext,
actors::{ actors::{
bootstrap::{Bootstrapper, ConsumeToken}, GlobalActors,
bootstrap::ConsumeToken,
keyholder::{self, TryUnseal},
user_agent::state::{ user_agent::state::{
AuthRequestContext, ChallengeContext, DummyContext, UnsealContext, UserAgentEvents, ChallengeContext, DummyContext, UnsealContext, UserAgentEvents, UserAgentStateMachine,
UserAgentStateMachine, UserAgentStates, UserAgentStates,
}, },
}, },
db::{self, schema}, db::{self, schema},
@@ -40,8 +35,6 @@ use crate::{
}; };
mod state; mod state;
#[cfg(test)]
mod tests;
mod transport; mod transport;
pub(crate) use transport::handle_user_agent; pub(crate) use transport::handle_user_agent;
@@ -49,7 +42,7 @@ pub(crate) use transport::handle_user_agent;
#[derive(Actor)] #[derive(Actor)]
pub struct UserAgentActor { pub struct UserAgentActor {
db: db::DatabasePool, db: db::DatabasePool,
bootstapper: ActorRef<Bootstrapper>, actors: GlobalActors,
state: UserAgentStateMachine<DummyContext>, state: UserAgentStateMachine<DummyContext>,
// will be used in future // will be used in future
_tx: Sender<Result<UserAgentResponse, Status>>, _tx: Sender<Result<UserAgentResponse, Status>>,
@@ -62,21 +55,20 @@ impl UserAgentActor {
) -> Self { ) -> Self {
Self { Self {
db: context.db.clone(), db: context.db.clone(),
bootstapper: context.bootstrapper.clone(), actors: context.actors.clone(),
state: UserAgentStateMachine::new(DummyContext), state: UserAgentStateMachine::new(DummyContext),
_tx: tx, _tx: tx,
} }
} }
#[cfg(test)] pub fn new_manual(
pub(crate) fn new_manual(
db: db::DatabasePool, db: db::DatabasePool,
bootstapper: ActorRef<Bootstrapper>, actors: GlobalActors,
tx: Sender<Result<UserAgentResponse, Status>>, tx: Sender<Result<UserAgentResponse, Status>>,
) -> Self { ) -> Self {
Self { Self {
db, db,
bootstapper, actors,
state: UserAgentStateMachine::new(DummyContext), state: UserAgentStateMachine::new(DummyContext),
_tx: tx, _tx: tx,
} }
@@ -96,7 +88,8 @@ impl UserAgentActor {
token: String, token: String,
) -> Result<UserAgentResponse, Status> { ) -> Result<UserAgentResponse, Status> {
let token_ok: bool = self let token_ok: bool = self
.bootstapper .actors
.bootstrapper
.ask(ConsumeToken { token }) .ask(ConsumeToken { token })
.await .await
.map_err(|e| { .map_err(|e| {
@@ -131,7 +124,7 @@ impl UserAgentActor {
let nonce: Option<i32> = { let nonce: Option<i32> = {
let mut db_conn = self.db.get().await.to_status()?; let mut db_conn = self.db.get().await.to_status()?;
db_conn db_conn
.transaction(|conn| { .exclusive_transaction(|conn| {
Box::pin(async move { Box::pin(async move {
let current_nonce = schema::useragent_client::table let current_nonce = schema::useragent_client::table
.filter( .filter(
@@ -164,7 +157,7 @@ impl UserAgentActor {
let challenge = auth::AuthChallenge { let challenge = auth::AuthChallenge {
pubkey: pubkey_bytes, pubkey: pubkey_bytes,
nonce: nonce, nonce,
}; };
self.transition(UserAgentEvents::SentChallenge(ChallengeContext { self.transition(UserAgentEvents::SentChallenge(ChallengeContext {
@@ -239,7 +232,6 @@ impl UserAgentActor {
let client_public_key = PublicKey::from(client_pubkey_bytes); let client_public_key = PublicKey::from(client_pubkey_bytes);
self.transition(UserAgentEvents::UnsealRequest(UnsealContext { self.transition(UserAgentEvents::UnsealRequest(UnsealContext {
server_public_key: public_key,
secret: Mutex::new(Some(secret)), secret: Mutex::new(Some(secret)),
client_public_key, client_public_key,
}))?; }))?;
@@ -280,22 +272,58 @@ impl UserAgentActor {
let shared_secret = ephemeral_secret.diffie_hellman(&unseal_context.client_public_key); let shared_secret = ephemeral_secret.diffie_hellman(&unseal_context.client_public_key);
let cipher = XChaCha20Poly1305::new(shared_secret.as_bytes().into()); let cipher = XChaCha20Poly1305::new(shared_secret.as_bytes().into());
let mut root_key_buffer = MemSafe::new(req.ciphertext.clone()).unwrap(); let mut seal_key_buffer = MemSafe::new(req.ciphertext.clone()).unwrap();
let mut write_handle = root_key_buffer.write().unwrap();
let write_handle = write_handle.deref_mut();
let decryption_result = cipher let decryption_result = {
.decrypt_in_place(nonce, &req.associated_data, write_handle); let mut write_handle = seal_key_buffer.write().unwrap();
let write_handle = write_handle.deref_mut();
cipher.decrypt_in_place(nonce, &req.associated_data, write_handle)
};
match decryption_result { match decryption_result {
Ok(_) => todo!("Send key to the keyguarding"), Ok(_) => {
match self
.actors
.key_holder
.ask(TryUnseal {
seal_key_raw: seal_key_buffer,
})
.await
{
Ok(_) => {
info!("Successfully unsealed key with client-provided key");
self.transition(UserAgentEvents::ReceivedValidKey)?;
Ok(unseal_response(UserAgentResponsePayload::UnsealResult(
UnsealResult::Success.into(),
)))
}
Err(SendError::HandlerError(keyholder::Error::InvalidKey)) => {
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(unseal_response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(),
)))
}
Err(SendError::HandlerError(err)) => {
error!(?err, "Keyholder failed to unseal key");
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(unseal_response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(),
)))
}
Err(err) => {
error!(?err, "Failed to send unseal request to keyholder");
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Err(Status::internal("Vault is not available"))
}
}
}
Err(err) => { Err(err) => {
error!(?err, "Failed to decrypt unseal key"); error!(?err, "Failed to decrypt unseal key");
self.transition(UserAgentEvents::ReceivedInvalidKey)?; self.transition(UserAgentEvents::ReceivedInvalidKey)?;
return Ok(unseal_response(UserAgentResponsePayload::UnsealResult( Ok(unseal_response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(), UnsealResult::InvalidKey.into(),
))); )))
}, }
} }
} }
@@ -309,10 +337,7 @@ impl UserAgentActor {
Status::invalid_argument("Failed to convert pubkey to VerifyingKey") Status::invalid_argument("Failed to convert pubkey to VerifyingKey")
})?; })?;
self.transition(UserAgentEvents::AuthRequest(AuthRequestContext { self.transition(UserAgentEvents::AuthRequest)?;
pubkey,
bootstrap_token: req.bootstrap_token.clone(),
}))?;
match req.bootstrap_token { match req.bootstrap_token {
Some(token) => self.auth_with_bootstrap_token(pubkey, token).await, Some(token) => self.auth_with_bootstrap_token(pubkey, token).await,

View File

@@ -12,31 +12,19 @@ pub struct ChallengeContext {
pub key: VerifyingKey, pub key: VerifyingKey,
} }
// Request context with deserialized public key for state machine.
// This intermediate struct is needed because the state machine branches depending on presence of bootstrap token,
// but we want to have the deserialized key in both branches.
#[derive(Clone, Debug)]
pub struct AuthRequestContext {
pub pubkey: VerifyingKey,
pub bootstrap_token: Option<String>,
}
pub struct UnsealContext { pub struct UnsealContext {
pub server_public_key: PublicKey,
pub client_public_key: PublicKey, pub client_public_key: PublicKey,
pub secret: Mutex<Option<EphemeralSecret>>, pub secret: Mutex<Option<EphemeralSecret>>,
} }
smlang::statemachine!( smlang::statemachine!(
name: UserAgent, name: UserAgent,
custom_error: false, custom_error: false,
transitions: { transitions: {
*Init + AuthRequest(AuthRequestContext) / auth_request_context = ReceivedAuthRequest(AuthRequestContext), *Init + AuthRequest = ReceivedAuthRequest,
ReceivedAuthRequest(AuthRequestContext) + ReceivedBootstrapToken = Idle, ReceivedAuthRequest + ReceivedBootstrapToken = Idle,
ReceivedAuthRequest(AuthRequestContext) + SentChallenge(ChallengeContext) / move_challenge = WaitingForChallengeSolution(ChallengeContext), ReceivedAuthRequest + SentChallenge(ChallengeContext) / move_challenge = WaitingForChallengeSolution(ChallengeContext),
WaitingForChallengeSolution(ChallengeContext) + ReceivedGoodSolution = Idle, WaitingForChallengeSolution(ChallengeContext) + ReceivedGoodSolution = Idle,
WaitingForChallengeSolution(ChallengeContext) + ReceivedBadSolution = AuthError, // block further transitions, but connection should close anyway WaitingForChallengeSolution(ChallengeContext) + ReceivedBadSolution = AuthError, // block further transitions, but connection should close anyway
@@ -49,28 +37,15 @@ smlang::statemachine!(
pub struct DummyContext; pub struct DummyContext;
impl UserAgentStateMachineContext for DummyContext { impl UserAgentStateMachineContext for DummyContext {
#[allow(missing_docs)]
#[allow(clippy::unused_unit)]
fn move_challenge(
&mut self,
_state_data: &AuthRequestContext,
event_data: ChallengeContext,
) -> Result<ChallengeContext, ()> {
Ok(event_data)
}
#[allow(missing_docs)]
#[allow(clippy::unused_unit)]
fn auth_request_context(
&mut self,
event_data: AuthRequestContext,
) -> Result<AuthRequestContext, ()> {
Ok(event_data)
}
#[allow(missing_docs)] #[allow(missing_docs)]
#[allow(clippy::unused_unit)] #[allow(clippy::unused_unit)]
fn generate_temp_keypair(&mut self, event_data: UnsealContext) -> Result<UnsealContext, ()> { fn generate_temp_keypair(&mut self, event_data: UnsealContext) -> Result<UnsealContext, ()> {
Ok(event_data) Ok(event_data)
} }
#[allow(missing_docs)]
#[allow(clippy::unused_unit)]
fn move_challenge(&mut self, event_data: ChallengeContext) -> Result<ChallengeContext, ()> {
Ok(event_data)
}
} }

View File

@@ -1,9 +1,7 @@
use super::UserAgentActor; use super::UserAgentActor;
use arbiter_proto::proto::{ use arbiter_proto::proto::{
UserAgentRequest, UserAgentResponse, UserAgentRequest, UserAgentResponse,
auth::{ auth::{ClientMessage as ClientAuthMessage, client_message::Payload as ClientAuthPayload},
ClientMessage as ClientAuthMessage, client_message::Payload as ClientAuthPayload,
},
user_agent_request::Payload as UserAgentRequestPayload, user_agent_request::Payload as UserAgentRequestPayload,
}; };
use futures::StreamExt; use futures::StreamExt;

View File

@@ -2,15 +2,11 @@ use std::sync::Arc;
use diesel::OptionalExtension as _; use diesel::OptionalExtension as _;
use diesel_async::RunQueryDsl as _; use diesel_async::RunQueryDsl as _;
use kameo::actor::{ActorRef, Spawn};
use miette::Diagnostic; use miette::Diagnostic;
use thiserror::Error; use thiserror::Error;
use crate::{ use crate::{
actors::{ actors::GlobalActors,
bootstrap::{self, Bootstrapper},
keyholder::KeyHolder,
},
context::tls::{TlsDataRaw, TlsManager}, context::tls::{TlsDataRaw, TlsManager},
db::{self, models::ArbiterSetting, schema::arbiter_settings}, db::{self, models::ArbiterSetting, schema::arbiter_settings},
}; };
@@ -35,13 +31,9 @@ pub enum InitError {
#[diagnostic(code(arbiter_server::init::tls_init))] #[diagnostic(code(arbiter_server::init::tls_init))]
Tls(#[from] tls::TlsInitError), Tls(#[from] tls::TlsInitError),
#[error("Bootstrap token generation failed: {0}")] #[error("Actor spawn failed: {0}")]
#[diagnostic(code(arbiter_server::init::bootstrap_token))] #[diagnostic(code(arbiter_server::init::actor_spawn))]
BootstrapToken(#[from] bootstrap::BootstrapError), ActorSpawn(#[from] crate::actors::SpawnError),
#[error("KeyHolder initialization failed: {0}")]
#[diagnostic(code(arbiter_server::init::keyholder_init))]
KeyHolder(#[from] crate::actors::keyholder::Error),
#[error("I/O Error: {0}")] #[error("I/O Error: {0}")]
#[diagnostic(code(arbiter_server::init::io))] #[diagnostic(code(arbiter_server::init::io))]
@@ -51,8 +43,7 @@ pub enum InitError {
pub struct _ServerContextInner { pub struct _ServerContextInner {
pub db: db::DatabasePool, pub db: db::DatabasePool,
pub tls: TlsManager, pub tls: TlsManager,
pub bootstrapper: ActorRef<Bootstrapper>, pub actors: GlobalActors,
pub keyholder: ActorRef<KeyHolder>,
} }
#[derive(Clone)] #[derive(Clone)]
pub struct ServerContext(Arc<_ServerContextInner>); pub struct ServerContext(Arc<_ServerContextInner>);
@@ -111,8 +102,7 @@ impl ServerContext {
drop(conn); drop(conn);
Ok(Self(Arc::new(_ServerContextInner { Ok(Self(Arc::new(_ServerContextInner {
bootstrapper: Bootstrapper::spawn(Bootstrapper::new(&db).await?), actors: GlobalActors::spawn(db.clone()).await?,
keyholder: KeyHolder::spawn(KeyHolder::new(db.clone()).await?),
db, db,
tls, tls,
}))) })))

View File

@@ -5,7 +5,6 @@ use rcgen::{Certificate, KeyPair};
use rustls::pki_types::CertificateDer; use rustls::pki_types::CertificateDer;
use thiserror::Error; use thiserror::Error;
#[derive(Error, Debug, Diagnostic)] #[derive(Error, Debug, Diagnostic)]
pub enum TlsInitError { pub enum TlsInitError {
#[error("Key generation error during TLS initialization: {0}")] #[error("Key generation error during TLS initialization: {0}")]
@@ -41,8 +40,7 @@ impl TlsDataRaw {
pub fn deserialize(&self) -> Result<TlsData, TlsInitError> { pub fn deserialize(&self) -> Result<TlsData, TlsInitError> {
let cert = CertificateDer::from_slice(&self.cert).into_owned(); let cert = CertificateDer::from_slice(&self.cert).into_owned();
let key = let key = String::from_utf8(self.key.clone()).map_err(TlsInitError::KeyInvalidFormat)?;
String::from_utf8(self.key.clone()).map_err(TlsInitError::KeyInvalidFormat)?;
let keypair = KeyPair::from_pem(&key).map_err(TlsInitError::KeyDeserializationError)?; let keypair = KeyPair::from_pem(&key).map_err(TlsInitError::KeyDeserializationError)?;
@@ -51,10 +49,8 @@ impl TlsDataRaw {
} }
fn generate_cert(key: &KeyPair) -> Result<Certificate, rcgen::Error> { fn generate_cert(key: &KeyPair) -> Result<Certificate, rcgen::Error> {
let params = rcgen::CertificateParams::new(vec![ let params =
"arbiter.local".to_string(), rcgen::CertificateParams::new(vec!["arbiter.local".to_string(), "localhost".to_string()])?;
"localhost".to_string(),
])?;
params.self_signed(key) params.self_signed(key)
} }

View File

@@ -1,8 +1,4 @@
use diesel::{Connection as _, SqliteConnection, connection::SimpleConnection as _};
use diesel::{
Connection as _, SqliteConnection,
connection::SimpleConnection as _,
};
use diesel_async::{ use diesel_async::{
AsyncConnection, SimpleAsyncConnection, AsyncConnection, SimpleAsyncConnection,
pooled_connection::{AsyncDieselConnectionManager, ManagerConfig}, pooled_connection::{AsyncDieselConnectionManager, ManagerConfig},
@@ -21,7 +17,7 @@ pub type DatabasePool = diesel_async::pooled_connection::bb8::Pool<DatabaseConne
pub type PoolInitError = diesel_async::pooled_connection::PoolError; pub type PoolInitError = diesel_async::pooled_connection::PoolError;
pub type PoolError = diesel_async::pooled_connection::bb8::RunError; pub type PoolError = diesel_async::pooled_connection::bb8::RunError;
static DB_FILE: &'static str = "arbiter.sqlite"; static DB_FILE: &str = "arbiter.sqlite";
const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations"); const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");
@@ -133,7 +129,6 @@ pub async fn create_pool(url: Option<&str>) -> Result<DatabasePool, DatabaseSetu
Ok(pool) Ok(pool)
} }
#[cfg(test)]
pub async fn create_test_pool() -> DatabasePool { pub async fn create_test_pool() -> DatabasePool {
use rand::distr::{Alphanumeric, SampleString as _}; use rand::distr::{Alphanumeric, SampleString as _};

View File

@@ -7,7 +7,7 @@ pub trait GrpcStatusExt<T> {
impl<T> GrpcStatusExt<T> for Result<T, diesel::result::Error> { impl<T> GrpcStatusExt<T> for Result<T, diesel::result::Error> {
fn to_status(self) -> Result<T, Status> { fn to_status(self) -> Result<T, Status> {
self.map_err(|e| { self.map_err(|e| {
error!(error = ?e, "Database error"); error!(error = ?e, "Database error");
Status::internal("Database error") Status::internal("Database error")
}) })
@@ -21,4 +21,4 @@ impl<T> GrpcStatusExt<T> for Result<T, crate::db::PoolError> {
Status::internal("Database pool error") Status::internal("Database pool error")
}) })
} }
} }

View File

@@ -0,0 +1,43 @@
use arbiter_server::{
actors::keyholder::KeyHolder,
db::{self, models::ArbiterSetting, schema},
};
use diesel::{QueryDsl, insert_into};
use diesel_async::RunQueryDsl;
use memsafe::MemSafe;
pub async fn seed_settings(pool: &db::DatabasePool) {
let mut conn = pool.get().await.unwrap();
insert_into(schema::arbiter_settings::table)
.values(&ArbiterSetting {
id: 1,
root_key_id: None,
cert_key: vec![],
cert: vec![],
})
.execute(&mut conn)
.await
.unwrap();
}
#[allow(dead_code)]
pub async fn bootstrapped_keyholder(db: &db::DatabasePool) -> KeyHolder {
seed_settings(db).await;
let mut actor = KeyHolder::new(db.clone()).await.unwrap();
actor
.bootstrap(MemSafe::new(b"test-seal-key".to_vec()).unwrap())
.await
.unwrap();
actor
}
#[allow(dead_code)]
pub async fn root_key_history_id(db: &db::DatabasePool) -> i32 {
let mut conn = db.get().await.unwrap();
let id = schema::arbiter_settings::table
.select(schema::arbiter_settings::root_key_id)
.first::<Option<i32>>(&mut conn)
.await
.unwrap();
id.expect("root_key_id should be set after bootstrap")
}

View File

@@ -0,0 +1,8 @@
mod common;
#[path = "keyholder/concurrency.rs"]
mod concurrency;
#[path = "keyholder/lifecycle.rs"]
mod lifecycle;
#[path = "keyholder/storage.rs"]
mod storage;

View File

@@ -0,0 +1,173 @@
use std::collections::{HashMap, HashSet};
use arbiter_server::{
actors::keyholder::{CreateNew, Error, KeyHolder},
db::{self, models, schema},
};
use diesel::{ExpressionMethods as _, QueryDsl, SelectableHelper, dsl::sql_query};
use diesel_async::RunQueryDsl;
use kameo::actor::{ActorRef, Spawn as _};
use memsafe::MemSafe;
use tokio::task::JoinSet;
use crate::common;
async fn write_concurrently(
actor: ActorRef<KeyHolder>,
prefix: &'static str,
count: usize,
) -> Vec<(i32, Vec<u8>)> {
let mut set = JoinSet::new();
for i in 0..count {
let actor = actor.clone();
set.spawn(async move {
let plaintext = format!("{prefix}-{i}").into_bytes();
let id = actor
.ask(CreateNew {
plaintext: MemSafe::new(plaintext.clone()).unwrap(),
})
.await
.unwrap();
(id, plaintext)
});
}
let mut out = Vec::with_capacity(count);
while let Some(res) = set.join_next().await {
out.push(res.unwrap());
}
out
}
#[tokio::test]
#[test_log::test]
async fn concurrent_create_new_no_duplicate_nonces_() {
let db = db::create_test_pool().await;
let actor = KeyHolder::spawn(common::bootstrapped_keyholder(&db).await);
let writes = write_concurrently(actor, "nonce-unique", 32).await;
assert_eq!(writes.len(), 32);
let mut conn = db.get().await.unwrap();
let rows: Vec<models::AeadEncrypted> = schema::aead_encrypted::table
.select(models::AeadEncrypted::as_select())
.load(&mut conn)
.await
.unwrap();
assert_eq!(rows.len(), 32);
let nonces: Vec<&Vec<u8>> = rows.iter().map(|r| &r.current_nonce).collect();
let unique: HashSet<&Vec<u8>> = nonces.iter().copied().collect();
assert_eq!(nonces.len(), unique.len(), "all nonces must be unique");
}
#[tokio::test]
#[test_log::test]
async fn concurrent_create_new_root_nonce_never_moves_backward() {
let db = db::create_test_pool().await;
let actor = KeyHolder::spawn(common::bootstrapped_keyholder(&db).await);
write_concurrently(actor, "root-max", 24).await;
let mut conn = db.get().await.unwrap();
let rows: Vec<models::AeadEncrypted> = schema::aead_encrypted::table
.select(models::AeadEncrypted::as_select())
.load(&mut conn)
.await
.unwrap();
let max_nonce = rows
.iter()
.map(|r| r.current_nonce.clone())
.max()
.expect("at least one row");
let root_row: models::RootKeyHistory = schema::root_key_history::table
.select(models::RootKeyHistory::as_select())
.first(&mut conn)
.await
.unwrap();
assert_eq!(root_row.data_encryption_nonce, max_nonce);
}
#[tokio::test]
#[test_log::test]
async fn insert_failure_does_not_create_partial_row() {
let db = db::create_test_pool().await;
let mut actor = common::bootstrapped_keyholder(&db).await;
let root_key_history_id = common::root_key_history_id(&db).await;
let mut conn = db.get().await.unwrap();
let before_count: i64 = schema::aead_encrypted::table
.count()
.get_result(&mut conn)
.await
.unwrap();
let before_root_nonce: Vec<u8> = schema::root_key_history::table
.filter(schema::root_key_history::id.eq(root_key_history_id))
.select(schema::root_key_history::data_encryption_nonce)
.first(&mut conn)
.await
.unwrap();
sql_query(
"CREATE TRIGGER fail_aead_insert BEFORE INSERT ON aead_encrypted BEGIN SELECT RAISE(ABORT, 'forced test failure'); END;",
)
.execute(&mut conn)
.await
.unwrap();
drop(conn);
let err = actor
.create_new(MemSafe::new(b"should fail".to_vec()).unwrap())
.await
.unwrap_err();
assert!(matches!(err, Error::DatabaseTransaction(_)));
let mut conn = db.get().await.unwrap();
sql_query("DROP TRIGGER fail_aead_insert;")
.execute(&mut conn)
.await
.unwrap();
let after_count: i64 = schema::aead_encrypted::table
.count()
.get_result(&mut conn)
.await
.unwrap();
assert_eq!(
before_count, after_count,
"failed insert must not create row"
);
let after_root_nonce: Vec<u8> = schema::root_key_history::table
.filter(schema::root_key_history::id.eq(root_key_history_id))
.select(schema::root_key_history::data_encryption_nonce)
.first(&mut conn)
.await
.unwrap();
assert!(
after_root_nonce > before_root_nonce,
"current behavior allows nonce gap on failed insert"
);
}
#[tokio::test]
#[test_log::test]
async fn decrypt_roundtrip_after_high_concurrency() {
let db = db::create_test_pool().await;
let actor = KeyHolder::spawn(common::bootstrapped_keyholder(&db).await);
let writes = write_concurrently(actor, "roundtrip", 40).await;
let expected: HashMap<i32, Vec<u8>> = writes.into_iter().collect();
let mut decryptor = KeyHolder::new(db.clone()).await.unwrap();
decryptor
.try_unseal(MemSafe::new(b"test-seal-key".to_vec()).unwrap())
.await
.unwrap();
for (id, plaintext) in expected {
let mut decrypted = decryptor.decrypt(id).await.unwrap();
assert_eq!(*decrypted.read().unwrap(), plaintext);
}
}

View File

@@ -0,0 +1,134 @@
use arbiter_server::{
actors::keyholder::{Error, KeyHolder},
db::{self, models, schema},
};
use diesel::{QueryDsl, SelectableHelper};
use diesel_async::RunQueryDsl;
use memsafe::MemSafe;
use crate::common;
#[tokio::test]
#[test_log::test]
async fn test_bootstrap() {
let db = db::create_test_pool().await;
common::seed_settings(&db).await;
let mut actor = KeyHolder::new(db.clone()).await.unwrap();
let seal_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap();
actor.bootstrap(seal_key).await.unwrap();
let mut conn = db.get().await.unwrap();
let row: models::RootKeyHistory = schema::root_key_history::table
.select(models::RootKeyHistory::as_select())
.first(&mut conn)
.await
.unwrap();
assert_eq!(row.schema_version, 1);
assert_eq!(
row.tag,
arbiter_server::actors::keyholder::encryption::v1::ROOT_KEY_TAG
);
assert!(!row.ciphertext.is_empty());
assert!(!row.salt.is_empty());
assert_eq!(
row.data_encryption_nonce,
arbiter_server::actors::keyholder::encryption::v1::Nonce::default().to_vec()
);
}
#[tokio::test]
#[test_log::test]
async fn test_bootstrap_rejects_double() {
let db = db::create_test_pool().await;
let mut actor = common::bootstrapped_keyholder(&db).await;
let seal_key2 = MemSafe::new(b"test-seal-key".to_vec()).unwrap();
let err = actor.bootstrap(seal_key2).await.unwrap_err();
assert!(matches!(err, Error::AlreadyBootstrapped));
}
#[tokio::test]
#[test_log::test]
async fn test_create_new_before_bootstrap_fails() {
let db = db::create_test_pool().await;
common::seed_settings(&db).await;
let mut actor = KeyHolder::new(db).await.unwrap();
let err = actor
.create_new(MemSafe::new(b"data".to_vec()).unwrap())
.await
.unwrap_err();
assert!(matches!(err, Error::NotBootstrapped));
}
#[tokio::test]
#[test_log::test]
async fn test_decrypt_before_bootstrap_fails() {
let db = db::create_test_pool().await;
common::seed_settings(&db).await;
let mut actor = KeyHolder::new(db).await.unwrap();
let err = actor.decrypt(1).await.unwrap_err();
assert!(matches!(err, Error::NotBootstrapped));
}
#[tokio::test]
#[test_log::test]
async fn test_new_restores_sealed_state() {
let db = db::create_test_pool().await;
let actor = common::bootstrapped_keyholder(&db).await;
drop(actor);
let mut actor2 = KeyHolder::new(db).await.unwrap();
let err = actor2.decrypt(1).await.unwrap_err();
assert!(matches!(err, Error::NotBootstrapped));
}
#[tokio::test]
#[test_log::test]
async fn test_unseal_correct_password() {
let db = db::create_test_pool().await;
let mut actor = common::bootstrapped_keyholder(&db).await;
let plaintext = b"survive a restart";
let aead_id = actor
.create_new(MemSafe::new(plaintext.to_vec()).unwrap())
.await
.unwrap();
drop(actor);
let mut actor = KeyHolder::new(db.clone()).await.unwrap();
let seal_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap();
actor.try_unseal(seal_key).await.unwrap();
let mut decrypted = actor.decrypt(aead_id).await.unwrap();
assert_eq!(*decrypted.read().unwrap(), plaintext);
}
#[tokio::test]
#[test_log::test]
async fn test_unseal_wrong_then_correct_password() {
let db = db::create_test_pool().await;
let mut actor = common::bootstrapped_keyholder(&db).await;
let plaintext = b"important data";
let aead_id = actor
.create_new(MemSafe::new(plaintext.to_vec()).unwrap())
.await
.unwrap();
drop(actor);
let mut actor = KeyHolder::new(db.clone()).await.unwrap();
let bad_key = MemSafe::new(b"wrong-password".to_vec()).unwrap();
let err = actor.try_unseal(bad_key).await.unwrap_err();
assert!(matches!(err, Error::InvalidKey));
let good_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap();
actor.try_unseal(good_key).await.unwrap();
let mut decrypted = actor.decrypt(aead_id).await.unwrap();
assert_eq!(*decrypted.read().unwrap(), plaintext);
}

View File

@@ -0,0 +1,161 @@
use std::collections::HashSet;
use arbiter_server::{
actors::keyholder::{Error, encryption::v1},
db::{self, models, schema},
};
use diesel::{ExpressionMethods as _, QueryDsl, SelectableHelper, dsl::update};
use diesel_async::RunQueryDsl;
use memsafe::MemSafe;
use crate::common;
#[tokio::test]
#[test_log::test]
async fn test_create_decrypt_roundtrip() {
let db = db::create_test_pool().await;
let mut actor = common::bootstrapped_keyholder(&db).await;
let plaintext = b"hello arbiter";
let aead_id = actor
.create_new(MemSafe::new(plaintext.to_vec()).unwrap())
.await
.unwrap();
let mut decrypted = actor.decrypt(aead_id).await.unwrap();
assert_eq!(*decrypted.read().unwrap(), plaintext);
}
#[tokio::test]
#[test_log::test]
async fn test_decrypt_nonexistent_returns_not_found() {
let db = db::create_test_pool().await;
let mut actor = common::bootstrapped_keyholder(&db).await;
let err = actor.decrypt(9999).await.unwrap_err();
assert!(matches!(err, Error::NotFound));
}
#[tokio::test]
#[test_log::test]
async fn test_ciphertext_differs_across_entries() {
let db = db::create_test_pool().await;
let mut actor = common::bootstrapped_keyholder(&db).await;
let plaintext = b"same content";
let id1 = actor
.create_new(MemSafe::new(plaintext.to_vec()).unwrap())
.await
.unwrap();
let id2 = actor
.create_new(MemSafe::new(plaintext.to_vec()).unwrap())
.await
.unwrap();
let mut conn = db.get().await.unwrap();
let row1: models::AeadEncrypted = schema::aead_encrypted::table
.filter(schema::aead_encrypted::id.eq(id1))
.select(models::AeadEncrypted::as_select())
.first(&mut conn)
.await
.unwrap();
let row2: models::AeadEncrypted = schema::aead_encrypted::table
.filter(schema::aead_encrypted::id.eq(id2))
.select(models::AeadEncrypted::as_select())
.first(&mut conn)
.await
.unwrap();
assert_ne!(row1.ciphertext, row2.ciphertext);
let mut d1 = actor.decrypt(id1).await.unwrap();
let mut d2 = actor.decrypt(id2).await.unwrap();
assert_eq!(*d1.read().unwrap(), plaintext);
assert_eq!(*d2.read().unwrap(), plaintext);
}
#[tokio::test]
#[test_log::test]
async fn test_nonce_never_reused() {
let db = db::create_test_pool().await;
let mut actor = common::bootstrapped_keyholder(&db).await;
let n = 5;
for i in 0..n {
actor
.create_new(MemSafe::new(format!("secret {i}").into_bytes()).unwrap())
.await
.unwrap();
}
let mut conn = db.get().await.unwrap();
let rows: Vec<models::AeadEncrypted> = schema::aead_encrypted::table
.select(models::AeadEncrypted::as_select())
.load(&mut conn)
.await
.unwrap();
assert_eq!(rows.len(), n);
let nonces: Vec<&Vec<u8>> = rows.iter().map(|r| &r.current_nonce).collect();
let unique: HashSet<&Vec<u8>> = nonces.iter().copied().collect();
assert_eq!(nonces.len(), unique.len(), "all nonces must be unique");
for (i, row) in rows.iter().enumerate() {
let mut expected = v1::Nonce::default();
for _ in 0..=i {
expected.increment();
}
assert_eq!(row.current_nonce, expected.to_vec(), "nonce {i} mismatch");
}
let root_row: models::RootKeyHistory = schema::root_key_history::table
.select(models::RootKeyHistory::as_select())
.first(&mut conn)
.await
.unwrap();
let last_nonce = &rows.last().unwrap().current_nonce;
assert_eq!(&root_row.data_encryption_nonce, last_nonce);
}
#[tokio::test]
#[test_log::test]
async fn broken_db_nonce_format_fails_closed() {
let db = db::create_test_pool().await;
let mut actor = common::bootstrapped_keyholder(&db).await;
let root_key_history_id = common::root_key_history_id(&db).await;
let mut conn = db.get().await.unwrap();
update(
schema::root_key_history::table
.filter(schema::root_key_history::id.eq(root_key_history_id)),
)
.set(schema::root_key_history::data_encryption_nonce.eq(vec![1, 2, 3]))
.execute(&mut conn)
.await
.unwrap();
drop(conn);
let err = actor
.create_new(MemSafe::new(b"must fail".to_vec()).unwrap())
.await
.unwrap_err();
assert!(matches!(err, Error::BrokenDatabase));
let db = db::create_test_pool().await;
let mut actor = common::bootstrapped_keyholder(&db).await;
let id = actor
.create_new(MemSafe::new(b"decrypt target".to_vec()).unwrap())
.await
.unwrap();
let mut conn = db.get().await.unwrap();
update(schema::aead_encrypted::table.filter(schema::aead_encrypted::id.eq(id)))
.set(schema::aead_encrypted::current_nonce.eq(vec![7, 8]))
.execute(&mut conn)
.await
.unwrap();
drop(conn);
let err = actor.decrypt(id).await.unwrap_err();
assert!(matches!(err, Error::BrokenDatabase));
}

View File

@@ -0,0 +1,6 @@
mod common;
#[path = "user_agent/auth.rs"]
mod auth;
#[path = "user_agent/unseal.rs"]
mod unseal;

View File

@@ -3,39 +3,31 @@ use arbiter_proto::proto::{
auth::{self, AuthChallengeRequest, AuthOk}, auth::{self, AuthChallengeRequest, AuthOk},
user_agent_response::Payload as UserAgentResponsePayload, user_agent_response::Payload as UserAgentResponsePayload,
}; };
use chrono::format; use arbiter_server::{
actors::{
GlobalActors,
bootstrap::GetToken,
user_agent::{HandleAuthChallengeRequest, HandleAuthChallengeSolution, UserAgentActor},
},
db::{self, schema},
};
use diesel::{ExpressionMethods as _, QueryDsl, insert_into}; use diesel::{ExpressionMethods as _, QueryDsl, insert_into};
use diesel_async::RunQueryDsl; use diesel_async::RunQueryDsl;
use ed25519_dalek::Signer as _; use ed25519_dalek::Signer as _;
use kameo::actor::Spawn; use kameo::actor::Spawn;
use crate::{
actors::{
bootstrap::Bootstrapper,
user_agent::{HandleAuthChallengeRequest, HandleAuthChallengeSolution},
},
db::{self, schema},
};
use super::UserAgentActor;
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
pub async fn test_bootstrap_token_auth() { pub async fn test_bootstrap_token_auth() {
let db = db::create_test_pool().await; let db =db::create_test_pool().await;
// explicitly not installing any user_agent pubkeys crate::common::seed_settings(&db).await;
let bootstrapper = Bootstrapper::new(&db).await.unwrap(); // this will create bootstrap token
let token = bootstrapper.get_token().unwrap();
let bootstrapper_ref = Bootstrapper::spawn(bootstrapper); let actors = GlobalActors::spawn(db.clone()).await.unwrap();
let user_agent = UserAgentActor::new_manual( let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap();
db.clone(), let user_agent =
bootstrapper_ref, UserAgentActor::new_manual(db.clone(), actors, tokio::sync::mpsc::channel(1).0);
tokio::sync::mpsc::channel(1).0, // dummy channel, we won't actually send responses in this test
);
let user_agent_ref = UserAgentActor::spawn(user_agent); let user_agent_ref = UserAgentActor::spawn(user_agent);
// simulate client sending auth request with bootstrap token
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
@@ -49,7 +41,6 @@ pub async fn test_bootstrap_token_auth() {
.await .await
.expect("Shouldn't fail to send message"); .expect("Shouldn't fail to send message");
// auth succeeded
assert_eq!( assert_eq!(
result, result,
UserAgentResponse { UserAgentResponse {
@@ -63,7 +54,6 @@ pub async fn test_bootstrap_token_auth() {
} }
); );
// key is succesfully recorded in database
let mut conn = db.get().await.unwrap(); let mut conn = db.get().await.unwrap();
let stored_pubkey: Vec<u8> = schema::useragent_client::table let stored_pubkey: Vec<u8> = schema::useragent_client::table
.select(schema::useragent_client::public_key) .select(schema::useragent_client::public_key)
@@ -77,18 +67,13 @@ pub async fn test_bootstrap_token_auth() {
#[test_log::test] #[test_log::test]
pub async fn test_bootstrap_invalid_token_auth() { pub async fn test_bootstrap_invalid_token_auth() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
// explicitly not installing any user_agent pubkeys crate::common::seed_settings(&db).await;
let bootstrapper = Bootstrapper::new(&db).await.unwrap(); // this will create bootstrap token
let bootstrapper_ref = Bootstrapper::spawn(bootstrapper); let actors = GlobalActors::spawn(db.clone()).await.unwrap();
let user_agent = UserAgentActor::new_manual( let user_agent =
db.clone(), UserAgentActor::new_manual(db.clone(), actors, tokio::sync::mpsc::channel(1).0);
bootstrapper_ref,
tokio::sync::mpsc::channel(1).0, // dummy channel, we won't actually send responses in this test
);
let user_agent_ref = UserAgentActor::spawn(user_agent); let user_agent_ref = UserAgentActor::spawn(user_agent);
// simulate client sending auth request with bootstrap token
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
@@ -125,20 +110,16 @@ pub async fn test_bootstrap_invalid_token_auth() {
#[test_log::test] #[test_log::test]
pub async fn test_challenge_auth() { pub async fn test_challenge_auth() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
crate::common::seed_settings(&db).await;
let bootstrapper_ref = Bootstrapper::spawn(Bootstrapper::new(&db).await.unwrap()); let actors = GlobalActors::spawn(db.clone()).await.unwrap();
let user_agent = UserAgentActor::new_manual( let user_agent =
db.clone(), UserAgentActor::new_manual(db.clone(), actors, tokio::sync::mpsc::channel(1).0);
bootstrapper_ref,
tokio::sync::mpsc::channel(1).0, // dummy channel, we won't actually send responses in this test
);
let user_agent_ref = UserAgentActor::spawn(user_agent); let user_agent_ref = UserAgentActor::spawn(user_agent);
// simulate client sending auth request with bootstrap token
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
// insert pubkey into database to trigger challenge-response auth flow
{ {
let mut conn = db.get().await.unwrap(); let mut conn = db.get().await.unwrap();
insert_into(schema::useragent_client::table) insert_into(schema::useragent_client::table)
@@ -158,7 +139,6 @@ pub async fn test_challenge_auth() {
.await .await
.expect("Shouldn't fail to send message"); .expect("Shouldn't fail to send message");
// auth challenge succeeded
let UserAgentResponse { let UserAgentResponse {
payload: payload:
Some(UserAgentResponsePayload::AuthMessage(arbiter_proto::proto::auth::ServerMessage { Some(UserAgentResponsePayload::AuthMessage(arbiter_proto::proto::auth::ServerMessage {
@@ -183,7 +163,6 @@ pub async fn test_challenge_auth() {
.await .await
.expect("Shouldn't fail to send message"); .expect("Shouldn't fail to send message");
// auth succeeded
assert_eq!( assert_eq!(
result, result,
UserAgentResponse { UserAgentResponse {

View File

@@ -0,0 +1,229 @@
use arbiter_proto::proto::{
UnsealEncryptedKey, UnsealResult, UnsealStart, auth::AuthChallengeRequest,
user_agent_response::Payload as UserAgentResponsePayload,
};
use arbiter_server::{
actors::{
GlobalActors,
bootstrap::GetToken,
keyholder::{Bootstrap, Seal},
user_agent::{
HandleAuthChallengeRequest, HandleUnsealEncryptedKey, HandleUnsealRequest,
UserAgentActor,
},
},
db,
};
use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit};
use kameo::actor::{ActorRef, Spawn};
use memsafe::MemSafe;
use x25519_dalek::{EphemeralSecret, PublicKey};
async fn setup_authenticated_user_agent(
seal_key: &[u8],
) -> (arbiter_server::db::DatabasePool, ActorRef<UserAgentActor>) {
let db = db::create_test_pool().await;
crate::common::seed_settings(&db).await;
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
actors
.key_holder
.ask(Bootstrap {
seal_key_raw: MemSafe::new(seal_key.to_vec()).unwrap(),
})
.await
.unwrap();
actors.key_holder.ask(Seal).await.unwrap();
let user_agent =
UserAgentActor::new_manual(db.clone(), actors.clone(), tokio::sync::mpsc::channel(1).0);
let user_agent_ref = UserAgentActor::spawn(user_agent);
let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap();
let auth_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
user_agent_ref
.ask(HandleAuthChallengeRequest {
req: AuthChallengeRequest {
pubkey: auth_key.verifying_key().to_bytes().to_vec(),
bootstrap_token: Some(token),
},
})
.await
.unwrap();
(db, user_agent_ref)
}
async fn client_dh_encrypt(
user_agent_ref: &ActorRef<UserAgentActor>,
key_to_send: &[u8],
) -> UnsealEncryptedKey {
let client_secret = EphemeralSecret::random();
let client_public = PublicKey::from(&client_secret);
let response = user_agent_ref
.ask(HandleUnsealRequest {
req: UnsealStart {
client_pubkey: client_public.as_bytes().to_vec(),
},
})
.await
.unwrap();
let server_pubkey = match response.payload.unwrap() {
UserAgentResponsePayload::UnsealStartResponse(resp) => resp.server_pubkey,
other => panic!("Expected UnsealStartResponse, got {other:?}"),
};
let server_public = PublicKey::from(<[u8; 32]>::try_from(server_pubkey.as_slice()).unwrap());
let shared_secret = client_secret.diffie_hellman(&server_public);
let cipher = XChaCha20Poly1305::new(shared_secret.as_bytes().into());
let nonce = XNonce::from([0u8; 24]);
let associated_data = b"unseal";
let mut ciphertext = key_to_send.to_vec();
cipher
.encrypt_in_place(&nonce, associated_data, &mut ciphertext)
.unwrap();
UnsealEncryptedKey {
nonce: nonce.to_vec(),
ciphertext,
associated_data: associated_data.to_vec(),
}
}
#[tokio::test]
#[test_log::test]
pub async fn test_unseal_success() {
let seal_key = b"test-seal-key";
let (_db, user_agent_ref) = setup_authenticated_user_agent(seal_key).await;
let encrypted_key = client_dh_encrypt(&user_agent_ref, seal_key).await;
let response = user_agent_ref
.ask(HandleUnsealEncryptedKey { req: encrypted_key })
.await
.unwrap();
assert_eq!(
response.payload.unwrap(),
UserAgentResponsePayload::UnsealResult(UnsealResult::Success.into()),
);
}
#[tokio::test]
#[test_log::test]
pub async fn test_unseal_wrong_seal_key() {
let (_db, user_agent_ref) = setup_authenticated_user_agent(b"correct-key").await;
let encrypted_key = client_dh_encrypt(&user_agent_ref, b"wrong-key").await;
let response = user_agent_ref
.ask(HandleUnsealEncryptedKey { req: encrypted_key })
.await
.unwrap();
assert_eq!(
response.payload.unwrap(),
UserAgentResponsePayload::UnsealResult(UnsealResult::InvalidKey.into()),
);
}
#[tokio::test]
#[test_log::test]
pub async fn test_unseal_corrupted_ciphertext() {
let (_db, user_agent_ref) = setup_authenticated_user_agent(b"test-key").await;
let client_secret = EphemeralSecret::random();
let client_public = PublicKey::from(&client_secret);
user_agent_ref
.ask(HandleUnsealRequest {
req: UnsealStart {
client_pubkey: client_public.as_bytes().to_vec(),
},
})
.await
.unwrap();
let response = user_agent_ref
.ask(HandleUnsealEncryptedKey {
req: UnsealEncryptedKey {
nonce: vec![0u8; 24],
ciphertext: vec![0u8; 32],
associated_data: vec![],
},
})
.await
.unwrap();
assert_eq!(
response.payload.unwrap(),
UserAgentResponsePayload::UnsealResult(UnsealResult::InvalidKey.into()),
);
}
#[tokio::test]
#[test_log::test]
pub async fn test_unseal_start_without_auth_fails() {
let db = db::create_test_pool().await;
crate::common::seed_settings(&db).await;
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
let user_agent =
UserAgentActor::new_manual(db.clone(), actors, tokio::sync::mpsc::channel(1).0);
let user_agent_ref = UserAgentActor::spawn(user_agent);
let client_secret = EphemeralSecret::random();
let client_public = PublicKey::from(&client_secret);
let result = user_agent_ref
.ask(HandleUnsealRequest {
req: UnsealStart {
client_pubkey: client_public.as_bytes().to_vec(),
},
})
.await;
match result {
Err(kameo::error::SendError::HandlerError(status)) => {
assert_eq!(status.code(), tonic::Code::Internal);
}
other => panic!("Expected state machine error, got {other:?}"),
}
}
#[tokio::test]
#[test_log::test]
pub async fn test_unseal_retry_after_invalid_key() {
let seal_key = b"real-seal-key";
let (_db, user_agent_ref) = setup_authenticated_user_agent(seal_key).await;
{
let encrypted_key = client_dh_encrypt(&user_agent_ref, b"wrong-key").await;
let response = user_agent_ref
.ask(HandleUnsealEncryptedKey { req: encrypted_key })
.await
.unwrap();
assert_eq!(
response.payload.unwrap(),
UserAgentResponsePayload::UnsealResult(UnsealResult::InvalidKey.into()),
);
}
{
let encrypted_key = client_dh_encrypt(&user_agent_ref, seal_key).await;
let response = user_agent_ref
.ask(HandleUnsealEncryptedKey { req: encrypted_key })
.await
.unwrap();
assert_eq!(
response.payload.unwrap(),
UserAgentResponsePayload::UnsealResult(UnsealResult::Success.into()),
);
}
}