From 004b14a168ec66b5af59230527b927977da31be8 Mon Sep 17 00:00:00 2001 From: hdbg Date: Sun, 1 Mar 2026 13:11:15 +0100 Subject: [PATCH 1/2] refactor(transport): convert Bi trait to use async_trait --- server/Cargo.lock | 2 ++ server/crates/arbiter-proto/Cargo.toml | 1 + server/crates/arbiter-proto/src/transport.rs | 25 ++++++++-------- .../arbiter-server/src/actors/client/mod.rs | 29 ++++++------------- .../src/actors/user_agent/mod.rs | 25 +++++----------- server/crates/arbiter-server/src/lib.rs | 4 +-- .../arbiter-server/tests/user_agent/unseal.rs | 9 ++---- server/crates/arbiter-useragent/Cargo.toml | 1 + server/crates/arbiter-useragent/tests/auth.rs | 2 ++ 9 files changed, 40 insertions(+), 58 deletions(-) diff --git a/server/Cargo.lock b/server/Cargo.lock index e6b7973..1cb154d 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -59,6 +59,7 @@ version = "0.1.0" name = "arbiter-proto" version = "0.1.0" dependencies = [ + "async-trait", "base64", "futures", "kameo", @@ -120,6 +121,7 @@ name = "arbiter-useragent" version = "0.1.0" dependencies = [ "arbiter-proto", + "async-trait", "ed25519-dalek", "http", "kameo", diff --git a/server/crates/arbiter-proto/Cargo.toml b/server/crates/arbiter-proto/Cargo.toml index 0640004..acc13fc 100644 --- a/server/crates/arbiter-proto/Cargo.toml +++ b/server/crates/arbiter-proto/Cargo.toml @@ -18,6 +18,7 @@ thiserror.workspace = true rustls-pki-types.workspace = true base64 = "0.22.1" tracing.workspace = true +async-trait.workspace = true [build-dependencies] tonic-prost-build = "0.14.3" diff --git a/server/crates/arbiter-proto/src/transport.rs b/server/crates/arbiter-proto/src/transport.rs index 48bb9a3..02ae72c 100644 --- a/server/crates/arbiter-proto/src/transport.rs +++ b/server/crates/arbiter-proto/src/transport.rs @@ -76,6 +76,8 @@ use std::marker::PhantomData; +use async_trait::async_trait; + /// Errors returned by transport adapters implementing [`Bi`]. pub enum Error { /// The outbound side of the transport is no longer accepting messages. @@ -87,13 +89,11 @@ pub enum Error { /// `Bi` models a duplex channel with: /// - inbound items of type `Inbound` read via [`Bi::recv`] /// - outbound items of type `Outbound` written via [`Bi::send`] +#[async_trait] pub trait Bi: Send + Sync + 'static { - fn send( - &mut self, - item: Outbound, - ) -> impl std::future::Future> + Send; + async fn send(&mut self, item: Outbound) -> Result<(), Error>; - fn recv(&mut self) -> impl std::future::Future> + Send; + async fn recv(&mut self) -> Option; } /// Converts transport-facing inbound items into protocol-facing inbound items. @@ -176,6 +176,7 @@ where /// gRPC-specific transport adapters and helpers. pub mod grpc { + use async_trait::async_trait; use futures::StreamExt; use tokio::sync::mpsc; use tonic::Streaming; @@ -199,7 +200,6 @@ pub mod grpc { outbound_converter: OutboundConverter, } - impl GrpcAdapter where @@ -221,8 +221,8 @@ pub mod grpc { } } - - impl< InboundConverter, OutboundConverter> Bi + #[async_trait] + impl Bi for GrpcAdapter where InboundConverter: RecvConverter, @@ -275,6 +275,7 @@ impl Default for DummyTransport { } } +#[async_trait] impl Bi for DummyTransport where Inbound: Send + Sync + 'static, @@ -284,10 +285,8 @@ where Ok(()) } - fn recv(&mut self) -> impl std::future::Future> + Send { - async { - std::future::pending::<()>().await; - None - } + async fn recv(&mut self) -> Option { + std::future::pending::<()>().await; + None } } diff --git a/server/crates/arbiter-server/src/actors/client/mod.rs b/server/crates/arbiter-server/src/actors/client/mod.rs index 8698abb..405bd54 100644 --- a/server/crates/arbiter-server/src/actors/client/mod.rs +++ b/server/crates/arbiter-server/src/actors/client/mod.rs @@ -1,8 +1,7 @@ use arbiter_proto::{ proto::client::{ AuthChallenge, AuthChallengeRequest, AuthChallengeSolution, AuthOk, ClientRequest, - ClientResponse, - client_request::Payload as ClientRequestPayload, + ClientResponse, client_request::Payload as ClientRequestPayload, client_response::Payload as ClientResponsePayload, }, transport::{Bi, DummyTransport}, @@ -50,19 +49,15 @@ pub enum ClientError { DatabaseOperationFailed, } -pub struct ClientActor -where - Transport: Bi>, -{ +pub type Transport = Box> + Send>; + +pub struct ClientActor { db: db::DatabasePool, state: ClientStateMachine, transport: Transport, } -impl ClientActor -where - Transport: Bi>, -{ +impl ClientActor { pub(crate) fn new(context: ServerContext, transport: Transport) -> Self { Self { db: context.db.clone(), @@ -197,10 +192,7 @@ where Ok((valid, challenge_context)) } - async fn handle_auth_challenge_solution( - &mut self, - solution: AuthChallengeSolution, - ) -> Output { + async fn handle_auth_challenge_solution(&mut self, solution: AuthChallengeSolution) -> Output { let (valid, challenge_context) = self.verify_challenge_solution(&solution)?; if valid { @@ -226,10 +218,7 @@ fn response(payload: ClientResponsePayload) -> ClientResponse { } } -impl Actor for ClientActor -where - Transport: Bi>, -{ +impl Actor for ClientActor { type Args = Self; type Error = (); @@ -278,12 +267,12 @@ where } } -impl ClientActor>> { +impl ClientActor { pub fn new_manual(db: db::DatabasePool) -> Self { Self { db, state: ClientStateMachine::new(DummyContext), - transport: DummyTransport::new(), + transport: Box::new(DummyTransport::new()), } } } diff --git a/server/crates/arbiter-server/src/actors/user_agent/mod.rs b/server/crates/arbiter-server/src/actors/user_agent/mod.rs index ba801ee..762ae6d 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/mod.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/mod.rs @@ -71,9 +71,9 @@ pub enum UserAgentError { DatabaseOperationFailed, } -pub struct UserAgentActor -where - Transport: Bi>, +pub type Transport = Box> + Send>; + +pub struct UserAgentActor { db: db::DatabasePool, actors: GlobalActors, @@ -81,10 +81,7 @@ where transport: Transport, } -impl UserAgentActor -where - Transport: Bi>, -{ +impl UserAgentActor { pub(crate) fn new(context: ServerContext, transport: Transport) -> Self { Self { db: context.db.clone(), @@ -265,10 +262,7 @@ fn response(payload: UserAgentResponsePayload) -> UserAgentResponse { } } -impl UserAgentActor -where - Transport: Bi>, -{ +impl UserAgentActor { async fn handle_unseal_request(&mut self, req: UnsealStart) -> Output { let secret = EphemeralSecret::random(); let public_key = PublicKey::from(&secret); @@ -413,10 +407,7 @@ where } -impl Actor for UserAgentActor -where - Transport: Bi>, -{ +impl Actor for UserAgentActor { type Args = Self; type Error = (); @@ -466,13 +457,13 @@ where } -impl UserAgentActor>> { +impl UserAgentActor { pub fn new_manual(db: db::DatabasePool, actors: GlobalActors) -> Self { Self { db, actors, state: UserAgentStateMachine::new(DummyContext), - transport: DummyTransport::new(), + transport: Box::new(DummyTransport::new()), } } } diff --git a/server/crates/arbiter-server/src/lib.rs b/server/crates/arbiter-server/src/lib.rs index e6cb5c5..a7b5ebe 100644 --- a/server/crates/arbiter-server/src/lib.rs +++ b/server/crates/arbiter-server/src/lib.rs @@ -170,7 +170,7 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server { IdentityRecvConverter::::new(), ClientGrpcSender, ); - ClientActor::spawn(ClientActor::new(self.context.clone(), transport)); + ClientActor::spawn(ClientActor::new(self.context.clone(), Box::new(transport))); info!(event = "connection established", "grpc.client"); @@ -191,7 +191,7 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server { IdentityRecvConverter::::new(), UserAgentGrpcSender, ); - UserAgentActor::spawn(UserAgentActor::new(self.context.clone(), transport)); + UserAgentActor::spawn(UserAgentActor::new(self.context.clone(), Box::new(transport))); info!(event = "connection established", "grpc.user_agent"); diff --git a/server/crates/arbiter-server/tests/user_agent/unseal.rs b/server/crates/arbiter-server/tests/user_agent/unseal.rs index 9128a6c..b0f5d1c 100644 --- a/server/crates/arbiter-server/tests/user_agent/unseal.rs +++ b/server/crates/arbiter-server/tests/user_agent/unseal.rs @@ -1,10 +1,9 @@ use arbiter_proto::proto::user_agent::{ AuthChallengeRequest, UnsealEncryptedKey, UnsealResult, UnsealStart, - UserAgentRequest, UserAgentResponse, + UserAgentRequest, user_agent_request::Payload as UserAgentRequestPayload, user_agent_response::Payload as UserAgentResponsePayload, }; -use arbiter_proto::transport::DummyTransport; use arbiter_server::{ actors::{ GlobalActors, @@ -18,14 +17,12 @@ use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit}; use memsafe::MemSafe; use x25519_dalek::{EphemeralSecret, PublicKey}; -type TestUserAgent = - UserAgentActor>>; async fn setup_authenticated_user_agent( seal_key: &[u8], ) -> ( arbiter_server::db::DatabasePool, - TestUserAgent, + UserAgentActor, ) { let db = db::create_test_pool().await; @@ -59,7 +56,7 @@ async fn setup_authenticated_user_agent( } async fn client_dh_encrypt( - user_agent: &mut TestUserAgent, + user_agent: &mut UserAgentActor, key_to_send: &[u8], ) -> UnsealEncryptedKey { let client_secret = EphemeralSecret::random(); diff --git a/server/crates/arbiter-useragent/Cargo.toml b/server/crates/arbiter-useragent/Cargo.toml index de46f67..8b6b85b 100644 --- a/server/crates/arbiter-useragent/Cargo.toml +++ b/server/crates/arbiter-useragent/Cargo.toml @@ -18,3 +18,4 @@ thiserror.workspace = true tokio-stream.workspace = true http = "1.4.0" rustls-webpki = { version = "0.103.9", features = ["aws-lc-rs"] } +async-trait.workspace = true diff --git a/server/crates/arbiter-useragent/tests/auth.rs b/server/crates/arbiter-useragent/tests/auth.rs index cd9dac1..8d79bbe 100644 --- a/server/crates/arbiter-useragent/tests/auth.rs +++ b/server/crates/arbiter-useragent/tests/auth.rs @@ -13,12 +13,14 @@ use ed25519_dalek::SigningKey; use kameo::actor::Spawn; use tokio::sync::mpsc; use tokio::time::{Duration, timeout}; +use async_trait::async_trait; struct TestTransport { inbound_rx: mpsc::Receiver, outbound_tx: mpsc::Sender, } +#[async_trait] impl Bi for TestTransport { async fn send(&mut self, item: UserAgentRequest) -> Result<(), arbiter_proto::transport::Error> { self.outbound_tx From f709650bb1ef5f81900e2d86ff88d67249d6df8c Mon Sep 17 00:00:00 2001 From: hdbg Date: Sun, 1 Mar 2026 19:59:42 +0100 Subject: [PATCH 2/2] refactor(server::{user_agent, client}): move auth part to separate function to not to pollute `actor session` with one-time concerns --- server/crates/arbiter-proto/src/transport.rs | 3 +- .../src/actors/client/auth/mod.rs | 101 ++++ .../src/actors/client/auth/state.rs | 136 +++++ .../arbiter-server/src/actors/client/mod.rs | 274 +--------- .../src/actors/client/session.rs | 90 ++++ .../arbiter-server/src/actors/client/state.rs | 31 -- .../src/actors/user_agent/auth.rs | 118 +++++ .../src/actors/user_agent/auth/state.rs | 202 ++++++++ .../src/actors/user_agent/mod.rs | 467 ++---------------- .../src/actors/user_agent/session.rs | 241 +++++++++ .../src/actors/user_agent/session/state.rs | 27 + .../src/actors/user_agent/state.rs | 51 -- server/crates/arbiter-server/src/lib.rs | 110 ++--- server/crates/arbiter-server/tests/client.rs | 2 + .../arbiter-server/tests/client/auth.rs | 75 +-- .../crates/arbiter-server/tests/common/mod.rs | 47 ++ .../arbiter-server/tests/user_agent/auth.rs | 105 ++-- .../arbiter-server/tests/user_agent/unseal.rs | 71 +-- 18 files changed, 1176 insertions(+), 975 deletions(-) create mode 100644 server/crates/arbiter-server/src/actors/client/auth/mod.rs create mode 100644 server/crates/arbiter-server/src/actors/client/auth/state.rs create mode 100644 server/crates/arbiter-server/src/actors/client/session.rs delete mode 100644 server/crates/arbiter-server/src/actors/client/state.rs create mode 100644 server/crates/arbiter-server/src/actors/user_agent/auth.rs create mode 100644 server/crates/arbiter-server/src/actors/user_agent/auth/state.rs create mode 100644 server/crates/arbiter-server/src/actors/user_agent/session.rs create mode 100644 server/crates/arbiter-server/src/actors/user_agent/session/state.rs delete mode 100644 server/crates/arbiter-server/src/actors/user_agent/state.rs diff --git a/server/crates/arbiter-proto/src/transport.rs b/server/crates/arbiter-proto/src/transport.rs index 02ae72c..a38b892 100644 --- a/server/crates/arbiter-proto/src/transport.rs +++ b/server/crates/arbiter-proto/src/transport.rs @@ -79,8 +79,9 @@ use std::marker::PhantomData; use async_trait::async_trait; /// Errors returned by transport adapters implementing [`Bi`]. +#[derive(thiserror::Error, Debug)] pub enum Error { - /// The outbound side of the transport is no longer accepting messages. + #[error("Transport channel is closed")] ChannelClosed, } diff --git a/server/crates/arbiter-server/src/actors/client/auth/mod.rs b/server/crates/arbiter-server/src/actors/client/auth/mod.rs new file mode 100644 index 0000000..33a9826 --- /dev/null +++ b/server/crates/arbiter-server/src/actors/client/auth/mod.rs @@ -0,0 +1,101 @@ +use arbiter_proto::proto::client::{ + AuthChallengeRequest, AuthChallengeSolution, ClientRequest, + client_request::Payload as ClientRequestPayload, +}; +use ed25519_dalek::VerifyingKey; +use tracing::error; + +use crate::actors::client::{ + ConnectionProps, + auth::state::{AuthContext, AuthStateMachine}, + session::ClientSession, +}; + +#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)] +pub enum Error { + #[error("Unexpected message payload")] + UnexpectedMessagePayload, + #[error("Invalid client public key length")] + InvalidClientPubkeyLength, + #[error("Invalid client public key encoding")] + InvalidAuthPubkeyEncoding, + #[error("Database pool unavailable")] + DatabasePoolUnavailable, + #[error("Database operation failed")] + DatabaseOperationFailed, + #[error("Public key not registered")] + PublicKeyNotRegistered, + #[error("Invalid signature length")] + InvalidSignatureLength, + #[error("Invalid challenge solution")] + InvalidChallengeSolution, + #[error("Transport error")] + Transport, +} + +mod state; +use state::*; + +fn parse_auth_event(payload: ClientRequestPayload) -> Result { + match payload { + ClientRequestPayload::AuthChallengeRequest(AuthChallengeRequest { pubkey }) => { + let pubkey_bytes = pubkey.as_array().ok_or(Error::InvalidClientPubkeyLength)?; + let pubkey = VerifyingKey::from_bytes(pubkey_bytes) + .map_err(|_| Error::InvalidAuthPubkeyEncoding)?; + Ok(AuthEvents::AuthRequest(ChallengeRequest { + pubkey: pubkey.into(), + })) + } + ClientRequestPayload::AuthChallengeSolution(AuthChallengeSolution { signature }) => { + Ok(AuthEvents::ReceivedSolution(ChallengeSolution { + solution: signature, + })) + } + } +} + +pub async fn authenticate(props: &mut ConnectionProps) -> Result { + let mut state = AuthStateMachine::new(AuthContext::new(props)); + + loop { + let transport = state.context_mut().conn.transport.as_mut(); + let Some(ClientRequest { + payload: Some(payload), + }) = transport.recv().await + else { + return Err(Error::Transport); + }; + + let event = parse_auth_event(payload)?; + + match state.process_event(event).await { + Ok(AuthStates::AuthOk(key)) => return Ok(key.clone()), + Err(AuthError::ActionFailed(err)) => { + error!(?err, "State machine action failed"); + return Err(err); + } + Err(AuthError::GuardFailed(err)) => { + error!(?err, "State machine guard failed"); + return Err(err); + } + Err(AuthError::InvalidEvent) => { + error!("Invalid event for current state"); + return Err(Error::InvalidChallengeSolution); + } + Err(AuthError::TransitionsFailed) => { + error!("Invalid state transition"); + return Err(Error::InvalidChallengeSolution); + } + + _ => (), + } + } +} + +pub async fn authenticate_and_create( + mut props: ConnectionProps, +) -> Result { + let key = authenticate(&mut props).await?; + let session = ClientSession::new(props, key); + Ok(session) +} diff --git a/server/crates/arbiter-server/src/actors/client/auth/state.rs b/server/crates/arbiter-server/src/actors/client/auth/state.rs new file mode 100644 index 0000000..550934f --- /dev/null +++ b/server/crates/arbiter-server/src/actors/client/auth/state.rs @@ -0,0 +1,136 @@ +use arbiter_proto::proto::client::{ + AuthChallenge, ClientResponse, + client_response::Payload as ClientResponsePayload, +}; +use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, update}; +use diesel_async::RunQueryDsl; +use ed25519_dalek::VerifyingKey; +use tracing::error; + +use super::Error; +use crate::{actors::client::ConnectionProps, db::schema}; + +pub struct ChallengeRequest { + pub pubkey: VerifyingKey, +} + +pub struct ChallengeContext { + pub challenge: AuthChallenge, + pub key: VerifyingKey, +} + +pub struct ChallengeSolution { + pub solution: Vec, +} + +smlang::statemachine!( + name: Auth, + custom_error: true, + transitions: { + *Init + AuthRequest(ChallengeRequest) / async prepare_challenge = SentChallenge(ChallengeContext), + SentChallenge(ChallengeContext) + ReceivedSolution(ChallengeSolution) [async verify_solution] / provide_key = AuthOk(VerifyingKey), + } +); + +async fn create_nonce(db: &crate::db::DatabasePool, pubkey_bytes: &[u8]) -> Result { + let mut db_conn = db.get().await.map_err(|e| { + error!(error = ?e, "Database pool error"); + Error::DatabasePoolUnavailable + })?; + db_conn + .exclusive_transaction(|conn| { + Box::pin(async move { + let current_nonce = schema::program_client::table + .filter(schema::program_client::public_key.eq(pubkey_bytes.to_vec())) + .select(schema::program_client::nonce) + .first::(conn) + .await?; + + update(schema::program_client::table) + .filter(schema::program_client::public_key.eq(pubkey_bytes.to_vec())) + .set(schema::program_client::nonce.eq(current_nonce + 1)) + .execute(conn) + .await?; + + Result::<_, diesel::result::Error>::Ok(current_nonce) + }) + }) + .await + .optional() + .map_err(|e| { + error!(error = ?e, "Database error"); + Error::DatabaseOperationFailed + })? + .ok_or_else(|| { + error!(?pubkey_bytes, "Public key not found in database"); + Error::PublicKeyNotRegistered + }) +} + +pub struct AuthContext<'a> { + pub(super) conn: &'a mut ConnectionProps, +} + +impl<'a> AuthContext<'a> { + pub fn new(conn: &'a mut ConnectionProps) -> Self { + Self { conn } + } +} + +impl AuthStateMachineContext for AuthContext<'_> { + type Error = Error; + + async fn verify_solution( + &self, + ChallengeContext { challenge, key }: &ChallengeContext, + ChallengeSolution { solution }: &ChallengeSolution, + ) -> Result { + let formatted_challenge = + arbiter_proto::format_challenge(challenge.nonce, &challenge.pubkey); + + let signature = solution.as_slice().try_into().map_err(|_| { + error!(?solution, "Invalid signature length"); + Error::InvalidChallengeSolution + })?; + + let valid = key.verify_strict(&formatted_challenge, &signature).is_ok(); + + Ok(valid) + } + + async fn prepare_challenge( + &mut self, + ChallengeRequest { pubkey }: ChallengeRequest, + ) -> Result { + let nonce = create_nonce(&self.conn.db, pubkey.as_bytes()).await?; + + let challenge = AuthChallenge { + pubkey: pubkey.as_bytes().to_vec(), + nonce, + }; + + self.conn + .transport + .send(Ok(ClientResponse { + payload: Some(ClientResponsePayload::AuthChallenge(challenge.clone())), + })) + .await + .map_err(|e| { + error!(?e, "Failed to send auth challenge"); + Error::Transport + })?; + + Ok(ChallengeContext { + challenge, + key: pubkey, + }) + } + + fn provide_key( + &mut self, + state_data: &ChallengeContext, + _: ChallengeSolution, + ) -> Result { + Ok(state_data.key) + } +} diff --git a/server/crates/arbiter-server/src/actors/client/mod.rs b/server/crates/arbiter-server/src/actors/client/mod.rs index 405bd54..4a9f131 100644 --- a/server/crates/arbiter-server/src/actors/client/mod.rs +++ b/server/crates/arbiter-server/src/actors/client/mod.rs @@ -1,27 +1,11 @@ use arbiter_proto::{ - proto::client::{ - AuthChallenge, AuthChallengeRequest, AuthChallengeSolution, AuthOk, ClientRequest, - ClientResponse, client_request::Payload as ClientRequestPayload, - client_response::Payload as ClientResponsePayload, - }, - transport::{Bi, DummyTransport}, + proto::client::{ClientRequest, ClientResponse}, + transport::Bi, }; -use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, dsl::update}; -use diesel_async::RunQueryDsl; -use ed25519_dalek::VerifyingKey; -use kameo::Actor; -use tokio::select; +use kameo::actor::Spawn; use tracing::{error, info}; -use crate::{ - ServerContext, - actors::client::state::{ - ChallengeContext, ClientEvents, ClientStateMachine, ClientStates, DummyContext, - }, - db::{self, schema}, -}; - -mod state; +use crate::{actors::client::session::ClientSession, db}; #[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] pub enum ClientError { @@ -29,250 +13,36 @@ pub enum ClientError { MissingRequestPayload, #[error("Unexpected request payload")] UnexpectedRequestPayload, - #[error("Invalid state for challenge solution")] - InvalidStateForChallengeSolution, - #[error("Expected pubkey to have specific length")] - InvalidAuthPubkeyLength, - #[error("Failed to convert pubkey to VerifyingKey")] - InvalidAuthPubkeyEncoding, - #[error("Invalid signature length")] - InvalidSignatureLength, - #[error("Public key not registered")] - PublicKeyNotRegistered, - #[error("Invalid challenge solution")] - InvalidChallengeSolution, #[error("State machine error")] StateTransitionFailed, - #[error("Database pool error")] - DatabasePoolUnavailable, - #[error("Database error")] - DatabaseOperationFailed, + #[error(transparent)] + Auth(#[from] auth::Error), } pub type Transport = Box> + Send>; -pub struct ClientActor { - db: db::DatabasePool, - state: ClientStateMachine, - transport: Transport, +pub struct ConnectionProps { + pub(crate) db: db::DatabasePool, + pub(crate) transport: Transport, } -impl ClientActor { - pub(crate) fn new(context: ServerContext, transport: Transport) -> Self { - Self { - db: context.db.clone(), - state: ClientStateMachine::new(DummyContext), - transport, - } - } - - fn transition(&mut self, event: ClientEvents) -> Result<(), ClientError> { - self.state.process_event(event).map_err(|e| { - error!(?e, "State transition failed"); - ClientError::StateTransitionFailed - })?; - Ok(()) - } - - pub async fn process_transport_inbound(&mut self, req: ClientRequest) -> Output { - let msg = req.payload.ok_or_else(|| { - error!(actor = "client", "Received message with no payload"); - ClientError::MissingRequestPayload - })?; - - match msg { - ClientRequestPayload::AuthChallengeRequest(req) => { - self.handle_auth_challenge_request(req).await - } - ClientRequestPayload::AuthChallengeSolution(solution) => { - self.handle_auth_challenge_solution(solution).await - } - } - } - - async fn handle_auth_challenge_request(&mut self, req: AuthChallengeRequest) -> Output { - let pubkey = req - .pubkey - .as_array() - .ok_or(ClientError::InvalidAuthPubkeyLength)?; - let pubkey = VerifyingKey::from_bytes(pubkey).map_err(|_err| { - error!(?pubkey, "Failed to convert to VerifyingKey"); - ClientError::InvalidAuthPubkeyEncoding - })?; - - self.transition(ClientEvents::AuthRequest)?; - - self.auth_with_challenge(pubkey, req.pubkey).await - } - - async fn auth_with_challenge(&mut self, pubkey: VerifyingKey, pubkey_bytes: Vec) -> Output { - let nonce: Option = { - let mut db_conn = self.db.get().await.map_err(|e| { - error!(error = ?e, "Database pool error"); - ClientError::DatabasePoolUnavailable - })?; - db_conn - .exclusive_transaction(|conn| { - Box::pin(async move { - let current_nonce = schema::program_client::table - .filter( - schema::program_client::public_key.eq(pubkey.as_bytes().to_vec()), - ) - .select(schema::program_client::nonce) - .first::(conn) - .await?; - - update(schema::program_client::table) - .filter( - schema::program_client::public_key.eq(pubkey.as_bytes().to_vec()), - ) - .set(schema::program_client::nonce.eq(current_nonce + 1)) - .execute(conn) - .await?; - - Result::<_, diesel::result::Error>::Ok(current_nonce) - }) - }) - .await - .optional() - .map_err(|e| { - error!(error = ?e, "Database error"); - ClientError::DatabaseOperationFailed - })? - }; - - let Some(nonce) = nonce else { - error!(?pubkey, "Public key not found in database"); - return Err(ClientError::PublicKeyNotRegistered); - }; - - let challenge = AuthChallenge { - pubkey: pubkey_bytes, - nonce, - }; - - self.transition(ClientEvents::SentChallenge(ChallengeContext { - challenge: challenge.clone(), - key: pubkey, - }))?; - - info!( - ?pubkey, - ?challenge, - "Sent authentication challenge to client" - ); - - Ok(response(ClientResponsePayload::AuthChallenge(challenge))) - } - - fn verify_challenge_solution( - &self, - solution: &AuthChallengeSolution, - ) -> Result<(bool, &ChallengeContext), ClientError> { - let ClientStates::WaitingForChallengeSolution(challenge_context) = self.state.state() - else { - error!("Received challenge solution in invalid state"); - return Err(ClientError::InvalidStateForChallengeSolution); - }; - let formatted_challenge = arbiter_proto::format_challenge( - challenge_context.challenge.nonce, - &challenge_context.challenge.pubkey, - ); - - let signature = solution.signature.as_slice().try_into().map_err(|_| { - error!(?solution, "Invalid signature length"); - ClientError::InvalidSignatureLength - })?; - - let valid = challenge_context - .key - .verify_strict(&formatted_challenge, &signature) - .is_ok(); - - Ok((valid, challenge_context)) - } - - async fn handle_auth_challenge_solution(&mut self, solution: AuthChallengeSolution) -> Output { - let (valid, challenge_context) = self.verify_challenge_solution(&solution)?; - - if valid { - info!( - ?challenge_context, - "Client provided valid solution to authentication challenge" - ); - self.transition(ClientEvents::ReceivedGoodSolution)?; - Ok(response(ClientResponsePayload::AuthOk(AuthOk {}))) - } else { - error!("Client provided invalid solution to authentication challenge"); - self.transition(ClientEvents::ReceivedBadSolution)?; - Err(ClientError::InvalidChallengeSolution) - } +impl ConnectionProps { + pub fn new(db: db::DatabasePool, transport: Transport) -> Self { + Self { db, transport } } } -type Output = Result; +pub mod auth; +pub mod session; -fn response(payload: ClientResponsePayload) -> ClientResponse { - ClientResponse { - payload: Some(payload), - } -} - -impl Actor for ClientActor { - type Args = Self; - - type Error = (); - - async fn on_start( - args: Self::Args, - _: kameo::prelude::ActorRef, - ) -> Result { - Ok(args) - } - - async fn next( - &mut self, - _actor_ref: kameo::prelude::WeakActorRef, - mailbox_rx: &mut kameo::prelude::MailboxReceiver, - ) -> Option> { - loop { - select! { - signal = mailbox_rx.recv() => { - return signal; - } - msg = self.transport.recv() => { - match msg { - Some(request) => { - match self.process_transport_inbound(request).await { - Ok(resp) => { - if self.transport.send(Ok(resp)).await.is_err() { - error!(actor = "client", reason = "channel closed", "send.failed"); - return Some(kameo::mailbox::Signal::Stop); - } - } - Err(err) => { - let _ = self.transport.send(Err(err)).await; - return Some(kameo::mailbox::Signal::Stop); - } - } - } - None => { - info!(actor = "client", "transport.closed"); - return Some(kameo::mailbox::Signal::Stop); - } - } - } - } - } - } -} - -impl ClientActor { - pub fn new_manual(db: db::DatabasePool) -> Self { - Self { - db, - state: ClientStateMachine::new(DummyContext), - transport: Box::new(DummyTransport::new()), +pub async fn connect_client(props: ConnectionProps) { + match auth::authenticate_and_create(props).await { + Ok(session) => { + ClientSession::spawn(session); + info!("Client authenticated, session started"); + } + Err(err) => { + error!(?err, "Authentication failed, closing connection"); } } } diff --git a/server/crates/arbiter-server/src/actors/client/session.rs b/server/crates/arbiter-server/src/actors/client/session.rs new file mode 100644 index 0000000..29a0dd6 --- /dev/null +++ b/server/crates/arbiter-server/src/actors/client/session.rs @@ -0,0 +1,90 @@ +use arbiter_proto::proto::client::{ClientRequest, ClientResponse}; +use ed25519_dalek::VerifyingKey; +use kameo::Actor; +use tokio::select; +use tracing::{error, info}; + +use crate::actors::client::{ClientError, ConnectionProps}; + +pub struct ClientSession { + props: ConnectionProps, + key: VerifyingKey, +} + +impl ClientSession { + pub(crate) fn new(props: ConnectionProps, key: VerifyingKey) -> Self { + Self { props, key } + } + + pub async fn process_transport_inbound(&mut self, req: ClientRequest) -> Output { + let msg = req.payload.ok_or_else(|| { + error!(actor = "client", "Received message with no payload"); + ClientError::MissingRequestPayload + })?; + + match msg { + _ => Err(ClientError::UnexpectedRequestPayload), + } + } +} + +type Output = Result; + +impl Actor for ClientSession { + type Args = Self; + + type Error = (); + + async fn on_start( + args: Self::Args, + _: kameo::prelude::ActorRef, + ) -> Result { + Ok(args) + } + + async fn next( + &mut self, + _actor_ref: kameo::prelude::WeakActorRef, + mailbox_rx: &mut kameo::prelude::MailboxReceiver, + ) -> Option> { + loop { + select! { + signal = mailbox_rx.recv() => { + return signal; + } + msg = self.props.transport.recv() => { + match msg { + Some(request) => { + match self.process_transport_inbound(request).await { + Ok(resp) => { + if self.props.transport.send(Ok(resp)).await.is_err() { + error!(actor = "client", reason = "channel closed", "send.failed"); + return Some(kameo::mailbox::Signal::Stop); + } + } + Err(err) => { + let _ = self.props.transport.send(Err(err)).await; + return Some(kameo::mailbox::Signal::Stop); + } + } + } + None => { + info!(actor = "client", "transport.closed"); + return Some(kameo::mailbox::Signal::Stop); + } + } + } + } + } + } +} + +impl ClientSession { + pub fn new_test(db: crate::db::DatabasePool) -> Self { + use arbiter_proto::transport::DummyTransport; + let transport: super::Transport = Box::new(DummyTransport::new()); + let props = ConnectionProps::new(db, transport); + let key = VerifyingKey::from_bytes(&[0u8; 32]).unwrap(); + Self { props, key } + } +} diff --git a/server/crates/arbiter-server/src/actors/client/state.rs b/server/crates/arbiter-server/src/actors/client/state.rs deleted file mode 100644 index 50382a4..0000000 --- a/server/crates/arbiter-server/src/actors/client/state.rs +++ /dev/null @@ -1,31 +0,0 @@ -use arbiter_proto::proto::client::AuthChallenge; -use ed25519_dalek::VerifyingKey; - -/// Context for state machine with validated key and sent challenge -#[derive(Clone, Debug)] -pub struct ChallengeContext { - pub challenge: AuthChallenge, - pub key: VerifyingKey, -} - -smlang::statemachine!( - name: Client, - custom_error: false, - transitions: { - *Init + AuthRequest = ReceivedAuthRequest, - - ReceivedAuthRequest + SentChallenge(ChallengeContext) / move_challenge = WaitingForChallengeSolution(ChallengeContext), - - WaitingForChallengeSolution(ChallengeContext) + ReceivedGoodSolution = Idle, - WaitingForChallengeSolution(ChallengeContext) + ReceivedBadSolution = AuthError, - } -); - -pub struct DummyContext; -impl ClientStateMachineContext for DummyContext { - #[allow(missing_docs)] - #[allow(clippy::unused_unit)] - fn move_challenge(&mut self, event_data: ChallengeContext) -> Result { - Ok(event_data) - } -} diff --git a/server/crates/arbiter-server/src/actors/user_agent/auth.rs b/server/crates/arbiter-server/src/actors/user_agent/auth.rs new file mode 100644 index 0000000..b34ed38 --- /dev/null +++ b/server/crates/arbiter-server/src/actors/user_agent/auth.rs @@ -0,0 +1,118 @@ +use arbiter_proto::proto::user_agent::{ + AuthChallengeRequest, AuthChallengeSolution, UserAgentRequest, + user_agent_request::Payload as UserAgentRequestPayload, +}; +use ed25519_dalek::VerifyingKey; +use tracing::error; + +use crate::actors::user_agent::{ + ConnectionProps, + auth::state::{AuthContext, AuthStateMachine}, session::UserAgentSession, +}; + +#[derive(thiserror::Error, Debug, PartialEq)] +pub enum Error { + #[error("Unexpected message payload")] + UnexpectedMessagePayload, + #[error("Invalid client public key length")] + InvalidClientPubkeyLength, + #[error("Invalid client public key encoding")] + InvalidAuthPubkeyEncoding, + #[error("Database pool unavailable")] + DatabasePoolUnavailable, + #[error("Database operation failed")] + DatabaseOperationFailed, + #[error("Public key not registered")] + PublicKeyNotRegistered, + #[error("Transport error")] + Transport, + #[error("Invalid bootstrap token")] + InvalidBootstrapToken, + #[error("Bootstrapper actor unreachable")] + BootstrapperActorUnreachable, + #[error("Invalid challenge solution")] + InvalidChallengeSolution, +} + +mod state; +use state::*; + +fn parse_auth_event(payload: UserAgentRequestPayload) -> Result { + match payload { + UserAgentRequestPayload::AuthChallengeRequest(AuthChallengeRequest { + pubkey, + bootstrap_token: None, + }) => { + let pubkey_bytes = pubkey.as_array().ok_or(Error::InvalidClientPubkeyLength)?; + let pubkey = VerifyingKey::from_bytes(pubkey_bytes) + .map_err(|_| Error::InvalidAuthPubkeyEncoding)?; + Ok(AuthEvents::AuthRequest(ChallengeRequest { + pubkey: pubkey.into(), + })) + } + UserAgentRequestPayload::AuthChallengeRequest(AuthChallengeRequest { + pubkey, + bootstrap_token: Some(token), + }) => { + let pubkey_bytes = pubkey.as_array().ok_or(Error::InvalidClientPubkeyLength)?; + let pubkey = VerifyingKey::from_bytes(pubkey_bytes) + .map_err(|_| Error::InvalidAuthPubkeyEncoding)?; + Ok(AuthEvents::BootstrapAuthRequest(BootstrapAuthRequest { + pubkey: pubkey.into(), + token, + })) + } + UserAgentRequestPayload::AuthChallengeSolution(AuthChallengeSolution { signature }) => { + Ok(AuthEvents::ReceivedSolution(ChallengeSolution { + solution: signature, + })) + } + _ => Err(Error::UnexpectedMessagePayload), + } +} + +pub async fn authenticate(props: &mut ConnectionProps) -> Result { + let mut state = AuthStateMachine::new(AuthContext::new(props)); + + loop { + // This is needed because `state` now holds mutable reference to `ConnectionProps`, so we can't directly access `props` here + let transport = state.context_mut().conn.transport.as_mut(); + let Some(UserAgentRequest { + payload: Some(payload), + }) = transport.recv().await + else { + return Err(Error::Transport); + }; + + let event = parse_auth_event(payload)?; + + match state.process_event(event).await { + Ok(AuthStates::AuthOk(key)) => return Ok(key.clone()), + Err(AuthError::ActionFailed(err)) => { + error!(?err, "State machine action failed"); + return Err(err); + } + Err(AuthError::GuardFailed(err)) => { + error!(?err, "State machine guard failed"); + return Err(err); + } + Err(AuthError::InvalidEvent) => { + error!("Invalid event for current state"); + return Err(Error::InvalidChallengeSolution); + } + Err(AuthError::TransitionsFailed) => { + error!("Invalid state transition"); + return Err(Error::InvalidChallengeSolution); + } + + _ => (), + } + } +} + + +pub async fn authenticate_and_create(mut props: ConnectionProps) -> Result { + let key = authenticate(&mut props).await?; + let session = UserAgentSession::new(props, key.clone()); + Ok(session) +} diff --git a/server/crates/arbiter-server/src/actors/user_agent/auth/state.rs b/server/crates/arbiter-server/src/actors/user_agent/auth/state.rs new file mode 100644 index 0000000..aa39bb6 --- /dev/null +++ b/server/crates/arbiter-server/src/actors/user_agent/auth/state.rs @@ -0,0 +1,202 @@ +use arbiter_proto::proto::user_agent::{ + AuthChallenge, UserAgentResponse, + user_agent_response::Payload as UserAgentResponsePayload, +}; +use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, update}; +use diesel_async::RunQueryDsl; +use ed25519_dalek::VerifyingKey; +use tracing::error; + +use super::Error; +use crate::{ + actors::{bootstrap::ConsumeToken, user_agent::ConnectionProps}, + db::schema, +}; + +pub struct ChallengeRequest { + pub pubkey: VerifyingKey, +} + +pub struct BootstrapAuthRequest { + pub pubkey: VerifyingKey, + pub token: String, +} + +pub struct ChallengeContext { + pub challenge: AuthChallenge, + pub key: VerifyingKey, +} + +pub struct ChallengeSolution { + pub solution: Vec, +} + +smlang::statemachine!( + name: Auth, + custom_error: true, + transitions: { + *Init + AuthRequest(ChallengeRequest) / async prepare_challenge = SentChallenge(ChallengeContext), + Init + BootstrapAuthRequest(BootstrapAuthRequest) [async verify_bootstrap_token] / provide_key_bootstrap = AuthOk(VerifyingKey), + SentChallenge(ChallengeContext) + ReceivedSolution(ChallengeSolution) [async verify_solution] / provide_key = AuthOk(VerifyingKey), + } +); + +async fn create_nonce(db: &crate::db::DatabasePool, pubkey_bytes: &[u8]) -> Result { + let mut db_conn = db.get().await.map_err(|e| { + error!(error = ?e, "Database pool error"); + Error::DatabasePoolUnavailable + })?; + db_conn + .exclusive_transaction(|conn| { + Box::pin(async move { + let current_nonce = schema::useragent_client::table + .filter(schema::useragent_client::public_key.eq(pubkey_bytes.to_vec())) + .select(schema::useragent_client::nonce) + .first::(conn) + .await?; + + update(schema::useragent_client::table) + .filter(schema::useragent_client::public_key.eq(pubkey_bytes.to_vec())) + .set(schema::useragent_client::nonce.eq(current_nonce + 1)) + .execute(conn) + .await?; + + Result::<_, diesel::result::Error>::Ok(current_nonce) + }) + }) + .await + .optional() + .map_err(|e| { + error!(error = ?e, "Database error"); + Error::DatabaseOperationFailed + })? + .ok_or_else(|| { + error!(?pubkey_bytes, "Public key not found in database"); + Error::PublicKeyNotRegistered + }) +} + +async fn register_key(db: &crate::db::DatabasePool, pubkey_bytes: &[u8]) -> Result<(), Error> { + let mut conn = db.get().await.map_err(|e| { + error!(error = ?e, "Database pool error"); + Error::DatabasePoolUnavailable + })?; + + diesel::insert_into(schema::useragent_client::table) + .values(( + schema::useragent_client::public_key.eq(pubkey_bytes.to_vec()), + schema::useragent_client::nonce.eq(1), + )) + .execute(&mut conn) + .await + .map_err(|e| { + error!(error = ?e, "Database error"); + Error::DatabaseOperationFailed + })?; + + Ok(()) +} + +pub struct AuthContext<'a> { + pub(super) conn: &'a mut ConnectionProps, +} + +impl<'a> AuthContext<'a> { + pub fn new(conn: &'a mut ConnectionProps) -> Self { + Self { conn } + } +} + +impl AuthStateMachineContext for AuthContext<'_> { + type Error = Error; + + async fn verify_solution( + &self, + ChallengeContext { challenge, key }: &ChallengeContext, + ChallengeSolution { solution }: &ChallengeSolution, + ) -> Result { + let formatted_challenge = + arbiter_proto::format_challenge(challenge.nonce, &challenge.pubkey); + + let signature = solution.as_slice().try_into().map_err(|_| { + error!(?solution, "Invalid signature length"); + Error::InvalidChallengeSolution + })?; + + let valid = key.verify_strict(&formatted_challenge, &signature).is_ok(); + + Ok(valid) + } + + async fn prepare_challenge( + &mut self, + ChallengeRequest { pubkey }: ChallengeRequest, + ) -> Result { + let nonce = create_nonce(&self.conn.db, pubkey.as_bytes()).await?; + + let challenge = AuthChallenge { + pubkey: pubkey.as_bytes().to_vec(), + nonce, + }; + + self.conn + .transport + .send(Ok(UserAgentResponse { + payload: Some(UserAgentResponsePayload::AuthChallenge(challenge.clone())), + })) + .await + .map_err(|e| { + error!(?e, "Failed to send auth challenge"); + Error::Transport + })?; + + Ok(ChallengeContext { + challenge, + key: pubkey, + }) + } + + #[allow(missing_docs)] + #[allow(clippy::result_unit_err)] + async fn verify_bootstrap_token( + &self, + BootstrapAuthRequest { pubkey, token }: &BootstrapAuthRequest, + ) -> Result { + let token_ok: bool = self + .conn + .actors + .bootstrapper + .ask(ConsumeToken { + token: token.clone(), + }) + .await + .map_err(|e| { + error!(?pubkey, "Failed to consume bootstrap token: {e}"); + Error::BootstrapperActorUnreachable + })?; + + if !token_ok { + error!(?pubkey, "Invalid bootstrap token provided"); + return Err(Error::InvalidBootstrapToken); + } + + register_key(&self.conn.db, pubkey.as_bytes()).await?; + + Ok(true) + } + + fn provide_key_bootstrap( + &mut self, + event_data: BootstrapAuthRequest, + ) -> Result { + Ok(event_data.pubkey) + } + + fn provide_key( + &mut self, + state_data: &ChallengeContext, + _: ChallengeSolution, + ) -> Result { + Ok(state_data.key) + } +} diff --git a/server/crates/arbiter-server/src/actors/user_agent/mod.rs b/server/crates/arbiter-server/src/actors/user_agent/mod.rs index 762ae6d..f9ca3cc 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/mod.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/mod.rs @@ -1,469 +1,60 @@ -use std::{ops::DerefMut, sync::Mutex}; - use arbiter_proto::{ - proto::user_agent::{ - AuthChallenge, AuthChallengeRequest, AuthChallengeSolution, AuthOk, UnsealEncryptedKey, - UnsealResult, UnsealStart, UnsealStartResponse, UserAgentRequest, UserAgentResponse, - user_agent_request::Payload as UserAgentRequestPayload, - user_agent_response::Payload as UserAgentResponsePayload, - }, - transport::{Bi, DummyTransport}, + proto::user_agent::{UserAgentRequest, UserAgentResponse}, + transport::Bi, }; -use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit}; -use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, dsl::update}; -use diesel_async::RunQueryDsl; -use ed25519_dalek::VerifyingKey; -use kameo::{Actor, error::SendError}; -use memsafe::MemSafe; -use tokio::select; +use kameo::actor::Spawn; use tracing::{error, info}; -use x25519_dalek::{EphemeralSecret, PublicKey}; -use crate::{ - ServerContext, - actors::{ - GlobalActors, - bootstrap::ConsumeToken, - keyholder::{self, TryUnseal}, - user_agent::state::{ - ChallengeContext, DummyContext, UnsealContext, UserAgentEvents, UserAgentStateMachine, - UserAgentStates, - }, - }, - db::{self, schema}, -}; +use crate::{actors::{GlobalActors, user_agent::session::UserAgentSession}, db}; -mod state; - -#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] +#[derive(Debug, thiserror::Error, PartialEq)] pub enum UserAgentError { #[error("Expected message with payload")] MissingRequestPayload, - #[error("Expected message with payload")] + #[error("Unexpected request payload")] UnexpectedRequestPayload, - #[error("Invalid state for challenge solution")] - InvalidStateForChallengeSolution, #[error("Invalid state for unseal encrypted key")] InvalidStateForUnsealEncryptedKey, #[error("client_pubkey must be 32 bytes")] InvalidClientPubkeyLength, - #[error("Expected pubkey to have specific length")] - InvalidAuthPubkeyLength, - #[error("Failed to convert pubkey to VerifyingKey")] - InvalidAuthPubkeyEncoding, - #[error("Invalid signature length")] - InvalidSignatureLength, - #[error("Invalid bootstrap token")] - InvalidBootstrapToken, - #[error("Public key not registered")] - PublicKeyNotRegistered, - #[error("Invalid challenge solution")] - InvalidChallengeSolution, #[error("State machine error")] StateTransitionFailed, - #[error("Bootstrap token consumption failed")] - BootstrapperActorUnreachable, #[error("Vault is not available")] KeyHolderActorUnreachable, - #[error("Database pool error")] - DatabasePoolUnavailable, - #[error("Database error")] - DatabaseOperationFailed, + #[error(transparent)] + Auth(#[from] auth::Error), } -pub type Transport = Box> + Send>; +pub type Transport = + Box> + Send>; -pub struct UserAgentActor -{ +pub struct ConnectionProps { db: db::DatabasePool, actors: GlobalActors, - state: UserAgentStateMachine, transport: Transport, } -impl UserAgentActor { - pub(crate) fn new(context: ServerContext, transport: Transport) -> Self { - Self { - db: context.db.clone(), - actors: context.actors.clone(), - state: UserAgentStateMachine::new(DummyContext), - transport, - } - } - - fn transition(&mut self, event: UserAgentEvents) -> Result<(), UserAgentError> { - self.state.process_event(event).map_err(|e| { - error!(?e, "State transition failed"); - UserAgentError::StateTransitionFailed - })?; - Ok(()) - } - - pub async fn process_transport_inbound(&mut self, req: UserAgentRequest) -> Output { - let msg = req.payload.ok_or_else(|| { - error!(actor = "useragent", "Received message with no payload"); - UserAgentError::MissingRequestPayload - })?; - - match msg { - UserAgentRequestPayload::AuthChallengeRequest(req) => { - self.handle_auth_challenge_request(req).await - } - UserAgentRequestPayload::AuthChallengeSolution(solution) => { - self.handle_auth_challenge_solution(solution).await - } - UserAgentRequestPayload::UnsealStart(unseal_start) => { - self.handle_unseal_request(unseal_start).await - } - UserAgentRequestPayload::UnsealEncryptedKey(unseal_encrypted_key) => { - self.handle_unseal_encrypted_key(unseal_encrypted_key).await - } - _ => Err(UserAgentError::UnexpectedRequestPayload), - } - } - - async fn auth_with_bootstrap_token( - &mut self, - pubkey: ed25519_dalek::VerifyingKey, - token: String, - ) -> Result { - let token_ok: bool = self - .actors - .bootstrapper - .ask(ConsumeToken { token }) - .await - .map_err(|e| { - error!(?pubkey, "Failed to consume bootstrap token: {e}"); - UserAgentError::BootstrapperActorUnreachable - })?; - - if !token_ok { - error!(?pubkey, "Invalid bootstrap token provided"); - return Err(UserAgentError::InvalidBootstrapToken); - } - - { - let mut conn = self.db.get().await.map_err(|e| { - error!(error = ?e, "Database pool error"); - UserAgentError::DatabasePoolUnavailable - })?; - - diesel::insert_into(schema::useragent_client::table) - .values(( - schema::useragent_client::public_key.eq(pubkey.as_bytes().to_vec()), - schema::useragent_client::nonce.eq(1), - )) - .execute(&mut conn) - .await - .map_err(|e| { - error!(error = ?e, "Database error"); - UserAgentError::DatabaseOperationFailed - })?; - } - - self.transition(UserAgentEvents::ReceivedBootstrapToken)?; - - Ok(response(UserAgentResponsePayload::AuthOk(AuthOk {}))) - } - - async fn auth_with_challenge(&mut self, pubkey: VerifyingKey, pubkey_bytes: Vec) -> Output { - let nonce: Option = { - let mut db_conn = self.db.get().await.map_err(|e| { - error!(error = ?e, "Database pool error"); - UserAgentError::DatabasePoolUnavailable - })?; - db_conn - .exclusive_transaction(|conn| { - Box::pin(async move { - let current_nonce = schema::useragent_client::table - .filter( - schema::useragent_client::public_key.eq(pubkey.as_bytes().to_vec()), - ) - .select(schema::useragent_client::nonce) - .first::(conn) - .await?; - - update(schema::useragent_client::table) - .filter( - schema::useragent_client::public_key.eq(pubkey.as_bytes().to_vec()), - ) - .set(schema::useragent_client::nonce.eq(current_nonce + 1)) - .execute(conn) - .await?; - - Result::<_, diesel::result::Error>::Ok(current_nonce) - }) - }) - .await - .optional() - .map_err(|e| { - error!(error = ?e, "Database error"); - UserAgentError::DatabaseOperationFailed - })? - }; - - let Some(nonce) = nonce else { - error!(?pubkey, "Public key not found in database"); - return Err(UserAgentError::PublicKeyNotRegistered); - }; - - let challenge = AuthChallenge { - pubkey: pubkey_bytes, - nonce, - }; - - self.transition(UserAgentEvents::SentChallenge(ChallengeContext { - challenge: challenge.clone(), - key: pubkey, - }))?; - - info!( - ?pubkey, - ?challenge, - "Sent authentication challenge to client" - ); - - Ok(response(UserAgentResponsePayload::AuthChallenge(challenge))) - } - - fn verify_challenge_solution( - &self, - solution: &AuthChallengeSolution, - ) -> Result<(bool, &ChallengeContext), UserAgentError> { - let UserAgentStates::WaitingForChallengeSolution(challenge_context) = self.state.state() - else { - error!("Received challenge solution in invalid state"); - return Err(UserAgentError::InvalidStateForChallengeSolution); - }; - let formatted_challenge = arbiter_proto::format_challenge( - challenge_context.challenge.nonce, - &challenge_context.challenge.pubkey, - ); - - let signature = solution.signature.as_slice().try_into().map_err(|_| { - error!(?solution, "Invalid signature length"); - UserAgentError::InvalidSignatureLength - })?; - - let valid = challenge_context - .key - .verify_strict(&formatted_challenge, &signature) - .is_ok(); - - Ok((valid, challenge_context)) - } -} - -type Output = Result; - -fn response(payload: UserAgentResponsePayload) -> UserAgentResponse { - UserAgentResponse { - payload: Some(payload), - } -} - -impl UserAgentActor { - async fn handle_unseal_request(&mut self, req: UnsealStart) -> Output { - let secret = EphemeralSecret::random(); - let public_key = PublicKey::from(&secret); - - let client_pubkey_bytes: [u8; 32] = req - .client_pubkey - .try_into() - .map_err(|_| UserAgentError::InvalidClientPubkeyLength)?; - - let client_public_key = PublicKey::from(client_pubkey_bytes); - - self.transition(UserAgentEvents::UnsealRequest(UnsealContext { - secret: Mutex::new(Some(secret)), - client_public_key, - }))?; - - Ok(response( - UserAgentResponsePayload::UnsealStartResponse(UnsealStartResponse { - server_pubkey: public_key.as_bytes().to_vec(), - }), - )) - } - - async fn handle_unseal_encrypted_key(&mut self, req: UnsealEncryptedKey) -> Output { - let UserAgentStates::WaitingForUnsealKey(unseal_context) = self.state.state() else { - error!("Received unseal encrypted key in invalid state"); - return Err(UserAgentError::InvalidStateForUnsealEncryptedKey); - }; - let ephemeral_secret = { - let mut secret_lock = unseal_context.secret.lock().unwrap(); - let secret = secret_lock.take(); - match secret { - Some(secret) => secret, - None => { - drop(secret_lock); - error!("Ephemeral secret already taken"); - self.transition(UserAgentEvents::ReceivedInvalidKey)?; - return Ok(response(UserAgentResponsePayload::UnsealResult( - UnsealResult::InvalidKey.into(), - ))); - } - } - }; - - let nonce = XNonce::from_slice(&req.nonce); - - let shared_secret = ephemeral_secret.diffie_hellman(&unseal_context.client_public_key); - let cipher = XChaCha20Poly1305::new(shared_secret.as_bytes().into()); - - let mut seal_key_buffer = MemSafe::new(req.ciphertext.clone()).unwrap(); - - let decryption_result = { - let mut write_handle = seal_key_buffer.write().unwrap(); - let write_handle = write_handle.deref_mut(); - cipher.decrypt_in_place(nonce, &req.associated_data, write_handle) - }; - - match decryption_result { - Ok(_) => { - match self - .actors - .key_holder - .ask(TryUnseal { - seal_key_raw: seal_key_buffer, - }) - .await - { - Ok(_) => { - info!("Successfully unsealed key with client-provided key"); - self.transition(UserAgentEvents::ReceivedValidKey)?; - Ok(response(UserAgentResponsePayload::UnsealResult( - UnsealResult::Success.into(), - ))) - } - Err(SendError::HandlerError(keyholder::Error::InvalidKey)) => { - self.transition(UserAgentEvents::ReceivedInvalidKey)?; - Ok(response(UserAgentResponsePayload::UnsealResult( - UnsealResult::InvalidKey.into(), - ))) - } - Err(SendError::HandlerError(err)) => { - error!(?err, "Keyholder failed to unseal key"); - self.transition(UserAgentEvents::ReceivedInvalidKey)?; - Ok(response(UserAgentResponsePayload::UnsealResult( - UnsealResult::InvalidKey.into(), - ))) - } - Err(err) => { - error!(?err, "Failed to send unseal request to keyholder"); - self.transition(UserAgentEvents::ReceivedInvalidKey)?; - Err(UserAgentError::KeyHolderActorUnreachable) - } - } - } - Err(err) => { - error!(?err, "Failed to decrypt unseal key"); - self.transition(UserAgentEvents::ReceivedInvalidKey)?; - Ok(response(UserAgentResponsePayload::UnsealResult( - UnsealResult::InvalidKey.into(), - ))) - } - } - } - - async fn handle_auth_challenge_request(&mut self, req: AuthChallengeRequest) -> Output { - let pubkey = req - .pubkey - .as_array() - .ok_or(UserAgentError::InvalidAuthPubkeyLength)?; - let pubkey = VerifyingKey::from_bytes(pubkey).map_err(|_err| { - error!(?pubkey, "Failed to convert to VerifyingKey"); - UserAgentError::InvalidAuthPubkeyEncoding - })?; - - self.transition(UserAgentEvents::AuthRequest)?; - - match req.bootstrap_token { - Some(token) => self.auth_with_bootstrap_token(pubkey, token).await, - None => self.auth_with_challenge(pubkey, req.pubkey).await, - } - } - - async fn handle_auth_challenge_solution( - &mut self, - solution: AuthChallengeSolution, - ) -> Output { - let (valid, challenge_context) = self.verify_challenge_solution(&solution)?; - - if valid { - info!( - ?challenge_context, - "Client provided valid solution to authentication challenge" - ); - self.transition(UserAgentEvents::ReceivedGoodSolution)?; - Ok(response(UserAgentResponsePayload::AuthOk(AuthOk {}))) - } else { - error!("Client provided invalid solution to authentication challenge"); - self.transition(UserAgentEvents::ReceivedBadSolution)?; - Err(UserAgentError::InvalidChallengeSolution) - } - } -} - - -impl Actor for UserAgentActor { - type Args = Self; - - type Error = (); - - async fn on_start( - args: Self::Args, - _: kameo::prelude::ActorRef, - ) -> Result { - Ok(args) - } - - async fn next( - &mut self, - _actor_ref: kameo::prelude::WeakActorRef, - mailbox_rx: &mut kameo::prelude::MailboxReceiver, - ) -> Option> { - loop { - select! { - signal = mailbox_rx.recv() => { - return signal; - } - msg = self.transport.recv() => { - match msg { - Some(request) => { - match self.process_transport_inbound(request).await { - Ok(response) => { - if self.transport.send(Ok(response)).await.is_err() { - error!(actor = "useragent", reason = "channel closed", "send.failed"); - return Some(kameo::mailbox::Signal::Stop); - } - } - Err(err) => { - let _ = self.transport.send(Err(err)).await; - return Some(kameo::mailbox::Signal::Stop); - } - } - } - None => { - info!(actor = "useragent", "transport.closed"); - return Some(kameo::mailbox::Signal::Stop); - } - } - } - } - } - } -} - - -impl UserAgentActor { - pub fn new_manual(db: db::DatabasePool, actors: GlobalActors) -> Self { +impl ConnectionProps { + pub fn new(db: db::DatabasePool, actors: GlobalActors, transport: Transport) -> Self { Self { db, actors, - state: UserAgentStateMachine::new(DummyContext), - transport: Box::new(DummyTransport::new()), + transport, } } } + +pub mod session; +pub mod auth; + +pub async fn connect_user_agent(mut props: ConnectionProps) { + match auth::authenticate_and_create( props).await { + Ok(session) => { + UserAgentSession::spawn(session); + info!("User authenticated, session started"); + }, + Err(err) => { + error!(?err, "Authentication failed, closing connection"); + }, + } +} diff --git a/server/crates/arbiter-server/src/actors/user_agent/session.rs b/server/crates/arbiter-server/src/actors/user_agent/session.rs new file mode 100644 index 0000000..6cc0808 --- /dev/null +++ b/server/crates/arbiter-server/src/actors/user_agent/session.rs @@ -0,0 +1,241 @@ +use std::{ops::DerefMut, sync::Mutex}; + +use arbiter_proto::proto::user_agent::{ + UnsealEncryptedKey, UnsealResult, UnsealStart, UnsealStartResponse, UserAgentRequest, + UserAgentResponse, user_agent_request::Payload as UserAgentRequestPayload, + user_agent_response::Payload as UserAgentResponsePayload, +}; +use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit}; +use ed25519_dalek::VerifyingKey; +use kameo::{Actor, error::SendError}; +use memsafe::MemSafe; +use tokio::select; +use tracing::{error, info}; +use x25519_dalek::{EphemeralSecret, PublicKey}; + +use crate::actors::{ + keyholder::{self, TryUnseal}, + user_agent::{ConnectionProps, UserAgentError}, +}; + +mod state; +use state::{DummyContext, UnsealContext, UserAgentEvents, UserAgentStateMachine, UserAgentStates}; + +pub struct UserAgentSession { + props: ConnectionProps, + key: VerifyingKey, + state: UserAgentStateMachine, +} + +impl UserAgentSession { + pub(crate) fn new(props: ConnectionProps, key: VerifyingKey) -> Self { + Self { + props, + key, + state: UserAgentStateMachine::new(DummyContext), + } + } + + fn transition(&mut self, event: UserAgentEvents) -> Result<(), UserAgentError> { + self.state.process_event(event).map_err(|e| { + error!(?e, "State transition failed"); + UserAgentError::StateTransitionFailed + })?; + Ok(()) + } + + pub async fn process_transport_inbound(&mut self, req: UserAgentRequest) -> Output { + let msg = req.payload.ok_or_else(|| { + error!(actor = "useragent", "Received message with no payload"); + UserAgentError::MissingRequestPayload + })?; + + match msg { + UserAgentRequestPayload::UnsealStart(unseal_start) => { + self.handle_unseal_request(unseal_start).await + } + UserAgentRequestPayload::UnsealEncryptedKey(unseal_encrypted_key) => { + self.handle_unseal_encrypted_key(unseal_encrypted_key).await + } + _ => Err(UserAgentError::UnexpectedRequestPayload), + } + } +} + +type Output = Result; + +fn response(payload: UserAgentResponsePayload) -> UserAgentResponse { + UserAgentResponse { + payload: Some(payload), + } +} + +impl UserAgentSession { + async fn handle_unseal_request(&mut self, req: UnsealStart) -> Output { + let secret = EphemeralSecret::random(); + let public_key = PublicKey::from(&secret); + + let client_pubkey_bytes: [u8; 32] = req + .client_pubkey + .try_into() + .map_err(|_| UserAgentError::InvalidClientPubkeyLength)?; + + let client_public_key = PublicKey::from(client_pubkey_bytes); + + self.transition(UserAgentEvents::UnsealRequest(UnsealContext { + secret: Mutex::new(Some(secret)), + client_public_key, + }))?; + + Ok(response(UserAgentResponsePayload::UnsealStartResponse( + UnsealStartResponse { + server_pubkey: public_key.as_bytes().to_vec(), + }, + ))) + } + + async fn handle_unseal_encrypted_key(&mut self, req: UnsealEncryptedKey) -> Output { + let UserAgentStates::WaitingForUnsealKey(unseal_context) = self.state.state() else { + error!("Received unseal encrypted key in invalid state"); + return Err(UserAgentError::InvalidStateForUnsealEncryptedKey); + }; + let ephemeral_secret = { + let mut secret_lock = unseal_context.secret.lock().unwrap(); + let secret = secret_lock.take(); + match secret { + Some(secret) => secret, + None => { + drop(secret_lock); + error!("Ephemeral secret already taken"); + self.transition(UserAgentEvents::ReceivedInvalidKey)?; + return Ok(response(UserAgentResponsePayload::UnsealResult( + UnsealResult::InvalidKey.into(), + ))); + } + } + }; + + let nonce = XNonce::from_slice(&req.nonce); + + let shared_secret = ephemeral_secret.diffie_hellman(&unseal_context.client_public_key); + let cipher = XChaCha20Poly1305::new(shared_secret.as_bytes().into()); + + let mut seal_key_buffer = MemSafe::new(req.ciphertext.clone()).unwrap(); + + let decryption_result = { + let mut write_handle = seal_key_buffer.write().unwrap(); + let write_handle = write_handle.deref_mut(); + cipher.decrypt_in_place(nonce, &req.associated_data, write_handle) + }; + + match decryption_result { + Ok(_) => { + match self + .props + .actors + .key_holder + .ask(TryUnseal { + seal_key_raw: seal_key_buffer, + }) + .await + { + Ok(_) => { + info!("Successfully unsealed key with client-provided key"); + self.transition(UserAgentEvents::ReceivedValidKey)?; + Ok(response(UserAgentResponsePayload::UnsealResult( + UnsealResult::Success.into(), + ))) + } + Err(SendError::HandlerError(keyholder::Error::InvalidKey)) => { + self.transition(UserAgentEvents::ReceivedInvalidKey)?; + Ok(response(UserAgentResponsePayload::UnsealResult( + UnsealResult::InvalidKey.into(), + ))) + } + Err(SendError::HandlerError(err)) => { + error!(?err, "Keyholder failed to unseal key"); + self.transition(UserAgentEvents::ReceivedInvalidKey)?; + Ok(response(UserAgentResponsePayload::UnsealResult( + UnsealResult::InvalidKey.into(), + ))) + } + Err(err) => { + error!(?err, "Failed to send unseal request to keyholder"); + self.transition(UserAgentEvents::ReceivedInvalidKey)?; + Err(UserAgentError::KeyHolderActorUnreachable) + } + } + } + Err(err) => { + error!(?err, "Failed to decrypt unseal key"); + self.transition(UserAgentEvents::ReceivedInvalidKey)?; + Ok(response(UserAgentResponsePayload::UnsealResult( + UnsealResult::InvalidKey.into(), + ))) + } + } + } +} + +impl Actor for UserAgentSession { + type Args = Self; + + type Error = (); + + async fn on_start( + args: Self::Args, + _: kameo::prelude::ActorRef, + ) -> Result { + Ok(args) + } + + async fn next( + &mut self, + _actor_ref: kameo::prelude::WeakActorRef, + mailbox_rx: &mut kameo::prelude::MailboxReceiver, + ) -> Option> { + loop { + select! { + signal = mailbox_rx.recv() => { + return signal; + } + msg = self.props.transport.recv() => { + match msg { + Some(request) => { + match self.process_transport_inbound(request).await { + Ok(response) => { + if self.props.transport.send(Ok(response)).await.is_err() { + error!(actor = "useragent", reason = "channel closed", "send.failed"); + return Some(kameo::mailbox::Signal::Stop); + } + } + Err(err) => { + let _ = self.props.transport.send(Err(err)).await; + return Some(kameo::mailbox::Signal::Stop); + } + } + } + None => { + info!(actor = "useragent", "transport.closed"); + return Some(kameo::mailbox::Signal::Stop); + } + } + } + } + } + } +} + +impl UserAgentSession { + pub fn new_test(db: crate::db::DatabasePool, actors: crate::actors::GlobalActors) -> Self { + use arbiter_proto::transport::DummyTransport; + let transport: super::Transport = Box::new(DummyTransport::new()); + let props = ConnectionProps::new(db, actors, transport); + let key = VerifyingKey::from_bytes(&[0u8; 32]).unwrap(); + Self { + props, + key, + state: UserAgentStateMachine::new(DummyContext), + } + } +} diff --git a/server/crates/arbiter-server/src/actors/user_agent/session/state.rs b/server/crates/arbiter-server/src/actors/user_agent/session/state.rs new file mode 100644 index 0000000..23ab674 --- /dev/null +++ b/server/crates/arbiter-server/src/actors/user_agent/session/state.rs @@ -0,0 +1,27 @@ +use std::sync::Mutex; + +use x25519_dalek::{EphemeralSecret, PublicKey}; + +pub struct UnsealContext { + pub client_public_key: PublicKey, + pub secret: Mutex>, +} + +smlang::statemachine!( + name: UserAgent, + custom_error: false, + transitions: { + *Idle + UnsealRequest(UnsealContext) / generate_temp_keypair = WaitingForUnsealKey(UnsealContext), + WaitingForUnsealKey(UnsealContext) + ReceivedValidKey = Unsealed, + WaitingForUnsealKey(UnsealContext) + ReceivedInvalidKey = Idle, + } +); + +pub struct DummyContext; +impl UserAgentStateMachineContext for DummyContext { + #[allow(missing_docs)] + #[allow(clippy::unused_unit)] + fn generate_temp_keypair(&mut self, event_data: UnsealContext) -> Result { + Ok(event_data) + } +} diff --git a/server/crates/arbiter-server/src/actors/user_agent/state.rs b/server/crates/arbiter-server/src/actors/user_agent/state.rs deleted file mode 100644 index e158a16..0000000 --- a/server/crates/arbiter-server/src/actors/user_agent/state.rs +++ /dev/null @@ -1,51 +0,0 @@ -use std::sync::Mutex; - -use arbiter_proto::proto::user_agent::AuthChallenge; -use ed25519_dalek::VerifyingKey; -use x25519_dalek::{EphemeralSecret, PublicKey}; - -/// Context for state machine with validated key and sent challenge -/// Challenge is then transformed to bytes using shared function and verified -#[derive(Clone, Debug)] -pub struct ChallengeContext { - pub challenge: AuthChallenge, - pub key: VerifyingKey, -} - -pub struct UnsealContext { - pub client_public_key: PublicKey, - pub secret: Mutex>, -} - -smlang::statemachine!( - name: UserAgent, - custom_error: false, - transitions: { - *Init + AuthRequest = ReceivedAuthRequest, - ReceivedAuthRequest + ReceivedBootstrapToken = Idle, - - ReceivedAuthRequest + SentChallenge(ChallengeContext) / move_challenge = WaitingForChallengeSolution(ChallengeContext), - - WaitingForChallengeSolution(ChallengeContext) + ReceivedGoodSolution = Idle, - WaitingForChallengeSolution(ChallengeContext) + ReceivedBadSolution = AuthError, // block further transitions, but connection should close anyway - - Idle + UnsealRequest(UnsealContext) / generate_temp_keypair = WaitingForUnsealKey(UnsealContext), - WaitingForUnsealKey(UnsealContext) + ReceivedValidKey = Unsealed, - WaitingForUnsealKey(UnsealContext) + ReceivedInvalidKey = Idle, - } -); - -pub struct DummyContext; -impl UserAgentStateMachineContext for DummyContext { - #[allow(missing_docs)] - #[allow(clippy::unused_unit)] - fn generate_temp_keypair(&mut self, event_data: UnsealContext) -> Result { - Ok(event_data) - } - - #[allow(missing_docs)] - #[allow(clippy::unused_unit)] - fn move_challenge(&mut self, event_data: ChallengeContext) -> Result { - Ok(event_data) - } -} diff --git a/server/crates/arbiter-server/src/lib.rs b/server/crates/arbiter-server/src/lib.rs index a7b5ebe..1d7fa97 100644 --- a/server/crates/arbiter-server/src/lib.rs +++ b/server/crates/arbiter-server/src/lib.rs @@ -7,7 +7,6 @@ use arbiter_proto::{ transport::{IdentityRecvConverter, SendConverter, grpc}, }; use async_trait::async_trait; -use kameo::actor::Spawn; use tokio_stream::wrappers::ReceiverStream; use tokio::sync::mpsc; @@ -16,8 +15,8 @@ use tracing::info; use crate::{ actors::{ - client::{ClientActor, ClientError}, - user_agent::{UserAgentActor, UserAgentError}, + client::{self, ClientError, ConnectionProps as ClientConnectionProps, connect_client}, + user_agent::{self, ConnectionProps, UserAgentError, connect_user_agent}, }, context::ServerContext, }; @@ -28,11 +27,6 @@ pub mod db; const DEFAULT_CHANNEL_SIZE: usize = 1000; -/// Converts User Agent domain outbounds into the tonic stream item emitted by -/// the server.ยง -/// -/// The conversion is defined at the server boundary so the actor module remains -/// focused on domain semantics and does not depend on tonic status encoding. struct UserAgentGrpcSender; impl SendConverter for UserAgentGrpcSender { @@ -47,11 +41,6 @@ impl SendConverter for UserAgentGrpcSender { } } -/// Converts Client domain outbounds into the tonic stream item emitted by the -/// server. -/// -/// The conversion is defined at the server boundary so the actor module remains -/// focused on domain semantics and does not depend on tonic status encoding. struct ClientGrpcSender; impl SendConverter for ClientGrpcSender { @@ -66,78 +55,71 @@ impl SendConverter for ClientGrpcSender { } } -/// Maps Client domain errors to public gRPC transport errors for the `client` -/// streaming endpoint. fn client_error_status(value: ClientError) -> Status { match value { ClientError::MissingRequestPayload | ClientError::UnexpectedRequestPayload => { Status::invalid_argument("Expected message with payload") } - ClientError::InvalidStateForChallengeSolution => { - Status::invalid_argument("Invalid state for challenge solution") - } - ClientError::InvalidAuthPubkeyLength => { - Status::invalid_argument("Expected pubkey to have specific length") - } - ClientError::InvalidAuthPubkeyEncoding => { - Status::invalid_argument("Failed to convert pubkey to VerifyingKey") - } - ClientError::InvalidSignatureLength => { - Status::invalid_argument("Invalid signature length") - } - ClientError::PublicKeyNotRegistered => { - Status::unauthenticated("Public key not registered") - } - ClientError::InvalidChallengeSolution => { - Status::unauthenticated("Invalid challenge solution") - } ClientError::StateTransitionFailed => Status::internal("State machine error"), - ClientError::DatabasePoolUnavailable => Status::internal("Database pool error"), - ClientError::DatabaseOperationFailed => Status::internal("Database error"), + ClientError::Auth(ref err) => client_auth_error_status(err), + } +} + +fn client_auth_error_status(value: &client::auth::Error) -> Status { + use client::auth::Error; + match value { + Error::UnexpectedMessagePayload | Error::InvalidClientPubkeyLength => { + Status::invalid_argument(value.to_string()) + } + Error::InvalidAuthPubkeyEncoding => { + Status::invalid_argument("Failed to convert pubkey to VerifyingKey") + } + Error::InvalidSignatureLength => Status::invalid_argument("Invalid signature length"), + Error::PublicKeyNotRegistered | Error::InvalidChallengeSolution => { + Status::unauthenticated(value.to_string()) + } + Error::Transport => Status::internal("Transport error"), + Error::DatabasePoolUnavailable => Status::internal("Database pool error"), + Error::DatabaseOperationFailed => Status::internal("Database error"), } } -/// Maps User Agent domain errors to public gRPC transport errors for the -/// `user_agent` streaming endpoint. fn user_agent_error_status(value: UserAgentError) -> Status { match value { UserAgentError::MissingRequestPayload | UserAgentError::UnexpectedRequestPayload => { Status::invalid_argument("Expected message with payload") } - UserAgentError::InvalidStateForChallengeSolution => { - Status::invalid_argument("Invalid state for challenge solution") - } UserAgentError::InvalidStateForUnsealEncryptedKey => { Status::failed_precondition("Invalid state for unseal encrypted key") } UserAgentError::InvalidClientPubkeyLength => { Status::invalid_argument("client_pubkey must be 32 bytes") } - UserAgentError::InvalidAuthPubkeyLength => { - Status::invalid_argument("Expected pubkey to have specific length") + UserAgentError::StateTransitionFailed => Status::internal("State machine error"), + UserAgentError::KeyHolderActorUnreachable => Status::internal("Vault is not available"), + UserAgentError::Auth(ref err) => auth_error_status(err), + } +} + +fn auth_error_status(value: &user_agent::auth::Error) -> Status { + use user_agent::auth::Error; + match value { + Error::UnexpectedMessagePayload | Error::InvalidClientPubkeyLength => { + Status::invalid_argument(value.to_string()) } - UserAgentError::InvalidAuthPubkeyEncoding => { + Error::InvalidAuthPubkeyEncoding => { Status::invalid_argument("Failed to convert pubkey to VerifyingKey") } - UserAgentError::InvalidSignatureLength => { - Status::invalid_argument("Invalid signature length") + Error::PublicKeyNotRegistered | Error::InvalidChallengeSolution => { + Status::unauthenticated(value.to_string()) } - UserAgentError::InvalidBootstrapToken => { - Status::invalid_argument("Invalid bootstrap token") - } - UserAgentError::PublicKeyNotRegistered => { - Status::unauthenticated("Public key not registered") - } - UserAgentError::InvalidChallengeSolution => { - Status::unauthenticated("Invalid challenge solution") - } - UserAgentError::StateTransitionFailed => Status::internal("State machine error"), - UserAgentError::BootstrapperActorUnreachable => { + Error::InvalidBootstrapToken => Status::invalid_argument("Invalid bootstrap token"), + Error::Transport => Status::internal("Transport error"), + Error::BootstrapperActorUnreachable => { Status::internal("Bootstrap token consumption failed") } - UserAgentError::KeyHolderActorUnreachable => Status::internal("Vault is not available"), - UserAgentError::DatabasePoolUnavailable => Status::internal("Database pool error"), - UserAgentError::DatabaseOperationFailed => Status::internal("Database error"), + Error::DatabasePoolUnavailable => Status::internal("Database pool error"), + Error::DatabaseOperationFailed => Status::internal("Database error"), } } @@ -170,7 +152,8 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server { IdentityRecvConverter::::new(), ClientGrpcSender, ); - ClientActor::spawn(ClientActor::new(self.context.clone(), Box::new(transport))); + let props = ClientConnectionProps::new(self.context.db.clone(), Box::new(transport)); + tokio::spawn(connect_client(props)); info!(event = "connection established", "grpc.client"); @@ -191,7 +174,12 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server { IdentityRecvConverter::::new(), UserAgentGrpcSender, ); - UserAgentActor::spawn(UserAgentActor::new(self.context.clone(), Box::new(transport))); + let props = ConnectionProps::new( + self.context.db.clone(), + self.context.actors.clone(), + Box::new(transport), + ); + tokio::spawn(connect_user_agent(props)); info!(event = "connection established", "grpc.user_agent"); diff --git a/server/crates/arbiter-server/tests/client.rs b/server/crates/arbiter-server/tests/client.rs index 2bef733..de40258 100644 --- a/server/crates/arbiter-server/tests/client.rs +++ b/server/crates/arbiter-server/tests/client.rs @@ -1,2 +1,4 @@ +mod common; + #[path = "client/auth.rs"] mod auth; diff --git a/server/crates/arbiter-server/tests/client/auth.rs b/server/crates/arbiter-server/tests/client/auth.rs index 7f84ba4..d7577a6 100644 --- a/server/crates/arbiter-server/tests/client/auth.rs +++ b/server/crates/arbiter-server/tests/client/auth.rs @@ -1,44 +1,44 @@ +use arbiter_proto::transport::Bi; use arbiter_proto::proto::client::{ - AuthChallengeRequest, AuthChallengeSolution, AuthOk, ClientRequest, ClientResponse, + AuthChallengeRequest, AuthChallengeSolution, ClientRequest, client_request::Payload as ClientRequestPayload, client_response::Payload as ClientResponsePayload, }; use arbiter_server::{ - actors::client::{ClientActor, ClientError}, + actors::client::{ConnectionProps, connect_client}, db::{self, schema}, }; use diesel::{ExpressionMethods as _, insert_into}; use diesel_async::RunQueryDsl; use ed25519_dalek::Signer as _; +use super::common::ChannelTransport; + #[tokio::test] #[test_log::test] pub async fn test_unregistered_pubkey_rejected() { let db = db::create_test_pool().await; - let mut client = ClientActor::new_manual(db.clone()); + let (server_transport, mut test_transport) = ChannelTransport::new(); + let props = ConnectionProps::new(db.clone(), Box::new(server_transport)); + let task = tokio::spawn(connect_client(props)); let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); - let result = client - .process_transport_inbound(ClientRequest { + test_transport + .send(ClientRequest { payload: Some(ClientRequestPayload::AuthChallengeRequest( AuthChallengeRequest { pubkey: pubkey_bytes, }, )), }) - .await; + .await + .unwrap(); - match result { - Err(err) => { - assert_eq!(err, ClientError::PublicKeyNotRegistered); - } - Ok(_) => { - panic!("Expected error due to unregistered pubkey, but got success"); - } - } + // Auth fails, connect_client returns, transport drops + task.await.unwrap(); } #[tokio::test] @@ -46,8 +46,6 @@ pub async fn test_unregistered_pubkey_rejected() { pub async fn test_challenge_auth() { let db = db::create_test_pool().await; - let mut client = ClientActor::new_manual(db.clone()); - let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); @@ -60,8 +58,13 @@ pub async fn test_challenge_auth() { .unwrap(); } - let result = client - .process_transport_inbound(ClientRequest { + let (server_transport, mut test_transport) = ChannelTransport::new(); + let props = ConnectionProps::new(db.clone(), Box::new(server_transport)); + let task = tokio::spawn(connect_client(props)); + + // Send challenge request + test_transport + .send(ClientRequest { payload: Some(ClientRequestPayload::AuthChallengeRequest( AuthChallengeRequest { pubkey: pubkey_bytes, @@ -69,34 +72,36 @@ pub async fn test_challenge_auth() { )), }) .await - .expect("Shouldn't fail to process message"); + .unwrap(); - let ClientResponse { - payload: Some(ClientResponsePayload::AuthChallenge(challenge)), - } = result - else { - panic!("Expected auth challenge response, got {result:?}"); + // Read the challenge response + let response = test_transport + .recv() + .await + .expect("should receive challenge"); + let challenge = match response { + Ok(resp) => match resp.payload { + Some(ClientResponsePayload::AuthChallenge(c)) => c, + other => panic!("Expected AuthChallenge, got {other:?}"), + }, + Err(err) => panic!("Expected Ok response, got Err({err:?})"), }; + // Sign the challenge and send solution let formatted_challenge = arbiter_proto::format_challenge(challenge.nonce, &challenge.pubkey); let signature = new_key.sign(&formatted_challenge); - let serialized_signature = signature.to_bytes().to_vec(); - let result = client - .process_transport_inbound(ClientRequest { + test_transport + .send(ClientRequest { payload: Some(ClientRequestPayload::AuthChallengeSolution( AuthChallengeSolution { - signature: serialized_signature, + signature: signature.to_bytes().to_vec(), }, )), }) .await - .expect("Shouldn't fail to process message"); + .unwrap(); - assert_eq!( - result, - ClientResponse { - payload: Some(ClientResponsePayload::AuthOk(AuthOk {})), - } - ); + // Auth completes, session spawned + task.await.unwrap(); } diff --git a/server/crates/arbiter-server/tests/common/mod.rs b/server/crates/arbiter-server/tests/common/mod.rs index ce01412..e23360f 100644 --- a/server/crates/arbiter-server/tests/common/mod.rs +++ b/server/crates/arbiter-server/tests/common/mod.rs @@ -1,10 +1,14 @@ +use arbiter_proto::transport::{Bi, Error}; use arbiter_server::{ actors::keyholder::KeyHolder, db::{self, schema}, }; +use async_trait::async_trait; use diesel::QueryDsl; use diesel_async::RunQueryDsl; use memsafe::MemSafe; +use tokio::sync::mpsc; + #[allow(dead_code)] pub async fn bootstrapped_keyholder(db: &db::DatabasePool) -> KeyHolder { @@ -26,3 +30,46 @@ pub async fn root_key_history_id(db: &db::DatabasePool) -> i32 { .unwrap(); id.expect("root_key_id should be set after bootstrap") } + + +pub struct ChannelTransport { + receiver: mpsc::Receiver, + sender: mpsc::Sender, +} + +impl ChannelTransport { + pub fn new() -> (Self, ChannelTransport) { + let (tx1, rx1) = mpsc::channel(10); + let (tx2, rx2) = mpsc::channel(10); + ( + Self { + receiver: rx1, + sender: tx2, + }, + ChannelTransport { + receiver: rx2, + sender: tx1, + }, + ) + } +} + + + +#[async_trait] +impl Bi for ChannelTransport +where + T: Send + 'static, + Y: Send + 'static, +{ + async fn send(&mut self, item: Y) -> Result<(), Error> { + self.sender + .send(item) + .await + .map_err(|_| Error::ChannelClosed) + } + + async fn recv(&mut self) -> Option { + self.receiver.recv().await + } +} diff --git a/server/crates/arbiter-server/tests/user_agent/auth.rs b/server/crates/arbiter-server/tests/user_agent/auth.rs index 93b8388..2704ae6 100644 --- a/server/crates/arbiter-server/tests/user_agent/auth.rs +++ b/server/crates/arbiter-server/tests/user_agent/auth.rs @@ -1,13 +1,14 @@ use arbiter_proto::proto::user_agent::{ - AuthChallengeRequest, AuthChallengeSolution, AuthOk, UserAgentRequest, UserAgentResponse, + AuthChallengeRequest, AuthChallengeSolution, UserAgentRequest, user_agent_request::Payload as UserAgentRequestPayload, user_agent_response::Payload as UserAgentResponsePayload, }; +use arbiter_proto::transport::Bi; use arbiter_server::{ actors::{ GlobalActors, bootstrap::GetToken, - user_agent::{UserAgentActor, UserAgentError}, + user_agent::{ConnectionProps, connect_user_agent}, }, db::{self, schema}, }; @@ -15,20 +16,24 @@ use diesel::{ExpressionMethods as _, QueryDsl, insert_into}; use diesel_async::RunQueryDsl; use ed25519_dalek::Signer as _; +use super::common::ChannelTransport; + #[tokio::test] #[test_log::test] pub async fn test_bootstrap_token_auth() { let db = db::create_test_pool().await; - let actors = GlobalActors::spawn(db.clone()).await.unwrap(); let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap(); - let mut user_agent = UserAgentActor::new_manual(db.clone(), actors); + + let (server_transport, mut test_transport) = ChannelTransport::new(); + let props = ConnectionProps::new(db.clone(), actors, Box::new(server_transport)); + let task = tokio::spawn(connect_user_agent(props)); let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); - let result = user_agent - .process_transport_inbound(UserAgentRequest { + test_transport + .send(UserAgentRequest { payload: Some(UserAgentRequestPayload::AuthChallengeRequest( AuthChallengeRequest { pubkey: pubkey_bytes, @@ -37,14 +42,9 @@ pub async fn test_bootstrap_token_auth() { )), }) .await - .expect("Shouldn't fail to process message"); + .unwrap(); - assert_eq!( - result, - UserAgentResponse { - payload: Some(UserAgentResponsePayload::AuthOk(AuthOk {})), - } - ); + task.await.unwrap(); let mut conn = db.get().await.unwrap(); let stored_pubkey: Vec = schema::useragent_client::table @@ -59,15 +59,17 @@ pub async fn test_bootstrap_token_auth() { #[test_log::test] pub async fn test_bootstrap_invalid_token_auth() { let db = db::create_test_pool().await; - let actors = GlobalActors::spawn(db.clone()).await.unwrap(); - let mut user_agent = UserAgentActor::new_manual(db.clone(), actors); + + let (server_transport, mut test_transport) = ChannelTransport::new(); + let props = ConnectionProps::new(db.clone(), actors, Box::new(server_transport)); + let task = tokio::spawn(connect_user_agent(props)); let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); - let result = user_agent - .process_transport_inbound(UserAgentRequest { + test_transport + .send(UserAgentRequest { payload: Some(UserAgentRequestPayload::AuthChallengeRequest( AuthChallengeRequest { pubkey: pubkey_bytes, @@ -75,25 +77,27 @@ pub async fn test_bootstrap_invalid_token_auth() { }, )), }) - .await; + .await + .unwrap(); - match result { - Err(err) => { - assert_eq!(err, UserAgentError::InvalidBootstrapToken); - } - Ok(_) => { - panic!("Expected error due to invalid bootstrap token, but got success"); - } - } + // Auth fails, connect_user_agent returns, transport drops + task.await.unwrap(); + + // Verify no key was registered + let mut conn = db.get().await.unwrap(); + let count: i64 = schema::useragent_client::table + .count() + .get_result::(&mut conn) + .await + .unwrap(); + assert_eq!(count, 0); } #[tokio::test] #[test_log::test] pub async fn test_challenge_auth() { let db = db::create_test_pool().await; - let actors = GlobalActors::spawn(db.clone()).await.unwrap(); - let mut user_agent = UserAgentActor::new_manual(db.clone(), actors); let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); @@ -107,8 +111,13 @@ pub async fn test_challenge_auth() { .unwrap(); } - let result = user_agent - .process_transport_inbound(UserAgentRequest { + let (server_transport, mut test_transport) = ChannelTransport::new(); + let props = ConnectionProps::new(db.clone(), actors, Box::new(server_transport)); + let task = tokio::spawn(connect_user_agent(props)); + + // Send challenge request + test_transport + .send(UserAgentRequest { payload: Some(UserAgentRequestPayload::AuthChallengeRequest( AuthChallengeRequest { pubkey: pubkey_bytes, @@ -117,34 +126,36 @@ pub async fn test_challenge_auth() { )), }) .await - .expect("Shouldn't fail to process message"); + .unwrap(); - let UserAgentResponse { - payload: Some(UserAgentResponsePayload::AuthChallenge(challenge)), - } = result - else { - panic!("Expected auth challenge response, got {result:?}"); + // Read the challenge response + let response = test_transport + .recv() + .await + .expect("should receive challenge"); + let challenge = match response { + Ok(resp) => match resp.payload { + Some(UserAgentResponsePayload::AuthChallenge(c)) => c, + other => panic!("Expected AuthChallenge, got {other:?}"), + }, + Err(err) => panic!("Expected Ok response, got Err({err:?})"), }; + // Sign the challenge and send solution let formatted_challenge = arbiter_proto::format_challenge(challenge.nonce, &challenge.pubkey); let signature = new_key.sign(&formatted_challenge); - let serialized_signature = signature.to_bytes().to_vec(); - let result = user_agent - .process_transport_inbound(UserAgentRequest { + test_transport + .send(UserAgentRequest { payload: Some(UserAgentRequestPayload::AuthChallengeSolution( AuthChallengeSolution { - signature: serialized_signature, + signature: signature.to_bytes().to_vec(), }, )), }) .await - .expect("Shouldn't fail to process message"); + .unwrap(); - assert_eq!( - result, - UserAgentResponse { - payload: Some(UserAgentResponsePayload::AuthOk(AuthOk {})), - } - ); + // Auth completes, session spawned + task.await.unwrap(); } diff --git a/server/crates/arbiter-server/tests/user_agent/unseal.rs b/server/crates/arbiter-server/tests/user_agent/unseal.rs index b0f5d1c..4e30ff4 100644 --- a/server/crates/arbiter-server/tests/user_agent/unseal.rs +++ b/server/crates/arbiter-server/tests/user_agent/unseal.rs @@ -1,15 +1,13 @@ use arbiter_proto::proto::user_agent::{ - AuthChallengeRequest, UnsealEncryptedKey, UnsealResult, UnsealStart, - UserAgentRequest, + UnsealEncryptedKey, UnsealResult, UnsealStart, UserAgentRequest, user_agent_request::Payload as UserAgentRequestPayload, user_agent_response::Payload as UserAgentResponsePayload, }; use arbiter_server::{ actors::{ GlobalActors, - bootstrap::GetToken, keyholder::{Bootstrap, Seal}, - user_agent::{UserAgentActor, UserAgentError}, + user_agent::session::UserAgentSession, }, db, }; @@ -17,16 +15,12 @@ use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit}; use memsafe::MemSafe; use x25519_dalek::{EphemeralSecret, PublicKey}; - -async fn setup_authenticated_user_agent( +async fn setup_sealed_user_agent( seal_key: &[u8], -) -> ( - arbiter_server::db::DatabasePool, - UserAgentActor, -) { +) -> (db::DatabasePool, UserAgentSession) { let db = db::create_test_pool().await; - let actors = GlobalActors::spawn(db.clone()).await.unwrap(); + actors .key_holder .ask(Bootstrap { @@ -36,27 +30,13 @@ async fn setup_authenticated_user_agent( .unwrap(); actors.key_holder.ask(Seal).await.unwrap(); - let mut user_agent = UserAgentActor::new_manual(db.clone(), actors.clone()); + let session = UserAgentSession::new_test(db.clone(), actors); - let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap(); - let auth_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); - user_agent - .process_transport_inbound(UserAgentRequest { - payload: Some(UserAgentRequestPayload::AuthChallengeRequest( - AuthChallengeRequest { - pubkey: auth_key.verifying_key().to_bytes().to_vec(), - bootstrap_token: Some(token), - }, - )), - }) - .await - .unwrap(); - - (db, user_agent) + (db, session) } async fn client_dh_encrypt( - user_agent: &mut UserAgentActor, + user_agent: &mut UserAgentSession, key_to_send: &[u8], ) -> UnsealEncryptedKey { let client_secret = EphemeralSecret::random(); @@ -103,7 +83,7 @@ fn unseal_key_request(req: UnsealEncryptedKey) -> UserAgentRequest { #[test_log::test] pub async fn test_unseal_success() { let seal_key = b"test-seal-key"; - let (_db, mut user_agent) = setup_authenticated_user_agent(seal_key).await; + let (_db, mut user_agent) = setup_sealed_user_agent(seal_key).await; let encrypted_key = client_dh_encrypt(&mut user_agent, seal_key).await; @@ -121,7 +101,7 @@ pub async fn test_unseal_success() { #[tokio::test] #[test_log::test] pub async fn test_unseal_wrong_seal_key() { - let (_db, mut user_agent) = setup_authenticated_user_agent(b"correct-key").await; + let (_db, mut user_agent) = setup_sealed_user_agent(b"correct-key").await; let encrypted_key = client_dh_encrypt(&mut user_agent, b"wrong-key").await; @@ -139,7 +119,7 @@ pub async fn test_unseal_wrong_seal_key() { #[tokio::test] #[test_log::test] pub async fn test_unseal_corrupted_ciphertext() { - let (_db, mut user_agent) = setup_authenticated_user_agent(b"test-key").await; + let (_db, mut user_agent) = setup_sealed_user_agent(b"test-key").await; let client_secret = EphemeralSecret::random(); let client_public = PublicKey::from(&client_secret); @@ -168,38 +148,11 @@ pub async fn test_unseal_corrupted_ciphertext() { ); } -#[tokio::test] -#[test_log::test] -pub async fn test_unseal_start_without_auth_fails() { - let db = db::create_test_pool().await; - - let actors = GlobalActors::spawn(db.clone()).await.unwrap(); - let mut user_agent = UserAgentActor::new_manual(db.clone(), actors); - - let client_secret = EphemeralSecret::random(); - let client_public = PublicKey::from(&client_secret); - - let result = user_agent - .process_transport_inbound(UserAgentRequest { - payload: Some(UserAgentRequestPayload::UnsealStart(UnsealStart { - client_pubkey: client_public.as_bytes().to_vec(), - })), - }) - .await; - - match result { - Err(err) => { - assert_eq!(err, UserAgentError::StateTransitionFailed); - } - other => panic!("Expected state machine error, got {other:?}"), - } -} - #[tokio::test] #[test_log::test] pub async fn test_unseal_retry_after_invalid_key() { let seal_key = b"real-seal-key"; - let (_db, mut user_agent) = setup_authenticated_user_agent(seal_key).await; + let (_db, mut user_agent) = setup_sealed_user_agent(seal_key).await; { let encrypted_key = client_dh_encrypt(&mut user_agent, b"wrong-key").await;