refactor: consolidate auth messages into client and user_agent packages
Some checks failed
ci/woodpecker/pr/server-lint Pipeline failed
ci/woodpecker/pr/server-audit Pipeline was successful
ci/woodpecker/pr/server-vet Pipeline failed
ci/woodpecker/pr/server-test Pipeline was successful
ci/woodpecker/push/server-lint Pipeline failed
ci/woodpecker/push/server-audit Pipeline was successful
ci/woodpecker/push/server-vet Pipeline failed
ci/woodpecker/push/server-test Pipeline was successful

This commit was merged in pull request #23.
This commit is contained in:
hdbg
2026-03-01 11:44:34 +01:00
parent 3cc63474a8
commit 4169b2ba42
19 changed files with 686 additions and 264 deletions

View File

@@ -1,12 +0,0 @@
use arbiter_proto::{
proto::{ClientRequest, ClientResponse},
transport::Bi,
};
use crate::ServerContext;
pub(crate) async fn handle_client(
_context: ServerContext,
_bistream: impl Bi<ClientRequest, ClientResponse>,
) {
}

View File

@@ -0,0 +1,289 @@
use arbiter_proto::{
proto::client::{
AuthChallenge, AuthChallengeRequest, AuthChallengeSolution, AuthOk, ClientRequest,
ClientResponse,
client_request::Payload as ClientRequestPayload,
client_response::Payload as ClientResponsePayload,
},
transport::{Bi, DummyTransport},
};
use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, dsl::update};
use diesel_async::RunQueryDsl;
use ed25519_dalek::VerifyingKey;
use kameo::Actor;
use tokio::select;
use tracing::{error, info};
use crate::{
ServerContext,
actors::client::state::{
ChallengeContext, ClientEvents, ClientStateMachine, ClientStates, DummyContext,
},
db::{self, schema},
};
mod state;
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
pub enum ClientError {
#[error("Expected message with payload")]
MissingRequestPayload,
#[error("Unexpected request payload")]
UnexpectedRequestPayload,
#[error("Invalid state for challenge solution")]
InvalidStateForChallengeSolution,
#[error("Expected pubkey to have specific length")]
InvalidAuthPubkeyLength,
#[error("Failed to convert pubkey to VerifyingKey")]
InvalidAuthPubkeyEncoding,
#[error("Invalid signature length")]
InvalidSignatureLength,
#[error("Public key not registered")]
PublicKeyNotRegistered,
#[error("Invalid challenge solution")]
InvalidChallengeSolution,
#[error("State machine error")]
StateTransitionFailed,
#[error("Database pool error")]
DatabasePoolUnavailable,
#[error("Database error")]
DatabaseOperationFailed,
}
pub struct ClientActor<Transport>
where
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
{
db: db::DatabasePool,
state: ClientStateMachine<DummyContext>,
transport: Transport,
}
impl<Transport> ClientActor<Transport>
where
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
{
pub(crate) fn new(context: ServerContext, transport: Transport) -> Self {
Self {
db: context.db.clone(),
state: ClientStateMachine::new(DummyContext),
transport,
}
}
fn transition(&mut self, event: ClientEvents) -> Result<(), ClientError> {
self.state.process_event(event).map_err(|e| {
error!(?e, "State transition failed");
ClientError::StateTransitionFailed
})?;
Ok(())
}
pub async fn process_transport_inbound(&mut self, req: ClientRequest) -> Output {
let msg = req.payload.ok_or_else(|| {
error!(actor = "client", "Received message with no payload");
ClientError::MissingRequestPayload
})?;
match msg {
ClientRequestPayload::AuthChallengeRequest(req) => {
self.handle_auth_challenge_request(req).await
}
ClientRequestPayload::AuthChallengeSolution(solution) => {
self.handle_auth_challenge_solution(solution).await
}
}
}
async fn handle_auth_challenge_request(&mut self, req: AuthChallengeRequest) -> Output {
let pubkey = req
.pubkey
.as_array()
.ok_or(ClientError::InvalidAuthPubkeyLength)?;
let pubkey = VerifyingKey::from_bytes(pubkey).map_err(|_err| {
error!(?pubkey, "Failed to convert to VerifyingKey");
ClientError::InvalidAuthPubkeyEncoding
})?;
self.transition(ClientEvents::AuthRequest)?;
self.auth_with_challenge(pubkey, req.pubkey).await
}
async fn auth_with_challenge(&mut self, pubkey: VerifyingKey, pubkey_bytes: Vec<u8>) -> Output {
let nonce: Option<i32> = {
let mut db_conn = self.db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error");
ClientError::DatabasePoolUnavailable
})?;
db_conn
.exclusive_transaction(|conn| {
Box::pin(async move {
let current_nonce = schema::program_client::table
.filter(
schema::program_client::public_key.eq(pubkey.as_bytes().to_vec()),
)
.select(schema::program_client::nonce)
.first::<i32>(conn)
.await?;
update(schema::program_client::table)
.filter(
schema::program_client::public_key.eq(pubkey.as_bytes().to_vec()),
)
.set(schema::program_client::nonce.eq(current_nonce + 1))
.execute(conn)
.await?;
Result::<_, diesel::result::Error>::Ok(current_nonce)
})
})
.await
.optional()
.map_err(|e| {
error!(error = ?e, "Database error");
ClientError::DatabaseOperationFailed
})?
};
let Some(nonce) = nonce else {
error!(?pubkey, "Public key not found in database");
return Err(ClientError::PublicKeyNotRegistered);
};
let challenge = AuthChallenge {
pubkey: pubkey_bytes,
nonce,
};
self.transition(ClientEvents::SentChallenge(ChallengeContext {
challenge: challenge.clone(),
key: pubkey,
}))?;
info!(
?pubkey,
?challenge,
"Sent authentication challenge to client"
);
Ok(response(ClientResponsePayload::AuthChallenge(challenge)))
}
fn verify_challenge_solution(
&self,
solution: &AuthChallengeSolution,
) -> Result<(bool, &ChallengeContext), ClientError> {
let ClientStates::WaitingForChallengeSolution(challenge_context) = self.state.state()
else {
error!("Received challenge solution in invalid state");
return Err(ClientError::InvalidStateForChallengeSolution);
};
let formatted_challenge = arbiter_proto::format_challenge(
challenge_context.challenge.nonce,
&challenge_context.challenge.pubkey,
);
let signature = solution.signature.as_slice().try_into().map_err(|_| {
error!(?solution, "Invalid signature length");
ClientError::InvalidSignatureLength
})?;
let valid = challenge_context
.key
.verify_strict(&formatted_challenge, &signature)
.is_ok();
Ok((valid, challenge_context))
}
async fn handle_auth_challenge_solution(
&mut self,
solution: AuthChallengeSolution,
) -> Output {
let (valid, challenge_context) = self.verify_challenge_solution(&solution)?;
if valid {
info!(
?challenge_context,
"Client provided valid solution to authentication challenge"
);
self.transition(ClientEvents::ReceivedGoodSolution)?;
Ok(response(ClientResponsePayload::AuthOk(AuthOk {})))
} else {
error!("Client provided invalid solution to authentication challenge");
self.transition(ClientEvents::ReceivedBadSolution)?;
Err(ClientError::InvalidChallengeSolution)
}
}
}
type Output = Result<ClientResponse, ClientError>;
fn response(payload: ClientResponsePayload) -> ClientResponse {
ClientResponse {
payload: Some(payload),
}
}
impl<Transport> Actor for ClientActor<Transport>
where
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
{
type Args = Self;
type Error = ();
async fn on_start(
args: Self::Args,
_: kameo::prelude::ActorRef<Self>,
) -> Result<Self, Self::Error> {
Ok(args)
}
async fn next(
&mut self,
_actor_ref: kameo::prelude::WeakActorRef<Self>,
mailbox_rx: &mut kameo::prelude::MailboxReceiver<Self>,
) -> Option<kameo::mailbox::Signal<Self>> {
loop {
select! {
signal = mailbox_rx.recv() => {
return signal;
}
msg = self.transport.recv() => {
match msg {
Some(request) => {
match self.process_transport_inbound(request).await {
Ok(resp) => {
if self.transport.send(Ok(resp)).await.is_err() {
error!(actor = "client", reason = "channel closed", "send.failed");
return Some(kameo::mailbox::Signal::Stop);
}
}
Err(err) => {
let _ = self.transport.send(Err(err)).await;
return Some(kameo::mailbox::Signal::Stop);
}
}
}
None => {
info!(actor = "client", "transport.closed");
return Some(kameo::mailbox::Signal::Stop);
}
}
}
}
}
}
}
impl ClientActor<DummyTransport<ClientRequest, Result<ClientResponse, ClientError>>> {
pub fn new_manual(db: db::DatabasePool) -> Self {
Self {
db,
state: ClientStateMachine::new(DummyContext),
transport: DummyTransport::new(),
}
}
}

View File

@@ -0,0 +1,31 @@
use arbiter_proto::proto::client::AuthChallenge;
use ed25519_dalek::VerifyingKey;
/// Context for state machine with validated key and sent challenge
#[derive(Clone, Debug)]
pub struct ChallengeContext {
pub challenge: AuthChallenge,
pub key: VerifyingKey,
}
smlang::statemachine!(
name: Client,
custom_error: false,
transitions: {
*Init + AuthRequest = ReceivedAuthRequest,
ReceivedAuthRequest + SentChallenge(ChallengeContext) / move_challenge = WaitingForChallengeSolution(ChallengeContext),
WaitingForChallengeSolution(ChallengeContext) + ReceivedGoodSolution = Idle,
WaitingForChallengeSolution(ChallengeContext) + ReceivedBadSolution = AuthError,
}
);
pub struct DummyContext;
impl ClientStateMachineContext for DummyContext {
#[allow(missing_docs)]
#[allow(clippy::unused_unit)]
fn move_challenge(&mut self, event_data: ChallengeContext) -> Result<ChallengeContext, ()> {
Ok(event_data)
}
}

View File

@@ -1,14 +1,9 @@
use std::{ops::DerefMut, sync::Mutex};
use arbiter_proto::{
proto::{
UnsealEncryptedKey, UnsealResult, UnsealStart, UnsealStartResponse, UserAgentRequest,
UserAgentResponse,
auth::{
self, AuthChallengeRequest, AuthOk, ClientMessage as ClientAuthMessage,
ServerMessage as AuthServerMessage, client_message::Payload as ClientAuthPayload,
server_message::Payload as ServerAuthPayload,
},
proto::user_agent::{
AuthChallenge, AuthChallengeRequest, AuthChallengeSolution, AuthOk, UnsealEncryptedKey,
UnsealResult, UnsealStart, UnsealStartResponse, UserAgentRequest, UserAgentResponse,
user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
},
@@ -114,12 +109,12 @@ where
})?;
match msg {
UserAgentRequestPayload::AuthMessage(ClientAuthMessage {
payload: Some(ClientAuthPayload::AuthChallengeRequest(req)),
}) => self.handle_auth_challenge_request(req).await,
UserAgentRequestPayload::AuthMessage(ClientAuthMessage {
payload: Some(ClientAuthPayload::AuthChallengeSolution(solution)),
}) => self.handle_auth_challenge_solution(solution).await,
UserAgentRequestPayload::AuthChallengeRequest(req) => {
self.handle_auth_challenge_request(req).await
}
UserAgentRequestPayload::AuthChallengeSolution(solution) => {
self.handle_auth_challenge_solution(solution).await
}
UserAgentRequestPayload::UnsealStart(unseal_start) => {
self.handle_unseal_request(unseal_start).await
}
@@ -171,7 +166,7 @@ where
self.transition(UserAgentEvents::ReceivedBootstrapToken)?;
Ok(auth_response(ServerAuthPayload::AuthOk(AuthOk {})))
Ok(response(UserAgentResponsePayload::AuthOk(AuthOk {})))
}
async fn auth_with_challenge(&mut self, pubkey: VerifyingKey, pubkey_bytes: Vec<u8>) -> Output {
@@ -215,7 +210,7 @@ where
return Err(UserAgentError::PublicKeyNotRegistered);
};
let challenge = auth::AuthChallenge {
let challenge = AuthChallenge {
pubkey: pubkey_bytes,
nonce,
};
@@ -231,19 +226,22 @@ where
"Sent authentication challenge to client"
);
Ok(auth_response(ServerAuthPayload::AuthChallenge(challenge)))
Ok(response(UserAgentResponsePayload::AuthChallenge(challenge)))
}
fn verify_challenge_solution(
&self,
solution: &auth::AuthChallengeSolution,
solution: &AuthChallengeSolution,
) -> Result<(bool, &ChallengeContext), UserAgentError> {
let UserAgentStates::WaitingForChallengeSolution(challenge_context) = self.state.state()
else {
error!("Received challenge solution in invalid state");
return Err(UserAgentError::InvalidStateForChallengeSolution);
};
let formatted_challenge = arbiter_proto::format_challenge(&challenge_context.challenge);
let formatted_challenge = arbiter_proto::format_challenge(
challenge_context.challenge.nonce,
&challenge_context.challenge.pubkey,
);
let signature = solution.signature.as_slice().try_into().map_err(|_| {
error!(?solution, "Invalid signature length");
@@ -261,15 +259,7 @@ where
type Output = Result<UserAgentResponse, UserAgentError>;
fn auth_response(payload: ServerAuthPayload) -> UserAgentResponse {
UserAgentResponse {
payload: Some(UserAgentResponsePayload::AuthMessage(AuthServerMessage {
payload: Some(payload),
})),
}
}
fn unseal_response(payload: UserAgentResponsePayload) -> UserAgentResponse {
fn response(payload: UserAgentResponsePayload) -> UserAgentResponse {
UserAgentResponse {
payload: Some(payload),
}
@@ -295,7 +285,7 @@ where
client_public_key,
}))?;
Ok(unseal_response(
Ok(response(
UserAgentResponsePayload::UnsealStartResponse(UnsealStartResponse {
server_pubkey: public_key.as_bytes().to_vec(),
}),
@@ -316,7 +306,7 @@ where
drop(secret_lock);
error!("Ephemeral secret already taken");
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
return Ok(unseal_response(UserAgentResponsePayload::UnsealResult(
return Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(),
)));
}
@@ -349,20 +339,20 @@ where
Ok(_) => {
info!("Successfully unsealed key with client-provided key");
self.transition(UserAgentEvents::ReceivedValidKey)?;
Ok(unseal_response(UserAgentResponsePayload::UnsealResult(
Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::Success.into(),
)))
}
Err(SendError::HandlerError(keyholder::Error::InvalidKey)) => {
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(unseal_response(UserAgentResponsePayload::UnsealResult(
Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(),
)))
}
Err(SendError::HandlerError(err)) => {
error!(?err, "Keyholder failed to unseal key");
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(unseal_response(UserAgentResponsePayload::UnsealResult(
Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(),
)))
}
@@ -376,7 +366,7 @@ where
Err(err) => {
error!(?err, "Failed to decrypt unseal key");
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(unseal_response(UserAgentResponsePayload::UnsealResult(
Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(),
)))
}
@@ -403,7 +393,7 @@ where
async fn handle_auth_challenge_solution(
&mut self,
solution: auth::AuthChallengeSolution,
solution: AuthChallengeSolution,
) -> Output {
let (valid, challenge_context) = self.verify_challenge_solution(&solution)?;
@@ -413,7 +403,7 @@ where
"Client provided valid solution to authentication challenge"
);
self.transition(UserAgentEvents::ReceivedGoodSolution)?;
Ok(auth_response(ServerAuthPayload::AuthOk(AuthOk {})))
Ok(response(UserAgentResponsePayload::AuthOk(AuthOk {})))
} else {
error!("Client provided invalid solution to authentication challenge");
self.transition(UserAgentEvents::ReceivedBadSolution)?;

View File

@@ -1,6 +1,6 @@
use std::sync::Mutex;
use arbiter_proto::proto::auth::AuthChallenge;
use arbiter_proto::proto::user_agent::AuthChallenge;
use ed25519_dalek::VerifyingKey;
use x25519_dalek::{EphemeralSecret, PublicKey};

View File

@@ -1,6 +1,9 @@
#![forbid(unsafe_code)]
use arbiter_proto::{
proto::{ClientRequest, ClientResponse, UserAgentRequest, UserAgentResponse},
proto::{
client::{ClientRequest, ClientResponse},
user_agent::{UserAgentRequest, UserAgentResponse},
},
transport::{IdentityRecvConverter, SendConverter, grpc},
};
use async_trait::async_trait;
@@ -12,7 +15,10 @@ use tonic::{Request, Response, Status};
use tracing::info;
use crate::{
actors::user_agent::{UserAgentActor, UserAgentError},
actors::{
client::{ClientActor, ClientError},
user_agent::{UserAgentActor, UserAgentError},
},
context::ServerContext,
};
@@ -41,6 +47,56 @@ impl SendConverter for UserAgentGrpcSender {
}
}
/// Converts Client domain outbounds into the tonic stream item emitted by the
/// server.
///
/// The conversion is defined at the server boundary so the actor module remains
/// focused on domain semantics and does not depend on tonic status encoding.
struct ClientGrpcSender;
impl SendConverter for ClientGrpcSender {
type Input = Result<ClientResponse, ClientError>;
type Output = Result<ClientResponse, Status>;
fn convert(&self, item: Self::Input) -> Self::Output {
match item {
Ok(message) => Ok(message),
Err(err) => Err(client_error_status(err)),
}
}
}
/// Maps Client domain errors to public gRPC transport errors for the `client`
/// streaming endpoint.
fn client_error_status(value: ClientError) -> Status {
match value {
ClientError::MissingRequestPayload | ClientError::UnexpectedRequestPayload => {
Status::invalid_argument("Expected message with payload")
}
ClientError::InvalidStateForChallengeSolution => {
Status::invalid_argument("Invalid state for challenge solution")
}
ClientError::InvalidAuthPubkeyLength => {
Status::invalid_argument("Expected pubkey to have specific length")
}
ClientError::InvalidAuthPubkeyEncoding => {
Status::invalid_argument("Failed to convert pubkey to VerifyingKey")
}
ClientError::InvalidSignatureLength => {
Status::invalid_argument("Invalid signature length")
}
ClientError::PublicKeyNotRegistered => {
Status::unauthenticated("Public key not registered")
}
ClientError::InvalidChallengeSolution => {
Status::unauthenticated("Invalid challenge solution")
}
ClientError::StateTransitionFailed => Status::internal("State machine error"),
ClientError::DatabasePoolUnavailable => Status::internal("Database pool error"),
ClientError::DatabaseOperationFailed => Status::internal("Database error"),
}
}
/// Maps User Agent domain errors to public gRPC transport errors for the
/// `user_agent` streaming endpoint.
fn user_agent_error_status(value: UserAgentError) -> Status {
@@ -100,11 +156,25 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server {
type UserAgentStream = ReceiverStream<Result<UserAgentResponse, Status>>;
type ClientStream = ReceiverStream<Result<ClientResponse, Status>>;
#[tracing::instrument(level = "debug", skip(self))]
async fn client(
&self,
_request: Request<tonic::Streaming<ClientRequest>>,
request: Request<tonic::Streaming<ClientRequest>>,
) -> Result<Response<Self::ClientStream>, Status> {
todo!()
let req_stream = request.into_inner();
let (tx, rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
let transport = grpc::GrpcAdapter::new(
tx,
req_stream,
IdentityRecvConverter::<ClientRequest>::new(),
ClientGrpcSender,
);
ClientActor::spawn(ClientActor::new(self.context.clone(), transport));
info!(event = "connection established", "grpc.client");
Ok(Response::new(ReceiverStream::new(rx)))
}
#[tracing::instrument(level = "debug", skip(self))]