diff --git a/server/crates/arbiter-server/src/actors/user_agent/mod.rs b/server/crates/arbiter-server/src/actors/user_agent/mod.rs index 4380b72..6793981 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/mod.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/mod.rs @@ -1,5 +1,7 @@ use arbiter_proto::{ - proto::user_agent::{UserAgentRequest, UserAgentResponse}, + proto::user_agent::{ + SdkClientError as ProtoSdkClientError, UserAgentRequest, UserAgentResponse, + }, transport::Bi, }; use kameo::actor::Spawn as _; @@ -24,12 +26,27 @@ pub enum TransportResponseError { StateTransitionFailed, #[error("Vault is not available")] KeyHolderActorUnreachable, + #[error("SDK client approve failed: {0:?}")] + SdkClientApprove(ProtoSdkClientError), + #[error("SDK client list failed: {0:?}")] + SdkClientList(ProtoSdkClientError), + #[error("SDK client revoke failed: {0:?}")] + SdkClientRevoke(ProtoSdkClientError), #[error(transparent)] Auth(#[from] auth::Error), #[error("Failed registering connection")] ConnectionRegistrationFailed, } +impl TransportResponseError { + pub fn is_terminal(&self) -> bool { + !matches!( + self, + Self::SdkClientApprove(_) | Self::SdkClientList(_) | Self::SdkClientRevoke(_) + ) + } +} + pub type Transport = Box> + Send>; diff --git a/server/crates/arbiter-server/src/actors/user_agent/session.rs b/server/crates/arbiter-server/src/actors/user_agent/session.rs index a19e85b..da53f3a 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/session.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/session.rs @@ -304,11 +304,9 @@ impl UserAgentSession { use sdk_client_approve_response::Result as ApproveResult; if req.pubkey.len() != 32 { - return Ok(response(UserAgentResponsePayload::SdkClientApprove( - SdkClientApproveResponse { - result: Some(ApproveResult::Error(ProtoSdkClientError::Internal.into())), - }, - ))); + return Err(TransportResponseError::SdkClientApprove( + ProtoSdkClientError::Internal, + )); } let now = std::time::SystemTime::now() @@ -320,11 +318,9 @@ impl UserAgentSession { Ok(c) => c, Err(e) => { error!(?e, "Failed to get DB connection for sdk_client_approve"); - return Ok(response(UserAgentResponsePayload::SdkClientApprove( - SdkClientApproveResponse { - result: Some(ApproveResult::Error(ProtoSdkClientError::Internal.into())), - }, - ))); + return Err(TransportResponseError::SdkClientApprove( + ProtoSdkClientError::Internal, + )); } }; @@ -363,33 +359,23 @@ impl UserAgentSession { )), Err(e) => { error!(?e, "Failed to fetch inserted SDK client"); - Ok(response(UserAgentResponsePayload::SdkClientApprove( - SdkClientApproveResponse { - result: Some(ApproveResult::Error( - ProtoSdkClientError::Internal.into(), - )), - }, - ))) + Err(TransportResponseError::SdkClientApprove( + ProtoSdkClientError::Internal, + )) } } } Err(diesel::result::Error::DatabaseError( diesel::result::DatabaseErrorKind::UniqueViolation, _, - )) => Ok(response(UserAgentResponsePayload::SdkClientApprove( - SdkClientApproveResponse { - result: Some(ApproveResult::Error( - ProtoSdkClientError::AlreadyExists.into(), - )), - }, - ))), + )) => Err(TransportResponseError::SdkClientApprove( + ProtoSdkClientError::AlreadyExists, + )), Err(e) => { error!(?e, "Failed to insert SDK client"); - Ok(response(UserAgentResponsePayload::SdkClientApprove( - SdkClientApproveResponse { - result: Some(ApproveResult::Error(ProtoSdkClientError::Internal.into())), - }, - ))) + Err(TransportResponseError::SdkClientApprove( + ProtoSdkClientError::Internal, + )) } } } @@ -399,13 +385,9 @@ impl UserAgentSession { Ok(c) => c, Err(e) => { error!(?e, "Failed to get DB connection for sdk_client_list"); - return Ok(response(UserAgentResponsePayload::SdkClientList( - SdkClientListResponse { - result: Some(sdk_client_list_response::Result::Error( - ProtoSdkClientError::Internal.into(), - )), - }, - ))); + return Err(TransportResponseError::SdkClientList( + ProtoSdkClientError::Internal, + )); } }; @@ -434,13 +416,9 @@ impl UserAgentSession { ))), Err(e) => { error!(?e, "Failed to list SDK clients"); - Ok(response(UserAgentResponsePayload::SdkClientList( - SdkClientListResponse { - result: Some(sdk_client_list_response::Result::Error( - ProtoSdkClientError::Internal.into(), - )), - }, - ))) + Err(TransportResponseError::SdkClientList( + ProtoSdkClientError::Internal, + )) } } } @@ -452,11 +430,9 @@ impl UserAgentSession { Ok(c) => c, Err(e) => { error!(?e, "Failed to get DB connection for sdk_client_revoke"); - return Ok(response(UserAgentResponsePayload::SdkClientRevoke( - SdkClientRevokeResponse { - result: Some(RevokeResult::Error(ProtoSdkClientError::Internal.into())), - }, - ))); + return Err(TransportResponseError::SdkClientRevoke( + ProtoSdkClientError::Internal, + )); } }; @@ -465,11 +441,9 @@ impl UserAgentSession { .execute(&mut conn) .await { - Ok(0) => Ok(response(UserAgentResponsePayload::SdkClientRevoke( - SdkClientRevokeResponse { - result: Some(RevokeResult::Error(ProtoSdkClientError::NotFound.into())), - }, - ))), + Ok(0) => Err(TransportResponseError::SdkClientRevoke( + ProtoSdkClientError::NotFound, + )), Ok(_) => Ok(response(UserAgentResponsePayload::SdkClientRevoke( SdkClientRevokeResponse { result: Some(RevokeResult::Ok(())), @@ -478,20 +452,14 @@ impl UserAgentSession { Err(diesel::result::Error::DatabaseError( diesel::result::DatabaseErrorKind::ForeignKeyViolation, _, - )) => Ok(response(UserAgentResponsePayload::SdkClientRevoke( - SdkClientRevokeResponse { - result: Some(RevokeResult::Error( - ProtoSdkClientError::HasRelatedData.into(), - )), - }, - ))), + )) => Err(TransportResponseError::SdkClientRevoke( + ProtoSdkClientError::HasRelatedData, + )), Err(e) => { error!(?e, "Failed to delete SDK client"); - Ok(response(UserAgentResponsePayload::SdkClientRevoke( - SdkClientRevokeResponse { - result: Some(RevokeResult::Error(ProtoSdkClientError::Internal.into())), - }, - ))) + Err(TransportResponseError::SdkClientRevoke( + ProtoSdkClientError::Internal, + )) } } } @@ -558,8 +526,15 @@ impl Actor for UserAgentSession { } } Err(err) => { - let _ = self.props.transport.send(Err(err)).await; - return Some(kameo::mailbox::Signal::Stop); + let should_stop = err.is_terminal(); + if self.props.transport.send(Err(err)).await.is_err() { + error!(actor = "useragent", reason = "channel closed", "send.failed"); + return Some(kameo::mailbox::Signal::Stop); + } + + if should_stop { + return Some(kameo::mailbox::Signal::Stop); + } } } } diff --git a/server/crates/arbiter-server/src/lib.rs b/server/crates/arbiter-server/src/lib.rs index abb51a5..12eb13e 100644 --- a/server/crates/arbiter-server/src/lib.rs +++ b/server/crates/arbiter-server/src/lib.rs @@ -2,7 +2,12 @@ use arbiter_proto::{ proto::{ client::{ClientRequest, ClientResponse}, - user_agent::{UserAgentRequest, UserAgentResponse}, + user_agent::{ + SdkClientApproveResponse, SdkClientListResponse, SdkClientRevokeResponse, + UserAgentRequest, UserAgentResponse, sdk_client_approve_response, + sdk_client_list_response, sdk_client_revoke_response, + user_agent_response::Payload as UserAgentResponsePayload, + }, }, transport::{IdentityRecvConverter, SendConverter, grpc}, }; @@ -37,6 +42,27 @@ impl SendConverter for UserAgentGrpcSender { fn convert(&self, item: Self::Input) -> Self::Output { match item { Ok(message) => Ok(message), + Err(TransportResponseError::SdkClientApprove(code)) => Ok(UserAgentResponse { + payload: Some(UserAgentResponsePayload::SdkClientApprove( + SdkClientApproveResponse { + result: Some(sdk_client_approve_response::Result::Error(code.into())), + }, + )), + }), + Err(TransportResponseError::SdkClientList(code)) => Ok(UserAgentResponse { + payload: Some(UserAgentResponsePayload::SdkClientList( + SdkClientListResponse { + result: Some(sdk_client_list_response::Result::Error(code.into())), + }, + )), + }), + Err(TransportResponseError::SdkClientRevoke(code)) => Ok(UserAgentResponse { + payload: Some(UserAgentResponsePayload::SdkClientRevoke( + SdkClientRevokeResponse { + result: Some(sdk_client_revoke_response::Result::Error(code.into())), + }, + )), + }), Err(err) => Err(user_agent_error_status(err)), } } @@ -103,6 +129,11 @@ fn user_agent_error_status(value: TransportResponseError) -> Status { TransportResponseError::KeyHolderActorUnreachable => { Status::internal("Vault is not available") } + TransportResponseError::SdkClientApprove(_) + | TransportResponseError::SdkClientList(_) + | TransportResponseError::SdkClientRevoke(_) => { + Status::internal("SDK client operation failed") + } TransportResponseError::Auth(ref err) => auth_error_status(err), TransportResponseError::ConnectionRegistrationFailed => { Status::internal("Failed registering connection") diff --git a/server/crates/arbiter-server/tests/user_agent/sdk_client.rs b/server/crates/arbiter-server/tests/user_agent/sdk_client.rs index 3e2734a..08904dc 100644 --- a/server/crates/arbiter-server/tests/user_agent/sdk_client.rs +++ b/server/crates/arbiter-server/tests/user_agent/sdk_client.rs @@ -5,7 +5,10 @@ use arbiter_proto::proto::user_agent::{ user_agent_response::Payload as UserAgentResponsePayload, }; use arbiter_server::{ - actors::{GlobalActors, user_agent::session::UserAgentSession}, + actors::{ + GlobalActors, + user_agent::{TransportResponseError, session::UserAgentSession}, + }, db, }; @@ -68,22 +71,15 @@ async fn test_sdk_client_approve_duplicate_returns_already_exists() { .await .unwrap(); - let response = session + let err = session .process_transport_inbound(req) .await - .expect("second insert should not panic"); + .expect_err("second insert should return typed TransportResponseError"); - match response.payload.unwrap() { - UserAgentResponsePayload::SdkClientApprove(resp) => match resp.result.unwrap() { - sdk_client_approve_response::Result::Error(code) => { - assert_eq!(code, ProtoSdkClientError::AlreadyExists as i32); - } - sdk_client_approve_response::Result::Client(_) => { - panic!("Expected AlreadyExists error for duplicate pubkey") - } - }, - other => panic!("Expected SdkClientApprove, got {other:?}"), - } + assert_eq!( + err, + TransportResponseError::SdkClientApprove(ProtoSdkClientError::AlreadyExists) + ); } #[tokio::test] @@ -203,26 +199,19 @@ async fn test_sdk_client_revoke_not_found_returns_error() { let db = db::create_test_pool().await; let mut session = make_session(&db).await; - let response = session + let err = session .process_transport_inbound(UserAgentRequest { payload: Some(UserAgentRequestPayload::SdkClientRevoke( SdkClientRevokeRequest { client_id: 9999 }, )), }) .await - .unwrap(); + .expect_err("missing client should return typed TransportResponseError"); - match response.payload.unwrap() { - UserAgentResponsePayload::SdkClientRevoke(resp) => match resp.result.unwrap() { - sdk_client_revoke_response::Result::Error(code) => { - assert_eq!(code, ProtoSdkClientError::NotFound as i32); - } - sdk_client_revoke_response::Result::Ok(_) => { - panic!("Expected NotFound error for missing client_id") - } - }, - other => panic!("Expected SdkClientRevoke, got {other:?}"), - } + assert_eq!( + err, + TransportResponseError::SdkClientRevoke(ProtoSdkClientError::NotFound) + ); } #[tokio::test]