feat(transport): add domain error type to GrpcTransportActor
This commit is contained in:
@@ -1,7 +1,56 @@
|
||||
//! Transport abstraction layer for bridging gRPC bidirectional streaming with kameo actors.
|
||||
//!
|
||||
//! This module provides a clean separation between the gRPC transport layer and business logic
|
||||
//! by modeling the connection as two linked kameo actors:
|
||||
//!
|
||||
//! - A **transport actor** ([`GrpcTransportActor`]) that owns the gRPC stream and channel,
|
||||
//! forwarding inbound messages to the business actor and outbound messages to the client.
|
||||
//! - A **business logic actor** that receives inbound messages from the transport actor and
|
||||
//! sends outbound messages back through the transport actor.
|
||||
//!
|
||||
//! The [`wire()`] function sets up bidirectional linking between the two actors, ensuring
|
||||
//! that if either actor dies, the other is notified and can shut down gracefully.
|
||||
//!
|
||||
//! # Terminology
|
||||
//!
|
||||
//! - **InboundMessage**: a message received by the transport actor from the channel/socket
|
||||
//! and forwarded to the business actor.
|
||||
//! - **OutboundMessage**: a message produced by the business actor and sent to the transport
|
||||
//! actor to be forwarded to the channel/socket.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! gRPC Stream ──InboundMessage──▶ GrpcTransportActor ──tell(InboundMessage)──▶ BusinessActor
|
||||
//! ▲ │
|
||||
//! └─tell(Result<OutboundMessage, _>)────┘
|
||||
//! │
|
||||
//! mpsc::Sender ──▶ Client
|
||||
//! ```
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! let (tx, rx) = mpsc::channel(1000);
|
||||
//! let context = server_context.clone();
|
||||
//!
|
||||
//! wire(
|
||||
//! |transport_ref| MyBusinessActor::new(context, transport_ref),
|
||||
//! |business_recipient, business_id| GrpcTransportActor {
|
||||
//! sender: tx,
|
||||
//! receiver: grpc_stream,
|
||||
//! business_logic_actor: business_recipient,
|
||||
//! business_logic_actor_id: business_id,
|
||||
//! },
|
||||
//! ).await;
|
||||
//!
|
||||
//! Ok(Response::new(ReceiverStream::new(rx)))
|
||||
//! ```
|
||||
|
||||
use futures::{Stream, StreamExt};
|
||||
use kameo::{
|
||||
Actor,
|
||||
actor::{ActorRef, PreparedActor, Spawn, WeakActorRef},
|
||||
actor::{ActorRef, PreparedActor, Recipient, Spawn, WeakActorRef},
|
||||
mailbox::Signal,
|
||||
prelude::Message,
|
||||
};
|
||||
@@ -12,7 +61,15 @@ use tokio::{
|
||||
use tonic::{Status, Streaming};
|
||||
use tracing::{debug, error};
|
||||
|
||||
// Abstraction for stream for sans-io capabilities
|
||||
/// A bidirectional stream abstraction for sans-io testing.
|
||||
///
|
||||
/// Combines a [`Stream`] of incoming messages with the ability to [`send`](Bi::send)
|
||||
/// outgoing responses. This trait allows business logic to be tested without a real
|
||||
/// gRPC connection by swapping in an in-memory implementation.
|
||||
///
|
||||
/// # Type Parameters
|
||||
/// - `T`: `InboundMessage` received from the channel/socket (e.g., `UserAgentRequest`)
|
||||
/// - `U`: `OutboundMessage` sent to the channel/socket (e.g., `UserAgentResponse`)
|
||||
pub trait Bi<T, U>: Stream<Item = Result<T, Status>> + Send + Sync + 'static {
|
||||
type Error;
|
||||
fn send(
|
||||
@@ -21,7 +78,10 @@ pub trait Bi<T, U>: Stream<Item = Result<T, Status>> + Send + Sync + 'static {
|
||||
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
|
||||
}
|
||||
|
||||
// Bi-directional stream abstraction for handling gRPC streaming requests and responses
|
||||
/// Concrete [`Bi`] implementation backed by a tonic gRPC [`Streaming`] and an [`mpsc::Sender`].
|
||||
///
|
||||
/// This is the production implementation used in gRPC service handlers. The `request_stream`
|
||||
/// receives messages from the client, and `response_sender` sends responses back.
|
||||
pub struct BiStream<T, U> {
|
||||
pub request_stream: Streaming<T>,
|
||||
pub response_sender: mpsc::Sender<Result<U, Status>>,
|
||||
@@ -54,23 +114,95 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub trait TransportActor<T: Send + 'static>: Actor + Send + Message<T> {}
|
||||
|
||||
pub struct GrpcTransportActor<SendMsg, RecvMsg, A>
|
||||
where
|
||||
SendMsg: Send + 'static,
|
||||
RecvMsg: Send + 'static,
|
||||
A: TransportActor<RecvMsg>,
|
||||
/// Marker trait for transport actors that can receive outbound messages of type `T`.
|
||||
///
|
||||
/// Implement this on your transport actor to indicate it can handle outbound messages
|
||||
/// produced by the business actor. Requires the actor to implement [`Message<Result<T, E>>`]
|
||||
/// so business logic can forward responses via [`tell()`](ActorRef::tell).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// #[derive(Actor)]
|
||||
/// struct MyTransportActor { /* ... */ }
|
||||
///
|
||||
/// impl Message<Result<MyResponse, MyError>> for MyTransportActor {
|
||||
/// type Reply = ();
|
||||
/// async fn handle(&mut self, msg: Result<MyResponse, MyError>, _ctx: &mut Context<Self, Self::Reply>) -> Self::Reply {
|
||||
/// // forward outbound message to channel/socket
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// impl TransportActor<MyResponse, MyError> for MyTransportActor {}
|
||||
/// ```
|
||||
pub trait TransportActor<Outbound: Send + 'static, DomainError: Send + 'static>:
|
||||
Actor + Send + Message<Result<Outbound, DomainError>>
|
||||
{
|
||||
pub sender: mpsc::Sender<Result<SendMsg, tonic::Status>>,
|
||||
pub receiver: tonic::Streaming<RecvMsg>,
|
||||
pub business_logic_actor: ActorRef<A>,
|
||||
}
|
||||
impl<SendMsg, RecvMsg, A> Actor for GrpcTransportActor<SendMsg, RecvMsg, A>
|
||||
|
||||
/// A kameo actor that bridges a gRPC bidirectional stream with a business logic actor.
|
||||
///
|
||||
/// This actor owns the gRPC [`Streaming`] receiver and an [`mpsc::Sender`] for responses.
|
||||
/// It multiplexes between its own mailbox (for outbound messages from the business actor)
|
||||
/// and the gRPC stream (for inbound client messages) using [`tokio::select!`].
|
||||
///
|
||||
/// # Message Flow
|
||||
///
|
||||
/// - **Inbound**: Messages from the gRPC stream are forwarded to `business_logic_actor`
|
||||
/// via [`tell()`](Recipient::tell).
|
||||
/// - **Outbound**: The business actor sends `Result<Outbound, DomainError>` messages to this
|
||||
/// actor, which forwards them through the `sender` channel to the gRPC response stream.
|
||||
///
|
||||
/// # Lifecycle
|
||||
///
|
||||
/// - If the business logic actor dies (detected via actor linking), this actor stops,
|
||||
/// which closes the gRPC stream.
|
||||
/// - If the gRPC stream closes or errors, this actor stops, which (via linking) notifies
|
||||
/// the business actor.
|
||||
/// - Error responses (`Err(DomainError)`) are forwarded to the client and then the actor stops,
|
||||
/// closing the connection.
|
||||
///
|
||||
/// # Type Parameters
|
||||
/// - `Outbound`: `OutboundMessage` sent to the client (e.g., `UserAgentResponse`)
|
||||
/// - `Inbound`: `InboundMessage` received from the client (e.g., `UserAgentRequest`)
|
||||
/// - `E`: The domain error type, must implement `Into<tonic::Status>` for gRPC conversion
|
||||
pub struct GrpcTransportActor<Outbound, Inbound, DomainError>
|
||||
where
|
||||
SendMsg: Send + 'static,
|
||||
RecvMsg: Send + 'static,
|
||||
A: TransportActor<RecvMsg>,
|
||||
Outbound: Send + 'static,
|
||||
Inbound: Send + 'static,
|
||||
DomainError: Into<tonic::Status> + Send + 'static,
|
||||
{
|
||||
sender: mpsc::Sender<Result<Outbound, tonic::Status>>,
|
||||
receiver: tonic::Streaming<Inbound>,
|
||||
business_logic_actor: Recipient<Inbound>,
|
||||
_error: std::marker::PhantomData<DomainError>,
|
||||
}
|
||||
|
||||
impl<Outbound, Inbound, DomainError> GrpcTransportActor<Outbound, Inbound, DomainError>
|
||||
where
|
||||
Outbound: Send + 'static,
|
||||
Inbound: Send + 'static,
|
||||
DomainError: Into<tonic::Status> + Send + 'static,
|
||||
{
|
||||
pub fn new(
|
||||
sender: mpsc::Sender<Result<Outbound, tonic::Status>>,
|
||||
receiver: tonic::Streaming<Inbound>,
|
||||
business_logic_actor: Recipient<Inbound>,
|
||||
) -> Self {
|
||||
Self {
|
||||
sender,
|
||||
receiver,
|
||||
business_logic_actor,
|
||||
_error: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Outbound, Inbound, E> Actor for GrpcTransportActor<Outbound, Inbound, E>
|
||||
where
|
||||
Outbound: Send + 'static,
|
||||
Inbound: Send + 'static,
|
||||
E: Into<tonic::Status> + Send + 'static,
|
||||
{
|
||||
type Args = Self;
|
||||
|
||||
@@ -139,19 +271,27 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<SendMsg: Send + 'static, RecvMsg: Send + 'static, A: TransportActor<RecvMsg>> Message<SendMsg>
|
||||
for GrpcTransportActor<SendMsg, RecvMsg, A>
|
||||
impl<Outbound, Inbound, E> Message<Result<Outbound, E>> for GrpcTransportActor<Outbound, Inbound, E>
|
||||
where
|
||||
Outbound: Send + 'static,
|
||||
Inbound: Send + 'static,
|
||||
E: Into<tonic::Status> + Send + 'static,
|
||||
{
|
||||
type Reply = ();
|
||||
|
||||
async fn handle(
|
||||
&mut self,
|
||||
msg: SendMsg,
|
||||
msg: Result<Outbound, E>,
|
||||
ctx: &mut kameo::prelude::Context<Self, Self::Reply>,
|
||||
) -> Self::Reply {
|
||||
let err = self.sender.send(Ok(msg)).await;
|
||||
match err {
|
||||
Ok(_) => (),
|
||||
let is_err = msg.is_err();
|
||||
let grpc_msg = msg.map_err(Into::into);
|
||||
match self.sender.send(grpc_msg).await {
|
||||
Ok(_) => {
|
||||
if is_err {
|
||||
ctx.stop();
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to send message: {}", e);
|
||||
ctx.stop();
|
||||
@@ -160,20 +300,60 @@ impl<SendMsg: Send + 'static, RecvMsg: Send + 'static, A: TransportActor<RecvMsg
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn wire<T, RecvMsg, SendMsg, BusinessActor, BusinessCtor, TransportCtor>(
|
||||
impl<Outbound, Inbound, E> TransportActor<Outbound, E> for GrpcTransportActor<Outbound, Inbound, E>
|
||||
where
|
||||
Outbound: Send + 'static,
|
||||
Inbound: Send + 'static,
|
||||
E: Into<tonic::Status> + Send + 'static,
|
||||
{
|
||||
}
|
||||
|
||||
/// Wires together a transport actor and a business logic actor with bidirectional linking.
|
||||
///
|
||||
/// This function handles the chicken-and-egg problem of two actors that need references
|
||||
/// to each other at construction time. It uses kameo's [`PreparedActor`] to obtain
|
||||
/// [`ActorRef`]s before spawning, then links both actors so that if either dies,
|
||||
/// the other is notified via [`on_link_died`](Actor::on_link_died).
|
||||
///
|
||||
/// The business actor receives a type-erased [`Recipient<Result<Outbound, DomainError>>`] instead of an
|
||||
/// `ActorRef<Transport>`, keeping it decoupled from the concrete transport implementation.
|
||||
///
|
||||
/// # Type Parameters
|
||||
/// - `Transport`: The transport actor type (e.g., [`GrpcTransportActor`])
|
||||
/// - `Inbound`: `InboundMessage` received by the business actor from the transport
|
||||
/// - `Outbound`: `OutboundMessage` sent by the business actor back to the transport
|
||||
/// - `Business`: The business logic actor
|
||||
/// - `BusinessCtor`: Closure that receives a prepared business actor and transport recipient,
|
||||
/// spawns the business actor, and returns its [`ActorRef`]
|
||||
/// - `TransportCtor`: Closure that receives a prepared transport actor, a recipient for
|
||||
/// inbound messages, and the business actor id, then spawns the transport actor
|
||||
///
|
||||
/// # Returns
|
||||
/// A tuple of `(transport_ref, business_ref)` — actor references for both spawned actors.
|
||||
pub async fn wire<
|
||||
Transport,
|
||||
Inbound,
|
||||
Outbound,
|
||||
DomainError,
|
||||
Business,
|
||||
BusinessCtor,
|
||||
TransportCtor,
|
||||
>(
|
||||
business_ctor: BusinessCtor,
|
||||
transport_ctor: TransportCtor,
|
||||
) -> (ActorRef<T>, ActorRef<BusinessActor>)
|
||||
) -> (ActorRef<Transport>, ActorRef<Business>)
|
||||
where
|
||||
T: TransportActor<RecvMsg>,
|
||||
RecvMsg: Send + 'static,
|
||||
SendMsg: Send + 'static,
|
||||
BusinessActor: Actor + Send + 'static,
|
||||
BusinessCtor: FnOnce(ActorRef<T>) -> BusinessActor::Args,
|
||||
TransportCtor: FnOnce(ActorRef<BusinessActor>) -> T::Args,
|
||||
Transport: TransportActor<Outbound, DomainError>,
|
||||
Inbound: Send + 'static,
|
||||
Outbound: Send + 'static,
|
||||
DomainError: Send + 'static,
|
||||
Business: Actor + Message<Inbound> + Send + 'static,
|
||||
BusinessCtor: FnOnce(PreparedActor<Business>, Recipient<Result<Outbound, DomainError>>),
|
||||
TransportCtor:
|
||||
FnOnce(PreparedActor<Transport>, Recipient<Inbound>),
|
||||
{
|
||||
let prepared_business: PreparedActor<BusinessActor> = Spawn::prepare();
|
||||
let prepared_transport: PreparedActor<T> = Spawn::prepare();
|
||||
let prepared_business: PreparedActor<Business> = Spawn::prepare();
|
||||
let prepared_transport: PreparedActor<Transport> = Spawn::prepare();
|
||||
|
||||
let business_ref = prepared_business.actor_ref().clone();
|
||||
let transport_ref = prepared_transport.actor_ref().clone();
|
||||
@@ -181,8 +361,11 @@ where
|
||||
transport_ref.link(&business_ref).await;
|
||||
business_ref.link(&transport_ref).await;
|
||||
|
||||
let _ = prepared_business.spawn(business_ctor(transport_ref.clone()));
|
||||
let _ = prepared_transport.spawn(transport_ctor(business_ref.clone()));
|
||||
let recipient = transport_ref.clone().recipient();
|
||||
business_ctor(prepared_business, recipient);
|
||||
let business_recipient = business_ref.clone().recipient();
|
||||
transport_ctor(prepared_transport, business_recipient);
|
||||
|
||||
|
||||
(transport_ref, business_ref)
|
||||
}
|
||||
|
||||
57
server/crates/arbiter-server/src/actors/user_agent/error.rs
Normal file
57
server/crates/arbiter-server/src/actors/user_agent/error.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use tonic::Status;
|
||||
|
||||
use crate::db;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum UserAgentError {
|
||||
#[error("Missing payload in request")]
|
||||
MissingPayload,
|
||||
|
||||
#[error("Invalid bootstrap token")]
|
||||
InvalidBootstrapToken,
|
||||
|
||||
#[error("Public key not registered")]
|
||||
PubkeyNotRegistered,
|
||||
|
||||
#[error("Invalid public key format")]
|
||||
InvalidPubkey,
|
||||
|
||||
#[error("Invalid signature length")]
|
||||
InvalidSignatureLength,
|
||||
|
||||
#[error("Invalid challenge solution")]
|
||||
InvalidChallengeSolution,
|
||||
|
||||
#[error("Invalid state for operation")]
|
||||
InvalidState,
|
||||
|
||||
#[error("Actor unavailable")]
|
||||
ActorUnavailable,
|
||||
|
||||
#[error("Database error")]
|
||||
Database(#[from] diesel::result::Error),
|
||||
|
||||
#[error("Database pool error")]
|
||||
DatabasePool(#[from] db::PoolError),
|
||||
}
|
||||
|
||||
impl From<UserAgentError> for Status {
|
||||
fn from(err: UserAgentError) -> Self {
|
||||
match err {
|
||||
UserAgentError::MissingPayload
|
||||
| UserAgentError::InvalidBootstrapToken
|
||||
| UserAgentError::InvalidPubkey
|
||||
| UserAgentError::InvalidSignatureLength => Status::invalid_argument(err.to_string()),
|
||||
|
||||
UserAgentError::PubkeyNotRegistered | UserAgentError::InvalidChallengeSolution => {
|
||||
Status::unauthenticated(err.to_string())
|
||||
}
|
||||
|
||||
UserAgentError::InvalidState => Status::failed_precondition(err.to_string()),
|
||||
|
||||
UserAgentError::ActorUnavailable
|
||||
| UserAgentError::Database(_)
|
||||
| UserAgentError::DatabasePool(_) => Status::internal(err.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,21 +1,23 @@
|
||||
use std::{ops::DerefMut, sync::Mutex};
|
||||
|
||||
use arbiter_proto::proto::{
|
||||
UnsealEncryptedKey, UnsealResult, UnsealStart, UnsealStartResponse, UserAgentResponse,
|
||||
UnsealEncryptedKey, UnsealResult, UnsealStart, UnsealStartResponse, UserAgentRequest,
|
||||
UserAgentResponse,
|
||||
auth::{
|
||||
self, AuthChallengeRequest, AuthOk, ServerMessage as AuthServerMessage,
|
||||
self, AuthChallengeRequest, AuthOk, ClientMessage as ClientAuthMessage,
|
||||
ServerMessage as AuthServerMessage,
|
||||
client_message::Payload as ClientAuthPayload,
|
||||
server_message::Payload as ServerAuthPayload,
|
||||
},
|
||||
user_agent_request::Payload as UserAgentRequestPayload,
|
||||
user_agent_response::Payload as UserAgentResponsePayload,
|
||||
};
|
||||
use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit};
|
||||
use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, dsl::update};
|
||||
use diesel_async::RunQueryDsl;
|
||||
use ed25519_dalek::VerifyingKey;
|
||||
use kameo::{Actor, error::SendError, messages};
|
||||
use kameo::{Actor, actor::Recipient, error::SendError, messages, prelude::Message};
|
||||
use memsafe::MemSafe;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tonic::Status;
|
||||
use tracing::{error, info};
|
||||
use x25519_dalek::{EphemeralSecret, PublicKey};
|
||||
|
||||
@@ -31,53 +33,74 @@ use crate::{
|
||||
},
|
||||
},
|
||||
db::{self, schema},
|
||||
errors::GrpcStatusExt,
|
||||
};
|
||||
|
||||
mod error;
|
||||
mod state;
|
||||
|
||||
mod transport;
|
||||
pub(crate) use transport::handle_user_agent;
|
||||
pub use error::UserAgentError;
|
||||
|
||||
#[derive(Actor)]
|
||||
pub struct UserAgentActor {
|
||||
db: db::DatabasePool,
|
||||
actors: GlobalActors,
|
||||
state: UserAgentStateMachine<DummyContext>,
|
||||
// will be used in future
|
||||
_tx: Sender<Result<UserAgentResponse, Status>>,
|
||||
transport: Recipient<Result<UserAgentResponse, UserAgentError>>,
|
||||
}
|
||||
|
||||
impl UserAgentActor {
|
||||
pub(crate) fn new(
|
||||
context: ServerContext,
|
||||
tx: Sender<Result<UserAgentResponse, Status>>,
|
||||
transport: Recipient<Result<UserAgentResponse, UserAgentError>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
db: context.db.clone(),
|
||||
actors: context.actors.clone(),
|
||||
state: UserAgentStateMachine::new(DummyContext),
|
||||
_tx: tx,
|
||||
transport,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_manual(
|
||||
db: db::DatabasePool,
|
||||
actors: GlobalActors,
|
||||
tx: Sender<Result<UserAgentResponse, Status>>,
|
||||
transport: Recipient<Result<UserAgentResponse, UserAgentError>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
db,
|
||||
actors,
|
||||
state: UserAgentStateMachine::new(DummyContext),
|
||||
_tx: tx,
|
||||
transport,
|
||||
}
|
||||
}
|
||||
|
||||
fn transition(&mut self, event: UserAgentEvents) -> Result<(), Status> {
|
||||
async fn process_request(&mut self, req: UserAgentRequest) -> Output {
|
||||
let msg = req.payload.ok_or_else(|| {
|
||||
error!(actor = "useragent", "Received message with no payload");
|
||||
UserAgentError::MissingPayload
|
||||
})?;
|
||||
|
||||
match msg {
|
||||
UserAgentRequestPayload::AuthMessage(ClientAuthMessage {
|
||||
payload: Some(ClientAuthPayload::AuthChallengeRequest(req)),
|
||||
}) => self.handle_auth_challenge_request(req).await,
|
||||
UserAgentRequestPayload::AuthMessage(ClientAuthMessage {
|
||||
payload: Some(ClientAuthPayload::AuthChallengeSolution(solution)),
|
||||
}) => self.handle_auth_challenge_solution(solution).await,
|
||||
UserAgentRequestPayload::UnsealStart(unseal_start) => {
|
||||
self.handle_unseal_request(unseal_start).await
|
||||
}
|
||||
UserAgentRequestPayload::UnsealEncryptedKey(unseal_encrypted_key) => {
|
||||
self.handle_unseal_encrypted_key(unseal_encrypted_key).await
|
||||
}
|
||||
_ => Err(UserAgentError::MissingPayload),
|
||||
}
|
||||
}
|
||||
|
||||
fn transition(&mut self, event: UserAgentEvents) -> Result<(), UserAgentError> {
|
||||
self.state.process_event(event).map_err(|e| {
|
||||
error!(?e, "State transition failed");
|
||||
Status::internal("State machine error")
|
||||
UserAgentError::InvalidState
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -86,7 +109,7 @@ impl UserAgentActor {
|
||||
&mut self,
|
||||
pubkey: ed25519_dalek::VerifyingKey,
|
||||
token: String,
|
||||
) -> Result<UserAgentResponse, Status> {
|
||||
) -> Output {
|
||||
let token_ok: bool = self
|
||||
.actors
|
||||
.bootstrapper
|
||||
@@ -94,16 +117,16 @@ impl UserAgentActor {
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!(?pubkey, "Failed to consume bootstrap token: {e}");
|
||||
Status::internal("Bootstrap token consumption failed")
|
||||
UserAgentError::ActorUnavailable
|
||||
})?;
|
||||
|
||||
if !token_ok {
|
||||
error!(?pubkey, "Invalid bootstrap token provided");
|
||||
return Err(Status::invalid_argument("Invalid bootstrap token"));
|
||||
return Err(UserAgentError::InvalidBootstrapToken);
|
||||
}
|
||||
|
||||
{
|
||||
let mut conn = self.db.get().await.to_status()?;
|
||||
let mut conn = self.db.get().await?;
|
||||
|
||||
diesel::insert_into(schema::useragent_client::table)
|
||||
.values((
|
||||
@@ -111,8 +134,7 @@ impl UserAgentActor {
|
||||
schema::useragent_client::nonce.eq(1),
|
||||
))
|
||||
.execute(&mut conn)
|
||||
.await
|
||||
.to_status()?;
|
||||
.await?;
|
||||
}
|
||||
|
||||
self.transition(UserAgentEvents::ReceivedBootstrapToken)?;
|
||||
@@ -122,7 +144,7 @@ impl UserAgentActor {
|
||||
|
||||
async fn auth_with_challenge(&mut self, pubkey: VerifyingKey, pubkey_bytes: Vec<u8>) -> Output {
|
||||
let nonce: Option<i32> = {
|
||||
let mut db_conn = self.db.get().await.to_status()?;
|
||||
let mut db_conn = self.db.get().await?;
|
||||
db_conn
|
||||
.exclusive_transaction(|conn| {
|
||||
Box::pin(async move {
|
||||
@@ -146,13 +168,12 @@ impl UserAgentActor {
|
||||
})
|
||||
})
|
||||
.await
|
||||
.optional()
|
||||
.to_status()?
|
||||
.optional()?
|
||||
};
|
||||
|
||||
let Some(nonce) = nonce else {
|
||||
error!(?pubkey, "Public key not found in database");
|
||||
return Err(Status::unauthenticated("Public key not registered"));
|
||||
return Err(UserAgentError::PubkeyNotRegistered);
|
||||
};
|
||||
|
||||
let challenge = auth::AuthChallenge {
|
||||
@@ -177,19 +198,17 @@ impl UserAgentActor {
|
||||
fn verify_challenge_solution(
|
||||
&self,
|
||||
solution: &auth::AuthChallengeSolution,
|
||||
) -> Result<(bool, &ChallengeContext), Status> {
|
||||
) -> Result<(bool, &ChallengeContext), UserAgentError> {
|
||||
let UserAgentStates::WaitingForChallengeSolution(challenge_context) = self.state.state()
|
||||
else {
|
||||
error!("Received challenge solution in invalid state");
|
||||
return Err(Status::invalid_argument(
|
||||
"Invalid state for challenge solution",
|
||||
));
|
||||
return Err(UserAgentError::InvalidState);
|
||||
};
|
||||
let formatted_challenge = arbiter_proto::format_challenge(&challenge_context.challenge);
|
||||
|
||||
let signature = solution.signature.as_slice().try_into().map_err(|_| {
|
||||
error!(?solution, "Invalid signature length");
|
||||
Status::invalid_argument("Invalid signature length")
|
||||
UserAgentError::InvalidSignatureLength
|
||||
})?;
|
||||
|
||||
let valid = challenge_context
|
||||
@@ -201,7 +220,7 @@ impl UserAgentActor {
|
||||
}
|
||||
}
|
||||
|
||||
type Output = Result<UserAgentResponse, Status>;
|
||||
type Output = Result<UserAgentResponse, UserAgentError>;
|
||||
|
||||
fn auth_response(payload: ServerAuthPayload) -> UserAgentResponse {
|
||||
UserAgentResponse {
|
||||
@@ -227,7 +246,7 @@ impl UserAgentActor {
|
||||
let client_pubkey_bytes: [u8; 32] = req
|
||||
.client_pubkey
|
||||
.try_into()
|
||||
.map_err(|_| Status::invalid_argument("client_pubkey must be 32 bytes"))?;
|
||||
.map_err(|_| UserAgentError::InvalidPubkey)?;
|
||||
|
||||
let client_public_key = PublicKey::from(client_pubkey_bytes);
|
||||
|
||||
@@ -247,9 +266,7 @@ impl UserAgentActor {
|
||||
pub async fn handle_unseal_encrypted_key(&mut self, req: UnsealEncryptedKey) -> Output {
|
||||
let UserAgentStates::WaitingForUnsealKey(unseal_context) = self.state.state() else {
|
||||
error!("Received unseal encrypted key in invalid state");
|
||||
return Err(Status::failed_precondition(
|
||||
"Invalid state for unseal encrypted key",
|
||||
));
|
||||
return Err(UserAgentError::InvalidState);
|
||||
};
|
||||
let ephemeral_secret = {
|
||||
let mut secret_lock = unseal_context.secret.lock().unwrap();
|
||||
@@ -313,7 +330,7 @@ impl UserAgentActor {
|
||||
Err(err) => {
|
||||
error!(?err, "Failed to send unseal request to keyholder");
|
||||
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
|
||||
Err(Status::internal("Vault is not available"))
|
||||
Err(UserAgentError::ActorUnavailable)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -329,12 +346,13 @@ impl UserAgentActor {
|
||||
|
||||
#[message]
|
||||
pub async fn handle_auth_challenge_request(&mut self, req: AuthChallengeRequest) -> Output {
|
||||
let pubkey = req.pubkey.as_array().ok_or(Status::invalid_argument(
|
||||
"Expected pubkey to have specific length",
|
||||
))?;
|
||||
let pubkey = req
|
||||
.pubkey
|
||||
.as_array()
|
||||
.ok_or(UserAgentError::InvalidPubkey)?;
|
||||
let pubkey = VerifyingKey::from_bytes(pubkey).map_err(|_err| {
|
||||
error!(?pubkey, "Failed to convert to VerifyingKey");
|
||||
Status::invalid_argument("Failed to convert pubkey to VerifyingKey")
|
||||
UserAgentError::InvalidPubkey
|
||||
})?;
|
||||
|
||||
self.transition(UserAgentEvents::AuthRequest)?;
|
||||
@@ -362,7 +380,22 @@ impl UserAgentActor {
|
||||
} else {
|
||||
error!("Client provided invalid solution to authentication challenge");
|
||||
self.transition(UserAgentEvents::ReceivedBadSolution)?;
|
||||
Err(Status::unauthenticated("Invalid challenge solution"))
|
||||
Err(UserAgentError::InvalidChallengeSolution)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Message<UserAgentRequest> for UserAgentActor {
|
||||
type Reply = ();
|
||||
|
||||
async fn handle(
|
||||
&mut self,
|
||||
msg: UserAgentRequest,
|
||||
_ctx: &mut kameo::prelude::Context<Self, Self::Reply>,
|
||||
) -> Self::Reply {
|
||||
let result = self.process_request(msg).await;
|
||||
if let Err(e) = self.transport.tell(result).await {
|
||||
error!(actor = "useragent", "Failed to send response to transport: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
use super::UserAgentActor;
|
||||
use arbiter_proto::proto::{
|
||||
UserAgentRequest, UserAgentResponse,
|
||||
auth::{ClientMessage as ClientAuthMessage, client_message::Payload as ClientAuthPayload},
|
||||
user_agent_request::Payload as UserAgentRequestPayload,
|
||||
};
|
||||
use futures::StreamExt;
|
||||
use kameo::{
|
||||
actor::{ActorRef, Spawn as _},
|
||||
error::SendError,
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
use tonic::Status;
|
||||
use tracing::error;
|
||||
|
||||
use crate::{
|
||||
actors::user_agent::{
|
||||
HandleAuthChallengeRequest, HandleAuthChallengeSolution, HandleUnsealEncryptedKey,
|
||||
HandleUnsealRequest,
|
||||
},
|
||||
context::ServerContext,
|
||||
};
|
||||
|
||||
pub(crate) async fn handle_user_agent(
|
||||
context: ServerContext,
|
||||
mut req_stream: tonic::Streaming<UserAgentRequest>,
|
||||
tx: mpsc::Sender<Result<UserAgentResponse, Status>>,
|
||||
) {
|
||||
let actor = UserAgentActor::spawn(UserAgentActor::new(context, tx.clone()));
|
||||
|
||||
while let Some(Ok(req)) = req_stream.next().await
|
||||
&& actor.is_alive()
|
||||
{
|
||||
match process_message(&actor, req).await {
|
||||
Ok(resp) => {
|
||||
if tx.send(Ok(resp)).await.is_err() {
|
||||
error!(actor = "useragent", "Failed to send response to client");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(status) => {
|
||||
let _ = tx.send(Err(status)).await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
actor.kill();
|
||||
}
|
||||
|
||||
async fn process_message(
|
||||
actor: &ActorRef<UserAgentActor>,
|
||||
req: UserAgentRequest,
|
||||
) -> Result<UserAgentResponse, Status> {
|
||||
let msg = req.payload.ok_or_else(|| {
|
||||
error!(actor = "useragent", "Received message with no payload");
|
||||
Status::invalid_argument("Expected message with payload")
|
||||
})?;
|
||||
|
||||
match msg {
|
||||
UserAgentRequestPayload::AuthMessage(ClientAuthMessage {
|
||||
payload: Some(ClientAuthPayload::AuthChallengeRequest(req)),
|
||||
}) => actor
|
||||
.ask(HandleAuthChallengeRequest { req })
|
||||
.await
|
||||
.map_err(into_status),
|
||||
UserAgentRequestPayload::AuthMessage(ClientAuthMessage {
|
||||
payload: Some(ClientAuthPayload::AuthChallengeSolution(solution)),
|
||||
}) => actor
|
||||
.ask(HandleAuthChallengeSolution { solution })
|
||||
.await
|
||||
.map_err(into_status),
|
||||
UserAgentRequestPayload::UnsealStart(unseal_start) => actor
|
||||
.ask(HandleUnsealRequest { req: unseal_start })
|
||||
.await
|
||||
.map_err(into_status),
|
||||
UserAgentRequestPayload::UnsealEncryptedKey(unseal_encrypted_key) => actor
|
||||
.ask(HandleUnsealEncryptedKey {
|
||||
req: unseal_encrypted_key,
|
||||
})
|
||||
.await
|
||||
.map_err(into_status),
|
||||
_ => Err(Status::invalid_argument("Expected message with payload")),
|
||||
}
|
||||
}
|
||||
|
||||
fn into_status<M>(e: SendError<M, Status>) -> Status {
|
||||
match e {
|
||||
SendError::HandlerError(status) => status,
|
||||
_ => {
|
||||
error!(actor = "useragent", "Failed to send message to actor");
|
||||
Status::internal("session failure")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
use tonic::Status;
|
||||
use tracing::error;
|
||||
|
||||
pub trait GrpcStatusExt<T> {
|
||||
fn to_status(self) -> Result<T, Status>;
|
||||
}
|
||||
|
||||
impl<T> GrpcStatusExt<T> for Result<T, diesel::result::Error> {
|
||||
fn to_status(self) -> Result<T, Status> {
|
||||
self.map_err(|e| {
|
||||
error!(error = ?e, "Database error");
|
||||
Status::internal("Database error")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> GrpcStatusExt<T> for Result<T, crate::db::PoolError> {
|
||||
fn to_status(self) -> Result<T, Status> {
|
||||
self.map_err(|e| {
|
||||
error!(error = ?e, "Database pool error");
|
||||
Status::internal("Database pool error")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,23 +1,26 @@
|
||||
#![forbid(unsafe_code)]
|
||||
use arbiter_proto::{
|
||||
proto::{ClientRequest, ClientResponse, UserAgentRequest, UserAgentResponse},
|
||||
transport::BiStream,
|
||||
transport::{BiStream, GrpcTransportActor, wire},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use kameo::actor::PreparedActor;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
use tonic::{Request, Response, Status};
|
||||
|
||||
use crate::{
|
||||
actors::{client::handle_client, user_agent::handle_user_agent},
|
||||
actors::{
|
||||
client::handle_client,
|
||||
user_agent::UserAgentActor,
|
||||
},
|
||||
context::ServerContext,
|
||||
};
|
||||
|
||||
pub mod actors;
|
||||
pub mod context;
|
||||
pub mod db;
|
||||
mod errors;
|
||||
|
||||
const DEFAULT_CHANNEL_SIZE: usize = 1000;
|
||||
|
||||
@@ -59,7 +62,22 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server {
|
||||
) -> Result<Response<Self::UserAgentStream>, Status> {
|
||||
let req_stream = request.into_inner();
|
||||
let (tx, rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
|
||||
tokio::spawn(handle_user_agent(self.context.clone(), req_stream, tx));
|
||||
let context = self.context.clone();
|
||||
|
||||
wire(
|
||||
|prepared: PreparedActor<UserAgentActor>, recipient| {
|
||||
prepared.spawn(UserAgentActor::new(context, recipient));
|
||||
},
|
||||
|prepared: PreparedActor<GrpcTransportActor<_, _, _>>, business_recipient| {
|
||||
prepared.spawn(GrpcTransportActor::new(
|
||||
tx,
|
||||
req_stream,
|
||||
business_recipient,
|
||||
));
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(Response::new(ReceiverStream::new(rx)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,30 @@
|
||||
mod common;
|
||||
|
||||
use arbiter_proto::proto::UserAgentResponse;
|
||||
use arbiter_server::actors::user_agent::UserAgentError;
|
||||
use kameo::{Actor, actor::Recipient, actor::Spawn, prelude::Message};
|
||||
|
||||
/// A no-op actor that discards any messages it receives.
|
||||
#[derive(Actor)]
|
||||
struct NullSink;
|
||||
|
||||
impl Message<Result<UserAgentResponse, UserAgentError>> for NullSink {
|
||||
type Reply = ();
|
||||
|
||||
async fn handle(
|
||||
&mut self,
|
||||
_msg: Result<UserAgentResponse, UserAgentError>,
|
||||
_ctx: &mut kameo::prelude::Context<Self, Self::Reply>,
|
||||
) -> Self::Reply {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a `Recipient` that silently discards all messages.
|
||||
fn null_recipient() -> Recipient<Result<UserAgentResponse, UserAgentError>> {
|
||||
let actor_ref = NullSink::spawn(NullSink);
|
||||
actor_ref.recipient()
|
||||
}
|
||||
|
||||
#[path = "user_agent/auth.rs"]
|
||||
mod auth;
|
||||
#[path = "user_agent/unseal.rs"]
|
||||
|
||||
@@ -24,7 +24,7 @@ pub async fn test_bootstrap_token_auth() {
|
||||
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
|
||||
let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap();
|
||||
let user_agent =
|
||||
UserAgentActor::new_manual(db.clone(), actors, tokio::sync::mpsc::channel(1).0);
|
||||
UserAgentActor::new_manual(db.clone(), actors, super::null_recipient());
|
||||
let user_agent_ref = UserAgentActor::spawn(user_agent);
|
||||
|
||||
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
|
||||
@@ -69,7 +69,7 @@ pub async fn test_bootstrap_invalid_token_auth() {
|
||||
|
||||
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
|
||||
let user_agent =
|
||||
UserAgentActor::new_manual(db.clone(), actors, tokio::sync::mpsc::channel(1).0);
|
||||
UserAgentActor::new_manual(db.clone(), actors, super::null_recipient());
|
||||
let user_agent_ref = UserAgentActor::spawn(user_agent);
|
||||
|
||||
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
|
||||
@@ -85,15 +85,11 @@ pub async fn test_bootstrap_invalid_token_auth() {
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(kameo::error::SendError::HandlerError(status)) => {
|
||||
assert_eq!(status.code(), tonic::Code::InvalidArgument);
|
||||
insta::assert_debug_snapshot!(status, @r#"
|
||||
Status {
|
||||
code: InvalidArgument,
|
||||
message: "Invalid bootstrap token",
|
||||
source: None,
|
||||
}
|
||||
"#);
|
||||
Err(kameo::error::SendError::HandlerError(err)) => {
|
||||
assert!(
|
||||
matches!(err, arbiter_server::actors::user_agent::UserAgentError::InvalidBootstrapToken),
|
||||
"Expected InvalidBootstrapToken, got {err:?}"
|
||||
);
|
||||
}
|
||||
Err(other) => {
|
||||
panic!("Expected SendError::HandlerError, got {other:?}");
|
||||
@@ -111,7 +107,7 @@ pub async fn test_challenge_auth() {
|
||||
|
||||
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
|
||||
let user_agent =
|
||||
UserAgentActor::new_manual(db.clone(), actors, tokio::sync::mpsc::channel(1).0);
|
||||
UserAgentActor::new_manual(db.clone(), actors, super::null_recipient());
|
||||
let user_agent_ref = UserAgentActor::spawn(user_agent);
|
||||
|
||||
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
|
||||
|
||||
@@ -35,7 +35,7 @@ async fn setup_authenticated_user_agent(
|
||||
actors.key_holder.ask(Seal).await.unwrap();
|
||||
|
||||
let user_agent =
|
||||
UserAgentActor::new_manual(db.clone(), actors.clone(), tokio::sync::mpsc::channel(1).0);
|
||||
UserAgentActor::new_manual(db.clone(), actors.clone(), super::null_recipient());
|
||||
let user_agent_ref = UserAgentActor::spawn(user_agent);
|
||||
|
||||
let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap();
|
||||
@@ -169,7 +169,7 @@ pub async fn test_unseal_start_without_auth_fails() {
|
||||
|
||||
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
|
||||
let user_agent =
|
||||
UserAgentActor::new_manual(db.clone(), actors, tokio::sync::mpsc::channel(1).0);
|
||||
UserAgentActor::new_manual(db.clone(), actors, super::null_recipient());
|
||||
let user_agent_ref = UserAgentActor::spawn(user_agent);
|
||||
|
||||
let client_secret = EphemeralSecret::random();
|
||||
@@ -184,8 +184,11 @@ pub async fn test_unseal_start_without_auth_fails() {
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(kameo::error::SendError::HandlerError(status)) => {
|
||||
assert_eq!(status.code(), tonic::Code::Internal);
|
||||
Err(kameo::error::SendError::HandlerError(err)) => {
|
||||
assert!(
|
||||
matches!(err, arbiter_server::actors::user_agent::UserAgentError::InvalidState),
|
||||
"Expected InvalidState, got {err:?}"
|
||||
);
|
||||
}
|
||||
other => panic!("Expected state machine error, got {other:?}"),
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use arbiter_proto::proto::UserAgentRequest;
|
||||
use arbiter_proto::{proto::UserAgentRequest, transport::TransportActor};
|
||||
use ed25519_dalek::SigningKey;
|
||||
use kameo::{
|
||||
Actor, Reply,
|
||||
@@ -6,7 +6,6 @@ use kameo::{
|
||||
prelude::Message,
|
||||
};
|
||||
use smlang::statemachine;
|
||||
use tokio::sync::mpsc;
|
||||
use tonic::transport::CertificateDer;
|
||||
use tracing::{debug, error};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user