From 3e8b26418ae36b38e556ac451f730b7ed1a5d43e Mon Sep 17 00:00:00 2001 From: hdbg Date: Wed, 18 Mar 2026 23:43:44 +0100 Subject: [PATCH] feat(proto): request / response pair tracking by assigning id --- protobufs/client.proto | 2 + protobufs/user_agent.proto | 2 + .../crates/arbiter-server/src/grpc/client.rs | 38 +++- .../arbiter-server/src/grpc/client/auth.rs | 118 +++++++++---- server/crates/arbiter-server/src/grpc/mod.rs | 1 + .../src/grpc/request_tracker.rs | 20 +++ .../arbiter-server/src/grpc/user_agent.rs | 33 +++- .../src/grpc/user_agent/auth.rs | 163 +++++++++++------- 8 files changed, 259 insertions(+), 118 deletions(-) create mode 100644 server/crates/arbiter-server/src/grpc/request_tracker.rs diff --git a/protobufs/client.proto b/protobufs/client.proto index dbe9708..c090a0d 100644 --- a/protobufs/client.proto +++ b/protobufs/client.proto @@ -37,6 +37,7 @@ enum VaultState { } message ClientRequest { + int32 request_id = 4; oneof payload { AuthChallengeRequest auth_challenge_request = 1; AuthChallengeSolution auth_challenge_solution = 2; @@ -45,6 +46,7 @@ message ClientRequest { } message ClientResponse { + optional int32 request_id = 7; oneof payload { AuthChallenge auth_challenge = 1; AuthResult auth_result = 2; diff --git a/protobufs/user_agent.proto b/protobufs/user_agent.proto index 6fb77e4..f54f05a 100644 --- a/protobufs/user_agent.proto +++ b/protobufs/user_agent.proto @@ -89,6 +89,7 @@ message ClientConnectionResponse { message ClientConnectionCancel {} message UserAgentRequest { + int32 id = 14; oneof payload { AuthChallengeRequest auth_challenge_request = 1; AuthChallengeSolution auth_challenge_solution = 2; @@ -105,6 +106,7 @@ message UserAgentRequest { } } message UserAgentResponse { + optional int32 id = 14; oneof payload { AuthChallenge auth_challenge = 1; AuthResult auth_result = 2; diff --git a/server/crates/arbiter-server/src/grpc/client.rs b/server/crates/arbiter-server/src/grpc/client.rs index 17442a0..653c7a8 100644 --- a/server/crates/arbiter-server/src/grpc/client.rs +++ b/server/crates/arbiter-server/src/grpc/client.rs @@ -10,6 +10,7 @@ use kameo::{ actor::{ActorRef, Spawn as _}, error::SendError, }; +use tonic::Status; use tracing::{info, warn}; use crate::{ @@ -20,6 +21,7 @@ use crate::{ }, keyholder::KeyHolderState, }, + grpc::request_tracker::RequestTracker, utils::defer, }; @@ -28,13 +30,17 @@ mod auth; async fn dispatch_loop( mut bi: GrpcBi, actor: ActorRef, + mut request_tracker: RequestTracker, ) { loop { let Some(conn) = bi.recv().await else { return; }; - if dispatch_conn_message(&mut bi, &actor, conn).await.is_err() { + if dispatch_conn_message(&mut bi, &actor, &mut request_tracker, conn) + .await + .is_err() + { return; } } @@ -43,7 +49,8 @@ async fn dispatch_loop( async fn dispatch_conn_message( bi: &mut GrpcBi, actor: &ActorRef, - conn: Result, + request_tracker: &mut RequestTracker, + conn: Result, ) -> Result<(), ()> { let conn = match conn { Ok(conn) => conn, @@ -53,9 +60,16 @@ async fn dispatch_conn_message( } }; + let request_id = match request_tracker.request(conn.request_id) { + Ok(request_id) => request_id, + Err(err) => { + let _ = bi.send(Err(err)).await; + return Err(()); + } + }; let Some(payload) = conn.payload else { let _ = bi - .send(Err(tonic::Status::invalid_argument( + .send(Err(Status::invalid_argument( "Missing client request payload", ))) .await; @@ -79,15 +93,14 @@ async fn dispatch_conn_message( payload => { warn!(?payload, "Unsupported post-auth client request"); let _ = bi - .send(Err(tonic::Status::invalid_argument( - "Unsupported client request", - ))) + .send(Err(Status::invalid_argument("Unsupported client request"))) .await; return Err(()); } }; bi.send(Ok(ClientResponse { + request_id: Some(request_id), payload: Some(payload), })) .await @@ -96,7 +109,10 @@ async fn dispatch_conn_message( pub async fn start(conn: ClientConnection, mut bi: GrpcBi) { let mut conn = conn; - match auth::start(&mut conn, &mut bi).await { + let mut request_tracker = RequestTracker::default(); + let mut response_id = None; + + match auth::start(&mut conn, &mut bi, &mut request_tracker, &mut response_id).await { Ok(_) => { let actor = client::session::ClientSession::spawn(client::session::ClientSession::new(conn)); @@ -106,10 +122,14 @@ pub async fn start(conn: ClientConnection, mut bi: GrpcBi { - let mut transport = auth::AuthTransportAdapter(&mut bi); + let mut transport = auth::AuthTransportAdapter::new( + &mut bi, + &mut request_tracker, + &mut response_id, + ); let _ = transport.send(Err(e.clone())).await; warn!(error = ?e, "Authentication failed"); return; diff --git a/server/crates/arbiter-server/src/grpc/client/auth.rs b/server/crates/arbiter-server/src/grpc/client/auth.rs index 8374b36..49d8d55 100644 --- a/server/crates/arbiter-server/src/grpc/client/auth.rs +++ b/server/crates/arbiter-server/src/grpc/client/auth.rs @@ -8,15 +8,35 @@ use arbiter_proto::{ transport::{Bi, Error as TransportError, Receiver, Sender, grpc::GrpcBi}, }; use async_trait::async_trait; +use tonic::Status; use tracing::warn; -use crate::actors::client::{self, ClientConnection, auth}; +use crate::{ + actors::client::{self, ClientConnection, auth}, + grpc::request_tracker::RequestTracker, +}; -pub struct AuthTransportAdapter<'a>(pub(super) &'a mut GrpcBi); +pub struct AuthTransportAdapter<'a> { + bi: &'a mut GrpcBi, + request_tracker: &'a mut RequestTracker, + response_id: &'a mut Option, +} -impl AuthTransportAdapter<'_> { - fn response_to_proto(response: auth::Outbound) -> ClientResponse { - let payload = match response { +impl<'a> AuthTransportAdapter<'a> { + pub fn new( + bi: &'a mut GrpcBi, + request_tracker: &'a mut RequestTracker, + response_id: &'a mut Option, + ) -> Self { + Self { + bi, + request_tracker, + response_id, + } + } + + fn response_to_proto(response: auth::Outbound) -> ClientResponsePayload { + match response { auth::Outbound::AuthChallenge { pubkey, nonce } => { ClientResponsePayload::AuthChallenge(ProtoAuthChallenge { pubkey: pubkey.to_bytes().to_vec(), @@ -26,39 +46,44 @@ impl AuthTransportAdapter<'_> { auth::Outbound::AuthSuccess => { ClientResponsePayload::AuthResult(ProtoAuthResult::Success.into()) } - }; - - ClientResponse { - payload: Some(payload), } } - fn error_to_proto(error: auth::Error) -> ClientResponse { - ClientResponse { - payload: Some(ClientResponsePayload::AuthResult( - match error { - auth::Error::InvalidChallengeSolution => ProtoAuthResult::InvalidSignature, - auth::Error::ApproveError(auth::ApproveError::Denied) => { - ProtoAuthResult::ApprovalDenied - } - auth::Error::ApproveError(auth::ApproveError::Upstream( - crate::actors::router::ApprovalError::NoUserAgentsConnected, - )) => ProtoAuthResult::NoUserAgentsOnline, - auth::Error::ApproveError(auth::ApproveError::Internal) - | auth::Error::DatabasePoolUnavailable - | auth::Error::DatabaseOperationFailed - | auth::Error::Transport => ProtoAuthResult::Internal, + fn error_to_proto(error: auth::Error) -> ClientResponsePayload { + ClientResponsePayload::AuthResult( + match error { + auth::Error::InvalidChallengeSolution => ProtoAuthResult::InvalidSignature, + auth::Error::ApproveError(auth::ApproveError::Denied) => { + ProtoAuthResult::ApprovalDenied } - .into(), - )), - } + auth::Error::ApproveError(auth::ApproveError::Upstream( + crate::actors::router::ApprovalError::NoUserAgentsConnected, + )) => ProtoAuthResult::NoUserAgentsOnline, + auth::Error::ApproveError(auth::ApproveError::Internal) + | auth::Error::DatabasePoolUnavailable + | auth::Error::DatabaseOperationFailed + | auth::Error::Transport => ProtoAuthResult::Internal, + } + .into(), + ) + } + + async fn send_client_response( + &mut self, + payload: ClientResponsePayload, + ) -> Result<(), TransportError> { + let request_id = self.response_id.take(); + + self.bi + .send(Ok(ClientResponse { + request_id, + payload: Some(payload), + })) + .await } async fn send_auth_result(&mut self, result: ProtoAuthResult) -> Result<(), TransportError> { - self.0 - .send(Ok(ClientResponse { - payload: Some(ClientResponsePayload::AuthResult(result.into())), - })) + self.send_client_response(ClientResponsePayload::AuthResult(result.into())) .await } } @@ -69,19 +94,19 @@ impl Sender> for AuthTransportAdapter<'_> { &mut self, item: Result, ) -> Result<(), TransportError> { - let outbound = match item { - Ok(message) => Ok(AuthTransportAdapter::response_to_proto(message)), - Err(err) => Ok(AuthTransportAdapter::error_to_proto(err)), + let payload = match item { + Ok(message) => AuthTransportAdapter::response_to_proto(message), + Err(err) => AuthTransportAdapter::error_to_proto(err), }; - self.0.send(outbound).await + self.send_client_response(payload).await } } #[async_trait] impl Receiver for AuthTransportAdapter<'_> { async fn recv(&mut self) -> Option { - let request = match self.0.recv().await? { + let request = match self.bi.recv().await? { Ok(request) => request, Err(error) => { warn!(error = ?error, "grpc client recv failed; closing stream"); @@ -89,6 +114,15 @@ impl Receiver for AuthTransportAdapter<'_> { } }; + let request_id = match self.request_tracker.request(request.request_id) { + Ok(request_id) => request_id, + Err(error) => { + let _ = self.bi.send(Err(error)).await; + return None; + } + }; + *self.response_id = Some(request_id); + let payload = request.payload?; match payload { @@ -114,7 +148,13 @@ impl Receiver for AuthTransportAdapter<'_> { }; Some(auth::Inbound::AuthChallengeSolution { signature }) } - _ => None, + _ => { + let _ = self + .bi + .send(Err(Status::invalid_argument("Unsupported client auth request"))) + .await; + None + } } } } @@ -124,8 +164,10 @@ impl Bi> for AuthTransportAda pub async fn start( conn: &mut ClientConnection, bi: &mut GrpcBi, + request_tracker: &mut RequestTracker, + response_id: &mut Option, ) -> Result<(), auth::Error> { - let mut transport = AuthTransportAdapter(bi); + let mut transport = AuthTransportAdapter::new(bi, request_tracker, response_id); client::auth::authenticate(conn, &mut transport).await?; Ok(()) } diff --git a/server/crates/arbiter-server/src/grpc/mod.rs b/server/crates/arbiter-server/src/grpc/mod.rs index 7bdedea..709d24d 100644 --- a/server/crates/arbiter-server/src/grpc/mod.rs +++ b/server/crates/arbiter-server/src/grpc/mod.rs @@ -17,6 +17,7 @@ use crate::{ }; pub mod client; +mod request_tracker; pub mod user_agent; #[async_trait] diff --git a/server/crates/arbiter-server/src/grpc/request_tracker.rs b/server/crates/arbiter-server/src/grpc/request_tracker.rs new file mode 100644 index 0000000..e282343 --- /dev/null +++ b/server/crates/arbiter-server/src/grpc/request_tracker.rs @@ -0,0 +1,20 @@ +use tonic::Status; + +#[derive(Default)] +pub struct RequestTracker { + next_request_id: i32, +} + +impl RequestTracker { + pub fn request(&mut self, id: i32) -> Result { + if id < self.next_request_id { + return Err(Status::invalid_argument("Duplicate request id")); + } + + self.next_request_id = id + .checked_add(1) + .ok_or_else(|| Status::invalid_argument("Invalid request id"))?; + + Ok(id) + } +} diff --git a/server/crates/arbiter-server/src/grpc/user_agent.rs b/server/crates/arbiter-server/src/grpc/user_agent.rs index fe587dc..c3fb347 100644 --- a/server/crates/arbiter-server/src/grpc/user_agent.rs +++ b/server/crates/arbiter-server/src/grpc/user_agent.rs @@ -53,6 +53,7 @@ use crate::{ Grant, SharedGrantSettings, SpecificGrant, TransactionRateLimit, VolumeRateLimit, ether_transfer, token_transfers, }, + grpc::request_tracker::RequestTracker, utils::defer, }; use alloy::primitives::{Address, U256}; @@ -74,6 +75,7 @@ async fn dispatch_loop( mut bi: GrpcBi, actor: ActorRef, mut receiver: mpsc::Receiver, + mut request_tracker: RequestTracker, ) { loop { tokio::select! { @@ -92,7 +94,10 @@ async fn dispatch_loop( return; }; - if dispatch_conn_message(&mut bi, &actor, conn).await.is_err() { + if dispatch_conn_message(&mut bi, &actor, &mut request_tracker, conn) + .await + .is_err() + { return; } } @@ -103,6 +108,7 @@ async fn dispatch_loop( async fn dispatch_conn_message( bi: &mut GrpcBi, actor: &ActorRef, + request_tracker: &mut RequestTracker, conn: Result, ) -> Result<(), ()> { let conn = match conn { @@ -113,6 +119,14 @@ async fn dispatch_conn_message( } }; + let request_id = match request_tracker.request(conn.id) { + Ok(request_id) => request_id, + Err(err) => { + let _ = bi.send(Err(err)).await; + return Err(()); + } + }; + let Some(payload) = conn.payload else { let _ = bi .send(Err(Status::invalid_argument( @@ -267,6 +281,7 @@ async fn dispatch_conn_message( }; bi.send(Ok(UserAgentResponse { + id: Some(request_id), payload: Some(payload), })) .await @@ -289,6 +304,7 @@ async fn send_out_of_band( }; bi.send(Ok(UserAgentResponse { + id: None, payload: Some(payload), })) .await @@ -558,7 +574,17 @@ pub async fn start( mut conn: UserAgentConnection, mut bi: GrpcBi, ) { - let pubkey = match auth::start(&mut conn, &mut bi).await { + let mut request_tracker = RequestTracker::default(); + let mut response_id = None; + + let pubkey = match auth::start( + &mut conn, + &mut bi, + &mut request_tracker, + &mut response_id, + ) + .await + { Ok(pubkey) => pubkey, Err(e) => { warn!(error = ?e, "Authentication failed"); @@ -572,11 +598,10 @@ pub async fn start( let actor = UserAgentSession::spawn(UserAgentSession::new(conn, Box::new(oob_adapter))); let actor_for_cleanup = actor.clone(); - // when connection closes let _ = defer(move || { actor_for_cleanup.kill(); }); info!(?pubkey, "User authenticated successfully"); - dispatch_loop(bi, actor, oob_receiver).await; + dispatch_loop(bi, actor, oob_receiver, request_tracker).await; } diff --git a/server/crates/arbiter-server/src/grpc/user_agent/auth.rs b/server/crates/arbiter-server/src/grpc/user_agent/auth.rs index 0e648a8..024190d 100644 --- a/server/crates/arbiter-server/src/grpc/user_agent/auth.rs +++ b/server/crates/arbiter-server/src/grpc/user_agent/auth.rs @@ -1,52 +1,56 @@ use arbiter_proto::{ - proto::{ - self, - evm::{ - EtherTransferSettings as ProtoEtherTransferSettings, EvmError as ProtoEvmError, - EvmGrantCreateRequest, EvmGrantCreateResponse, EvmGrantDeleteRequest, - EvmGrantDeleteResponse, EvmGrantList, EvmGrantListResponse, GrantEntry, - SharedSettings as ProtoSharedSettings, SpecificGrant as ProtoSpecificGrant, - TokenTransferSettings as ProtoTokenTransferSettings, - VolumeRateLimit as ProtoVolumeRateLimit, WalletCreateResponse, WalletEntry, WalletList, - WalletListResponse, evm_grant_create_response::Result as EvmGrantCreateResult, - evm_grant_delete_response::Result as EvmGrantDeleteResult, - evm_grant_list_response::Result as EvmGrantListResult, - specific_grant::Grant as ProtoSpecificGrantType, - wallet_create_response::Result as WalletCreateResult, - wallet_list_response::Result as WalletListResult, - }, - user_agent::{ - AuthChallenge as ProtoAuthChallenge, AuthChallengeRequest as ProtoAuthChallengeRequest, - AuthChallengeSolution as ProtoAuthChallengeSolution, AuthResult as ProtoAuthResult, - BootstrapEncryptedKey as ProtoBootstrapEncryptedKey, - BootstrapResult as ProtoBootstrapResult, ClientConnectionCancel, - ClientConnectionRequest, ClientConnectionResponse, KeyType as ProtoKeyType, - UnsealEncryptedKey as ProtoUnsealEncryptedKey, UnsealResult as ProtoUnsealResult, - UnsealStart, UnsealStartResponse, UserAgentRequest, UserAgentResponse, - VaultState as ProtoVaultState, user_agent_request::Payload as UserAgentRequestPayload, - user_agent_response::Payload as UserAgentResponsePayload, - }, + proto::user_agent::{ + AuthChallenge as ProtoAuthChallenge, AuthChallengeRequest as ProtoAuthChallengeRequest, + AuthChallengeSolution as ProtoAuthChallengeSolution, AuthResult as ProtoAuthResult, + KeyType as ProtoKeyType, UserAgentRequest, UserAgentResponse, + user_agent_request::Payload as UserAgentRequestPayload, + user_agent_response::Payload as UserAgentResponsePayload, }, transport::{Bi, Error as TransportError, Receiver, Sender, grpc::GrpcBi}, }; use async_trait::async_trait; -use tonic::{Status, Streaming}; -use tracing::{info, warn}; +use tonic::Status; +use tracing::warn; use crate::{ - actors::user_agent::{ - self, AuthPublicKey, OutOfBand as DomainResponse, UserAgentConnection, auth, - }, + actors::user_agent::{AuthPublicKey, UserAgentConnection, auth}, db::models::KeyType, - evm::policies::{ - Grant, SharedGrantSettings, SpecificGrant, TransactionRateLimit, VolumeRateLimit, - ether_transfer, token_transfers, - }, + grpc::request_tracker::RequestTracker, }; -use alloy::primitives::{Address, U256}; -use chrono::{DateTime, TimeZone, Utc}; -pub struct AuthTransportAdapter<'a>(&'a mut GrpcBi); +pub struct AuthTransportAdapter<'a> { + bi: &'a mut GrpcBi, + request_tracker: &'a mut RequestTracker, + response_id: &'a mut Option, +} + +impl<'a> AuthTransportAdapter<'a> { + pub fn new( + bi: &'a mut GrpcBi, + request_tracker: &'a mut RequestTracker, + response_id: &'a mut Option, + ) -> Self { + Self { + bi, + request_tracker, + response_id, + } + } + + async fn send_user_agent_response( + &mut self, + payload: UserAgentResponsePayload, + ) -> Result<(), TransportError> { + let id = self.response_id.take(); + + self.bi + .send(Ok(UserAgentResponse { + id, + payload: Some(payload), + })) + .await + } +} #[async_trait] impl Sender> for AuthTransportAdapter<'_> { @@ -55,39 +59,53 @@ impl Sender> for AuthTransportAdapter<'_> { item: Result, ) -> Result<(), TransportError> { use auth::{Error, Outbound}; - let response = match item { - Ok(Outbound::AuthChallenge { nonce }) => Ok(UserAgentResponsePayload::AuthChallenge( - ProtoAuthChallenge { nonce }, - )), - Ok(Outbound::AuthSuccess) => Ok(UserAgentResponsePayload::AuthResult( - ProtoAuthResult::Success.into(), - )), - - Err(Error::UnregisteredPublicKey) => Ok(UserAgentResponsePayload::AuthResult( - ProtoAuthResult::InvalidKey.into(), - )), - Err(Error::InvalidChallengeSolution) => Ok(UserAgentResponsePayload::AuthResult( - ProtoAuthResult::InvalidSignature.into(), - )), - Err(Error::InvalidBootstrapToken) => Ok(UserAgentResponsePayload::BootstrapResult( - ProtoAuthResult::TokenInvalid.into(), - )), - Err(Error::Internal { details }) => Err(Status::internal(details)), - Err(Error::Transport) => Err(Status::unavailable("transport error")), + let payload = match item { + Ok(Outbound::AuthChallenge { nonce }) => { + UserAgentResponsePayload::AuthChallenge(ProtoAuthChallenge { nonce }) + } + Ok(Outbound::AuthSuccess) => { + UserAgentResponsePayload::AuthResult(ProtoAuthResult::Success.into()) + } + Err(Error::UnregisteredPublicKey) => { + UserAgentResponsePayload::AuthResult(ProtoAuthResult::InvalidKey.into()) + } + Err(Error::InvalidChallengeSolution) => { + UserAgentResponsePayload::AuthResult(ProtoAuthResult::InvalidSignature.into()) + } + Err(Error::InvalidBootstrapToken) => { + UserAgentResponsePayload::AuthResult(ProtoAuthResult::TokenInvalid.into()) + } + Err(Error::Internal { details }) => return self.bi.send(Err(Status::internal(details))).await, + Err(Error::Transport) => { + return self.bi.send(Err(Status::unavailable("transport error"))).await; + } }; - self.0 - .send(response.map(|r| UserAgentResponse { payload: Some(r) })) - .await + + self.send_user_agent_response(payload).await } } #[async_trait] impl Receiver for AuthTransportAdapter<'_> { async fn recv(&mut self) -> Option { - let Ok(UserAgentRequest { - payload: Some(payload), - }) = self.0.recv().await? - else { + let request = match self.bi.recv().await? { + Ok(request) => request, + Err(error) => { + warn!(error = ?error, "Failed to receive user agent auth request"); + return None; + } + }; + + let request_id = match self.request_tracker.request(request.id) { + Ok(request_id) => request_id, + Err(error) => { + let _ = self.bi.send(Err(error)).await; + return None; + } + }; + *self.response_id = Some(request_id); + + let Some(payload) = request.payload else { warn!( event = "received request with empty payload", "grpc.useragent.auth_adapter" @@ -136,16 +154,27 @@ impl Receiver for AuthTransportAdapter<'_> { UserAgentRequestPayload::AuthChallengeSolution(ProtoAuthChallengeSolution { signature, }) => Some(auth::Inbound::AuthChallengeSolution { signature }), - _ => None, // Ignore other request types for this adapter + _ => { + let _ = self + .bi + .send(Err(Status::invalid_argument( + "Unsupported user-agent auth request", + ))) + .await; + None + } } } } + impl Bi> for AuthTransportAdapter<'_> {} pub async fn start( conn: &mut UserAgentConnection, bi: &mut GrpcBi, + request_tracker: &mut RequestTracker, + response_id: &mut Option, ) -> Result { - let mut transport = AuthTransportAdapter(bi); + let transport = AuthTransportAdapter::new(bi, request_tracker, response_id); auth::authenticate(conn, transport).await }