diff --git a/server/crates/arbiter-proto/src/transport.rs b/server/crates/arbiter-proto/src/transport.rs index 02e54c5..9700d49 100644 --- a/server/crates/arbiter-proto/src/transport.rs +++ b/server/crates/arbiter-proto/src/transport.rs @@ -1,7 +1,56 @@ +//! Transport abstraction layer for bridging gRPC bidirectional streaming with kameo actors. +//! +//! This module provides a clean separation between the gRPC transport layer and business logic +//! by modeling the connection as two linked kameo actors: +//! +//! - A **transport actor** ([`GrpcTransportActor`]) that owns the gRPC stream and channel, +//! forwarding inbound messages to the business actor and outbound messages to the client. +//! - A **business logic actor** that receives inbound messages from the transport actor and +//! sends outbound messages back through the transport actor. +//! +//! The [`wire()`] function sets up bidirectional linking between the two actors, ensuring +//! that if either actor dies, the other is notified and can shut down gracefully. +//! +//! # Terminology +//! +//! - **InboundMessage**: a message received by the transport actor from the channel/socket +//! and forwarded to the business actor. +//! - **OutboundMessage**: a message produced by the business actor and sent to the transport +//! actor to be forwarded to the channel/socket. +//! +//! # Architecture +//! +//! ```text +//! gRPC Stream ──InboundMessage──▶ GrpcTransportActor ──tell(InboundMessage)──▶ BusinessActor +//! ▲ │ +//! └─tell(Result)────┘ +//! │ +//! mpsc::Sender ──▶ Client +//! ``` +//! +//! # Example +//! +//! ```rust,ignore +//! let (tx, rx) = mpsc::channel(1000); +//! let context = server_context.clone(); +//! +//! wire( +//! |transport_ref| MyBusinessActor::new(context, transport_ref), +//! |business_recipient, business_id| GrpcTransportActor { +//! sender: tx, +//! receiver: grpc_stream, +//! business_logic_actor: business_recipient, +//! business_logic_actor_id: business_id, +//! }, +//! ).await; +//! +//! Ok(Response::new(ReceiverStream::new(rx))) +//! ``` + use futures::{Stream, StreamExt}; use kameo::{ Actor, - actor::{ActorRef, PreparedActor, Spawn, WeakActorRef}, + actor::{ActorRef, PreparedActor, Recipient, Spawn, WeakActorRef}, mailbox::Signal, prelude::Message, }; @@ -12,7 +61,15 @@ use tokio::{ use tonic::{Status, Streaming}; use tracing::{debug, error}; -// Abstraction for stream for sans-io capabilities +/// A bidirectional stream abstraction for sans-io testing. +/// +/// Combines a [`Stream`] of incoming messages with the ability to [`send`](Bi::send) +/// outgoing responses. This trait allows business logic to be tested without a real +/// gRPC connection by swapping in an in-memory implementation. +/// +/// # Type Parameters +/// - `T`: `InboundMessage` received from the channel/socket (e.g., `UserAgentRequest`) +/// - `U`: `OutboundMessage` sent to the channel/socket (e.g., `UserAgentResponse`) pub trait Bi: Stream> + Send + Sync + 'static { type Error; fn send( @@ -21,7 +78,10 @@ pub trait Bi: Stream> + Send + Sync + 'static { ) -> impl std::future::Future> + Send; } -// Bi-directional stream abstraction for handling gRPC streaming requests and responses +/// Concrete [`Bi`] implementation backed by a tonic gRPC [`Streaming`] and an [`mpsc::Sender`]. +/// +/// This is the production implementation used in gRPC service handlers. The `request_stream` +/// receives messages from the client, and `response_sender` sends responses back. pub struct BiStream { pub request_stream: Streaming, pub response_sender: mpsc::Sender>, @@ -54,23 +114,95 @@ where } } -pub trait TransportActor: Actor + Send + Message {} - -pub struct GrpcTransportActor -where - SendMsg: Send + 'static, - RecvMsg: Send + 'static, - A: TransportActor, +/// Marker trait for transport actors that can receive outbound messages of type `T`. +/// +/// Implement this on your transport actor to indicate it can handle outbound messages +/// produced by the business actor. Requires the actor to implement [`Message>`] +/// so business logic can forward responses via [`tell()`](ActorRef::tell). +/// +/// # Example +/// +/// ```rust,ignore +/// #[derive(Actor)] +/// struct MyTransportActor { /* ... */ } +/// +/// impl Message> for MyTransportActor { +/// type Reply = (); +/// async fn handle(&mut self, msg: Result, _ctx: &mut Context) -> Self::Reply { +/// // forward outbound message to channel/socket +/// } +/// } +/// +/// impl TransportActor for MyTransportActor {} +/// ``` +pub trait TransportActor: + Actor + Send + Message> { - pub sender: mpsc::Sender>, - pub receiver: tonic::Streaming, - pub business_logic_actor: ActorRef, } -impl Actor for GrpcTransportActor + +/// A kameo actor that bridges a gRPC bidirectional stream with a business logic actor. +/// +/// This actor owns the gRPC [`Streaming`] receiver and an [`mpsc::Sender`] for responses. +/// It multiplexes between its own mailbox (for outbound messages from the business actor) +/// and the gRPC stream (for inbound client messages) using [`tokio::select!`]. +/// +/// # Message Flow +/// +/// - **Inbound**: Messages from the gRPC stream are forwarded to `business_logic_actor` +/// via [`tell()`](Recipient::tell). +/// - **Outbound**: The business actor sends `Result` messages to this +/// actor, which forwards them through the `sender` channel to the gRPC response stream. +/// +/// # Lifecycle +/// +/// - If the business logic actor dies (detected via actor linking), this actor stops, +/// which closes the gRPC stream. +/// - If the gRPC stream closes or errors, this actor stops, which (via linking) notifies +/// the business actor. +/// - Error responses (`Err(DomainError)`) are forwarded to the client and then the actor stops, +/// closing the connection. +/// +/// # Type Parameters +/// - `Outbound`: `OutboundMessage` sent to the client (e.g., `UserAgentResponse`) +/// - `Inbound`: `InboundMessage` received from the client (e.g., `UserAgentRequest`) +/// - `E`: The domain error type, must implement `Into` for gRPC conversion +pub struct GrpcTransportActor where - SendMsg: Send + 'static, - RecvMsg: Send + 'static, - A: TransportActor, + Outbound: Send + 'static, + Inbound: Send + 'static, + DomainError: Into + Send + 'static, +{ + sender: mpsc::Sender>, + receiver: tonic::Streaming, + business_logic_actor: Recipient, + _error: std::marker::PhantomData, +} + +impl GrpcTransportActor +where + Outbound: Send + 'static, + Inbound: Send + 'static, + DomainError: Into + Send + 'static, +{ + pub fn new( + sender: mpsc::Sender>, + receiver: tonic::Streaming, + business_logic_actor: Recipient, + ) -> Self { + Self { + sender, + receiver, + business_logic_actor, + _error: std::marker::PhantomData, + } + } +} + +impl Actor for GrpcTransportActor +where + Outbound: Send + 'static, + Inbound: Send + 'static, + E: Into + Send + 'static, { type Args = Self; @@ -139,19 +271,27 @@ where } } -impl> Message - for GrpcTransportActor +impl Message> for GrpcTransportActor +where + Outbound: Send + 'static, + Inbound: Send + 'static, + E: Into + Send + 'static, { type Reply = (); async fn handle( &mut self, - msg: SendMsg, + msg: Result, ctx: &mut kameo::prelude::Context, ) -> Self::Reply { - let err = self.sender.send(Ok(msg)).await; - match err { - Ok(_) => (), + let is_err = msg.is_err(); + let grpc_msg = msg.map_err(Into::into); + match self.sender.send(grpc_msg).await { + Ok(_) => { + if is_err { + ctx.stop(); + } + } Err(e) => { error!("Failed to send message: {}", e); ctx.stop(); @@ -160,20 +300,60 @@ impl( +impl TransportActor for GrpcTransportActor +where + Outbound: Send + 'static, + Inbound: Send + 'static, + E: Into + Send + 'static, +{ +} + +/// Wires together a transport actor and a business logic actor with bidirectional linking. +/// +/// This function handles the chicken-and-egg problem of two actors that need references +/// to each other at construction time. It uses kameo's [`PreparedActor`] to obtain +/// [`ActorRef`]s before spawning, then links both actors so that if either dies, +/// the other is notified via [`on_link_died`](Actor::on_link_died). +/// +/// The business actor receives a type-erased [`Recipient>`] instead of an +/// `ActorRef`, keeping it decoupled from the concrete transport implementation. +/// +/// # Type Parameters +/// - `Transport`: The transport actor type (e.g., [`GrpcTransportActor`]) +/// - `Inbound`: `InboundMessage` received by the business actor from the transport +/// - `Outbound`: `OutboundMessage` sent by the business actor back to the transport +/// - `Business`: The business logic actor +/// - `BusinessCtor`: Closure that receives a prepared business actor and transport recipient, +/// spawns the business actor, and returns its [`ActorRef`] +/// - `TransportCtor`: Closure that receives a prepared transport actor, a recipient for +/// inbound messages, and the business actor id, then spawns the transport actor +/// +/// # Returns +/// A tuple of `(transport_ref, business_ref)` — actor references for both spawned actors. +pub async fn wire< + Transport, + Inbound, + Outbound, + DomainError, + Business, + BusinessCtor, + TransportCtor, +>( business_ctor: BusinessCtor, transport_ctor: TransportCtor, -) -> (ActorRef, ActorRef) +) -> (ActorRef, ActorRef) where - T: TransportActor, - RecvMsg: Send + 'static, - SendMsg: Send + 'static, - BusinessActor: Actor + Send + 'static, - BusinessCtor: FnOnce(ActorRef) -> BusinessActor::Args, - TransportCtor: FnOnce(ActorRef) -> T::Args, + Transport: TransportActor, + Inbound: Send + 'static, + Outbound: Send + 'static, + DomainError: Send + 'static, + Business: Actor + Message + Send + 'static, + BusinessCtor: FnOnce(PreparedActor, Recipient>), + TransportCtor: + FnOnce(PreparedActor, Recipient), { - let prepared_business: PreparedActor = Spawn::prepare(); - let prepared_transport: PreparedActor = Spawn::prepare(); + let prepared_business: PreparedActor = Spawn::prepare(); + let prepared_transport: PreparedActor = Spawn::prepare(); let business_ref = prepared_business.actor_ref().clone(); let transport_ref = prepared_transport.actor_ref().clone(); @@ -181,8 +361,11 @@ where transport_ref.link(&business_ref).await; business_ref.link(&transport_ref).await; - let _ = prepared_business.spawn(business_ctor(transport_ref.clone())); - let _ = prepared_transport.spawn(transport_ctor(business_ref.clone())); + let recipient = transport_ref.clone().recipient(); + business_ctor(prepared_business, recipient); + let business_recipient = business_ref.clone().recipient(); + transport_ctor(prepared_transport, business_recipient); + (transport_ref, business_ref) } diff --git a/server/crates/arbiter-server/src/actors/user_agent/error.rs b/server/crates/arbiter-server/src/actors/user_agent/error.rs new file mode 100644 index 0000000..a983723 --- /dev/null +++ b/server/crates/arbiter-server/src/actors/user_agent/error.rs @@ -0,0 +1,57 @@ +use tonic::Status; + +use crate::db; + +#[derive(Debug, thiserror::Error)] +pub enum UserAgentError { + #[error("Missing payload in request")] + MissingPayload, + + #[error("Invalid bootstrap token")] + InvalidBootstrapToken, + + #[error("Public key not registered")] + PubkeyNotRegistered, + + #[error("Invalid public key format")] + InvalidPubkey, + + #[error("Invalid signature length")] + InvalidSignatureLength, + + #[error("Invalid challenge solution")] + InvalidChallengeSolution, + + #[error("Invalid state for operation")] + InvalidState, + + #[error("Actor unavailable")] + ActorUnavailable, + + #[error("Database error")] + Database(#[from] diesel::result::Error), + + #[error("Database pool error")] + DatabasePool(#[from] db::PoolError), +} + +impl From for Status { + fn from(err: UserAgentError) -> Self { + match err { + UserAgentError::MissingPayload + | UserAgentError::InvalidBootstrapToken + | UserAgentError::InvalidPubkey + | UserAgentError::InvalidSignatureLength => Status::invalid_argument(err.to_string()), + + UserAgentError::PubkeyNotRegistered | UserAgentError::InvalidChallengeSolution => { + Status::unauthenticated(err.to_string()) + } + + UserAgentError::InvalidState => Status::failed_precondition(err.to_string()), + + UserAgentError::ActorUnavailable + | UserAgentError::Database(_) + | UserAgentError::DatabasePool(_) => Status::internal(err.to_string()), + } + } +} diff --git a/server/crates/arbiter-server/src/actors/user_agent/mod.rs b/server/crates/arbiter-server/src/actors/user_agent/mod.rs index 700dd40..13c38be 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/mod.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/mod.rs @@ -1,21 +1,23 @@ use std::{ops::DerefMut, sync::Mutex}; use arbiter_proto::proto::{ - UnsealEncryptedKey, UnsealResult, UnsealStart, UnsealStartResponse, UserAgentResponse, + UnsealEncryptedKey, UnsealResult, UnsealStart, UnsealStartResponse, UserAgentRequest, + UserAgentResponse, auth::{ - self, AuthChallengeRequest, AuthOk, ServerMessage as AuthServerMessage, + self, AuthChallengeRequest, AuthOk, ClientMessage as ClientAuthMessage, + ServerMessage as AuthServerMessage, + client_message::Payload as ClientAuthPayload, server_message::Payload as ServerAuthPayload, }, + user_agent_request::Payload as UserAgentRequestPayload, user_agent_response::Payload as UserAgentResponsePayload, }; use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit}; use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, dsl::update}; use diesel_async::RunQueryDsl; use ed25519_dalek::VerifyingKey; -use kameo::{Actor, error::SendError, messages}; +use kameo::{Actor, actor::Recipient, error::SendError, messages, prelude::Message}; use memsafe::MemSafe; -use tokio::sync::mpsc::Sender; -use tonic::Status; use tracing::{error, info}; use x25519_dalek::{EphemeralSecret, PublicKey}; @@ -31,53 +33,74 @@ use crate::{ }, }, db::{self, schema}, - errors::GrpcStatusExt, }; +mod error; mod state; -mod transport; -pub(crate) use transport::handle_user_agent; +pub use error::UserAgentError; #[derive(Actor)] pub struct UserAgentActor { db: db::DatabasePool, actors: GlobalActors, state: UserAgentStateMachine, - // will be used in future - _tx: Sender>, + transport: Recipient>, } impl UserAgentActor { pub(crate) fn new( context: ServerContext, - tx: Sender>, + transport: Recipient>, ) -> Self { Self { db: context.db.clone(), actors: context.actors.clone(), state: UserAgentStateMachine::new(DummyContext), - _tx: tx, + transport, } } pub fn new_manual( db: db::DatabasePool, actors: GlobalActors, - tx: Sender>, + transport: Recipient>, ) -> Self { Self { db, actors, state: UserAgentStateMachine::new(DummyContext), - _tx: tx, + transport, } } - fn transition(&mut self, event: UserAgentEvents) -> Result<(), Status> { + async fn process_request(&mut self, req: UserAgentRequest) -> Output { + let msg = req.payload.ok_or_else(|| { + error!(actor = "useragent", "Received message with no payload"); + UserAgentError::MissingPayload + })?; + + match msg { + UserAgentRequestPayload::AuthMessage(ClientAuthMessage { + payload: Some(ClientAuthPayload::AuthChallengeRequest(req)), + }) => self.handle_auth_challenge_request(req).await, + UserAgentRequestPayload::AuthMessage(ClientAuthMessage { + payload: Some(ClientAuthPayload::AuthChallengeSolution(solution)), + }) => self.handle_auth_challenge_solution(solution).await, + UserAgentRequestPayload::UnsealStart(unseal_start) => { + self.handle_unseal_request(unseal_start).await + } + UserAgentRequestPayload::UnsealEncryptedKey(unseal_encrypted_key) => { + self.handle_unseal_encrypted_key(unseal_encrypted_key).await + } + _ => Err(UserAgentError::MissingPayload), + } + } + + fn transition(&mut self, event: UserAgentEvents) -> Result<(), UserAgentError> { self.state.process_event(event).map_err(|e| { error!(?e, "State transition failed"); - Status::internal("State machine error") + UserAgentError::InvalidState })?; Ok(()) } @@ -86,7 +109,7 @@ impl UserAgentActor { &mut self, pubkey: ed25519_dalek::VerifyingKey, token: String, - ) -> Result { + ) -> Output { let token_ok: bool = self .actors .bootstrapper @@ -94,16 +117,16 @@ impl UserAgentActor { .await .map_err(|e| { error!(?pubkey, "Failed to consume bootstrap token: {e}"); - Status::internal("Bootstrap token consumption failed") + UserAgentError::ActorUnavailable })?; if !token_ok { error!(?pubkey, "Invalid bootstrap token provided"); - return Err(Status::invalid_argument("Invalid bootstrap token")); + return Err(UserAgentError::InvalidBootstrapToken); } { - let mut conn = self.db.get().await.to_status()?; + let mut conn = self.db.get().await?; diesel::insert_into(schema::useragent_client::table) .values(( @@ -111,8 +134,7 @@ impl UserAgentActor { schema::useragent_client::nonce.eq(1), )) .execute(&mut conn) - .await - .to_status()?; + .await?; } self.transition(UserAgentEvents::ReceivedBootstrapToken)?; @@ -122,7 +144,7 @@ impl UserAgentActor { async fn auth_with_challenge(&mut self, pubkey: VerifyingKey, pubkey_bytes: Vec) -> Output { let nonce: Option = { - let mut db_conn = self.db.get().await.to_status()?; + let mut db_conn = self.db.get().await?; db_conn .exclusive_transaction(|conn| { Box::pin(async move { @@ -146,13 +168,12 @@ impl UserAgentActor { }) }) .await - .optional() - .to_status()? + .optional()? }; let Some(nonce) = nonce else { error!(?pubkey, "Public key not found in database"); - return Err(Status::unauthenticated("Public key not registered")); + return Err(UserAgentError::PubkeyNotRegistered); }; let challenge = auth::AuthChallenge { @@ -177,19 +198,17 @@ impl UserAgentActor { fn verify_challenge_solution( &self, solution: &auth::AuthChallengeSolution, - ) -> Result<(bool, &ChallengeContext), Status> { + ) -> Result<(bool, &ChallengeContext), UserAgentError> { let UserAgentStates::WaitingForChallengeSolution(challenge_context) = self.state.state() else { error!("Received challenge solution in invalid state"); - return Err(Status::invalid_argument( - "Invalid state for challenge solution", - )); + return Err(UserAgentError::InvalidState); }; let formatted_challenge = arbiter_proto::format_challenge(&challenge_context.challenge); let signature = solution.signature.as_slice().try_into().map_err(|_| { error!(?solution, "Invalid signature length"); - Status::invalid_argument("Invalid signature length") + UserAgentError::InvalidSignatureLength })?; let valid = challenge_context @@ -201,7 +220,7 @@ impl UserAgentActor { } } -type Output = Result; +type Output = Result; fn auth_response(payload: ServerAuthPayload) -> UserAgentResponse { UserAgentResponse { @@ -227,7 +246,7 @@ impl UserAgentActor { let client_pubkey_bytes: [u8; 32] = req .client_pubkey .try_into() - .map_err(|_| Status::invalid_argument("client_pubkey must be 32 bytes"))?; + .map_err(|_| UserAgentError::InvalidPubkey)?; let client_public_key = PublicKey::from(client_pubkey_bytes); @@ -247,9 +266,7 @@ impl UserAgentActor { pub async fn handle_unseal_encrypted_key(&mut self, req: UnsealEncryptedKey) -> Output { let UserAgentStates::WaitingForUnsealKey(unseal_context) = self.state.state() else { error!("Received unseal encrypted key in invalid state"); - return Err(Status::failed_precondition( - "Invalid state for unseal encrypted key", - )); + return Err(UserAgentError::InvalidState); }; let ephemeral_secret = { let mut secret_lock = unseal_context.secret.lock().unwrap(); @@ -313,7 +330,7 @@ impl UserAgentActor { Err(err) => { error!(?err, "Failed to send unseal request to keyholder"); self.transition(UserAgentEvents::ReceivedInvalidKey)?; - Err(Status::internal("Vault is not available")) + Err(UserAgentError::ActorUnavailable) } } } @@ -329,12 +346,13 @@ impl UserAgentActor { #[message] pub async fn handle_auth_challenge_request(&mut self, req: AuthChallengeRequest) -> Output { - let pubkey = req.pubkey.as_array().ok_or(Status::invalid_argument( - "Expected pubkey to have specific length", - ))?; + let pubkey = req + .pubkey + .as_array() + .ok_or(UserAgentError::InvalidPubkey)?; let pubkey = VerifyingKey::from_bytes(pubkey).map_err(|_err| { error!(?pubkey, "Failed to convert to VerifyingKey"); - Status::invalid_argument("Failed to convert pubkey to VerifyingKey") + UserAgentError::InvalidPubkey })?; self.transition(UserAgentEvents::AuthRequest)?; @@ -362,7 +380,22 @@ impl UserAgentActor { } else { error!("Client provided invalid solution to authentication challenge"); self.transition(UserAgentEvents::ReceivedBadSolution)?; - Err(Status::unauthenticated("Invalid challenge solution")) + Err(UserAgentError::InvalidChallengeSolution) + } + } +} + +impl Message for UserAgentActor { + type Reply = (); + + async fn handle( + &mut self, + msg: UserAgentRequest, + _ctx: &mut kameo::prelude::Context, + ) -> Self::Reply { + let result = self.process_request(msg).await; + if let Err(e) = self.transport.tell(result).await { + error!(actor = "useragent", "Failed to send response to transport: {}", e); } } } diff --git a/server/crates/arbiter-server/src/actors/user_agent/transport.rs b/server/crates/arbiter-server/src/actors/user_agent/transport.rs deleted file mode 100644 index c1ac84c..0000000 --- a/server/crates/arbiter-server/src/actors/user_agent/transport.rs +++ /dev/null @@ -1,95 +0,0 @@ -use super::UserAgentActor; -use arbiter_proto::proto::{ - UserAgentRequest, UserAgentResponse, - auth::{ClientMessage as ClientAuthMessage, client_message::Payload as ClientAuthPayload}, - user_agent_request::Payload as UserAgentRequestPayload, -}; -use futures::StreamExt; -use kameo::{ - actor::{ActorRef, Spawn as _}, - error::SendError, -}; -use tokio::sync::mpsc; -use tonic::Status; -use tracing::error; - -use crate::{ - actors::user_agent::{ - HandleAuthChallengeRequest, HandleAuthChallengeSolution, HandleUnsealEncryptedKey, - HandleUnsealRequest, - }, - context::ServerContext, -}; - -pub(crate) async fn handle_user_agent( - context: ServerContext, - mut req_stream: tonic::Streaming, - tx: mpsc::Sender>, -) { - let actor = UserAgentActor::spawn(UserAgentActor::new(context, tx.clone())); - - while let Some(Ok(req)) = req_stream.next().await - && actor.is_alive() - { - match process_message(&actor, req).await { - Ok(resp) => { - if tx.send(Ok(resp)).await.is_err() { - error!(actor = "useragent", "Failed to send response to client"); - break; - } - } - Err(status) => { - let _ = tx.send(Err(status)).await; - break; - } - } - } - - actor.kill(); -} - -async fn process_message( - actor: &ActorRef, - req: UserAgentRequest, -) -> Result { - let msg = req.payload.ok_or_else(|| { - error!(actor = "useragent", "Received message with no payload"); - Status::invalid_argument("Expected message with payload") - })?; - - match msg { - UserAgentRequestPayload::AuthMessage(ClientAuthMessage { - payload: Some(ClientAuthPayload::AuthChallengeRequest(req)), - }) => actor - .ask(HandleAuthChallengeRequest { req }) - .await - .map_err(into_status), - UserAgentRequestPayload::AuthMessage(ClientAuthMessage { - payload: Some(ClientAuthPayload::AuthChallengeSolution(solution)), - }) => actor - .ask(HandleAuthChallengeSolution { solution }) - .await - .map_err(into_status), - UserAgentRequestPayload::UnsealStart(unseal_start) => actor - .ask(HandleUnsealRequest { req: unseal_start }) - .await - .map_err(into_status), - UserAgentRequestPayload::UnsealEncryptedKey(unseal_encrypted_key) => actor - .ask(HandleUnsealEncryptedKey { - req: unseal_encrypted_key, - }) - .await - .map_err(into_status), - _ => Err(Status::invalid_argument("Expected message with payload")), - } -} - -fn into_status(e: SendError) -> Status { - match e { - SendError::HandlerError(status) => status, - _ => { - error!(actor = "useragent", "Failed to send message to actor"); - Status::internal("session failure") - } - } -} diff --git a/server/crates/arbiter-server/src/errors.rs b/server/crates/arbiter-server/src/errors.rs deleted file mode 100644 index 98dae76..0000000 --- a/server/crates/arbiter-server/src/errors.rs +++ /dev/null @@ -1,24 +0,0 @@ -use tonic::Status; -use tracing::error; - -pub trait GrpcStatusExt { - fn to_status(self) -> Result; -} - -impl GrpcStatusExt for Result { - fn to_status(self) -> Result { - self.map_err(|e| { - error!(error = ?e, "Database error"); - Status::internal("Database error") - }) - } -} - -impl GrpcStatusExt for Result { - fn to_status(self) -> Result { - self.map_err(|e| { - error!(error = ?e, "Database pool error"); - Status::internal("Database pool error") - }) - } -} diff --git a/server/crates/arbiter-server/src/lib.rs b/server/crates/arbiter-server/src/lib.rs index 9d86e27..af02dd5 100644 --- a/server/crates/arbiter-server/src/lib.rs +++ b/server/crates/arbiter-server/src/lib.rs @@ -1,23 +1,26 @@ #![forbid(unsafe_code)] use arbiter_proto::{ proto::{ClientRequest, ClientResponse, UserAgentRequest, UserAgentResponse}, - transport::BiStream, + transport::{BiStream, GrpcTransportActor, wire}, }; use async_trait::async_trait; +use kameo::actor::PreparedActor; use tokio_stream::wrappers::ReceiverStream; use tokio::sync::mpsc; use tonic::{Request, Response, Status}; use crate::{ - actors::{client::handle_client, user_agent::handle_user_agent}, + actors::{ + client::handle_client, + user_agent::UserAgentActor, + }, context::ServerContext, }; pub mod actors; pub mod context; pub mod db; -mod errors; const DEFAULT_CHANNEL_SIZE: usize = 1000; @@ -59,7 +62,22 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server { ) -> Result, Status> { let req_stream = request.into_inner(); let (tx, rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); - tokio::spawn(handle_user_agent(self.context.clone(), req_stream, tx)); + let context = self.context.clone(); + + wire( + |prepared: PreparedActor, recipient| { + prepared.spawn(UserAgentActor::new(context, recipient)); + }, + |prepared: PreparedActor>, business_recipient| { + prepared.spawn(GrpcTransportActor::new( + tx, + req_stream, + business_recipient, + )); + }, + ) + .await; + Ok(Response::new(ReceiverStream::new(rx))) } } diff --git a/server/crates/arbiter-server/tests/user_agent.rs b/server/crates/arbiter-server/tests/user_agent.rs index dcd9789..d90bf3a 100644 --- a/server/crates/arbiter-server/tests/user_agent.rs +++ b/server/crates/arbiter-server/tests/user_agent.rs @@ -1,5 +1,30 @@ mod common; +use arbiter_proto::proto::UserAgentResponse; +use arbiter_server::actors::user_agent::UserAgentError; +use kameo::{Actor, actor::Recipient, actor::Spawn, prelude::Message}; + +/// A no-op actor that discards any messages it receives. +#[derive(Actor)] +struct NullSink; + +impl Message> for NullSink { + type Reply = (); + + async fn handle( + &mut self, + _msg: Result, + _ctx: &mut kameo::prelude::Context, + ) -> Self::Reply { + } +} + +/// Creates a `Recipient` that silently discards all messages. +fn null_recipient() -> Recipient> { + let actor_ref = NullSink::spawn(NullSink); + actor_ref.recipient() +} + #[path = "user_agent/auth.rs"] mod auth; #[path = "user_agent/unseal.rs"] diff --git a/server/crates/arbiter-server/tests/user_agent/auth.rs b/server/crates/arbiter-server/tests/user_agent/auth.rs index c79d616..2cf72f9 100644 --- a/server/crates/arbiter-server/tests/user_agent/auth.rs +++ b/server/crates/arbiter-server/tests/user_agent/auth.rs @@ -24,7 +24,7 @@ pub async fn test_bootstrap_token_auth() { let actors = GlobalActors::spawn(db.clone()).await.unwrap(); let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap(); let user_agent = - UserAgentActor::new_manual(db.clone(), actors, tokio::sync::mpsc::channel(1).0); + UserAgentActor::new_manual(db.clone(), actors, super::null_recipient()); let user_agent_ref = UserAgentActor::spawn(user_agent); let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); @@ -69,7 +69,7 @@ pub async fn test_bootstrap_invalid_token_auth() { let actors = GlobalActors::spawn(db.clone()).await.unwrap(); let user_agent = - UserAgentActor::new_manual(db.clone(), actors, tokio::sync::mpsc::channel(1).0); + UserAgentActor::new_manual(db.clone(), actors, super::null_recipient()); let user_agent_ref = UserAgentActor::spawn(user_agent); let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); @@ -85,15 +85,11 @@ pub async fn test_bootstrap_invalid_token_auth() { .await; match result { - Err(kameo::error::SendError::HandlerError(status)) => { - assert_eq!(status.code(), tonic::Code::InvalidArgument); - insta::assert_debug_snapshot!(status, @r#" - Status { - code: InvalidArgument, - message: "Invalid bootstrap token", - source: None, - } - "#); + Err(kameo::error::SendError::HandlerError(err)) => { + assert!( + matches!(err, arbiter_server::actors::user_agent::UserAgentError::InvalidBootstrapToken), + "Expected InvalidBootstrapToken, got {err:?}" + ); } Err(other) => { panic!("Expected SendError::HandlerError, got {other:?}"); @@ -111,7 +107,7 @@ pub async fn test_challenge_auth() { let actors = GlobalActors::spawn(db.clone()).await.unwrap(); let user_agent = - UserAgentActor::new_manual(db.clone(), actors, tokio::sync::mpsc::channel(1).0); + UserAgentActor::new_manual(db.clone(), actors, super::null_recipient()); let user_agent_ref = UserAgentActor::spawn(user_agent); let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); diff --git a/server/crates/arbiter-server/tests/user_agent/unseal.rs b/server/crates/arbiter-server/tests/user_agent/unseal.rs index 9a7c85f..38c17de 100644 --- a/server/crates/arbiter-server/tests/user_agent/unseal.rs +++ b/server/crates/arbiter-server/tests/user_agent/unseal.rs @@ -35,7 +35,7 @@ async fn setup_authenticated_user_agent( actors.key_holder.ask(Seal).await.unwrap(); let user_agent = - UserAgentActor::new_manual(db.clone(), actors.clone(), tokio::sync::mpsc::channel(1).0); + UserAgentActor::new_manual(db.clone(), actors.clone(), super::null_recipient()); let user_agent_ref = UserAgentActor::spawn(user_agent); let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap(); @@ -169,7 +169,7 @@ pub async fn test_unseal_start_without_auth_fails() { let actors = GlobalActors::spawn(db.clone()).await.unwrap(); let user_agent = - UserAgentActor::new_manual(db.clone(), actors, tokio::sync::mpsc::channel(1).0); + UserAgentActor::new_manual(db.clone(), actors, super::null_recipient()); let user_agent_ref = UserAgentActor::spawn(user_agent); let client_secret = EphemeralSecret::random(); @@ -184,8 +184,11 @@ pub async fn test_unseal_start_without_auth_fails() { .await; match result { - Err(kameo::error::SendError::HandlerError(status)) => { - assert_eq!(status.code(), tonic::Code::Internal); + Err(kameo::error::SendError::HandlerError(err)) => { + assert!( + matches!(err, arbiter_server::actors::user_agent::UserAgentError::InvalidState), + "Expected InvalidState, got {err:?}" + ); } other => panic!("Expected state machine error, got {other:?}"), } diff --git a/server/crates/arbiter-useragent/src/lib.rs b/server/crates/arbiter-useragent/src/lib.rs index 9d35ddf..986f6aa 100644 --- a/server/crates/arbiter-useragent/src/lib.rs +++ b/server/crates/arbiter-useragent/src/lib.rs @@ -1,4 +1,4 @@ -use arbiter_proto::proto::UserAgentRequest; +use arbiter_proto::{proto::UserAgentRequest, transport::TransportActor}; use ed25519_dalek::SigningKey; use kameo::{ Actor, Reply, @@ -6,7 +6,6 @@ use kameo::{ prelude::Message, }; use smlang::statemachine; -use tokio::sync::mpsc; use tonic::transport::CertificateDer; use tracing::{debug, error};