refactor(server::client): migrated to new connection design

This commit is contained in:
hdbg
2026-03-18 22:40:07 +01:00
committed by Stas
parent d61dab3285
commit 2ff4d0961c
14 changed files with 474 additions and 401 deletions

View File

@@ -1,30 +1,25 @@
use arbiter_proto::{format_challenge, transport::expect_message};
use arbiter_proto::{
format_challenge,
transport::{Bi, expect_message},
};
use diesel::{
ExpressionMethods as _, OptionalExtension as _, QueryDsl as _, dsl::insert_into, update,
};
use diesel_async::RunQueryDsl as _;
use ed25519_dalek::VerifyingKey;
use ed25519_dalek::{Signature, VerifyingKey};
use kameo::error::SendError;
use tracing::error;
use crate::{
actors::{
client::{ClientConnection, ConnectErrorCode, Request, Response},
client::ClientConnection,
router::{self, RequestClientApproval},
},
db::{self, schema::program_client},
};
use super::session::ClientSession;
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
pub enum Error {
#[error("Unexpected message payload")]
UnexpectedMessagePayload,
#[error("Invalid client public key length")]
InvalidClientPubkeyLength,
#[error("Invalid client public key encoding")]
InvalidAuthPubkeyEncoding,
#[error("Database pool unavailable")]
DatabasePoolUnavailable,
#[error("Database operation failed")]
@@ -33,8 +28,6 @@ pub enum Error {
InvalidChallengeSolution,
#[error("Client approval request failed")]
ApproveError(#[from] ApproveError),
#[error("Internal error")]
InternalError,
#[error("Transport error")]
Transport,
}
@@ -49,6 +42,18 @@ pub enum ApproveError {
Upstream(router::ApprovalError),
}
#[derive(Debug, Clone)]
pub enum Inbound {
AuthChallengeRequest { pubkey: VerifyingKey },
AuthChallengeSolution { signature: Signature },
}
#[derive(Debug, Clone)]
pub enum Outbound {
AuthChallenge { pubkey: VerifyingKey, nonce: i32 },
AuthSuccess,
}
/// Atomically reads and increments the nonce for a known client.
/// Returns `None` if the pubkey is not registered.
async fn get_nonce(db: &db::DatabasePool, pubkey: &VerifyingKey) -> Result<Option<i32>, Error> {
@@ -141,27 +146,24 @@ async fn insert_client(db: &db::DatabasePool, pubkey: &VerifyingKey) -> Result<(
Ok(())
}
async fn challenge_client(
props: &mut ClientConnection,
async fn challenge_client<T>(
transport: &mut T,
pubkey: VerifyingKey,
nonce: i32,
) -> Result<(), Error> {
let challenge_pubkey = pubkey.as_bytes().to_vec();
props
.transport
.send(Ok(Response::AuthChallenge {
pubkey: challenge_pubkey.clone(),
nonce,
}))
) -> Result<(), Error>
where
T: Bi<Inbound, Result<Outbound, Error>> + ?Sized,
{
transport
.send(Ok(Outbound::AuthChallenge { pubkey, nonce }))
.await
.map_err(|e| {
error!(error = ?e, "Failed to send auth challenge");
Error::Transport
})?;
let signature = expect_message(&mut *props.transport, |req: Request| match req {
Request::AuthChallengeSolution { signature } => Some(signature),
let signature = expect_message(transport, |req: Inbound| match req {
Inbound::AuthChallengeSolution { signature } => Some(signature),
_ => None,
})
.await
@@ -170,13 +172,9 @@ async fn challenge_client(
Error::Transport
})?;
let formatted = format_challenge(nonce, &challenge_pubkey);
let sig = signature.as_slice().try_into().map_err(|_| {
error!("Invalid signature length");
Error::InvalidChallengeSolution
})?;
let formatted = format_challenge(nonce, pubkey.as_bytes());
pubkey.verify_strict(&formatted, &sig).map_err(|_| {
pubkey.verify_strict(&formatted, &signature).map_err(|_| {
error!("Challenge solution verification failed");
Error::InvalidChallengeSolution
})?;
@@ -184,30 +182,17 @@ async fn challenge_client(
Ok(())
}
fn connect_error_code(err: &Error) -> ConnectErrorCode {
match err {
Error::ApproveError(ApproveError::Denied) => ConnectErrorCode::ApprovalDenied,
Error::ApproveError(ApproveError::Upstream(
router::ApprovalError::NoUserAgentsConnected,
)) => ConnectErrorCode::NoUserAgentsOnline,
_ => ConnectErrorCode::Unknown,
}
}
async fn authenticate(props: &mut ClientConnection) -> Result<VerifyingKey, Error> {
let Some(Request::AuthChallengeRequest {
pubkey: challenge_pubkey,
}) = props.transport.recv().await
else {
pub async fn authenticate<T>(
props: &mut ClientConnection,
transport: &mut T,
) -> Result<VerifyingKey, Error>
where
T: Bi<Inbound, Result<Outbound, Error>> + Send + ?Sized,
{
let Some(Inbound::AuthChallengeRequest { pubkey }) = transport.recv().await else {
return Err(Error::Transport);
};
let pubkey_bytes = challenge_pubkey
.as_array()
.ok_or(Error::InvalidClientPubkeyLength)?;
let pubkey =
VerifyingKey::from_bytes(pubkey_bytes).map_err(|_| Error::InvalidAuthPubkeyEncoding)?;
let nonce = match get_nonce(&props.db, &pubkey).await? {
Some(nonce) => nonce,
None => {
@@ -217,21 +202,14 @@ async fn authenticate(props: &mut ClientConnection) -> Result<VerifyingKey, Erro
}
};
challenge_client(props, pubkey, nonce).await?;
challenge_client(transport, pubkey, nonce).await?;
transport
.send(Ok(Outbound::AuthSuccess))
.await
.map_err(|e| {
error!(error = ?e, "Failed to send auth success");
Error::Transport
})?;
Ok(pubkey)
}
pub async fn authenticate_and_create(mut props: ClientConnection) -> Result<ClientSession, Error> {
match authenticate(&mut props).await {
Ok(_pubkey) => Ok(ClientSession::new(props)),
Err(err) => {
let code = connect_error_code(&err);
let _ = props
.transport
.send(Ok(Response::ClientConnectError { code }))
.await;
Err(err)
}
}
}

View File

@@ -7,68 +7,31 @@ use crate::{
db,
};
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
pub enum ClientError {
#[error("Expected message with payload")]
MissingRequestPayload,
#[error("Unexpected request payload")]
UnexpectedRequestPayload,
#[error("State machine error")]
StateTransitionFailed,
#[error("Connection registration failed")]
ConnectionRegistrationFailed,
#[error(transparent)]
Auth(#[from] auth::Error),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectErrorCode {
Unknown,
ApprovalDenied,
NoUserAgentsOnline,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Request {
AuthChallengeRequest { pubkey: Vec<u8> },
AuthChallengeSolution { signature: Vec<u8> },
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Response {
AuthChallenge { pubkey: Vec<u8>, nonce: i32 },
AuthOk,
ClientConnectError { code: ConnectErrorCode },
}
pub type Transport = Box<dyn Bi<Request, Result<Response, ClientError>> + Send>;
pub struct ClientConnection {
pub(crate) db: db::DatabasePool,
pub(crate) transport: Transport,
pub(crate) actors: GlobalActors,
}
impl ClientConnection {
pub fn new(db: db::DatabasePool, transport: Transport, actors: GlobalActors) -> Self {
Self {
db,
transport,
actors,
}
pub fn new(db: db::DatabasePool, actors: GlobalActors) -> Self {
Self { db, actors }
}
}
pub mod auth;
pub mod session;
pub async fn connect_client(props: ClientConnection) {
match auth::authenticate_and_create(props).await {
Ok(session) => {
ClientSession::spawn(session);
pub async fn connect_client<T>(mut props: ClientConnection, transport: &mut T)
where
T: Bi<auth::Inbound, Result<auth::Outbound, auth::Error>> + Send + ?Sized,
{
match auth::authenticate(&mut props, transport).await {
Ok(_pubkey) => {
ClientSession::spawn(ClientSession::new(props));
info!("Client authenticated, session started");
}
Err(err) => {
let _ = transport.send(Err(err.clone())).await;
error!(?err, "Authentication failed, closing connection");
}
}

View File

@@ -1,12 +1,9 @@
use kameo::Actor;
use tokio::select;
use tracing::{error, info};
use kameo::{Actor, messages};
use tracing::error;
use crate::{
actors::{
GlobalActors,
client::{ClientConnection, ClientError, Request, Response},
router::RegisterClient,
GlobalActors, client::ClientConnection, keyholder::KeyHolderState, router::RegisterClient,
},
db,
};
@@ -19,19 +16,30 @@ impl ClientSession {
pub(crate) fn new(props: ClientConnection) -> Self {
Self { props }
}
pub async fn process_transport_inbound(&mut self, req: Request) -> Output {
let _ = req;
Err(ClientError::UnexpectedRequestPayload)
}
}
type Output = Result<Response, ClientError>;
#[messages]
impl ClientSession {
#[message]
pub(crate) async fn handle_query_vault_state(&mut self) -> Result<KeyHolderState, Error> {
use crate::actors::keyholder::GetState;
let vault_state = match self.props.actors.key_holder.ask(GetState {}).await {
Ok(state) => state,
Err(err) => {
error!(?err, actor = "client", "keyholder.query.failed");
return Err(Error::Internal);
}
};
Ok(vault_state)
}
}
impl Actor for ClientSession {
type Args = Self;
type Error = ClientError;
type Error = Error;
async fn on_start(
args: Self::Args,
@@ -42,52 +50,22 @@ impl Actor for ClientSession {
.router
.ask(RegisterClient { actor: this })
.await
.map_err(|_| ClientError::ConnectionRegistrationFailed)?;
.map_err(|_| Error::ConnectionRegistrationFailed)?;
Ok(args)
}
async fn next(
&mut self,
_actor_ref: kameo::prelude::WeakActorRef<Self>,
mailbox_rx: &mut kameo::prelude::MailboxReceiver<Self>,
) -> Option<kameo::mailbox::Signal<Self>> {
loop {
select! {
signal = mailbox_rx.recv() => {
return signal;
}
msg = self.props.transport.recv() => {
match msg {
Some(request) => {
match self.process_transport_inbound(request).await {
Ok(resp) => {
if self.props.transport.send(Ok(resp)).await.is_err() {
error!(actor = "client", reason = "channel closed", "send.failed");
return Some(kameo::mailbox::Signal::Stop);
}
}
Err(err) => {
let _ = self.props.transport.send(Err(err)).await;
return Some(kameo::mailbox::Signal::Stop);
}
}
}
None => {
info!(actor = "client", "transport.closed");
return Some(kameo::mailbox::Signal::Stop);
}
}
}
}
}
}
}
impl ClientSession {
pub fn new_test(db: db::DatabasePool, actors: GlobalActors) -> Self {
use arbiter_proto::transport::DummyTransport;
let transport: super::Transport = Box::new(DummyTransport::new());
let props = ClientConnection::new(db, transport, actors);
let props = ClientConnection::new(db, actors);
Self { props }
}
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Connection registration failed")]
ConnectionRegistrationFailed,
#[error("Internal error")]
Internal,
}

View File

@@ -106,7 +106,7 @@ pub struct AuthContext<'a, T> {
}
impl<'a, T> AuthContext<'a, T> {
pub fn new(conn: &'a mut UserAgentConnection, transport: T) -> Self {
pub fn new(conn: &'a mut UserAgentConnection, transport: T) -> Self {
Self { conn, transport }
}
}
@@ -124,8 +124,7 @@ where
let stored_bytes = pubkey.to_stored_bytes();
let nonce = create_nonce(&self.conn.db, &stored_bytes).await?;
self
.transport
self.transport
.send(Ok(Outbound::AuthChallenge { nonce }))
.await
.map_err(|e| {
@@ -165,8 +164,7 @@ where
register_key(&self.conn.db, &pubkey).await?;
self
.transport
self.transport
.send(Ok(Outbound::AuthSuccess))
.await
.map_err(|_| Error::Transport)?;
@@ -214,8 +212,7 @@ where
};
if valid {
self
.transport
self.transport
.send(Ok(Outbound::AuthSuccess))
.await
.map_err(|_| Error::Transport)?;

View File

@@ -1,142 +1,118 @@
use arbiter_proto::{
proto::client::{
AuthChallenge as ProtoAuthChallenge, AuthChallengeRequest as ProtoAuthChallengeRequest,
AuthChallengeSolution as ProtoAuthChallengeSolution, AuthOk as ProtoAuthOk,
ClientConnectError, ClientRequest, ClientResponse,
client_connect_error::Code as ProtoClientConnectErrorCode,
ClientRequest, ClientResponse, VaultState as ProtoVaultState,
client_request::Payload as ClientRequestPayload,
client_response::Payload as ClientResponsePayload,
},
transport::{Bi, Error as TransportError, Sender},
transport::{Receiver, Sender, grpc::GrpcBi},
};
use async_trait::async_trait;
use futures::StreamExt as _;
use tokio::sync::mpsc;
use tonic::{Status, Streaming};
use kameo::{
actor::{ActorRef, Spawn as _},
error::SendError,
};
use tracing::{info, warn};
use crate::actors::client::{
self, ClientError, ConnectErrorCode, Request as DomainRequest, Response as DomainResponse,
use crate::{
actors::{
client::{
self, ClientConnection,
session::{ClientSession, Error, HandleQueryVaultState},
},
keyholder::KeyHolderState,
},
utils::defer,
};
pub struct GrpcTransport {
sender: mpsc::Sender<Result<ClientResponse, Status>>,
receiver: Streaming<ClientRequest>,
}
mod auth;
impl GrpcTransport {
pub fn new(
sender: mpsc::Sender<Result<ClientResponse, Status>>,
receiver: Streaming<ClientRequest>,
) -> Self {
Self { sender, receiver }
}
fn request_to_domain(request: ClientRequest) -> Result<DomainRequest, Status> {
match request.payload {
Some(ClientRequestPayload::AuthChallengeRequest(ProtoAuthChallengeRequest {
pubkey,
})) => Ok(DomainRequest::AuthChallengeRequest { pubkey }),
Some(ClientRequestPayload::AuthChallengeSolution(ProtoAuthChallengeSolution {
signature,
})) => Ok(DomainRequest::AuthChallengeSolution { signature }),
None => Err(Status::invalid_argument("Missing client request payload")),
}
}
fn response_to_proto(response: DomainResponse) -> ClientResponse {
let payload = match response {
DomainResponse::AuthChallenge { pubkey, nonce } => {
ClientResponsePayload::AuthChallenge(ProtoAuthChallenge { pubkey, nonce })
}
DomainResponse::AuthOk => ClientResponsePayload::AuthOk(ProtoAuthOk {}),
DomainResponse::ClientConnectError { code } => {
ClientResponsePayload::ClientConnectError(ClientConnectError {
code: match code {
ConnectErrorCode::Unknown => ProtoClientConnectErrorCode::Unknown,
ConnectErrorCode::ApprovalDenied => {
ProtoClientConnectErrorCode::ApprovalDenied
}
ConnectErrorCode::NoUserAgentsOnline => {
ProtoClientConnectErrorCode::NoUserAgentsOnline
}
}
.into(),
})
}
async fn dispatch_loop(
mut bi: GrpcBi<ClientRequest, ClientResponse>,
actor: ActorRef<ClientSession>,
) {
loop {
let Some(conn) = bi.recv().await else {
return;
};
ClientResponse {
payload: Some(payload),
}
}
fn error_to_status(value: ClientError) -> Status {
match value {
ClientError::MissingRequestPayload | ClientError::UnexpectedRequestPayload => {
Status::invalid_argument("Expected message with payload")
}
ClientError::StateTransitionFailed => Status::internal("State machine error"),
ClientError::Auth(ref err) => auth_error_status(err),
ClientError::ConnectionRegistrationFailed => {
Status::internal("Connection registration failed")
}
if dispatch_conn_message(&mut bi, &actor, conn).await.is_err() {
return;
}
}
}
#[async_trait]
impl Sender<Result<DomainResponse, ClientError>> for GrpcTransport {
async fn send(
&mut self,
item: Result<DomainResponse, ClientError>,
) -> Result<(), TransportError> {
let outbound = match item {
Ok(message) => Ok(Self::response_to_proto(message)),
Err(err) => Err(Self::error_to_status(err)),
};
async fn dispatch_conn_message(
bi: &mut GrpcBi<ClientRequest, ClientResponse>,
actor: &ActorRef<ClientSession>,
conn: Result<ClientRequest, tonic::Status>,
) -> Result<(), ()> {
let conn = match conn {
Ok(conn) => conn,
Err(err) => {
warn!(error = ?err, "Failed to receive client request");
return Err(());
}
};
self.sender
.send(outbound)
.await
.map_err(|_| TransportError::ChannelClosed)
}
}
let Some(payload) = conn.payload else {
let _ = bi
.send(Err(tonic::Status::invalid_argument(
"Missing client request payload",
)))
.await;
return Err(());
};
#[async_trait]
impl Bi<DomainRequest, Result<DomainResponse, ClientError>> for GrpcTransport {
async fn recv(&mut self) -> Option<DomainRequest> {
match self.receiver.next().await {
Some(Ok(item)) => match Self::request_to_domain(item) {
Ok(request) => Some(request),
Err(status) => {
let _ = self.sender.send(Err(status)).await;
None
let payload = match payload {
ClientRequestPayload::QueryVaultState(_) => ClientResponsePayload::VaultState(
match actor.ask(HandleQueryVaultState {}).await {
Ok(KeyHolderState::Unbootstrapped) => ProtoVaultState::Unbootstrapped,
Ok(KeyHolderState::Sealed) => ProtoVaultState::Sealed,
Ok(KeyHolderState::Unsealed) => ProtoVaultState::Unsealed,
Err(SendError::HandlerError(Error::Internal)) => ProtoVaultState::Error,
Err(err) => {
warn!(error = ?err, "Failed to query vault state");
ProtoVaultState::Error
}
},
Some(Err(error)) => {
tracing::error!(error = ?error, "grpc client recv failed; closing stream");
None
}
None => None,
.into(),
),
payload => {
warn!(?payload, "Unsupported post-auth client request");
let _ = bi
.send(Err(tonic::Status::invalid_argument(
"Unsupported client request",
)))
.await;
return Err(());
}
};
bi.send(Ok(ClientResponse {
payload: Some(payload),
}))
.await
.map_err(|_| ())
}
pub async fn start(conn: ClientConnection, mut bi: GrpcBi<ClientRequest, ClientResponse>) {
let mut conn = conn;
match auth::start(&mut conn, &mut bi).await {
Ok(_) => {
let actor =
client::session::ClientSession::spawn(client::session::ClientSession::new(conn));
let actor_for_cleanup = actor.clone();
let _ = defer(move || {
actor_for_cleanup.kill();
});
info!("Client authenticated successfully");
dispatch_loop(bi, actor).await;
}
Err(e) => {
let mut transport = auth::AuthTransportAdapter(&mut bi);
let _ = transport.send(Err(e.clone())).await;
warn!(error = ?e, "Authentication failed");
return;
}
}
}
fn auth_error_status(value: &client::auth::Error) -> Status {
use client::auth::Error;
match value {
Error::UnexpectedMessagePayload | Error::InvalidClientPubkeyLength => {
Status::invalid_argument(value.to_string())
}
Error::InvalidAuthPubkeyEncoding => {
Status::invalid_argument("Failed to convert pubkey to VerifyingKey")
}
Error::InvalidChallengeSolution => Status::unauthenticated(value.to_string()),
Error::ApproveError(_) => Status::permission_denied(value.to_string()),
Error::Transport => Status::internal("Transport error"),
Error::DatabasePoolUnavailable => Status::internal("Database pool error"),
Error::DatabaseOperationFailed => Status::internal("Database error"),
Error::InternalError => Status::internal("Internal error"),
}
}

View File

@@ -0,0 +1,131 @@
use arbiter_proto::{
proto::client::{
AuthChallenge as ProtoAuthChallenge, AuthChallengeRequest as ProtoAuthChallengeRequest,
AuthChallengeSolution as ProtoAuthChallengeSolution, AuthResult as ProtoAuthResult,
ClientRequest, ClientResponse, client_request::Payload as ClientRequestPayload,
client_response::Payload as ClientResponsePayload,
},
transport::{Bi, Error as TransportError, Receiver, Sender, grpc::GrpcBi},
};
use async_trait::async_trait;
use tracing::warn;
use crate::actors::client::{self, ClientConnection, auth};
pub struct AuthTransportAdapter<'a>(pub(super) &'a mut GrpcBi<ClientRequest, ClientResponse>);
impl AuthTransportAdapter<'_> {
fn response_to_proto(response: auth::Outbound) -> ClientResponse {
let payload = match response {
auth::Outbound::AuthChallenge { pubkey, nonce } => {
ClientResponsePayload::AuthChallenge(ProtoAuthChallenge {
pubkey: pubkey.to_bytes().to_vec(),
nonce,
})
}
auth::Outbound::AuthSuccess => {
ClientResponsePayload::AuthResult(ProtoAuthResult::Success.into())
}
};
ClientResponse {
payload: Some(payload),
}
}
fn error_to_proto(error: auth::Error) -> ClientResponse {
ClientResponse {
payload: Some(ClientResponsePayload::AuthResult(
match error {
auth::Error::InvalidChallengeSolution => ProtoAuthResult::InvalidSignature,
auth::Error::ApproveError(auth::ApproveError::Denied) => {
ProtoAuthResult::ApprovalDenied
}
auth::Error::ApproveError(auth::ApproveError::Upstream(
crate::actors::router::ApprovalError::NoUserAgentsConnected,
)) => ProtoAuthResult::NoUserAgentsOnline,
auth::Error::ApproveError(auth::ApproveError::Internal)
| auth::Error::DatabasePoolUnavailable
| auth::Error::DatabaseOperationFailed
| auth::Error::Transport => ProtoAuthResult::Internal,
}
.into(),
)),
}
}
async fn send_auth_result(&mut self, result: ProtoAuthResult) -> Result<(), TransportError> {
self.0
.send(Ok(ClientResponse {
payload: Some(ClientResponsePayload::AuthResult(result.into())),
}))
.await
}
}
#[async_trait]
impl Sender<Result<auth::Outbound, auth::Error>> for AuthTransportAdapter<'_> {
async fn send(
&mut self,
item: Result<auth::Outbound, auth::Error>,
) -> Result<(), TransportError> {
let outbound = match item {
Ok(message) => Ok(AuthTransportAdapter::response_to_proto(message)),
Err(err) => Ok(AuthTransportAdapter::error_to_proto(err)),
};
self.0.send(outbound).await
}
}
#[async_trait]
impl Receiver<auth::Inbound> for AuthTransportAdapter<'_> {
async fn recv(&mut self) -> Option<auth::Inbound> {
let request = match self.0.recv().await? {
Ok(request) => request,
Err(error) => {
warn!(error = ?error, "grpc client recv failed; closing stream");
return None;
}
};
let payload = request.payload?;
match payload {
ClientRequestPayload::AuthChallengeRequest(ProtoAuthChallengeRequest { pubkey }) => {
let Ok(pubkey) = <[u8; 32]>::try_from(pubkey) else {
let _ = self.send_auth_result(ProtoAuthResult::InvalidKey).await;
return None;
};
let Ok(pubkey) = ed25519_dalek::VerifyingKey::from_bytes(&pubkey) else {
let _ = self.send_auth_result(ProtoAuthResult::InvalidKey).await;
return None;
};
Some(auth::Inbound::AuthChallengeRequest { pubkey })
}
ClientRequestPayload::AuthChallengeSolution(ProtoAuthChallengeSolution {
signature,
}) => {
let Ok(signature) = ed25519_dalek::Signature::try_from(signature.as_slice()) else {
let _ = self
.send_auth_result(ProtoAuthResult::InvalidSignature)
.await;
return None;
};
Some(auth::Inbound::AuthChallengeSolution { signature })
}
_ => None,
}
}
}
impl Bi<auth::Inbound, Result<auth::Outbound, auth::Error>> for AuthTransportAdapter<'_> {}
pub async fn start(
conn: &mut ClientConnection,
bi: &mut GrpcBi<ClientRequest, ClientResponse>,
) -> Result<(), auth::Error> {
let mut transport = AuthTransportAdapter(bi);
client::auth::authenticate(conn, &mut transport).await?;
Ok(())
}

View File

@@ -12,10 +12,7 @@ use tracing::info;
use crate::{
DEFAULT_CHANNEL_SIZE,
actors::{
client::{ClientConnection, connect_client},
user_agent::UserAgentConnection,
},
actors::{client::ClientConnection, user_agent::UserAgentConnection},
grpc::{self, user_agent::start},
};
@@ -33,19 +30,13 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for super::Ser
request: Request<tonic::Streaming<ClientRequest>>,
) -> Result<Response<Self::ClientStream>, Status> {
let req_stream = request.into_inner();
let (tx, rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
let transport = client::GrpcTransport::new(tx, req_stream);
let props = ClientConnection::new(
self.context.db.clone(),
Box::new(transport),
self.context.actors.clone(),
);
tokio::spawn(connect_client(props));
let (bi, rx) = GrpcBi::from_bi_stream(req_stream);
let props = ClientConnection::new(self.context.db.clone(), self.context.actors.clone());
tokio::spawn(client::start(props, bi));
info!(event = "connection established", "grpc.client");
Ok(Response::new(ReceiverStream::new(rx)))
Ok(Response::new(rx))
}
#[tracing::instrument(level = "debug", skip(self))]

View File

@@ -30,7 +30,10 @@ use arbiter_proto::{
};
use async_trait::async_trait;
use chrono::{TimeZone, Utc};
use kameo::{actor::{ActorRef, Spawn as _}, error::SendError};
use kameo::{
actor::{ActorRef, Spawn as _},
error::SendError,
};
use tonic::Status;
use tracing::{info, warn};
@@ -40,7 +43,9 @@ use crate::{
user_agent::{
OutOfBand, UserAgentConnection, UserAgentSession,
session::{
BootstrapError, Error, HandleBootstrapEncryptedKey, HandleEvmWalletCreate, HandleEvmWalletList, HandleGrantCreate, HandleGrantDelete, HandleGrantList, HandleQueryVaultState, HandleUnsealEncryptedKey, HandleUnsealRequest, UnsealError
BootstrapError, Error, HandleBootstrapEncryptedKey, HandleEvmWalletCreate,
HandleEvmWalletList, HandleGrantCreate, HandleGrantDelete, HandleGrantList,
HandleQueryVaultState, HandleUnsealEncryptedKey, HandleUnsealRequest, UnsealError,
},
},
},
@@ -109,7 +114,11 @@ async fn dispatch_conn_message(
};
let Some(payload) = conn.payload else {
let _ = bi.send(Err(Status::invalid_argument("Missing user-agent request payload"))).await;
let _ = bi
.send(Err(Status::invalid_argument(
"Missing user-agent request payload",
)))
.await;
return Err(());
};
@@ -118,7 +127,9 @@ async fn dispatch_conn_message(
let client_pubkey = match <[u8; 32]>::try_from(client_pubkey) {
Ok(bytes) => x25519_dalek::PublicKey::from(bytes),
Err(_) => {
let _ = bi.send(Err(Status::invalid_argument("Invalid X25519 public key"))).await;
let _ = bi
.send(Err(Status::invalid_argument("Invalid X25519 public key")))
.await;
return Err(());
}
};
@@ -131,7 +142,9 @@ async fn dispatch_conn_message(
),
Err(err) => {
warn!(error = ?err, "Failed to handle unseal start request");
let _ = bi.send(Err(Status::internal("Failed to start unseal flow"))).await;
let _ = bi
.send(Err(Status::internal("Failed to start unseal flow")))
.await;
return Err(());
}
}
@@ -155,7 +168,9 @@ async fn dispatch_conn_message(
}
Err(err) => {
warn!(error = ?err, "Failed to handle unseal request");
let _ = bi.send(Err(Status::internal("Failed to unseal vault"))).await;
let _ = bi
.send(Err(Status::internal("Failed to unseal vault")))
.await;
return Err(());
}
}
@@ -178,12 +193,14 @@ async fn dispatch_conn_message(
Err(SendError::HandlerError(BootstrapError::InvalidKey)) => {
ProtoBootstrapResult::InvalidKey
}
Err(SendError::HandlerError(
BootstrapError::AlreadyBootstrapped,
)) => ProtoBootstrapResult::AlreadyBootstrapped,
Err(SendError::HandlerError(BootstrapError::AlreadyBootstrapped)) => {
ProtoBootstrapResult::AlreadyBootstrapped
}
Err(err) => {
warn!(error = ?err, "Failed to handle bootstrap request");
let _ = bi.send(Err(Status::internal("Failed to bootstrap vault"))).await;
let _ = bi
.send(Err(Status::internal("Failed to bootstrap vault")))
.await;
return Err(());
}
}
@@ -224,12 +241,13 @@ async fn dispatch_conn_message(
};
UserAgentResponsePayload::EvmGrantCreate(EvmGrantOrWallet::grant_create_response(
actor.ask(HandleGrantCreate {
client_id,
basic,
grant,
})
.await,
actor
.ask(HandleGrantCreate {
client_id,
basic,
grant,
})
.await,
))
}
UserAgentRequestPayload::EvmGrantDelete(EvmGrantDeleteRequest { grant_id }) => {
@@ -239,7 +257,11 @@ async fn dispatch_conn_message(
}
payload => {
warn!(?payload, "Unsupported post-auth user agent request");
let _ = bi.send(Err(Status::invalid_argument("Unsupported user-agent request"))).await;
let _ = bi
.send(Err(Status::invalid_argument(
"Unsupported user-agent request",
)))
.await;
return Err(());
}
};
@@ -281,7 +303,10 @@ fn parse_grant_request(
let specific =
specific.ok_or_else(|| Status::invalid_argument("Missing specific grant settings"))?;
Ok((shared_settings_from_proto(shared)?, specific_grant_from_proto(specific)?))
Ok((
shared_settings_from_proto(shared)?,
specific_grant_from_proto(specific)?,
))
}
fn shared_settings_from_proto(shared: ProtoSharedSettings) -> Result<SharedGrantSettings, Status> {
@@ -289,14 +314,8 @@ fn shared_settings_from_proto(shared: ProtoSharedSettings) -> Result<SharedGrant
wallet_id: shared.wallet_id,
client_id: 0,
chain: shared.chain_id,
valid_from: shared
.valid_from
.map(proto_timestamp_to_utc)
.transpose()?,
valid_until: shared
.valid_until
.map(proto_timestamp_to_utc)
.transpose()?,
valid_from: shared.valid_from.map(proto_timestamp_to_utc).transpose()?,
valid_until: shared.valid_until.map(proto_timestamp_to_utc).transpose()?,
max_gas_fee_per_gas: shared
.max_gas_fee_per_gas
.as_deref()
@@ -307,12 +326,10 @@ fn shared_settings_from_proto(shared: ProtoSharedSettings) -> Result<SharedGrant
.as_deref()
.map(u256_from_proto_bytes)
.transpose()?,
rate_limit: shared
.rate_limit
.map(|limit| TransactionRateLimit {
count: limit.count,
window: chrono::Duration::seconds(limit.window_secs),
}),
rate_limit: shared.rate_limit.map(|limit| TransactionRateLimit {
count: limit.count,
window: chrono::Duration::seconds(limit.window_secs),
}),
})
}
@@ -326,11 +343,9 @@ fn specific_grant_from_proto(specific: ProtoSpecificGrant) -> Result<SpecificGra
.into_iter()
.map(address_from_bytes)
.collect::<Result<_, _>>()?,
limit: volume_rate_limit_from_proto(
limit.ok_or_else(|| {
Status::invalid_argument("Missing ether transfer volume rate limit")
})?,
)?,
limit: volume_rate_limit_from_proto(limit.ok_or_else(|| {
Status::invalid_argument("Missing ether transfer volume rate limit")
})?)?,
})),
Some(ProtoSpecificGrantType::TokenTransfer(ProtoTokenTransferSettings {
token_contract,
@@ -391,12 +406,12 @@ fn shared_settings_to_proto(shared: SharedGrantSettings) -> ProtoSharedSettings
seconds: time.timestamp(),
nanos: time.timestamp_subsec_nanos() as i32,
}),
max_gas_fee_per_gas: shared.max_gas_fee_per_gas.map(|value| {
value.to_be_bytes::<32>().to_vec()
}),
max_priority_fee_per_gas: shared.max_priority_fee_per_gas.map(|value| {
value.to_be_bytes::<32>().to_vec()
}),
max_gas_fee_per_gas: shared
.max_gas_fee_per_gas
.map(|value| value.to_be_bytes::<32>().to_vec()),
max_priority_fee_per_gas: shared
.max_priority_fee_per_gas
.map(|value| value.to_be_bytes::<32>().to_vec()),
rate_limit: shared.rate_limit.map(|limit| ProtoTransactionRateLimit {
count: limit.count,
window_secs: limit.window.num_seconds(),
@@ -408,7 +423,11 @@ fn specific_grant_to_proto(grant: SpecificGrant) -> ProtoSpecificGrant {
let grant = match grant {
SpecificGrant::EtherTransfer(settings) => {
ProtoSpecificGrantType::EtherTransfer(ProtoEtherTransferSettings {
targets: settings.target.into_iter().map(|address| address.to_vec()).collect(),
targets: settings
.target
.into_iter()
.map(|address| address.to_vec())
.collect(),
limit: Some(ProtoVolumeRateLimit {
max_volume: settings.limit.max_volume.to_be_bytes::<32>().to_vec(),
window_secs: settings.limit.window.num_seconds(),
@@ -450,7 +469,9 @@ impl EvmGrantOrWallet {
}
};
WalletCreateResponse { result: Some(result) }
WalletCreateResponse {
result: Some(result),
}
}
fn wallet_list_response<M>(
@@ -471,7 +492,9 @@ impl EvmGrantOrWallet {
}
};
WalletListResponse { result: Some(result) }
WalletListResponse {
result: Some(result),
}
}
fn grant_create_response<M>(
@@ -485,12 +508,12 @@ impl EvmGrantOrWallet {
}
};
EvmGrantCreateResponse { result: Some(result) }
EvmGrantCreateResponse {
result: Some(result),
}
}
fn grant_delete_response<M>(
result: Result<(), SendError<M, Error>>,
) -> EvmGrantDeleteResponse {
fn grant_delete_response<M>(result: Result<(), SendError<M, Error>>) -> EvmGrantDeleteResponse {
let result = match result {
Ok(()) => EvmGrantDeleteResult::Ok(()),
Err(err) => {
@@ -499,7 +522,9 @@ impl EvmGrantOrWallet {
}
};
EvmGrantDeleteResponse { result: Some(result) }
EvmGrantDeleteResponse {
result: Some(result),
}
}
fn grant_list_response<M>(
@@ -523,7 +548,9 @@ impl EvmGrantOrWallet {
}
};
EvmGrantListResponse { result: Some(result) }
EvmGrantListResponse {
result: Some(result),
}
}
}

View File

@@ -1,9 +1,5 @@
#![forbid(unsafe_code)]
#![deny(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic
)]
#![deny(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
use crate::context::ServerContext;
@@ -26,4 +22,3 @@ impl Server {
Self { context }
}
}

View File

@@ -1,7 +1,7 @@
use arbiter_proto::transport::Bi;
use arbiter_proto::transport::{Receiver, Sender};
use arbiter_server::actors::GlobalActors;
use arbiter_server::{
actors::client::{ClientConnection, Request, Response, connect_client},
actors::client::{ClientConnection, auth, connect_client},
db::{self, schema},
};
use diesel::{ExpressionMethods as _, insert_into};
@@ -17,15 +17,17 @@ pub async fn test_unregistered_pubkey_rejected() {
let (server_transport, mut test_transport) = ChannelTransport::new();
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
let props = ClientConnection::new(db.clone(), Box::new(server_transport), actors);
let task = tokio::spawn(connect_client(props));
let props = ClientConnection::new(db.clone(), actors);
let task = tokio::spawn(async move {
let mut server_transport = server_transport;
connect_client(props, &mut server_transport).await;
});
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
test_transport
.send(Request::AuthChallengeRequest {
pubkey: pubkey_bytes,
.send(auth::Inbound::AuthChallengeRequest {
pubkey: new_key.verifying_key(),
})
.await
.unwrap();
@@ -54,13 +56,16 @@ pub async fn test_challenge_auth() {
let (server_transport, mut test_transport) = ChannelTransport::new();
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
let props = ClientConnection::new(db.clone(), Box::new(server_transport), actors);
let task = tokio::spawn(connect_client(props));
let props = ClientConnection::new(db.clone(), actors);
let task = tokio::spawn(async move {
let mut server_transport = server_transport;
connect_client(props, &mut server_transport).await;
});
// Send challenge request
test_transport
.send(Request::AuthChallengeRequest {
pubkey: pubkey_bytes,
.send(auth::Inbound::AuthChallengeRequest {
pubkey: new_key.verifying_key(),
})
.await
.unwrap();
@@ -72,23 +77,31 @@ pub async fn test_challenge_auth() {
.expect("should receive challenge");
let challenge = match response {
Ok(resp) => match resp {
Response::AuthChallenge { pubkey, nonce } => (pubkey, nonce),
auth::Outbound::AuthChallenge { pubkey, nonce } => (pubkey, nonce),
other => panic!("Expected AuthChallenge, got {other:?}"),
},
Err(err) => panic!("Expected Ok response, got Err({err:?})"),
};
// Sign the challenge and send solution
let formatted_challenge = arbiter_proto::format_challenge(challenge.1, &challenge.0);
let formatted_challenge = arbiter_proto::format_challenge(challenge.1, challenge.0.as_bytes());
let signature = new_key.sign(&formatted_challenge);
test_transport
.send(Request::AuthChallengeSolution {
signature: signature.to_bytes().to_vec(),
})
.send(auth::Inbound::AuthChallengeSolution { signature })
.await
.unwrap();
let response = test_transport
.recv()
.await
.expect("should receive auth success");
match response {
Ok(auth::Outbound::AuthSuccess) => {}
Ok(other) => panic!("Expected AuthSuccess, got {other:?}"),
Err(err) => panic!("Expected Ok response, got Err({err:?})"),
}
// Auth completes, session spawned
task.await.unwrap();
}

View File

@@ -1,7 +1,8 @@
use arbiter_proto::transport::{Bi, Error};
use arbiter_proto::transport::{Bi, Error, Receiver, Sender};
use arbiter_server::{
actors::keyholder::KeyHolder,
db::{self, schema}, safe_cell::{SafeCell, SafeCellHandle as _},
db::{self, schema},
safe_cell::{SafeCell, SafeCellHandle as _},
};
use async_trait::async_trait;
use diesel::QueryDsl;
@@ -54,10 +55,10 @@ impl<T, Y> ChannelTransport<T, Y> {
}
#[async_trait]
impl<T, Y> Bi<T, Y> for ChannelTransport<T, Y>
impl<T, Y> Sender<Y> for ChannelTransport<T, Y>
where
T: Send + 'static,
Y: Send + 'static,
T: Send + Sync + 'static,
Y: Send + Sync + 'static,
{
async fn send(&mut self, item: Y) -> Result<(), Error> {
self.sender
@@ -65,8 +66,22 @@ where
.await
.map_err(|_| Error::ChannelClosed)
}
}
#[async_trait]
impl<T, Y> Receiver<T> for ChannelTransport<T, Y>
where
T: Send + Sync + 'static,
Y: Send + Sync + 'static,
{
async fn recv(&mut self) -> Option<T> {
self.receiver.recv().await
}
}
impl<T, Y> Bi<T, Y> for ChannelTransport<T, Y>
where
T: Send + Sync + 'static,
Y: Send + Sync + 'static,
{
}

View File

@@ -3,7 +3,7 @@ use arbiter_server::{
actors::{
GlobalActors,
bootstrap::GetToken,
user_agent::{AuthPublicKey, Request, OutOfBand, UserAgentConnection, connect_user_agent},
user_agent::{AuthPublicKey, OutOfBand, Request, UserAgentConnection, connect_user_agent},
},
db::{self, schema},
};

View File

@@ -2,7 +2,7 @@ use arbiter_server::{
actors::{
GlobalActors,
keyholder::{Bootstrap, Seal},
user_agent::{Request, OutOfBand, UnsealError, session::UserAgentSession},
user_agent::{OutOfBand, Request, UnsealError, session::UserAgentSession},
},
db,
safe_cell::{SafeCell, SafeCellHandle as _},