diff --git a/server/crates/arbiter-server/src/actors/router/mod.rs b/server/crates/arbiter-server/src/actors/router/mod.rs index 966e1ce..fbac74c 100644 --- a/server/crates/arbiter-server/src/actors/router/mod.rs +++ b/server/crates/arbiter-server/src/actors/router/mod.rs @@ -1,17 +1,20 @@ -use std::{ - collections::{HashMap}, - ops::ControlFlow, -}; +use std::{collections::HashMap, ops::ControlFlow}; +use ed25519_dalek::VerifyingKey; use kameo::{ Actor, actor::{ActorId, ActorRef}, messages, prelude::{ActorStopReason, Context, WeakActorRef}, + reply::DelegatedReply, }; -use tracing::info; +use tokio::{sync::watch, task::JoinSet}; +use tracing::{info, warn}; -use crate::actors::{client::session::ClientSession, user_agent::session::UserAgentSession}; +use crate::actors::{ + client::session::ClientSession, + user_agent::session::{RequestNewClientApproval, UserAgentSession}, +}; #[derive(Default)] pub struct MessageRouter { @@ -53,6 +56,74 @@ impl Actor for MessageRouter { } } +#[derive(Debug, thiserror::Error)] +pub enum ApprovalError { + #[error("No user agents connected")] + NoUserAgentsConnected, +} + +async fn request_client_approval( + user_agents: &[WeakActorRef], + client_pubkey: VerifyingKey, +) -> Result { + if user_agents.is_empty() { + return Err(ApprovalError::NoUserAgentsConnected).into(); + } + + let mut pool = JoinSet::new(); + let (cancel_tx, cancel_rx) = watch::channel(()); + + for weak_ref in user_agents { + match weak_ref.upgrade() { + Some(agent) => { + let client_pubkey = client_pubkey.clone(); + let cancel_rx = cancel_rx.clone(); + pool.spawn(async move { + agent + .ask(RequestNewClientApproval { + client_pubkey, + cancel_flag: cancel_rx.clone(), + }) + .await + }); + } + None => { + warn!( + id = weak_ref.id().to_string(), + actor = "MessageRouter", + event = "useragent.disconnected_before_approval" + ); + } + } + } + + while let Some(result) = pool.join_next().await { + match result { + Ok(Ok(approved)) => { + // cancel other pending requests + let _ = cancel_tx.send(()); + return Ok(approved); + } + Ok(Err(err)) => { + warn!( + ?err, + actor = "MessageRouter", + event = "useragent.approval_error" + ); + } + Err(err) => { + warn!( + ?err, + actor = "MessageRouter", + event = "useragent.approval_task_failed" + ); + } + } + } + + Err(ApprovalError::NoUserAgentsConnected) +} + #[messages] impl MessageRouter { #[message(ctx)] @@ -76,4 +147,29 @@ impl MessageRouter { ctx.actor_ref().link(&actor).await; self.clients.insert(actor.id(), actor); } + + #[message(ctx)] + pub async fn request_client_approval( + &mut self, + client_pubkey: VerifyingKey, + ctx: &mut Context>>, + ) -> DelegatedReply> { + let (reply, Some(reply_sender)) = ctx.reply_sender() else { + panic!("Exptected `request_client_approval` to have callback channel"); + }; + + let weak_refs = self + .user_agents + .values() + .map(|agent| agent.downgrade()) + .collect::>(); + + // handle in subtask to not to lock the actor + tokio::task::spawn(async move { + let result = request_client_approval(&weak_refs, client_pubkey).await; + let _ = reply_sender.send(result); + }); + + reply + } } diff --git a/server/crates/arbiter-server/src/actors/user_agent/mod.rs b/server/crates/arbiter-server/src/actors/user_agent/mod.rs index 2043b27..4380b72 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/mod.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/mod.rs @@ -11,7 +11,7 @@ use crate::{ }; #[derive(Debug, thiserror::Error, PartialEq)] -pub enum UserAgentError { +pub enum TransportResponseError { #[error("Expected message with payload")] MissingRequestPayload, #[error("Unexpected request payload")] @@ -31,7 +31,7 @@ pub enum UserAgentError { } pub type Transport = - Box> + Send>; + Box> + Send>; pub struct UserAgentConnection { db: db::DatabasePool, diff --git a/server/crates/arbiter-server/src/actors/user_agent/session.rs b/server/crates/arbiter-server/src/actors/user_agent/session.rs index 04d3260..401b3f6 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/session.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/session.rs @@ -1,27 +1,41 @@ use std::{ops::DerefMut, sync::Mutex}; -use arbiter_proto::proto::user_agent::{ - UnsealEncryptedKey, UnsealResult, UnsealStart, UnsealStartResponse, UserAgentRequest, - UserAgentResponse, user_agent_request::Payload as UserAgentRequestPayload, - user_agent_response::Payload as UserAgentResponsePayload, +use arbiter_proto::proto::{ + client, + user_agent::{ + ClientConnectionCancel, ClientConnectionRequest, UnsealEncryptedKey, UnsealResult, + UnsealStart, UnsealStartResponse, UserAgentRequest, UserAgentResponse, + user_agent_request::Payload as UserAgentRequestPayload, + user_agent_response::Payload as UserAgentResponsePayload, + }, }; use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit}; use ed25519_dalek::VerifyingKey; -use kameo::{Actor, error::SendError}; +use kameo::{Actor, error::SendError, message, messages, prelude::Context}; use memsafe::MemSafe; -use tokio::select; +use tokio::{select, sync::watch}; use tracing::{error, info}; use x25519_dalek::{EphemeralSecret, PublicKey}; use crate::actors::{ keyholder::{self, TryUnseal}, router::RegisterUserAgent, - user_agent::{UserAgentConnection, UserAgentError}, + user_agent::{TransportResponseError, UserAgentConnection}, }; mod state; use state::{DummyContext, UnsealContext, UserAgentEvents, UserAgentStateMachine, UserAgentStates}; +// Error for consumption by other actors +#[derive(Debug, thiserror::Error, PartialEq)] +pub enum Error { + #[error("User agent session ended due to connection loss")] + ConnectionLost, + + #[error("User agent session ended due to unexpected message")] + UnexpectedMessage, +} + pub struct UserAgentSession { props: UserAgentConnection, key: VerifyingKey, @@ -29,7 +43,7 @@ pub struct UserAgentSession { } impl UserAgentSession { - pub(crate) fn new(props: UserAgentConnection, key: VerifyingKey) -> Self { + pub(crate) fn new(props: UserAgentConnection, key: VerifyingKey) -> Self { Self { props, key, @@ -37,18 +51,119 @@ impl UserAgentSession { } } - fn transition(&mut self, event: UserAgentEvents) -> Result<(), UserAgentError> { + fn transition(&mut self, event: UserAgentEvents) -> Result<(), TransportResponseError> { self.state.process_event(event).map_err(|e| { error!(?e, "State transition failed"); - UserAgentError::StateTransitionFailed + TransportResponseError::StateTransitionFailed })?; Ok(()) } + async fn send_msg( + &mut self, + msg: UserAgentResponsePayload, + ctx: &mut Context, + ) -> Result<(), Error> { + self.props + .transport + .send(Ok(response(msg))) + .await + .map_err(|_| { + error!( + actor = "useragent", + reason = "channel closed", + "send.failed" + ); + Error::ConnectionLost + }) + } + + async fn expect_msg( + &mut self, + extractor: Extractor, + ctx: &mut Context, + ) -> Result + where + Extractor: FnOnce(UserAgentRequestPayload) -> Option, + Reply: kameo::Reply, + { + let msg = self.props.transport.recv().await.ok_or_else(|| { + error!( + actor = "useragent", + reason = "channel closed", + "recv.failed" + ); + ctx.stop(); + Error::ConnectionLost + })?; + + msg.payload.and_then(extractor).ok_or_else(|| { + error!( + actor = "useragent", + reason = "unexpected message", + "recv.failed" + ); + ctx.stop(); + Error::UnexpectedMessage + }) + } +} + +#[messages] +impl UserAgentSession { + + // TODO: Think about refactoring it to state-machine based flow, as we already have one + #[message(ctx)] + pub async fn request_new_client_approval( + &mut self, + client_pubkey: VerifyingKey, + mut cancel_flag: watch::Receiver<()>, + ctx: &mut Context>, + ) -> Result { + self.send_msg( + UserAgentResponsePayload::ClientConnectionRequest( + ClientConnectionRequest { + pubkey: client_pubkey.as_bytes().to_vec(), + } + .into(), + ), + ctx, + ) + .await?; + + let extractor = |msg| { + if let UserAgentRequestPayload::ClientConnectionResponse(client_connection_response) = + msg + { + Some(client_connection_response) + } else { + None + } + }; + + tokio::select! { + _ = cancel_flag.changed() => { + info!(actor = "useragent", "client connection approval cancelled"); + self.send_msg( + UserAgentResponsePayload::ClientConnectionCancel(ClientConnectionCancel {}), + ctx, + ).await?; + return Ok(false); + } + result = self.expect_msg(extractor, ctx) => { + let result = result?; + info!(actor = "useragent", "received client connection approval result: approved={}", result.approved); + return Ok(result.approved); + } + } + } +} + +impl UserAgentSession { pub async fn process_transport_inbound(&mut self, req: UserAgentRequest) -> Output { let msg = req.payload.ok_or_else(|| { error!(actor = "useragent", "Received message with no payload"); - UserAgentError::MissingRequestPayload + TransportResponseError::MissingRequestPayload })?; match msg { @@ -58,12 +173,12 @@ impl UserAgentSession { UserAgentRequestPayload::UnsealEncryptedKey(unseal_encrypted_key) => { self.handle_unseal_encrypted_key(unseal_encrypted_key).await } - _ => Err(UserAgentError::UnexpectedRequestPayload), + _ => Err(TransportResponseError::UnexpectedRequestPayload), } } } -type Output = Result; +type Output = Result; fn response(payload: UserAgentResponsePayload) -> UserAgentResponse { UserAgentResponse { @@ -79,7 +194,7 @@ impl UserAgentSession { let client_pubkey_bytes: [u8; 32] = req .client_pubkey .try_into() - .map_err(|_| UserAgentError::InvalidClientPubkeyLength)?; + .map_err(|_| TransportResponseError::InvalidClientPubkeyLength)?; let client_public_key = PublicKey::from(client_pubkey_bytes); @@ -98,7 +213,7 @@ impl UserAgentSession { async fn handle_unseal_encrypted_key(&mut self, req: UnsealEncryptedKey) -> Output { let UserAgentStates::WaitingForUnsealKey(unseal_context) = self.state.state() else { error!("Received unseal encrypted key in invalid state"); - return Err(UserAgentError::InvalidStateForUnsealEncryptedKey); + return Err(TransportResponseError::InvalidStateForUnsealEncryptedKey); }; let ephemeral_secret = { let mut secret_lock = unseal_context.secret.lock().unwrap(); @@ -163,7 +278,7 @@ impl UserAgentSession { Err(err) => { error!(?err, "Failed to send unseal request to keyholder"); self.transition(UserAgentEvents::ReceivedInvalidKey)?; - Err(UserAgentError::KeyHolderActorUnreachable) + Err(TransportResponseError::KeyHolderActorUnreachable) } } } @@ -181,7 +296,7 @@ impl UserAgentSession { impl Actor for UserAgentSession { type Args = Self; - type Error = UserAgentError; + type Error = TransportResponseError; async fn on_start( args: Self::Args, @@ -196,7 +311,7 @@ impl Actor for UserAgentSession { .await .map_err(|err| { error!(?err, "Failed to register user agent connection with router"); - UserAgentError::ConnectionRegistrationFailed + TransportResponseError::ConnectionRegistrationFailed })?; Ok(args) } diff --git a/server/crates/arbiter-server/src/lib.rs b/server/crates/arbiter-server/src/lib.rs index 59aeb9f..5af93a0 100644 --- a/server/crates/arbiter-server/src/lib.rs +++ b/server/crates/arbiter-server/src/lib.rs @@ -16,7 +16,7 @@ use tracing::info; use crate::{ actors::{ client::{self, ClientError, ClientConnection as ClientConnectionProps, connect_client}, - user_agent::{self, UserAgentConnection, UserAgentError, connect_user_agent}, + user_agent::{self, UserAgentConnection, TransportResponseError, connect_user_agent}, }, context::ServerContext, }; @@ -30,7 +30,7 @@ const DEFAULT_CHANNEL_SIZE: usize = 1000; struct UserAgentGrpcSender; impl SendConverter for UserAgentGrpcSender { - type Input = Result; + type Input = Result; type Output = Result; fn convert(&self, item: Self::Input) -> Self::Output { @@ -87,21 +87,21 @@ fn client_auth_error_status(value: &client::auth::Error) -> Status { } } -fn user_agent_error_status(value: UserAgentError) -> Status { +fn user_agent_error_status(value: TransportResponseError) -> Status { match value { - UserAgentError::MissingRequestPayload | UserAgentError::UnexpectedRequestPayload => { + TransportResponseError::MissingRequestPayload | TransportResponseError::UnexpectedRequestPayload => { Status::invalid_argument("Expected message with payload") } - UserAgentError::InvalidStateForUnsealEncryptedKey => { + TransportResponseError::InvalidStateForUnsealEncryptedKey => { Status::failed_precondition("Invalid state for unseal encrypted key") } - UserAgentError::InvalidClientPubkeyLength => { + TransportResponseError::InvalidClientPubkeyLength => { Status::invalid_argument("client_pubkey must be 32 bytes") } - UserAgentError::StateTransitionFailed => Status::internal("State machine error"), - UserAgentError::KeyHolderActorUnreachable => Status::internal("Vault is not available"), - UserAgentError::Auth(ref err) => auth_error_status(err), - UserAgentError::ConnectionRegistrationFailed => { + TransportResponseError::StateTransitionFailed => Status::internal("State machine error"), + TransportResponseError::KeyHolderActorUnreachable => Status::internal("Vault is not available"), + TransportResponseError::Auth(ref err) => auth_error_status(err), + TransportResponseError::ConnectionRegistrationFailed => { Status::internal("Failed registering connection") } }