feat(server::{router, useragent}): inter-actor approval coordination

This commit is contained in:
hdbg
2026-03-11 17:59:32 +01:00
parent b3a67ffc00
commit 606a1f3774
4 changed files with 247 additions and 36 deletions

View File

@@ -1,17 +1,20 @@
use std::{ use std::{collections::HashMap, ops::ControlFlow};
collections::{HashMap},
ops::ControlFlow,
};
use ed25519_dalek::VerifyingKey;
use kameo::{ use kameo::{
Actor, Actor,
actor::{ActorId, ActorRef}, actor::{ActorId, ActorRef},
messages, messages,
prelude::{ActorStopReason, Context, WeakActorRef}, prelude::{ActorStopReason, Context, WeakActorRef},
reply::DelegatedReply,
}; };
use tracing::info; use tokio::{sync::watch, task::JoinSet};
use tracing::{info, warn};
use crate::actors::{client::session::ClientSession, user_agent::session::UserAgentSession}; use crate::actors::{
client::session::ClientSession,
user_agent::session::{RequestNewClientApproval, UserAgentSession},
};
#[derive(Default)] #[derive(Default)]
pub struct MessageRouter { pub struct MessageRouter {
@@ -53,6 +56,74 @@ impl Actor for MessageRouter {
} }
} }
#[derive(Debug, thiserror::Error)]
pub enum ApprovalError {
#[error("No user agents connected")]
NoUserAgentsConnected,
}
async fn request_client_approval(
user_agents: &[WeakActorRef<UserAgentSession>],
client_pubkey: VerifyingKey,
) -> Result<bool, ApprovalError> {
if user_agents.is_empty() {
return Err(ApprovalError::NoUserAgentsConnected).into();
}
let mut pool = JoinSet::new();
let (cancel_tx, cancel_rx) = watch::channel(());
for weak_ref in user_agents {
match weak_ref.upgrade() {
Some(agent) => {
let client_pubkey = client_pubkey.clone();
let cancel_rx = cancel_rx.clone();
pool.spawn(async move {
agent
.ask(RequestNewClientApproval {
client_pubkey,
cancel_flag: cancel_rx.clone(),
})
.await
});
}
None => {
warn!(
id = weak_ref.id().to_string(),
actor = "MessageRouter",
event = "useragent.disconnected_before_approval"
);
}
}
}
while let Some(result) = pool.join_next().await {
match result {
Ok(Ok(approved)) => {
// cancel other pending requests
let _ = cancel_tx.send(());
return Ok(approved);
}
Ok(Err(err)) => {
warn!(
?err,
actor = "MessageRouter",
event = "useragent.approval_error"
);
}
Err(err) => {
warn!(
?err,
actor = "MessageRouter",
event = "useragent.approval_task_failed"
);
}
}
}
Err(ApprovalError::NoUserAgentsConnected)
}
#[messages] #[messages]
impl MessageRouter { impl MessageRouter {
#[message(ctx)] #[message(ctx)]
@@ -76,4 +147,29 @@ impl MessageRouter {
ctx.actor_ref().link(&actor).await; ctx.actor_ref().link(&actor).await;
self.clients.insert(actor.id(), actor); self.clients.insert(actor.id(), actor);
} }
#[message(ctx)]
pub async fn request_client_approval(
&mut self,
client_pubkey: VerifyingKey,
ctx: &mut Context<Self, DelegatedReply<Result<bool, ApprovalError>>>,
) -> DelegatedReply<Result<bool, ApprovalError>> {
let (reply, Some(reply_sender)) = ctx.reply_sender() else {
panic!("Exptected `request_client_approval` to have callback channel");
};
let weak_refs = self
.user_agents
.values()
.map(|agent| agent.downgrade())
.collect::<Vec<_>>();
// handle in subtask to not to lock the actor
tokio::task::spawn(async move {
let result = request_client_approval(&weak_refs, client_pubkey).await;
let _ = reply_sender.send(result);
});
reply
}
} }

View File

@@ -11,7 +11,7 @@ use crate::{
}; };
#[derive(Debug, thiserror::Error, PartialEq)] #[derive(Debug, thiserror::Error, PartialEq)]
pub enum UserAgentError { pub enum TransportResponseError {
#[error("Expected message with payload")] #[error("Expected message with payload")]
MissingRequestPayload, MissingRequestPayload,
#[error("Unexpected request payload")] #[error("Unexpected request payload")]
@@ -31,7 +31,7 @@ pub enum UserAgentError {
} }
pub type Transport = pub type Transport =
Box<dyn Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>> + Send>; Box<dyn Bi<UserAgentRequest, Result<UserAgentResponse, TransportResponseError>> + Send>;
pub struct UserAgentConnection { pub struct UserAgentConnection {
db: db::DatabasePool, db: db::DatabasePool,

View File

@@ -1,27 +1,41 @@
use std::{ops::DerefMut, sync::Mutex}; use std::{ops::DerefMut, sync::Mutex};
use arbiter_proto::proto::user_agent::{ use arbiter_proto::proto::{
UnsealEncryptedKey, UnsealResult, UnsealStart, UnsealStartResponse, UserAgentRequest, client,
UserAgentResponse, user_agent_request::Payload as UserAgentRequestPayload, user_agent::{
user_agent_response::Payload as UserAgentResponsePayload, ClientConnectionCancel, ClientConnectionRequest, UnsealEncryptedKey, UnsealResult,
UnsealStart, UnsealStartResponse, UserAgentRequest, UserAgentResponse,
user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
},
}; };
use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit}; use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit};
use ed25519_dalek::VerifyingKey; use ed25519_dalek::VerifyingKey;
use kameo::{Actor, error::SendError}; use kameo::{Actor, error::SendError, message, messages, prelude::Context};
use memsafe::MemSafe; use memsafe::MemSafe;
use tokio::select; use tokio::{select, sync::watch};
use tracing::{error, info}; use tracing::{error, info};
use x25519_dalek::{EphemeralSecret, PublicKey}; use x25519_dalek::{EphemeralSecret, PublicKey};
use crate::actors::{ use crate::actors::{
keyholder::{self, TryUnseal}, keyholder::{self, TryUnseal},
router::RegisterUserAgent, router::RegisterUserAgent,
user_agent::{UserAgentConnection, UserAgentError}, user_agent::{TransportResponseError, UserAgentConnection},
}; };
mod state; mod state;
use state::{DummyContext, UnsealContext, UserAgentEvents, UserAgentStateMachine, UserAgentStates}; use state::{DummyContext, UnsealContext, UserAgentEvents, UserAgentStateMachine, UserAgentStates};
// Error for consumption by other actors
#[derive(Debug, thiserror::Error, PartialEq)]
pub enum Error {
#[error("User agent session ended due to connection loss")]
ConnectionLost,
#[error("User agent session ended due to unexpected message")]
UnexpectedMessage,
}
pub struct UserAgentSession { pub struct UserAgentSession {
props: UserAgentConnection, props: UserAgentConnection,
key: VerifyingKey, key: VerifyingKey,
@@ -29,7 +43,7 @@ pub struct UserAgentSession {
} }
impl UserAgentSession { impl UserAgentSession {
pub(crate) fn new(props: UserAgentConnection, key: VerifyingKey) -> Self { pub(crate) fn new(props: UserAgentConnection, key: VerifyingKey) -> Self {
Self { Self {
props, props,
key, key,
@@ -37,18 +51,119 @@ impl UserAgentSession {
} }
} }
fn transition(&mut self, event: UserAgentEvents) -> Result<(), UserAgentError> { fn transition(&mut self, event: UserAgentEvents) -> Result<(), TransportResponseError> {
self.state.process_event(event).map_err(|e| { self.state.process_event(event).map_err(|e| {
error!(?e, "State transition failed"); error!(?e, "State transition failed");
UserAgentError::StateTransitionFailed TransportResponseError::StateTransitionFailed
})?; })?;
Ok(()) Ok(())
} }
async fn send_msg<Reply: kameo::Reply>(
&mut self,
msg: UserAgentResponsePayload,
ctx: &mut Context<Self, Reply>,
) -> Result<(), Error> {
self.props
.transport
.send(Ok(response(msg)))
.await
.map_err(|_| {
error!(
actor = "useragent",
reason = "channel closed",
"send.failed"
);
Error::ConnectionLost
})
}
async fn expect_msg<Extractor, Msg, Reply>(
&mut self,
extractor: Extractor,
ctx: &mut Context<Self, Reply>,
) -> Result<Msg, Error>
where
Extractor: FnOnce(UserAgentRequestPayload) -> Option<Msg>,
Reply: kameo::Reply,
{
let msg = self.props.transport.recv().await.ok_or_else(|| {
error!(
actor = "useragent",
reason = "channel closed",
"recv.failed"
);
ctx.stop();
Error::ConnectionLost
})?;
msg.payload.and_then(extractor).ok_or_else(|| {
error!(
actor = "useragent",
reason = "unexpected message",
"recv.failed"
);
ctx.stop();
Error::UnexpectedMessage
})
}
}
#[messages]
impl UserAgentSession {
// TODO: Think about refactoring it to state-machine based flow, as we already have one
#[message(ctx)]
pub async fn request_new_client_approval(
&mut self,
client_pubkey: VerifyingKey,
mut cancel_flag: watch::Receiver<()>,
ctx: &mut Context<Self, Result<bool, Error>>,
) -> Result<bool, Error> {
self.send_msg(
UserAgentResponsePayload::ClientConnectionRequest(
ClientConnectionRequest {
pubkey: client_pubkey.as_bytes().to_vec(),
}
.into(),
),
ctx,
)
.await?;
let extractor = |msg| {
if let UserAgentRequestPayload::ClientConnectionResponse(client_connection_response) =
msg
{
Some(client_connection_response)
} else {
None
}
};
tokio::select! {
_ = cancel_flag.changed() => {
info!(actor = "useragent", "client connection approval cancelled");
self.send_msg(
UserAgentResponsePayload::ClientConnectionCancel(ClientConnectionCancel {}),
ctx,
).await?;
return Ok(false);
}
result = self.expect_msg(extractor, ctx) => {
let result = result?;
info!(actor = "useragent", "received client connection approval result: approved={}", result.approved);
return Ok(result.approved);
}
}
}
}
impl UserAgentSession {
pub async fn process_transport_inbound(&mut self, req: UserAgentRequest) -> Output { pub async fn process_transport_inbound(&mut self, req: UserAgentRequest) -> Output {
let msg = req.payload.ok_or_else(|| { let msg = req.payload.ok_or_else(|| {
error!(actor = "useragent", "Received message with no payload"); error!(actor = "useragent", "Received message with no payload");
UserAgentError::MissingRequestPayload TransportResponseError::MissingRequestPayload
})?; })?;
match msg { match msg {
@@ -58,12 +173,12 @@ impl UserAgentSession {
UserAgentRequestPayload::UnsealEncryptedKey(unseal_encrypted_key) => { UserAgentRequestPayload::UnsealEncryptedKey(unseal_encrypted_key) => {
self.handle_unseal_encrypted_key(unseal_encrypted_key).await self.handle_unseal_encrypted_key(unseal_encrypted_key).await
} }
_ => Err(UserAgentError::UnexpectedRequestPayload), _ => Err(TransportResponseError::UnexpectedRequestPayload),
} }
} }
} }
type Output = Result<UserAgentResponse, UserAgentError>; type Output = Result<UserAgentResponse, TransportResponseError>;
fn response(payload: UserAgentResponsePayload) -> UserAgentResponse { fn response(payload: UserAgentResponsePayload) -> UserAgentResponse {
UserAgentResponse { UserAgentResponse {
@@ -79,7 +194,7 @@ impl UserAgentSession {
let client_pubkey_bytes: [u8; 32] = req let client_pubkey_bytes: [u8; 32] = req
.client_pubkey .client_pubkey
.try_into() .try_into()
.map_err(|_| UserAgentError::InvalidClientPubkeyLength)?; .map_err(|_| TransportResponseError::InvalidClientPubkeyLength)?;
let client_public_key = PublicKey::from(client_pubkey_bytes); let client_public_key = PublicKey::from(client_pubkey_bytes);
@@ -98,7 +213,7 @@ impl UserAgentSession {
async fn handle_unseal_encrypted_key(&mut self, req: UnsealEncryptedKey) -> Output { async fn handle_unseal_encrypted_key(&mut self, req: UnsealEncryptedKey) -> Output {
let UserAgentStates::WaitingForUnsealKey(unseal_context) = self.state.state() else { let UserAgentStates::WaitingForUnsealKey(unseal_context) = self.state.state() else {
error!("Received unseal encrypted key in invalid state"); error!("Received unseal encrypted key in invalid state");
return Err(UserAgentError::InvalidStateForUnsealEncryptedKey); return Err(TransportResponseError::InvalidStateForUnsealEncryptedKey);
}; };
let ephemeral_secret = { let ephemeral_secret = {
let mut secret_lock = unseal_context.secret.lock().unwrap(); let mut secret_lock = unseal_context.secret.lock().unwrap();
@@ -163,7 +278,7 @@ impl UserAgentSession {
Err(err) => { Err(err) => {
error!(?err, "Failed to send unseal request to keyholder"); error!(?err, "Failed to send unseal request to keyholder");
self.transition(UserAgentEvents::ReceivedInvalidKey)?; self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Err(UserAgentError::KeyHolderActorUnreachable) Err(TransportResponseError::KeyHolderActorUnreachable)
} }
} }
} }
@@ -181,7 +296,7 @@ impl UserAgentSession {
impl Actor for UserAgentSession { impl Actor for UserAgentSession {
type Args = Self; type Args = Self;
type Error = UserAgentError; type Error = TransportResponseError;
async fn on_start( async fn on_start(
args: Self::Args, args: Self::Args,
@@ -196,7 +311,7 @@ impl Actor for UserAgentSession {
.await .await
.map_err(|err| { .map_err(|err| {
error!(?err, "Failed to register user agent connection with router"); error!(?err, "Failed to register user agent connection with router");
UserAgentError::ConnectionRegistrationFailed TransportResponseError::ConnectionRegistrationFailed
})?; })?;
Ok(args) Ok(args)
} }

View File

@@ -16,7 +16,7 @@ use tracing::info;
use crate::{ use crate::{
actors::{ actors::{
client::{self, ClientError, ClientConnection as ClientConnectionProps, connect_client}, client::{self, ClientError, ClientConnection as ClientConnectionProps, connect_client},
user_agent::{self, UserAgentConnection, UserAgentError, connect_user_agent}, user_agent::{self, UserAgentConnection, TransportResponseError, connect_user_agent},
}, },
context::ServerContext, context::ServerContext,
}; };
@@ -30,7 +30,7 @@ const DEFAULT_CHANNEL_SIZE: usize = 1000;
struct UserAgentGrpcSender; struct UserAgentGrpcSender;
impl SendConverter for UserAgentGrpcSender { impl SendConverter for UserAgentGrpcSender {
type Input = Result<UserAgentResponse, UserAgentError>; type Input = Result<UserAgentResponse, TransportResponseError>;
type Output = Result<UserAgentResponse, Status>; type Output = Result<UserAgentResponse, Status>;
fn convert(&self, item: Self::Input) -> Self::Output { fn convert(&self, item: Self::Input) -> Self::Output {
@@ -87,21 +87,21 @@ fn client_auth_error_status(value: &client::auth::Error) -> Status {
} }
} }
fn user_agent_error_status(value: UserAgentError) -> Status { fn user_agent_error_status(value: TransportResponseError) -> Status {
match value { match value {
UserAgentError::MissingRequestPayload | UserAgentError::UnexpectedRequestPayload => { TransportResponseError::MissingRequestPayload | TransportResponseError::UnexpectedRequestPayload => {
Status::invalid_argument("Expected message with payload") Status::invalid_argument("Expected message with payload")
} }
UserAgentError::InvalidStateForUnsealEncryptedKey => { TransportResponseError::InvalidStateForUnsealEncryptedKey => {
Status::failed_precondition("Invalid state for unseal encrypted key") Status::failed_precondition("Invalid state for unseal encrypted key")
} }
UserAgentError::InvalidClientPubkeyLength => { TransportResponseError::InvalidClientPubkeyLength => {
Status::invalid_argument("client_pubkey must be 32 bytes") Status::invalid_argument("client_pubkey must be 32 bytes")
} }
UserAgentError::StateTransitionFailed => Status::internal("State machine error"), TransportResponseError::StateTransitionFailed => Status::internal("State machine error"),
UserAgentError::KeyHolderActorUnreachable => Status::internal("Vault is not available"), TransportResponseError::KeyHolderActorUnreachable => Status::internal("Vault is not available"),
UserAgentError::Auth(ref err) => auth_error_status(err), TransportResponseError::Auth(ref err) => auth_error_status(err),
UserAgentError::ConnectionRegistrationFailed => { TransportResponseError::ConnectionRegistrationFailed => {
Status::internal("Failed registering connection") Status::internal("Failed registering connection")
} }
} }