From 82b5b85f5279ea1fe7406d01952842d2417c2787 Mon Sep 17 00:00:00 2001 From: hdbg Date: Fri, 3 Apr 2026 19:15:53 +0200 Subject: [PATCH] refactor(proto): nest client protocol and extract shared schemas --- server/crates/arbiter-client/src/auth.rs | 58 +++++++++------ server/crates/arbiter-proto/src/lib.rs | 20 ++++++ .../crates/arbiter-server/src/grpc/client.rs | 39 +++++++--- .../arbiter-server/src/grpc/client/auth.rs | 72 +++++++++++-------- .../src/grpc/user_agent/sdk_client.rs | 2 +- .../src/grpc/user_agent/vault.rs | 3 +- 6 files changed, 133 insertions(+), 61 deletions(-) diff --git a/server/crates/arbiter-client/src/auth.rs b/server/crates/arbiter-client/src/auth.rs index a0e2b5c..b42f82a 100644 --- a/server/crates/arbiter-client/src/auth.rs +++ b/server/crates/arbiter-client/src/auth.rs @@ -1,9 +1,17 @@ use arbiter_proto::{ ClientMetadata, format_challenge, - proto::client::{ - AuthChallengeRequest, AuthChallengeSolution, AuthResult, ClientInfo as ProtoClientInfo, - ClientRequest, client_request::Payload as ClientRequestPayload, - client_response::Payload as ClientResponsePayload, + proto::{ + client::{ + ClientRequest, + auth::{ + self as proto_auth, AuthChallenge, AuthChallengeRequest, AuthChallengeSolution, + AuthResult, request::Payload as AuthRequestPayload, + response::Payload as AuthResponsePayload, + }, + client_request::Payload as ClientRequestPayload, + client_response::Payload as ClientResponsePayload, + }, + shared::ClientInfo as ProtoClientInfo, }, }; use ed25519_dalek::Signer as _; @@ -51,16 +59,16 @@ async fn send_auth_challenge_request( transport .send(ClientRequest { request_id: next_request_id(), - payload: Some(ClientRequestPayload::AuthChallengeRequest( - AuthChallengeRequest { + payload: Some(ClientRequestPayload::Auth(proto_auth::Request { + payload: Some(AuthRequestPayload::ChallengeRequest(AuthChallengeRequest { pubkey: key.verifying_key().to_bytes().to_vec(), client_info: Some(ProtoClientInfo { name: metadata.name, description: metadata.description, version: metadata.version, }), - }, - )), + })), + })), }) .await .map_err(|_| AuthError::UnexpectedAuthResponse) @@ -68,7 +76,7 @@ async fn send_auth_challenge_request( async fn receive_auth_challenge( transport: &mut ClientTransport, -) -> std::result::Result { +) -> std::result::Result { let response = transport .recv() .await @@ -76,8 +84,11 @@ async fn receive_auth_challenge( let payload = response.payload.ok_or(AuthError::MissingAuthChallenge)?; match payload { - ClientResponsePayload::AuthChallenge(challenge) => Ok(challenge), - ClientResponsePayload::AuthResult(result) => Err(map_auth_result(result)), + ClientResponsePayload::Auth(response) => match response.payload { + Some(AuthResponsePayload::Challenge(challenge)) => Ok(challenge), + Some(AuthResponsePayload::Result(result)) => Err(map_auth_result(result)), + None => Err(AuthError::MissingAuthChallenge), + }, _ => Err(AuthError::UnexpectedAuthResponse), } } @@ -85,7 +96,7 @@ async fn receive_auth_challenge( async fn send_auth_challenge_solution( transport: &mut ClientTransport, key: &ed25519_dalek::SigningKey, - challenge: arbiter_proto::proto::client::AuthChallenge, + challenge: AuthChallenge, ) -> std::result::Result<(), AuthError> { let challenge_payload = format_challenge(challenge.nonce, &challenge.pubkey); let signature = key.sign(&challenge_payload).to_bytes().to_vec(); @@ -93,9 +104,11 @@ async fn send_auth_challenge_solution( transport .send(ClientRequest { request_id: next_request_id(), - payload: Some(ClientRequestPayload::AuthChallengeSolution( - AuthChallengeSolution { signature }, - )), + payload: Some(ClientRequestPayload::Auth(proto_auth::Request { + payload: Some(AuthRequestPayload::ChallengeSolution( + AuthChallengeSolution { signature }, + )), + })), }) .await .map_err(|_| AuthError::UnexpectedAuthResponse) @@ -113,12 +126,15 @@ async fn receive_auth_confirmation( .payload .ok_or(AuthError::UnexpectedAuthResponse)?; match payload { - ClientResponsePayload::AuthResult(result) - if AuthResult::try_from(result).ok() == Some(AuthResult::Success) => - { - Ok(()) - } - ClientResponsePayload::AuthResult(result) => Err(map_auth_result(result)), + ClientResponsePayload::Auth(response) => match response.payload { + Some(AuthResponsePayload::Result(result)) + if AuthResult::try_from(result).ok() == Some(AuthResult::Success) => + { + Ok(()) + } + Some(AuthResponsePayload::Result(result)) => Err(map_auth_result(result)), + _ => Err(AuthError::UnexpectedAuthResponse), + }, _ => Err(AuthError::UnexpectedAuthResponse), } } diff --git a/server/crates/arbiter-proto/src/lib.rs b/server/crates/arbiter-proto/src/lib.rs index 323254a..141b231 100644 --- a/server/crates/arbiter-proto/src/lib.rs +++ b/server/crates/arbiter-proto/src/lib.rs @@ -6,6 +6,14 @@ use base64::{Engine, prelude::BASE64_STANDARD}; pub mod proto { tonic::include_proto!("arbiter"); + pub mod shared { + tonic::include_proto!("arbiter.shared"); + + pub mod evm { + tonic::include_proto!("arbiter.shared.evm"); + } + } + pub mod user_agent { tonic::include_proto!("arbiter.user_agent"); @@ -36,6 +44,18 @@ pub mod proto { pub mod client { tonic::include_proto!("arbiter.client"); + + pub mod auth { + tonic::include_proto!("arbiter.client.auth"); + } + + pub mod evm { + tonic::include_proto!("arbiter.client.evm"); + } + + pub mod vault { + tonic::include_proto!("arbiter.client.vault"); + } } pub mod evm { diff --git a/server/crates/arbiter-server/src/grpc/client.rs b/server/crates/arbiter-server/src/grpc/client.rs index cd032f4..7fff51c 100644 --- a/server/crates/arbiter-server/src/grpc/client.rs +++ b/server/crates/arbiter-server/src/grpc/client.rs @@ -1,8 +1,12 @@ use arbiter_proto::{ - proto::client::{ - ClientRequest, ClientResponse, VaultState as ProtoVaultState, - client_request::Payload as ClientRequestPayload, - client_response::Payload as ClientResponsePayload, + proto::{ + client::{ + ClientRequest, ClientResponse, + client_request::Payload as ClientRequestPayload, + client_response::Payload as ClientResponsePayload, + vault::{self as proto_vault, request::Payload as VaultRequestPayload, response::Payload as VaultResponsePayload}, + }, + shared::VaultState as ProtoVaultState, }, transport::{Receiver, Sender, grpc::GrpcBi}, }; @@ -79,7 +83,24 @@ async fn dispatch_inner( payload: ClientRequestPayload, ) -> Result { match payload { - ClientRequestPayload::QueryVaultState(_) => { + ClientRequestPayload::Vault(req) => dispatch_vault_request(actor, req).await, + payload => { + warn!(?payload, "Unsupported post-auth client request"); + Err(Status::invalid_argument("Unsupported client request")) + } + } +} + +async fn dispatch_vault_request( + actor: &ActorRef, + req: proto_vault::Request, +) -> Result { + let Some(payload) = req.payload else { + return Err(Status::invalid_argument("Missing client vault request payload")); + }; + + match payload { + VaultRequestPayload::QueryState(_) => { let state = match actor.ask(HandleQueryVaultState {}).await { Ok(KeyHolderState::Unbootstrapped) => ProtoVaultState::Unbootstrapped, Ok(KeyHolderState::Sealed) => ProtoVaultState::Sealed, @@ -90,11 +111,9 @@ async fn dispatch_inner( ProtoVaultState::Error } }; - Ok(ClientResponsePayload::VaultState(state.into())) - } - payload => { - warn!(?payload, "Unsupported post-auth client request"); - Err(Status::invalid_argument("Unsupported client request")) + Ok(ClientResponsePayload::Vault(proto_vault::Response { + payload: Some(VaultResponsePayload::State(state.into())), + })) } } } diff --git a/server/crates/arbiter-server/src/grpc/client/auth.rs b/server/crates/arbiter-server/src/grpc/client/auth.rs index c711520..e5e141d 100644 --- a/server/crates/arbiter-server/src/grpc/client/auth.rs +++ b/server/crates/arbiter-server/src/grpc/client/auth.rs @@ -1,11 +1,20 @@ use arbiter_proto::{ - ClientMetadata, proto::client::{ - AuthChallenge as ProtoAuthChallenge, AuthChallengeRequest as ProtoAuthChallengeRequest, - AuthChallengeSolution as ProtoAuthChallengeSolution, AuthResult as ProtoAuthResult, - ClientInfo as ProtoClientInfo, ClientRequest, ClientResponse, - client_request::Payload as ClientRequestPayload, - client_response::Payload as ClientResponsePayload, - }, transport::{Bi, Error as TransportError, Receiver, Sender, grpc::GrpcBi} + ClientMetadata, + proto::{ + client::{ + ClientRequest, ClientResponse, + auth::{ + self as proto_auth, AuthChallenge as ProtoAuthChallenge, + AuthChallengeRequest as ProtoAuthChallengeRequest, + AuthChallengeSolution as ProtoAuthChallengeSolution, AuthResult as ProtoAuthResult, + request::Payload as AuthRequestPayload, response::Payload as AuthResponsePayload, + }, + client_request::Payload as ClientRequestPayload, + client_response::Payload as ClientResponsePayload, + }, + shared::ClientInfo as ProtoClientInfo, + }, + transport::{Bi, Error as TransportError, Receiver, Sender, grpc::GrpcBi} }; use async_trait::async_trait; use tonic::Status; @@ -32,22 +41,20 @@ impl<'a> AuthTransportAdapter<'a> { } } - fn response_to_proto(response: auth::Outbound) -> ClientResponsePayload { + fn response_to_proto(response: auth::Outbound) -> AuthResponsePayload { match response { auth::Outbound::AuthChallenge { pubkey, nonce } => { - ClientResponsePayload::AuthChallenge(ProtoAuthChallenge { + AuthResponsePayload::Challenge(ProtoAuthChallenge { pubkey: pubkey.to_bytes().to_vec(), nonce, }) } - auth::Outbound::AuthSuccess => { - ClientResponsePayload::AuthResult(ProtoAuthResult::Success.into()) - } + auth::Outbound::AuthSuccess => AuthResponsePayload::Result(ProtoAuthResult::Success.into()), } } - fn error_to_proto(error: auth::Error) -> ClientResponsePayload { - ClientResponsePayload::AuthResult( + fn error_to_proto(error: auth::Error) -> AuthResponsePayload { + AuthResponsePayload::Result( match error { auth::Error::InvalidChallengeSolution => ProtoAuthResult::InvalidSignature, auth::Error::ApproveError(auth::ApproveError::Denied) => { @@ -67,18 +74,20 @@ impl<'a> AuthTransportAdapter<'a> { async fn send_client_response( &mut self, - payload: ClientResponsePayload, + payload: AuthResponsePayload, ) -> Result<(), TransportError> { self.bi .send(Ok(ClientResponse { request_id: Some(self.request_tracker.current_request_id()), - payload: Some(payload), + payload: Some(ClientResponsePayload::Auth(proto_auth::Response { + payload: Some(payload), + })), })) .await } async fn send_auth_result(&mut self, result: ProtoAuthResult) -> Result<(), TransportError> { - self.send_client_response(ClientResponsePayload::AuthResult(result.into())) + self.send_client_response(AuthResponsePayload::Result(result.into())) .await } } @@ -117,9 +126,25 @@ impl Receiver for AuthTransportAdapter<'_> { } }; let payload = request.payload?; + let ClientRequestPayload::Auth(auth_request) = payload else { + let _ = self + .bi + .send(Err(Status::invalid_argument( + "Unsupported client auth request", + ))) + .await; + return None; + }; + let Some(payload) = auth_request.payload else { + let _ = self + .bi + .send(Err(Status::invalid_argument("Missing client auth request payload"))) + .await; + return None; + }; match payload { - ClientRequestPayload::AuthChallengeRequest(ProtoAuthChallengeRequest { + AuthRequestPayload::ChallengeRequest(ProtoAuthChallengeRequest { pubkey, client_info, }) => { @@ -143,7 +168,7 @@ impl Receiver for AuthTransportAdapter<'_> { metadata: client_metadata_from_proto(client_info), }) } - ClientRequestPayload::AuthChallengeSolution(ProtoAuthChallengeSolution { + AuthRequestPayload::ChallengeSolution(ProtoAuthChallengeSolution { signature, }) => { let Ok(signature) = ed25519_dalek::Signature::try_from(signature.as_slice()) else { @@ -154,15 +179,6 @@ impl Receiver for AuthTransportAdapter<'_> { }; Some(auth::Inbound::AuthChallengeSolution { signature }) } - _ => { - let _ = self - .bi - .send(Err(Status::invalid_argument( - "Unsupported client auth request", - ))) - .await; - None - } } } } diff --git a/server/crates/arbiter-server/src/grpc/user_agent/sdk_client.rs b/server/crates/arbiter-server/src/grpc/user_agent/sdk_client.rs index 6e40514..f1827af 100644 --- a/server/crates/arbiter-server/src/grpc/user_agent/sdk_client.rs +++ b/server/crates/arbiter-server/src/grpc/user_agent/sdk_client.rs @@ -1,5 +1,4 @@ use arbiter_proto::proto::{ - client::ClientInfo as ProtoClientMetadata, user_agent::{ sdk_client::{ self as proto_sdk_client, ConnectionCancel as ProtoSdkClientConnectionCancel, @@ -14,6 +13,7 @@ use arbiter_proto::proto::{ }, user_agent_response::Payload as UserAgentResponsePayload, }, + shared::ClientInfo as ProtoClientMetadata, }; use kameo::actor::ActorRef; use tonic::Status; diff --git a/server/crates/arbiter-server/src/grpc/user_agent/vault.rs b/server/crates/arbiter-server/src/grpc/user_agent/vault.rs index 5aad751..669d35c 100644 --- a/server/crates/arbiter-server/src/grpc/user_agent/vault.rs +++ b/server/crates/arbiter-server/src/grpc/user_agent/vault.rs @@ -1,7 +1,7 @@ use arbiter_proto::proto::user_agent::{ user_agent_response::Payload as UserAgentResponsePayload, vault::{ - self as proto_vault, VaultState as ProtoVaultState, + self as proto_vault, bootstrap::{ self as proto_bootstrap, BootstrapEncryptedKey as ProtoBootstrapEncryptedKey, BootstrapResult as ProtoBootstrapResult, @@ -16,6 +16,7 @@ use arbiter_proto::proto::user_agent::{ }, }, }; +use arbiter_proto::proto::shared::VaultState as ProtoVaultState; use kameo::{actor::ActorRef, error::SendError}; use tonic::Status; use tracing::warn;