diff --git a/server/Cargo.lock b/server/Cargo.lock index 3e8ae9e..b7e6b44 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -73,6 +73,7 @@ dependencies = [ "tonic", "tonic-prost", "tonic-prost-build", + "tracing", "url", ] diff --git a/server/crates/arbiter-proto/Cargo.toml b/server/crates/arbiter-proto/Cargo.toml index 60d27cf..0640004 100644 --- a/server/crates/arbiter-proto/Cargo.toml +++ b/server/crates/arbiter-proto/Cargo.toml @@ -17,7 +17,7 @@ miette.workspace = true thiserror.workspace = true rustls-pki-types.workspace = true base64 = "0.22.1" - +tracing.workspace = true [build-dependencies] tonic-prost-build = "0.14.3" diff --git a/server/crates/arbiter-proto/src/transport.rs b/server/crates/arbiter-proto/src/transport.rs index 691ef9a..becbf0c 100644 --- a/server/crates/arbiter-proto/src/transport.rs +++ b/server/crates/arbiter-proto/src/transport.rs @@ -1,46 +1,125 @@ -use futures::{Stream, StreamExt}; -use tokio::sync::mpsc::{self, error::SendError}; +use std::marker::PhantomData; + +use futures::StreamExt; +use tokio::sync::mpsc; use tonic::{Status, Streaming}; +/// Errors returned by transport adapters implementing [`Bi`]. +pub enum Error { + /// The outbound side of the transport is no longer accepting messages. + ChannelClosed, +} -// Abstraction for stream for sans-io capabilities -pub trait Bi: Stream> + Send + Sync + 'static { - type Error; +/// Minimal bidirectional transport abstraction used by protocol code. +/// +/// `Bi` models a duplex channel with: +/// - inbound items of type `T` read via [`Bi::recv`] +/// - outbound success items of type `U` or domain errors of type `E` written via [`Bi::send`] +/// +/// The trait intentionally exposes only the operations the protocol layer needs, +/// allowing it to work with gRPC streams and other transport implementations. +/// +/// # Stream termination and errors +/// +/// [`Bi::recv`] returns: +/// - `Some(item)` when a new inbound message is available +/// - `None` when the inbound stream ends or the underlying transport reports an error +/// +/// Implementations may collapse transport-specific receive errors into `None` +/// when the protocol does not need to distinguish them from normal stream +/// termination. +pub trait Bi: Send + Sync + 'static { + /// Sends one outbound result to the peer. fn send( &mut self, - item: Result, - ) -> impl std::future::Future> + Send; + item: Result, + ) -> impl std::future::Future> + Send; + + /// Receives the next inbound item. + /// + /// Returns `None` when the inbound stream is finished or can no longer + /// produce items. + fn recv(&mut self) -> impl std::future::Future> + Send; } -// Bi-directional stream abstraction for handling gRPC streaming requests and responses -pub struct BiStream { - pub request_stream: Streaming, - pub response_sender: mpsc::Sender>, +/// [`Bi`] adapter backed by a tonic gRPC bidirectional stream. +/// +/// Outbound items are sent through a Tokio MPSC sender, while inbound items are +/// read from tonic [`Streaming`]. +pub struct GrpcAdapter { + sender: mpsc::Sender>, + receiver: Streaming, + _error: PhantomData, } -impl Stream for BiStream -where - T: Send + 'static, - U: Send + 'static, -{ - type Item = Result; +impl GrpcAdapter { - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.request_stream.poll_next_unpin(cx) + /// Creates a new gRPC-backed [`Bi`] adapter. + pub fn new(sender: mpsc::Sender>, receiver: Streaming) -> Self { + Self { + sender, + receiver, + _error: PhantomData, + } } } -impl Bi for BiStream +impl Bi for GrpcAdapter where - T: Send + 'static, - U: Send + 'static, + Inbound: Send + 'static, + Outbound: Send + 'static, + E: Into + Send + Sync + 'static, { - type Error = SendError>; + #[tracing::instrument(level = "trace", skip(self, item))] + async fn send(&mut self, item: Result) -> Result<(), Error> { + self.sender + .send(item.map_err(Into::into)) + .await + .map_err(|_| Error::ChannelClosed) + } - async fn send(&mut self, item: Result) -> Result<(), Self::Error> { - self.response_sender.send(item).await + #[tracing::instrument(level = "trace", skip(self))] + async fn recv(&mut self) -> Option { + self.receiver.next().await.transpose().ok().flatten() + } +} + +/// No-op [`Bi`] transport for tests and manual actor usage. +/// +/// `send` drops all items and succeeds. [`Bi::recv`] never resolves and therefore +/// does not busy-wait or spuriously close the stream. +pub struct DummyTransport { + _marker: PhantomData<(T, U, E)>, +} + +impl DummyTransport { + pub fn new() -> Self { + Self { + _marker: PhantomData, + } + } +} + +impl Default for DummyTransport { + fn default() -> Self { + Self::new() + } +} + +impl Bi for DummyTransport +where + T: Send + Sync + 'static, + U: Send + Sync + 'static, + E: Send + Sync + 'static, +{ + async fn send(&mut self, _item: Result) -> Result<(), Error> { + Ok(()) + } + + fn recv(&mut self) -> impl std::future::Future> + Send { + async { + std::future::pending::<()>().await; + None + } } } diff --git a/server/crates/arbiter-server/src/actors/client.rs b/server/crates/arbiter-server/src/actors/client.rs index 3828821..b7e2543 100644 --- a/server/crates/arbiter-server/src/actors/client.rs +++ b/server/crates/arbiter-server/src/actors/client.rs @@ -7,6 +7,6 @@ use crate::ServerContext; pub(crate) async fn handle_client( _context: ServerContext, - _bistream: impl Bi, + _bistream: impl Bi, ) { } diff --git a/server/crates/arbiter-server/src/actors/user_agent/mod.rs b/server/crates/arbiter-server/src/actors/user_agent/mod.rs index 700dd40..5299457 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/mod.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/mod.rs @@ -1,20 +1,26 @@ use std::{ops::DerefMut, sync::Mutex}; -use arbiter_proto::proto::{ - UnsealEncryptedKey, UnsealResult, UnsealStart, UnsealStartResponse, UserAgentResponse, - auth::{ - self, AuthChallengeRequest, AuthOk, ServerMessage as AuthServerMessage, - server_message::Payload as ServerAuthPayload, +use arbiter_proto::{ + proto::{ + UnsealEncryptedKey, UnsealResult, UnsealStart, UnsealStartResponse, UserAgentRequest, + UserAgentResponse, + auth::{ + self, AuthChallengeRequest, AuthOk, ClientMessage as ClientAuthMessage, + ServerMessage as AuthServerMessage, client_message::Payload as ClientAuthPayload, + server_message::Payload as ServerAuthPayload, + }, + user_agent_request::Payload as UserAgentRequestPayload, + user_agent_response::Payload as UserAgentResponsePayload, }, - user_agent_response::Payload as UserAgentResponsePayload, + transport::{Bi, DummyTransport}, }; use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit}; use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, dsl::update}; use diesel_async::RunQueryDsl; use ed25519_dalek::VerifyingKey; -use kameo::{Actor, error::SendError, messages}; +use kameo::{Actor, error::SendError}; use memsafe::MemSafe; -use tokio::sync::mpsc::Sender; +use tokio::select; use tonic::Status; use tracing::{error, info}; use x25519_dalek::{EphemeralSecret, PublicKey}; @@ -31,62 +37,149 @@ use crate::{ }, }, db::{self, schema}, - errors::GrpcStatusExt, }; mod state; -mod transport; -pub(crate) use transport::handle_user_agent; +#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] +pub enum UserAgentError { + #[error("Expected message with payload")] + MissingRequestPayload, + #[error("Expected message with payload")] + UnexpectedRequestPayload, + #[error("Invalid state for challenge solution")] + InvalidStateForChallengeSolution, + #[error("Invalid state for unseal encrypted key")] + InvalidStateForUnsealEncryptedKey, + #[error("client_pubkey must be 32 bytes")] + InvalidClientPubkeyLength, + #[error("Expected pubkey to have specific length")] + InvalidAuthPubkeyLength, + #[error("Failed to convert pubkey to VerifyingKey")] + InvalidAuthPubkeyEncoding, + #[error("Invalid signature length")] + InvalidSignatureLength, + #[error("Invalid bootstrap token")] + InvalidBootstrapToken, + #[error("Public key not registered")] + PublicKeyNotRegistered, + #[error("Invalid challenge solution")] + InvalidChallengeSolution, + #[error("State machine error")] + StateTransitionFailed, + #[error("Bootstrap token consumption failed")] + BootstrapperActorUnreachable, + #[error("Vault is not available")] + KeyHolderActorUnreachable, + #[error("Database pool error")] + DatabasePoolUnavailable, + #[error("Database error")] + DatabaseOperationFailed, +} -#[derive(Actor)] -pub struct UserAgentActor { +impl From for Status { + fn from(value: UserAgentError) -> Self { + match value { + UserAgentError::MissingRequestPayload | UserAgentError::UnexpectedRequestPayload => { + Status::invalid_argument("Expected message with payload") + } + UserAgentError::InvalidStateForChallengeSolution => { + Status::invalid_argument("Invalid state for challenge solution") + } + UserAgentError::InvalidStateForUnsealEncryptedKey => { + Status::failed_precondition("Invalid state for unseal encrypted key") + } + UserAgentError::InvalidClientPubkeyLength => { + Status::invalid_argument("client_pubkey must be 32 bytes") + } + UserAgentError::InvalidAuthPubkeyLength => { + Status::invalid_argument("Expected pubkey to have specific length") + } + UserAgentError::InvalidAuthPubkeyEncoding => { + Status::invalid_argument("Failed to convert pubkey to VerifyingKey") + } + UserAgentError::InvalidSignatureLength => { + Status::invalid_argument("Invalid signature length") + } + UserAgentError::InvalidBootstrapToken => { + Status::invalid_argument("Invalid bootstrap token") + } + UserAgentError::PublicKeyNotRegistered => { + Status::unauthenticated("Public key not registered") + } + UserAgentError::InvalidChallengeSolution => { + Status::unauthenticated("Invalid challenge solution") + } + UserAgentError::StateTransitionFailed => Status::internal("State machine error"), + UserAgentError::BootstrapperActorUnreachable => { + Status::internal("Bootstrap token consumption failed") + } + UserAgentError::KeyHolderActorUnreachable => Status::internal("Vault is not available"), + UserAgentError::DatabasePoolUnavailable => Status::internal("Database pool error"), + UserAgentError::DatabaseOperationFailed => Status::internal("Database error"), + } + } +} + +pub struct UserAgentActor +where + Transport: Bi, +{ db: db::DatabasePool, actors: GlobalActors, state: UserAgentStateMachine, - // will be used in future - _tx: Sender>, + transport: Transport, } -impl UserAgentActor { - pub(crate) fn new( - context: ServerContext, - tx: Sender>, - ) -> Self { +impl UserAgentActor +where + Transport: Bi, +{ + pub(crate) fn new(context: ServerContext, transport: Transport) -> Self { Self { db: context.db.clone(), actors: context.actors.clone(), state: UserAgentStateMachine::new(DummyContext), - _tx: tx, + transport, } } - pub fn new_manual( - db: db::DatabasePool, - actors: GlobalActors, - tx: Sender>, - ) -> Self { - Self { - db, - actors, - state: UserAgentStateMachine::new(DummyContext), - _tx: tx, - } - } - - fn transition(&mut self, event: UserAgentEvents) -> Result<(), Status> { + fn transition(&mut self, event: UserAgentEvents) -> Result<(), UserAgentError> { self.state.process_event(event).map_err(|e| { error!(?e, "State transition failed"); - Status::internal("State machine error") + UserAgentError::StateTransitionFailed })?; Ok(()) } + pub async fn process_transport_inbound(&mut self, req: UserAgentRequest) -> Output { + let msg = req.payload.ok_or_else(|| { + error!(actor = "useragent", "Received message with no payload"); + UserAgentError::MissingRequestPayload + })?; + + match msg { + UserAgentRequestPayload::AuthMessage(ClientAuthMessage { + payload: Some(ClientAuthPayload::AuthChallengeRequest(req)), + }) => self.handle_auth_challenge_request(req).await, + UserAgentRequestPayload::AuthMessage(ClientAuthMessage { + payload: Some(ClientAuthPayload::AuthChallengeSolution(solution)), + }) => self.handle_auth_challenge_solution(solution).await, + UserAgentRequestPayload::UnsealStart(unseal_start) => { + self.handle_unseal_request(unseal_start).await + } + UserAgentRequestPayload::UnsealEncryptedKey(unseal_encrypted_key) => { + self.handle_unseal_encrypted_key(unseal_encrypted_key).await + } + _ => Err(UserAgentError::UnexpectedRequestPayload), + } + } + async fn auth_with_bootstrap_token( &mut self, pubkey: ed25519_dalek::VerifyingKey, token: String, - ) -> Result { + ) -> Result { let token_ok: bool = self .actors .bootstrapper @@ -94,16 +187,19 @@ impl UserAgentActor { .await .map_err(|e| { error!(?pubkey, "Failed to consume bootstrap token: {e}"); - Status::internal("Bootstrap token consumption failed") + UserAgentError::BootstrapperActorUnreachable })?; if !token_ok { error!(?pubkey, "Invalid bootstrap token provided"); - return Err(Status::invalid_argument("Invalid bootstrap token")); + return Err(UserAgentError::InvalidBootstrapToken); } { - let mut conn = self.db.get().await.to_status()?; + let mut conn = self.db.get().await.map_err(|e| { + error!(error = ?e, "Database pool error"); + UserAgentError::DatabasePoolUnavailable + })?; diesel::insert_into(schema::useragent_client::table) .values(( @@ -112,7 +208,10 @@ impl UserAgentActor { )) .execute(&mut conn) .await - .to_status()?; + .map_err(|e| { + error!(error = ?e, "Database error"); + UserAgentError::DatabaseOperationFailed + })?; } self.transition(UserAgentEvents::ReceivedBootstrapToken)?; @@ -122,7 +221,10 @@ impl UserAgentActor { async fn auth_with_challenge(&mut self, pubkey: VerifyingKey, pubkey_bytes: Vec) -> Output { let nonce: Option = { - let mut db_conn = self.db.get().await.to_status()?; + let mut db_conn = self.db.get().await.map_err(|e| { + error!(error = ?e, "Database pool error"); + UserAgentError::DatabasePoolUnavailable + })?; db_conn .exclusive_transaction(|conn| { Box::pin(async move { @@ -147,12 +249,15 @@ impl UserAgentActor { }) .await .optional() - .to_status()? + .map_err(|e| { + error!(error = ?e, "Database error"); + UserAgentError::DatabaseOperationFailed + })? }; let Some(nonce) = nonce else { error!(?pubkey, "Public key not found in database"); - return Err(Status::unauthenticated("Public key not registered")); + return Err(UserAgentError::PublicKeyNotRegistered); }; let challenge = auth::AuthChallenge { @@ -177,19 +282,17 @@ impl UserAgentActor { fn verify_challenge_solution( &self, solution: &auth::AuthChallengeSolution, - ) -> Result<(bool, &ChallengeContext), Status> { + ) -> Result<(bool, &ChallengeContext), UserAgentError> { let UserAgentStates::WaitingForChallengeSolution(challenge_context) = self.state.state() else { error!("Received challenge solution in invalid state"); - return Err(Status::invalid_argument( - "Invalid state for challenge solution", - )); + return Err(UserAgentError::InvalidStateForChallengeSolution); }; let formatted_challenge = arbiter_proto::format_challenge(&challenge_context.challenge); let signature = solution.signature.as_slice().try_into().map_err(|_| { error!(?solution, "Invalid signature length"); - Status::invalid_argument("Invalid signature length") + UserAgentError::InvalidSignatureLength })?; let valid = challenge_context @@ -201,7 +304,7 @@ impl UserAgentActor { } } -type Output = Result; +type Output = Result; fn auth_response(payload: ServerAuthPayload) -> UserAgentResponse { UserAgentResponse { @@ -217,17 +320,18 @@ fn unseal_response(payload: UserAgentResponsePayload) -> UserAgentResponse { } } -#[messages] -impl UserAgentActor { - #[message] - pub async fn handle_unseal_request(&mut self, req: UnsealStart) -> Output { +impl UserAgentActor +where + Transport: Bi, +{ + async fn handle_unseal_request(&mut self, req: UnsealStart) -> Output { let secret = EphemeralSecret::random(); let public_key = PublicKey::from(&secret); let client_pubkey_bytes: [u8; 32] = req .client_pubkey .try_into() - .map_err(|_| Status::invalid_argument("client_pubkey must be 32 bytes"))?; + .map_err(|_| UserAgentError::InvalidClientPubkeyLength)?; let client_public_key = PublicKey::from(client_pubkey_bytes); @@ -243,13 +347,10 @@ impl UserAgentActor { )) } - #[message] - pub async fn handle_unseal_encrypted_key(&mut self, req: UnsealEncryptedKey) -> Output { + async fn handle_unseal_encrypted_key(&mut self, req: UnsealEncryptedKey) -> Output { let UserAgentStates::WaitingForUnsealKey(unseal_context) = self.state.state() else { error!("Received unseal encrypted key in invalid state"); - return Err(Status::failed_precondition( - "Invalid state for unseal encrypted key", - )); + return Err(UserAgentError::InvalidStateForUnsealEncryptedKey); }; let ephemeral_secret = { let mut secret_lock = unseal_context.secret.lock().unwrap(); @@ -313,7 +414,7 @@ impl UserAgentActor { Err(err) => { error!(?err, "Failed to send unseal request to keyholder"); self.transition(UserAgentEvents::ReceivedInvalidKey)?; - Err(Status::internal("Vault is not available")) + Err(UserAgentError::KeyHolderActorUnreachable) } } } @@ -327,14 +428,14 @@ impl UserAgentActor { } } - #[message] - pub async fn handle_auth_challenge_request(&mut self, req: AuthChallengeRequest) -> Output { - let pubkey = req.pubkey.as_array().ok_or(Status::invalid_argument( - "Expected pubkey to have specific length", - ))?; + async fn handle_auth_challenge_request(&mut self, req: AuthChallengeRequest) -> Output { + let pubkey = req + .pubkey + .as_array() + .ok_or(UserAgentError::InvalidAuthPubkeyLength)?; let pubkey = VerifyingKey::from_bytes(pubkey).map_err(|_err| { error!(?pubkey, "Failed to convert to VerifyingKey"); - Status::invalid_argument("Failed to convert pubkey to VerifyingKey") + UserAgentError::InvalidAuthPubkeyEncoding })?; self.transition(UserAgentEvents::AuthRequest)?; @@ -345,8 +446,7 @@ impl UserAgentActor { } } - #[message] - pub async fn handle_auth_challenge_solution( + async fn handle_auth_challenge_solution( &mut self, solution: auth::AuthChallengeSolution, ) -> Output { @@ -362,7 +462,72 @@ impl UserAgentActor { } else { error!("Client provided invalid solution to authentication challenge"); self.transition(UserAgentEvents::ReceivedBadSolution)?; - Err(Status::unauthenticated("Invalid challenge solution")) + Err(UserAgentError::InvalidChallengeSolution) } } } + + +impl Actor for UserAgentActor +where + Transport: Bi, +{ + type Args = Self; + + type Error = (); + + async fn on_start( + args: Self::Args, + _: kameo::prelude::ActorRef, + ) -> Result { + Ok(args) + } + + async fn next( + &mut self, + _actor_ref: kameo::prelude::WeakActorRef, + mailbox_rx: &mut kameo::prelude::MailboxReceiver, + ) -> Option> { + loop { + select! { + signal = mailbox_rx.recv() => { + return signal; + } + msg = self.transport.recv() => { + match msg { + Some(request) => { + match self.process_transport_inbound(request).await { + Ok(response) => { + if self.transport.send(Ok(response)).await.is_err() { + error!(actor = "useragent", reason = "channel closed", "send.failed"); + return Some(kameo::mailbox::Signal::Stop); + } + } + Err(err) => { + let _ = self.transport.send(Err(err)).await; + return Some(kameo::mailbox::Signal::Stop); + } + } + } + None => { + info!(actor = "useragent", "transport.closed"); + return Some(kameo::mailbox::Signal::Stop); + } + } + } + } + } + } +} + + +impl UserAgentActor> { + pub fn new_manual(db: db::DatabasePool, actors: GlobalActors) -> Self { + Self { + db, + actors, + state: UserAgentStateMachine::new(DummyContext), + transport: DummyTransport::new(), + } + } +} \ No newline at end of file diff --git a/server/crates/arbiter-server/src/actors/user_agent/transport.rs b/server/crates/arbiter-server/src/actors/user_agent/transport.rs deleted file mode 100644 index c1ac84c..0000000 --- a/server/crates/arbiter-server/src/actors/user_agent/transport.rs +++ /dev/null @@ -1,95 +0,0 @@ -use super::UserAgentActor; -use arbiter_proto::proto::{ - UserAgentRequest, UserAgentResponse, - auth::{ClientMessage as ClientAuthMessage, client_message::Payload as ClientAuthPayload}, - user_agent_request::Payload as UserAgentRequestPayload, -}; -use futures::StreamExt; -use kameo::{ - actor::{ActorRef, Spawn as _}, - error::SendError, -}; -use tokio::sync::mpsc; -use tonic::Status; -use tracing::error; - -use crate::{ - actors::user_agent::{ - HandleAuthChallengeRequest, HandleAuthChallengeSolution, HandleUnsealEncryptedKey, - HandleUnsealRequest, - }, - context::ServerContext, -}; - -pub(crate) async fn handle_user_agent( - context: ServerContext, - mut req_stream: tonic::Streaming, - tx: mpsc::Sender>, -) { - let actor = UserAgentActor::spawn(UserAgentActor::new(context, tx.clone())); - - while let Some(Ok(req)) = req_stream.next().await - && actor.is_alive() - { - match process_message(&actor, req).await { - Ok(resp) => { - if tx.send(Ok(resp)).await.is_err() { - error!(actor = "useragent", "Failed to send response to client"); - break; - } - } - Err(status) => { - let _ = tx.send(Err(status)).await; - break; - } - } - } - - actor.kill(); -} - -async fn process_message( - actor: &ActorRef, - req: UserAgentRequest, -) -> Result { - let msg = req.payload.ok_or_else(|| { - error!(actor = "useragent", "Received message with no payload"); - Status::invalid_argument("Expected message with payload") - })?; - - match msg { - UserAgentRequestPayload::AuthMessage(ClientAuthMessage { - payload: Some(ClientAuthPayload::AuthChallengeRequest(req)), - }) => actor - .ask(HandleAuthChallengeRequest { req }) - .await - .map_err(into_status), - UserAgentRequestPayload::AuthMessage(ClientAuthMessage { - payload: Some(ClientAuthPayload::AuthChallengeSolution(solution)), - }) => actor - .ask(HandleAuthChallengeSolution { solution }) - .await - .map_err(into_status), - UserAgentRequestPayload::UnsealStart(unseal_start) => actor - .ask(HandleUnsealRequest { req: unseal_start }) - .await - .map_err(into_status), - UserAgentRequestPayload::UnsealEncryptedKey(unseal_encrypted_key) => actor - .ask(HandleUnsealEncryptedKey { - req: unseal_encrypted_key, - }) - .await - .map_err(into_status), - _ => Err(Status::invalid_argument("Expected message with payload")), - } -} - -fn into_status(e: SendError) -> Status { - match e { - SendError::HandlerError(status) => status, - _ => { - error!(actor = "useragent", "Failed to send message to actor"); - Status::internal("session failure") - } - } -} diff --git a/server/crates/arbiter-server/src/lib.rs b/server/crates/arbiter-server/src/lib.rs index 9d86e27..9f6fcd9 100644 --- a/server/crates/arbiter-server/src/lib.rs +++ b/server/crates/arbiter-server/src/lib.rs @@ -1,16 +1,18 @@ #![forbid(unsafe_code)] use arbiter_proto::{ proto::{ClientRequest, ClientResponse, UserAgentRequest, UserAgentResponse}, - transport::BiStream, + transport::GrpcAdapter, }; use async_trait::async_trait; +use kameo::actor::Spawn; use tokio_stream::wrappers::ReceiverStream; use tokio::sync::mpsc; use tonic::{Request, Response, Status}; +use tracing::info; use crate::{ - actors::{client::handle_client, user_agent::handle_user_agent}, + actors::user_agent::UserAgentActor, context::ServerContext, }; @@ -38,28 +40,25 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server { async fn client( &self, - request: Request>, + _request: Request>, ) -> Result, Status> { - let req_stream = request.into_inner(); - let (tx, rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); - tokio::spawn(handle_client( - self.context.clone(), - BiStream { - request_stream: req_stream, - response_sender: tx, - }, - )); - - Ok(Response::new(ReceiverStream::new(rx))) + todo!() } + #[tracing::instrument(level = "debug", skip(self))] async fn user_agent( &self, request: Request>, ) -> Result, Status> { let req_stream = request.into_inner(); let (tx, rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); - tokio::spawn(handle_user_agent(self.context.clone(), req_stream, tx)); + + let adapter = GrpcAdapter::new(tx, req_stream); + + UserAgentActor::spawn(UserAgentActor::new(self.context.clone(), adapter)); + + info!(event = "connection established", "grpc.user_agent"); + Ok(Response::new(ReceiverStream::new(rx))) } } diff --git a/server/crates/arbiter-server/tests/user_agent/auth.rs b/server/crates/arbiter-server/tests/user_agent/auth.rs index c79d616..6604a52 100644 --- a/server/crates/arbiter-server/tests/user_agent/auth.rs +++ b/server/crates/arbiter-server/tests/user_agent/auth.rs @@ -1,20 +1,29 @@ use arbiter_proto::proto::{ UserAgentResponse, - auth::{self, AuthChallengeRequest, AuthOk}, + UserAgentRequest, + auth::{self, AuthChallengeRequest, AuthOk, ClientMessage, client_message::Payload as ClientAuthPayload}, + user_agent_request::Payload as UserAgentRequestPayload, user_agent_response::Payload as UserAgentResponsePayload, }; use arbiter_server::{ actors::{ GlobalActors, bootstrap::GetToken, - user_agent::{HandleAuthChallengeRequest, HandleAuthChallengeSolution, UserAgentActor}, + user_agent::{UserAgentActor, UserAgentError}, }, db::{self, schema}, }; use diesel::{ExpressionMethods as _, QueryDsl, insert_into}; use diesel_async::RunQueryDsl; use ed25519_dalek::Signer as _; -use kameo::actor::Spawn; + +fn auth_request(payload: ClientAuthPayload) -> UserAgentRequest { + UserAgentRequest { + payload: Some(UserAgentRequestPayload::AuthMessage(ClientMessage { + payload: Some(payload), + })), + } +} #[tokio::test] #[test_log::test] @@ -23,22 +32,20 @@ pub async fn test_bootstrap_token_auth() { let actors = GlobalActors::spawn(db.clone()).await.unwrap(); let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap(); - let user_agent = - UserAgentActor::new_manual(db.clone(), actors, tokio::sync::mpsc::channel(1).0); - let user_agent_ref = UserAgentActor::spawn(user_agent); + let mut user_agent = UserAgentActor::new_manual(db.clone(), actors); let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); - let result = user_agent_ref - .ask(HandleAuthChallengeRequest { - req: AuthChallengeRequest { + let result = user_agent + .process_transport_inbound(auth_request(ClientAuthPayload::AuthChallengeRequest( + AuthChallengeRequest { pubkey: pubkey_bytes, bootstrap_token: Some(token), }, - }) + ))) .await - .expect("Shouldn't fail to send message"); + .expect("Shouldn't fail to process message"); assert_eq!( result, @@ -68,35 +75,25 @@ pub async fn test_bootstrap_invalid_token_auth() { let db = db::create_test_pool().await; let actors = GlobalActors::spawn(db.clone()).await.unwrap(); - let user_agent = - UserAgentActor::new_manual(db.clone(), actors, tokio::sync::mpsc::channel(1).0); - let user_agent_ref = UserAgentActor::spawn(user_agent); + let mut user_agent = UserAgentActor::new_manual(db.clone(), actors); let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); - let result = user_agent_ref - .ask(HandleAuthChallengeRequest { - req: AuthChallengeRequest { + let result = user_agent + .process_transport_inbound(auth_request(ClientAuthPayload::AuthChallengeRequest( + AuthChallengeRequest { pubkey: pubkey_bytes, bootstrap_token: Some("invalid_token".to_string()), }, - }) + ))) .await; match result { - Err(kameo::error::SendError::HandlerError(status)) => { + Err(err) => { + assert_eq!(err, UserAgentError::InvalidBootstrapToken); + let status: tonic::Status = err.into(); assert_eq!(status.code(), tonic::Code::InvalidArgument); - insta::assert_debug_snapshot!(status, @r#" - Status { - code: InvalidArgument, - message: "Invalid bootstrap token", - source: None, - } - "#); - } - Err(other) => { - panic!("Expected SendError::HandlerError, got {other:?}"); } Ok(_) => { panic!("Expected error due to invalid bootstrap token, but got success"); @@ -110,9 +107,7 @@ pub async fn test_challenge_auth() { let db = db::create_test_pool().await; let actors = GlobalActors::spawn(db.clone()).await.unwrap(); - let user_agent = - UserAgentActor::new_manual(db.clone(), actors, tokio::sync::mpsc::channel(1).0); - let user_agent_ref = UserAgentActor::spawn(user_agent); + let mut user_agent = UserAgentActor::new_manual(db.clone(), actors); let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); @@ -126,15 +121,15 @@ pub async fn test_challenge_auth() { .unwrap(); } - let result = user_agent_ref - .ask(HandleAuthChallengeRequest { - req: AuthChallengeRequest { + let result = user_agent + .process_transport_inbound(auth_request(ClientAuthPayload::AuthChallengeRequest( + AuthChallengeRequest { pubkey: pubkey_bytes, bootstrap_token: None, }, - }) + ))) .await - .expect("Shouldn't fail to send message"); + .expect("Shouldn't fail to process message"); let UserAgentResponse { payload: @@ -151,14 +146,14 @@ pub async fn test_challenge_auth() { let signature = new_key.sign(&formatted_challenge); let serialized_signature = signature.to_bytes().to_vec(); - let result = user_agent_ref - .ask(HandleAuthChallengeSolution { - solution: auth::AuthChallengeSolution { + let result = user_agent + .process_transport_inbound(auth_request(ClientAuthPayload::AuthChallengeSolution( + auth::AuthChallengeSolution { signature: serialized_signature, }, - }) + ))) .await - .expect("Shouldn't fail to send message"); + .expect("Shouldn't fail to process message"); assert_eq!( result, diff --git a/server/crates/arbiter-server/tests/user_agent/unseal.rs b/server/crates/arbiter-server/tests/user_agent/unseal.rs index 9a7c85f..dc22270 100644 --- a/server/crates/arbiter-server/tests/user_agent/unseal.rs +++ b/server/crates/arbiter-server/tests/user_agent/unseal.rs @@ -1,27 +1,51 @@ use arbiter_proto::proto::{ - UnsealEncryptedKey, UnsealResult, UnsealStart, auth::AuthChallengeRequest, + UnsealEncryptedKey, UnsealResult, UnsealStart, UserAgentRequest, UserAgentResponse, + auth::{AuthChallengeRequest, ClientMessage, client_message::Payload as ClientAuthPayload}, + user_agent_request::Payload as UserAgentRequestPayload, user_agent_response::Payload as UserAgentResponsePayload, }; +use arbiter_proto::transport::DummyTransport; use arbiter_server::{ actors::{ GlobalActors, bootstrap::GetToken, keyholder::{Bootstrap, Seal}, - user_agent::{ - HandleAuthChallengeRequest, HandleUnsealEncryptedKey, HandleUnsealRequest, - UserAgentActor, - }, + user_agent::{UserAgentActor, UserAgentError}, }, db, }; use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit}; -use kameo::actor::{ActorRef, Spawn}; use memsafe::MemSafe; use x25519_dalek::{EphemeralSecret, PublicKey}; +type TestUserAgent = UserAgentActor>; + +fn auth_request(payload: ClientAuthPayload) -> UserAgentRequest { + UserAgentRequest { + payload: Some(UserAgentRequestPayload::AuthMessage(ClientMessage { + payload: Some(payload), + })), + } +} + +fn unseal_start_request(req: UnsealStart) -> UserAgentRequest { + UserAgentRequest { + payload: Some(UserAgentRequestPayload::UnsealStart(req)), + } +} + +fn unseal_key_request(req: UnsealEncryptedKey) -> UserAgentRequest { + UserAgentRequest { + payload: Some(UserAgentRequestPayload::UnsealEncryptedKey(req)), + } +} + async fn setup_authenticated_user_agent( seal_key: &[u8], -) -> (arbiter_server::db::DatabasePool, ActorRef) { +) -> ( + arbiter_server::db::DatabasePool, + TestUserAgent, +) { let db = db::create_test_pool().await; let actors = GlobalActors::spawn(db.clone()).await.unwrap(); @@ -34,38 +58,34 @@ async fn setup_authenticated_user_agent( .unwrap(); actors.key_holder.ask(Seal).await.unwrap(); - let user_agent = - UserAgentActor::new_manual(db.clone(), actors.clone(), tokio::sync::mpsc::channel(1).0); - let user_agent_ref = UserAgentActor::spawn(user_agent); + let mut user_agent = UserAgentActor::new_manual(db.clone(), actors.clone()); let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap(); let auth_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); - user_agent_ref - .ask(HandleAuthChallengeRequest { - req: AuthChallengeRequest { + user_agent + .process_transport_inbound(auth_request(ClientAuthPayload::AuthChallengeRequest( + AuthChallengeRequest { pubkey: auth_key.verifying_key().to_bytes().to_vec(), bootstrap_token: Some(token), }, - }) + ))) .await .unwrap(); - (db, user_agent_ref) + (db, user_agent) } async fn client_dh_encrypt( - user_agent_ref: &ActorRef, + user_agent: &mut TestUserAgent, key_to_send: &[u8], ) -> UnsealEncryptedKey { let client_secret = EphemeralSecret::random(); let client_public = PublicKey::from(&client_secret); - let response = user_agent_ref - .ask(HandleUnsealRequest { - req: UnsealStart { - client_pubkey: client_public.as_bytes().to_vec(), - }, - }) + let response = user_agent + .process_transport_inbound(unseal_start_request(UnsealStart { + client_pubkey: client_public.as_bytes().to_vec(), + })) .await .unwrap(); @@ -95,12 +115,12 @@ async fn client_dh_encrypt( #[test_log::test] pub async fn test_unseal_success() { let seal_key = b"test-seal-key"; - let (_db, user_agent_ref) = setup_authenticated_user_agent(seal_key).await; + let (_db, mut user_agent) = setup_authenticated_user_agent(seal_key).await; - let encrypted_key = client_dh_encrypt(&user_agent_ref, seal_key).await; + let encrypted_key = client_dh_encrypt(&mut user_agent, seal_key).await; - let response = user_agent_ref - .ask(HandleUnsealEncryptedKey { req: encrypted_key }) + let response = user_agent + .process_transport_inbound(unseal_key_request(encrypted_key)) .await .unwrap(); @@ -113,12 +133,12 @@ pub async fn test_unseal_success() { #[tokio::test] #[test_log::test] pub async fn test_unseal_wrong_seal_key() { - let (_db, user_agent_ref) = setup_authenticated_user_agent(b"correct-key").await; + let (_db, mut user_agent) = setup_authenticated_user_agent(b"correct-key").await; - let encrypted_key = client_dh_encrypt(&user_agent_ref, b"wrong-key").await; + let encrypted_key = client_dh_encrypt(&mut user_agent, b"wrong-key").await; - let response = user_agent_ref - .ask(HandleUnsealEncryptedKey { req: encrypted_key }) + let response = user_agent + .process_transport_inbound(unseal_key_request(encrypted_key)) .await .unwrap(); @@ -131,28 +151,24 @@ pub async fn test_unseal_wrong_seal_key() { #[tokio::test] #[test_log::test] pub async fn test_unseal_corrupted_ciphertext() { - let (_db, user_agent_ref) = setup_authenticated_user_agent(b"test-key").await; + let (_db, mut user_agent) = setup_authenticated_user_agent(b"test-key").await; let client_secret = EphemeralSecret::random(); let client_public = PublicKey::from(&client_secret); - user_agent_ref - .ask(HandleUnsealRequest { - req: UnsealStart { - client_pubkey: client_public.as_bytes().to_vec(), - }, - }) + user_agent + .process_transport_inbound(unseal_start_request(UnsealStart { + client_pubkey: client_public.as_bytes().to_vec(), + })) .await .unwrap(); - let response = user_agent_ref - .ask(HandleUnsealEncryptedKey { - req: UnsealEncryptedKey { - nonce: vec![0u8; 24], - ciphertext: vec![0u8; 32], - associated_data: vec![], - }, - }) + let response = user_agent + .process_transport_inbound(unseal_key_request(UnsealEncryptedKey { + nonce: vec![0u8; 24], + ciphertext: vec![0u8; 32], + associated_data: vec![], + })) .await .unwrap(); @@ -168,24 +184,20 @@ pub async fn test_unseal_start_without_auth_fails() { let db = db::create_test_pool().await; let actors = GlobalActors::spawn(db.clone()).await.unwrap(); - let user_agent = - UserAgentActor::new_manual(db.clone(), actors, tokio::sync::mpsc::channel(1).0); - let user_agent_ref = UserAgentActor::spawn(user_agent); + let mut user_agent = UserAgentActor::new_manual(db.clone(), actors); let client_secret = EphemeralSecret::random(); let client_public = PublicKey::from(&client_secret); - let result = user_agent_ref - .ask(HandleUnsealRequest { - req: UnsealStart { - client_pubkey: client_public.as_bytes().to_vec(), - }, - }) + let result = user_agent + .process_transport_inbound(unseal_start_request(UnsealStart { + client_pubkey: client_public.as_bytes().to_vec(), + })) .await; match result { - Err(kameo::error::SendError::HandlerError(status)) => { - assert_eq!(status.code(), tonic::Code::Internal); + Err(err) => { + assert_eq!(err, UserAgentError::StateTransitionFailed); } other => panic!("Expected state machine error, got {other:?}"), } @@ -195,13 +207,13 @@ pub async fn test_unseal_start_without_auth_fails() { #[test_log::test] pub async fn test_unseal_retry_after_invalid_key() { let seal_key = b"real-seal-key"; - let (_db, user_agent_ref) = setup_authenticated_user_agent(seal_key).await; + let (_db, mut user_agent) = setup_authenticated_user_agent(seal_key).await; { - let encrypted_key = client_dh_encrypt(&user_agent_ref, b"wrong-key").await; + let encrypted_key = client_dh_encrypt(&mut user_agent, b"wrong-key").await; - let response = user_agent_ref - .ask(HandleUnsealEncryptedKey { req: encrypted_key }) + let response = user_agent + .process_transport_inbound(unseal_key_request(encrypted_key)) .await .unwrap(); @@ -212,10 +224,10 @@ pub async fn test_unseal_retry_after_invalid_key() { } { - let encrypted_key = client_dh_encrypt(&user_agent_ref, seal_key).await; + let encrypted_key = client_dh_encrypt(&mut user_agent, seal_key).await; - let response = user_agent_ref - .ask(HandleUnsealEncryptedKey { req: encrypted_key }) + let response = user_agent + .process_transport_inbound(unseal_key_request(encrypted_key)) .await .unwrap();