refactor(server): separated global actors into their own handle

This commit is contained in:
hdbg
2026-02-16 21:44:11 +01:00
parent bdb9f01757
commit b3566c8af6
5 changed files with 92 additions and 80 deletions

View File

@@ -1,4 +1,40 @@
pub mod user_agent; use kameo::actor::{ActorRef, Spawn};
pub mod client; use miette::Diagnostic;
use thiserror::Error;
use crate::{
actors::{bootstrap::Bootstrapper, keyholder::KeyHolder},
db,
};
pub(crate) mod bootstrap; pub(crate) mod bootstrap;
pub(crate) mod keyholder; pub mod client;
pub(crate) mod keyholder;
pub mod user_agent;
#[derive(Error, Debug, Diagnostic)]
pub enum SpawnError {
#[error("Failed to spawn Bootstrapper actor")]
#[diagnostic(code(SpawnError::Bootstrapper))]
Bootstrapper(#[from] bootstrap::Error),
#[error("Failed to spawn KeyHolder actor")]
#[diagnostic(code(SpawnError::KeyHolder))]
KeyHolder(#[from] keyholder::Error),
}
/// Long-lived actors that are shared across all connections and handle global state and operations
#[derive(Clone)]
pub struct GlobalActors {
pub key_holder: ActorRef<KeyHolder>,
pub bootstrapper: ActorRef<Bootstrapper>,
}
impl GlobalActors {
pub async fn spawn(db: db::DatabasePool) -> Result<Self, SpawnError> {
Ok(Self {
bootstrapper: Bootstrapper::spawn(Bootstrapper::new(&db).await?),
key_holder: KeyHolder::spawn(KeyHolder::new(db.clone()).await?),
})
}
}

View File

@@ -28,7 +28,7 @@ pub async fn generate_token() -> Result<String, std::io::Error> {
} }
#[derive(Error, Debug, Diagnostic)] #[derive(Error, Debug, Diagnostic)]
pub enum BootstrapError { pub enum Error {
#[error("Database error: {0}")] #[error("Database error: {0}")]
#[diagnostic(code(arbiter_server::bootstrap::database))] #[diagnostic(code(arbiter_server::bootstrap::database))]
Database(#[from] db::PoolError), Database(#[from] db::PoolError),
@@ -48,7 +48,7 @@ pub struct Bootstrapper {
} }
impl Bootstrapper { impl Bootstrapper {
pub async fn new(db: &DatabasePool) -> Result<Self, BootstrapError> { pub async fn new(db: &DatabasePool) -> Result<Self, Error> {
let mut conn = db.get().await?; let mut conn = db.get().await?;
let row_count: i64 = schema::useragent_client::table let row_count: i64 = schema::useragent_client::table
@@ -69,11 +69,6 @@ impl Bootstrapper {
Ok(Self { token }) Ok(Self { token })
} }
#[cfg(test)]
pub fn get_token(&self) -> Option<String> {
self.token.clone()
}
} }
#[messages] #[messages]
@@ -96,3 +91,12 @@ impl Bootstrapper {
} }
} }
} }
#[cfg(test)]
#[messages]
impl Bootstrapper {
#[message]
pub fn get_token(&self) -> Option<String> {
self.token.clone()
}
}

View File

@@ -10,9 +10,9 @@ use arbiter_proto::proto::{
}; };
use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit}; use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit};
use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, dsl::update}; use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, dsl::update};
use diesel_async::{RunQueryDsl}; use diesel_async::RunQueryDsl;
use ed25519_dalek::VerifyingKey; use ed25519_dalek::VerifyingKey;
use kameo::{Actor, actor::ActorRef, error::SendError, messages}; use kameo::{Actor, error::SendError, messages};
use memsafe::MemSafe; use memsafe::MemSafe;
use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::Sender;
use tonic::Status; use tonic::Status;
@@ -22,11 +22,12 @@ use x25519_dalek::{EphemeralSecret, PublicKey};
use crate::{ use crate::{
ServerContext, ServerContext,
actors::{ actors::{
bootstrap::{Bootstrapper, ConsumeToken}, GlobalActors,
keyholder::{self, KeyHolder, TryUnseal}, bootstrap::ConsumeToken,
keyholder::{self, TryUnseal},
user_agent::state::{ user_agent::state::{
ChallengeContext, DummyContext, UnsealContext, UserAgentEvents, ChallengeContext, DummyContext, UnsealContext, UserAgentEvents, UserAgentStateMachine,
UserAgentStateMachine, UserAgentStates, UserAgentStates,
}, },
}, },
db::{self, schema}, db::{self, schema},
@@ -43,8 +44,7 @@ pub(crate) use transport::handle_user_agent;
#[derive(Actor)] #[derive(Actor)]
pub struct UserAgentActor { pub struct UserAgentActor {
db: db::DatabasePool, db: db::DatabasePool,
bootstapper: ActorRef<Bootstrapper>, actors: GlobalActors,
keyholder: ActorRef<KeyHolder>,
state: UserAgentStateMachine<DummyContext>, state: UserAgentStateMachine<DummyContext>,
// will be used in future // will be used in future
_tx: Sender<Result<UserAgentResponse, Status>>, _tx: Sender<Result<UserAgentResponse, Status>>,
@@ -57,8 +57,7 @@ impl UserAgentActor {
) -> Self { ) -> Self {
Self { Self {
db: context.db.clone(), db: context.db.clone(),
bootstapper: context.bootstrapper.clone(), actors: context.actors.clone(),
keyholder: context.keyholder.clone(),
state: UserAgentStateMachine::new(DummyContext), state: UserAgentStateMachine::new(DummyContext),
_tx: tx, _tx: tx,
} }
@@ -67,14 +66,12 @@ impl UserAgentActor {
#[cfg(test)] #[cfg(test)]
pub(crate) fn new_manual( pub(crate) fn new_manual(
db: db::DatabasePool, db: db::DatabasePool,
bootstapper: ActorRef<Bootstrapper>, actors: GlobalActors,
keyholder: ActorRef<KeyHolder>,
tx: Sender<Result<UserAgentResponse, Status>>, tx: Sender<Result<UserAgentResponse, Status>>,
) -> Self { ) -> Self {
Self { Self {
db, db,
bootstapper, actors,
keyholder,
state: UserAgentStateMachine::new(DummyContext), state: UserAgentStateMachine::new(DummyContext),
_tx: tx, _tx: tx,
} }
@@ -94,7 +91,8 @@ impl UserAgentActor {
token: String, token: String,
) -> Result<UserAgentResponse, Status> { ) -> Result<UserAgentResponse, Status> {
let token_ok: bool = self let token_ok: bool = self
.bootstapper .actors
.bootstrapper
.ask(ConsumeToken { token }) .ask(ConsumeToken { token })
.await .await
.map_err(|e| { .map_err(|e| {
@@ -288,7 +286,8 @@ impl UserAgentActor {
match decryption_result { match decryption_result {
Ok(_) => { Ok(_) => {
match self match self
.keyholder .actors
.key_holder
.ask(TryUnseal { .ask(TryUnseal {
seal_key_raw: seal_key_buffer, seal_key_raw: seal_key_buffer,
}) })

View File

@@ -13,8 +13,9 @@ use x25519_dalek::{EphemeralSecret, PublicKey};
use crate::{ use crate::{
actors::{ actors::{
bootstrap::Bootstrapper, GlobalActors,
keyholder::KeyHolder, bootstrap::GetToken,
keyholder::{Bootstrap, Seal},
user_agent::{ user_agent::{
HandleAuthChallengeRequest, HandleAuthChallengeSolution, HandleUnsealEncryptedKey, HandleAuthChallengeRequest, HandleAuthChallengeSolution, HandleUnsealEncryptedKey,
HandleUnsealRequest, HandleUnsealRequest,
@@ -47,25 +48,23 @@ async fn setup_authenticated_user_agent(
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
seed_settings(&db).await; seed_settings(&db).await;
let mut keyholder = KeyHolder::new(db.clone()).await.unwrap(); let actors = GlobalActors::spawn(db.clone()).await.unwrap();
keyholder actors
.bootstrap(MemSafe::new(seal_key.to_vec()).unwrap()) .key_holder
.ask(Bootstrap {
seal_key_raw: MemSafe::new(seal_key.to_vec()).unwrap(),
})
.await .await
.unwrap(); .unwrap();
keyholder.seal().unwrap(); actors.key_holder.ask(Seal).await.unwrap();
let keyholder_ref = KeyHolder::spawn(keyholder);
let bootstrapper = Bootstrapper::new(&db).await.unwrap();
let token = bootstrapper.get_token().unwrap();
let bootstrapper_ref = Bootstrapper::spawn(bootstrapper);
let user_agent = UserAgentActor::new_manual( let user_agent = UserAgentActor::new_manual(
db.clone(), db.clone(),
bootstrapper_ref, actors.clone(),
keyholder_ref,
tokio::sync::mpsc::channel(1).0, tokio::sync::mpsc::channel(1).0,
); );
let user_agent_ref = UserAgentActor::spawn(user_agent); let user_agent_ref = UserAgentActor::spawn(user_agent);
let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap();
let auth_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let auth_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
user_agent_ref user_agent_ref
@@ -128,17 +127,11 @@ async fn client_dh_encrypt(
pub async fn test_bootstrap_token_auth() { pub async fn test_bootstrap_token_auth() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
seed_settings(&db).await; seed_settings(&db).await;
// explicitly not installing any user_agent pubkeys let actors = GlobalActors::spawn(db.clone()).await.unwrap();
let bootstrapper = Bootstrapper::new(&db).await.unwrap(); // this will create bootstrap token let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap();
let keyholder = KeyHolder::new(db.clone()).await.unwrap();
let token = bootstrapper.get_token().unwrap();
let bootstrapper_ref = Bootstrapper::spawn(bootstrapper);
let keyholder_ref = KeyHolder::spawn(keyholder);
let user_agent = UserAgentActor::new_manual( let user_agent = UserAgentActor::new_manual(
db.clone(), db.clone(),
bootstrapper_ref, actors.clone(),
keyholder_ref,
tokio::sync::mpsc::channel(1).0, // dummy channel, we won't actually send responses in this test tokio::sync::mpsc::channel(1).0, // dummy channel, we won't actually send responses in this test
); );
let user_agent_ref = UserAgentActor::spawn(user_agent); let user_agent_ref = UserAgentActor::spawn(user_agent);
@@ -186,17 +179,11 @@ pub async fn test_bootstrap_token_auth() {
pub async fn test_bootstrap_invalid_token_auth() { pub async fn test_bootstrap_invalid_token_auth() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
seed_settings(&db).await; seed_settings(&db).await;
// explicitly not installing any user_agent pubkeys let actors = GlobalActors::spawn(db.clone()).await.unwrap();
let bootstrapper = Bootstrapper::new(&db).await.unwrap(); // this will create bootstrap token
let keyholder = KeyHolder::new(db.clone()).await.unwrap();
let bootstrapper_ref = Bootstrapper::spawn(bootstrapper);
let keyholder_ref = KeyHolder::spawn(keyholder);
let user_agent = UserAgentActor::new_manual( let user_agent = UserAgentActor::new_manual(
db.clone(), db.clone(),
bootstrapper_ref, actors,
keyholder_ref,
tokio::sync::mpsc::channel(1).0, // dummy channel, we won't actually send responses in this test tokio::sync::mpsc::channel(1).0, // dummy channel, we won't actually send responses in this test
); );
let user_agent_ref = UserAgentActor::spawn(user_agent); let user_agent_ref = UserAgentActor::spawn(user_agent);
@@ -240,12 +227,10 @@ pub async fn test_challenge_auth() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
seed_settings(&db).await; seed_settings(&db).await;
let bootstrapper_ref = Bootstrapper::spawn(Bootstrapper::new(&db).await.unwrap()); let actors = GlobalActors::spawn(db.clone()).await.unwrap();
let keyholder_ref = KeyHolder::spawn(KeyHolder::new(db.clone()).await.unwrap());
let user_agent = UserAgentActor::new_manual( let user_agent = UserAgentActor::new_manual(
db.clone(), db.clone(),
bootstrapper_ref, actors,
keyholder_ref,
tokio::sync::mpsc::channel(1).0, // dummy channel, we won't actually send responses in this test tokio::sync::mpsc::channel(1).0, // dummy channel, we won't actually send responses in this test
); );
let user_agent_ref = UserAgentActor::spawn(user_agent); let user_agent_ref = UserAgentActor::spawn(user_agent);
@@ -394,13 +379,11 @@ pub async fn test_unseal_start_without_auth_fails() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
seed_settings(&db).await; seed_settings(&db).await;
let keyholder_ref = KeyHolder::spawn( KeyHolder::new(db.clone()).await.unwrap()); let actors = GlobalActors::spawn(db.clone()).await.unwrap();
let bootstrapper_ref = Bootstrapper::spawn(Bootstrapper::new(&db).await.unwrap());
let user_agent = UserAgentActor::new_manual( let user_agent = UserAgentActor::new_manual(
db.clone(), db.clone(),
bootstrapper_ref, actors,
keyholder_ref,
tokio::sync::mpsc::channel(1).0, tokio::sync::mpsc::channel(1).0,
); );
let user_agent_ref = UserAgentActor::spawn(user_agent); let user_agent_ref = UserAgentActor::spawn(user_agent);

View File

@@ -2,15 +2,11 @@ use std::sync::Arc;
use diesel::OptionalExtension as _; use diesel::OptionalExtension as _;
use diesel_async::RunQueryDsl as _; use diesel_async::RunQueryDsl as _;
use kameo::actor::{ActorRef, Spawn};
use miette::Diagnostic; use miette::Diagnostic;
use thiserror::Error; use thiserror::Error;
use crate::{ use crate::{
actors::{ actors::GlobalActors,
bootstrap::{self, Bootstrapper},
keyholder::KeyHolder,
},
context::tls::{TlsDataRaw, TlsManager}, context::tls::{TlsDataRaw, TlsManager},
db::{self, models::ArbiterSetting, schema::arbiter_settings}, db::{self, models::ArbiterSetting, schema::arbiter_settings},
}; };
@@ -35,13 +31,9 @@ pub enum InitError {
#[diagnostic(code(arbiter_server::init::tls_init))] #[diagnostic(code(arbiter_server::init::tls_init))]
Tls(#[from] tls::TlsInitError), Tls(#[from] tls::TlsInitError),
#[error("Bootstrap token generation failed: {0}")] #[error("Actor spawn failed: {0}")]
#[diagnostic(code(arbiter_server::init::bootstrap_token))] #[diagnostic(code(arbiter_server::init::actor_spawn))]
BootstrapToken(#[from] bootstrap::BootstrapError), ActorSpawn(#[from] crate::actors::SpawnError),
#[error("KeyHolder initialization failed: {0}")]
#[diagnostic(code(arbiter_server::init::keyholder_init))]
KeyHolder(#[from] crate::actors::keyholder::Error),
#[error("I/O Error: {0}")] #[error("I/O Error: {0}")]
#[diagnostic(code(arbiter_server::init::io))] #[diagnostic(code(arbiter_server::init::io))]
@@ -51,8 +43,7 @@ pub enum InitError {
pub struct _ServerContextInner { pub struct _ServerContextInner {
pub db: db::DatabasePool, pub db: db::DatabasePool,
pub tls: TlsManager, pub tls: TlsManager,
pub bootstrapper: ActorRef<Bootstrapper>, pub actors: GlobalActors,
pub keyholder: ActorRef<KeyHolder>,
} }
#[derive(Clone)] #[derive(Clone)]
pub struct ServerContext(Arc<_ServerContextInner>); pub struct ServerContext(Arc<_ServerContextInner>);
@@ -111,8 +102,7 @@ impl ServerContext {
drop(conn); drop(conn);
Ok(Self(Arc::new(_ServerContextInner { Ok(Self(Arc::new(_ServerContextInner {
bootstrapper: Bootstrapper::spawn(Bootstrapper::new(&db).await?), actors: GlobalActors::spawn(db.clone()).await?,
keyholder: KeyHolder::spawn(KeyHolder::new(db.clone()).await?),
db, db,
tls, tls,
}))) })))