tests(user-agent): basic auth tests similar to server

This commit is contained in:
hdbg
2026-02-26 22:23:52 +01:00
parent 61c65ddbcb
commit 3478204b9f
4 changed files with 212 additions and 79 deletions

View File

@@ -188,20 +188,20 @@ pub mod grpc {
/// Tonic receive errors are logged and treated as stream closure (`None`).
/// The receive converter is only invoked for successful inbound transport
/// items.
pub struct GrpcAdapter<InboundTransport, Inbound, InboundConverter, OutboundConverter>
pub struct GrpcAdapter<InboundConverter, OutboundConverter>
where
InboundConverter: RecvConverter<Input = InboundTransport, Output = Inbound>,
InboundConverter: RecvConverter,
OutboundConverter: SendConverter,
{
sender: mpsc::Sender<OutboundConverter::Output>,
receiver: Streaming<InboundTransport>,
receiver: Streaming<InboundConverter::Input>,
inbound_converter: InboundConverter,
outbound_converter: OutboundConverter,
}
impl<InboundTransport, Inbound, InboundConverter, OutboundConverter>
GrpcAdapter<InboundTransport, Inbound, InboundConverter, OutboundConverter>
GrpcAdapter<InboundConverter, OutboundConverter>
where
InboundConverter: RecvConverter<Input = InboundTransport, Output = Inbound>,
OutboundConverter: SendConverter,
@@ -222,12 +222,10 @@ pub mod grpc {
}
impl<InboundTransport, Inbound, InboundConverter, OutboundConverter> Bi<Inbound, OutboundConverter::Input>
for GrpcAdapter<InboundTransport, Inbound, InboundConverter, OutboundConverter>
impl< InboundConverter, OutboundConverter> Bi<InboundConverter::Output, OutboundConverter::Input>
for GrpcAdapter<InboundConverter, OutboundConverter>
where
InboundTransport: Send + 'static,
Inbound: Send + 'static,
InboundConverter: RecvConverter<Input = InboundTransport, Output = Inbound>,
InboundConverter: RecvConverter,
OutboundConverter: SendConverter,
OutboundConverter::Input: Send + 'static,
OutboundConverter::Output: Send + 'static,
@@ -242,7 +240,7 @@ pub mod grpc {
}
#[tracing::instrument(level = "trace", skip(self))]
async fn recv(&mut self) -> Option<Inbound> {
async fn recv(&mut self) -> Option<InboundConverter::Output> {
match self.receiver.next().await {
Some(Ok(item)) => Some(self.inbound_converter.convert(item)),
Some(Err(error)) => {

View File

@@ -2,7 +2,7 @@ use arbiter_proto::{
proto::{
UserAgentRequest, UserAgentResponse, arbiter_service_client::ArbiterServiceClient,
},
transport::{RecvConverter, IdentitySendConverter, grpc},
transport::{IdentityRecvConverter, IdentitySendConverter, RecvConverter, grpc},
url::ArbiterUrl,
};
use ed25519_dalek::SigningKey;
@@ -15,7 +15,7 @@ use tonic::transport::ClientTlsConfig;
#[derive(Debug, thiserror::Error)]
pub enum InitError {
pub enum ConnectError {
#[error("Could establish connection")]
Connection(#[from] tonic::transport::Error),
@@ -29,26 +29,12 @@ pub enum InitError {
Grpc(#[from] tonic::Status),
}
pub struct InboundConverter;
impl RecvConverter for InboundConverter {
type Input = UserAgentResponse;
type Output = Result<UserAgentResponse, InboundError>;
fn convert(&self, item: Self::Input) -> Self::Output {
Ok(item)
}
}
use crate::InboundError;
use super::UserAgentActor;
use super::UserAgentActor;
pub type UserAgentGrpc = ActorRef<
UserAgentActor<
grpc::GrpcAdapter<
UserAgentResponse,
Result<UserAgentResponse, InboundError>,
InboundConverter,
IdentityRecvConverter<UserAgentResponse>,
IdentitySendConverter<UserAgentRequest>,
>,
>,
@@ -56,7 +42,7 @@ pub type UserAgentGrpc = ActorRef<
pub async fn connect_grpc(
url: ArbiterUrl,
key: SigningKey,
) -> Result<UserAgentGrpc, InitError> {
) -> Result<UserAgentGrpc, ConnectError> {
let bootstrap_token = url.bootstrap_token.clone();
let anchor = webpki::anchor_from_trusted_cert(&url.ca_cert)?.to_owned();
let tls = ClientTlsConfig::new().trust_anchor(anchor);
@@ -75,16 +61,11 @@ pub async fn connect_grpc(
let adapter = grpc::GrpcAdapter::new(
tx,
bistream,
InboundConverter,
IdentityRecvConverter::new(),
IdentitySendConverter::new(),
);
let actor = UserAgentActor::spawn(UserAgentActor {
key,
bootstrap_token,
state: super::UserAgentStateMachine::new(super::DummyContext),
transport: adapter,
});
let actor = UserAgentActor::spawn(UserAgentActor::new(key, bootstrap_token, adapter));
Ok(actor)
}

View File

@@ -13,15 +13,25 @@ use arbiter_proto::{
transport::Bi,
};
use ed25519_dalek::{Signer, SigningKey};
use kameo::{
Actor,
actor::{ActorRef, Spawn},
prelude::Message,
};
use kameo::{Actor, actor::ActorRef};
use smlang::statemachine;
use tokio::select;
use tracing::{error, info};
statemachine! {
name: UserAgent,
custom_error: false,
transitions: {
*Init + SentAuthChallengeRequest = WaitingForServerAuth,
WaitingForServerAuth + ReceivedAuthChallenge = WaitingForAuthOk,
WaitingForServerAuth + ReceivedAuthOk = Authenticated,
WaitingForAuthOk + ReceivedAuthOk = Authenticated,
}
}
pub struct DummyContext;
impl UserAgentStateMachineContext for DummyContext {}
#[derive(Debug, thiserror::Error)]
pub enum InboundError {
#[error("Invalid user agent response")]
@@ -40,23 +50,9 @@ pub enum InboundError {
TransportSendFailed,
}
statemachine! {
name: UserAgent,
custom_error: false,
transitions: {
*Init + SentAuthChallengeRequest = WaitingForServerAuth,
WaitingForServerAuth + ReceivedAuthChallenge = WaitingForAuthOk,
WaitingForServerAuth + ReceivedAuthOk = Authenticated,
WaitingForAuthOk + ReceivedAuthOk = Authenticated,
}
}
pub struct DummyContext;
impl UserAgentStateMachineContext for DummyContext {}
pub struct UserAgentActor<Transport>
where
Transport: Bi<Result<UserAgentResponse, InboundError>, UserAgentRequest>,
Transport: Bi<UserAgentResponse, UserAgentRequest>,
{
key: SigningKey,
bootstrap_token: Option<String>,
@@ -66,8 +62,17 @@ where
impl<Transport> UserAgentActor<Transport>
where
Transport: Bi<Result<UserAgentResponse, InboundError>, UserAgentRequest>,
Transport: Bi<UserAgentResponse, UserAgentRequest>,
{
pub fn new(key: SigningKey, bootstrap_token: Option<String>, transport: Transport) -> Self {
Self {
key,
bootstrap_token,
state: UserAgentStateMachine::new(DummyContext),
transport,
}
}
fn transition(&mut self, event: UserAgentEvents) -> Result<(), InboundError> {
self.state.process_event(event).map_err(|e| {
error!(?e, "useragent state transition failed");
@@ -90,11 +95,15 @@ where
bootstrap_token: self.bootstrap_token.take(),
};
self.transition(UserAgentEvents::SentAuthChallengeRequest)?;
self.transport
.send(Self::auth_request(ClientAuthPayload::AuthChallengeRequest(req)))
.send(Self::auth_request(ClientAuthPayload::AuthChallengeRequest(
req,
)))
.await
.map_err(|_| InboundError::TransportSendFailed)?;
self.transition(UserAgentEvents::SentAuthChallengeRequest)?;
info!(actor = "useragent", "auth.request.sent");
Ok(())
}
@@ -103,10 +112,6 @@ where
&mut self,
challenge: auth::AuthChallenge,
) -> Result<(), InboundError> {
if !matches!(self.state.state(), UserAgentStates::WaitingForServerAuth) {
return Err(InboundError::InvalidStateForAuthChallenge);
}
self.transition(UserAgentEvents::ReceivedAuthChallenge)?;
let formatted = format_challenge(&challenge);
@@ -116,9 +121,9 @@ where
};
self.transport
.send(Self::auth_request(ClientAuthPayload::AuthChallengeSolution(
solution,
)))
.send(Self::auth_request(
ClientAuthPayload::AuthChallengeSolution(solution),
))
.await
.map_err(|_| InboundError::TransportSendFailed)?;
@@ -127,22 +132,16 @@ where
}
fn handle_auth_ok(&mut self, _ok: AuthOk) -> Result<(), InboundError> {
match self.state.state() {
UserAgentStates::WaitingForServerAuth | UserAgentStates::WaitingForAuthOk => {
self.transition(UserAgentEvents::ReceivedAuthOk)?;
info!(actor = "useragent", "auth.ok");
Ok(())
}
_ => Err(InboundError::InvalidStateForAuthOk),
}
self.transition(UserAgentEvents::ReceivedAuthOk)?;
info!(actor = "useragent", "auth.ok");
Ok(())
}
pub async fn process_inbound_transport(
&mut self,
inbound: Result<UserAgentResponse, InboundError>,
inbound: UserAgentResponse
) -> Result<(), InboundError> {
let response = inbound?;
let payload = response
let payload = inbound
.payload
.ok_or(InboundError::MissingResponsePayload)?;
@@ -160,13 +159,16 @@ where
impl<Transport> Actor for UserAgentActor<Transport>
where
Transport: Bi<Result<UserAgentResponse, InboundError>, UserAgentRequest>,
Transport: Bi<UserAgentResponse, UserAgentRequest>,
{
type Args = Self;
type Error = ();
async fn on_start(mut args: Self::Args, _actor_ref: ActorRef<Self>) -> Result<Self, Self::Error> {
async fn on_start(
mut args: Self::Args,
_actor_ref: ActorRef<Self>,
) -> Result<Self, Self::Error> {
if let Err(err) = args.send_auth_challenge_request().await {
error!(?err, actor = "useragent", "auth.start.failed");
return Err(());
@@ -204,3 +206,4 @@ where
}
mod grpc;
pub use grpc::{connect_grpc, ConnectError};

View File

@@ -0,0 +1,151 @@
use arbiter_proto::{
format_challenge,
proto::{
UserAgentRequest, UserAgentResponse,
auth::{
AuthChallenge, AuthOk, ClientMessage as AuthClientMessage,
ServerMessage as AuthServerMessage, client_message::Payload as ClientAuthPayload,
server_message::Payload as ServerAuthPayload,
},
user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
},
transport::Bi,
};
use arbiter_useragent::{InboundError, UserAgentActor};
use ed25519_dalek::SigningKey;
use kameo::actor::Spawn;
use tokio::sync::mpsc;
use tokio::time::{Duration, timeout};
struct TestTransport {
inbound_rx: mpsc::Receiver<UserAgentResponse>,
outbound_tx: mpsc::Sender<UserAgentRequest>,
}
impl Bi<UserAgentResponse, UserAgentRequest> for TestTransport {
async fn send(&mut self, item: UserAgentRequest) -> Result<(), arbiter_proto::transport::Error> {
self.outbound_tx
.send(item)
.await
.map_err(|_| arbiter_proto::transport::Error::ChannelClosed)
}
async fn recv(&mut self) -> Option<UserAgentResponse> {
self.inbound_rx.recv().await
}
}
fn make_transport() -> (
TestTransport,
mpsc::Sender<UserAgentResponse>,
mpsc::Receiver<UserAgentRequest>,
) {
let (inbound_tx, inbound_rx) = mpsc::channel(8);
let (outbound_tx, outbound_rx) = mpsc::channel(8);
(
TestTransport {
inbound_rx,
outbound_tx,
},
inbound_tx,
outbound_rx,
)
}
fn test_key() -> SigningKey {
SigningKey::from_bytes(&[7u8; 32])
}
fn auth_response(payload: ServerAuthPayload) -> UserAgentResponse {
UserAgentResponse {
payload: Some(UserAgentResponsePayload::AuthMessage(AuthServerMessage {
payload: Some(payload),
})),
}
}
#[tokio::test]
async fn sends_auth_request_on_start_with_bootstrap_token() {
let key = test_key();
let pubkey = key.verifying_key().to_bytes().to_vec();
let bootstrap_token = Some("bootstrap-123".to_string());
let (transport, inbound_tx, mut outbound_rx) = make_transport();
let actor = UserAgentActor::spawn(UserAgentActor::new(key, bootstrap_token.clone(), transport));
let outbound = timeout(Duration::from_secs(1), outbound_rx.recv())
.await
.expect("timed out waiting for auth request")
.expect("channel closed before auth request");
let UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthMessage(AuthClientMessage {
payload: Some(ClientAuthPayload::AuthChallengeRequest(req)),
})),
} = outbound
else {
panic!("expected auth challenge request");
};
assert_eq!(req.pubkey, pubkey);
assert_eq!(req.bootstrap_token, bootstrap_token);
drop(inbound_tx);
drop(actor);
}
#[tokio::test]
async fn challenge_flow_sends_solution_from_transport_inbound() {
let key = test_key();
let verify_key = key.verifying_key();
let (transport, inbound_tx, mut outbound_rx) = make_transport();
let actor = UserAgentActor::spawn(UserAgentActor::new(key, None, transport));
let _initial_auth_request = timeout(Duration::from_secs(1), outbound_rx.recv())
.await
.expect("timed out waiting for initial auth request")
.expect("missing initial auth request");
let challenge = AuthChallenge {
pubkey: verify_key.to_bytes().to_vec(),
nonce: 42,
};
inbound_tx
.send(auth_response(ServerAuthPayload::AuthChallenge(challenge.clone())))
.await
.unwrap();
let outbound = timeout(Duration::from_secs(1), outbound_rx.recv())
.await
.expect("timed out waiting for challenge solution")
.expect("missing challenge solution");
let UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthMessage(AuthClientMessage {
payload: Some(ClientAuthPayload::AuthChallengeSolution(solution)),
})),
} = outbound
else {
panic!("expected auth challenge solution");
};
let formatted = format_challenge(&challenge);
let sig: ed25519_dalek::Signature = solution
.signature
.as_slice()
.try_into()
.expect("signature bytes length");
verify_key
.verify_strict(&formatted, &sig)
.expect("solution signature should verify");
inbound_tx
.send(auth_response(ServerAuthPayload::AuthOk(AuthOk {})))
.await
.unwrap();
drop(inbound_tx);
drop(actor);
}