feat(server::{router, useragent}): inter-actor approval coordination
This commit is contained in:
@@ -1,17 +1,20 @@
|
||||
use std::{
|
||||
collections::{HashMap},
|
||||
ops::ControlFlow,
|
||||
};
|
||||
use std::{collections::HashMap, ops::ControlFlow};
|
||||
|
||||
use ed25519_dalek::VerifyingKey;
|
||||
use kameo::{
|
||||
Actor,
|
||||
actor::{ActorId, ActorRef},
|
||||
messages,
|
||||
prelude::{ActorStopReason, Context, WeakActorRef},
|
||||
reply::DelegatedReply,
|
||||
};
|
||||
use tracing::info;
|
||||
use tokio::{sync::watch, task::JoinSet};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::actors::{client::session::ClientSession, user_agent::session::UserAgentSession};
|
||||
use crate::actors::{
|
||||
client::session::ClientSession,
|
||||
user_agent::session::{RequestNewClientApproval, UserAgentSession},
|
||||
};
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct MessageRouter {
|
||||
@@ -53,6 +56,74 @@ impl Actor for MessageRouter {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ApprovalError {
|
||||
#[error("No user agents connected")]
|
||||
NoUserAgentsConnected,
|
||||
}
|
||||
|
||||
async fn request_client_approval(
|
||||
user_agents: &[WeakActorRef<UserAgentSession>],
|
||||
client_pubkey: VerifyingKey,
|
||||
) -> Result<bool, ApprovalError> {
|
||||
if user_agents.is_empty() {
|
||||
return Err(ApprovalError::NoUserAgentsConnected).into();
|
||||
}
|
||||
|
||||
let mut pool = JoinSet::new();
|
||||
let (cancel_tx, cancel_rx) = watch::channel(());
|
||||
|
||||
for weak_ref in user_agents {
|
||||
match weak_ref.upgrade() {
|
||||
Some(agent) => {
|
||||
let client_pubkey = client_pubkey.clone();
|
||||
let cancel_rx = cancel_rx.clone();
|
||||
pool.spawn(async move {
|
||||
agent
|
||||
.ask(RequestNewClientApproval {
|
||||
client_pubkey,
|
||||
cancel_flag: cancel_rx.clone(),
|
||||
})
|
||||
.await
|
||||
});
|
||||
}
|
||||
None => {
|
||||
warn!(
|
||||
id = weak_ref.id().to_string(),
|
||||
actor = "MessageRouter",
|
||||
event = "useragent.disconnected_before_approval"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(result) = pool.join_next().await {
|
||||
match result {
|
||||
Ok(Ok(approved)) => {
|
||||
// cancel other pending requests
|
||||
let _ = cancel_tx.send(());
|
||||
return Ok(approved);
|
||||
}
|
||||
Ok(Err(err)) => {
|
||||
warn!(
|
||||
?err,
|
||||
actor = "MessageRouter",
|
||||
event = "useragent.approval_error"
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
?err,
|
||||
actor = "MessageRouter",
|
||||
event = "useragent.approval_task_failed"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(ApprovalError::NoUserAgentsConnected)
|
||||
}
|
||||
|
||||
#[messages]
|
||||
impl MessageRouter {
|
||||
#[message(ctx)]
|
||||
@@ -76,4 +147,29 @@ impl MessageRouter {
|
||||
ctx.actor_ref().link(&actor).await;
|
||||
self.clients.insert(actor.id(), actor);
|
||||
}
|
||||
|
||||
#[message(ctx)]
|
||||
pub async fn request_client_approval(
|
||||
&mut self,
|
||||
client_pubkey: VerifyingKey,
|
||||
ctx: &mut Context<Self, DelegatedReply<Result<bool, ApprovalError>>>,
|
||||
) -> DelegatedReply<Result<bool, ApprovalError>> {
|
||||
let (reply, Some(reply_sender)) = ctx.reply_sender() else {
|
||||
panic!("Exptected `request_client_approval` to have callback channel");
|
||||
};
|
||||
|
||||
let weak_refs = self
|
||||
.user_agents
|
||||
.values()
|
||||
.map(|agent| agent.downgrade())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// handle in subtask to not to lock the actor
|
||||
tokio::task::spawn(async move {
|
||||
let result = request_client_approval(&weak_refs, client_pubkey).await;
|
||||
let _ = reply_sender.send(result);
|
||||
});
|
||||
|
||||
reply
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ use crate::{
|
||||
};
|
||||
|
||||
#[derive(Debug, thiserror::Error, PartialEq)]
|
||||
pub enum UserAgentError {
|
||||
pub enum TransportResponseError {
|
||||
#[error("Expected message with payload")]
|
||||
MissingRequestPayload,
|
||||
#[error("Unexpected request payload")]
|
||||
@@ -31,7 +31,7 @@ pub enum UserAgentError {
|
||||
}
|
||||
|
||||
pub type Transport =
|
||||
Box<dyn Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>> + Send>;
|
||||
Box<dyn Bi<UserAgentRequest, Result<UserAgentResponse, TransportResponseError>> + Send>;
|
||||
|
||||
pub struct UserAgentConnection {
|
||||
db: db::DatabasePool,
|
||||
|
||||
@@ -1,27 +1,41 @@
|
||||
use std::{ops::DerefMut, sync::Mutex};
|
||||
|
||||
use arbiter_proto::proto::user_agent::{
|
||||
UnsealEncryptedKey, UnsealResult, UnsealStart, UnsealStartResponse, UserAgentRequest,
|
||||
UserAgentResponse, user_agent_request::Payload as UserAgentRequestPayload,
|
||||
user_agent_response::Payload as UserAgentResponsePayload,
|
||||
use arbiter_proto::proto::{
|
||||
client,
|
||||
user_agent::{
|
||||
ClientConnectionCancel, ClientConnectionRequest, UnsealEncryptedKey, UnsealResult,
|
||||
UnsealStart, UnsealStartResponse, UserAgentRequest, UserAgentResponse,
|
||||
user_agent_request::Payload as UserAgentRequestPayload,
|
||||
user_agent_response::Payload as UserAgentResponsePayload,
|
||||
},
|
||||
};
|
||||
use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit};
|
||||
use ed25519_dalek::VerifyingKey;
|
||||
use kameo::{Actor, error::SendError};
|
||||
use kameo::{Actor, error::SendError, message, messages, prelude::Context};
|
||||
use memsafe::MemSafe;
|
||||
use tokio::select;
|
||||
use tokio::{select, sync::watch};
|
||||
use tracing::{error, info};
|
||||
use x25519_dalek::{EphemeralSecret, PublicKey};
|
||||
|
||||
use crate::actors::{
|
||||
keyholder::{self, TryUnseal},
|
||||
router::RegisterUserAgent,
|
||||
user_agent::{UserAgentConnection, UserAgentError},
|
||||
user_agent::{TransportResponseError, UserAgentConnection},
|
||||
};
|
||||
|
||||
mod state;
|
||||
use state::{DummyContext, UnsealContext, UserAgentEvents, UserAgentStateMachine, UserAgentStates};
|
||||
|
||||
// Error for consumption by other actors
|
||||
#[derive(Debug, thiserror::Error, PartialEq)]
|
||||
pub enum Error {
|
||||
#[error("User agent session ended due to connection loss")]
|
||||
ConnectionLost,
|
||||
|
||||
#[error("User agent session ended due to unexpected message")]
|
||||
UnexpectedMessage,
|
||||
}
|
||||
|
||||
pub struct UserAgentSession {
|
||||
props: UserAgentConnection,
|
||||
key: VerifyingKey,
|
||||
@@ -29,7 +43,7 @@ pub struct UserAgentSession {
|
||||
}
|
||||
|
||||
impl UserAgentSession {
|
||||
pub(crate) fn new(props: UserAgentConnection, key: VerifyingKey) -> Self {
|
||||
pub(crate) fn new(props: UserAgentConnection, key: VerifyingKey) -> Self {
|
||||
Self {
|
||||
props,
|
||||
key,
|
||||
@@ -37,18 +51,119 @@ impl UserAgentSession {
|
||||
}
|
||||
}
|
||||
|
||||
fn transition(&mut self, event: UserAgentEvents) -> Result<(), UserAgentError> {
|
||||
fn transition(&mut self, event: UserAgentEvents) -> Result<(), TransportResponseError> {
|
||||
self.state.process_event(event).map_err(|e| {
|
||||
error!(?e, "State transition failed");
|
||||
UserAgentError::StateTransitionFailed
|
||||
TransportResponseError::StateTransitionFailed
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_msg<Reply: kameo::Reply>(
|
||||
&mut self,
|
||||
msg: UserAgentResponsePayload,
|
||||
ctx: &mut Context<Self, Reply>,
|
||||
) -> Result<(), Error> {
|
||||
self.props
|
||||
.transport
|
||||
.send(Ok(response(msg)))
|
||||
.await
|
||||
.map_err(|_| {
|
||||
error!(
|
||||
actor = "useragent",
|
||||
reason = "channel closed",
|
||||
"send.failed"
|
||||
);
|
||||
Error::ConnectionLost
|
||||
})
|
||||
}
|
||||
|
||||
async fn expect_msg<Extractor, Msg, Reply>(
|
||||
&mut self,
|
||||
extractor: Extractor,
|
||||
ctx: &mut Context<Self, Reply>,
|
||||
) -> Result<Msg, Error>
|
||||
where
|
||||
Extractor: FnOnce(UserAgentRequestPayload) -> Option<Msg>,
|
||||
Reply: kameo::Reply,
|
||||
{
|
||||
let msg = self.props.transport.recv().await.ok_or_else(|| {
|
||||
error!(
|
||||
actor = "useragent",
|
||||
reason = "channel closed",
|
||||
"recv.failed"
|
||||
);
|
||||
ctx.stop();
|
||||
Error::ConnectionLost
|
||||
})?;
|
||||
|
||||
msg.payload.and_then(extractor).ok_or_else(|| {
|
||||
error!(
|
||||
actor = "useragent",
|
||||
reason = "unexpected message",
|
||||
"recv.failed"
|
||||
);
|
||||
ctx.stop();
|
||||
Error::UnexpectedMessage
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[messages]
|
||||
impl UserAgentSession {
|
||||
|
||||
// TODO: Think about refactoring it to state-machine based flow, as we already have one
|
||||
#[message(ctx)]
|
||||
pub async fn request_new_client_approval(
|
||||
&mut self,
|
||||
client_pubkey: VerifyingKey,
|
||||
mut cancel_flag: watch::Receiver<()>,
|
||||
ctx: &mut Context<Self, Result<bool, Error>>,
|
||||
) -> Result<bool, Error> {
|
||||
self.send_msg(
|
||||
UserAgentResponsePayload::ClientConnectionRequest(
|
||||
ClientConnectionRequest {
|
||||
pubkey: client_pubkey.as_bytes().to_vec(),
|
||||
}
|
||||
.into(),
|
||||
),
|
||||
ctx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let extractor = |msg| {
|
||||
if let UserAgentRequestPayload::ClientConnectionResponse(client_connection_response) =
|
||||
msg
|
||||
{
|
||||
Some(client_connection_response)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
tokio::select! {
|
||||
_ = cancel_flag.changed() => {
|
||||
info!(actor = "useragent", "client connection approval cancelled");
|
||||
self.send_msg(
|
||||
UserAgentResponsePayload::ClientConnectionCancel(ClientConnectionCancel {}),
|
||||
ctx,
|
||||
).await?;
|
||||
return Ok(false);
|
||||
}
|
||||
result = self.expect_msg(extractor, ctx) => {
|
||||
let result = result?;
|
||||
info!(actor = "useragent", "received client connection approval result: approved={}", result.approved);
|
||||
return Ok(result.approved);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserAgentSession {
|
||||
pub async fn process_transport_inbound(&mut self, req: UserAgentRequest) -> Output {
|
||||
let msg = req.payload.ok_or_else(|| {
|
||||
error!(actor = "useragent", "Received message with no payload");
|
||||
UserAgentError::MissingRequestPayload
|
||||
TransportResponseError::MissingRequestPayload
|
||||
})?;
|
||||
|
||||
match msg {
|
||||
@@ -58,12 +173,12 @@ impl UserAgentSession {
|
||||
UserAgentRequestPayload::UnsealEncryptedKey(unseal_encrypted_key) => {
|
||||
self.handle_unseal_encrypted_key(unseal_encrypted_key).await
|
||||
}
|
||||
_ => Err(UserAgentError::UnexpectedRequestPayload),
|
||||
_ => Err(TransportResponseError::UnexpectedRequestPayload),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Output = Result<UserAgentResponse, UserAgentError>;
|
||||
type Output = Result<UserAgentResponse, TransportResponseError>;
|
||||
|
||||
fn response(payload: UserAgentResponsePayload) -> UserAgentResponse {
|
||||
UserAgentResponse {
|
||||
@@ -79,7 +194,7 @@ impl UserAgentSession {
|
||||
let client_pubkey_bytes: [u8; 32] = req
|
||||
.client_pubkey
|
||||
.try_into()
|
||||
.map_err(|_| UserAgentError::InvalidClientPubkeyLength)?;
|
||||
.map_err(|_| TransportResponseError::InvalidClientPubkeyLength)?;
|
||||
|
||||
let client_public_key = PublicKey::from(client_pubkey_bytes);
|
||||
|
||||
@@ -98,7 +213,7 @@ impl UserAgentSession {
|
||||
async fn handle_unseal_encrypted_key(&mut self, req: UnsealEncryptedKey) -> Output {
|
||||
let UserAgentStates::WaitingForUnsealKey(unseal_context) = self.state.state() else {
|
||||
error!("Received unseal encrypted key in invalid state");
|
||||
return Err(UserAgentError::InvalidStateForUnsealEncryptedKey);
|
||||
return Err(TransportResponseError::InvalidStateForUnsealEncryptedKey);
|
||||
};
|
||||
let ephemeral_secret = {
|
||||
let mut secret_lock = unseal_context.secret.lock().unwrap();
|
||||
@@ -163,7 +278,7 @@ impl UserAgentSession {
|
||||
Err(err) => {
|
||||
error!(?err, "Failed to send unseal request to keyholder");
|
||||
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
|
||||
Err(UserAgentError::KeyHolderActorUnreachable)
|
||||
Err(TransportResponseError::KeyHolderActorUnreachable)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -181,7 +296,7 @@ impl UserAgentSession {
|
||||
impl Actor for UserAgentSession {
|
||||
type Args = Self;
|
||||
|
||||
type Error = UserAgentError;
|
||||
type Error = TransportResponseError;
|
||||
|
||||
async fn on_start(
|
||||
args: Self::Args,
|
||||
@@ -196,7 +311,7 @@ impl Actor for UserAgentSession {
|
||||
.await
|
||||
.map_err(|err| {
|
||||
error!(?err, "Failed to register user agent connection with router");
|
||||
UserAgentError::ConnectionRegistrationFailed
|
||||
TransportResponseError::ConnectionRegistrationFailed
|
||||
})?;
|
||||
Ok(args)
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ use tracing::info;
|
||||
use crate::{
|
||||
actors::{
|
||||
client::{self, ClientError, ClientConnection as ClientConnectionProps, connect_client},
|
||||
user_agent::{self, UserAgentConnection, UserAgentError, connect_user_agent},
|
||||
user_agent::{self, UserAgentConnection, TransportResponseError, connect_user_agent},
|
||||
},
|
||||
context::ServerContext,
|
||||
};
|
||||
@@ -30,7 +30,7 @@ const DEFAULT_CHANNEL_SIZE: usize = 1000;
|
||||
struct UserAgentGrpcSender;
|
||||
|
||||
impl SendConverter for UserAgentGrpcSender {
|
||||
type Input = Result<UserAgentResponse, UserAgentError>;
|
||||
type Input = Result<UserAgentResponse, TransportResponseError>;
|
||||
type Output = Result<UserAgentResponse, Status>;
|
||||
|
||||
fn convert(&self, item: Self::Input) -> Self::Output {
|
||||
@@ -87,21 +87,21 @@ fn client_auth_error_status(value: &client::auth::Error) -> Status {
|
||||
}
|
||||
}
|
||||
|
||||
fn user_agent_error_status(value: UserAgentError) -> Status {
|
||||
fn user_agent_error_status(value: TransportResponseError) -> Status {
|
||||
match value {
|
||||
UserAgentError::MissingRequestPayload | UserAgentError::UnexpectedRequestPayload => {
|
||||
TransportResponseError::MissingRequestPayload | TransportResponseError::UnexpectedRequestPayload => {
|
||||
Status::invalid_argument("Expected message with payload")
|
||||
}
|
||||
UserAgentError::InvalidStateForUnsealEncryptedKey => {
|
||||
TransportResponseError::InvalidStateForUnsealEncryptedKey => {
|
||||
Status::failed_precondition("Invalid state for unseal encrypted key")
|
||||
}
|
||||
UserAgentError::InvalidClientPubkeyLength => {
|
||||
TransportResponseError::InvalidClientPubkeyLength => {
|
||||
Status::invalid_argument("client_pubkey must be 32 bytes")
|
||||
}
|
||||
UserAgentError::StateTransitionFailed => Status::internal("State machine error"),
|
||||
UserAgentError::KeyHolderActorUnreachable => Status::internal("Vault is not available"),
|
||||
UserAgentError::Auth(ref err) => auth_error_status(err),
|
||||
UserAgentError::ConnectionRegistrationFailed => {
|
||||
TransportResponseError::StateTransitionFailed => Status::internal("State machine error"),
|
||||
TransportResponseError::KeyHolderActorUnreachable => Status::internal("Vault is not available"),
|
||||
TransportResponseError::Auth(ref err) => auth_error_status(err),
|
||||
TransportResponseError::ConnectionRegistrationFailed => {
|
||||
Status::internal("Failed registering connection")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user