feat(proto): request / response pair tracking by assigning id

This commit is contained in:
hdbg
2026-03-18 23:43:44 +01:00
committed by Stas
parent 60ce1cc110
commit 3e8b26418a
8 changed files with 259 additions and 118 deletions

View File

@@ -37,6 +37,7 @@ enum VaultState {
} }
message ClientRequest { message ClientRequest {
int32 request_id = 4;
oneof payload { oneof payload {
AuthChallengeRequest auth_challenge_request = 1; AuthChallengeRequest auth_challenge_request = 1;
AuthChallengeSolution auth_challenge_solution = 2; AuthChallengeSolution auth_challenge_solution = 2;
@@ -45,6 +46,7 @@ message ClientRequest {
} }
message ClientResponse { message ClientResponse {
optional int32 request_id = 7;
oneof payload { oneof payload {
AuthChallenge auth_challenge = 1; AuthChallenge auth_challenge = 1;
AuthResult auth_result = 2; AuthResult auth_result = 2;

View File

@@ -89,6 +89,7 @@ message ClientConnectionResponse {
message ClientConnectionCancel {} message ClientConnectionCancel {}
message UserAgentRequest { message UserAgentRequest {
int32 id = 14;
oneof payload { oneof payload {
AuthChallengeRequest auth_challenge_request = 1; AuthChallengeRequest auth_challenge_request = 1;
AuthChallengeSolution auth_challenge_solution = 2; AuthChallengeSolution auth_challenge_solution = 2;
@@ -105,6 +106,7 @@ message UserAgentRequest {
} }
} }
message UserAgentResponse { message UserAgentResponse {
optional int32 id = 14;
oneof payload { oneof payload {
AuthChallenge auth_challenge = 1; AuthChallenge auth_challenge = 1;
AuthResult auth_result = 2; AuthResult auth_result = 2;

View File

@@ -10,6 +10,7 @@ use kameo::{
actor::{ActorRef, Spawn as _}, actor::{ActorRef, Spawn as _},
error::SendError, error::SendError,
}; };
use tonic::Status;
use tracing::{info, warn}; use tracing::{info, warn};
use crate::{ use crate::{
@@ -20,6 +21,7 @@ use crate::{
}, },
keyholder::KeyHolderState, keyholder::KeyHolderState,
}, },
grpc::request_tracker::RequestTracker,
utils::defer, utils::defer,
}; };
@@ -28,13 +30,17 @@ mod auth;
async fn dispatch_loop( async fn dispatch_loop(
mut bi: GrpcBi<ClientRequest, ClientResponse>, mut bi: GrpcBi<ClientRequest, ClientResponse>,
actor: ActorRef<ClientSession>, actor: ActorRef<ClientSession>,
mut request_tracker: RequestTracker,
) { ) {
loop { loop {
let Some(conn) = bi.recv().await else { let Some(conn) = bi.recv().await else {
return; return;
}; };
if dispatch_conn_message(&mut bi, &actor, conn).await.is_err() { if dispatch_conn_message(&mut bi, &actor, &mut request_tracker, conn)
.await
.is_err()
{
return; return;
} }
} }
@@ -43,7 +49,8 @@ async fn dispatch_loop(
async fn dispatch_conn_message( async fn dispatch_conn_message(
bi: &mut GrpcBi<ClientRequest, ClientResponse>, bi: &mut GrpcBi<ClientRequest, ClientResponse>,
actor: &ActorRef<ClientSession>, actor: &ActorRef<ClientSession>,
conn: Result<ClientRequest, tonic::Status>, request_tracker: &mut RequestTracker,
conn: Result<ClientRequest, Status>,
) -> Result<(), ()> { ) -> Result<(), ()> {
let conn = match conn { let conn = match conn {
Ok(conn) => conn, Ok(conn) => conn,
@@ -53,9 +60,16 @@ async fn dispatch_conn_message(
} }
}; };
let request_id = match request_tracker.request(conn.request_id) {
Ok(request_id) => request_id,
Err(err) => {
let _ = bi.send(Err(err)).await;
return Err(());
}
};
let Some(payload) = conn.payload else { let Some(payload) = conn.payload else {
let _ = bi let _ = bi
.send(Err(tonic::Status::invalid_argument( .send(Err(Status::invalid_argument(
"Missing client request payload", "Missing client request payload",
))) )))
.await; .await;
@@ -79,15 +93,14 @@ async fn dispatch_conn_message(
payload => { payload => {
warn!(?payload, "Unsupported post-auth client request"); warn!(?payload, "Unsupported post-auth client request");
let _ = bi let _ = bi
.send(Err(tonic::Status::invalid_argument( .send(Err(Status::invalid_argument("Unsupported client request")))
"Unsupported client request",
)))
.await; .await;
return Err(()); return Err(());
} }
}; };
bi.send(Ok(ClientResponse { bi.send(Ok(ClientResponse {
request_id: Some(request_id),
payload: Some(payload), payload: Some(payload),
})) }))
.await .await
@@ -96,7 +109,10 @@ async fn dispatch_conn_message(
pub async fn start(conn: ClientConnection, mut bi: GrpcBi<ClientRequest, ClientResponse>) { pub async fn start(conn: ClientConnection, mut bi: GrpcBi<ClientRequest, ClientResponse>) {
let mut conn = conn; let mut conn = conn;
match auth::start(&mut conn, &mut bi).await { let mut request_tracker = RequestTracker::default();
let mut response_id = None;
match auth::start(&mut conn, &mut bi, &mut request_tracker, &mut response_id).await {
Ok(_) => { Ok(_) => {
let actor = let actor =
client::session::ClientSession::spawn(client::session::ClientSession::new(conn)); client::session::ClientSession::spawn(client::session::ClientSession::new(conn));
@@ -106,10 +122,14 @@ pub async fn start(conn: ClientConnection, mut bi: GrpcBi<ClientRequest, ClientR
}); });
info!("Client authenticated successfully"); info!("Client authenticated successfully");
dispatch_loop(bi, actor).await; dispatch_loop(bi, actor, request_tracker).await;
} }
Err(e) => { Err(e) => {
let mut transport = auth::AuthTransportAdapter(&mut bi); let mut transport = auth::AuthTransportAdapter::new(
&mut bi,
&mut request_tracker,
&mut response_id,
);
let _ = transport.send(Err(e.clone())).await; let _ = transport.send(Err(e.clone())).await;
warn!(error = ?e, "Authentication failed"); warn!(error = ?e, "Authentication failed");
return; return;

View File

@@ -8,15 +8,35 @@ use arbiter_proto::{
transport::{Bi, Error as TransportError, Receiver, Sender, grpc::GrpcBi}, transport::{Bi, Error as TransportError, Receiver, Sender, grpc::GrpcBi},
}; };
use async_trait::async_trait; use async_trait::async_trait;
use tonic::Status;
use tracing::warn; use tracing::warn;
use crate::actors::client::{self, ClientConnection, auth}; use crate::{
actors::client::{self, ClientConnection, auth},
grpc::request_tracker::RequestTracker,
};
pub struct AuthTransportAdapter<'a>(pub(super) &'a mut GrpcBi<ClientRequest, ClientResponse>); pub struct AuthTransportAdapter<'a> {
bi: &'a mut GrpcBi<ClientRequest, ClientResponse>,
request_tracker: &'a mut RequestTracker,
response_id: &'a mut Option<i32>,
}
impl AuthTransportAdapter<'_> { impl<'a> AuthTransportAdapter<'a> {
fn response_to_proto(response: auth::Outbound) -> ClientResponse { pub fn new(
let payload = match response { bi: &'a mut GrpcBi<ClientRequest, ClientResponse>,
request_tracker: &'a mut RequestTracker,
response_id: &'a mut Option<i32>,
) -> Self {
Self {
bi,
request_tracker,
response_id,
}
}
fn response_to_proto(response: auth::Outbound) -> ClientResponsePayload {
match response {
auth::Outbound::AuthChallenge { pubkey, nonce } => { auth::Outbound::AuthChallenge { pubkey, nonce } => {
ClientResponsePayload::AuthChallenge(ProtoAuthChallenge { ClientResponsePayload::AuthChallenge(ProtoAuthChallenge {
pubkey: pubkey.to_bytes().to_vec(), pubkey: pubkey.to_bytes().to_vec(),
@@ -26,39 +46,44 @@ impl AuthTransportAdapter<'_> {
auth::Outbound::AuthSuccess => { auth::Outbound::AuthSuccess => {
ClientResponsePayload::AuthResult(ProtoAuthResult::Success.into()) ClientResponsePayload::AuthResult(ProtoAuthResult::Success.into())
} }
};
ClientResponse {
payload: Some(payload),
} }
} }
fn error_to_proto(error: auth::Error) -> ClientResponse { fn error_to_proto(error: auth::Error) -> ClientResponsePayload {
ClientResponse { ClientResponsePayload::AuthResult(
payload: Some(ClientResponsePayload::AuthResult( match error {
match error { auth::Error::InvalidChallengeSolution => ProtoAuthResult::InvalidSignature,
auth::Error::InvalidChallengeSolution => ProtoAuthResult::InvalidSignature, auth::Error::ApproveError(auth::ApproveError::Denied) => {
auth::Error::ApproveError(auth::ApproveError::Denied) => { ProtoAuthResult::ApprovalDenied
ProtoAuthResult::ApprovalDenied
}
auth::Error::ApproveError(auth::ApproveError::Upstream(
crate::actors::router::ApprovalError::NoUserAgentsConnected,
)) => ProtoAuthResult::NoUserAgentsOnline,
auth::Error::ApproveError(auth::ApproveError::Internal)
| auth::Error::DatabasePoolUnavailable
| auth::Error::DatabaseOperationFailed
| auth::Error::Transport => ProtoAuthResult::Internal,
} }
.into(), auth::Error::ApproveError(auth::ApproveError::Upstream(
)), crate::actors::router::ApprovalError::NoUserAgentsConnected,
} )) => ProtoAuthResult::NoUserAgentsOnline,
auth::Error::ApproveError(auth::ApproveError::Internal)
| auth::Error::DatabasePoolUnavailable
| auth::Error::DatabaseOperationFailed
| auth::Error::Transport => ProtoAuthResult::Internal,
}
.into(),
)
}
async fn send_client_response(
&mut self,
payload: ClientResponsePayload,
) -> Result<(), TransportError> {
let request_id = self.response_id.take();
self.bi
.send(Ok(ClientResponse {
request_id,
payload: Some(payload),
}))
.await
} }
async fn send_auth_result(&mut self, result: ProtoAuthResult) -> Result<(), TransportError> { async fn send_auth_result(&mut self, result: ProtoAuthResult) -> Result<(), TransportError> {
self.0 self.send_client_response(ClientResponsePayload::AuthResult(result.into()))
.send(Ok(ClientResponse {
payload: Some(ClientResponsePayload::AuthResult(result.into())),
}))
.await .await
} }
} }
@@ -69,19 +94,19 @@ impl Sender<Result<auth::Outbound, auth::Error>> for AuthTransportAdapter<'_> {
&mut self, &mut self,
item: Result<auth::Outbound, auth::Error>, item: Result<auth::Outbound, auth::Error>,
) -> Result<(), TransportError> { ) -> Result<(), TransportError> {
let outbound = match item { let payload = match item {
Ok(message) => Ok(AuthTransportAdapter::response_to_proto(message)), Ok(message) => AuthTransportAdapter::response_to_proto(message),
Err(err) => Ok(AuthTransportAdapter::error_to_proto(err)), Err(err) => AuthTransportAdapter::error_to_proto(err),
}; };
self.0.send(outbound).await self.send_client_response(payload).await
} }
} }
#[async_trait] #[async_trait]
impl Receiver<auth::Inbound> for AuthTransportAdapter<'_> { impl Receiver<auth::Inbound> for AuthTransportAdapter<'_> {
async fn recv(&mut self) -> Option<auth::Inbound> { async fn recv(&mut self) -> Option<auth::Inbound> {
let request = match self.0.recv().await? { let request = match self.bi.recv().await? {
Ok(request) => request, Ok(request) => request,
Err(error) => { Err(error) => {
warn!(error = ?error, "grpc client recv failed; closing stream"); warn!(error = ?error, "grpc client recv failed; closing stream");
@@ -89,6 +114,15 @@ impl Receiver<auth::Inbound> for AuthTransportAdapter<'_> {
} }
}; };
let request_id = match self.request_tracker.request(request.request_id) {
Ok(request_id) => request_id,
Err(error) => {
let _ = self.bi.send(Err(error)).await;
return None;
}
};
*self.response_id = Some(request_id);
let payload = request.payload?; let payload = request.payload?;
match payload { match payload {
@@ -114,7 +148,13 @@ impl Receiver<auth::Inbound> for AuthTransportAdapter<'_> {
}; };
Some(auth::Inbound::AuthChallengeSolution { signature }) Some(auth::Inbound::AuthChallengeSolution { signature })
} }
_ => None, _ => {
let _ = self
.bi
.send(Err(Status::invalid_argument("Unsupported client auth request")))
.await;
None
}
} }
} }
} }
@@ -124,8 +164,10 @@ impl Bi<auth::Inbound, Result<auth::Outbound, auth::Error>> for AuthTransportAda
pub async fn start( pub async fn start(
conn: &mut ClientConnection, conn: &mut ClientConnection,
bi: &mut GrpcBi<ClientRequest, ClientResponse>, bi: &mut GrpcBi<ClientRequest, ClientResponse>,
request_tracker: &mut RequestTracker,
response_id: &mut Option<i32>,
) -> Result<(), auth::Error> { ) -> Result<(), auth::Error> {
let mut transport = AuthTransportAdapter(bi); let mut transport = AuthTransportAdapter::new(bi, request_tracker, response_id);
client::auth::authenticate(conn, &mut transport).await?; client::auth::authenticate(conn, &mut transport).await?;
Ok(()) Ok(())
} }

View File

@@ -17,6 +17,7 @@ use crate::{
}; };
pub mod client; pub mod client;
mod request_tracker;
pub mod user_agent; pub mod user_agent;
#[async_trait] #[async_trait]

View File

@@ -0,0 +1,20 @@
use tonic::Status;
#[derive(Default)]
pub struct RequestTracker {
next_request_id: i32,
}
impl RequestTracker {
pub fn request(&mut self, id: i32) -> Result<i32, Status> {
if id < self.next_request_id {
return Err(Status::invalid_argument("Duplicate request id"));
}
self.next_request_id = id
.checked_add(1)
.ok_or_else(|| Status::invalid_argument("Invalid request id"))?;
Ok(id)
}
}

View File

@@ -53,6 +53,7 @@ use crate::{
Grant, SharedGrantSettings, SpecificGrant, TransactionRateLimit, VolumeRateLimit, Grant, SharedGrantSettings, SpecificGrant, TransactionRateLimit, VolumeRateLimit,
ether_transfer, token_transfers, ether_transfer, token_transfers,
}, },
grpc::request_tracker::RequestTracker,
utils::defer, utils::defer,
}; };
use alloy::primitives::{Address, U256}; use alloy::primitives::{Address, U256};
@@ -74,6 +75,7 @@ async fn dispatch_loop(
mut bi: GrpcBi<UserAgentRequest, UserAgentResponse>, mut bi: GrpcBi<UserAgentRequest, UserAgentResponse>,
actor: ActorRef<UserAgentSession>, actor: ActorRef<UserAgentSession>,
mut receiver: mpsc::Receiver<OutOfBand>, mut receiver: mpsc::Receiver<OutOfBand>,
mut request_tracker: RequestTracker,
) { ) {
loop { loop {
tokio::select! { tokio::select! {
@@ -92,7 +94,10 @@ async fn dispatch_loop(
return; return;
}; };
if dispatch_conn_message(&mut bi, &actor, conn).await.is_err() { if dispatch_conn_message(&mut bi, &actor, &mut request_tracker, conn)
.await
.is_err()
{
return; return;
} }
} }
@@ -103,6 +108,7 @@ async fn dispatch_loop(
async fn dispatch_conn_message( async fn dispatch_conn_message(
bi: &mut GrpcBi<UserAgentRequest, UserAgentResponse>, bi: &mut GrpcBi<UserAgentRequest, UserAgentResponse>,
actor: &ActorRef<UserAgentSession>, actor: &ActorRef<UserAgentSession>,
request_tracker: &mut RequestTracker,
conn: Result<UserAgentRequest, Status>, conn: Result<UserAgentRequest, Status>,
) -> Result<(), ()> { ) -> Result<(), ()> {
let conn = match conn { let conn = match conn {
@@ -113,6 +119,14 @@ async fn dispatch_conn_message(
} }
}; };
let request_id = match request_tracker.request(conn.id) {
Ok(request_id) => request_id,
Err(err) => {
let _ = bi.send(Err(err)).await;
return Err(());
}
};
let Some(payload) = conn.payload else { let Some(payload) = conn.payload else {
let _ = bi let _ = bi
.send(Err(Status::invalid_argument( .send(Err(Status::invalid_argument(
@@ -267,6 +281,7 @@ async fn dispatch_conn_message(
}; };
bi.send(Ok(UserAgentResponse { bi.send(Ok(UserAgentResponse {
id: Some(request_id),
payload: Some(payload), payload: Some(payload),
})) }))
.await .await
@@ -289,6 +304,7 @@ async fn send_out_of_band(
}; };
bi.send(Ok(UserAgentResponse { bi.send(Ok(UserAgentResponse {
id: None,
payload: Some(payload), payload: Some(payload),
})) }))
.await .await
@@ -558,7 +574,17 @@ pub async fn start(
mut conn: UserAgentConnection, mut conn: UserAgentConnection,
mut bi: GrpcBi<UserAgentRequest, UserAgentResponse>, mut bi: GrpcBi<UserAgentRequest, UserAgentResponse>,
) { ) {
let pubkey = match auth::start(&mut conn, &mut bi).await { let mut request_tracker = RequestTracker::default();
let mut response_id = None;
let pubkey = match auth::start(
&mut conn,
&mut bi,
&mut request_tracker,
&mut response_id,
)
.await
{
Ok(pubkey) => pubkey, Ok(pubkey) => pubkey,
Err(e) => { Err(e) => {
warn!(error = ?e, "Authentication failed"); warn!(error = ?e, "Authentication failed");
@@ -572,11 +598,10 @@ pub async fn start(
let actor = UserAgentSession::spawn(UserAgentSession::new(conn, Box::new(oob_adapter))); let actor = UserAgentSession::spawn(UserAgentSession::new(conn, Box::new(oob_adapter)));
let actor_for_cleanup = actor.clone(); let actor_for_cleanup = actor.clone();
// when connection closes
let _ = defer(move || { let _ = defer(move || {
actor_for_cleanup.kill(); actor_for_cleanup.kill();
}); });
info!(?pubkey, "User authenticated successfully"); info!(?pubkey, "User authenticated successfully");
dispatch_loop(bi, actor, oob_receiver).await; dispatch_loop(bi, actor, oob_receiver, request_tracker).await;
} }

View File

@@ -1,52 +1,56 @@
use arbiter_proto::{ use arbiter_proto::{
proto::{ proto::user_agent::{
self, AuthChallenge as ProtoAuthChallenge, AuthChallengeRequest as ProtoAuthChallengeRequest,
evm::{ AuthChallengeSolution as ProtoAuthChallengeSolution, AuthResult as ProtoAuthResult,
EtherTransferSettings as ProtoEtherTransferSettings, EvmError as ProtoEvmError, KeyType as ProtoKeyType, UserAgentRequest, UserAgentResponse,
EvmGrantCreateRequest, EvmGrantCreateResponse, EvmGrantDeleteRequest, user_agent_request::Payload as UserAgentRequestPayload,
EvmGrantDeleteResponse, EvmGrantList, EvmGrantListResponse, GrantEntry, user_agent_response::Payload as UserAgentResponsePayload,
SharedSettings as ProtoSharedSettings, SpecificGrant as ProtoSpecificGrant,
TokenTransferSettings as ProtoTokenTransferSettings,
VolumeRateLimit as ProtoVolumeRateLimit, WalletCreateResponse, WalletEntry, WalletList,
WalletListResponse, evm_grant_create_response::Result as EvmGrantCreateResult,
evm_grant_delete_response::Result as EvmGrantDeleteResult,
evm_grant_list_response::Result as EvmGrantListResult,
specific_grant::Grant as ProtoSpecificGrantType,
wallet_create_response::Result as WalletCreateResult,
wallet_list_response::Result as WalletListResult,
},
user_agent::{
AuthChallenge as ProtoAuthChallenge, AuthChallengeRequest as ProtoAuthChallengeRequest,
AuthChallengeSolution as ProtoAuthChallengeSolution, AuthResult as ProtoAuthResult,
BootstrapEncryptedKey as ProtoBootstrapEncryptedKey,
BootstrapResult as ProtoBootstrapResult, ClientConnectionCancel,
ClientConnectionRequest, ClientConnectionResponse, KeyType as ProtoKeyType,
UnsealEncryptedKey as ProtoUnsealEncryptedKey, UnsealResult as ProtoUnsealResult,
UnsealStart, UnsealStartResponse, UserAgentRequest, UserAgentResponse,
VaultState as ProtoVaultState, user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
},
}, },
transport::{Bi, Error as TransportError, Receiver, Sender, grpc::GrpcBi}, transport::{Bi, Error as TransportError, Receiver, Sender, grpc::GrpcBi},
}; };
use async_trait::async_trait; use async_trait::async_trait;
use tonic::{Status, Streaming}; use tonic::Status;
use tracing::{info, warn}; use tracing::warn;
use crate::{ use crate::{
actors::user_agent::{ actors::user_agent::{AuthPublicKey, UserAgentConnection, auth},
self, AuthPublicKey, OutOfBand as DomainResponse, UserAgentConnection, auth,
},
db::models::KeyType, db::models::KeyType,
evm::policies::{ grpc::request_tracker::RequestTracker,
Grant, SharedGrantSettings, SpecificGrant, TransactionRateLimit, VolumeRateLimit,
ether_transfer, token_transfers,
},
}; };
use alloy::primitives::{Address, U256};
use chrono::{DateTime, TimeZone, Utc};
pub struct AuthTransportAdapter<'a>(&'a mut GrpcBi<UserAgentRequest, UserAgentResponse>); pub struct AuthTransportAdapter<'a> {
bi: &'a mut GrpcBi<UserAgentRequest, UserAgentResponse>,
request_tracker: &'a mut RequestTracker,
response_id: &'a mut Option<i32>,
}
impl<'a> AuthTransportAdapter<'a> {
pub fn new(
bi: &'a mut GrpcBi<UserAgentRequest, UserAgentResponse>,
request_tracker: &'a mut RequestTracker,
response_id: &'a mut Option<i32>,
) -> Self {
Self {
bi,
request_tracker,
response_id,
}
}
async fn send_user_agent_response(
&mut self,
payload: UserAgentResponsePayload,
) -> Result<(), TransportError> {
let id = self.response_id.take();
self.bi
.send(Ok(UserAgentResponse {
id,
payload: Some(payload),
}))
.await
}
}
#[async_trait] #[async_trait]
impl Sender<Result<auth::Outbound, auth::Error>> for AuthTransportAdapter<'_> { impl Sender<Result<auth::Outbound, auth::Error>> for AuthTransportAdapter<'_> {
@@ -55,39 +59,53 @@ impl Sender<Result<auth::Outbound, auth::Error>> for AuthTransportAdapter<'_> {
item: Result<auth::Outbound, auth::Error>, item: Result<auth::Outbound, auth::Error>,
) -> Result<(), TransportError> { ) -> Result<(), TransportError> {
use auth::{Error, Outbound}; use auth::{Error, Outbound};
let response = match item { let payload = match item {
Ok(Outbound::AuthChallenge { nonce }) => Ok(UserAgentResponsePayload::AuthChallenge( Ok(Outbound::AuthChallenge { nonce }) => {
ProtoAuthChallenge { nonce }, UserAgentResponsePayload::AuthChallenge(ProtoAuthChallenge { nonce })
)), }
Ok(Outbound::AuthSuccess) => Ok(UserAgentResponsePayload::AuthResult( Ok(Outbound::AuthSuccess) => {
ProtoAuthResult::Success.into(), UserAgentResponsePayload::AuthResult(ProtoAuthResult::Success.into())
)), }
Err(Error::UnregisteredPublicKey) => {
Err(Error::UnregisteredPublicKey) => Ok(UserAgentResponsePayload::AuthResult( UserAgentResponsePayload::AuthResult(ProtoAuthResult::InvalidKey.into())
ProtoAuthResult::InvalidKey.into(), }
)), Err(Error::InvalidChallengeSolution) => {
Err(Error::InvalidChallengeSolution) => Ok(UserAgentResponsePayload::AuthResult( UserAgentResponsePayload::AuthResult(ProtoAuthResult::InvalidSignature.into())
ProtoAuthResult::InvalidSignature.into(), }
)), Err(Error::InvalidBootstrapToken) => {
Err(Error::InvalidBootstrapToken) => Ok(UserAgentResponsePayload::BootstrapResult( UserAgentResponsePayload::AuthResult(ProtoAuthResult::TokenInvalid.into())
ProtoAuthResult::TokenInvalid.into(), }
)), Err(Error::Internal { details }) => return self.bi.send(Err(Status::internal(details))).await,
Err(Error::Internal { details }) => Err(Status::internal(details)), Err(Error::Transport) => {
Err(Error::Transport) => Err(Status::unavailable("transport error")), return self.bi.send(Err(Status::unavailable("transport error"))).await;
}
}; };
self.0
.send(response.map(|r| UserAgentResponse { payload: Some(r) })) self.send_user_agent_response(payload).await
.await
} }
} }
#[async_trait] #[async_trait]
impl Receiver<auth::Inbound> for AuthTransportAdapter<'_> { impl Receiver<auth::Inbound> for AuthTransportAdapter<'_> {
async fn recv(&mut self) -> Option<auth::Inbound> { async fn recv(&mut self) -> Option<auth::Inbound> {
let Ok(UserAgentRequest { let request = match self.bi.recv().await? {
payload: Some(payload), Ok(request) => request,
}) = self.0.recv().await? Err(error) => {
else { warn!(error = ?error, "Failed to receive user agent auth request");
return None;
}
};
let request_id = match self.request_tracker.request(request.id) {
Ok(request_id) => request_id,
Err(error) => {
let _ = self.bi.send(Err(error)).await;
return None;
}
};
*self.response_id = Some(request_id);
let Some(payload) = request.payload else {
warn!( warn!(
event = "received request with empty payload", event = "received request with empty payload",
"grpc.useragent.auth_adapter" "grpc.useragent.auth_adapter"
@@ -136,16 +154,27 @@ impl Receiver<auth::Inbound> for AuthTransportAdapter<'_> {
UserAgentRequestPayload::AuthChallengeSolution(ProtoAuthChallengeSolution { UserAgentRequestPayload::AuthChallengeSolution(ProtoAuthChallengeSolution {
signature, signature,
}) => Some(auth::Inbound::AuthChallengeSolution { signature }), }) => Some(auth::Inbound::AuthChallengeSolution { signature }),
_ => None, // Ignore other request types for this adapter _ => {
let _ = self
.bi
.send(Err(Status::invalid_argument(
"Unsupported user-agent auth request",
)))
.await;
None
}
} }
} }
} }
impl Bi<auth::Inbound, Result<auth::Outbound, auth::Error>> for AuthTransportAdapter<'_> {} impl Bi<auth::Inbound, Result<auth::Outbound, auth::Error>> for AuthTransportAdapter<'_> {}
pub async fn start( pub async fn start(
conn: &mut UserAgentConnection, conn: &mut UserAgentConnection,
bi: &mut GrpcBi<UserAgentRequest, UserAgentResponse>, bi: &mut GrpcBi<UserAgentRequest, UserAgentResponse>,
request_tracker: &mut RequestTracker,
response_id: &mut Option<i32>,
) -> Result<AuthPublicKey, auth::Error> { ) -> Result<AuthPublicKey, auth::Error> {
let mut transport = AuthTransportAdapter(bi); let transport = AuthTransportAdapter::new(bi, request_tracker, response_id);
auth::authenticate(conn, transport).await auth::authenticate(conn, transport).await
} }