refactor(transport): implemented Bi stream based abstraction for actor communication with next loop override

This commit is contained in:
hdbg
2026-02-26 17:15:35 +01:00
parent 7b57965952
commit b8afd94b21
16 changed files with 1120 additions and 344 deletions

View File

@@ -1,21 +1,26 @@
use std::{ops::DerefMut, sync::Mutex};
use arbiter_proto::proto::{
UnsealEncryptedKey, UnsealResult, UnsealStart, UnsealStartResponse, UserAgentResponse,
auth::{
self, AuthChallengeRequest, AuthOk, ServerMessage as AuthServerMessage,
server_message::Payload as ServerAuthPayload,
use arbiter_proto::{
proto::{
UnsealEncryptedKey, UnsealResult, UnsealStart, UnsealStartResponse, UserAgentRequest,
UserAgentResponse,
auth::{
self, AuthChallengeRequest, AuthOk, ClientMessage as ClientAuthMessage,
ServerMessage as AuthServerMessage, client_message::Payload as ClientAuthPayload,
server_message::Payload as ServerAuthPayload,
},
user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
},
user_agent_response::Payload as UserAgentResponsePayload,
transport::{Bi, DummyTransport},
};
use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit};
use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, dsl::update};
use diesel_async::RunQueryDsl;
use ed25519_dalek::VerifyingKey;
use kameo::{Actor, error::SendError, messages};
use kameo::{Actor, error::SendError};
use memsafe::MemSafe;
use tokio::sync::mpsc::Sender;
use tonic::Status;
use tokio::select;
use tracing::{error, info};
use x25519_dalek::{EphemeralSecret, PublicKey};
@@ -31,62 +36,105 @@ use crate::{
},
},
db::{self, schema},
errors::GrpcStatusExt,
};
mod state;
mod transport;
pub(crate) use transport::handle_user_agent;
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
pub enum UserAgentError {
#[error("Expected message with payload")]
MissingRequestPayload,
#[error("Expected message with payload")]
UnexpectedRequestPayload,
#[error("Invalid state for challenge solution")]
InvalidStateForChallengeSolution,
#[error("Invalid state for unseal encrypted key")]
InvalidStateForUnsealEncryptedKey,
#[error("client_pubkey must be 32 bytes")]
InvalidClientPubkeyLength,
#[error("Expected pubkey to have specific length")]
InvalidAuthPubkeyLength,
#[error("Failed to convert pubkey to VerifyingKey")]
InvalidAuthPubkeyEncoding,
#[error("Invalid signature length")]
InvalidSignatureLength,
#[error("Invalid bootstrap token")]
InvalidBootstrapToken,
#[error("Public key not registered")]
PublicKeyNotRegistered,
#[error("Invalid challenge solution")]
InvalidChallengeSolution,
#[error("State machine error")]
StateTransitionFailed,
#[error("Bootstrap token consumption failed")]
BootstrapperActorUnreachable,
#[error("Vault is not available")]
KeyHolderActorUnreachable,
#[error("Database pool error")]
DatabasePoolUnavailable,
#[error("Database error")]
DatabaseOperationFailed,
}
#[derive(Actor)]
pub struct UserAgentActor {
pub struct UserAgentActor<Transport>
where
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
{
db: db::DatabasePool,
actors: GlobalActors,
state: UserAgentStateMachine<DummyContext>,
// will be used in future
_tx: Sender<Result<UserAgentResponse, Status>>,
transport: Transport,
}
impl UserAgentActor {
pub(crate) fn new(
context: ServerContext,
tx: Sender<Result<UserAgentResponse, Status>>,
) -> Self {
impl<Transport> UserAgentActor<Transport>
where
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
{
pub(crate) fn new(context: ServerContext, transport: Transport) -> Self {
Self {
db: context.db.clone(),
actors: context.actors.clone(),
state: UserAgentStateMachine::new(DummyContext),
_tx: tx,
transport,
}
}
pub fn new_manual(
db: db::DatabasePool,
actors: GlobalActors,
tx: Sender<Result<UserAgentResponse, Status>>,
) -> Self {
Self {
db,
actors,
state: UserAgentStateMachine::new(DummyContext),
_tx: tx,
}
}
fn transition(&mut self, event: UserAgentEvents) -> Result<(), Status> {
fn transition(&mut self, event: UserAgentEvents) -> Result<(), UserAgentError> {
self.state.process_event(event).map_err(|e| {
error!(?e, "State transition failed");
Status::internal("State machine error")
UserAgentError::StateTransitionFailed
})?;
Ok(())
}
pub async fn process_transport_inbound(&mut self, req: UserAgentRequest) -> Output {
let msg = req.payload.ok_or_else(|| {
error!(actor = "useragent", "Received message with no payload");
UserAgentError::MissingRequestPayload
})?;
match msg {
UserAgentRequestPayload::AuthMessage(ClientAuthMessage {
payload: Some(ClientAuthPayload::AuthChallengeRequest(req)),
}) => self.handle_auth_challenge_request(req).await,
UserAgentRequestPayload::AuthMessage(ClientAuthMessage {
payload: Some(ClientAuthPayload::AuthChallengeSolution(solution)),
}) => self.handle_auth_challenge_solution(solution).await,
UserAgentRequestPayload::UnsealStart(unseal_start) => {
self.handle_unseal_request(unseal_start).await
}
UserAgentRequestPayload::UnsealEncryptedKey(unseal_encrypted_key) => {
self.handle_unseal_encrypted_key(unseal_encrypted_key).await
}
_ => Err(UserAgentError::UnexpectedRequestPayload),
}
}
async fn auth_with_bootstrap_token(
&mut self,
pubkey: ed25519_dalek::VerifyingKey,
token: String,
) -> Result<UserAgentResponse, Status> {
) -> Result<UserAgentResponse, UserAgentError> {
let token_ok: bool = self
.actors
.bootstrapper
@@ -94,16 +142,19 @@ impl UserAgentActor {
.await
.map_err(|e| {
error!(?pubkey, "Failed to consume bootstrap token: {e}");
Status::internal("Bootstrap token consumption failed")
UserAgentError::BootstrapperActorUnreachable
})?;
if !token_ok {
error!(?pubkey, "Invalid bootstrap token provided");
return Err(Status::invalid_argument("Invalid bootstrap token"));
return Err(UserAgentError::InvalidBootstrapToken);
}
{
let mut conn = self.db.get().await.to_status()?;
let mut conn = self.db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error");
UserAgentError::DatabasePoolUnavailable
})?;
diesel::insert_into(schema::useragent_client::table)
.values((
@@ -112,7 +163,10 @@ impl UserAgentActor {
))
.execute(&mut conn)
.await
.to_status()?;
.map_err(|e| {
error!(error = ?e, "Database error");
UserAgentError::DatabaseOperationFailed
})?;
}
self.transition(UserAgentEvents::ReceivedBootstrapToken)?;
@@ -122,7 +176,10 @@ impl UserAgentActor {
async fn auth_with_challenge(&mut self, pubkey: VerifyingKey, pubkey_bytes: Vec<u8>) -> Output {
let nonce: Option<i32> = {
let mut db_conn = self.db.get().await.to_status()?;
let mut db_conn = self.db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error");
UserAgentError::DatabasePoolUnavailable
})?;
db_conn
.exclusive_transaction(|conn| {
Box::pin(async move {
@@ -147,12 +204,15 @@ impl UserAgentActor {
})
.await
.optional()
.to_status()?
.map_err(|e| {
error!(error = ?e, "Database error");
UserAgentError::DatabaseOperationFailed
})?
};
let Some(nonce) = nonce else {
error!(?pubkey, "Public key not found in database");
return Err(Status::unauthenticated("Public key not registered"));
return Err(UserAgentError::PublicKeyNotRegistered);
};
let challenge = auth::AuthChallenge {
@@ -177,19 +237,17 @@ impl UserAgentActor {
fn verify_challenge_solution(
&self,
solution: &auth::AuthChallengeSolution,
) -> Result<(bool, &ChallengeContext), Status> {
) -> Result<(bool, &ChallengeContext), UserAgentError> {
let UserAgentStates::WaitingForChallengeSolution(challenge_context) = self.state.state()
else {
error!("Received challenge solution in invalid state");
return Err(Status::invalid_argument(
"Invalid state for challenge solution",
));
return Err(UserAgentError::InvalidStateForChallengeSolution);
};
let formatted_challenge = arbiter_proto::format_challenge(&challenge_context.challenge);
let signature = solution.signature.as_slice().try_into().map_err(|_| {
error!(?solution, "Invalid signature length");
Status::invalid_argument("Invalid signature length")
UserAgentError::InvalidSignatureLength
})?;
let valid = challenge_context
@@ -201,7 +259,7 @@ impl UserAgentActor {
}
}
type Output = Result<UserAgentResponse, Status>;
type Output = Result<UserAgentResponse, UserAgentError>;
fn auth_response(payload: ServerAuthPayload) -> UserAgentResponse {
UserAgentResponse {
@@ -217,17 +275,18 @@ fn unseal_response(payload: UserAgentResponsePayload) -> UserAgentResponse {
}
}
#[messages]
impl UserAgentActor {
#[message]
pub async fn handle_unseal_request(&mut self, req: UnsealStart) -> Output {
impl<Transport> UserAgentActor<Transport>
where
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
{
async fn handle_unseal_request(&mut self, req: UnsealStart) -> Output {
let secret = EphemeralSecret::random();
let public_key = PublicKey::from(&secret);
let client_pubkey_bytes: [u8; 32] = req
.client_pubkey
.try_into()
.map_err(|_| Status::invalid_argument("client_pubkey must be 32 bytes"))?;
.map_err(|_| UserAgentError::InvalidClientPubkeyLength)?;
let client_public_key = PublicKey::from(client_pubkey_bytes);
@@ -243,13 +302,10 @@ impl UserAgentActor {
))
}
#[message]
pub async fn handle_unseal_encrypted_key(&mut self, req: UnsealEncryptedKey) -> Output {
async fn handle_unseal_encrypted_key(&mut self, req: UnsealEncryptedKey) -> Output {
let UserAgentStates::WaitingForUnsealKey(unseal_context) = self.state.state() else {
error!("Received unseal encrypted key in invalid state");
return Err(Status::failed_precondition(
"Invalid state for unseal encrypted key",
));
return Err(UserAgentError::InvalidStateForUnsealEncryptedKey);
};
let ephemeral_secret = {
let mut secret_lock = unseal_context.secret.lock().unwrap();
@@ -313,7 +369,7 @@ impl UserAgentActor {
Err(err) => {
error!(?err, "Failed to send unseal request to keyholder");
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Err(Status::internal("Vault is not available"))
Err(UserAgentError::KeyHolderActorUnreachable)
}
}
}
@@ -327,14 +383,14 @@ impl UserAgentActor {
}
}
#[message]
pub async fn handle_auth_challenge_request(&mut self, req: AuthChallengeRequest) -> Output {
let pubkey = req.pubkey.as_array().ok_or(Status::invalid_argument(
"Expected pubkey to have specific length",
))?;
async fn handle_auth_challenge_request(&mut self, req: AuthChallengeRequest) -> Output {
let pubkey = req
.pubkey
.as_array()
.ok_or(UserAgentError::InvalidAuthPubkeyLength)?;
let pubkey = VerifyingKey::from_bytes(pubkey).map_err(|_err| {
error!(?pubkey, "Failed to convert to VerifyingKey");
Status::invalid_argument("Failed to convert pubkey to VerifyingKey")
UserAgentError::InvalidAuthPubkeyEncoding
})?;
self.transition(UserAgentEvents::AuthRequest)?;
@@ -345,8 +401,7 @@ impl UserAgentActor {
}
}
#[message]
pub async fn handle_auth_challenge_solution(
async fn handle_auth_challenge_solution(
&mut self,
solution: auth::AuthChallengeSolution,
) -> Output {
@@ -362,7 +417,72 @@ impl UserAgentActor {
} else {
error!("Client provided invalid solution to authentication challenge");
self.transition(UserAgentEvents::ReceivedBadSolution)?;
Err(Status::unauthenticated("Invalid challenge solution"))
Err(UserAgentError::InvalidChallengeSolution)
}
}
}
impl<Transport> Actor for UserAgentActor<Transport>
where
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
{
type Args = Self;
type Error = ();
async fn on_start(
args: Self::Args,
_: kameo::prelude::ActorRef<Self>,
) -> Result<Self, Self::Error> {
Ok(args)
}
async fn next(
&mut self,
_actor_ref: kameo::prelude::WeakActorRef<Self>,
mailbox_rx: &mut kameo::prelude::MailboxReceiver<Self>,
) -> Option<kameo::mailbox::Signal<Self>> {
loop {
select! {
signal = mailbox_rx.recv() => {
return signal;
}
msg = self.transport.recv() => {
match msg {
Some(request) => {
match self.process_transport_inbound(request).await {
Ok(response) => {
if self.transport.send(Ok(response)).await.is_err() {
error!(actor = "useragent", reason = "channel closed", "send.failed");
return Some(kameo::mailbox::Signal::Stop);
}
}
Err(err) => {
let _ = self.transport.send(Err(err)).await;
return Some(kameo::mailbox::Signal::Stop);
}
}
}
None => {
info!(actor = "useragent", "transport.closed");
return Some(kameo::mailbox::Signal::Stop);
}
}
}
}
}
}
}
impl UserAgentActor<DummyTransport<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>> {
pub fn new_manual(db: db::DatabasePool, actors: GlobalActors) -> Self {
Self {
db,
actors,
state: UserAgentStateMachine::new(DummyContext),
transport: DummyTransport::new(),
}
}
}

View File

@@ -1,95 +0,0 @@
use super::UserAgentActor;
use arbiter_proto::proto::{
UserAgentRequest, UserAgentResponse,
auth::{ClientMessage as ClientAuthMessage, client_message::Payload as ClientAuthPayload},
user_agent_request::Payload as UserAgentRequestPayload,
};
use futures::StreamExt;
use kameo::{
actor::{ActorRef, Spawn as _},
error::SendError,
};
use tokio::sync::mpsc;
use tonic::Status;
use tracing::error;
use crate::{
actors::user_agent::{
HandleAuthChallengeRequest, HandleAuthChallengeSolution, HandleUnsealEncryptedKey,
HandleUnsealRequest,
},
context::ServerContext,
};
pub(crate) async fn handle_user_agent(
context: ServerContext,
mut req_stream: tonic::Streaming<UserAgentRequest>,
tx: mpsc::Sender<Result<UserAgentResponse, Status>>,
) {
let actor = UserAgentActor::spawn(UserAgentActor::new(context, tx.clone()));
while let Some(Ok(req)) = req_stream.next().await
&& actor.is_alive()
{
match process_message(&actor, req).await {
Ok(resp) => {
if tx.send(Ok(resp)).await.is_err() {
error!(actor = "useragent", "Failed to send response to client");
break;
}
}
Err(status) => {
let _ = tx.send(Err(status)).await;
break;
}
}
}
actor.kill();
}
async fn process_message(
actor: &ActorRef<UserAgentActor>,
req: UserAgentRequest,
) -> Result<UserAgentResponse, Status> {
let msg = req.payload.ok_or_else(|| {
error!(actor = "useragent", "Received message with no payload");
Status::invalid_argument("Expected message with payload")
})?;
match msg {
UserAgentRequestPayload::AuthMessage(ClientAuthMessage {
payload: Some(ClientAuthPayload::AuthChallengeRequest(req)),
}) => actor
.ask(HandleAuthChallengeRequest { req })
.await
.map_err(into_status),
UserAgentRequestPayload::AuthMessage(ClientAuthMessage {
payload: Some(ClientAuthPayload::AuthChallengeSolution(solution)),
}) => actor
.ask(HandleAuthChallengeSolution { solution })
.await
.map_err(into_status),
UserAgentRequestPayload::UnsealStart(unseal_start) => actor
.ask(HandleUnsealRequest { req: unseal_start })
.await
.map_err(into_status),
UserAgentRequestPayload::UnsealEncryptedKey(unseal_encrypted_key) => actor
.ask(HandleUnsealEncryptedKey {
req: unseal_encrypted_key,
})
.await
.map_err(into_status),
_ => Err(Status::invalid_argument("Expected message with payload")),
}
}
fn into_status<M>(e: SendError<M, Status>) -> Status {
match e {
SendError::HandlerError(status) => status,
_ => {
error!(actor = "useragent", "Failed to send message to actor");
Status::internal("session failure")
}
}
}