From cb05407bb69acadbd1fd82489959770e55b054b8 Mon Sep 17 00:00:00 2001 From: hdbg Date: Sun, 1 Mar 2026 11:35:06 +0100 Subject: [PATCH] feat(server): broker agent for inter-actor coordination --- .../src/actors/client/auth/mod.rs | 6 +- .../src/actors/client/auth/state.rs | 6 +- .../arbiter-server/src/actors/client/mod.rs | 22 ++++-- .../src/actors/client/session.rs | 22 ++++-- .../crates/arbiter-server/src/actors/mod.rs | 7 +- .../arbiter-server/src/actors/router/mod.rs | 79 +++++++++++++++++++ .../src/actors/user_agent/auth.rs | 6 +- .../src/actors/user_agent/auth/state.rs | 6 +- .../src/actors/user_agent/mod.rs | 23 +++--- .../src/actors/user_agent/session.rs | 24 ++++-- server/crates/arbiter-server/src/lib.rs | 18 ++++- .../arbiter-server/tests/client/auth.rs | 12 ++- .../arbiter-server/tests/user_agent/auth.rs | 8 +- 13 files changed, 185 insertions(+), 54 deletions(-) create mode 100644 server/crates/arbiter-server/src/actors/router/mod.rs diff --git a/server/crates/arbiter-server/src/actors/client/auth/mod.rs b/server/crates/arbiter-server/src/actors/client/auth/mod.rs index 33a9826..06b9d29 100644 --- a/server/crates/arbiter-server/src/actors/client/auth/mod.rs +++ b/server/crates/arbiter-server/src/actors/client/auth/mod.rs @@ -6,7 +6,7 @@ use ed25519_dalek::VerifyingKey; use tracing::error; use crate::actors::client::{ - ConnectionProps, + ClientConnection, auth::state::{AuthContext, AuthStateMachine}, session::ClientSession, }; @@ -54,7 +54,7 @@ fn parse_auth_event(payload: ClientRequestPayload) -> Result } } -pub async fn authenticate(props: &mut ConnectionProps) -> Result { +pub async fn authenticate(props: &mut ClientConnection) -> Result { let mut state = AuthStateMachine::new(AuthContext::new(props)); loop { @@ -93,7 +93,7 @@ pub async fn authenticate(props: &mut ConnectionProps) -> Result Result { let key = authenticate(&mut props).await?; let session = ClientSession::new(props, key); diff --git a/server/crates/arbiter-server/src/actors/client/auth/state.rs b/server/crates/arbiter-server/src/actors/client/auth/state.rs index 550934f..bfa2dc3 100644 --- a/server/crates/arbiter-server/src/actors/client/auth/state.rs +++ b/server/crates/arbiter-server/src/actors/client/auth/state.rs @@ -8,7 +8,7 @@ use ed25519_dalek::VerifyingKey; use tracing::error; use super::Error; -use crate::{actors::client::ConnectionProps, db::schema}; +use crate::{actors::client::ClientConnection, db::schema}; pub struct ChallengeRequest { pub pubkey: VerifyingKey, @@ -68,11 +68,11 @@ async fn create_nonce(db: &crate::db::DatabasePool, pubkey_bytes: &[u8]) -> Resu } pub struct AuthContext<'a> { - pub(super) conn: &'a mut ConnectionProps, + pub(super) conn: &'a mut ClientConnection, } impl<'a> AuthContext<'a> { - pub fn new(conn: &'a mut ConnectionProps) -> Self { + pub fn new(conn: &'a mut ClientConnection) -> Self { Self { conn } } } diff --git a/server/crates/arbiter-server/src/actors/client/mod.rs b/server/crates/arbiter-server/src/actors/client/mod.rs index 4a9f131..dd7bcc5 100644 --- a/server/crates/arbiter-server/src/actors/client/mod.rs +++ b/server/crates/arbiter-server/src/actors/client/mod.rs @@ -5,7 +5,10 @@ use arbiter_proto::{ use kameo::actor::Spawn; use tracing::{error, info}; -use crate::{actors::client::session::ClientSession, db}; +use crate::{ + actors::{GlobalActors, client::session::ClientSession}, + db, +}; #[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] pub enum ClientError { @@ -15,27 +18,34 @@ pub enum ClientError { UnexpectedRequestPayload, #[error("State machine error")] StateTransitionFailed, + #[error("Connection registration failed")] + ConnectionRegistrationFailed, #[error(transparent)] Auth(#[from] auth::Error), } pub type Transport = Box> + Send>; -pub struct ConnectionProps { +pub struct ClientConnection { pub(crate) db: db::DatabasePool, pub(crate) transport: Transport, + pub(crate) actors: GlobalActors, } -impl ConnectionProps { - pub fn new(db: db::DatabasePool, transport: Transport) -> Self { - Self { db, transport } +impl ClientConnection { + pub fn new(db: db::DatabasePool, transport: Transport, actors: GlobalActors) -> Self { + Self { + db, + transport, + actors, + } } } pub mod auth; pub mod session; -pub async fn connect_client(props: ConnectionProps) { +pub async fn connect_client(props: ClientConnection) { match auth::authenticate_and_create(props).await { Ok(session) => { ClientSession::spawn(session); diff --git a/server/crates/arbiter-server/src/actors/client/session.rs b/server/crates/arbiter-server/src/actors/client/session.rs index 29a0dd6..a0d21ca 100644 --- a/server/crates/arbiter-server/src/actors/client/session.rs +++ b/server/crates/arbiter-server/src/actors/client/session.rs @@ -4,15 +4,17 @@ use kameo::Actor; use tokio::select; use tracing::{error, info}; -use crate::actors::client::{ClientError, ConnectionProps}; +use crate::{actors::{ + GlobalActors, client::{ClientError, ClientConnection}, router::RegisterClient +}, db}; pub struct ClientSession { - props: ConnectionProps, + props: ClientConnection, key: VerifyingKey, } impl ClientSession { - pub(crate) fn new(props: ConnectionProps, key: VerifyingKey) -> Self { + pub(crate) fn new(props: ClientConnection, key: VerifyingKey) -> Self { Self { props, key } } @@ -33,12 +35,18 @@ type Output = Result; impl Actor for ClientSession { type Args = Self; - type Error = (); + type Error = ClientError; async fn on_start( args: Self::Args, - _: kameo::prelude::ActorRef, + this: kameo::prelude::ActorRef, ) -> Result { + args.props + .actors + .router + .ask(RegisterClient { actor: this }) + .await + .map_err(|_| ClientError::ConnectionRegistrationFailed)?; Ok(args) } @@ -80,10 +88,10 @@ impl Actor for ClientSession { } impl ClientSession { - pub fn new_test(db: crate::db::DatabasePool) -> Self { + pub fn new_test(db: db::DatabasePool, actors: GlobalActors) -> Self { use arbiter_proto::transport::DummyTransport; let transport: super::Transport = Box::new(DummyTransport::new()); - let props = ConnectionProps::new(db, transport); + let props = ClientConnection::new(db, transport, actors); let key = VerifyingKey::from_bytes(&[0u8; 32]).unwrap(); Self { props, key } } diff --git a/server/crates/arbiter-server/src/actors/mod.rs b/server/crates/arbiter-server/src/actors/mod.rs index 80ca4dd..33bdb5e 100644 --- a/server/crates/arbiter-server/src/actors/mod.rs +++ b/server/crates/arbiter-server/src/actors/mod.rs @@ -3,14 +3,15 @@ use miette::Diagnostic; use thiserror::Error; use crate::{ - actors::{bootstrap::Bootstrapper, keyholder::KeyHolder}, + actors::{bootstrap::Bootstrapper, keyholder::KeyHolder, router::MessageRouter}, db, }; pub mod bootstrap; -pub mod client; +pub mod router; pub mod keyholder; pub mod user_agent; +pub mod client; #[derive(Error, Debug, Diagnostic)] pub enum SpawnError { @@ -28,6 +29,7 @@ pub enum SpawnError { pub struct GlobalActors { pub key_holder: ActorRef, pub bootstrapper: ActorRef, + pub router: ActorRef, } impl GlobalActors { @@ -35,6 +37,7 @@ impl GlobalActors { Ok(Self { bootstrapper: Bootstrapper::spawn(Bootstrapper::new(&db).await?), key_holder: KeyHolder::spawn(KeyHolder::new(db.clone()).await?), + router: MessageRouter::spawn(MessageRouter::default()), }) } } diff --git a/server/crates/arbiter-server/src/actors/router/mod.rs b/server/crates/arbiter-server/src/actors/router/mod.rs new file mode 100644 index 0000000..966e1ce --- /dev/null +++ b/server/crates/arbiter-server/src/actors/router/mod.rs @@ -0,0 +1,79 @@ +use std::{ + collections::{HashMap}, + ops::ControlFlow, +}; + +use kameo::{ + Actor, + actor::{ActorId, ActorRef}, + messages, + prelude::{ActorStopReason, Context, WeakActorRef}, +}; +use tracing::info; + +use crate::actors::{client::session::ClientSession, user_agent::session::UserAgentSession}; + +#[derive(Default)] +pub struct MessageRouter { + pub user_agents: HashMap>, + pub clients: HashMap>, +} + +impl Actor for MessageRouter { + type Args = Self; + + type Error = (); + + async fn on_start(args: Self::Args, _: ActorRef) -> Result { + Ok(args) + } + + async fn on_link_died( + &mut self, + _: WeakActorRef, + id: ActorId, + _: ActorStopReason, + ) -> Result, Self::Error> { + if self.user_agents.remove(&id).is_some() { + info!( + ?id, + actor = "MessageRouter", + event = "useragent.disconnected" + ); + } else if self.clients.remove(&id).is_some() { + info!(?id, actor = "MessageRouter", event = "client.disconnected"); + } else { + info!( + ?id, + actor = "MessageRouter", + event = "unknown.actor.disconnected" + ); + } + Ok(ControlFlow::Continue(())) + } +} + +#[messages] +impl MessageRouter { + #[message(ctx)] + pub async fn register_user_agent( + &mut self, + actor: ActorRef, + ctx: &mut Context, + ) { + info!(id = %actor.id(), actor = "MessageRouter", event = "useragent.connected"); + ctx.actor_ref().link(&actor).await; + self.user_agents.insert(actor.id(), actor); + } + + #[message(ctx)] + pub async fn register_client( + &mut self, + actor: ActorRef, + ctx: &mut Context, + ) { + info!(id = %actor.id(), actor = "MessageRouter", event = "client.connected"); + ctx.actor_ref().link(&actor).await; + self.clients.insert(actor.id(), actor); + } +} diff --git a/server/crates/arbiter-server/src/actors/user_agent/auth.rs b/server/crates/arbiter-server/src/actors/user_agent/auth.rs index b34ed38..543dc87 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/auth.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/auth.rs @@ -6,7 +6,7 @@ use ed25519_dalek::VerifyingKey; use tracing::error; use crate::actors::user_agent::{ - ConnectionProps, + UserAgentConnection, auth::state::{AuthContext, AuthStateMachine}, session::UserAgentSession, }; @@ -71,7 +71,7 @@ fn parse_auth_event(payload: UserAgentRequestPayload) -> Result Result { +pub async fn authenticate(props: &mut UserAgentConnection) -> Result { let mut state = AuthStateMachine::new(AuthContext::new(props)); loop { @@ -111,7 +111,7 @@ pub async fn authenticate(props: &mut ConnectionProps) -> Result Result { +pub async fn authenticate_and_create(mut props: UserAgentConnection) -> Result { let key = authenticate(&mut props).await?; let session = UserAgentSession::new(props, key.clone()); Ok(session) diff --git a/server/crates/arbiter-server/src/actors/user_agent/auth/state.rs b/server/crates/arbiter-server/src/actors/user_agent/auth/state.rs index aa39bb6..9a4cf0c 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/auth/state.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/auth/state.rs @@ -9,7 +9,7 @@ use tracing::error; use super::Error; use crate::{ - actors::{bootstrap::ConsumeToken, user_agent::ConnectionProps}, + actors::{bootstrap::ConsumeToken, user_agent::UserAgentConnection}, db::schema, }; @@ -98,11 +98,11 @@ async fn register_key(db: &crate::db::DatabasePool, pubkey_bytes: &[u8]) -> Resu } pub struct AuthContext<'a> { - pub(super) conn: &'a mut ConnectionProps, + pub(super) conn: &'a mut UserAgentConnection, } impl<'a> AuthContext<'a> { - pub fn new(conn: &'a mut ConnectionProps) -> Self { + pub fn new(conn: &'a mut UserAgentConnection) -> Self { Self { conn } } } diff --git a/server/crates/arbiter-server/src/actors/user_agent/mod.rs b/server/crates/arbiter-server/src/actors/user_agent/mod.rs index f9ca3cc..2043b27 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/mod.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/mod.rs @@ -2,10 +2,13 @@ use arbiter_proto::{ proto::user_agent::{UserAgentRequest, UserAgentResponse}, transport::Bi, }; -use kameo::actor::Spawn; +use kameo::actor::Spawn as _; use tracing::{error, info}; -use crate::{actors::{GlobalActors, user_agent::session::UserAgentSession}, db}; +use crate::{ + actors::{GlobalActors, user_agent::session::UserAgentSession}, + db::{self}, +}; #[derive(Debug, thiserror::Error, PartialEq)] pub enum UserAgentError { @@ -23,18 +26,20 @@ pub enum UserAgentError { KeyHolderActorUnreachable, #[error(transparent)] Auth(#[from] auth::Error), + #[error("Failed registering connection")] + ConnectionRegistrationFailed, } pub type Transport = Box> + Send>; -pub struct ConnectionProps { +pub struct UserAgentConnection { db: db::DatabasePool, actors: GlobalActors, transport: Transport, } -impl ConnectionProps { +impl UserAgentConnection { pub fn new(db: db::DatabasePool, actors: GlobalActors, transport: Transport) -> Self { Self { db, @@ -44,17 +49,17 @@ impl ConnectionProps { } } -pub mod session; pub mod auth; +pub mod session; -pub async fn connect_user_agent(mut props: ConnectionProps) { - match auth::authenticate_and_create( props).await { +pub async fn connect_user_agent(props: UserAgentConnection) { + match auth::authenticate_and_create(props).await { Ok(session) => { UserAgentSession::spawn(session); info!("User authenticated, session started"); - }, + } Err(err) => { error!(?err, "Authentication failed, closing connection"); - }, + } } } diff --git a/server/crates/arbiter-server/src/actors/user_agent/session.rs b/server/crates/arbiter-server/src/actors/user_agent/session.rs index 6cc0808..04d3260 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/session.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/session.rs @@ -15,20 +15,21 @@ use x25519_dalek::{EphemeralSecret, PublicKey}; use crate::actors::{ keyholder::{self, TryUnseal}, - user_agent::{ConnectionProps, UserAgentError}, + router::RegisterUserAgent, + user_agent::{UserAgentConnection, UserAgentError}, }; mod state; use state::{DummyContext, UnsealContext, UserAgentEvents, UserAgentStateMachine, UserAgentStates}; pub struct UserAgentSession { - props: ConnectionProps, + props: UserAgentConnection, key: VerifyingKey, state: UserAgentStateMachine, } impl UserAgentSession { - pub(crate) fn new(props: ConnectionProps, key: VerifyingKey) -> Self { + pub(crate) fn new(props: UserAgentConnection, key: VerifyingKey) -> Self { Self { props, key, @@ -180,12 +181,23 @@ impl UserAgentSession { impl Actor for UserAgentSession { type Args = Self; - type Error = (); + type Error = UserAgentError; async fn on_start( args: Self::Args, - _: kameo::prelude::ActorRef, + this: kameo::prelude::ActorRef, ) -> Result { + args.props + .actors + .router + .ask(RegisterUserAgent { + actor: this.clone(), + }) + .await + .map_err(|err| { + error!(?err, "Failed to register user agent connection with router"); + UserAgentError::ConnectionRegistrationFailed + })?; Ok(args) } @@ -230,7 +242,7 @@ impl UserAgentSession { pub fn new_test(db: crate::db::DatabasePool, actors: crate::actors::GlobalActors) -> Self { use arbiter_proto::transport::DummyTransport; let transport: super::Transport = Box::new(DummyTransport::new()); - let props = ConnectionProps::new(db, actors, transport); + let props = UserAgentConnection::new(db, actors, transport); let key = VerifyingKey::from_bytes(&[0u8; 32]).unwrap(); Self { props, diff --git a/server/crates/arbiter-server/src/lib.rs b/server/crates/arbiter-server/src/lib.rs index 1d7fa97..59aeb9f 100644 --- a/server/crates/arbiter-server/src/lib.rs +++ b/server/crates/arbiter-server/src/lib.rs @@ -15,8 +15,8 @@ use tracing::info; use crate::{ actors::{ - client::{self, ClientError, ConnectionProps as ClientConnectionProps, connect_client}, - user_agent::{self, ConnectionProps, UserAgentError, connect_user_agent}, + client::{self, ClientError, ClientConnection as ClientConnectionProps, connect_client}, + user_agent::{self, UserAgentConnection, UserAgentError, connect_user_agent}, }, context::ServerContext, }; @@ -62,6 +62,9 @@ fn client_error_status(value: ClientError) -> Status { } ClientError::StateTransitionFailed => Status::internal("State machine error"), ClientError::Auth(ref err) => client_auth_error_status(err), + ClientError::ConnectionRegistrationFailed => { + Status::internal("Connection registration failed") + } } } @@ -98,6 +101,9 @@ fn user_agent_error_status(value: UserAgentError) -> Status { UserAgentError::StateTransitionFailed => Status::internal("State machine error"), UserAgentError::KeyHolderActorUnreachable => Status::internal("Vault is not available"), UserAgentError::Auth(ref err) => auth_error_status(err), + UserAgentError::ConnectionRegistrationFailed => { + Status::internal("Failed registering connection") + } } } @@ -152,7 +158,11 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server { IdentityRecvConverter::::new(), ClientGrpcSender, ); - let props = ClientConnectionProps::new(self.context.db.clone(), Box::new(transport)); + let props = ClientConnectionProps::new( + self.context.db.clone(), + Box::new(transport), + self.context.actors.clone(), + ); tokio::spawn(connect_client(props)); info!(event = "connection established", "grpc.client"); @@ -174,7 +184,7 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server { IdentityRecvConverter::::new(), UserAgentGrpcSender, ); - let props = ConnectionProps::new( + let props = UserAgentConnection::new( self.context.db.clone(), self.context.actors.clone(), Box::new(transport), diff --git a/server/crates/arbiter-server/tests/client/auth.rs b/server/crates/arbiter-server/tests/client/auth.rs index d7577a6..6228a58 100644 --- a/server/crates/arbiter-server/tests/client/auth.rs +++ b/server/crates/arbiter-server/tests/client/auth.rs @@ -1,11 +1,12 @@ -use arbiter_proto::transport::Bi; use arbiter_proto::proto::client::{ AuthChallengeRequest, AuthChallengeSolution, ClientRequest, client_request::Payload as ClientRequestPayload, client_response::Payload as ClientResponsePayload, }; +use arbiter_proto::transport::Bi; +use arbiter_server::actors::GlobalActors; use arbiter_server::{ - actors::client::{ConnectionProps, connect_client}, + actors::client::{ClientConnection, connect_client}, db::{self, schema}, }; use diesel::{ExpressionMethods as _, insert_into}; @@ -20,7 +21,8 @@ pub async fn test_unregistered_pubkey_rejected() { let db = db::create_test_pool().await; let (server_transport, mut test_transport) = ChannelTransport::new(); - let props = ConnectionProps::new(db.clone(), Box::new(server_transport)); + let actors = GlobalActors::spawn(db.clone()).await.unwrap(); + let props = ClientConnection::new(db.clone(), Box::new(server_transport), actors); let task = tokio::spawn(connect_client(props)); let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); @@ -59,7 +61,9 @@ pub async fn test_challenge_auth() { } let (server_transport, mut test_transport) = ChannelTransport::new(); - let props = ConnectionProps::new(db.clone(), Box::new(server_transport)); + let actors = GlobalActors::spawn(db.clone()).await.unwrap(); + + let props = ClientConnection::new(db.clone(), Box::new(server_transport), actors); let task = tokio::spawn(connect_client(props)); // Send challenge request diff --git a/server/crates/arbiter-server/tests/user_agent/auth.rs b/server/crates/arbiter-server/tests/user_agent/auth.rs index 2704ae6..17af990 100644 --- a/server/crates/arbiter-server/tests/user_agent/auth.rs +++ b/server/crates/arbiter-server/tests/user_agent/auth.rs @@ -8,7 +8,7 @@ use arbiter_server::{ actors::{ GlobalActors, bootstrap::GetToken, - user_agent::{ConnectionProps, connect_user_agent}, + user_agent::{UserAgentConnection, connect_user_agent}, }, db::{self, schema}, }; @@ -26,7 +26,7 @@ pub async fn test_bootstrap_token_auth() { let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap(); let (server_transport, mut test_transport) = ChannelTransport::new(); - let props = ConnectionProps::new(db.clone(), actors, Box::new(server_transport)); + let props = UserAgentConnection::new(db.clone(), actors, Box::new(server_transport)); let task = tokio::spawn(connect_user_agent(props)); let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); @@ -62,7 +62,7 @@ pub async fn test_bootstrap_invalid_token_auth() { let actors = GlobalActors::spawn(db.clone()).await.unwrap(); let (server_transport, mut test_transport) = ChannelTransport::new(); - let props = ConnectionProps::new(db.clone(), actors, Box::new(server_transport)); + let props = UserAgentConnection::new(db.clone(), actors, Box::new(server_transport)); let task = tokio::spawn(connect_user_agent(props)); let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); @@ -112,7 +112,7 @@ pub async fn test_challenge_auth() { } let (server_transport, mut test_transport) = ChannelTransport::new(); - let props = ConnectionProps::new(db.clone(), actors, Box::new(server_transport)); + let props = UserAgentConnection::new(db.clone(), actors, Box::new(server_transport)); let task = tokio::spawn(connect_user_agent(props)); // Send challenge request