feat(proto): request / response pair tracking by assigning id

This commit is contained in:
hdbg
2026-03-18 23:43:44 +01:00
committed by Stas
parent 60ce1cc110
commit 3e8b26418a
8 changed files with 259 additions and 118 deletions

View File

@@ -37,6 +37,7 @@ enum VaultState {
}
message ClientRequest {
int32 request_id = 4;
oneof payload {
AuthChallengeRequest auth_challenge_request = 1;
AuthChallengeSolution auth_challenge_solution = 2;
@@ -45,6 +46,7 @@ message ClientRequest {
}
message ClientResponse {
optional int32 request_id = 7;
oneof payload {
AuthChallenge auth_challenge = 1;
AuthResult auth_result = 2;

View File

@@ -89,6 +89,7 @@ message ClientConnectionResponse {
message ClientConnectionCancel {}
message UserAgentRequest {
int32 id = 14;
oneof payload {
AuthChallengeRequest auth_challenge_request = 1;
AuthChallengeSolution auth_challenge_solution = 2;
@@ -105,6 +106,7 @@ message UserAgentRequest {
}
}
message UserAgentResponse {
optional int32 id = 14;
oneof payload {
AuthChallenge auth_challenge = 1;
AuthResult auth_result = 2;

View File

@@ -10,6 +10,7 @@ use kameo::{
actor::{ActorRef, Spawn as _},
error::SendError,
};
use tonic::Status;
use tracing::{info, warn};
use crate::{
@@ -20,6 +21,7 @@ use crate::{
},
keyholder::KeyHolderState,
},
grpc::request_tracker::RequestTracker,
utils::defer,
};
@@ -28,13 +30,17 @@ mod auth;
async fn dispatch_loop(
mut bi: GrpcBi<ClientRequest, ClientResponse>,
actor: ActorRef<ClientSession>,
mut request_tracker: RequestTracker,
) {
loop {
let Some(conn) = bi.recv().await else {
return;
};
if dispatch_conn_message(&mut bi, &actor, conn).await.is_err() {
if dispatch_conn_message(&mut bi, &actor, &mut request_tracker, conn)
.await
.is_err()
{
return;
}
}
@@ -43,7 +49,8 @@ async fn dispatch_loop(
async fn dispatch_conn_message(
bi: &mut GrpcBi<ClientRequest, ClientResponse>,
actor: &ActorRef<ClientSession>,
conn: Result<ClientRequest, tonic::Status>,
request_tracker: &mut RequestTracker,
conn: Result<ClientRequest, Status>,
) -> Result<(), ()> {
let conn = match conn {
Ok(conn) => conn,
@@ -53,9 +60,16 @@ async fn dispatch_conn_message(
}
};
let request_id = match request_tracker.request(conn.request_id) {
Ok(request_id) => request_id,
Err(err) => {
let _ = bi.send(Err(err)).await;
return Err(());
}
};
let Some(payload) = conn.payload else {
let _ = bi
.send(Err(tonic::Status::invalid_argument(
.send(Err(Status::invalid_argument(
"Missing client request payload",
)))
.await;
@@ -79,15 +93,14 @@ async fn dispatch_conn_message(
payload => {
warn!(?payload, "Unsupported post-auth client request");
let _ = bi
.send(Err(tonic::Status::invalid_argument(
"Unsupported client request",
)))
.send(Err(Status::invalid_argument("Unsupported client request")))
.await;
return Err(());
}
};
bi.send(Ok(ClientResponse {
request_id: Some(request_id),
payload: Some(payload),
}))
.await
@@ -96,7 +109,10 @@ async fn dispatch_conn_message(
pub async fn start(conn: ClientConnection, mut bi: GrpcBi<ClientRequest, ClientResponse>) {
let mut conn = conn;
match auth::start(&mut conn, &mut bi).await {
let mut request_tracker = RequestTracker::default();
let mut response_id = None;
match auth::start(&mut conn, &mut bi, &mut request_tracker, &mut response_id).await {
Ok(_) => {
let actor =
client::session::ClientSession::spawn(client::session::ClientSession::new(conn));
@@ -106,10 +122,14 @@ pub async fn start(conn: ClientConnection, mut bi: GrpcBi<ClientRequest, ClientR
});
info!("Client authenticated successfully");
dispatch_loop(bi, actor).await;
dispatch_loop(bi, actor, request_tracker).await;
}
Err(e) => {
let mut transport = auth::AuthTransportAdapter(&mut bi);
let mut transport = auth::AuthTransportAdapter::new(
&mut bi,
&mut request_tracker,
&mut response_id,
);
let _ = transport.send(Err(e.clone())).await;
warn!(error = ?e, "Authentication failed");
return;

View File

@@ -8,15 +8,35 @@ use arbiter_proto::{
transport::{Bi, Error as TransportError, Receiver, Sender, grpc::GrpcBi},
};
use async_trait::async_trait;
use tonic::Status;
use tracing::warn;
use crate::actors::client::{self, ClientConnection, auth};
use crate::{
actors::client::{self, ClientConnection, auth},
grpc::request_tracker::RequestTracker,
};
pub struct AuthTransportAdapter<'a>(pub(super) &'a mut GrpcBi<ClientRequest, ClientResponse>);
pub struct AuthTransportAdapter<'a> {
bi: &'a mut GrpcBi<ClientRequest, ClientResponse>,
request_tracker: &'a mut RequestTracker,
response_id: &'a mut Option<i32>,
}
impl AuthTransportAdapter<'_> {
fn response_to_proto(response: auth::Outbound) -> ClientResponse {
let payload = match response {
impl<'a> AuthTransportAdapter<'a> {
pub fn new(
bi: &'a mut GrpcBi<ClientRequest, ClientResponse>,
request_tracker: &'a mut RequestTracker,
response_id: &'a mut Option<i32>,
) -> Self {
Self {
bi,
request_tracker,
response_id,
}
}
fn response_to_proto(response: auth::Outbound) -> ClientResponsePayload {
match response {
auth::Outbound::AuthChallenge { pubkey, nonce } => {
ClientResponsePayload::AuthChallenge(ProtoAuthChallenge {
pubkey: pubkey.to_bytes().to_vec(),
@@ -26,39 +46,44 @@ impl AuthTransportAdapter<'_> {
auth::Outbound::AuthSuccess => {
ClientResponsePayload::AuthResult(ProtoAuthResult::Success.into())
}
};
ClientResponse {
payload: Some(payload),
}
}
fn error_to_proto(error: auth::Error) -> ClientResponse {
ClientResponse {
payload: Some(ClientResponsePayload::AuthResult(
match error {
auth::Error::InvalidChallengeSolution => ProtoAuthResult::InvalidSignature,
auth::Error::ApproveError(auth::ApproveError::Denied) => {
ProtoAuthResult::ApprovalDenied
}
auth::Error::ApproveError(auth::ApproveError::Upstream(
crate::actors::router::ApprovalError::NoUserAgentsConnected,
)) => ProtoAuthResult::NoUserAgentsOnline,
auth::Error::ApproveError(auth::ApproveError::Internal)
| auth::Error::DatabasePoolUnavailable
| auth::Error::DatabaseOperationFailed
| auth::Error::Transport => ProtoAuthResult::Internal,
fn error_to_proto(error: auth::Error) -> ClientResponsePayload {
ClientResponsePayload::AuthResult(
match error {
auth::Error::InvalidChallengeSolution => ProtoAuthResult::InvalidSignature,
auth::Error::ApproveError(auth::ApproveError::Denied) => {
ProtoAuthResult::ApprovalDenied
}
.into(),
)),
}
auth::Error::ApproveError(auth::ApproveError::Upstream(
crate::actors::router::ApprovalError::NoUserAgentsConnected,
)) => ProtoAuthResult::NoUserAgentsOnline,
auth::Error::ApproveError(auth::ApproveError::Internal)
| auth::Error::DatabasePoolUnavailable
| auth::Error::DatabaseOperationFailed
| auth::Error::Transport => ProtoAuthResult::Internal,
}
.into(),
)
}
async fn send_client_response(
&mut self,
payload: ClientResponsePayload,
) -> Result<(), TransportError> {
let request_id = self.response_id.take();
self.bi
.send(Ok(ClientResponse {
request_id,
payload: Some(payload),
}))
.await
}
async fn send_auth_result(&mut self, result: ProtoAuthResult) -> Result<(), TransportError> {
self.0
.send(Ok(ClientResponse {
payload: Some(ClientResponsePayload::AuthResult(result.into())),
}))
self.send_client_response(ClientResponsePayload::AuthResult(result.into()))
.await
}
}
@@ -69,19 +94,19 @@ impl Sender<Result<auth::Outbound, auth::Error>> for AuthTransportAdapter<'_> {
&mut self,
item: Result<auth::Outbound, auth::Error>,
) -> Result<(), TransportError> {
let outbound = match item {
Ok(message) => Ok(AuthTransportAdapter::response_to_proto(message)),
Err(err) => Ok(AuthTransportAdapter::error_to_proto(err)),
let payload = match item {
Ok(message) => AuthTransportAdapter::response_to_proto(message),
Err(err) => AuthTransportAdapter::error_to_proto(err),
};
self.0.send(outbound).await
self.send_client_response(payload).await
}
}
#[async_trait]
impl Receiver<auth::Inbound> for AuthTransportAdapter<'_> {
async fn recv(&mut self) -> Option<auth::Inbound> {
let request = match self.0.recv().await? {
let request = match self.bi.recv().await? {
Ok(request) => request,
Err(error) => {
warn!(error = ?error, "grpc client recv failed; closing stream");
@@ -89,6 +114,15 @@ impl Receiver<auth::Inbound> for AuthTransportAdapter<'_> {
}
};
let request_id = match self.request_tracker.request(request.request_id) {
Ok(request_id) => request_id,
Err(error) => {
let _ = self.bi.send(Err(error)).await;
return None;
}
};
*self.response_id = Some(request_id);
let payload = request.payload?;
match payload {
@@ -114,7 +148,13 @@ impl Receiver<auth::Inbound> for AuthTransportAdapter<'_> {
};
Some(auth::Inbound::AuthChallengeSolution { signature })
}
_ => None,
_ => {
let _ = self
.bi
.send(Err(Status::invalid_argument("Unsupported client auth request")))
.await;
None
}
}
}
}
@@ -124,8 +164,10 @@ impl Bi<auth::Inbound, Result<auth::Outbound, auth::Error>> for AuthTransportAda
pub async fn start(
conn: &mut ClientConnection,
bi: &mut GrpcBi<ClientRequest, ClientResponse>,
request_tracker: &mut RequestTracker,
response_id: &mut Option<i32>,
) -> Result<(), auth::Error> {
let mut transport = AuthTransportAdapter(bi);
let mut transport = AuthTransportAdapter::new(bi, request_tracker, response_id);
client::auth::authenticate(conn, &mut transport).await?;
Ok(())
}

View File

@@ -17,6 +17,7 @@ use crate::{
};
pub mod client;
mod request_tracker;
pub mod user_agent;
#[async_trait]

View File

@@ -0,0 +1,20 @@
use tonic::Status;
#[derive(Default)]
pub struct RequestTracker {
next_request_id: i32,
}
impl RequestTracker {
pub fn request(&mut self, id: i32) -> Result<i32, Status> {
if id < self.next_request_id {
return Err(Status::invalid_argument("Duplicate request id"));
}
self.next_request_id = id
.checked_add(1)
.ok_or_else(|| Status::invalid_argument("Invalid request id"))?;
Ok(id)
}
}

View File

@@ -53,6 +53,7 @@ use crate::{
Grant, SharedGrantSettings, SpecificGrant, TransactionRateLimit, VolumeRateLimit,
ether_transfer, token_transfers,
},
grpc::request_tracker::RequestTracker,
utils::defer,
};
use alloy::primitives::{Address, U256};
@@ -74,6 +75,7 @@ async fn dispatch_loop(
mut bi: GrpcBi<UserAgentRequest, UserAgentResponse>,
actor: ActorRef<UserAgentSession>,
mut receiver: mpsc::Receiver<OutOfBand>,
mut request_tracker: RequestTracker,
) {
loop {
tokio::select! {
@@ -92,7 +94,10 @@ async fn dispatch_loop(
return;
};
if dispatch_conn_message(&mut bi, &actor, conn).await.is_err() {
if dispatch_conn_message(&mut bi, &actor, &mut request_tracker, conn)
.await
.is_err()
{
return;
}
}
@@ -103,6 +108,7 @@ async fn dispatch_loop(
async fn dispatch_conn_message(
bi: &mut GrpcBi<UserAgentRequest, UserAgentResponse>,
actor: &ActorRef<UserAgentSession>,
request_tracker: &mut RequestTracker,
conn: Result<UserAgentRequest, Status>,
) -> Result<(), ()> {
let conn = match conn {
@@ -113,6 +119,14 @@ async fn dispatch_conn_message(
}
};
let request_id = match request_tracker.request(conn.id) {
Ok(request_id) => request_id,
Err(err) => {
let _ = bi.send(Err(err)).await;
return Err(());
}
};
let Some(payload) = conn.payload else {
let _ = bi
.send(Err(Status::invalid_argument(
@@ -267,6 +281,7 @@ async fn dispatch_conn_message(
};
bi.send(Ok(UserAgentResponse {
id: Some(request_id),
payload: Some(payload),
}))
.await
@@ -289,6 +304,7 @@ async fn send_out_of_band(
};
bi.send(Ok(UserAgentResponse {
id: None,
payload: Some(payload),
}))
.await
@@ -558,7 +574,17 @@ pub async fn start(
mut conn: UserAgentConnection,
mut bi: GrpcBi<UserAgentRequest, UserAgentResponse>,
) {
let pubkey = match auth::start(&mut conn, &mut bi).await {
let mut request_tracker = RequestTracker::default();
let mut response_id = None;
let pubkey = match auth::start(
&mut conn,
&mut bi,
&mut request_tracker,
&mut response_id,
)
.await
{
Ok(pubkey) => pubkey,
Err(e) => {
warn!(error = ?e, "Authentication failed");
@@ -572,11 +598,10 @@ pub async fn start(
let actor = UserAgentSession::spawn(UserAgentSession::new(conn, Box::new(oob_adapter)));
let actor_for_cleanup = actor.clone();
// when connection closes
let _ = defer(move || {
actor_for_cleanup.kill();
});
info!(?pubkey, "User authenticated successfully");
dispatch_loop(bi, actor, oob_receiver).await;
dispatch_loop(bi, actor, oob_receiver, request_tracker).await;
}

View File

@@ -1,52 +1,56 @@
use arbiter_proto::{
proto::{
self,
evm::{
EtherTransferSettings as ProtoEtherTransferSettings, EvmError as ProtoEvmError,
EvmGrantCreateRequest, EvmGrantCreateResponse, EvmGrantDeleteRequest,
EvmGrantDeleteResponse, EvmGrantList, EvmGrantListResponse, GrantEntry,
SharedSettings as ProtoSharedSettings, SpecificGrant as ProtoSpecificGrant,
TokenTransferSettings as ProtoTokenTransferSettings,
VolumeRateLimit as ProtoVolumeRateLimit, WalletCreateResponse, WalletEntry, WalletList,
WalletListResponse, evm_grant_create_response::Result as EvmGrantCreateResult,
evm_grant_delete_response::Result as EvmGrantDeleteResult,
evm_grant_list_response::Result as EvmGrantListResult,
specific_grant::Grant as ProtoSpecificGrantType,
wallet_create_response::Result as WalletCreateResult,
wallet_list_response::Result as WalletListResult,
},
user_agent::{
AuthChallenge as ProtoAuthChallenge, AuthChallengeRequest as ProtoAuthChallengeRequest,
AuthChallengeSolution as ProtoAuthChallengeSolution, AuthResult as ProtoAuthResult,
BootstrapEncryptedKey as ProtoBootstrapEncryptedKey,
BootstrapResult as ProtoBootstrapResult, ClientConnectionCancel,
ClientConnectionRequest, ClientConnectionResponse, KeyType as ProtoKeyType,
UnsealEncryptedKey as ProtoUnsealEncryptedKey, UnsealResult as ProtoUnsealResult,
UnsealStart, UnsealStartResponse, UserAgentRequest, UserAgentResponse,
VaultState as ProtoVaultState, user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
},
proto::user_agent::{
AuthChallenge as ProtoAuthChallenge, AuthChallengeRequest as ProtoAuthChallengeRequest,
AuthChallengeSolution as ProtoAuthChallengeSolution, AuthResult as ProtoAuthResult,
KeyType as ProtoKeyType, UserAgentRequest, UserAgentResponse,
user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
},
transport::{Bi, Error as TransportError, Receiver, Sender, grpc::GrpcBi},
};
use async_trait::async_trait;
use tonic::{Status, Streaming};
use tracing::{info, warn};
use tonic::Status;
use tracing::warn;
use crate::{
actors::user_agent::{
self, AuthPublicKey, OutOfBand as DomainResponse, UserAgentConnection, auth,
},
actors::user_agent::{AuthPublicKey, UserAgentConnection, auth},
db::models::KeyType,
evm::policies::{
Grant, SharedGrantSettings, SpecificGrant, TransactionRateLimit, VolumeRateLimit,
ether_transfer, token_transfers,
},
grpc::request_tracker::RequestTracker,
};
use alloy::primitives::{Address, U256};
use chrono::{DateTime, TimeZone, Utc};
pub struct AuthTransportAdapter<'a>(&'a mut GrpcBi<UserAgentRequest, UserAgentResponse>);
pub struct AuthTransportAdapter<'a> {
bi: &'a mut GrpcBi<UserAgentRequest, UserAgentResponse>,
request_tracker: &'a mut RequestTracker,
response_id: &'a mut Option<i32>,
}
impl<'a> AuthTransportAdapter<'a> {
pub fn new(
bi: &'a mut GrpcBi<UserAgentRequest, UserAgentResponse>,
request_tracker: &'a mut RequestTracker,
response_id: &'a mut Option<i32>,
) -> Self {
Self {
bi,
request_tracker,
response_id,
}
}
async fn send_user_agent_response(
&mut self,
payload: UserAgentResponsePayload,
) -> Result<(), TransportError> {
let id = self.response_id.take();
self.bi
.send(Ok(UserAgentResponse {
id,
payload: Some(payload),
}))
.await
}
}
#[async_trait]
impl Sender<Result<auth::Outbound, auth::Error>> for AuthTransportAdapter<'_> {
@@ -55,39 +59,53 @@ impl Sender<Result<auth::Outbound, auth::Error>> for AuthTransportAdapter<'_> {
item: Result<auth::Outbound, auth::Error>,
) -> Result<(), TransportError> {
use auth::{Error, Outbound};
let response = match item {
Ok(Outbound::AuthChallenge { nonce }) => Ok(UserAgentResponsePayload::AuthChallenge(
ProtoAuthChallenge { nonce },
)),
Ok(Outbound::AuthSuccess) => Ok(UserAgentResponsePayload::AuthResult(
ProtoAuthResult::Success.into(),
)),
Err(Error::UnregisteredPublicKey) => Ok(UserAgentResponsePayload::AuthResult(
ProtoAuthResult::InvalidKey.into(),
)),
Err(Error::InvalidChallengeSolution) => Ok(UserAgentResponsePayload::AuthResult(
ProtoAuthResult::InvalidSignature.into(),
)),
Err(Error::InvalidBootstrapToken) => Ok(UserAgentResponsePayload::BootstrapResult(
ProtoAuthResult::TokenInvalid.into(),
)),
Err(Error::Internal { details }) => Err(Status::internal(details)),
Err(Error::Transport) => Err(Status::unavailable("transport error")),
let payload = match item {
Ok(Outbound::AuthChallenge { nonce }) => {
UserAgentResponsePayload::AuthChallenge(ProtoAuthChallenge { nonce })
}
Ok(Outbound::AuthSuccess) => {
UserAgentResponsePayload::AuthResult(ProtoAuthResult::Success.into())
}
Err(Error::UnregisteredPublicKey) => {
UserAgentResponsePayload::AuthResult(ProtoAuthResult::InvalidKey.into())
}
Err(Error::InvalidChallengeSolution) => {
UserAgentResponsePayload::AuthResult(ProtoAuthResult::InvalidSignature.into())
}
Err(Error::InvalidBootstrapToken) => {
UserAgentResponsePayload::AuthResult(ProtoAuthResult::TokenInvalid.into())
}
Err(Error::Internal { details }) => return self.bi.send(Err(Status::internal(details))).await,
Err(Error::Transport) => {
return self.bi.send(Err(Status::unavailable("transport error"))).await;
}
};
self.0
.send(response.map(|r| UserAgentResponse { payload: Some(r) }))
.await
self.send_user_agent_response(payload).await
}
}
#[async_trait]
impl Receiver<auth::Inbound> for AuthTransportAdapter<'_> {
async fn recv(&mut self) -> Option<auth::Inbound> {
let Ok(UserAgentRequest {
payload: Some(payload),
}) = self.0.recv().await?
else {
let request = match self.bi.recv().await? {
Ok(request) => request,
Err(error) => {
warn!(error = ?error, "Failed to receive user agent auth request");
return None;
}
};
let request_id = match self.request_tracker.request(request.id) {
Ok(request_id) => request_id,
Err(error) => {
let _ = self.bi.send(Err(error)).await;
return None;
}
};
*self.response_id = Some(request_id);
let Some(payload) = request.payload else {
warn!(
event = "received request with empty payload",
"grpc.useragent.auth_adapter"
@@ -136,16 +154,27 @@ impl Receiver<auth::Inbound> for AuthTransportAdapter<'_> {
UserAgentRequestPayload::AuthChallengeSolution(ProtoAuthChallengeSolution {
signature,
}) => Some(auth::Inbound::AuthChallengeSolution { signature }),
_ => None, // Ignore other request types for this adapter
_ => {
let _ = self
.bi
.send(Err(Status::invalid_argument(
"Unsupported user-agent auth request",
)))
.await;
None
}
}
}
}
impl Bi<auth::Inbound, Result<auth::Outbound, auth::Error>> for AuthTransportAdapter<'_> {}
pub async fn start(
conn: &mut UserAgentConnection,
bi: &mut GrpcBi<UserAgentRequest, UserAgentResponse>,
request_tracker: &mut RequestTracker,
response_id: &mut Option<i32>,
) -> Result<AuthPublicKey, auth::Error> {
let mut transport = AuthTransportAdapter(bi);
let transport = AuthTransportAdapter::new(bi, request_tracker, response_id);
auth::authenticate(conn, transport).await
}