From 2ff4d0961c29674652ffcf7d48dc67597eae2a04 Mon Sep 17 00:00:00 2001 From: hdbg Date: Wed, 18 Mar 2026 22:40:07 +0100 Subject: [PATCH] refactor(server::client): migrated to new connection design --- protobufs/client.proto | 33 ++- .../arbiter-server/src/actors/client/auth.rs | 114 ++++------ .../arbiter-server/src/actors/client/mod.rs | 57 +---- .../src/actors/client/session.rs | 84 +++---- .../src/actors/user_agent/auth/state.rs | 11 +- .../crates/arbiter-server/src/grpc/client.rs | 214 ++++++++---------- .../arbiter-server/src/grpc/client/auth.rs | 131 +++++++++++ server/crates/arbiter-server/src/grpc/mod.rs | 19 +- .../arbiter-server/src/grpc/user_agent.rs | 131 ++++++----- server/crates/arbiter-server/src/lib.rs | 7 +- .../arbiter-server/tests/client/auth.rs | 45 ++-- .../crates/arbiter-server/tests/common/mod.rs | 25 +- .../arbiter-server/tests/user_agent/auth.rs | 2 +- .../arbiter-server/tests/user_agent/unseal.rs | 2 +- 14 files changed, 474 insertions(+), 401 deletions(-) create mode 100644 server/crates/arbiter-server/src/grpc/client/auth.rs diff --git a/protobufs/client.proto b/protobufs/client.proto index 62761c3..dbe9708 100644 --- a/protobufs/client.proto +++ b/protobufs/client.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package arbiter.client; import "evm.proto"; +import "google/protobuf/empty.proto"; message AuthChallengeRequest { bytes pubkey = 1; @@ -17,30 +18,38 @@ message AuthChallengeSolution { bytes signature = 1; } -message AuthOk {} +enum AuthResult { + AUTH_RESULT_UNSPECIFIED = 0; + AUTH_RESULT_SUCCESS = 1; + AUTH_RESULT_INVALID_KEY = 2; + AUTH_RESULT_INVALID_SIGNATURE = 3; + AUTH_RESULT_APPROVAL_DENIED = 4; + AUTH_RESULT_NO_USER_AGENTS_ONLINE = 5; + AUTH_RESULT_INTERNAL = 6; +} + +enum VaultState { + VAULT_STATE_UNSPECIFIED = 0; + VAULT_STATE_UNBOOTSTRAPPED = 1; + VAULT_STATE_SEALED = 2; + VAULT_STATE_UNSEALED = 3; + VAULT_STATE_ERROR = 4; +} message ClientRequest { oneof payload { AuthChallengeRequest auth_challenge_request = 1; AuthChallengeSolution auth_challenge_solution = 2; + google.protobuf.Empty query_vault_state = 3; } } -message ClientConnectError { - enum Code { - UNKNOWN = 0; - APPROVAL_DENIED = 1; - NO_USER_AGENTS_ONLINE = 2; - } - Code code = 1; -} - message ClientResponse { oneof payload { AuthChallenge auth_challenge = 1; - AuthOk auth_ok = 2; - ClientConnectError client_connect_error = 5; + AuthResult auth_result = 2; arbiter.evm.EvmSignTransactionResponse evm_sign_transaction = 3; arbiter.evm.EvmAnalyzeTransactionResponse evm_analyze_transaction = 4; + VaultState vault_state = 6; } } diff --git a/server/crates/arbiter-server/src/actors/client/auth.rs b/server/crates/arbiter-server/src/actors/client/auth.rs index c69fb77..ffd425a 100644 --- a/server/crates/arbiter-server/src/actors/client/auth.rs +++ b/server/crates/arbiter-server/src/actors/client/auth.rs @@ -1,30 +1,25 @@ -use arbiter_proto::{format_challenge, transport::expect_message}; +use arbiter_proto::{ + format_challenge, + transport::{Bi, expect_message}, +}; use diesel::{ ExpressionMethods as _, OptionalExtension as _, QueryDsl as _, dsl::insert_into, update, }; use diesel_async::RunQueryDsl as _; -use ed25519_dalek::VerifyingKey; +use ed25519_dalek::{Signature, VerifyingKey}; use kameo::error::SendError; use tracing::error; use crate::{ actors::{ - client::{ClientConnection, ConnectErrorCode, Request, Response}, + client::ClientConnection, router::{self, RequestClientApproval}, }, db::{self, schema::program_client}, }; -use super::session::ClientSession; - #[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)] pub enum Error { - #[error("Unexpected message payload")] - UnexpectedMessagePayload, - #[error("Invalid client public key length")] - InvalidClientPubkeyLength, - #[error("Invalid client public key encoding")] - InvalidAuthPubkeyEncoding, #[error("Database pool unavailable")] DatabasePoolUnavailable, #[error("Database operation failed")] @@ -33,8 +28,6 @@ pub enum Error { InvalidChallengeSolution, #[error("Client approval request failed")] ApproveError(#[from] ApproveError), - #[error("Internal error")] - InternalError, #[error("Transport error")] Transport, } @@ -49,6 +42,18 @@ pub enum ApproveError { Upstream(router::ApprovalError), } +#[derive(Debug, Clone)] +pub enum Inbound { + AuthChallengeRequest { pubkey: VerifyingKey }, + AuthChallengeSolution { signature: Signature }, +} + +#[derive(Debug, Clone)] +pub enum Outbound { + AuthChallenge { pubkey: VerifyingKey, nonce: i32 }, + AuthSuccess, +} + /// Atomically reads and increments the nonce for a known client. /// Returns `None` if the pubkey is not registered. async fn get_nonce(db: &db::DatabasePool, pubkey: &VerifyingKey) -> Result, Error> { @@ -141,27 +146,24 @@ async fn insert_client(db: &db::DatabasePool, pubkey: &VerifyingKey) -> Result<( Ok(()) } -async fn challenge_client( - props: &mut ClientConnection, +async fn challenge_client( + transport: &mut T, pubkey: VerifyingKey, nonce: i32, -) -> Result<(), Error> { - let challenge_pubkey = pubkey.as_bytes().to_vec(); - - props - .transport - .send(Ok(Response::AuthChallenge { - pubkey: challenge_pubkey.clone(), - nonce, - })) +) -> Result<(), Error> +where + T: Bi> + ?Sized, +{ + transport + .send(Ok(Outbound::AuthChallenge { pubkey, nonce })) .await .map_err(|e| { error!(error = ?e, "Failed to send auth challenge"); Error::Transport })?; - let signature = expect_message(&mut *props.transport, |req: Request| match req { - Request::AuthChallengeSolution { signature } => Some(signature), + let signature = expect_message(transport, |req: Inbound| match req { + Inbound::AuthChallengeSolution { signature } => Some(signature), _ => None, }) .await @@ -170,13 +172,9 @@ async fn challenge_client( Error::Transport })?; - let formatted = format_challenge(nonce, &challenge_pubkey); - let sig = signature.as_slice().try_into().map_err(|_| { - error!("Invalid signature length"); - Error::InvalidChallengeSolution - })?; + let formatted = format_challenge(nonce, pubkey.as_bytes()); - pubkey.verify_strict(&formatted, &sig).map_err(|_| { + pubkey.verify_strict(&formatted, &signature).map_err(|_| { error!("Challenge solution verification failed"); Error::InvalidChallengeSolution })?; @@ -184,30 +182,17 @@ async fn challenge_client( Ok(()) } -fn connect_error_code(err: &Error) -> ConnectErrorCode { - match err { - Error::ApproveError(ApproveError::Denied) => ConnectErrorCode::ApprovalDenied, - Error::ApproveError(ApproveError::Upstream( - router::ApprovalError::NoUserAgentsConnected, - )) => ConnectErrorCode::NoUserAgentsOnline, - _ => ConnectErrorCode::Unknown, - } -} - -async fn authenticate(props: &mut ClientConnection) -> Result { - let Some(Request::AuthChallengeRequest { - pubkey: challenge_pubkey, - }) = props.transport.recv().await - else { +pub async fn authenticate( + props: &mut ClientConnection, + transport: &mut T, +) -> Result +where + T: Bi> + Send + ?Sized, +{ + let Some(Inbound::AuthChallengeRequest { pubkey }) = transport.recv().await else { return Err(Error::Transport); }; - let pubkey_bytes = challenge_pubkey - .as_array() - .ok_or(Error::InvalidClientPubkeyLength)?; - let pubkey = - VerifyingKey::from_bytes(pubkey_bytes).map_err(|_| Error::InvalidAuthPubkeyEncoding)?; - let nonce = match get_nonce(&props.db, &pubkey).await? { Some(nonce) => nonce, None => { @@ -217,21 +202,14 @@ async fn authenticate(props: &mut ClientConnection) -> Result Result { - match authenticate(&mut props).await { - Ok(_pubkey) => Ok(ClientSession::new(props)), - Err(err) => { - let code = connect_error_code(&err); - let _ = props - .transport - .send(Ok(Response::ClientConnectError { code })) - .await; - Err(err) - } - } -} diff --git a/server/crates/arbiter-server/src/actors/client/mod.rs b/server/crates/arbiter-server/src/actors/client/mod.rs index 55c0ed7..3fae866 100644 --- a/server/crates/arbiter-server/src/actors/client/mod.rs +++ b/server/crates/arbiter-server/src/actors/client/mod.rs @@ -7,68 +7,31 @@ use crate::{ db, }; -#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] -pub enum ClientError { - #[error("Expected message with payload")] - MissingRequestPayload, - #[error("Unexpected request payload")] - UnexpectedRequestPayload, - #[error("State machine error")] - StateTransitionFailed, - #[error("Connection registration failed")] - ConnectionRegistrationFailed, - #[error(transparent)] - Auth(#[from] auth::Error), -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ConnectErrorCode { - Unknown, - ApprovalDenied, - NoUserAgentsOnline, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Request { - AuthChallengeRequest { pubkey: Vec }, - AuthChallengeSolution { signature: Vec }, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Response { - AuthChallenge { pubkey: Vec, nonce: i32 }, - AuthOk, - ClientConnectError { code: ConnectErrorCode }, -} - -pub type Transport = Box> + Send>; - pub struct ClientConnection { pub(crate) db: db::DatabasePool, - pub(crate) transport: Transport, pub(crate) actors: GlobalActors, } impl ClientConnection { - pub fn new(db: db::DatabasePool, transport: Transport, actors: GlobalActors) -> Self { - Self { - db, - transport, - actors, - } + pub fn new(db: db::DatabasePool, actors: GlobalActors) -> Self { + Self { db, actors } } } pub mod auth; pub mod session; -pub async fn connect_client(props: ClientConnection) { - match auth::authenticate_and_create(props).await { - Ok(session) => { - ClientSession::spawn(session); +pub async fn connect_client(mut props: ClientConnection, transport: &mut T) +where + T: Bi> + Send + ?Sized, +{ + match auth::authenticate(&mut props, transport).await { + Ok(_pubkey) => { + ClientSession::spawn(ClientSession::new(props)); info!("Client authenticated, session started"); } Err(err) => { + let _ = transport.send(Err(err.clone())).await; error!(?err, "Authentication failed, closing connection"); } } diff --git a/server/crates/arbiter-server/src/actors/client/session.rs b/server/crates/arbiter-server/src/actors/client/session.rs index fb18feb..93f2c6e 100644 --- a/server/crates/arbiter-server/src/actors/client/session.rs +++ b/server/crates/arbiter-server/src/actors/client/session.rs @@ -1,12 +1,9 @@ -use kameo::Actor; -use tokio::select; -use tracing::{error, info}; +use kameo::{Actor, messages}; +use tracing::error; use crate::{ actors::{ - GlobalActors, - client::{ClientConnection, ClientError, Request, Response}, - router::RegisterClient, + GlobalActors, client::ClientConnection, keyholder::KeyHolderState, router::RegisterClient, }, db, }; @@ -19,19 +16,30 @@ impl ClientSession { pub(crate) fn new(props: ClientConnection) -> Self { Self { props } } - - pub async fn process_transport_inbound(&mut self, req: Request) -> Output { - let _ = req; - Err(ClientError::UnexpectedRequestPayload) - } } -type Output = Result; +#[messages] +impl ClientSession { + #[message] + pub(crate) async fn handle_query_vault_state(&mut self) -> Result { + use crate::actors::keyholder::GetState; + + let vault_state = match self.props.actors.key_holder.ask(GetState {}).await { + Ok(state) => state, + Err(err) => { + error!(?err, actor = "client", "keyholder.query.failed"); + return Err(Error::Internal); + } + }; + + Ok(vault_state) + } +} impl Actor for ClientSession { type Args = Self; - type Error = ClientError; + type Error = Error; async fn on_start( args: Self::Args, @@ -42,52 +50,22 @@ impl Actor for ClientSession { .router .ask(RegisterClient { actor: this }) .await - .map_err(|_| ClientError::ConnectionRegistrationFailed)?; + .map_err(|_| Error::ConnectionRegistrationFailed)?; Ok(args) } - - async fn next( - &mut self, - _actor_ref: kameo::prelude::WeakActorRef, - mailbox_rx: &mut kameo::prelude::MailboxReceiver, - ) -> Option> { - loop { - select! { - signal = mailbox_rx.recv() => { - return signal; - } - msg = self.props.transport.recv() => { - match msg { - Some(request) => { - match self.process_transport_inbound(request).await { - Ok(resp) => { - if self.props.transport.send(Ok(resp)).await.is_err() { - error!(actor = "client", reason = "channel closed", "send.failed"); - return Some(kameo::mailbox::Signal::Stop); - } - } - Err(err) => { - let _ = self.props.transport.send(Err(err)).await; - return Some(kameo::mailbox::Signal::Stop); - } - } - } - None => { - info!(actor = "client", "transport.closed"); - return Some(kameo::mailbox::Signal::Stop); - } - } - } - } - } - } } impl ClientSession { pub fn new_test(db: db::DatabasePool, actors: GlobalActors) -> Self { - use arbiter_proto::transport::DummyTransport; - let transport: super::Transport = Box::new(DummyTransport::new()); - let props = ClientConnection::new(db, transport, actors); + let props = ClientConnection::new(db, actors); Self { props } } } + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Connection registration failed")] + ConnectionRegistrationFailed, + #[error("Internal error")] + Internal, +} diff --git a/server/crates/arbiter-server/src/actors/user_agent/auth/state.rs b/server/crates/arbiter-server/src/actors/user_agent/auth/state.rs index 7a5991d..2fdd048 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/auth/state.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/auth/state.rs @@ -106,7 +106,7 @@ pub struct AuthContext<'a, T> { } impl<'a, T> AuthContext<'a, T> { - pub fn new(conn: &'a mut UserAgentConnection, transport: T) -> Self { + pub fn new(conn: &'a mut UserAgentConnection, transport: T) -> Self { Self { conn, transport } } } @@ -124,8 +124,7 @@ where let stored_bytes = pubkey.to_stored_bytes(); let nonce = create_nonce(&self.conn.db, &stored_bytes).await?; - self - .transport + self.transport .send(Ok(Outbound::AuthChallenge { nonce })) .await .map_err(|e| { @@ -165,8 +164,7 @@ where register_key(&self.conn.db, &pubkey).await?; - self - .transport + self.transport .send(Ok(Outbound::AuthSuccess)) .await .map_err(|_| Error::Transport)?; @@ -214,8 +212,7 @@ where }; if valid { - self - .transport + self.transport .send(Ok(Outbound::AuthSuccess)) .await .map_err(|_| Error::Transport)?; diff --git a/server/crates/arbiter-server/src/grpc/client.rs b/server/crates/arbiter-server/src/grpc/client.rs index 3d41785..17442a0 100644 --- a/server/crates/arbiter-server/src/grpc/client.rs +++ b/server/crates/arbiter-server/src/grpc/client.rs @@ -1,142 +1,118 @@ use arbiter_proto::{ proto::client::{ - AuthChallenge as ProtoAuthChallenge, AuthChallengeRequest as ProtoAuthChallengeRequest, - AuthChallengeSolution as ProtoAuthChallengeSolution, AuthOk as ProtoAuthOk, - ClientConnectError, ClientRequest, ClientResponse, - client_connect_error::Code as ProtoClientConnectErrorCode, + ClientRequest, ClientResponse, VaultState as ProtoVaultState, client_request::Payload as ClientRequestPayload, client_response::Payload as ClientResponsePayload, }, - transport::{Bi, Error as TransportError, Sender}, + transport::{Receiver, Sender, grpc::GrpcBi}, }; -use async_trait::async_trait; -use futures::StreamExt as _; -use tokio::sync::mpsc; -use tonic::{Status, Streaming}; +use kameo::{ + actor::{ActorRef, Spawn as _}, + error::SendError, +}; +use tracing::{info, warn}; -use crate::actors::client::{ - self, ClientError, ConnectErrorCode, Request as DomainRequest, Response as DomainResponse, +use crate::{ + actors::{ + client::{ + self, ClientConnection, + session::{ClientSession, Error, HandleQueryVaultState}, + }, + keyholder::KeyHolderState, + }, + utils::defer, }; -pub struct GrpcTransport { - sender: mpsc::Sender>, - receiver: Streaming, -} +mod auth; -impl GrpcTransport { - pub fn new( - sender: mpsc::Sender>, - receiver: Streaming, - ) -> Self { - Self { sender, receiver } - } - - fn request_to_domain(request: ClientRequest) -> Result { - match request.payload { - Some(ClientRequestPayload::AuthChallengeRequest(ProtoAuthChallengeRequest { - pubkey, - })) => Ok(DomainRequest::AuthChallengeRequest { pubkey }), - Some(ClientRequestPayload::AuthChallengeSolution(ProtoAuthChallengeSolution { - signature, - })) => Ok(DomainRequest::AuthChallengeSolution { signature }), - None => Err(Status::invalid_argument("Missing client request payload")), - } - } - - fn response_to_proto(response: DomainResponse) -> ClientResponse { - let payload = match response { - DomainResponse::AuthChallenge { pubkey, nonce } => { - ClientResponsePayload::AuthChallenge(ProtoAuthChallenge { pubkey, nonce }) - } - DomainResponse::AuthOk => ClientResponsePayload::AuthOk(ProtoAuthOk {}), - DomainResponse::ClientConnectError { code } => { - ClientResponsePayload::ClientConnectError(ClientConnectError { - code: match code { - ConnectErrorCode::Unknown => ProtoClientConnectErrorCode::Unknown, - ConnectErrorCode::ApprovalDenied => { - ProtoClientConnectErrorCode::ApprovalDenied - } - ConnectErrorCode::NoUserAgentsOnline => { - ProtoClientConnectErrorCode::NoUserAgentsOnline - } - } - .into(), - }) - } +async fn dispatch_loop( + mut bi: GrpcBi, + actor: ActorRef, +) { + loop { + let Some(conn) = bi.recv().await else { + return; }; - ClientResponse { - payload: Some(payload), - } - } - - fn error_to_status(value: ClientError) -> Status { - match value { - ClientError::MissingRequestPayload | ClientError::UnexpectedRequestPayload => { - Status::invalid_argument("Expected message with payload") - } - ClientError::StateTransitionFailed => Status::internal("State machine error"), - ClientError::Auth(ref err) => auth_error_status(err), - ClientError::ConnectionRegistrationFailed => { - Status::internal("Connection registration failed") - } + if dispatch_conn_message(&mut bi, &actor, conn).await.is_err() { + return; } } } -#[async_trait] -impl Sender> for GrpcTransport { - async fn send( - &mut self, - item: Result, - ) -> Result<(), TransportError> { - let outbound = match item { - Ok(message) => Ok(Self::response_to_proto(message)), - Err(err) => Err(Self::error_to_status(err)), - }; +async fn dispatch_conn_message( + bi: &mut GrpcBi, + actor: &ActorRef, + conn: Result, +) -> Result<(), ()> { + let conn = match conn { + Ok(conn) => conn, + Err(err) => { + warn!(error = ?err, "Failed to receive client request"); + return Err(()); + } + }; - self.sender - .send(outbound) - .await - .map_err(|_| TransportError::ChannelClosed) - } -} + let Some(payload) = conn.payload else { + let _ = bi + .send(Err(tonic::Status::invalid_argument( + "Missing client request payload", + ))) + .await; + return Err(()); + }; -#[async_trait] -impl Bi> for GrpcTransport { - async fn recv(&mut self) -> Option { - match self.receiver.next().await { - Some(Ok(item)) => match Self::request_to_domain(item) { - Ok(request) => Some(request), - Err(status) => { - let _ = self.sender.send(Err(status)).await; - None + let payload = match payload { + ClientRequestPayload::QueryVaultState(_) => ClientResponsePayload::VaultState( + match actor.ask(HandleQueryVaultState {}).await { + Ok(KeyHolderState::Unbootstrapped) => ProtoVaultState::Unbootstrapped, + Ok(KeyHolderState::Sealed) => ProtoVaultState::Sealed, + Ok(KeyHolderState::Unsealed) => ProtoVaultState::Unsealed, + Err(SendError::HandlerError(Error::Internal)) => ProtoVaultState::Error, + Err(err) => { + warn!(error = ?err, "Failed to query vault state"); + ProtoVaultState::Error } - }, - Some(Err(error)) => { - tracing::error!(error = ?error, "grpc client recv failed; closing stream"); - None } - None => None, + .into(), + ), + payload => { + warn!(?payload, "Unsupported post-auth client request"); + let _ = bi + .send(Err(tonic::Status::invalid_argument( + "Unsupported client request", + ))) + .await; + return Err(()); + } + }; + + bi.send(Ok(ClientResponse { + payload: Some(payload), + })) + .await + .map_err(|_| ()) +} + +pub async fn start(conn: ClientConnection, mut bi: GrpcBi) { + let mut conn = conn; + match auth::start(&mut conn, &mut bi).await { + Ok(_) => { + let actor = + client::session::ClientSession::spawn(client::session::ClientSession::new(conn)); + let actor_for_cleanup = actor.clone(); + let _ = defer(move || { + actor_for_cleanup.kill(); + }); + + info!("Client authenticated successfully"); + dispatch_loop(bi, actor).await; + } + Err(e) => { + let mut transport = auth::AuthTransportAdapter(&mut bi); + let _ = transport.send(Err(e.clone())).await; + warn!(error = ?e, "Authentication failed"); + return; } } } - -fn auth_error_status(value: &client::auth::Error) -> Status { - use client::auth::Error; - - match value { - Error::UnexpectedMessagePayload | Error::InvalidClientPubkeyLength => { - Status::invalid_argument(value.to_string()) - } - Error::InvalidAuthPubkeyEncoding => { - Status::invalid_argument("Failed to convert pubkey to VerifyingKey") - } - Error::InvalidChallengeSolution => Status::unauthenticated(value.to_string()), - Error::ApproveError(_) => Status::permission_denied(value.to_string()), - Error::Transport => Status::internal("Transport error"), - Error::DatabasePoolUnavailable => Status::internal("Database pool error"), - Error::DatabaseOperationFailed => Status::internal("Database error"), - Error::InternalError => Status::internal("Internal error"), - } -} diff --git a/server/crates/arbiter-server/src/grpc/client/auth.rs b/server/crates/arbiter-server/src/grpc/client/auth.rs new file mode 100644 index 0000000..8374b36 --- /dev/null +++ b/server/crates/arbiter-server/src/grpc/client/auth.rs @@ -0,0 +1,131 @@ +use arbiter_proto::{ + proto::client::{ + AuthChallenge as ProtoAuthChallenge, AuthChallengeRequest as ProtoAuthChallengeRequest, + AuthChallengeSolution as ProtoAuthChallengeSolution, AuthResult as ProtoAuthResult, + ClientRequest, ClientResponse, client_request::Payload as ClientRequestPayload, + client_response::Payload as ClientResponsePayload, + }, + transport::{Bi, Error as TransportError, Receiver, Sender, grpc::GrpcBi}, +}; +use async_trait::async_trait; +use tracing::warn; + +use crate::actors::client::{self, ClientConnection, auth}; + +pub struct AuthTransportAdapter<'a>(pub(super) &'a mut GrpcBi); + +impl AuthTransportAdapter<'_> { + fn response_to_proto(response: auth::Outbound) -> ClientResponse { + let payload = match response { + auth::Outbound::AuthChallenge { pubkey, nonce } => { + ClientResponsePayload::AuthChallenge(ProtoAuthChallenge { + pubkey: pubkey.to_bytes().to_vec(), + nonce, + }) + } + auth::Outbound::AuthSuccess => { + ClientResponsePayload::AuthResult(ProtoAuthResult::Success.into()) + } + }; + + ClientResponse { + payload: Some(payload), + } + } + + fn error_to_proto(error: auth::Error) -> ClientResponse { + ClientResponse { + payload: Some(ClientResponsePayload::AuthResult( + match error { + auth::Error::InvalidChallengeSolution => ProtoAuthResult::InvalidSignature, + auth::Error::ApproveError(auth::ApproveError::Denied) => { + ProtoAuthResult::ApprovalDenied + } + auth::Error::ApproveError(auth::ApproveError::Upstream( + crate::actors::router::ApprovalError::NoUserAgentsConnected, + )) => ProtoAuthResult::NoUserAgentsOnline, + auth::Error::ApproveError(auth::ApproveError::Internal) + | auth::Error::DatabasePoolUnavailable + | auth::Error::DatabaseOperationFailed + | auth::Error::Transport => ProtoAuthResult::Internal, + } + .into(), + )), + } + } + + async fn send_auth_result(&mut self, result: ProtoAuthResult) -> Result<(), TransportError> { + self.0 + .send(Ok(ClientResponse { + payload: Some(ClientResponsePayload::AuthResult(result.into())), + })) + .await + } +} + +#[async_trait] +impl Sender> for AuthTransportAdapter<'_> { + async fn send( + &mut self, + item: Result, + ) -> Result<(), TransportError> { + let outbound = match item { + Ok(message) => Ok(AuthTransportAdapter::response_to_proto(message)), + Err(err) => Ok(AuthTransportAdapter::error_to_proto(err)), + }; + + self.0.send(outbound).await + } +} + +#[async_trait] +impl Receiver for AuthTransportAdapter<'_> { + async fn recv(&mut self) -> Option { + let request = match self.0.recv().await? { + Ok(request) => request, + Err(error) => { + warn!(error = ?error, "grpc client recv failed; closing stream"); + return None; + } + }; + + let payload = request.payload?; + + match payload { + ClientRequestPayload::AuthChallengeRequest(ProtoAuthChallengeRequest { pubkey }) => { + let Ok(pubkey) = <[u8; 32]>::try_from(pubkey) else { + let _ = self.send_auth_result(ProtoAuthResult::InvalidKey).await; + return None; + }; + let Ok(pubkey) = ed25519_dalek::VerifyingKey::from_bytes(&pubkey) else { + let _ = self.send_auth_result(ProtoAuthResult::InvalidKey).await; + return None; + }; + Some(auth::Inbound::AuthChallengeRequest { pubkey }) + } + ClientRequestPayload::AuthChallengeSolution(ProtoAuthChallengeSolution { + signature, + }) => { + let Ok(signature) = ed25519_dalek::Signature::try_from(signature.as_slice()) else { + let _ = self + .send_auth_result(ProtoAuthResult::InvalidSignature) + .await; + return None; + }; + Some(auth::Inbound::AuthChallengeSolution { signature }) + } + _ => None, + } + } +} + +impl Bi> for AuthTransportAdapter<'_> {} + +pub async fn start( + conn: &mut ClientConnection, + bi: &mut GrpcBi, +) -> Result<(), auth::Error> { + let mut transport = AuthTransportAdapter(bi); + client::auth::authenticate(conn, &mut transport).await?; + Ok(()) +} diff --git a/server/crates/arbiter-server/src/grpc/mod.rs b/server/crates/arbiter-server/src/grpc/mod.rs index 204d6b1..7bdedea 100644 --- a/server/crates/arbiter-server/src/grpc/mod.rs +++ b/server/crates/arbiter-server/src/grpc/mod.rs @@ -12,10 +12,7 @@ use tracing::info; use crate::{ DEFAULT_CHANNEL_SIZE, - actors::{ - client::{ClientConnection, connect_client}, - user_agent::UserAgentConnection, - }, + actors::{client::ClientConnection, user_agent::UserAgentConnection}, grpc::{self, user_agent::start}, }; @@ -33,19 +30,13 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for super::Ser request: Request>, ) -> Result, Status> { let req_stream = request.into_inner(); - let (tx, rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); - - let transport = client::GrpcTransport::new(tx, req_stream); - let props = ClientConnection::new( - self.context.db.clone(), - Box::new(transport), - self.context.actors.clone(), - ); - tokio::spawn(connect_client(props)); + let (bi, rx) = GrpcBi::from_bi_stream(req_stream); + let props = ClientConnection::new(self.context.db.clone(), self.context.actors.clone()); + tokio::spawn(client::start(props, bi)); info!(event = "connection established", "grpc.client"); - Ok(Response::new(ReceiverStream::new(rx))) + Ok(Response::new(rx)) } #[tracing::instrument(level = "debug", skip(self))] diff --git a/server/crates/arbiter-server/src/grpc/user_agent.rs b/server/crates/arbiter-server/src/grpc/user_agent.rs index 4df6317..fe587dc 100644 --- a/server/crates/arbiter-server/src/grpc/user_agent.rs +++ b/server/crates/arbiter-server/src/grpc/user_agent.rs @@ -30,7 +30,10 @@ use arbiter_proto::{ }; use async_trait::async_trait; use chrono::{TimeZone, Utc}; -use kameo::{actor::{ActorRef, Spawn as _}, error::SendError}; +use kameo::{ + actor::{ActorRef, Spawn as _}, + error::SendError, +}; use tonic::Status; use tracing::{info, warn}; @@ -40,7 +43,9 @@ use crate::{ user_agent::{ OutOfBand, UserAgentConnection, UserAgentSession, session::{ - BootstrapError, Error, HandleBootstrapEncryptedKey, HandleEvmWalletCreate, HandleEvmWalletList, HandleGrantCreate, HandleGrantDelete, HandleGrantList, HandleQueryVaultState, HandleUnsealEncryptedKey, HandleUnsealRequest, UnsealError + BootstrapError, Error, HandleBootstrapEncryptedKey, HandleEvmWalletCreate, + HandleEvmWalletList, HandleGrantCreate, HandleGrantDelete, HandleGrantList, + HandleQueryVaultState, HandleUnsealEncryptedKey, HandleUnsealRequest, UnsealError, }, }, }, @@ -109,7 +114,11 @@ async fn dispatch_conn_message( }; let Some(payload) = conn.payload else { - let _ = bi.send(Err(Status::invalid_argument("Missing user-agent request payload"))).await; + let _ = bi + .send(Err(Status::invalid_argument( + "Missing user-agent request payload", + ))) + .await; return Err(()); }; @@ -118,7 +127,9 @@ async fn dispatch_conn_message( let client_pubkey = match <[u8; 32]>::try_from(client_pubkey) { Ok(bytes) => x25519_dalek::PublicKey::from(bytes), Err(_) => { - let _ = bi.send(Err(Status::invalid_argument("Invalid X25519 public key"))).await; + let _ = bi + .send(Err(Status::invalid_argument("Invalid X25519 public key"))) + .await; return Err(()); } }; @@ -131,7 +142,9 @@ async fn dispatch_conn_message( ), Err(err) => { warn!(error = ?err, "Failed to handle unseal start request"); - let _ = bi.send(Err(Status::internal("Failed to start unseal flow"))).await; + let _ = bi + .send(Err(Status::internal("Failed to start unseal flow"))) + .await; return Err(()); } } @@ -155,7 +168,9 @@ async fn dispatch_conn_message( } Err(err) => { warn!(error = ?err, "Failed to handle unseal request"); - let _ = bi.send(Err(Status::internal("Failed to unseal vault"))).await; + let _ = bi + .send(Err(Status::internal("Failed to unseal vault"))) + .await; return Err(()); } } @@ -178,12 +193,14 @@ async fn dispatch_conn_message( Err(SendError::HandlerError(BootstrapError::InvalidKey)) => { ProtoBootstrapResult::InvalidKey } - Err(SendError::HandlerError( - BootstrapError::AlreadyBootstrapped, - )) => ProtoBootstrapResult::AlreadyBootstrapped, + Err(SendError::HandlerError(BootstrapError::AlreadyBootstrapped)) => { + ProtoBootstrapResult::AlreadyBootstrapped + } Err(err) => { warn!(error = ?err, "Failed to handle bootstrap request"); - let _ = bi.send(Err(Status::internal("Failed to bootstrap vault"))).await; + let _ = bi + .send(Err(Status::internal("Failed to bootstrap vault"))) + .await; return Err(()); } } @@ -224,12 +241,13 @@ async fn dispatch_conn_message( }; UserAgentResponsePayload::EvmGrantCreate(EvmGrantOrWallet::grant_create_response( - actor.ask(HandleGrantCreate { - client_id, - basic, - grant, - }) - .await, + actor + .ask(HandleGrantCreate { + client_id, + basic, + grant, + }) + .await, )) } UserAgentRequestPayload::EvmGrantDelete(EvmGrantDeleteRequest { grant_id }) => { @@ -239,7 +257,11 @@ async fn dispatch_conn_message( } payload => { warn!(?payload, "Unsupported post-auth user agent request"); - let _ = bi.send(Err(Status::invalid_argument("Unsupported user-agent request"))).await; + let _ = bi + .send(Err(Status::invalid_argument( + "Unsupported user-agent request", + ))) + .await; return Err(()); } }; @@ -281,7 +303,10 @@ fn parse_grant_request( let specific = specific.ok_or_else(|| Status::invalid_argument("Missing specific grant settings"))?; - Ok((shared_settings_from_proto(shared)?, specific_grant_from_proto(specific)?)) + Ok(( + shared_settings_from_proto(shared)?, + specific_grant_from_proto(specific)?, + )) } fn shared_settings_from_proto(shared: ProtoSharedSettings) -> Result { @@ -289,14 +314,8 @@ fn shared_settings_from_proto(shared: ProtoSharedSettings) -> Result Result Result>()?, - limit: volume_rate_limit_from_proto( - limit.ok_or_else(|| { - Status::invalid_argument("Missing ether transfer volume rate limit") - })?, - )?, + limit: volume_rate_limit_from_proto(limit.ok_or_else(|| { + Status::invalid_argument("Missing ether transfer volume rate limit") + })?)?, })), Some(ProtoSpecificGrantType::TokenTransfer(ProtoTokenTransferSettings { token_contract, @@ -391,12 +406,12 @@ fn shared_settings_to_proto(shared: SharedGrantSettings) -> ProtoSharedSettings seconds: time.timestamp(), nanos: time.timestamp_subsec_nanos() as i32, }), - max_gas_fee_per_gas: shared.max_gas_fee_per_gas.map(|value| { - value.to_be_bytes::<32>().to_vec() - }), - max_priority_fee_per_gas: shared.max_priority_fee_per_gas.map(|value| { - value.to_be_bytes::<32>().to_vec() - }), + max_gas_fee_per_gas: shared + .max_gas_fee_per_gas + .map(|value| value.to_be_bytes::<32>().to_vec()), + max_priority_fee_per_gas: shared + .max_priority_fee_per_gas + .map(|value| value.to_be_bytes::<32>().to_vec()), rate_limit: shared.rate_limit.map(|limit| ProtoTransactionRateLimit { count: limit.count, window_secs: limit.window.num_seconds(), @@ -408,7 +423,11 @@ fn specific_grant_to_proto(grant: SpecificGrant) -> ProtoSpecificGrant { let grant = match grant { SpecificGrant::EtherTransfer(settings) => { ProtoSpecificGrantType::EtherTransfer(ProtoEtherTransferSettings { - targets: settings.target.into_iter().map(|address| address.to_vec()).collect(), + targets: settings + .target + .into_iter() + .map(|address| address.to_vec()) + .collect(), limit: Some(ProtoVolumeRateLimit { max_volume: settings.limit.max_volume.to_be_bytes::<32>().to_vec(), window_secs: settings.limit.window.num_seconds(), @@ -450,7 +469,9 @@ impl EvmGrantOrWallet { } }; - WalletCreateResponse { result: Some(result) } + WalletCreateResponse { + result: Some(result), + } } fn wallet_list_response( @@ -471,7 +492,9 @@ impl EvmGrantOrWallet { } }; - WalletListResponse { result: Some(result) } + WalletListResponse { + result: Some(result), + } } fn grant_create_response( @@ -485,12 +508,12 @@ impl EvmGrantOrWallet { } }; - EvmGrantCreateResponse { result: Some(result) } + EvmGrantCreateResponse { + result: Some(result), + } } - fn grant_delete_response( - result: Result<(), SendError>, - ) -> EvmGrantDeleteResponse { + fn grant_delete_response(result: Result<(), SendError>) -> EvmGrantDeleteResponse { let result = match result { Ok(()) => EvmGrantDeleteResult::Ok(()), Err(err) => { @@ -499,7 +522,9 @@ impl EvmGrantOrWallet { } }; - EvmGrantDeleteResponse { result: Some(result) } + EvmGrantDeleteResponse { + result: Some(result), + } } fn grant_list_response( @@ -523,7 +548,9 @@ impl EvmGrantOrWallet { } }; - EvmGrantListResponse { result: Some(result) } + EvmGrantListResponse { + result: Some(result), + } } } diff --git a/server/crates/arbiter-server/src/lib.rs b/server/crates/arbiter-server/src/lib.rs index 0b255e5..6367164 100644 --- a/server/crates/arbiter-server/src/lib.rs +++ b/server/crates/arbiter-server/src/lib.rs @@ -1,9 +1,5 @@ #![forbid(unsafe_code)] -#![deny( - clippy::unwrap_used, - clippy::expect_used, - clippy::panic -)] +#![deny(clippy::unwrap_used, clippy::expect_used, clippy::panic)] use crate::context::ServerContext; @@ -26,4 +22,3 @@ impl Server { Self { context } } } - diff --git a/server/crates/arbiter-server/tests/client/auth.rs b/server/crates/arbiter-server/tests/client/auth.rs index ca4f38b..ca1d0d0 100644 --- a/server/crates/arbiter-server/tests/client/auth.rs +++ b/server/crates/arbiter-server/tests/client/auth.rs @@ -1,7 +1,7 @@ -use arbiter_proto::transport::Bi; +use arbiter_proto::transport::{Receiver, Sender}; use arbiter_server::actors::GlobalActors; use arbiter_server::{ - actors::client::{ClientConnection, Request, Response, connect_client}, + actors::client::{ClientConnection, auth, connect_client}, db::{self, schema}, }; use diesel::{ExpressionMethods as _, insert_into}; @@ -17,15 +17,17 @@ pub async fn test_unregistered_pubkey_rejected() { let (server_transport, mut test_transport) = ChannelTransport::new(); let actors = GlobalActors::spawn(db.clone()).await.unwrap(); - let props = ClientConnection::new(db.clone(), Box::new(server_transport), actors); - let task = tokio::spawn(connect_client(props)); + let props = ClientConnection::new(db.clone(), actors); + let task = tokio::spawn(async move { + let mut server_transport = server_transport; + connect_client(props, &mut server_transport).await; + }); let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); - let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); test_transport - .send(Request::AuthChallengeRequest { - pubkey: pubkey_bytes, + .send(auth::Inbound::AuthChallengeRequest { + pubkey: new_key.verifying_key(), }) .await .unwrap(); @@ -54,13 +56,16 @@ pub async fn test_challenge_auth() { let (server_transport, mut test_transport) = ChannelTransport::new(); let actors = GlobalActors::spawn(db.clone()).await.unwrap(); - let props = ClientConnection::new(db.clone(), Box::new(server_transport), actors); - let task = tokio::spawn(connect_client(props)); + let props = ClientConnection::new(db.clone(), actors); + let task = tokio::spawn(async move { + let mut server_transport = server_transport; + connect_client(props, &mut server_transport).await; + }); // Send challenge request test_transport - .send(Request::AuthChallengeRequest { - pubkey: pubkey_bytes, + .send(auth::Inbound::AuthChallengeRequest { + pubkey: new_key.verifying_key(), }) .await .unwrap(); @@ -72,23 +77,31 @@ pub async fn test_challenge_auth() { .expect("should receive challenge"); let challenge = match response { Ok(resp) => match resp { - Response::AuthChallenge { pubkey, nonce } => (pubkey, nonce), + auth::Outbound::AuthChallenge { pubkey, nonce } => (pubkey, nonce), other => panic!("Expected AuthChallenge, got {other:?}"), }, Err(err) => panic!("Expected Ok response, got Err({err:?})"), }; // Sign the challenge and send solution - let formatted_challenge = arbiter_proto::format_challenge(challenge.1, &challenge.0); + let formatted_challenge = arbiter_proto::format_challenge(challenge.1, challenge.0.as_bytes()); let signature = new_key.sign(&formatted_challenge); test_transport - .send(Request::AuthChallengeSolution { - signature: signature.to_bytes().to_vec(), - }) + .send(auth::Inbound::AuthChallengeSolution { signature }) .await .unwrap(); + let response = test_transport + .recv() + .await + .expect("should receive auth success"); + match response { + Ok(auth::Outbound::AuthSuccess) => {} + Ok(other) => panic!("Expected AuthSuccess, got {other:?}"), + Err(err) => panic!("Expected Ok response, got Err({err:?})"), + } + // Auth completes, session spawned task.await.unwrap(); } diff --git a/server/crates/arbiter-server/tests/common/mod.rs b/server/crates/arbiter-server/tests/common/mod.rs index 3bc3430..13ccd32 100644 --- a/server/crates/arbiter-server/tests/common/mod.rs +++ b/server/crates/arbiter-server/tests/common/mod.rs @@ -1,7 +1,8 @@ -use arbiter_proto::transport::{Bi, Error}; +use arbiter_proto::transport::{Bi, Error, Receiver, Sender}; use arbiter_server::{ actors::keyholder::KeyHolder, - db::{self, schema}, safe_cell::{SafeCell, SafeCellHandle as _}, + db::{self, schema}, + safe_cell::{SafeCell, SafeCellHandle as _}, }; use async_trait::async_trait; use diesel::QueryDsl; @@ -54,10 +55,10 @@ impl ChannelTransport { } #[async_trait] -impl Bi for ChannelTransport +impl Sender for ChannelTransport where - T: Send + 'static, - Y: Send + 'static, + T: Send + Sync + 'static, + Y: Send + Sync + 'static, { async fn send(&mut self, item: Y) -> Result<(), Error> { self.sender @@ -65,8 +66,22 @@ where .await .map_err(|_| Error::ChannelClosed) } +} +#[async_trait] +impl Receiver for ChannelTransport +where + T: Send + Sync + 'static, + Y: Send + Sync + 'static, +{ async fn recv(&mut self) -> Option { self.receiver.recv().await } } + +impl Bi for ChannelTransport +where + T: Send + Sync + 'static, + Y: Send + Sync + 'static, +{ +} diff --git a/server/crates/arbiter-server/tests/user_agent/auth.rs b/server/crates/arbiter-server/tests/user_agent/auth.rs index 1a7bbad..bfe308a 100644 --- a/server/crates/arbiter-server/tests/user_agent/auth.rs +++ b/server/crates/arbiter-server/tests/user_agent/auth.rs @@ -3,7 +3,7 @@ use arbiter_server::{ actors::{ GlobalActors, bootstrap::GetToken, - user_agent::{AuthPublicKey, Request, OutOfBand, UserAgentConnection, connect_user_agent}, + user_agent::{AuthPublicKey, OutOfBand, Request, UserAgentConnection, connect_user_agent}, }, db::{self, schema}, }; diff --git a/server/crates/arbiter-server/tests/user_agent/unseal.rs b/server/crates/arbiter-server/tests/user_agent/unseal.rs index 4b6d7a3..0b2eea6 100644 --- a/server/crates/arbiter-server/tests/user_agent/unseal.rs +++ b/server/crates/arbiter-server/tests/user_agent/unseal.rs @@ -2,7 +2,7 @@ use arbiter_server::{ actors::{ GlobalActors, keyholder::{Bootstrap, Seal}, - user_agent::{Request, OutOfBand, UnsealError, session::UserAgentSession}, + user_agent::{OutOfBand, Request, UnsealError, session::UserAgentSession}, }, db, safe_cell::{SafeCell, SafeCellHandle as _},