From 86f8feb2912f3ef1a7473fa3a8d12408c45ca486 Mon Sep 17 00:00:00 2001 From: hdbg Date: Thu, 26 Feb 2026 22:23:52 +0100 Subject: [PATCH] tests(user-agent): basic auth tests similar to `server` --- server/crates/arbiter-proto/src/transport.rs | 18 +-- server/crates/arbiter-useragent/src/grpc.rs | 33 +--- server/crates/arbiter-useragent/src/lib.rs | 89 ++++++----- server/crates/arbiter-useragent/tests/auth.rs | 151 ++++++++++++++++++ 4 files changed, 212 insertions(+), 79 deletions(-) create mode 100644 server/crates/arbiter-useragent/tests/auth.rs diff --git a/server/crates/arbiter-proto/src/transport.rs b/server/crates/arbiter-proto/src/transport.rs index 9ab433b..48bb9a3 100644 --- a/server/crates/arbiter-proto/src/transport.rs +++ b/server/crates/arbiter-proto/src/transport.rs @@ -188,20 +188,20 @@ pub mod grpc { /// Tonic receive errors are logged and treated as stream closure (`None`). /// The receive converter is only invoked for successful inbound transport /// items. - pub struct GrpcAdapter + pub struct GrpcAdapter where - InboundConverter: RecvConverter, + InboundConverter: RecvConverter, OutboundConverter: SendConverter, { sender: mpsc::Sender, - receiver: Streaming, + receiver: Streaming, inbound_converter: InboundConverter, outbound_converter: OutboundConverter, } impl - GrpcAdapter + GrpcAdapter where InboundConverter: RecvConverter, OutboundConverter: SendConverter, @@ -222,12 +222,10 @@ pub mod grpc { } - impl Bi - for GrpcAdapter + impl< InboundConverter, OutboundConverter> Bi + for GrpcAdapter where - InboundTransport: Send + 'static, - Inbound: Send + 'static, - InboundConverter: RecvConverter, + InboundConverter: RecvConverter, OutboundConverter: SendConverter, OutboundConverter::Input: Send + 'static, OutboundConverter::Output: Send + 'static, @@ -242,7 +240,7 @@ pub mod grpc { } #[tracing::instrument(level = "trace", skip(self))] - async fn recv(&mut self) -> Option { + async fn recv(&mut self) -> Option { match self.receiver.next().await { Some(Ok(item)) => Some(self.inbound_converter.convert(item)), Some(Err(error)) => { diff --git a/server/crates/arbiter-useragent/src/grpc.rs b/server/crates/arbiter-useragent/src/grpc.rs index ef523a9..9a00ad7 100644 --- a/server/crates/arbiter-useragent/src/grpc.rs +++ b/server/crates/arbiter-useragent/src/grpc.rs @@ -2,7 +2,7 @@ use arbiter_proto::{ proto::{ UserAgentRequest, UserAgentResponse, arbiter_service_client::ArbiterServiceClient, }, - transport::{RecvConverter, IdentitySendConverter, grpc}, + transport::{IdentityRecvConverter, IdentitySendConverter, RecvConverter, grpc}, url::ArbiterUrl, }; use ed25519_dalek::SigningKey; @@ -15,7 +15,7 @@ use tonic::transport::ClientTlsConfig; #[derive(Debug, thiserror::Error)] -pub enum InitError { +pub enum ConnectError { #[error("Could establish connection")] Connection(#[from] tonic::transport::Error), @@ -29,26 +29,12 @@ pub enum InitError { Grpc(#[from] tonic::Status), } -pub struct InboundConverter; -impl RecvConverter for InboundConverter { - type Input = UserAgentResponse; - type Output = Result; - - fn convert(&self, item: Self::Input) -> Self::Output { - Ok(item) - } -} - -use crate::InboundError; - -use super::UserAgentActor; + use super::UserAgentActor; pub type UserAgentGrpc = ActorRef< UserAgentActor< grpc::GrpcAdapter< - UserAgentResponse, - Result, - InboundConverter, + IdentityRecvConverter, IdentitySendConverter, >, >, @@ -56,7 +42,7 @@ pub type UserAgentGrpc = ActorRef< pub async fn connect_grpc( url: ArbiterUrl, key: SigningKey, -) -> Result { +) -> Result { let bootstrap_token = url.bootstrap_token.clone(); let anchor = webpki::anchor_from_trusted_cert(&url.ca_cert)?.to_owned(); let tls = ClientTlsConfig::new().trust_anchor(anchor); @@ -75,16 +61,11 @@ pub async fn connect_grpc( let adapter = grpc::GrpcAdapter::new( tx, bistream, - InboundConverter, + IdentityRecvConverter::new(), IdentitySendConverter::new(), ); - let actor = UserAgentActor::spawn(UserAgentActor { - key, - bootstrap_token, - state: super::UserAgentStateMachine::new(super::DummyContext), - transport: adapter, - }); + let actor = UserAgentActor::spawn(UserAgentActor::new(key, bootstrap_token, adapter)); Ok(actor) } diff --git a/server/crates/arbiter-useragent/src/lib.rs b/server/crates/arbiter-useragent/src/lib.rs index fbdec08..5972d7e 100644 --- a/server/crates/arbiter-useragent/src/lib.rs +++ b/server/crates/arbiter-useragent/src/lib.rs @@ -13,15 +13,25 @@ use arbiter_proto::{ transport::Bi, }; use ed25519_dalek::{Signer, SigningKey}; -use kameo::{ - Actor, - actor::{ActorRef, Spawn}, - prelude::Message, -}; +use kameo::{Actor, actor::ActorRef}; use smlang::statemachine; use tokio::select; use tracing::{error, info}; +statemachine! { + name: UserAgent, + custom_error: false, + transitions: { + *Init + SentAuthChallengeRequest = WaitingForServerAuth, + WaitingForServerAuth + ReceivedAuthChallenge = WaitingForAuthOk, + WaitingForServerAuth + ReceivedAuthOk = Authenticated, + WaitingForAuthOk + ReceivedAuthOk = Authenticated, + } +} + +pub struct DummyContext; +impl UserAgentStateMachineContext for DummyContext {} + #[derive(Debug, thiserror::Error)] pub enum InboundError { #[error("Invalid user agent response")] @@ -40,23 +50,9 @@ pub enum InboundError { TransportSendFailed, } -statemachine! { - name: UserAgent, - custom_error: false, - transitions: { - *Init + SentAuthChallengeRequest = WaitingForServerAuth, - WaitingForServerAuth + ReceivedAuthChallenge = WaitingForAuthOk, - WaitingForServerAuth + ReceivedAuthOk = Authenticated, - WaitingForAuthOk + ReceivedAuthOk = Authenticated, - } -} - -pub struct DummyContext; -impl UserAgentStateMachineContext for DummyContext {} - pub struct UserAgentActor where - Transport: Bi, UserAgentRequest>, + Transport: Bi, { key: SigningKey, bootstrap_token: Option, @@ -66,8 +62,17 @@ where impl UserAgentActor where - Transport: Bi, UserAgentRequest>, + Transport: Bi, { + pub fn new(key: SigningKey, bootstrap_token: Option, transport: Transport) -> Self { + Self { + key, + bootstrap_token, + state: UserAgentStateMachine::new(DummyContext), + transport, + } + } + fn transition(&mut self, event: UserAgentEvents) -> Result<(), InboundError> { self.state.process_event(event).map_err(|e| { error!(?e, "useragent state transition failed"); @@ -90,11 +95,15 @@ where bootstrap_token: self.bootstrap_token.take(), }; + self.transition(UserAgentEvents::SentAuthChallengeRequest)?; + self.transport - .send(Self::auth_request(ClientAuthPayload::AuthChallengeRequest(req))) + .send(Self::auth_request(ClientAuthPayload::AuthChallengeRequest( + req, + ))) .await .map_err(|_| InboundError::TransportSendFailed)?; - self.transition(UserAgentEvents::SentAuthChallengeRequest)?; + info!(actor = "useragent", "auth.request.sent"); Ok(()) } @@ -103,10 +112,6 @@ where &mut self, challenge: auth::AuthChallenge, ) -> Result<(), InboundError> { - if !matches!(self.state.state(), UserAgentStates::WaitingForServerAuth) { - return Err(InboundError::InvalidStateForAuthChallenge); - } - self.transition(UserAgentEvents::ReceivedAuthChallenge)?; let formatted = format_challenge(&challenge); @@ -116,9 +121,9 @@ where }; self.transport - .send(Self::auth_request(ClientAuthPayload::AuthChallengeSolution( - solution, - ))) + .send(Self::auth_request( + ClientAuthPayload::AuthChallengeSolution(solution), + )) .await .map_err(|_| InboundError::TransportSendFailed)?; @@ -127,22 +132,16 @@ where } fn handle_auth_ok(&mut self, _ok: AuthOk) -> Result<(), InboundError> { - match self.state.state() { - UserAgentStates::WaitingForServerAuth | UserAgentStates::WaitingForAuthOk => { - self.transition(UserAgentEvents::ReceivedAuthOk)?; - info!(actor = "useragent", "auth.ok"); - Ok(()) - } - _ => Err(InboundError::InvalidStateForAuthOk), - } + self.transition(UserAgentEvents::ReceivedAuthOk)?; + info!(actor = "useragent", "auth.ok"); + Ok(()) } pub async fn process_inbound_transport( &mut self, - inbound: Result, + inbound: UserAgentResponse ) -> Result<(), InboundError> { - let response = inbound?; - let payload = response + let payload = inbound .payload .ok_or(InboundError::MissingResponsePayload)?; @@ -160,13 +159,16 @@ where impl Actor for UserAgentActor where - Transport: Bi, UserAgentRequest>, + Transport: Bi, { type Args = Self; type Error = (); - async fn on_start(mut args: Self::Args, _actor_ref: ActorRef) -> Result { + async fn on_start( + mut args: Self::Args, + _actor_ref: ActorRef, + ) -> Result { if let Err(err) = args.send_auth_challenge_request().await { error!(?err, actor = "useragent", "auth.start.failed"); return Err(()); @@ -204,3 +206,4 @@ where } mod grpc; +pub use grpc::{connect_grpc, ConnectError}; \ No newline at end of file diff --git a/server/crates/arbiter-useragent/tests/auth.rs b/server/crates/arbiter-useragent/tests/auth.rs new file mode 100644 index 0000000..a883f15 --- /dev/null +++ b/server/crates/arbiter-useragent/tests/auth.rs @@ -0,0 +1,151 @@ +use arbiter_proto::{ + format_challenge, + proto::{ + UserAgentRequest, UserAgentResponse, + auth::{ + AuthChallenge, AuthOk, ClientMessage as AuthClientMessage, + ServerMessage as AuthServerMessage, client_message::Payload as ClientAuthPayload, + server_message::Payload as ServerAuthPayload, + }, + user_agent_request::Payload as UserAgentRequestPayload, + user_agent_response::Payload as UserAgentResponsePayload, + }, + transport::Bi, +}; +use arbiter_useragent::{InboundError, UserAgentActor}; +use ed25519_dalek::SigningKey; +use kameo::actor::Spawn; +use tokio::sync::mpsc; +use tokio::time::{Duration, timeout}; + +struct TestTransport { + inbound_rx: mpsc::Receiver, + outbound_tx: mpsc::Sender, +} + +impl Bi for TestTransport { + async fn send(&mut self, item: UserAgentRequest) -> Result<(), arbiter_proto::transport::Error> { + self.outbound_tx + .send(item) + .await + .map_err(|_| arbiter_proto::transport::Error::ChannelClosed) + } + + async fn recv(&mut self) -> Option { + self.inbound_rx.recv().await + } +} + +fn make_transport() -> ( + TestTransport, + mpsc::Sender, + mpsc::Receiver, +) { + let (inbound_tx, inbound_rx) = mpsc::channel(8); + let (outbound_tx, outbound_rx) = mpsc::channel(8); + ( + TestTransport { + inbound_rx, + outbound_tx, + }, + inbound_tx, + outbound_rx, + ) +} + +fn test_key() -> SigningKey { + SigningKey::from_bytes(&[7u8; 32]) +} + +fn auth_response(payload: ServerAuthPayload) -> UserAgentResponse { + UserAgentResponse { + payload: Some(UserAgentResponsePayload::AuthMessage(AuthServerMessage { + payload: Some(payload), + })), + } +} + +#[tokio::test] +async fn sends_auth_request_on_start_with_bootstrap_token() { + let key = test_key(); + let pubkey = key.verifying_key().to_bytes().to_vec(); + let bootstrap_token = Some("bootstrap-123".to_string()); + let (transport, inbound_tx, mut outbound_rx) = make_transport(); + + let actor = UserAgentActor::spawn(UserAgentActor::new(key, bootstrap_token.clone(), transport)); + + let outbound = timeout(Duration::from_secs(1), outbound_rx.recv()) + .await + .expect("timed out waiting for auth request") + .expect("channel closed before auth request"); + + let UserAgentRequest { + payload: Some(UserAgentRequestPayload::AuthMessage(AuthClientMessage { + payload: Some(ClientAuthPayload::AuthChallengeRequest(req)), + })), + } = outbound + else { + panic!("expected auth challenge request"); + }; + + assert_eq!(req.pubkey, pubkey); + assert_eq!(req.bootstrap_token, bootstrap_token); + + drop(inbound_tx); + drop(actor); +} + +#[tokio::test] +async fn challenge_flow_sends_solution_from_transport_inbound() { + let key = test_key(); + let verify_key = key.verifying_key(); + let (transport, inbound_tx, mut outbound_rx) = make_transport(); + + let actor = UserAgentActor::spawn(UserAgentActor::new(key, None, transport)); + + let _initial_auth_request = timeout(Duration::from_secs(1), outbound_rx.recv()) + .await + .expect("timed out waiting for initial auth request") + .expect("missing initial auth request"); + + let challenge = AuthChallenge { + pubkey: verify_key.to_bytes().to_vec(), + nonce: 42, + }; + inbound_tx + .send(auth_response(ServerAuthPayload::AuthChallenge(challenge.clone()))) + .await + .unwrap(); + + let outbound = timeout(Duration::from_secs(1), outbound_rx.recv()) + .await + .expect("timed out waiting for challenge solution") + .expect("missing challenge solution"); + + let UserAgentRequest { + payload: Some(UserAgentRequestPayload::AuthMessage(AuthClientMessage { + payload: Some(ClientAuthPayload::AuthChallengeSolution(solution)), + })), + } = outbound + else { + panic!("expected auth challenge solution"); + }; + + let formatted = format_challenge(&challenge); + let sig: ed25519_dalek::Signature = solution + .signature + .as_slice() + .try_into() + .expect("signature bytes length"); + verify_key + .verify_strict(&formatted, &sig) + .expect("solution signature should verify"); + + inbound_tx + .send(auth_response(ServerAuthPayload::AuthOk(AuthOk {}))) + .await + .unwrap(); + + drop(inbound_tx); + drop(actor); +}