8 Commits

Author SHA1 Message Date
hdbg
ba86d18250 refactor(server::client::auth): removed state machine and added approval flow coordination
Some checks failed
ci/woodpecker/pr/server-audit Pipeline was successful
ci/woodpecker/pr/server-vet Pipeline failed
ci/woodpecker/pr/server-lint Pipeline failed
ci/woodpecker/pr/server-test Pipeline was successful
2026-03-12 16:12:19 +01:00
hdbg
606a1f3774 feat(server::{router, useragent}): inter-actor approval coordination 2026-03-11 20:07:06 +01:00
hdbg
b3a67ffc00 feat(server::client): proper connect error 2026-03-11 17:58:44 +01:00
hdbg
168290040c feat(server::client): approval flow through user-agent on first-time client connects 2026-03-11 16:31:58 +01:00
hdbg
cb05407bb6 feat(server): broker agent for inter-actor coordination
Some checks failed
ci/woodpecker/push/server-vet Pipeline failed
ci/woodpecker/push/server-audit Pipeline was successful
ci/woodpecker/push/server-lint Pipeline failed
ci/woodpecker/push/server-test Pipeline was successful
2026-03-11 14:08:15 +01:00
4beb34764d Merge pull request 'refactor(server::{user_agent, client}): move auth part to separate function to not to pollute actor session with one-time concerns' (#24) from push-upvpzwvlwyvs into main
Reviewed-on: #24
2026-03-11 14:08:15 +01:00
hdbg
54d0fe0505 refactor(server::{user_agent, client}): move auth part to separate function to not to pollute actor session with one-time concerns 2026-03-11 14:08:15 +01:00
hdbg
657f47e32f refactor(transport): convert Bi trait to use async_trait 2026-03-11 14:08:15 +01:00
27 changed files with 1650 additions and 1029 deletions

View File

@@ -4,6 +4,52 @@ This document covers concrete technology choices and dependencies. For the archi
--- ---
## Client Connection Flow
### New Client Approval
When a client whose public key is not yet in the database connects, all connected user agents are asked to approve the connection. The first agent to respond determines the outcome; remaining requests are cancelled via a watch channel.
```mermaid
flowchart TD
A([Client connects]) --> B[Receive AuthChallengeRequest]
B --> C{pubkey in DB?}
C -- yes --> D[Read nonce\nIncrement nonce in DB]
D --> G
C -- no --> E[Ask all UserAgents:\nClientConnectionRequest]
E --> F{First response}
F -- denied --> Z([Reject connection])
F -- approved --> F2[Cancel remaining\nUserAgent requests]
F2 --> F3[INSERT client\nnonce = 1]
F3 --> G[Send AuthChallenge\nwith nonce]
G --> H[Receive AuthChallengeSolution]
H --> I{Signature valid?}
I -- no --> Z
I -- yes --> J([Session started])
```
### Known Issue: Concurrent Registration Race (TOCTOU)
Two connections presenting the same previously-unknown public key can race through the approval flow simultaneously:
1. Both check the DB → neither is registered.
2. Both request approval from user agents → both receive approval.
3. Both `INSERT` the client record → the second insert silently overwrites the first, resetting the nonce.
This means the first connection's nonce is invalidated by the second, causing its challenge verification to fail. A fix requires either serialising new-client registration (e.g. an in-memory lock keyed on pubkey) or replacing the separate check + insert with an `INSERT OR IGNORE` / upsert guarded by a unique constraint on `public_key`.
### Nonce Semantics
The `program_client.nonce` column stores the **next usable nonce** — i.e. it is always one ahead of the nonce last issued in a challenge.
- **New client:** inserted with `nonce = 1`; the first challenge is issued with `nonce = 0`.
- **Existing client:** the current DB value is read and used as the challenge nonce, then immediately incremented within the same exclusive transaction, preventing replay.
---
## Cryptography ## Cryptography
### Authentication ### Authentication

View File

@@ -24,9 +24,19 @@ message ClientRequest {
} }
} }
message ClientConnectError {
enum Code {
UNKNOWN = 0;
APPROVAL_DENIED = 1;
NO_USER_AGENTS_ONLINE = 2;
}
Code code = 1;
}
message ClientResponse { message ClientResponse {
oneof payload { oneof payload {
AuthChallenge auth_challenge = 1; AuthChallenge auth_challenge = 1;
AuthOk auth_ok = 2; AuthOk auth_ok = 2;
ClientConnectError client_connect_error = 5;
} }
} }

View File

@@ -48,6 +48,16 @@ enum VaultState {
VAULT_STATE_ERROR = 4; VAULT_STATE_ERROR = 4;
} }
message ClientConnectionRequest {
bytes pubkey = 1;
}
message ClientConnectionResponse {
bool approved = 1;
}
message ClientConnectionCancel {}
message UserAgentRequest { message UserAgentRequest {
oneof payload { oneof payload {
AuthChallengeRequest auth_challenge_request = 1; AuthChallengeRequest auth_challenge_request = 1;
@@ -55,6 +65,7 @@ message UserAgentRequest {
UnsealStart unseal_start = 3; UnsealStart unseal_start = 3;
UnsealEncryptedKey unseal_encrypted_key = 4; UnsealEncryptedKey unseal_encrypted_key = 4;
google.protobuf.Empty query_vault_state = 5; google.protobuf.Empty query_vault_state = 5;
ClientConnectionResponse client_connection_response = 11;
} }
} }
message UserAgentResponse { message UserAgentResponse {
@@ -64,5 +75,7 @@ message UserAgentResponse {
UnsealStartResponse unseal_start_response = 3; UnsealStartResponse unseal_start_response = 3;
UnsealResult unseal_result = 4; UnsealResult unseal_result = 4;
VaultState vault_state = 5; VaultState vault_state = 5;
ClientConnectionRequest client_connection_request = 11;
ClientConnectionCancel client_connection_cancel = 12;
} }
} }

2
server/Cargo.lock generated
View File

@@ -59,6 +59,7 @@ version = "0.1.0"
name = "arbiter-proto" name = "arbiter-proto"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"async-trait",
"base64", "base64",
"futures", "futures",
"hex", "hex",
@@ -122,6 +123,7 @@ name = "arbiter-useragent"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"arbiter-proto", "arbiter-proto",
"async-trait",
"ed25519-dalek", "ed25519-dalek",
"http", "http",
"kameo", "kameo",

View File

@@ -20,6 +20,7 @@ rustls-pki-types.workspace = true
base64 = "0.22.1" base64 = "0.22.1"
prost-types.workspace = true prost-types.workspace = true
tracing.workspace = true tracing.workspace = true
async-trait.workspace = true
[build-dependencies] [build-dependencies]
tonic-prost-build = "0.14.3" tonic-prost-build = "0.14.3"

View File

@@ -76,10 +76,30 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use async_trait::async_trait;
/// Errors returned by transport adapters implementing [`Bi`]. /// Errors returned by transport adapters implementing [`Bi`].
#[derive(thiserror::Error, Debug)]
pub enum Error { pub enum Error {
/// The outbound side of the transport is no longer accepting messages. #[error("Transport channel is closed")]
ChannelClosed, ChannelClosed,
#[error("Unexpected message received")]
UnexpectedMessage,
}
/// Receives one message from `transport` and extracts a value from it using
/// `extractor`. Returns [`Error::ChannelClosed`] if the transport closes and
/// [`Error::UnexpectedMessage`] if `extractor` returns `None`.
pub async fn expect_message<T, Inbound, Outbound, Target, F>(
transport: &mut T,
extractor: F,
) -> Result<Target, Error>
where
T: Bi<Inbound, Outbound> + ?Sized,
F: FnOnce(Inbound) -> Option<Target>,
{
let msg = transport.recv().await.ok_or(Error::ChannelClosed)?;
extractor(msg).ok_or(Error::UnexpectedMessage)
} }
/// Minimal bidirectional transport abstraction used by protocol code. /// Minimal bidirectional transport abstraction used by protocol code.
@@ -87,13 +107,11 @@ pub enum Error {
/// `Bi<Inbound, Outbound>` models a duplex channel with: /// `Bi<Inbound, Outbound>` models a duplex channel with:
/// - inbound items of type `Inbound` read via [`Bi::recv`] /// - inbound items of type `Inbound` read via [`Bi::recv`]
/// - outbound items of type `Outbound` written via [`Bi::send`] /// - outbound items of type `Outbound` written via [`Bi::send`]
#[async_trait]
pub trait Bi<Inbound, Outbound>: Send + Sync + 'static { pub trait Bi<Inbound, Outbound>: Send + Sync + 'static {
fn send( async fn send(&mut self, item: Outbound) -> Result<(), Error>;
&mut self,
item: Outbound,
) -> impl std::future::Future<Output = Result<(), Error>> + Send;
fn recv(&mut self) -> impl std::future::Future<Output = Option<Inbound>> + Send; async fn recv(&mut self) -> Option<Inbound>;
} }
/// Converts transport-facing inbound items into protocol-facing inbound items. /// Converts transport-facing inbound items into protocol-facing inbound items.
@@ -176,6 +194,7 @@ where
/// gRPC-specific transport adapters and helpers. /// gRPC-specific transport adapters and helpers.
pub mod grpc { pub mod grpc {
use async_trait::async_trait;
use futures::StreamExt; use futures::StreamExt;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tonic::Streaming; use tonic::Streaming;
@@ -199,7 +218,6 @@ pub mod grpc {
outbound_converter: OutboundConverter, outbound_converter: OutboundConverter,
} }
impl<InboundTransport, Inbound, InboundConverter, OutboundConverter> impl<InboundTransport, Inbound, InboundConverter, OutboundConverter>
GrpcAdapter<InboundConverter, OutboundConverter> GrpcAdapter<InboundConverter, OutboundConverter>
where where
@@ -221,7 +239,7 @@ pub mod grpc {
} }
} }
#[async_trait]
impl<InboundConverter, OutboundConverter> Bi<InboundConverter::Output, OutboundConverter::Input> impl<InboundConverter, OutboundConverter> Bi<InboundConverter::Output, OutboundConverter::Input>
for GrpcAdapter<InboundConverter, OutboundConverter> for GrpcAdapter<InboundConverter, OutboundConverter>
where where
@@ -275,6 +293,7 @@ impl<Inbound, Outbound> Default for DummyTransport<Inbound, Outbound> {
} }
} }
#[async_trait]
impl<Inbound, Outbound> Bi<Inbound, Outbound> for DummyTransport<Inbound, Outbound> impl<Inbound, Outbound> Bi<Inbound, Outbound> for DummyTransport<Inbound, Outbound>
where where
Inbound: Send + Sync + 'static, Inbound: Send + Sync + 'static,
@@ -284,10 +303,8 @@ where
Ok(()) Ok(())
} }
fn recv(&mut self) -> impl std::future::Future<Output = Option<Inbound>> + Send { async fn recv(&mut self) -> Option<Inbound> {
async {
std::future::pending::<()>().await; std::future::pending::<()>().await;
None None
} }
} }
}

View File

@@ -0,0 +1,251 @@
use arbiter_proto::{
format_challenge,
proto::client::{
AuthChallenge, AuthChallengeSolution, ClientConnectError, ClientRequest, ClientResponse,
client_connect_error::Code as ConnectErrorCode,
client_request::Payload as ClientRequestPayload,
client_response::Payload as ClientResponsePayload,
},
transport::expect_message,
};
use diesel::{
ExpressionMethods as _, OptionalExtension as _, QueryDsl as _, dsl::insert_into, update,
};
use diesel_async::RunQueryDsl as _;
use ed25519_dalek::VerifyingKey;
use kameo::error::SendError;
use tracing::error;
use crate::{
actors::{client::ClientConnection, router::{self, RequestClientApproval}},
db::{self, schema::program_client},
};
use super::session::ClientSession;
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
pub enum Error {
#[error("Unexpected message payload")]
UnexpectedMessagePayload,
#[error("Invalid client public key length")]
InvalidClientPubkeyLength,
#[error("Invalid client public key encoding")]
InvalidAuthPubkeyEncoding,
#[error("Database pool unavailable")]
DatabasePoolUnavailable,
#[error("Database operation failed")]
DatabaseOperationFailed,
#[error("Invalid challenge solution")]
InvalidChallengeSolution,
#[error("Client approval request failed")]
ApproveError(#[from] ApproveError),
#[error("Internal error")]
InternalError,
#[error("Transport error")]
Transport,
}
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
pub enum ApproveError {
#[error("Internal error")]
Internal,
#[error("Client connection denied by user agents")]
Denied,
#[error("Upstream error: {0}")]
Upstream(router::ApprovalError),
}
/// Atomically reads and increments the nonce for a known client.
/// Returns `None` if the pubkey is not registered.
async fn get_nonce(db: &db::DatabasePool, pubkey: &VerifyingKey) -> Result<Option<i32>, Error> {
let pubkey_bytes = pubkey.as_bytes().to_vec();
let mut conn = db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error");
Error::DatabasePoolUnavailable
})?;
conn.exclusive_transaction(|conn| {
let pubkey_bytes = pubkey_bytes.clone();
Box::pin(async move {
let Some(current_nonce) = program_client::table
.filter(program_client::public_key.eq(&pubkey_bytes))
.select(program_client::nonce)
.first::<i32>(conn)
.await
.optional()?
else {
return Result::<_, diesel::result::Error>::Ok(None);
};
update(program_client::table)
.filter(program_client::public_key.eq(&pubkey_bytes))
.set(program_client::nonce.eq(current_nonce + 1))
.execute(conn)
.await?;
Ok(Some(current_nonce))
})
})
.await
.map_err(|e| {
error!(error = ?e, "Database error");
Error::DatabaseOperationFailed
})
}
async fn approve_new_client(
actors: &crate::actors::GlobalActors,
pubkey: VerifyingKey,
) -> Result<(), Error> {
let result = actors
.router
.ask(RequestClientApproval { client_pubkey: pubkey })
.await;
match result {
Ok(true) => Ok(()),
Ok(false) => Err(Error::ApproveError(ApproveError::Denied)),
Err(SendError::HandlerError(e)) => {
error!(error = ?e, "Approval upstream error");
Err(Error::ApproveError(ApproveError::Upstream(e)))
}
Err(e) => {
error!(error = ?e, "Approval request to router failed");
Err(Error::ApproveError(ApproveError::Internal))
}
}
}
async fn insert_client(db: &db::DatabasePool, pubkey: &VerifyingKey) -> Result<(), Error> {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i32;
let mut conn = db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error");
Error::DatabasePoolUnavailable
})?;
insert_into(program_client::table)
.values((
program_client::public_key.eq(pubkey.as_bytes().to_vec()),
program_client::nonce.eq(1), // pre-incremented; challenge uses 0
program_client::created_at.eq(now),
program_client::updated_at.eq(now),
))
.execute(&mut conn)
.await
.map_err(|e| {
error!(error = ?e, "Failed to insert new client");
Error::DatabaseOperationFailed
})?;
Ok(())
}
async fn challenge_client(
props: &mut ClientConnection,
pubkey: VerifyingKey,
nonce: i32,
) -> Result<(), Error> {
let challenge = AuthChallenge {
pubkey: pubkey.as_bytes().to_vec(),
nonce,
};
props
.transport
.send(Ok(ClientResponse {
payload: Some(ClientResponsePayload::AuthChallenge(challenge.clone())),
}))
.await
.map_err(|e| {
error!(error = ?e, "Failed to send auth challenge");
Error::Transport
})?;
let AuthChallengeSolution { signature } = expect_message(
&mut *props.transport,
|req: ClientRequest| match req.payload? {
ClientRequestPayload::AuthChallengeSolution(s) => Some(s),
_ => None,
},
)
.await
.map_err(|e| {
error!(error = ?e, "Failed to receive challenge solution");
Error::Transport
})?;
let formatted = format_challenge(nonce, &challenge.pubkey);
let sig = signature.as_slice().try_into().map_err(|_| {
error!("Invalid signature length");
Error::InvalidChallengeSolution
})?;
pubkey.verify_strict(&formatted, &sig).map_err(|_| {
error!("Challenge solution verification failed");
Error::InvalidChallengeSolution
})?;
Ok(())
}
fn connect_error_code(err: &Error) -> ConnectErrorCode {
match err {
Error::ApproveError(ApproveError::Denied) => ConnectErrorCode::ApprovalDenied,
Error::ApproveError(ApproveError::Upstream(router::ApprovalError::NoUserAgentsConnected)) => {
ConnectErrorCode::NoUserAgentsOnline
}
_ => ConnectErrorCode::Unknown,
}
}
async fn authenticate(props: &mut ClientConnection) -> Result<VerifyingKey, Error> {
let Some(ClientRequest {
payload: Some(ClientRequestPayload::AuthChallengeRequest(challenge)),
}) = props.transport.recv().await
else {
return Err(Error::Transport);
};
let pubkey_bytes = challenge
.pubkey
.as_array()
.ok_or(Error::InvalidClientPubkeyLength)?;
let pubkey =
VerifyingKey::from_bytes(pubkey_bytes).map_err(|_| Error::InvalidAuthPubkeyEncoding)?;
let nonce = match get_nonce(&props.db, &pubkey).await? {
Some(nonce) => nonce,
None => {
approve_new_client(&props.actors, pubkey).await?;
insert_client(&props.db, &pubkey).await?;
0
}
};
challenge_client(props, pubkey, nonce).await?;
Ok(pubkey)
}
pub async fn authenticate_and_create(mut props: ClientConnection) -> Result<ClientSession, Error> {
match authenticate(&mut props).await {
Ok(pubkey) => Ok(ClientSession::new(props, pubkey)),
Err(err) => {
let code = connect_error_code(&err);
let _ = props
.transport
.send(Ok(ClientResponse {
payload: Some(ClientResponsePayload::ClientConnectError(
ClientConnectError { code: code.into() },
)),
}))
.await;
Err(err)
}
}
}

View File

@@ -1,289 +1,58 @@
use arbiter_proto::{ use arbiter_proto::{
proto::client::{ proto::client::{ClientRequest, ClientResponse},
AuthChallenge, AuthChallengeRequest, AuthChallengeSolution, AuthOk, ClientRequest, transport::Bi,
ClientResponse,
client_request::Payload as ClientRequestPayload,
client_response::Payload as ClientResponsePayload,
},
transport::{Bi, DummyTransport},
}; };
use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, dsl::update}; use kameo::actor::Spawn;
use diesel_async::RunQueryDsl;
use ed25519_dalek::VerifyingKey;
use kameo::Actor;
use tokio::select;
use tracing::{error, info}; use tracing::{error, info};
use crate::{ use crate::{
ServerContext, actors::{GlobalActors, client::session::ClientSession},
actors::client::state::{ db,
ChallengeContext, ClientEvents, ClientStateMachine, ClientStates, DummyContext,
},
db::{self, schema},
}; };
mod state;
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] #[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
pub enum ClientError { pub enum ClientError {
#[error("Expected message with payload")] #[error("Expected message with payload")]
MissingRequestPayload, MissingRequestPayload,
#[error("Unexpected request payload")] #[error("Unexpected request payload")]
UnexpectedRequestPayload, UnexpectedRequestPayload,
#[error("Invalid state for challenge solution")]
InvalidStateForChallengeSolution,
#[error("Expected pubkey to have specific length")]
InvalidAuthPubkeyLength,
#[error("Failed to convert pubkey to VerifyingKey")]
InvalidAuthPubkeyEncoding,
#[error("Invalid signature length")]
InvalidSignatureLength,
#[error("Public key not registered")]
PublicKeyNotRegistered,
#[error("Invalid challenge solution")]
InvalidChallengeSolution,
#[error("State machine error")] #[error("State machine error")]
StateTransitionFailed, StateTransitionFailed,
#[error("Database pool error")] #[error("Connection registration failed")]
DatabasePoolUnavailable, ConnectionRegistrationFailed,
#[error("Database error")] #[error(transparent)]
DatabaseOperationFailed, Auth(#[from] auth::Error),
} }
pub struct ClientActor<Transport> pub type Transport = Box<dyn Bi<ClientRequest, Result<ClientResponse, ClientError>> + Send>;
where
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>, pub struct ClientConnection {
{ pub(crate) db: db::DatabasePool,
db: db::DatabasePool, pub(crate) transport: Transport,
state: ClientStateMachine<DummyContext>, pub(crate) actors: GlobalActors,
transport: Transport,
} }
impl<Transport> ClientActor<Transport> impl ClientConnection {
where pub fn new(db: db::DatabasePool, transport: Transport, actors: GlobalActors) -> Self {
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
{
pub(crate) fn new(context: ServerContext, transport: Transport) -> Self {
Self {
db: context.db.clone(),
state: ClientStateMachine::new(DummyContext),
transport,
}
}
fn transition(&mut self, event: ClientEvents) -> Result<(), ClientError> {
self.state.process_event(event).map_err(|e| {
error!(?e, "State transition failed");
ClientError::StateTransitionFailed
})?;
Ok(())
}
pub async fn process_transport_inbound(&mut self, req: ClientRequest) -> Output {
let msg = req.payload.ok_or_else(|| {
error!(actor = "client", "Received message with no payload");
ClientError::MissingRequestPayload
})?;
match msg {
ClientRequestPayload::AuthChallengeRequest(req) => {
self.handle_auth_challenge_request(req).await
}
ClientRequestPayload::AuthChallengeSolution(solution) => {
self.handle_auth_challenge_solution(solution).await
}
}
}
async fn handle_auth_challenge_request(&mut self, req: AuthChallengeRequest) -> Output {
let pubkey = req
.pubkey
.as_array()
.ok_or(ClientError::InvalidAuthPubkeyLength)?;
let pubkey = VerifyingKey::from_bytes(pubkey).map_err(|_err| {
error!(?pubkey, "Failed to convert to VerifyingKey");
ClientError::InvalidAuthPubkeyEncoding
})?;
self.transition(ClientEvents::AuthRequest)?;
self.auth_with_challenge(pubkey, req.pubkey).await
}
async fn auth_with_challenge(&mut self, pubkey: VerifyingKey, pubkey_bytes: Vec<u8>) -> Output {
let nonce: Option<i32> = {
let mut db_conn = self.db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error");
ClientError::DatabasePoolUnavailable
})?;
db_conn
.exclusive_transaction(|conn| {
Box::pin(async move {
let current_nonce = schema::program_client::table
.filter(
schema::program_client::public_key.eq(pubkey.as_bytes().to_vec()),
)
.select(schema::program_client::nonce)
.first::<i32>(conn)
.await?;
update(schema::program_client::table)
.filter(
schema::program_client::public_key.eq(pubkey.as_bytes().to_vec()),
)
.set(schema::program_client::nonce.eq(current_nonce + 1))
.execute(conn)
.await?;
Result::<_, diesel::result::Error>::Ok(current_nonce)
})
})
.await
.optional()
.map_err(|e| {
error!(error = ?e, "Database error");
ClientError::DatabaseOperationFailed
})?
};
let Some(nonce) = nonce else {
error!(?pubkey, "Public key not found in database");
return Err(ClientError::PublicKeyNotRegistered);
};
let challenge = AuthChallenge {
pubkey: pubkey_bytes,
nonce,
};
self.transition(ClientEvents::SentChallenge(ChallengeContext {
challenge: challenge.clone(),
key: pubkey,
}))?;
info!(
?pubkey,
?challenge,
"Sent authentication challenge to client"
);
Ok(response(ClientResponsePayload::AuthChallenge(challenge)))
}
fn verify_challenge_solution(
&self,
solution: &AuthChallengeSolution,
) -> Result<(bool, &ChallengeContext), ClientError> {
let ClientStates::WaitingForChallengeSolution(challenge_context) = self.state.state()
else {
error!("Received challenge solution in invalid state");
return Err(ClientError::InvalidStateForChallengeSolution);
};
let formatted_challenge = arbiter_proto::format_challenge(
challenge_context.challenge.nonce,
&challenge_context.challenge.pubkey,
);
let signature = solution.signature.as_slice().try_into().map_err(|_| {
error!(?solution, "Invalid signature length");
ClientError::InvalidSignatureLength
})?;
let valid = challenge_context
.key
.verify_strict(&formatted_challenge, &signature)
.is_ok();
Ok((valid, challenge_context))
}
async fn handle_auth_challenge_solution(
&mut self,
solution: AuthChallengeSolution,
) -> Output {
let (valid, challenge_context) = self.verify_challenge_solution(&solution)?;
if valid {
info!(
?challenge_context,
"Client provided valid solution to authentication challenge"
);
self.transition(ClientEvents::ReceivedGoodSolution)?;
Ok(response(ClientResponsePayload::AuthOk(AuthOk {})))
} else {
error!("Client provided invalid solution to authentication challenge");
self.transition(ClientEvents::ReceivedBadSolution)?;
Err(ClientError::InvalidChallengeSolution)
}
}
}
type Output = Result<ClientResponse, ClientError>;
fn response(payload: ClientResponsePayload) -> ClientResponse {
ClientResponse {
payload: Some(payload),
}
}
impl<Transport> Actor for ClientActor<Transport>
where
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
{
type Args = Self;
type Error = ();
async fn on_start(
args: Self::Args,
_: kameo::prelude::ActorRef<Self>,
) -> Result<Self, Self::Error> {
Ok(args)
}
async fn next(
&mut self,
_actor_ref: kameo::prelude::WeakActorRef<Self>,
mailbox_rx: &mut kameo::prelude::MailboxReceiver<Self>,
) -> Option<kameo::mailbox::Signal<Self>> {
loop {
select! {
signal = mailbox_rx.recv() => {
return signal;
}
msg = self.transport.recv() => {
match msg {
Some(request) => {
match self.process_transport_inbound(request).await {
Ok(resp) => {
if self.transport.send(Ok(resp)).await.is_err() {
error!(actor = "client", reason = "channel closed", "send.failed");
return Some(kameo::mailbox::Signal::Stop);
}
}
Err(err) => {
let _ = self.transport.send(Err(err)).await;
return Some(kameo::mailbox::Signal::Stop);
}
}
}
None => {
info!(actor = "client", "transport.closed");
return Some(kameo::mailbox::Signal::Stop);
}
}
}
}
}
}
}
impl ClientActor<DummyTransport<ClientRequest, Result<ClientResponse, ClientError>>> {
pub fn new_manual(db: db::DatabasePool) -> Self {
Self { Self {
db, db,
state: ClientStateMachine::new(DummyContext), transport,
transport: DummyTransport::new(), actors,
}
}
}
pub mod auth;
pub mod session;
pub async fn connect_client(props: ClientConnection) {
match auth::authenticate_and_create(props).await {
Ok(session) => {
ClientSession::spawn(session);
info!("Client authenticated, session started");
}
Err(err) => {
error!(?err, "Authentication failed, closing connection");
} }
} }
} }

View File

@@ -0,0 +1,98 @@
use arbiter_proto::proto::client::{ClientRequest, ClientResponse};
use ed25519_dalek::VerifyingKey;
use kameo::Actor;
use tokio::select;
use tracing::{error, info};
use crate::{actors::{
GlobalActors, client::{ClientError, ClientConnection}, router::RegisterClient
}, db};
pub struct ClientSession {
props: ClientConnection,
key: VerifyingKey,
}
impl ClientSession {
pub(crate) fn new(props: ClientConnection, key: VerifyingKey) -> Self {
Self { props, key }
}
pub async fn process_transport_inbound(&mut self, req: ClientRequest) -> Output {
let msg = req.payload.ok_or_else(|| {
error!(actor = "client", "Received message with no payload");
ClientError::MissingRequestPayload
})?;
match msg {
_ => Err(ClientError::UnexpectedRequestPayload),
}
}
}
type Output = Result<ClientResponse, ClientError>;
impl Actor for ClientSession {
type Args = Self;
type Error = ClientError;
async fn on_start(
args: Self::Args,
this: kameo::prelude::ActorRef<Self>,
) -> Result<Self, Self::Error> {
args.props
.actors
.router
.ask(RegisterClient { actor: this })
.await
.map_err(|_| ClientError::ConnectionRegistrationFailed)?;
Ok(args)
}
async fn next(
&mut self,
_actor_ref: kameo::prelude::WeakActorRef<Self>,
mailbox_rx: &mut kameo::prelude::MailboxReceiver<Self>,
) -> Option<kameo::mailbox::Signal<Self>> {
loop {
select! {
signal = mailbox_rx.recv() => {
return signal;
}
msg = self.props.transport.recv() => {
match msg {
Some(request) => {
match self.process_transport_inbound(request).await {
Ok(resp) => {
if self.props.transport.send(Ok(resp)).await.is_err() {
error!(actor = "client", reason = "channel closed", "send.failed");
return Some(kameo::mailbox::Signal::Stop);
}
}
Err(err) => {
let _ = self.props.transport.send(Err(err)).await;
return Some(kameo::mailbox::Signal::Stop);
}
}
}
None => {
info!(actor = "client", "transport.closed");
return Some(kameo::mailbox::Signal::Stop);
}
}
}
}
}
}
}
impl ClientSession {
pub fn new_test(db: db::DatabasePool, actors: GlobalActors) -> Self {
use arbiter_proto::transport::DummyTransport;
let transport: super::Transport = Box::new(DummyTransport::new());
let props = ClientConnection::new(db, transport, actors);
let key = VerifyingKey::from_bytes(&[0u8; 32]).unwrap();
Self { props, key }
}
}

View File

@@ -1,31 +0,0 @@
use arbiter_proto::proto::client::AuthChallenge;
use ed25519_dalek::VerifyingKey;
/// Context for state machine with validated key and sent challenge
#[derive(Clone, Debug)]
pub struct ChallengeContext {
pub challenge: AuthChallenge,
pub key: VerifyingKey,
}
smlang::statemachine!(
name: Client,
custom_error: false,
transitions: {
*Init + AuthRequest = ReceivedAuthRequest,
ReceivedAuthRequest + SentChallenge(ChallengeContext) / move_challenge = WaitingForChallengeSolution(ChallengeContext),
WaitingForChallengeSolution(ChallengeContext) + ReceivedGoodSolution = Idle,
WaitingForChallengeSolution(ChallengeContext) + ReceivedBadSolution = AuthError,
}
);
pub struct DummyContext;
impl ClientStateMachineContext for DummyContext {
#[allow(missing_docs)]
#[allow(clippy::unused_unit)]
fn move_challenge(&mut self, event_data: ChallengeContext) -> Result<ChallengeContext, ()> {
Ok(event_data)
}
}

View File

@@ -3,14 +3,15 @@ use miette::Diagnostic;
use thiserror::Error; use thiserror::Error;
use crate::{ use crate::{
actors::{bootstrap::Bootstrapper, keyholder::KeyHolder}, actors::{bootstrap::Bootstrapper, keyholder::KeyHolder, router::MessageRouter},
db, db,
}; };
pub mod bootstrap; pub mod bootstrap;
pub mod client; pub mod router;
pub mod keyholder; pub mod keyholder;
pub mod user_agent; pub mod user_agent;
pub mod client;
#[derive(Error, Debug, Diagnostic)] #[derive(Error, Debug, Diagnostic)]
pub enum SpawnError { pub enum SpawnError {
@@ -28,6 +29,7 @@ pub enum SpawnError {
pub struct GlobalActors { pub struct GlobalActors {
pub key_holder: ActorRef<KeyHolder>, pub key_holder: ActorRef<KeyHolder>,
pub bootstrapper: ActorRef<Bootstrapper>, pub bootstrapper: ActorRef<Bootstrapper>,
pub router: ActorRef<MessageRouter>,
} }
impl GlobalActors { impl GlobalActors {
@@ -35,6 +37,7 @@ impl GlobalActors {
Ok(Self { Ok(Self {
bootstrapper: Bootstrapper::spawn(Bootstrapper::new(&db).await?), bootstrapper: Bootstrapper::spawn(Bootstrapper::new(&db).await?),
key_holder: KeyHolder::spawn(KeyHolder::new(db.clone()).await?), key_holder: KeyHolder::spawn(KeyHolder::new(db.clone()).await?),
router: MessageRouter::spawn(MessageRouter::default()),
}) })
} }
} }

View File

@@ -0,0 +1,175 @@
use std::{collections::HashMap, ops::ControlFlow};
use ed25519_dalek::VerifyingKey;
use kameo::{
Actor,
actor::{ActorId, ActorRef},
messages,
prelude::{ActorStopReason, Context, WeakActorRef},
reply::DelegatedReply,
};
use tokio::{sync::watch, task::JoinSet};
use tracing::{info, warn};
use crate::actors::{
client::session::ClientSession,
user_agent::session::{RequestNewClientApproval, UserAgentSession},
};
#[derive(Default)]
pub struct MessageRouter {
pub user_agents: HashMap<ActorId, ActorRef<UserAgentSession>>,
pub clients: HashMap<ActorId, ActorRef<ClientSession>>,
}
impl Actor for MessageRouter {
type Args = Self;
type Error = ();
async fn on_start(args: Self::Args, _: ActorRef<Self>) -> Result<Self, Self::Error> {
Ok(args)
}
async fn on_link_died(
&mut self,
_: WeakActorRef<Self>,
id: ActorId,
_: ActorStopReason,
) -> Result<ControlFlow<ActorStopReason>, Self::Error> {
if self.user_agents.remove(&id).is_some() {
info!(
?id,
actor = "MessageRouter",
event = "useragent.disconnected"
);
} else if self.clients.remove(&id).is_some() {
info!(?id, actor = "MessageRouter", event = "client.disconnected");
} else {
info!(
?id,
actor = "MessageRouter",
event = "unknown.actor.disconnected"
);
}
Ok(ControlFlow::Continue(()))
}
}
#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq, Hash)]
pub enum ApprovalError {
#[error("No user agents connected")]
NoUserAgentsConnected,
}
async fn request_client_approval(
user_agents: &[WeakActorRef<UserAgentSession>],
client_pubkey: VerifyingKey,
) -> Result<bool, ApprovalError> {
if user_agents.is_empty() {
return Err(ApprovalError::NoUserAgentsConnected).into();
}
let mut pool = JoinSet::new();
let (cancel_tx, cancel_rx) = watch::channel(());
for weak_ref in user_agents {
match weak_ref.upgrade() {
Some(agent) => {
let client_pubkey = client_pubkey.clone();
let cancel_rx = cancel_rx.clone();
pool.spawn(async move {
agent
.ask(RequestNewClientApproval {
client_pubkey,
cancel_flag: cancel_rx.clone(),
})
.await
});
}
None => {
warn!(
id = weak_ref.id().to_string(),
actor = "MessageRouter",
event = "useragent.disconnected_before_approval"
);
}
}
}
while let Some(result) = pool.join_next().await {
match result {
Ok(Ok(approved)) => {
// cancel other pending requests
let _ = cancel_tx.send(());
return Ok(approved);
}
Ok(Err(err)) => {
warn!(
?err,
actor = "MessageRouter",
event = "useragent.approval_error"
);
}
Err(err) => {
warn!(
?err,
actor = "MessageRouter",
event = "useragent.approval_task_failed"
);
}
}
}
Err(ApprovalError::NoUserAgentsConnected)
}
#[messages]
impl MessageRouter {
#[message(ctx)]
pub async fn register_user_agent(
&mut self,
actor: ActorRef<UserAgentSession>,
ctx: &mut Context<Self, ()>,
) {
info!(id = %actor.id(), actor = "MessageRouter", event = "useragent.connected");
ctx.actor_ref().link(&actor).await;
self.user_agents.insert(actor.id(), actor);
}
#[message(ctx)]
pub async fn register_client(
&mut self,
actor: ActorRef<ClientSession>,
ctx: &mut Context<Self, ()>,
) {
info!(id = %actor.id(), actor = "MessageRouter", event = "client.connected");
ctx.actor_ref().link(&actor).await;
self.clients.insert(actor.id(), actor);
}
#[message(ctx)]
pub async fn request_client_approval(
&mut self,
client_pubkey: VerifyingKey,
ctx: &mut Context<Self, DelegatedReply<Result<bool, ApprovalError>>>,
) -> DelegatedReply<Result<bool, ApprovalError>> {
let (reply, Some(reply_sender)) = ctx.reply_sender() else {
panic!("Exptected `request_client_approval` to have callback channel");
};
let weak_refs = self
.user_agents
.values()
.map(|agent| agent.downgrade())
.collect::<Vec<_>>();
// handle in subtask to not to lock the actor
tokio::task::spawn(async move {
let result = request_client_approval(&weak_refs, client_pubkey).await;
let _ = reply_sender.send(result);
});
reply
}
}

View File

@@ -0,0 +1,118 @@
use arbiter_proto::proto::user_agent::{
AuthChallengeRequest, AuthChallengeSolution, UserAgentRequest,
user_agent_request::Payload as UserAgentRequestPayload,
};
use ed25519_dalek::VerifyingKey;
use tracing::error;
use crate::actors::user_agent::{
UserAgentConnection,
auth::state::{AuthContext, AuthStateMachine}, session::UserAgentSession,
};
#[derive(thiserror::Error, Debug, PartialEq)]
pub enum Error {
#[error("Unexpected message payload")]
UnexpectedMessagePayload,
#[error("Invalid client public key length")]
InvalidClientPubkeyLength,
#[error("Invalid client public key encoding")]
InvalidAuthPubkeyEncoding,
#[error("Database pool unavailable")]
DatabasePoolUnavailable,
#[error("Database operation failed")]
DatabaseOperationFailed,
#[error("Public key not registered")]
PublicKeyNotRegistered,
#[error("Transport error")]
Transport,
#[error("Invalid bootstrap token")]
InvalidBootstrapToken,
#[error("Bootstrapper actor unreachable")]
BootstrapperActorUnreachable,
#[error("Invalid challenge solution")]
InvalidChallengeSolution,
}
mod state;
use state::*;
fn parse_auth_event(payload: UserAgentRequestPayload) -> Result<AuthEvents, Error> {
match payload {
UserAgentRequestPayload::AuthChallengeRequest(AuthChallengeRequest {
pubkey,
bootstrap_token: None,
}) => {
let pubkey_bytes = pubkey.as_array().ok_or(Error::InvalidClientPubkeyLength)?;
let pubkey = VerifyingKey::from_bytes(pubkey_bytes)
.map_err(|_| Error::InvalidAuthPubkeyEncoding)?;
Ok(AuthEvents::AuthRequest(ChallengeRequest {
pubkey: pubkey.into(),
}))
}
UserAgentRequestPayload::AuthChallengeRequest(AuthChallengeRequest {
pubkey,
bootstrap_token: Some(token),
}) => {
let pubkey_bytes = pubkey.as_array().ok_or(Error::InvalidClientPubkeyLength)?;
let pubkey = VerifyingKey::from_bytes(pubkey_bytes)
.map_err(|_| Error::InvalidAuthPubkeyEncoding)?;
Ok(AuthEvents::BootstrapAuthRequest(BootstrapAuthRequest {
pubkey: pubkey.into(),
token,
}))
}
UserAgentRequestPayload::AuthChallengeSolution(AuthChallengeSolution { signature }) => {
Ok(AuthEvents::ReceivedSolution(ChallengeSolution {
solution: signature,
}))
}
_ => Err(Error::UnexpectedMessagePayload),
}
}
pub async fn authenticate(props: &mut UserAgentConnection) -> Result<VerifyingKey, Error> {
let mut state = AuthStateMachine::new(AuthContext::new(props));
loop {
// This is needed because `state` now holds mutable reference to `ConnectionProps`, so we can't directly access `props` here
let transport = state.context_mut().conn.transport.as_mut();
let Some(UserAgentRequest {
payload: Some(payload),
}) = transport.recv().await
else {
return Err(Error::Transport);
};
let event = parse_auth_event(payload)?;
match state.process_event(event).await {
Ok(AuthStates::AuthOk(key)) => return Ok(key.clone()),
Err(AuthError::ActionFailed(err)) => {
error!(?err, "State machine action failed");
return Err(err);
}
Err(AuthError::GuardFailed(err)) => {
error!(?err, "State machine guard failed");
return Err(err);
}
Err(AuthError::InvalidEvent) => {
error!("Invalid event for current state");
return Err(Error::InvalidChallengeSolution);
}
Err(AuthError::TransitionsFailed) => {
error!("Invalid state transition");
return Err(Error::InvalidChallengeSolution);
}
_ => (),
}
}
}
pub async fn authenticate_and_create(mut props: UserAgentConnection) -> Result<UserAgentSession, Error> {
let key = authenticate(&mut props).await?;
let session = UserAgentSession::new(props, key.clone());
Ok(session)
}

View File

@@ -0,0 +1,202 @@
use arbiter_proto::proto::user_agent::{
AuthChallenge, UserAgentResponse,
user_agent_response::Payload as UserAgentResponsePayload,
};
use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, update};
use diesel_async::RunQueryDsl;
use ed25519_dalek::VerifyingKey;
use tracing::error;
use super::Error;
use crate::{
actors::{bootstrap::ConsumeToken, user_agent::UserAgentConnection},
db::schema,
};
pub struct ChallengeRequest {
pub pubkey: VerifyingKey,
}
pub struct BootstrapAuthRequest {
pub pubkey: VerifyingKey,
pub token: String,
}
pub struct ChallengeContext {
pub challenge: AuthChallenge,
pub key: VerifyingKey,
}
pub struct ChallengeSolution {
pub solution: Vec<u8>,
}
smlang::statemachine!(
name: Auth,
custom_error: true,
transitions: {
*Init + AuthRequest(ChallengeRequest) / async prepare_challenge = SentChallenge(ChallengeContext),
Init + BootstrapAuthRequest(BootstrapAuthRequest) [async verify_bootstrap_token] / provide_key_bootstrap = AuthOk(VerifyingKey),
SentChallenge(ChallengeContext) + ReceivedSolution(ChallengeSolution) [async verify_solution] / provide_key = AuthOk(VerifyingKey),
}
);
async fn create_nonce(db: &crate::db::DatabasePool, pubkey_bytes: &[u8]) -> Result<i32, Error> {
let mut db_conn = db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error");
Error::DatabasePoolUnavailable
})?;
db_conn
.exclusive_transaction(|conn| {
Box::pin(async move {
let current_nonce = schema::useragent_client::table
.filter(schema::useragent_client::public_key.eq(pubkey_bytes.to_vec()))
.select(schema::useragent_client::nonce)
.first::<i32>(conn)
.await?;
update(schema::useragent_client::table)
.filter(schema::useragent_client::public_key.eq(pubkey_bytes.to_vec()))
.set(schema::useragent_client::nonce.eq(current_nonce + 1))
.execute(conn)
.await?;
Result::<_, diesel::result::Error>::Ok(current_nonce)
})
})
.await
.optional()
.map_err(|e| {
error!(error = ?e, "Database error");
Error::DatabaseOperationFailed
})?
.ok_or_else(|| {
error!(?pubkey_bytes, "Public key not found in database");
Error::PublicKeyNotRegistered
})
}
async fn register_key(db: &crate::db::DatabasePool, pubkey_bytes: &[u8]) -> Result<(), Error> {
let mut conn = db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error");
Error::DatabasePoolUnavailable
})?;
diesel::insert_into(schema::useragent_client::table)
.values((
schema::useragent_client::public_key.eq(pubkey_bytes.to_vec()),
schema::useragent_client::nonce.eq(1),
))
.execute(&mut conn)
.await
.map_err(|e| {
error!(error = ?e, "Database error");
Error::DatabaseOperationFailed
})?;
Ok(())
}
pub struct AuthContext<'a> {
pub(super) conn: &'a mut UserAgentConnection,
}
impl<'a> AuthContext<'a> {
pub fn new(conn: &'a mut UserAgentConnection) -> Self {
Self { conn }
}
}
impl AuthStateMachineContext for AuthContext<'_> {
type Error = Error;
async fn verify_solution(
&self,
ChallengeContext { challenge, key }: &ChallengeContext,
ChallengeSolution { solution }: &ChallengeSolution,
) -> Result<bool, Self::Error> {
let formatted_challenge =
arbiter_proto::format_challenge(challenge.nonce, &challenge.pubkey);
let signature = solution.as_slice().try_into().map_err(|_| {
error!(?solution, "Invalid signature length");
Error::InvalidChallengeSolution
})?;
let valid = key.verify_strict(&formatted_challenge, &signature).is_ok();
Ok(valid)
}
async fn prepare_challenge(
&mut self,
ChallengeRequest { pubkey }: ChallengeRequest,
) -> Result<ChallengeContext, Self::Error> {
let nonce = create_nonce(&self.conn.db, pubkey.as_bytes()).await?;
let challenge = AuthChallenge {
pubkey: pubkey.as_bytes().to_vec(),
nonce,
};
self.conn
.transport
.send(Ok(UserAgentResponse {
payload: Some(UserAgentResponsePayload::AuthChallenge(challenge.clone())),
}))
.await
.map_err(|e| {
error!(?e, "Failed to send auth challenge");
Error::Transport
})?;
Ok(ChallengeContext {
challenge,
key: pubkey,
})
}
#[allow(missing_docs)]
#[allow(clippy::result_unit_err)]
async fn verify_bootstrap_token(
&self,
BootstrapAuthRequest { pubkey, token }: &BootstrapAuthRequest,
) -> Result<bool, Self::Error> {
let token_ok: bool = self
.conn
.actors
.bootstrapper
.ask(ConsumeToken {
token: token.clone(),
})
.await
.map_err(|e| {
error!(?pubkey, "Failed to consume bootstrap token: {e}");
Error::BootstrapperActorUnreachable
})?;
if !token_ok {
error!(?pubkey, "Invalid bootstrap token provided");
return Err(Error::InvalidBootstrapToken);
}
register_key(&self.conn.db, pubkey.as_bytes()).await?;
Ok(true)
}
fn provide_key_bootstrap(
&mut self,
event_data: BootstrapAuthRequest,
) -> Result<VerifyingKey, Self::Error> {
Ok(event_data.pubkey)
}
fn provide_key(
&mut self,
state_data: &ChallengeContext,
_: ChallengeSolution,
) -> Result<VerifyingKey, Self::Error> {
Ok(state_data.key)
}
}

View File

@@ -1,478 +1,65 @@
use std::{ops::DerefMut, sync::Mutex};
use arbiter_proto::{ use arbiter_proto::{
proto::user_agent::{ proto::user_agent::{UserAgentRequest, UserAgentResponse},
AuthChallenge, AuthChallengeRequest, AuthChallengeSolution, AuthOk, UnsealEncryptedKey, transport::Bi,
UnsealResult, UnsealStart, UnsealStartResponse, UserAgentRequest, UserAgentResponse,
user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
},
transport::{Bi, DummyTransport},
}; };
use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit}; use kameo::actor::Spawn as _;
use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, dsl::update};
use diesel_async::RunQueryDsl;
use ed25519_dalek::VerifyingKey;
use kameo::{Actor, error::SendError};
use memsafe::MemSafe;
use tokio::select;
use tracing::{error, info}; use tracing::{error, info};
use x25519_dalek::{EphemeralSecret, PublicKey};
use crate::{ use crate::{
ServerContext, actors::{GlobalActors, user_agent::session::UserAgentSession},
actors::{ db::{self},
GlobalActors,
bootstrap::ConsumeToken,
keyholder::{self, TryUnseal},
user_agent::state::{
ChallengeContext, DummyContext, UnsealContext, UserAgentEvents, UserAgentStateMachine,
UserAgentStates,
},
},
db::{self, schema},
}; };
mod state; #[derive(Debug, thiserror::Error, PartialEq)]
pub enum TransportResponseError {
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
pub enum UserAgentError {
#[error("Expected message with payload")] #[error("Expected message with payload")]
MissingRequestPayload, MissingRequestPayload,
#[error("Expected message with payload")] #[error("Unexpected request payload")]
UnexpectedRequestPayload, UnexpectedRequestPayload,
#[error("Invalid state for challenge solution")]
InvalidStateForChallengeSolution,
#[error("Invalid state for unseal encrypted key")] #[error("Invalid state for unseal encrypted key")]
InvalidStateForUnsealEncryptedKey, InvalidStateForUnsealEncryptedKey,
#[error("client_pubkey must be 32 bytes")] #[error("client_pubkey must be 32 bytes")]
InvalidClientPubkeyLength, InvalidClientPubkeyLength,
#[error("Expected pubkey to have specific length")]
InvalidAuthPubkeyLength,
#[error("Failed to convert pubkey to VerifyingKey")]
InvalidAuthPubkeyEncoding,
#[error("Invalid signature length")]
InvalidSignatureLength,
#[error("Invalid bootstrap token")]
InvalidBootstrapToken,
#[error("Public key not registered")]
PublicKeyNotRegistered,
#[error("Invalid challenge solution")]
InvalidChallengeSolution,
#[error("State machine error")] #[error("State machine error")]
StateTransitionFailed, StateTransitionFailed,
#[error("Bootstrap token consumption failed")]
BootstrapperActorUnreachable,
#[error("Vault is not available")] #[error("Vault is not available")]
KeyHolderActorUnreachable, KeyHolderActorUnreachable,
#[error("Database pool error")] #[error(transparent)]
DatabasePoolUnavailable, Auth(#[from] auth::Error),
#[error("Database error")] #[error("Failed registering connection")]
DatabaseOperationFailed, ConnectionRegistrationFailed,
} }
pub struct UserAgentActor<Transport> pub type Transport =
where Box<dyn Bi<UserAgentRequest, Result<UserAgentResponse, TransportResponseError>> + Send>;
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
{ pub struct UserAgentConnection {
db: db::DatabasePool, db: db::DatabasePool,
actors: GlobalActors, actors: GlobalActors,
state: UserAgentStateMachine<DummyContext>,
transport: Transport, transport: Transport,
} }
impl<Transport> UserAgentActor<Transport> impl UserAgentConnection {
where pub fn new(db: db::DatabasePool, actors: GlobalActors, transport: Transport) -> Self {
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
{
pub(crate) fn new(context: ServerContext, transport: Transport) -> Self {
Self {
db: context.db.clone(),
actors: context.actors.clone(),
state: UserAgentStateMachine::new(DummyContext),
transport,
}
}
fn transition(&mut self, event: UserAgentEvents) -> Result<(), UserAgentError> {
self.state.process_event(event).map_err(|e| {
error!(?e, "State transition failed");
UserAgentError::StateTransitionFailed
})?;
Ok(())
}
pub async fn process_transport_inbound(&mut self, req: UserAgentRequest) -> Output {
let msg = req.payload.ok_or_else(|| {
error!(actor = "useragent", "Received message with no payload");
UserAgentError::MissingRequestPayload
})?;
match msg {
UserAgentRequestPayload::AuthChallengeRequest(req) => {
self.handle_auth_challenge_request(req).await
}
UserAgentRequestPayload::AuthChallengeSolution(solution) => {
self.handle_auth_challenge_solution(solution).await
}
UserAgentRequestPayload::UnsealStart(unseal_start) => {
self.handle_unseal_request(unseal_start).await
}
UserAgentRequestPayload::UnsealEncryptedKey(unseal_encrypted_key) => {
self.handle_unseal_encrypted_key(unseal_encrypted_key).await
}
_ => Err(UserAgentError::UnexpectedRequestPayload),
}
}
async fn auth_with_bootstrap_token(
&mut self,
pubkey: ed25519_dalek::VerifyingKey,
token: String,
) -> Result<UserAgentResponse, UserAgentError> {
let token_ok: bool = self
.actors
.bootstrapper
.ask(ConsumeToken { token })
.await
.map_err(|e| {
error!(?pubkey, "Failed to consume bootstrap token: {e}");
UserAgentError::BootstrapperActorUnreachable
})?;
if !token_ok {
error!(?pubkey, "Invalid bootstrap token provided");
return Err(UserAgentError::InvalidBootstrapToken);
}
{
let mut conn = self.db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error");
UserAgentError::DatabasePoolUnavailable
})?;
diesel::insert_into(schema::useragent_client::table)
.values((
schema::useragent_client::public_key.eq(pubkey.as_bytes().to_vec()),
schema::useragent_client::nonce.eq(1),
))
.execute(&mut conn)
.await
.map_err(|e| {
error!(error = ?e, "Database error");
UserAgentError::DatabaseOperationFailed
})?;
}
self.transition(UserAgentEvents::ReceivedBootstrapToken)?;
Ok(response(UserAgentResponsePayload::AuthOk(AuthOk {})))
}
async fn auth_with_challenge(&mut self, pubkey: VerifyingKey, pubkey_bytes: Vec<u8>) -> Output {
let nonce: Option<i32> = {
let mut db_conn = self.db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error");
UserAgentError::DatabasePoolUnavailable
})?;
db_conn
.exclusive_transaction(|conn| {
Box::pin(async move {
let current_nonce = schema::useragent_client::table
.filter(
schema::useragent_client::public_key.eq(pubkey.as_bytes().to_vec()),
)
.select(schema::useragent_client::nonce)
.first::<i32>(conn)
.await?;
update(schema::useragent_client::table)
.filter(
schema::useragent_client::public_key.eq(pubkey.as_bytes().to_vec()),
)
.set(schema::useragent_client::nonce.eq(current_nonce + 1))
.execute(conn)
.await?;
Result::<_, diesel::result::Error>::Ok(current_nonce)
})
})
.await
.optional()
.map_err(|e| {
error!(error = ?e, "Database error");
UserAgentError::DatabaseOperationFailed
})?
};
let Some(nonce) = nonce else {
error!(?pubkey, "Public key not found in database");
return Err(UserAgentError::PublicKeyNotRegistered);
};
let challenge = AuthChallenge {
pubkey: pubkey_bytes,
nonce,
};
self.transition(UserAgentEvents::SentChallenge(ChallengeContext {
challenge: challenge.clone(),
key: pubkey,
}))?;
info!(
?pubkey,
?challenge,
"Sent authentication challenge to client"
);
Ok(response(UserAgentResponsePayload::AuthChallenge(challenge)))
}
fn verify_challenge_solution(
&self,
solution: &AuthChallengeSolution,
) -> Result<(bool, &ChallengeContext), UserAgentError> {
let UserAgentStates::WaitingForChallengeSolution(challenge_context) = self.state.state()
else {
error!("Received challenge solution in invalid state");
return Err(UserAgentError::InvalidStateForChallengeSolution);
};
let formatted_challenge = arbiter_proto::format_challenge(
challenge_context.challenge.nonce,
&challenge_context.challenge.pubkey,
);
let signature = solution.signature.as_slice().try_into().map_err(|_| {
error!(?solution, "Invalid signature length");
UserAgentError::InvalidSignatureLength
})?;
let valid = challenge_context
.key
.verify_strict(&formatted_challenge, &signature)
.is_ok();
Ok((valid, challenge_context))
}
}
type Output = Result<UserAgentResponse, UserAgentError>;
fn response(payload: UserAgentResponsePayload) -> UserAgentResponse {
UserAgentResponse {
payload: Some(payload),
}
}
impl<Transport> UserAgentActor<Transport>
where
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
{
async fn handle_unseal_request(&mut self, req: UnsealStart) -> Output {
let secret = EphemeralSecret::random();
let public_key = PublicKey::from(&secret);
let client_pubkey_bytes: [u8; 32] = req
.client_pubkey
.try_into()
.map_err(|_| UserAgentError::InvalidClientPubkeyLength)?;
let client_public_key = PublicKey::from(client_pubkey_bytes);
self.transition(UserAgentEvents::UnsealRequest(UnsealContext {
secret: Mutex::new(Some(secret)),
client_public_key,
}))?;
Ok(response(
UserAgentResponsePayload::UnsealStartResponse(UnsealStartResponse {
server_pubkey: public_key.as_bytes().to_vec(),
}),
))
}
async fn handle_unseal_encrypted_key(&mut self, req: UnsealEncryptedKey) -> Output {
let UserAgentStates::WaitingForUnsealKey(unseal_context) = self.state.state() else {
error!("Received unseal encrypted key in invalid state");
return Err(UserAgentError::InvalidStateForUnsealEncryptedKey);
};
let ephemeral_secret = {
let mut secret_lock = unseal_context.secret.lock().unwrap();
let secret = secret_lock.take();
match secret {
Some(secret) => secret,
None => {
drop(secret_lock);
error!("Ephemeral secret already taken");
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
return Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(),
)));
}
}
};
let nonce = XNonce::from_slice(&req.nonce);
let shared_secret = ephemeral_secret.diffie_hellman(&unseal_context.client_public_key);
let cipher = XChaCha20Poly1305::new(shared_secret.as_bytes().into());
let mut seal_key_buffer = MemSafe::new(req.ciphertext.clone()).unwrap();
let decryption_result = {
let mut write_handle = seal_key_buffer.write().unwrap();
let write_handle = write_handle.deref_mut();
cipher.decrypt_in_place(nonce, &req.associated_data, write_handle)
};
match decryption_result {
Ok(_) => {
match self
.actors
.key_holder
.ask(TryUnseal {
seal_key_raw: seal_key_buffer,
})
.await
{
Ok(_) => {
info!("Successfully unsealed key with client-provided key");
self.transition(UserAgentEvents::ReceivedValidKey)?;
Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::Success.into(),
)))
}
Err(SendError::HandlerError(keyholder::Error::InvalidKey)) => {
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(),
)))
}
Err(SendError::HandlerError(err)) => {
error!(?err, "Keyholder failed to unseal key");
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(),
)))
}
Err(err) => {
error!(?err, "Failed to send unseal request to keyholder");
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Err(UserAgentError::KeyHolderActorUnreachable)
}
}
}
Err(err) => {
error!(?err, "Failed to decrypt unseal key");
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(),
)))
}
}
}
async fn handle_auth_challenge_request(&mut self, req: AuthChallengeRequest) -> Output {
let pubkey = req
.pubkey
.as_array()
.ok_or(UserAgentError::InvalidAuthPubkeyLength)?;
let pubkey = VerifyingKey::from_bytes(pubkey).map_err(|_err| {
error!(?pubkey, "Failed to convert to VerifyingKey");
UserAgentError::InvalidAuthPubkeyEncoding
})?;
self.transition(UserAgentEvents::AuthRequest)?;
match req.bootstrap_token {
Some(token) => self.auth_with_bootstrap_token(pubkey, token).await,
None => self.auth_with_challenge(pubkey, req.pubkey).await,
}
}
async fn handle_auth_challenge_solution(
&mut self,
solution: AuthChallengeSolution,
) -> Output {
let (valid, challenge_context) = self.verify_challenge_solution(&solution)?;
if valid {
info!(
?challenge_context,
"Client provided valid solution to authentication challenge"
);
self.transition(UserAgentEvents::ReceivedGoodSolution)?;
Ok(response(UserAgentResponsePayload::AuthOk(AuthOk {})))
} else {
error!("Client provided invalid solution to authentication challenge");
self.transition(UserAgentEvents::ReceivedBadSolution)?;
Err(UserAgentError::InvalidChallengeSolution)
}
}
}
impl<Transport> Actor for UserAgentActor<Transport>
where
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
{
type Args = Self;
type Error = ();
async fn on_start(
args: Self::Args,
_: kameo::prelude::ActorRef<Self>,
) -> Result<Self, Self::Error> {
Ok(args)
}
async fn next(
&mut self,
_actor_ref: kameo::prelude::WeakActorRef<Self>,
mailbox_rx: &mut kameo::prelude::MailboxReceiver<Self>,
) -> Option<kameo::mailbox::Signal<Self>> {
loop {
select! {
signal = mailbox_rx.recv() => {
return signal;
}
msg = self.transport.recv() => {
match msg {
Some(request) => {
match self.process_transport_inbound(request).await {
Ok(response) => {
if self.transport.send(Ok(response)).await.is_err() {
error!(actor = "useragent", reason = "channel closed", "send.failed");
return Some(kameo::mailbox::Signal::Stop);
}
}
Err(err) => {
let _ = self.transport.send(Err(err)).await;
return Some(kameo::mailbox::Signal::Stop);
}
}
}
None => {
info!(actor = "useragent", "transport.closed");
return Some(kameo::mailbox::Signal::Stop);
}
}
}
}
}
}
}
impl UserAgentActor<DummyTransport<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>> {
pub fn new_manual(db: db::DatabasePool, actors: GlobalActors) -> Self {
Self { Self {
db, db,
actors, actors,
state: UserAgentStateMachine::new(DummyContext), transport,
transport: DummyTransport::new(), }
}
}
pub mod auth;
pub mod session;
pub async fn connect_user_agent(props: UserAgentConnection) {
match auth::authenticate_and_create(props).await {
Ok(session) => {
UserAgentSession::spawn(session);
info!("User authenticated, session started");
}
Err(err) => {
error!(?err, "Authentication failed, closing connection");
} }
} }
} }

View File

@@ -0,0 +1,364 @@
use std::{ops::DerefMut, sync::Mutex};
use arbiter_proto::proto::user_agent::{
ClientConnectionCancel, ClientConnectionRequest, UnsealEncryptedKey, UnsealResult,
UnsealStart, UnsealStartResponse, UserAgentRequest, UserAgentResponse,
user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
};
use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit};
use ed25519_dalek::VerifyingKey;
use kameo::{Actor, error::SendError, messages, prelude::Context};
use memsafe::MemSafe;
use tokio::{select, sync::watch};
use tracing::{error, info};
use x25519_dalek::{EphemeralSecret, PublicKey};
use crate::actors::{
keyholder::{self, TryUnseal},
router::RegisterUserAgent,
user_agent::{TransportResponseError, UserAgentConnection},
};
mod state;
use state::{DummyContext, UnsealContext, UserAgentEvents, UserAgentStateMachine, UserAgentStates};
// Error for consumption by other actors
#[derive(Debug, thiserror::Error, PartialEq)]
pub enum Error {
#[error("User agent session ended due to connection loss")]
ConnectionLost,
#[error("User agent session ended due to unexpected message")]
UnexpectedMessage,
}
pub struct UserAgentSession {
props: UserAgentConnection,
key: VerifyingKey,
state: UserAgentStateMachine<DummyContext>,
}
impl UserAgentSession {
pub(crate) fn new(props: UserAgentConnection, key: VerifyingKey) -> Self {
Self {
props,
key,
state: UserAgentStateMachine::new(DummyContext),
}
}
fn transition(&mut self, event: UserAgentEvents) -> Result<(), TransportResponseError> {
self.state.process_event(event).map_err(|e| {
error!(?e, "State transition failed");
TransportResponseError::StateTransitionFailed
})?;
Ok(())
}
async fn send_msg<Reply: kameo::Reply>(
&mut self,
msg: UserAgentResponsePayload,
_ctx: &mut Context<Self, Reply>,
) -> Result<(), Error> {
self.props
.transport
.send(Ok(response(msg)))
.await
.map_err(|_| {
error!(
actor = "useragent",
reason = "channel closed",
"send.failed"
);
Error::ConnectionLost
})
}
async fn expect_msg<Extractor, Msg, Reply>(
&mut self,
extractor: Extractor,
ctx: &mut Context<Self, Reply>,
) -> Result<Msg, Error>
where
Extractor: FnOnce(UserAgentRequestPayload) -> Option<Msg>,
Reply: kameo::Reply,
{
let msg = self.props.transport.recv().await.ok_or_else(|| {
error!(
actor = "useragent",
reason = "channel closed",
"recv.failed"
);
ctx.stop();
Error::ConnectionLost
})?;
msg.payload.and_then(extractor).ok_or_else(|| {
error!(
actor = "useragent",
reason = "unexpected message",
"recv.failed"
);
ctx.stop();
Error::UnexpectedMessage
})
}
}
#[messages]
impl UserAgentSession {
// TODO: Think about refactoring it to state-machine based flow, as we already have one
#[message(ctx)]
pub async fn request_new_client_approval(
&mut self,
client_pubkey: VerifyingKey,
mut cancel_flag: watch::Receiver<()>,
ctx: &mut Context<Self, Result<bool, Error>>,
) -> Result<bool, Error> {
self.send_msg(
UserAgentResponsePayload::ClientConnectionRequest(
ClientConnectionRequest {
pubkey: client_pubkey.as_bytes().to_vec(),
}
.into(),
),
ctx,
)
.await?;
let extractor = |msg| {
if let UserAgentRequestPayload::ClientConnectionResponse(client_connection_response) =
msg
{
Some(client_connection_response)
} else {
None
}
};
tokio::select! {
_ = cancel_flag.changed() => {
info!(actor = "useragent", "client connection approval cancelled");
self.send_msg(
UserAgentResponsePayload::ClientConnectionCancel(ClientConnectionCancel {}),
ctx,
).await?;
return Ok(false);
}
result = self.expect_msg(extractor, ctx) => {
let result = result?;
info!(actor = "useragent", "received client connection approval result: approved={}", result.approved);
return Ok(result.approved);
}
}
}
}
impl UserAgentSession {
pub async fn process_transport_inbound(&mut self, req: UserAgentRequest) -> Output {
let msg = req.payload.ok_or_else(|| {
error!(actor = "useragent", "Received message with no payload");
TransportResponseError::MissingRequestPayload
})?;
match msg {
UserAgentRequestPayload::UnsealStart(unseal_start) => {
self.handle_unseal_request(unseal_start).await
}
UserAgentRequestPayload::UnsealEncryptedKey(unseal_encrypted_key) => {
self.handle_unseal_encrypted_key(unseal_encrypted_key).await
}
_ => Err(TransportResponseError::UnexpectedRequestPayload),
}
}
}
type Output = Result<UserAgentResponse, TransportResponseError>;
fn response(payload: UserAgentResponsePayload) -> UserAgentResponse {
UserAgentResponse {
payload: Some(payload),
}
}
impl UserAgentSession {
async fn handle_unseal_request(&mut self, req: UnsealStart) -> Output {
let secret = EphemeralSecret::random();
let public_key = PublicKey::from(&secret);
let client_pubkey_bytes: [u8; 32] = req
.client_pubkey
.try_into()
.map_err(|_| TransportResponseError::InvalidClientPubkeyLength)?;
let client_public_key = PublicKey::from(client_pubkey_bytes);
self.transition(UserAgentEvents::UnsealRequest(UnsealContext {
secret: Mutex::new(Some(secret)),
client_public_key,
}))?;
Ok(response(UserAgentResponsePayload::UnsealStartResponse(
UnsealStartResponse {
server_pubkey: public_key.as_bytes().to_vec(),
},
)))
}
async fn handle_unseal_encrypted_key(&mut self, req: UnsealEncryptedKey) -> Output {
let UserAgentStates::WaitingForUnsealKey(unseal_context) = self.state.state() else {
error!("Received unseal encrypted key in invalid state");
return Err(TransportResponseError::InvalidStateForUnsealEncryptedKey);
};
let ephemeral_secret = {
let mut secret_lock = unseal_context.secret.lock().unwrap();
let secret = secret_lock.take();
match secret {
Some(secret) => secret,
None => {
drop(secret_lock);
error!("Ephemeral secret already taken");
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
return Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(),
)));
}
}
};
let nonce = XNonce::from_slice(&req.nonce);
let shared_secret = ephemeral_secret.diffie_hellman(&unseal_context.client_public_key);
let cipher = XChaCha20Poly1305::new(shared_secret.as_bytes().into());
let mut seal_key_buffer = MemSafe::new(req.ciphertext.clone()).unwrap();
let decryption_result = {
let mut write_handle = seal_key_buffer.write().unwrap();
let write_handle = write_handle.deref_mut();
cipher.decrypt_in_place(nonce, &req.associated_data, write_handle)
};
match decryption_result {
Ok(_) => {
match self
.props
.actors
.key_holder
.ask(TryUnseal {
seal_key_raw: seal_key_buffer,
})
.await
{
Ok(_) => {
info!("Successfully unsealed key with client-provided key");
self.transition(UserAgentEvents::ReceivedValidKey)?;
Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::Success.into(),
)))
}
Err(SendError::HandlerError(keyholder::Error::InvalidKey)) => {
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(),
)))
}
Err(SendError::HandlerError(err)) => {
error!(?err, "Keyholder failed to unseal key");
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(),
)))
}
Err(err) => {
error!(?err, "Failed to send unseal request to keyholder");
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Err(TransportResponseError::KeyHolderActorUnreachable)
}
}
}
Err(err) => {
error!(?err, "Failed to decrypt unseal key");
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(),
)))
}
}
}
}
impl Actor for UserAgentSession {
type Args = Self;
type Error = TransportResponseError;
async fn on_start(
args: Self::Args,
this: kameo::prelude::ActorRef<Self>,
) -> Result<Self, Self::Error> {
args.props
.actors
.router
.ask(RegisterUserAgent {
actor: this.clone(),
})
.await
.map_err(|err| {
error!(?err, "Failed to register user agent connection with router");
TransportResponseError::ConnectionRegistrationFailed
})?;
Ok(args)
}
async fn next(
&mut self,
_actor_ref: kameo::prelude::WeakActorRef<Self>,
mailbox_rx: &mut kameo::prelude::MailboxReceiver<Self>,
) -> Option<kameo::mailbox::Signal<Self>> {
loop {
select! {
signal = mailbox_rx.recv() => {
return signal;
}
msg = self.props.transport.recv() => {
match msg {
Some(request) => {
match self.process_transport_inbound(request).await {
Ok(response) => {
if self.props.transport.send(Ok(response)).await.is_err() {
error!(actor = "useragent", reason = "channel closed", "send.failed");
return Some(kameo::mailbox::Signal::Stop);
}
}
Err(err) => {
let _ = self.props.transport.send(Err(err)).await;
return Some(kameo::mailbox::Signal::Stop);
}
}
}
None => {
info!(actor = "useragent", "transport.closed");
return Some(kameo::mailbox::Signal::Stop);
}
}
}
}
}
}
}
impl UserAgentSession {
pub fn new_test(db: crate::db::DatabasePool, actors: crate::actors::GlobalActors) -> Self {
use arbiter_proto::transport::DummyTransport;
let transport: super::Transport = Box::new(DummyTransport::new());
let props = UserAgentConnection::new(db, actors, transport);
let key = VerifyingKey::from_bytes(&[0u8; 32]).unwrap();
Self {
props,
key,
state: UserAgentStateMachine::new(DummyContext),
}
}
}

View File

@@ -0,0 +1,27 @@
use std::sync::Mutex;
use x25519_dalek::{EphemeralSecret, PublicKey};
pub struct UnsealContext {
pub client_public_key: PublicKey,
pub secret: Mutex<Option<EphemeralSecret>>,
}
smlang::statemachine!(
name: UserAgent,
custom_error: false,
transitions: {
*Idle + UnsealRequest(UnsealContext) / generate_temp_keypair = WaitingForUnsealKey(UnsealContext),
WaitingForUnsealKey(UnsealContext) + ReceivedValidKey = Unsealed,
WaitingForUnsealKey(UnsealContext) + ReceivedInvalidKey = Idle,
}
);
pub struct DummyContext;
impl UserAgentStateMachineContext for DummyContext {
#[allow(missing_docs)]
#[allow(clippy::unused_unit)]
fn generate_temp_keypair(&mut self, event_data: UnsealContext) -> Result<UnsealContext, ()> {
Ok(event_data)
}
}

View File

@@ -1,51 +0,0 @@
use std::sync::Mutex;
use arbiter_proto::proto::user_agent::AuthChallenge;
use ed25519_dalek::VerifyingKey;
use x25519_dalek::{EphemeralSecret, PublicKey};
/// Context for state machine with validated key and sent challenge
/// Challenge is then transformed to bytes using shared function and verified
#[derive(Clone, Debug)]
pub struct ChallengeContext {
pub challenge: AuthChallenge,
pub key: VerifyingKey,
}
pub struct UnsealContext {
pub client_public_key: PublicKey,
pub secret: Mutex<Option<EphemeralSecret>>,
}
smlang::statemachine!(
name: UserAgent,
custom_error: false,
transitions: {
*Init + AuthRequest = ReceivedAuthRequest,
ReceivedAuthRequest + ReceivedBootstrapToken = Idle,
ReceivedAuthRequest + SentChallenge(ChallengeContext) / move_challenge = WaitingForChallengeSolution(ChallengeContext),
WaitingForChallengeSolution(ChallengeContext) + ReceivedGoodSolution = Idle,
WaitingForChallengeSolution(ChallengeContext) + ReceivedBadSolution = AuthError, // block further transitions, but connection should close anyway
Idle + UnsealRequest(UnsealContext) / generate_temp_keypair = WaitingForUnsealKey(UnsealContext),
WaitingForUnsealKey(UnsealContext) + ReceivedValidKey = Unsealed,
WaitingForUnsealKey(UnsealContext) + ReceivedInvalidKey = Idle,
}
);
pub struct DummyContext;
impl UserAgentStateMachineContext for DummyContext {
#[allow(missing_docs)]
#[allow(clippy::unused_unit)]
fn generate_temp_keypair(&mut self, event_data: UnsealContext) -> Result<UnsealContext, ()> {
Ok(event_data)
}
#[allow(missing_docs)]
#[allow(clippy::unused_unit)]
fn move_challenge(&mut self, event_data: ChallengeContext) -> Result<ChallengeContext, ()> {
Ok(event_data)
}
}

View File

@@ -1,12 +1,13 @@
#![allow(unused)] #![allow(unused)]
#![allow(clippy::all)] #![allow(clippy::all)]
use crate::db::schema::{self, aead_encrypted, arbiter_settings, root_key_history, tls_history}; use crate::db::{ schema::{self, aead_encrypted, arbiter_settings, root_key_history, tls_history}};
use diesel::{prelude::*, sqlite::Sqlite}; use diesel::{prelude::*, sql_types::Bool, sqlite::Sqlite};
use restructed::Models; use restructed::Models;
pub mod types { pub mod types {
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use diesel::{deserialize::FromSql, expression::AsExpression, serialize::{IsNull, ToSql}, sql_types::Integer, sqlite::Sqlite};
pub struct SqliteTimestamp(DateTime<Utc>); pub struct SqliteTimestamp(DateTime<Utc>);
} }
@@ -71,7 +72,7 @@ pub struct ArbiterSettings {
pub tls_id: Option<i32>, // references tls_history.id pub tls_id: Option<i32>, // references tls_history.id
} }
#[derive(Queryable, Debug)] #[derive(Queryable, Debug, Insertable, Selectable)]
#[diesel(table_name = schema::program_client, check_for_backend(Sqlite))] #[diesel(table_name = schema::program_client, check_for_backend(Sqlite))]
pub struct ProgramClient { pub struct ProgramClient {
pub id: i32, pub id: i32,

View File

@@ -7,7 +7,6 @@ use arbiter_proto::{
transport::{IdentityRecvConverter, SendConverter, grpc}, transport::{IdentityRecvConverter, SendConverter, grpc},
}; };
use async_trait::async_trait; use async_trait::async_trait;
use kameo::actor::Spawn;
use tokio_stream::wrappers::ReceiverStream; use tokio_stream::wrappers::ReceiverStream;
use tokio::sync::mpsc; use tokio::sync::mpsc;
@@ -16,8 +15,8 @@ use tracing::info;
use crate::{ use crate::{
actors::{ actors::{
client::{ClientActor, ClientError}, client::{self, ClientError, ClientConnection as ClientConnectionProps, connect_client},
user_agent::{UserAgentActor, UserAgentError}, user_agent::{self, UserAgentConnection, TransportResponseError, connect_user_agent},
}, },
context::ServerContext, context::ServerContext,
}; };
@@ -28,15 +27,10 @@ pub mod db;
const DEFAULT_CHANNEL_SIZE: usize = 1000; const DEFAULT_CHANNEL_SIZE: usize = 1000;
/// Converts User Agent domain outbounds into the tonic stream item emitted by
/// the server.§
///
/// The conversion is defined at the server boundary so the actor module remains
/// focused on domain semantics and does not depend on tonic status encoding.
struct UserAgentGrpcSender; struct UserAgentGrpcSender;
impl SendConverter for UserAgentGrpcSender { impl SendConverter for UserAgentGrpcSender {
type Input = Result<UserAgentResponse, UserAgentError>; type Input = Result<UserAgentResponse, TransportResponseError>;
type Output = Result<UserAgentResponse, Status>; type Output = Result<UserAgentResponse, Status>;
fn convert(&self, item: Self::Input) -> Self::Output { fn convert(&self, item: Self::Input) -> Self::Output {
@@ -47,11 +41,6 @@ impl SendConverter for UserAgentGrpcSender {
} }
} }
/// Converts Client domain outbounds into the tonic stream item emitted by the
/// server.
///
/// The conversion is defined at the server boundary so the actor module remains
/// focused on domain semantics and does not depend on tonic status encoding.
struct ClientGrpcSender; struct ClientGrpcSender;
impl SendConverter for ClientGrpcSender { impl SendConverter for ClientGrpcSender {
@@ -66,78 +55,76 @@ impl SendConverter for ClientGrpcSender {
} }
} }
/// Maps Client domain errors to public gRPC transport errors for the `client`
/// streaming endpoint.
fn client_error_status(value: ClientError) -> Status { fn client_error_status(value: ClientError) -> Status {
match value { match value {
ClientError::MissingRequestPayload | ClientError::UnexpectedRequestPayload => { ClientError::MissingRequestPayload | ClientError::UnexpectedRequestPayload => {
Status::invalid_argument("Expected message with payload") Status::invalid_argument("Expected message with payload")
} }
ClientError::InvalidStateForChallengeSolution => {
Status::invalid_argument("Invalid state for challenge solution")
}
ClientError::InvalidAuthPubkeyLength => {
Status::invalid_argument("Expected pubkey to have specific length")
}
ClientError::InvalidAuthPubkeyEncoding => {
Status::invalid_argument("Failed to convert pubkey to VerifyingKey")
}
ClientError::InvalidSignatureLength => {
Status::invalid_argument("Invalid signature length")
}
ClientError::PublicKeyNotRegistered => {
Status::unauthenticated("Public key not registered")
}
ClientError::InvalidChallengeSolution => {
Status::unauthenticated("Invalid challenge solution")
}
ClientError::StateTransitionFailed => Status::internal("State machine error"), ClientError::StateTransitionFailed => Status::internal("State machine error"),
ClientError::DatabasePoolUnavailable => Status::internal("Database pool error"), ClientError::Auth(ref err) => client_auth_error_status(err),
ClientError::DatabaseOperationFailed => Status::internal("Database error"), ClientError::ConnectionRegistrationFailed => {
Status::internal("Connection registration failed")
}
} }
} }
/// Maps User Agent domain errors to public gRPC transport errors for the fn client_auth_error_status(value: &client::auth::Error) -> Status {
/// `user_agent` streaming endpoint. use client::auth::Error;
fn user_agent_error_status(value: UserAgentError) -> Status {
match value { match value {
UserAgentError::MissingRequestPayload | UserAgentError::UnexpectedRequestPayload => { Error::UnexpectedMessagePayload | Error::InvalidClientPubkeyLength => {
Status::invalid_argument("Expected message with payload") Status::invalid_argument(value.to_string())
} }
UserAgentError::InvalidStateForChallengeSolution => { Error::InvalidAuthPubkeyEncoding => {
Status::invalid_argument("Invalid state for challenge solution")
}
UserAgentError::InvalidStateForUnsealEncryptedKey => {
Status::failed_precondition("Invalid state for unseal encrypted key")
}
UserAgentError::InvalidClientPubkeyLength => {
Status::invalid_argument("client_pubkey must be 32 bytes")
}
UserAgentError::InvalidAuthPubkeyLength => {
Status::invalid_argument("Expected pubkey to have specific length")
}
UserAgentError::InvalidAuthPubkeyEncoding => {
Status::invalid_argument("Failed to convert pubkey to VerifyingKey") Status::invalid_argument("Failed to convert pubkey to VerifyingKey")
} }
UserAgentError::InvalidSignatureLength => { Error::InvalidChallengeSolution => Status::unauthenticated(value.to_string()),
Status::invalid_argument("Invalid signature length") Error::ApproveError(_) => Status::permission_denied(value.to_string()),
Error::Transport => Status::internal("Transport error"),
Error::DatabasePoolUnavailable => Status::internal("Database pool error"),
Error::DatabaseOperationFailed => Status::internal("Database error"),
Error::InternalError => Status::internal("Internal error"),
} }
UserAgentError::InvalidBootstrapToken => {
Status::invalid_argument("Invalid bootstrap token")
} }
UserAgentError::PublicKeyNotRegistered => {
Status::unauthenticated("Public key not registered") fn user_agent_error_status(value: TransportResponseError) -> Status {
match value {
TransportResponseError::MissingRequestPayload | TransportResponseError::UnexpectedRequestPayload => {
Status::invalid_argument("Expected message with payload")
} }
UserAgentError::InvalidChallengeSolution => { TransportResponseError::InvalidStateForUnsealEncryptedKey => {
Status::unauthenticated("Invalid challenge solution") Status::failed_precondition("Invalid state for unseal encrypted key")
} }
UserAgentError::StateTransitionFailed => Status::internal("State machine error"), TransportResponseError::InvalidClientPubkeyLength => {
UserAgentError::BootstrapperActorUnreachable => { Status::invalid_argument("client_pubkey must be 32 bytes")
}
TransportResponseError::StateTransitionFailed => Status::internal("State machine error"),
TransportResponseError::KeyHolderActorUnreachable => Status::internal("Vault is not available"),
TransportResponseError::Auth(ref err) => auth_error_status(err),
TransportResponseError::ConnectionRegistrationFailed => {
Status::internal("Failed registering connection")
}
}
}
fn auth_error_status(value: &user_agent::auth::Error) -> Status {
use user_agent::auth::Error;
match value {
Error::UnexpectedMessagePayload | Error::InvalidClientPubkeyLength => {
Status::invalid_argument(value.to_string())
}
Error::InvalidAuthPubkeyEncoding => {
Status::invalid_argument("Failed to convert pubkey to VerifyingKey")
}
Error::PublicKeyNotRegistered | Error::InvalidChallengeSolution => {
Status::unauthenticated(value.to_string())
}
Error::InvalidBootstrapToken => Status::invalid_argument("Invalid bootstrap token"),
Error::Transport => Status::internal("Transport error"),
Error::BootstrapperActorUnreachable => {
Status::internal("Bootstrap token consumption failed") Status::internal("Bootstrap token consumption failed")
} }
UserAgentError::KeyHolderActorUnreachable => Status::internal("Vault is not available"), Error::DatabasePoolUnavailable => Status::internal("Database pool error"),
UserAgentError::DatabasePoolUnavailable => Status::internal("Database pool error"), Error::DatabaseOperationFailed => Status::internal("Database error"),
UserAgentError::DatabaseOperationFailed => Status::internal("Database error"),
} }
} }
@@ -170,7 +157,12 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server {
IdentityRecvConverter::<ClientRequest>::new(), IdentityRecvConverter::<ClientRequest>::new(),
ClientGrpcSender, ClientGrpcSender,
); );
ClientActor::spawn(ClientActor::new(self.context.clone(), transport)); let props = ClientConnectionProps::new(
self.context.db.clone(),
Box::new(transport),
self.context.actors.clone(),
);
tokio::spawn(connect_client(props));
info!(event = "connection established", "grpc.client"); info!(event = "connection established", "grpc.client");
@@ -191,7 +183,12 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server {
IdentityRecvConverter::<UserAgentRequest>::new(), IdentityRecvConverter::<UserAgentRequest>::new(),
UserAgentGrpcSender, UserAgentGrpcSender,
); );
UserAgentActor::spawn(UserAgentActor::new(self.context.clone(), transport)); let props = UserAgentConnection::new(
self.context.db.clone(),
self.context.actors.clone(),
Box::new(transport),
);
tokio::spawn(connect_user_agent(props));
info!(event = "connection established", "grpc.user_agent"); info!(event = "connection established", "grpc.user_agent");

View File

@@ -1,2 +1,4 @@
mod common;
#[path = "client/auth.rs"] #[path = "client/auth.rs"]
mod auth; mod auth;

View File

@@ -1,44 +1,46 @@
use arbiter_proto::proto::client::{ use arbiter_proto::proto::client::{
AuthChallengeRequest, AuthChallengeSolution, AuthOk, ClientRequest, ClientResponse, AuthChallengeRequest, AuthChallengeSolution, ClientRequest,
client_request::Payload as ClientRequestPayload, client_request::Payload as ClientRequestPayload,
client_response::Payload as ClientResponsePayload, client_response::Payload as ClientResponsePayload,
}; };
use arbiter_proto::transport::Bi;
use arbiter_server::actors::GlobalActors;
use arbiter_server::{ use arbiter_server::{
actors::client::{ClientActor, ClientError}, actors::client::{ClientConnection, connect_client},
db::{self, schema}, db::{self, schema},
}; };
use diesel::{ExpressionMethods as _, insert_into}; use diesel::{ExpressionMethods as _, insert_into};
use diesel_async::RunQueryDsl; use diesel_async::RunQueryDsl;
use ed25519_dalek::Signer as _; use ed25519_dalek::Signer as _;
use super::common::ChannelTransport;
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
pub async fn test_unregistered_pubkey_rejected() { pub async fn test_unregistered_pubkey_rejected() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let mut client = ClientActor::new_manual(db.clone()); let (server_transport, mut test_transport) = ChannelTransport::new();
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
let props = ClientConnection::new(db.clone(), Box::new(server_transport), actors);
let task = tokio::spawn(connect_client(props));
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
let result = client test_transport
.process_transport_inbound(ClientRequest { .send(ClientRequest {
payload: Some(ClientRequestPayload::AuthChallengeRequest( payload: Some(ClientRequestPayload::AuthChallengeRequest(
AuthChallengeRequest { AuthChallengeRequest {
pubkey: pubkey_bytes, pubkey: pubkey_bytes,
}, },
)), )),
}) })
.await; .await
.unwrap();
match result { // Auth fails, connect_client returns, transport drops
Err(err) => { task.await.unwrap();
assert_eq!(err, ClientError::PublicKeyNotRegistered);
}
Ok(_) => {
panic!("Expected error due to unregistered pubkey, but got success");
}
}
} }
#[tokio::test] #[tokio::test]
@@ -46,8 +48,6 @@ pub async fn test_unregistered_pubkey_rejected() {
pub async fn test_challenge_auth() { pub async fn test_challenge_auth() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let mut client = ClientActor::new_manual(db.clone());
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
@@ -60,8 +60,15 @@ pub async fn test_challenge_auth() {
.unwrap(); .unwrap();
} }
let result = client let (server_transport, mut test_transport) = ChannelTransport::new();
.process_transport_inbound(ClientRequest { let actors = GlobalActors::spawn(db.clone()).await.unwrap();
let props = ClientConnection::new(db.clone(), Box::new(server_transport), actors);
let task = tokio::spawn(connect_client(props));
// Send challenge request
test_transport
.send(ClientRequest {
payload: Some(ClientRequestPayload::AuthChallengeRequest( payload: Some(ClientRequestPayload::AuthChallengeRequest(
AuthChallengeRequest { AuthChallengeRequest {
pubkey: pubkey_bytes, pubkey: pubkey_bytes,
@@ -69,34 +76,36 @@ pub async fn test_challenge_auth() {
)), )),
}) })
.await .await
.expect("Shouldn't fail to process message"); .unwrap();
let ClientResponse { // Read the challenge response
payload: Some(ClientResponsePayload::AuthChallenge(challenge)), let response = test_transport
} = result .recv()
else { .await
panic!("Expected auth challenge response, got {result:?}"); .expect("should receive challenge");
let challenge = match response {
Ok(resp) => match resp.payload {
Some(ClientResponsePayload::AuthChallenge(c)) => c,
other => panic!("Expected AuthChallenge, got {other:?}"),
},
Err(err) => panic!("Expected Ok response, got Err({err:?})"),
}; };
// Sign the challenge and send solution
let formatted_challenge = arbiter_proto::format_challenge(challenge.nonce, &challenge.pubkey); let formatted_challenge = arbiter_proto::format_challenge(challenge.nonce, &challenge.pubkey);
let signature = new_key.sign(&formatted_challenge); let signature = new_key.sign(&formatted_challenge);
let serialized_signature = signature.to_bytes().to_vec();
let result = client test_transport
.process_transport_inbound(ClientRequest { .send(ClientRequest {
payload: Some(ClientRequestPayload::AuthChallengeSolution( payload: Some(ClientRequestPayload::AuthChallengeSolution(
AuthChallengeSolution { AuthChallengeSolution {
signature: serialized_signature, signature: signature.to_bytes().to_vec(),
}, },
)), )),
}) })
.await .await
.expect("Shouldn't fail to process message"); .unwrap();
assert_eq!( // Auth completes, session spawned
result, task.await.unwrap();
ClientResponse {
payload: Some(ClientResponsePayload::AuthOk(AuthOk {})),
}
);
} }

View File

@@ -1,10 +1,14 @@
use arbiter_proto::transport::{Bi, Error};
use arbiter_server::{ use arbiter_server::{
actors::keyholder::KeyHolder, actors::keyholder::KeyHolder,
db::{self, schema}, db::{self, schema},
}; };
use async_trait::async_trait;
use diesel::QueryDsl; use diesel::QueryDsl;
use diesel_async::RunQueryDsl; use diesel_async::RunQueryDsl;
use memsafe::MemSafe; use memsafe::MemSafe;
use tokio::sync::mpsc;
#[allow(dead_code)] #[allow(dead_code)]
pub async fn bootstrapped_keyholder(db: &db::DatabasePool) -> KeyHolder { pub async fn bootstrapped_keyholder(db: &db::DatabasePool) -> KeyHolder {
@@ -26,3 +30,46 @@ pub async fn root_key_history_id(db: &db::DatabasePool) -> i32 {
.unwrap(); .unwrap();
id.expect("root_key_id should be set after bootstrap") id.expect("root_key_id should be set after bootstrap")
} }
pub struct ChannelTransport<T, Y> {
receiver: mpsc::Receiver<T>,
sender: mpsc::Sender<Y>,
}
impl<T, Y> ChannelTransport<T, Y> {
pub fn new() -> (Self, ChannelTransport<Y, T>) {
let (tx1, rx1) = mpsc::channel(10);
let (tx2, rx2) = mpsc::channel(10);
(
Self {
receiver: rx1,
sender: tx2,
},
ChannelTransport {
receiver: rx2,
sender: tx1,
},
)
}
}
#[async_trait]
impl<T, Y> Bi<T, Y> for ChannelTransport<T, Y>
where
T: Send + 'static,
Y: Send + 'static,
{
async fn send(&mut self, item: Y) -> Result<(), Error> {
self.sender
.send(item)
.await
.map_err(|_| Error::ChannelClosed)
}
async fn recv(&mut self) -> Option<T> {
self.receiver.recv().await
}
}

View File

@@ -1,13 +1,14 @@
use arbiter_proto::proto::user_agent::{ use arbiter_proto::proto::user_agent::{
AuthChallengeRequest, AuthChallengeSolution, AuthOk, UserAgentRequest, UserAgentResponse, AuthChallengeRequest, AuthChallengeSolution, UserAgentRequest,
user_agent_request::Payload as UserAgentRequestPayload, user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload, user_agent_response::Payload as UserAgentResponsePayload,
}; };
use arbiter_proto::transport::Bi;
use arbiter_server::{ use arbiter_server::{
actors::{ actors::{
GlobalActors, GlobalActors,
bootstrap::GetToken, bootstrap::GetToken,
user_agent::{UserAgentActor, UserAgentError}, user_agent::{UserAgentConnection, connect_user_agent},
}, },
db::{self, schema}, db::{self, schema},
}; };
@@ -15,20 +16,24 @@ use diesel::{ExpressionMethods as _, QueryDsl, insert_into};
use diesel_async::RunQueryDsl; use diesel_async::RunQueryDsl;
use ed25519_dalek::Signer as _; use ed25519_dalek::Signer as _;
use super::common::ChannelTransport;
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
pub async fn test_bootstrap_token_auth() { pub async fn test_bootstrap_token_auth() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let actors = GlobalActors::spawn(db.clone()).await.unwrap(); let actors = GlobalActors::spawn(db.clone()).await.unwrap();
let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap(); let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap();
let mut user_agent = UserAgentActor::new_manual(db.clone(), actors);
let (server_transport, mut test_transport) = ChannelTransport::new();
let props = UserAgentConnection::new(db.clone(), actors, Box::new(server_transport));
let task = tokio::spawn(connect_user_agent(props));
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
let result = user_agent test_transport
.process_transport_inbound(UserAgentRequest { .send(UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthChallengeRequest( payload: Some(UserAgentRequestPayload::AuthChallengeRequest(
AuthChallengeRequest { AuthChallengeRequest {
pubkey: pubkey_bytes, pubkey: pubkey_bytes,
@@ -37,14 +42,9 @@ pub async fn test_bootstrap_token_auth() {
)), )),
}) })
.await .await
.expect("Shouldn't fail to process message"); .unwrap();
assert_eq!( task.await.unwrap();
result,
UserAgentResponse {
payload: Some(UserAgentResponsePayload::AuthOk(AuthOk {})),
}
);
let mut conn = db.get().await.unwrap(); let mut conn = db.get().await.unwrap();
let stored_pubkey: Vec<u8> = schema::useragent_client::table let stored_pubkey: Vec<u8> = schema::useragent_client::table
@@ -59,15 +59,17 @@ pub async fn test_bootstrap_token_auth() {
#[test_log::test] #[test_log::test]
pub async fn test_bootstrap_invalid_token_auth() { pub async fn test_bootstrap_invalid_token_auth() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let actors = GlobalActors::spawn(db.clone()).await.unwrap(); let actors = GlobalActors::spawn(db.clone()).await.unwrap();
let mut user_agent = UserAgentActor::new_manual(db.clone(), actors);
let (server_transport, mut test_transport) = ChannelTransport::new();
let props = UserAgentConnection::new(db.clone(), actors, Box::new(server_transport));
let task = tokio::spawn(connect_user_agent(props));
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
let result = user_agent test_transport
.process_transport_inbound(UserAgentRequest { .send(UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthChallengeRequest( payload: Some(UserAgentRequestPayload::AuthChallengeRequest(
AuthChallengeRequest { AuthChallengeRequest {
pubkey: pubkey_bytes, pubkey: pubkey_bytes,
@@ -75,25 +77,27 @@ pub async fn test_bootstrap_invalid_token_auth() {
}, },
)), )),
}) })
.await; .await
.unwrap();
match result { // Auth fails, connect_user_agent returns, transport drops
Err(err) => { task.await.unwrap();
assert_eq!(err, UserAgentError::InvalidBootstrapToken);
} // Verify no key was registered
Ok(_) => { let mut conn = db.get().await.unwrap();
panic!("Expected error due to invalid bootstrap token, but got success"); let count: i64 = schema::useragent_client::table
} .count()
} .get_result::<i64>(&mut conn)
.await
.unwrap();
assert_eq!(count, 0);
} }
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
pub async fn test_challenge_auth() { pub async fn test_challenge_auth() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let actors = GlobalActors::spawn(db.clone()).await.unwrap(); let actors = GlobalActors::spawn(db.clone()).await.unwrap();
let mut user_agent = UserAgentActor::new_manual(db.clone(), actors);
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
@@ -107,8 +111,13 @@ pub async fn test_challenge_auth() {
.unwrap(); .unwrap();
} }
let result = user_agent let (server_transport, mut test_transport) = ChannelTransport::new();
.process_transport_inbound(UserAgentRequest { let props = UserAgentConnection::new(db.clone(), actors, Box::new(server_transport));
let task = tokio::spawn(connect_user_agent(props));
// Send challenge request
test_transport
.send(UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthChallengeRequest( payload: Some(UserAgentRequestPayload::AuthChallengeRequest(
AuthChallengeRequest { AuthChallengeRequest {
pubkey: pubkey_bytes, pubkey: pubkey_bytes,
@@ -117,34 +126,36 @@ pub async fn test_challenge_auth() {
)), )),
}) })
.await .await
.expect("Shouldn't fail to process message"); .unwrap();
let UserAgentResponse { // Read the challenge response
payload: Some(UserAgentResponsePayload::AuthChallenge(challenge)), let response = test_transport
} = result .recv()
else { .await
panic!("Expected auth challenge response, got {result:?}"); .expect("should receive challenge");
let challenge = match response {
Ok(resp) => match resp.payload {
Some(UserAgentResponsePayload::AuthChallenge(c)) => c,
other => panic!("Expected AuthChallenge, got {other:?}"),
},
Err(err) => panic!("Expected Ok response, got Err({err:?})"),
}; };
// Sign the challenge and send solution
let formatted_challenge = arbiter_proto::format_challenge(challenge.nonce, &challenge.pubkey); let formatted_challenge = arbiter_proto::format_challenge(challenge.nonce, &challenge.pubkey);
let signature = new_key.sign(&formatted_challenge); let signature = new_key.sign(&formatted_challenge);
let serialized_signature = signature.to_bytes().to_vec();
let result = user_agent test_transport
.process_transport_inbound(UserAgentRequest { .send(UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthChallengeSolution( payload: Some(UserAgentRequestPayload::AuthChallengeSolution(
AuthChallengeSolution { AuthChallengeSolution {
signature: serialized_signature, signature: signature.to_bytes().to_vec(),
}, },
)), )),
}) })
.await .await
.expect("Shouldn't fail to process message"); .unwrap();
assert_eq!( // Auth completes, session spawned
result, task.await.unwrap();
UserAgentResponse {
payload: Some(UserAgentResponsePayload::AuthOk(AuthOk {})),
}
);
} }

View File

@@ -1,16 +1,13 @@
use arbiter_proto::proto::user_agent::{ use arbiter_proto::proto::user_agent::{
AuthChallengeRequest, UnsealEncryptedKey, UnsealResult, UnsealStart, UnsealEncryptedKey, UnsealResult, UnsealStart, UserAgentRequest,
UserAgentRequest, UserAgentResponse,
user_agent_request::Payload as UserAgentRequestPayload, user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload, user_agent_response::Payload as UserAgentResponsePayload,
}; };
use arbiter_proto::transport::DummyTransport;
use arbiter_server::{ use arbiter_server::{
actors::{ actors::{
GlobalActors, GlobalActors,
bootstrap::GetToken,
keyholder::{Bootstrap, Seal}, keyholder::{Bootstrap, Seal},
user_agent::{UserAgentActor, UserAgentError}, user_agent::session::UserAgentSession,
}, },
db, db,
}; };
@@ -18,18 +15,12 @@ use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit};
use memsafe::MemSafe; use memsafe::MemSafe;
use x25519_dalek::{EphemeralSecret, PublicKey}; use x25519_dalek::{EphemeralSecret, PublicKey};
type TestUserAgent = async fn setup_sealed_user_agent(
UserAgentActor<DummyTransport<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>>;
async fn setup_authenticated_user_agent(
seal_key: &[u8], seal_key: &[u8],
) -> ( ) -> (db::DatabasePool, UserAgentSession) {
arbiter_server::db::DatabasePool,
TestUserAgent,
) {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let actors = GlobalActors::spawn(db.clone()).await.unwrap(); let actors = GlobalActors::spawn(db.clone()).await.unwrap();
actors actors
.key_holder .key_holder
.ask(Bootstrap { .ask(Bootstrap {
@@ -39,27 +30,13 @@ async fn setup_authenticated_user_agent(
.unwrap(); .unwrap();
actors.key_holder.ask(Seal).await.unwrap(); actors.key_holder.ask(Seal).await.unwrap();
let mut user_agent = UserAgentActor::new_manual(db.clone(), actors.clone()); let session = UserAgentSession::new_test(db.clone(), actors);
let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap(); (db, session)
let auth_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
user_agent
.process_transport_inbound(UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthChallengeRequest(
AuthChallengeRequest {
pubkey: auth_key.verifying_key().to_bytes().to_vec(),
bootstrap_token: Some(token),
},
)),
})
.await
.unwrap();
(db, user_agent)
} }
async fn client_dh_encrypt( async fn client_dh_encrypt(
user_agent: &mut TestUserAgent, user_agent: &mut UserAgentSession,
key_to_send: &[u8], key_to_send: &[u8],
) -> UnsealEncryptedKey { ) -> UnsealEncryptedKey {
let client_secret = EphemeralSecret::random(); let client_secret = EphemeralSecret::random();
@@ -106,7 +83,7 @@ fn unseal_key_request(req: UnsealEncryptedKey) -> UserAgentRequest {
#[test_log::test] #[test_log::test]
pub async fn test_unseal_success() { pub async fn test_unseal_success() {
let seal_key = b"test-seal-key"; let seal_key = b"test-seal-key";
let (_db, mut user_agent) = setup_authenticated_user_agent(seal_key).await; let (_db, mut user_agent) = setup_sealed_user_agent(seal_key).await;
let encrypted_key = client_dh_encrypt(&mut user_agent, seal_key).await; let encrypted_key = client_dh_encrypt(&mut user_agent, seal_key).await;
@@ -124,7 +101,7 @@ pub async fn test_unseal_success() {
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
pub async fn test_unseal_wrong_seal_key() { pub async fn test_unseal_wrong_seal_key() {
let (_db, mut user_agent) = setup_authenticated_user_agent(b"correct-key").await; let (_db, mut user_agent) = setup_sealed_user_agent(b"correct-key").await;
let encrypted_key = client_dh_encrypt(&mut user_agent, b"wrong-key").await; let encrypted_key = client_dh_encrypt(&mut user_agent, b"wrong-key").await;
@@ -142,7 +119,7 @@ pub async fn test_unseal_wrong_seal_key() {
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
pub async fn test_unseal_corrupted_ciphertext() { pub async fn test_unseal_corrupted_ciphertext() {
let (_db, mut user_agent) = setup_authenticated_user_agent(b"test-key").await; let (_db, mut user_agent) = setup_sealed_user_agent(b"test-key").await;
let client_secret = EphemeralSecret::random(); let client_secret = EphemeralSecret::random();
let client_public = PublicKey::from(&client_secret); let client_public = PublicKey::from(&client_secret);
@@ -171,38 +148,11 @@ pub async fn test_unseal_corrupted_ciphertext() {
); );
} }
#[tokio::test]
#[test_log::test]
pub async fn test_unseal_start_without_auth_fails() {
let db = db::create_test_pool().await;
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
let mut user_agent = UserAgentActor::new_manual(db.clone(), actors);
let client_secret = EphemeralSecret::random();
let client_public = PublicKey::from(&client_secret);
let result = user_agent
.process_transport_inbound(UserAgentRequest {
payload: Some(UserAgentRequestPayload::UnsealStart(UnsealStart {
client_pubkey: client_public.as_bytes().to_vec(),
})),
})
.await;
match result {
Err(err) => {
assert_eq!(err, UserAgentError::StateTransitionFailed);
}
other => panic!("Expected state machine error, got {other:?}"),
}
}
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
pub async fn test_unseal_retry_after_invalid_key() { pub async fn test_unseal_retry_after_invalid_key() {
let seal_key = b"real-seal-key"; let seal_key = b"real-seal-key";
let (_db, mut user_agent) = setup_authenticated_user_agent(seal_key).await; let (_db, mut user_agent) = setup_sealed_user_agent(seal_key).await;
{ {
let encrypted_key = client_dh_encrypt(&mut user_agent, b"wrong-key").await; let encrypted_key = client_dh_encrypt(&mut user_agent, b"wrong-key").await;

View File

@@ -18,3 +18,4 @@ thiserror.workspace = true
tokio-stream.workspace = true tokio-stream.workspace = true
http = "1.4.0" http = "1.4.0"
rustls-webpki = { version = "0.103.9", features = ["aws-lc-rs"] } rustls-webpki = { version = "0.103.9", features = ["aws-lc-rs"] }
async-trait.workspace = true

View File

@@ -13,12 +13,14 @@ use ed25519_dalek::SigningKey;
use kameo::actor::Spawn; use kameo::actor::Spawn;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio::time::{Duration, timeout}; use tokio::time::{Duration, timeout};
use async_trait::async_trait;
struct TestTransport { struct TestTransport {
inbound_rx: mpsc::Receiver<UserAgentResponse>, inbound_rx: mpsc::Receiver<UserAgentResponse>,
outbound_tx: mpsc::Sender<UserAgentRequest>, outbound_tx: mpsc::Sender<UserAgentRequest>,
} }
#[async_trait]
impl Bi<UserAgentResponse, UserAgentRequest> for TestTransport { impl Bi<UserAgentResponse, UserAgentRequest> for TestTransport {
async fn send(&mut self, item: UserAgentRequest) -> Result<(), arbiter_proto::transport::Error> { async fn send(&mut self, item: UserAgentRequest) -> Result<(), arbiter_proto::transport::Error> {
self.outbound_tx self.outbound_tx