From d61dab3285b8d9f2b30b8e0bec7f5e98c043bd23 Mon Sep 17 00:00:00 2001 From: hdbg Date: Tue, 17 Mar 2026 18:39:12 +0100 Subject: [PATCH] refactor(server::useragent): migrated to new connection design --- protobufs/user_agent.proto | 16 +- server/Cargo.lock | 1 + server/crates/arbiter-proto/Cargo.toml | 1 + server/crates/arbiter-proto/src/transport.rs | 49 +- .../arbiter-proto/src/transport/grpc.rs | 106 ++ .../src/actors/keyholder/mod.rs | 4 +- .../src/actors/user_agent/auth.rs | 120 +-- .../src/actors/user_agent/auth/state.rs | 53 +- .../src/actors/user_agent/mod.rs | 158 +-- .../src/actors/user_agent/session.rs | 175 +--- .../actors/user_agent/session/connection.rs | 202 ++-- server/crates/arbiter-server/src/db/mod.rs | 8 + .../crates/arbiter-server/src/grpc/client.rs | 21 +- server/crates/arbiter-server/src/grpc/mod.rs | 35 +- .../arbiter-server/src/grpc/user_agent.rs | 974 +++++++++--------- .../src/grpc/user_agent/auth.rs | 151 +++ server/crates/arbiter-server/src/lib.rs | 1 + server/crates/arbiter-server/src/utils.rs | 16 + .../arbiter-server/tests/user_agent/auth.rs | 4 +- .../arbiter-server/tests/user_agent/unseal.rs | 14 +- 20 files changed, 1151 insertions(+), 958 deletions(-) create mode 100644 server/crates/arbiter-proto/src/transport/grpc.rs create mode 100644 server/crates/arbiter-server/src/grpc/user_agent/auth.rs create mode 100644 server/crates/arbiter-server/src/utils.rs diff --git a/protobufs/user_agent.proto b/protobufs/user_agent.proto index 821575e..6fb77e4 100644 --- a/protobufs/user_agent.proto +++ b/protobufs/user_agent.proto @@ -2,8 +2,8 @@ syntax = "proto3"; package arbiter.user_agent; -import "google/protobuf/empty.proto"; import "evm.proto"; +import "google/protobuf/empty.proto"; enum KeyType { KEY_TYPE_UNSPECIFIED = 0; @@ -19,15 +19,23 @@ message AuthChallengeRequest { } message AuthChallenge { - bytes pubkey = 1; int32 nonce = 2; + reserved 1; } message AuthChallengeSolution { bytes signature = 1; } -message AuthOk {} +enum AuthResult { + AUTH_RESULT_UNSPECIFIED = 0; + AUTH_RESULT_SUCCESS = 1; + AUTH_RESULT_INVALID_KEY = 2; + AUTH_RESULT_INVALID_SIGNATURE = 3; + AUTH_RESULT_BOOTSTRAP_REQUIRED = 4; + AUTH_RESULT_TOKEN_INVALID = 5; + AUTH_RESULT_INTERNAL = 6; +} message UnsealStart { bytes client_pubkey = 1; @@ -99,7 +107,7 @@ message UserAgentRequest { message UserAgentResponse { oneof payload { AuthChallenge auth_challenge = 1; - AuthOk auth_ok = 2; + AuthResult auth_result = 2; UnsealStartResponse unseal_start_response = 3; UnsealResult unseal_result = 4; VaultState vault_state = 5; diff --git a/server/Cargo.lock b/server/Cargo.lock index 30ec3d7..057a88b 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -697,6 +697,7 @@ dependencies = [ "rustls-pki-types", "thiserror", "tokio", + "tokio-stream", "tonic", "tonic-prost", "tonic-prost-build", diff --git a/server/crates/arbiter-proto/Cargo.toml b/server/crates/arbiter-proto/Cargo.toml index 0673f8a..88676a0 100644 --- a/server/crates/arbiter-proto/Cargo.toml +++ b/server/crates/arbiter-proto/Cargo.toml @@ -21,6 +21,7 @@ base64 = "0.22.1" prost-types.workspace = true tracing.workspace = true async-trait.workspace = true +tokio-stream.workspace = true [build-dependencies] tonic-prost-build = "0.14.3" diff --git a/server/crates/arbiter-proto/src/transport.rs b/server/crates/arbiter-proto/src/transport.rs index 55415c8..b31aa61 100644 --- a/server/crates/arbiter-proto/src/transport.rs +++ b/server/crates/arbiter-proto/src/transport.rs @@ -63,16 +63,29 @@ where extractor(msg).ok_or(Error::UnexpectedMessage) } +#[async_trait] +pub trait Sender: Send + Sync { + async fn send(&mut self, item: Outbound) -> Result<(), Error>; +} + +#[async_trait] +pub trait Receiver: Send + Sync { + async fn recv(&mut self) -> Option; +} + /// Minimal bidirectional transport abstraction used by protocol code. /// /// `Bi` models a duplex channel with: /// - inbound items of type `Inbound` read via [`Bi::recv`] /// - outbound items of type `Outbound` written via [`Bi::send`] -#[async_trait] -pub trait Bi: Send + Sync + 'static { - async fn send(&mut self, item: Outbound) -> Result<(), Error>; +pub trait Bi: Sender + Receiver + Send + Sync {} - async fn recv(&mut self) -> Option; +pub trait SplittableBi: Bi { + type Sender: Sender; + type Receiver: Receiver; + + fn split(self) -> (Self::Sender, Self::Receiver); + fn from_parts(sender: Self::Sender, receiver: Self::Receiver) -> Self; } /// No-op [`Bi`] transport for tests and manual actor usage. @@ -83,22 +96,16 @@ pub struct DummyTransport { _marker: PhantomData<(Inbound, Outbound)>, } -impl DummyTransport { - pub fn new() -> Self { +impl Default for DummyTransport { + fn default() -> Self { Self { _marker: PhantomData, } } } -impl Default for DummyTransport { - fn default() -> Self { - Self::new() - } -} - #[async_trait] -impl Bi for DummyTransport +impl Sender for DummyTransport where Inbound: Send + Sync + 'static, Outbound: Send + Sync + 'static, @@ -106,9 +113,25 @@ where async fn send(&mut self, _item: Outbound) -> Result<(), Error> { Ok(()) } +} +#[async_trait] +impl Receiver for DummyTransport +where + Inbound: Send + Sync + 'static, + Outbound: Send + Sync + 'static, +{ async fn recv(&mut self) -> Option { std::future::pending::<()>().await; None } } + +impl Bi for DummyTransport +where + Inbound: Send + Sync + 'static, + Outbound: Send + Sync + 'static, +{ +} + +pub mod grpc; diff --git a/server/crates/arbiter-proto/src/transport/grpc.rs b/server/crates/arbiter-proto/src/transport/grpc.rs new file mode 100644 index 0000000..e0959e0 --- /dev/null +++ b/server/crates/arbiter-proto/src/transport/grpc.rs @@ -0,0 +1,106 @@ +use async_trait::async_trait; +use futures::StreamExt; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; + +use super::{Bi, Receiver, Sender}; + +pub struct GrpcSender { + tx: mpsc::Sender>, +} + +#[async_trait] +impl Sender> for GrpcSender +where + Outbound: Send + Sync + 'static, +{ + async fn send(&mut self, item: Result) -> Result<(), super::Error> { + self.tx + .send(item) + .await + .map_err(|_| super::Error::ChannelClosed) + } +} + +pub struct GrpcReceiver { + rx: tonic::Streaming, +} +#[async_trait] +impl Receiver> for GrpcReceiver +where + Inbound: Send + Sync + 'static, +{ + async fn recv(&mut self) -> Option> { + self.rx.next().await + } +} + +pub struct GrpcBi { + sender: GrpcSender, + receiver: GrpcReceiver, +} + +impl GrpcBi +where + Inbound: Send + Sync + 'static, + Outbound: Send + Sync + 'static, +{ + pub fn from_bi_stream( + receiver: tonic::Streaming, + ) -> (Self, ReceiverStream>) { + let (tx, rx) = mpsc::channel(10); + let sender = GrpcSender { tx }; + let receiver = GrpcReceiver { rx: receiver }; + let bi = GrpcBi { sender, receiver }; + (bi, ReceiverStream::new(rx)) + } +} + +#[async_trait] +impl Sender> for GrpcBi +where + Inbound: Send + Sync + 'static, + Outbound: Send + Sync + 'static, +{ + async fn send(&mut self, item: Result) -> Result<(), super::Error> { + self.sender.send(item).await + } +} + +#[async_trait] +impl Receiver> for GrpcBi +where + Inbound: Send + Sync + 'static, + Outbound: Send + Sync + 'static, +{ + async fn recv(&mut self) -> Option> { + self.receiver.recv().await + } +} + +impl Bi, Result> + for GrpcBi +where + Inbound: Send + Sync + 'static, + Outbound: Send + Sync + 'static, +{ +} + +impl + super::SplittableBi, Result> + for GrpcBi +where + Inbound: Send + Sync + 'static, + Outbound: Send + Sync + 'static, +{ + type Sender = GrpcSender; + type Receiver = GrpcReceiver; + + fn split(self) -> (Self::Sender, Self::Receiver) { + (self.sender, self.receiver) + } + + fn from_parts(sender: Self::Sender, receiver: Self::Receiver) -> Self { + GrpcBi { sender, receiver } + } +} diff --git a/server/crates/arbiter-server/src/actors/keyholder/mod.rs b/server/crates/arbiter-server/src/actors/keyholder/mod.rs index f37284a..3a245af 100644 --- a/server/crates/arbiter-server/src/actors/keyholder/mod.rs +++ b/server/crates/arbiter-server/src/actors/keyholder/mod.rs @@ -22,7 +22,7 @@ use encryption::v1::{self, KeyCell, Nonce}; pub mod encryption; #[derive(Default, EnumDiscriminants)] -#[strum_discriminants(derive(Reply), vis(pub))] +#[strum_discriminants(derive(Reply), vis(pub), name(KeyHolderState))] enum State { #[default] Unbootstrapped, @@ -325,7 +325,7 @@ impl KeyHolder { } #[message] - pub fn get_state(&self) -> StateDiscriminants { + pub fn get_state(&self) -> KeyHolderState { self.state.discriminant() } diff --git a/server/crates/arbiter-server/src/actors/user_agent/auth.rs b/server/crates/arbiter-server/src/actors/user_agent/auth.rs index eab7acf..7e2cf9c 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/auth.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/auth.rs @@ -1,74 +1,82 @@ +use arbiter_proto::transport::Bi; use tracing::error; use crate::actors::user_agent::{ - Request, UserAgentConnection, + AuthPublicKey, UserAgentConnection, auth::state::{AuthContext, AuthStateMachine}, - AuthPublicKey, - session::UserAgentSession, }; -#[derive(thiserror::Error, Debug, PartialEq)] -pub enum Error { - #[error("Unexpected message payload")] - UnexpectedMessagePayload, - #[error("Invalid client public key length")] - InvalidClientPubkeyLength, - #[error("Invalid client public key encoding")] - InvalidAuthPubkeyEncoding, - #[error("Database pool unavailable")] - DatabasePoolUnavailable, - #[error("Database operation failed")] - DatabaseOperationFailed, - #[error("Public key not registered")] - PublicKeyNotRegistered, - #[error("Transport error")] - Transport, - #[error("Invalid bootstrap token")] - InvalidBootstrapToken, - #[error("Bootstrapper actor unreachable")] - BootstrapperActorUnreachable, - #[error("Invalid challenge solution")] - InvalidChallengeSolution, -} - mod state; use state::*; -fn parse_auth_event(payload: Request) -> Result { - match payload { - Request::AuthChallengeRequest { - pubkey, - bootstrap_token: None, - } => Ok(AuthEvents::AuthRequest(ChallengeRequest { pubkey })), - Request::AuthChallengeRequest { - pubkey, - bootstrap_token: Some(token), - } => Ok(AuthEvents::BootstrapAuthRequest(BootstrapAuthRequest { - pubkey, - token, - })), - Request::AuthChallengeSolution { signature } => { - Ok(AuthEvents::ReceivedSolution(ChallengeSolution { - solution: signature, - })) +#[derive(Debug, Clone)] +pub enum Inbound { + AuthChallengeRequest { + pubkey: AuthPublicKey, + bootstrap_token: Option, + }, + AuthChallengeSolution { + signature: Vec, + }, +} + +#[derive(Debug)] +pub enum Error { + UnregisteredPublicKey, + InvalidChallengeSolution, + InvalidBootstrapToken, + Internal { details: String }, + Transport, +} + +impl Error { + fn internal(details: impl Into) -> Self { + Self::Internal { + details: details.into(), } - _ => Err(Error::UnexpectedMessagePayload), } } -pub async fn authenticate(props: &mut UserAgentConnection) -> Result { - let mut state = AuthStateMachine::new(AuthContext::new(props)); +#[derive(Debug, Clone)] +pub enum Outbound { + AuthChallenge { nonce: i32 }, + AuthSuccess, +} + +fn parse_auth_event(payload: Inbound) -> AuthEvents { + match payload { + Inbound::AuthChallengeRequest { + pubkey, + bootstrap_token: None, + } => AuthEvents::AuthRequest(ChallengeRequest { pubkey }), + Inbound::AuthChallengeRequest { + pubkey, + bootstrap_token: Some(token), + } => AuthEvents::BootstrapAuthRequest(BootstrapAuthRequest { pubkey, token }), + Inbound::AuthChallengeSolution { signature } => { + AuthEvents::ReceivedSolution(ChallengeSolution { + solution: signature, + }) + } + } +} + +pub async fn authenticate( + props: &mut UserAgentConnection, + transport: T, +) -> Result +where + T: Bi> + Send, +{ + let mut state = AuthStateMachine::new(AuthContext::new(props, transport)); loop { // `state` holds a mutable reference to `props` so we can't access it directly here - let transport = state.context_mut().conn.transport.as_mut(); - let Some(payload) = transport.recv().await else { + let Some(payload) = state.context_mut().transport.recv().await else { return Err(Error::Transport); }; - let event = parse_auth_event(payload)?; - - match state.process_event(event).await { + match state.process_event(parse_auth_event(payload)).await { Ok(AuthStates::AuthOk(key)) => return Ok(key.clone()), Err(AuthError::ActionFailed(err)) => { error!(?err, "State machine action failed"); @@ -91,11 +99,3 @@ pub async fn authenticate(props: &mut UserAgentConnection) -> Result Result { - let _key = authenticate(&mut props).await?; - let session = UserAgentSession::new(props); - Ok(session) -} diff --git a/server/crates/arbiter-server/src/actors/user_agent/auth/state.rs b/server/crates/arbiter-server/src/actors/user_agent/auth/state.rs index 608f3a7..7a5991d 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/auth/state.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/auth/state.rs @@ -1,3 +1,5 @@ +use alloy::transports::Transport; +use arbiter_proto::transport::Bi; use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, update}; use diesel_async::RunQueryDsl; use tracing::error; @@ -6,7 +8,7 @@ use super::Error; use crate::{ actors::{ bootstrap::ConsumeToken, - user_agent::{AuthPublicKey, Response, UserAgentConnection}, + user_agent::{AuthPublicKey, OutOfBand, UserAgentConnection, auth::Outbound}, }, db::schema, }; @@ -42,7 +44,7 @@ smlang::statemachine!( async fn create_nonce(db: &crate::db::DatabasePool, pubkey_bytes: &[u8]) -> Result { let mut db_conn = db.get().await.map_err(|e| { error!(error = ?e, "Database pool error"); - Error::DatabasePoolUnavailable + Error::internal("Database unavailable") })?; db_conn .exclusive_transaction(|conn| { @@ -66,11 +68,11 @@ async fn create_nonce(db: &crate::db::DatabasePool, pubkey_bytes: &[u8]) -> Resu .optional() .map_err(|e| { error!(error = ?e, "Database error"); - Error::DatabaseOperationFailed + Error::internal("Database operation failed") })? .ok_or_else(|| { error!(?pubkey_bytes, "Public key not found in database"); - Error::PublicKeyNotRegistered + Error::UnregisteredPublicKey }) } @@ -79,7 +81,7 @@ async fn register_key(db: &crate::db::DatabasePool, pubkey: &AuthPublicKey) -> R let key_type = pubkey.key_type(); let mut conn = db.get().await.map_err(|e| { error!(error = ?e, "Database pool error"); - Error::DatabasePoolUnavailable + Error::internal("Database unavailable") })?; diesel::insert_into(schema::useragent_client::table) @@ -92,23 +94,27 @@ async fn register_key(db: &crate::db::DatabasePool, pubkey: &AuthPublicKey) -> R .await .map_err(|e| { error!(error = ?e, "Database error"); - Error::DatabaseOperationFailed + Error::internal("Database operation failed") })?; Ok(()) } -pub struct AuthContext<'a> { +pub struct AuthContext<'a, T> { pub(super) conn: &'a mut UserAgentConnection, + pub(super) transport: T, } -impl<'a> AuthContext<'a> { - pub fn new(conn: &'a mut UserAgentConnection) -> Self { - Self { conn } +impl<'a, T> AuthContext<'a, T> { + pub fn new(conn: &'a mut UserAgentConnection, transport: T) -> Self { + Self { conn, transport } } } -impl AuthStateMachineContext for AuthContext<'_> { +impl AuthStateMachineContext for AuthContext<'_, T> +where + T: Bi> + Send, +{ type Error = Error; async fn prepare_challenge( @@ -118,9 +124,9 @@ impl AuthStateMachineContext for AuthContext<'_> { let stored_bytes = pubkey.to_stored_bytes(); let nonce = create_nonce(&self.conn.db, &stored_bytes).await?; - self.conn + self .transport - .send(Ok(Response::AuthChallenge { nonce })) + .send(Ok(Outbound::AuthChallenge { nonce })) .await .map_err(|e| { error!(?e, "Failed to send auth challenge"); @@ -149,7 +155,7 @@ impl AuthStateMachineContext for AuthContext<'_> { .await .map_err(|e| { error!(?e, "Failed to consume bootstrap token"); - Error::BootstrapperActorUnreachable + Error::internal("Failed to consume bootstrap token") })?; if !token_ok { @@ -159,11 +165,11 @@ impl AuthStateMachineContext for AuthContext<'_> { register_key(&self.conn.db, &pubkey).await?; - self.conn - .transport - .send(Ok(Response::AuthOk)) - .await - .map_err(|_| Error::Transport)?; + self + .transport + .send(Ok(Outbound::AuthSuccess)) + .await + .map_err(|_| Error::Transport)?; Ok(pubkey) } @@ -172,7 +178,10 @@ impl AuthStateMachineContext for AuthContext<'_> { #[allow(clippy::unused_unit)] async fn verify_solution( &mut self, - ChallengeContext { challenge_nonce, key }: &ChallengeContext, + ChallengeContext { + challenge_nonce, + key, + }: &ChallengeContext, ChallengeSolution { solution }: ChallengeSolution, ) -> Result { let formatted = arbiter_proto::format_challenge(*challenge_nonce, &key.to_stored_bytes()); @@ -205,9 +214,9 @@ impl AuthStateMachineContext for AuthContext<'_> { }; if valid { - self.conn + self .transport - .send(Ok(Response::AuthOk)) + .send(Ok(Outbound::AuthSuccess)) .await .map_err(|_| Error::Transport)?; } diff --git a/server/crates/arbiter-server/src/actors/user_agent/mod.rs b/server/crates/arbiter-server/src/actors/user_agent/mod.rs index 6b4a7d6..7f980b8 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/mod.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/mod.rs @@ -1,33 +1,15 @@ use alloy::primitives::Address; -use arbiter_proto::transport::Bi; +use arbiter_proto::transport::{Bi, Sender}; use kameo::actor::Spawn as _; use tracing::{error, info}; use crate::{ - actors::{GlobalActors, evm, user_agent::session::UserAgentSession}, + actors::{GlobalActors, evm}, db::{self, models::KeyType}, evm::policies::SharedGrantSettings, evm::policies::{Grant, SpecificGrant}, }; -#[derive(Debug, thiserror::Error, PartialEq)] -pub enum TransportResponseError { - #[error("Unexpected request payload")] - UnexpectedRequestPayload, - #[error("Invalid state for unseal encrypted key")] - InvalidStateForUnsealEncryptedKey, - #[error("client_pubkey must be 32 bytes")] - InvalidClientPubkeyLength, - #[error("State machine error")] - StateTransitionFailed, - #[error("Vault is not available")] - KeyHolderActorUnreachable, - #[error(transparent)] - Auth(#[from] auth::Error), - #[error("Failed registering connection")] - ConnectionRegistrationFailed, -} - /// Abstraction over Ed25519 / ECDSA-secp256k1 / RSA public keys used during the auth handshake. #[derive(Clone, Debug)] pub enum AuthPublicKey { @@ -65,119 +47,55 @@ impl AuthPublicKey { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum UnsealError { - InvalidKey, - Unbootstrapped, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum BootstrapError { - AlreadyBootstrapped, - InvalidKey, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum VaultState { - Unbootstrapped, - Sealed, - Unsealed, -} - -#[derive(Debug, Clone)] -pub enum Request { - AuthChallengeRequest { - pubkey: AuthPublicKey, - bootstrap_token: Option, - }, - AuthChallengeSolution { - signature: Vec, - }, - UnsealStart { - client_pubkey: x25519_dalek::PublicKey, - }, - UnsealEncryptedKey { - nonce: Vec, - ciphertext: Vec, - associated_data: Vec, - }, - BootstrapEncryptedKey { - nonce: Vec, - ciphertext: Vec, - associated_data: Vec, - }, - QueryVaultState, - EvmWalletCreate, - EvmWalletList, - ClientConnectionResponse { - approved: bool, - }, - - ListGrants, - EvmGrantCreate { - client_id: i32, - shared: SharedGrantSettings, - specific: SpecificGrant, - }, - EvmGrantDelete { - grant_id: i32, - }, +impl TryFrom<(KeyType, Vec)> for AuthPublicKey { + type Error = &'static str; + + fn try_from(value: (KeyType, Vec)) -> Result { + let (key_type, bytes) = value; + match key_type { + KeyType::Ed25519 => { + let bytes: [u8; 32] = bytes.try_into().map_err(|_| "invalid Ed25519 key length")?; + let key = ed25519_dalek::VerifyingKey::from_bytes(&bytes) + .map_err(|e| "invalid Ed25519 key")?; + Ok(AuthPublicKey::Ed25519(key)) + } + KeyType::EcdsaSecp256k1 => { + let point = + k256::EncodedPoint::from_bytes(&bytes).map_err(|e| "invalid ECDSA key")?; + let key = k256::ecdsa::VerifyingKey::from_encoded_point(&point) + .map_err(|e| "invalid ECDSA key")?; + Ok(AuthPublicKey::EcdsaSecp256k1(key)) + } + KeyType::Rsa => { + use rsa::pkcs8::DecodePublicKey as _; + let key = rsa::RsaPublicKey::from_public_key_der(&bytes) + .map_err(|e| "invalid RSA key")?; + Ok(AuthPublicKey::Rsa(key)) + } + } + } } +// Messages, sent by user agent to connection client without having a request #[derive(Debug)] -pub enum Response { - AuthChallenge { - nonce: i32, - }, - AuthOk, - UnsealStartResponse { - server_pubkey: x25519_dalek::PublicKey, - }, - UnsealResult(Result<(), UnsealError>), - BootstrapResult(Result<(), BootstrapError>), - VaultState(VaultState), - ClientConnectionRequest { - pubkey: ed25519_dalek::VerifyingKey, - }, +pub enum OutOfBand { + ClientConnectionRequest { pubkey: ed25519_dalek::VerifyingKey }, ClientConnectionCancel, - EvmWalletCreate(Result<(), evm::Error>), - EvmWalletList(Vec
), - - ListGrants(Vec>), - EvmGrantCreate(Result), - EvmGrantDelete(Result<(), evm::Error>), } -pub type Transport = Box> + Send>; - pub struct UserAgentConnection { - db: db::DatabasePool, - actors: GlobalActors, - transport: Transport, + pub(crate) db: db::DatabasePool, + pub(crate) actors: GlobalActors, } impl UserAgentConnection { - pub fn new(db: db::DatabasePool, actors: GlobalActors, transport: Transport) -> Self { - Self { - db, - actors, - transport, - } + pub fn new(db: db::DatabasePool, actors: GlobalActors) -> Self { + Self { db, actors } } } pub mod auth; pub mod session; -#[tracing::instrument(skip(props))] -pub async fn connect_user_agent(props: UserAgentConnection) { - match auth::authenticate_and_create(props).await { - Ok(session) => { - UserAgentSession::spawn(session); - info!("User authenticated, session started"); - } - Err(err) => { - error!(?err, "Authentication failed, closing connection"); - } - } -} +pub use auth::authenticate; +pub use session::UserAgentSession; diff --git a/server/crates/arbiter-server/src/actors/user_agent/session.rs b/server/crates/arbiter-server/src/actors/user_agent/session.rs index d568cc5..382165a 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/session.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/session.rs @@ -1,93 +1,63 @@ +use std::{borrow::Cow, convert::Infallible}; + +use arbiter_proto::transport::Sender; use ed25519_dalek::VerifyingKey; use kameo::{Actor, messages, prelude::Context}; +use thiserror::Error; use tokio::{select, sync::watch}; use tracing::{error, info}; use crate::actors::{ router::RegisterUserAgent, - user_agent::{ - Request, Response, TransportResponseError, - UserAgentConnection, - }, + user_agent::{OutOfBand, UserAgentConnection}, }; mod state; use state::{DummyContext, UserAgentEvents, UserAgentStateMachine}; -// Error for consumption by other actors -#[derive(Debug, thiserror::Error, PartialEq)] +#[derive(Debug, Error)] pub enum Error { - #[error("User agent session ended due to connection loss")] - ConnectionLost, + #[error("State transition failed")] + State, - #[error("User agent session ended due to unexpected message")] - UnexpectedMessage, + #[error("Internal error: {message}")] + Internal { message: Cow<'static, str> }, +} + +impl Error { + pub fn internal(message: impl Into>) -> Self { + Self::Internal { + message: message.into(), + } + } } pub struct UserAgentSession { props: UserAgentConnection, state: UserAgentStateMachine, + sender: Box>, } mod connection; +pub(crate) use connection::{ + BootstrapError, HandleBootstrapEncryptedKey, HandleEvmWalletCreate, HandleEvmWalletList, + HandleGrantCreate, HandleGrantDelete, HandleGrantList, HandleQueryVaultState, + HandleUnsealEncryptedKey, HandleUnsealRequest, UnsealError, +}; impl UserAgentSession { - pub(crate) fn new(props: UserAgentConnection) -> Self { + pub(crate) fn new(props: UserAgentConnection, sender: Box>) -> Self { Self { props, state: UserAgentStateMachine::new(DummyContext), + sender, } } - pub(super) async fn send_msg( - &mut self, - msg: Response, - _ctx: &mut Context, - ) -> Result<(), Error> { - self.props.transport.send(Ok(msg)).await.map_err(|_| { - error!( - actor = "useragent", - reason = "channel closed", - "send.failed" - ); - Error::ConnectionLost - }) - } - - async fn expect_msg( - &mut self, - extractor: Extractor, - ctx: &mut Context, - ) -> Result - where - Extractor: FnOnce(Request) -> Option, - Reply: kameo::Reply, - { - let msg = self.props.transport.recv().await.ok_or_else(|| { - error!( - actor = "useragent", - reason = "channel closed", - "recv.failed" - ); - ctx.stop(); - Error::ConnectionLost - })?; - - extractor(msg).ok_or_else(|| { - error!( - actor = "useragent", - reason = "unexpected message", - "recv.failed" - ); - ctx.stop(); - Error::UnexpectedMessage - }) - } - - fn transition(&mut self, event: UserAgentEvents) -> Result<(), TransportResponseError> { + fn transition(&mut self, event: UserAgentEvents) -> Result<(), Error> { self.state.process_event(event).map_err(|e| { error!(?e, "State transition failed"); - TransportResponseError::StateTransitionFailed + Error::State })?; Ok(()) } @@ -95,52 +65,21 @@ impl UserAgentSession { #[messages] impl UserAgentSession { - // TODO: Think about refactoring it to state-machine based flow, as we already have one #[message(ctx)] pub async fn request_new_client_approval( &mut self, client_pubkey: VerifyingKey, mut cancel_flag: watch::Receiver<()>, - ctx: &mut Context>, - ) -> Result { - self.send_msg( - Response::ClientConnectionRequest { - pubkey: client_pubkey, - }, - ctx, - ) - .await?; - - let extractor = |msg| { - if let Request::ClientConnectionResponse { approved } = msg { - Some(approved) - } else { - None - } - }; - - tokio::select! { - _ = cancel_flag.changed() => { - info!(actor = "useragent", "client connection approval cancelled"); - self.send_msg( - Response::ClientConnectionCancel, - ctx, - ).await?; - Ok(false) - } - result = self.expect_msg(extractor, ctx) => { - let result = result?; - info!(actor = "useragent", "received client connection approval result: approved={}", result); - Ok(result) - } - } + ctx: &mut Context>, + ) -> Result { + todo!("Think about refactoring it to state-machine based flow, as we already have one") } } impl Actor for UserAgentSession { type Args = Self; - type Error = TransportResponseError; + type Error = Error; async fn on_start( args: Self::Args, @@ -155,56 +94,8 @@ impl Actor for UserAgentSession { .await .map_err(|err| { error!(?err, "Failed to register user agent connection with router"); - TransportResponseError::ConnectionRegistrationFailed + Error::internal("Failed to register user agent connection with router") })?; Ok(args) } - - async fn next( - &mut self, - _actor_ref: kameo::prelude::WeakActorRef, - mailbox_rx: &mut kameo::prelude::MailboxReceiver, - ) -> Option> { - loop { - select! { - signal = mailbox_rx.recv() => { - return signal; - } - msg = self.props.transport.recv() => { - match msg { - Some(request) => { - match self.process_transport_inbound(request).await { - Ok(response) => { - if self.props.transport.send(Ok(response)).await.is_err() { - error!(actor = "useragent", reason = "channel closed", "send.failed"); - return Some(kameo::mailbox::Signal::Stop); - } - } - Err(err) => { - let _ = self.props.transport.send(Err(err)).await; - return Some(kameo::mailbox::Signal::Stop); - } - } - } - None => { - info!(actor = "useragent", "transport.closed"); - return Some(kameo::mailbox::Signal::Stop); - } - } - } - } - } - } -} - -impl UserAgentSession { - pub fn new_test(db: crate::db::DatabasePool, actors: crate::actors::GlobalActors) -> Self { - use arbiter_proto::transport::DummyTransport; - let transport: super::Transport = Box::new(DummyTransport::new()); - let props = UserAgentConnection::new(db, actors, transport); - Self { - props, - state: UserAgentStateMachine::new(DummyContext), - } - } } diff --git a/server/crates/arbiter-server/src/actors/user_agent/session/connection.rs b/server/crates/arbiter-server/src/actors/user_agent/session/connection.rs index f7cf2be..ed9a107 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/session/connection.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/session/connection.rs @@ -1,10 +1,15 @@ use std::sync::Mutex; +use alloy::primitives::Address; use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit}; use kameo::error::SendError; +use kameo::messages; use tracing::{error, info}; use x25519_dalek::{EphemeralSecret, PublicKey}; +use crate::actors::keyholder::KeyHolderState; +use crate::actors::user_agent::session::Error; +use crate::evm::policies::{Grant, SpecificGrant}; use crate::safe_cell::SafeCell; use crate::{ actors::{ @@ -13,7 +18,7 @@ use crate::{ }, keyholder::{self, Bootstrap, TryUnseal}, user_agent::{ - BootstrapError, Request, Response, TransportResponseError, UnsealError, VaultState, + OutOfBand, session::{ UserAgentSession, state::{UnsealContext, UserAgentEvents, UserAgentStates}, @@ -24,55 +29,10 @@ use crate::{ }; impl UserAgentSession { - pub async fn process_transport_inbound(&mut self, req: Request) -> Output { - match req { - Request::UnsealStart { client_pubkey } => { - self.handle_unseal_request(client_pubkey).await - } - Request::UnsealEncryptedKey { - nonce, - ciphertext, - associated_data, - } => { - self.handle_unseal_encrypted_key(nonce, ciphertext, associated_data) - .await - } - Request::BootstrapEncryptedKey { - nonce, - ciphertext, - associated_data, - } => { - self.handle_bootstrap_encrypted_key(nonce, ciphertext, associated_data) - .await - } - Request::ListGrants => self.handle_grant_list().await, - Request::QueryVaultState => self.handle_query_vault_state().await, - Request::EvmWalletCreate => self.handle_evm_wallet_create().await, - Request::EvmWalletList => self.handle_evm_wallet_list().await, - Request::AuthChallengeRequest { .. } - | Request::AuthChallengeSolution { .. } - | Request::ClientConnectionResponse { .. } => { - Err(TransportResponseError::UnexpectedRequestPayload) - } - Request::EvmGrantCreate { - client_id, - shared, - specific, - } => self.handle_grant_create(client_id, shared, specific).await, - Request::EvmGrantDelete { grant_id } => self.handle_grant_delete(grant_id).await, - } - } -} - -type Output = Result; - -impl UserAgentSession { - fn take_unseal_secret( - &mut self, - ) -> Result<(EphemeralSecret, PublicKey), TransportResponseError> { + fn take_unseal_secret(&mut self) -> Result<(EphemeralSecret, PublicKey), Error> { let UserAgentStates::WaitingForUnsealKey(unseal_context) = self.state.state() else { error!("Received encrypted key in invalid state"); - return Err(TransportResponseError::InvalidStateForUnsealEncryptedKey); + return Err(Error::internal("Invalid state for unseal encrypted key")); }; let ephemeral_secret = { @@ -87,7 +47,7 @@ impl UserAgentSession { None => { drop(secret_lock); error!("Ephemeral secret already taken"); - return Err(TransportResponseError::StateTransitionFailed); + return Err(Error::internal("Ephemeral secret already taken")); } } }; @@ -121,8 +81,38 @@ impl UserAgentSession { } } } +} - async fn handle_unseal_request(&mut self, client_pubkey: x25519_dalek::PublicKey) -> Output { +pub struct UnsealStartResponse { + pub server_pubkey: PublicKey, +} + +#[derive(Debug, Error)] +pub enum UnsealError { + #[error("Invalid key provided for unsealing")] + InvalidKey, + #[error("Internal error during unsealing process")] + General(#[from] super::Error), +} + +#[derive(Debug, Error)] +pub enum BootstrapError { + #[error("Invalid key provided for bootstrapping")] + InvalidKey, + #[error("Vault is already bootstrapped")] + AlreadyBootstrapped, + + #[error("Internal error during bootstrapping process")] + General(#[from] super::Error), +} + +#[messages] +impl UserAgentSession { + #[message] + pub(crate) async fn handle_unseal_request( + &mut self, + client_pubkey: x25519_dalek::PublicKey, + ) -> Result { let secret = EphemeralSecret::random(); let public_key = PublicKey::from(&secret); @@ -131,24 +121,27 @@ impl UserAgentSession { client_public_key: client_pubkey, }))?; - Ok(Response::UnsealStartResponse { + Ok(UnsealStartResponse { server_pubkey: public_key, }) } - async fn handle_unseal_encrypted_key( + #[message] + pub(crate) async fn handle_unseal_encrypted_key( &mut self, nonce: Vec, ciphertext: Vec, associated_data: Vec, - ) -> Output { + ) -> Result<(), UnsealError> { let (ephemeral_secret, client_public_key) = match self.take_unseal_secret() { Ok(values) => values, - Err(TransportResponseError::StateTransitionFailed) => { + Err(Error::State) => { self.transition(UserAgentEvents::ReceivedInvalidKey)?; - return Ok(Response::UnsealResult(Err(UnsealError::InvalidKey))); + return Err(UnsealError::InvalidKey); + } + Err(err) => { + return Err(Error::internal("Failed to take unseal secret").into()); } - Err(err) => return Err(err), }; let seal_key_buffer = match Self::decrypt_client_key_material( @@ -161,7 +154,7 @@ impl UserAgentSession { Ok(buffer) => buffer, Err(()) => { self.transition(UserAgentEvents::ReceivedInvalidKey)?; - return Ok(Response::UnsealResult(Err(UnsealError::InvalidKey))); + return Err(UnsealError::InvalidKey); } }; @@ -177,38 +170,39 @@ impl UserAgentSession { Ok(_) => { info!("Successfully unsealed key with client-provided key"); self.transition(UserAgentEvents::ReceivedValidKey)?; - Ok(Response::UnsealResult(Ok(()))) + Ok(()) } Err(SendError::HandlerError(keyholder::Error::InvalidKey)) => { self.transition(UserAgentEvents::ReceivedInvalidKey)?; - Ok(Response::UnsealResult(Err(UnsealError::InvalidKey))) + Err(UnsealError::InvalidKey) } Err(SendError::HandlerError(err)) => { error!(?err, "Keyholder failed to unseal key"); self.transition(UserAgentEvents::ReceivedInvalidKey)?; - Ok(Response::UnsealResult(Err(UnsealError::InvalidKey))) + Err(UnsealError::InvalidKey) } Err(err) => { error!(?err, "Failed to send unseal request to keyholder"); self.transition(UserAgentEvents::ReceivedInvalidKey)?; - Err(TransportResponseError::KeyHolderActorUnreachable) + Err(Error::internal("Vault actor error").into()) } } } - async fn handle_bootstrap_encrypted_key( + #[message] + pub(crate) async fn handle_bootstrap_encrypted_key( &mut self, nonce: Vec, ciphertext: Vec, associated_data: Vec, - ) -> Output { + ) -> Result<(), BootstrapError> { let (ephemeral_secret, client_public_key) = match self.take_unseal_secret() { Ok(values) => values, - Err(TransportResponseError::StateTransitionFailed) => { + Err(Error::State) => { self.transition(UserAgentEvents::ReceivedInvalidKey)?; - return Ok(Response::BootstrapResult(Err(BootstrapError::InvalidKey))); + return Err(BootstrapError::InvalidKey); } - Err(err) => return Err(err), + Err(err) => return Err(err.into()), }; let seal_key_buffer = match Self::decrypt_client_key_material( @@ -221,7 +215,7 @@ impl UserAgentSession { Ok(buffer) => buffer, Err(()) => { self.transition(UserAgentEvents::ReceivedInvalidKey)?; - return Ok(Response::BootstrapResult(Err(BootstrapError::InvalidKey))); + return Err(BootstrapError::InvalidKey); } }; @@ -237,87 +231,94 @@ impl UserAgentSession { Ok(_) => { info!("Successfully bootstrapped vault with client-provided key"); self.transition(UserAgentEvents::ReceivedValidKey)?; - Ok(Response::BootstrapResult(Ok(()))) + Ok(()) } Err(SendError::HandlerError(keyholder::Error::AlreadyBootstrapped)) => { self.transition(UserAgentEvents::ReceivedInvalidKey)?; - Ok(Response::BootstrapResult(Err( - BootstrapError::AlreadyBootstrapped, - ))) + Err(BootstrapError::AlreadyBootstrapped) } Err(SendError::HandlerError(err)) => { error!(?err, "Keyholder failed to bootstrap vault"); self.transition(UserAgentEvents::ReceivedInvalidKey)?; - Ok(Response::BootstrapResult(Err(BootstrapError::InvalidKey))) + Err(BootstrapError::InvalidKey) } Err(err) => { error!(?err, "Failed to send bootstrap request to keyholder"); self.transition(UserAgentEvents::ReceivedInvalidKey)?; - Err(TransportResponseError::KeyHolderActorUnreachable) + Err(BootstrapError::General(Error::internal( + "Vault actor error", + ))) } } } } +#[messages] impl UserAgentSession { - async fn handle_query_vault_state(&mut self) -> Output { - use crate::actors::keyholder::{GetState, StateDiscriminants}; + #[message] + pub(crate) async fn handle_query_vault_state(&mut self) -> Result { + use crate::actors::keyholder::GetState; let vault_state = match self.props.actors.key_holder.ask(GetState {}).await { - Ok(StateDiscriminants::Unbootstrapped) => VaultState::Unbootstrapped, - Ok(StateDiscriminants::Sealed) => VaultState::Sealed, - Ok(StateDiscriminants::Unsealed) => VaultState::Unsealed, + Ok(state) => state, Err(err) => { error!(?err, actor = "useragent", "keyholder.query.failed"); - return Err(TransportResponseError::KeyHolderActorUnreachable); + return Err(Error::internal("Vault is in broken state").into()); } }; - Ok(Response::VaultState(vault_state)) + Ok(vault_state) } } +#[messages] impl UserAgentSession { - async fn handle_evm_wallet_create(&mut self) -> Output { - let result = match self.props.actors.evm.ask(Generate {}).await { - Ok(_address) => return Ok(Response::EvmWalletCreate(Ok(()))), - Err(SendError::HandlerError(err)) => Err(err), + #[message] + pub(crate) async fn handle_evm_wallet_create(&mut self) -> Result { + match self.props.actors.evm.ask(Generate {}).await { + Ok(address) => return Ok(address), + Err(SendError::HandlerError(err)) => Err(Error::internal(format!( + "EVM wallet generation failed: {err}" + ))), Err(err) => { error!(?err, "EVM actor unreachable during wallet create"); - return Err(TransportResponseError::KeyHolderActorUnreachable); + return Err(Error::internal("EVM actor unreachable")); } - }; - Ok(Response::EvmWalletCreate(result)) + } } - async fn handle_evm_wallet_list(&mut self) -> Output { + #[message] + pub(crate) async fn handle_evm_wallet_list(&mut self) -> Result, Error> { match self.props.actors.evm.ask(ListWallets {}).await { - Ok(wallets) => Ok(Response::EvmWalletList(wallets)), + Ok(wallets) => Ok(wallets), Err(err) => { error!(?err, "EVM wallet list failed"); - Err(TransportResponseError::KeyHolderActorUnreachable) + Err(Error::internal("Failed to list EVM wallets")) } } } } +#[messages] impl UserAgentSession { - async fn handle_grant_list(&mut self) -> Output { + #[message] + pub(crate) async fn handle_grant_list(&mut self) -> Result>, Error> { match self.props.actors.evm.ask(UseragentListGrants {}).await { - Ok(grants) => Ok(Response::ListGrants(grants)), + Ok(grants) => Ok(grants), Err(err) => { error!(?err, "EVM grant list failed"); - Err(TransportResponseError::KeyHolderActorUnreachable) + Err(Error::internal("Failed to list EVM grants")) } } } - async fn handle_grant_create( + #[message] + pub(crate) async fn handle_grant_create( &mut self, client_id: i32, basic: crate::evm::policies::SharedGrantSettings, grant: crate::evm::policies::SpecificGrant, - ) -> Output { + ) -> Result { match self .props .actors @@ -329,15 +330,16 @@ impl UserAgentSession { }) .await { - Ok(grant_id) => Ok(Response::EvmGrantCreate(Ok(grant_id))), + Ok(grant_id) => Ok(grant_id), Err(err) => { error!(?err, "EVM grant create failed"); - Err(TransportResponseError::KeyHolderActorUnreachable) + Err(Error::internal("Failed to create EVM grant")) } } } - async fn handle_grant_delete(&mut self, grant_id: i32) -> Output { + #[message] + pub(crate) async fn handle_grant_delete(&mut self, grant_id: i32) -> Result<(), Error> { match self .props .actors @@ -345,10 +347,10 @@ impl UserAgentSession { .ask(UseragentDeleteGrant { grant_id }) .await { - Ok(()) => Ok(Response::EvmGrantDelete(Ok(()))), + Ok(()) => Ok(()), Err(err) => { error!(?err, "EVM grant delete failed"); - Err(TransportResponseError::KeyHolderActorUnreachable) + Err(Error::internal("Failed to delete EVM grant")) } } } diff --git a/server/crates/arbiter-server/src/db/mod.rs b/server/crates/arbiter-server/src/db/mod.rs index 616bd92..ba7ef0e 100644 --- a/server/crates/arbiter-server/src/db/mod.rs +++ b/server/crates/arbiter-server/src/db/mod.rs @@ -44,6 +44,14 @@ pub enum DatabaseSetupError { Pool(#[from] PoolInitError), } +#[derive(Error, Debug)] +pub enum DatabaseError { + #[error("Database connection error")] + Pool(#[from] PoolError), + #[error("Database query error")] + Connection(#[from] diesel::result::Error), +} + #[tracing::instrument(level = "info")] fn database_path() -> Result { let arbiter_home = arbiter_proto::home_path().map_err(DatabaseSetupError::HomeDir)?; diff --git a/server/crates/arbiter-server/src/grpc/client.rs b/server/crates/arbiter-server/src/grpc/client.rs index 1e9e072..3d41785 100644 --- a/server/crates/arbiter-server/src/grpc/client.rs +++ b/server/crates/arbiter-server/src/grpc/client.rs @@ -1,14 +1,13 @@ use arbiter_proto::{ proto::client::{ - AuthChallenge as ProtoAuthChallenge, - AuthChallengeRequest as ProtoAuthChallengeRequest, + AuthChallenge as ProtoAuthChallenge, AuthChallengeRequest as ProtoAuthChallengeRequest, AuthChallengeSolution as ProtoAuthChallengeSolution, AuthOk as ProtoAuthOk, ClientConnectError, ClientRequest, ClientResponse, client_connect_error::Code as ProtoClientConnectErrorCode, client_request::Payload as ClientRequestPayload, client_response::Payload as ClientResponsePayload, }, - transport::{Bi, Error as TransportError}, + transport::{Bi, Error as TransportError, Sender}, }; use async_trait::async_trait; use futures::StreamExt as _; @@ -37,9 +36,9 @@ impl GrpcTransport { Some(ClientRequestPayload::AuthChallengeRequest(ProtoAuthChallengeRequest { pubkey, })) => Ok(DomainRequest::AuthChallengeRequest { pubkey }), - Some(ClientRequestPayload::AuthChallengeSolution( - ProtoAuthChallengeSolution { signature }, - )) => Ok(DomainRequest::AuthChallengeSolution { signature }), + Some(ClientRequestPayload::AuthChallengeSolution(ProtoAuthChallengeSolution { + signature, + })) => Ok(DomainRequest::AuthChallengeSolution { signature }), None => Err(Status::invalid_argument("Missing client request payload")), } } @@ -86,8 +85,11 @@ impl GrpcTransport { } #[async_trait] -impl Bi> for GrpcTransport { - async fn send(&mut self, item: Result) -> Result<(), TransportError> { +impl Sender> for GrpcTransport { + async fn send( + &mut self, + item: Result, + ) -> Result<(), TransportError> { let outbound = match item { Ok(message) => Ok(Self::response_to_proto(message)), Err(err) => Err(Self::error_to_status(err)), @@ -98,7 +100,10 @@ impl Bi> for GrpcTransport { .await .map_err(|_| TransportError::ChannelClosed) } +} +#[async_trait] +impl Bi> for GrpcTransport { async fn recv(&mut self) -> Option { match self.receiver.next().await { Some(Ok(item)) => match Self::request_to_domain(item) { diff --git a/server/crates/arbiter-server/src/grpc/mod.rs b/server/crates/arbiter-server/src/grpc/mod.rs index 18b9f70..204d6b1 100644 --- a/server/crates/arbiter-server/src/grpc/mod.rs +++ b/server/crates/arbiter-server/src/grpc/mod.rs @@ -1,7 +1,9 @@ - -use arbiter_proto::proto::{ - client::{ClientRequest, ClientResponse}, - user_agent::{UserAgentRequest, UserAgentResponse}, +use arbiter_proto::{ + proto::{ + client::{ClientRequest, ClientResponse}, + user_agent::{UserAgentRequest, UserAgentResponse}, + }, + transport::grpc::GrpcBi, }; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; @@ -10,7 +12,11 @@ use tracing::info; use crate::{ DEFAULT_CHANNEL_SIZE, - actors::{client::{ClientConnection, connect_client}, user_agent::{UserAgentConnection, connect_user_agent}}, + actors::{ + client::{ClientConnection, connect_client}, + user_agent::UserAgentConnection, + }, + grpc::{self, user_agent::start}, }; pub mod client; @@ -48,18 +54,19 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for super::Ser request: Request>, ) -> Result, Status> { let req_stream = request.into_inner(); - let (tx, rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); - let transport = user_agent::GrpcTransport::new(tx, req_stream); - let props = UserAgentConnection::new( - self.context.db.clone(), - self.context.actors.clone(), - Box::new(transport), - ); - tokio::spawn(connect_user_agent(props)); + let (bi, rx) = GrpcBi::from_bi_stream(req_stream); + + tokio::spawn(start( + UserAgentConnection { + db: self.context.db.clone(), + actors: self.context.actors.clone(), + }, + bi, + )); info!(event = "connection established", "grpc.user_agent"); - Ok(Response::new(ReceiverStream::new(rx))) + Ok(Response::new(rx)) } } diff --git a/server/crates/arbiter-server/src/grpc/user_agent.rs b/server/crates/arbiter-server/src/grpc/user_agent.rs index 39c6afb..4df6317 100644 --- a/server/crates/arbiter-server/src/grpc/user_agent.rs +++ b/server/crates/arbiter-server/src/grpc/user_agent.rs @@ -1,12 +1,14 @@ +use tokio::sync::mpsc; + use arbiter_proto::{ proto::{ - self, evm::{ EtherTransferSettings as ProtoEtherTransferSettings, EvmError as ProtoEvmError, EvmGrantCreateRequest, EvmGrantCreateResponse, EvmGrantDeleteRequest, EvmGrantDeleteResponse, EvmGrantList, EvmGrantListResponse, GrantEntry, SharedSettings as ProtoSharedSettings, SpecificGrant as ProtoSpecificGrant, TokenTransferSettings as ProtoTokenTransferSettings, + TransactionRateLimit as ProtoTransactionRateLimit, VolumeRateLimit as ProtoVolumeRateLimit, WalletCreateResponse, WalletEntry, WalletList, WalletListResponse, evm_grant_create_response::Result as EvmGrantCreateResult, evm_grant_delete_response::Result as EvmGrantDeleteResult, @@ -16,494 +18,538 @@ use arbiter_proto::{ wallet_list_response::Result as WalletListResult, }, user_agent::{ - AuthChallenge as ProtoAuthChallenge, AuthChallengeRequest as ProtoAuthChallengeRequest, - AuthChallengeSolution as ProtoAuthChallengeSolution, AuthOk as ProtoAuthOk, BootstrapEncryptedKey as ProtoBootstrapEncryptedKey, BootstrapResult as ProtoBootstrapResult, ClientConnectionCancel, - ClientConnectionRequest, ClientConnectionResponse, KeyType as ProtoKeyType, - UnsealEncryptedKey as ProtoUnsealEncryptedKey, UnsealResult as ProtoUnsealResult, - UnsealStart, UnsealStartResponse, UserAgentRequest, UserAgentResponse, + ClientConnectionRequest, UnsealEncryptedKey as ProtoUnsealEncryptedKey, + UnsealResult as ProtoUnsealResult, UnsealStart, UserAgentRequest, UserAgentResponse, VaultState as ProtoVaultState, user_agent_request::Payload as UserAgentRequestPayload, user_agent_response::Payload as UserAgentResponsePayload, }, }, - transport::{Bi, Error as TransportError}, + transport::{Error as TransportError, Receiver, Sender, grpc::GrpcBi}, }; use async_trait::async_trait; -use futures::StreamExt as _; -use prost_types::Timestamp; -use tokio::sync::mpsc; -use tonic::{Status, Streaming}; +use chrono::{TimeZone, Utc}; +use kameo::{actor::{ActorRef, Spawn as _}, error::SendError}; +use tonic::Status; +use tracing::{info, warn}; use crate::{ - actors::user_agent::{ - self, AuthPublicKey, BootstrapError, Request as DomainRequest, Response as DomainResponse, - TransportResponseError, UnsealError, VaultState, - }, - evm::{ - policies::{Grant, SpecificGrant}, - policies::{ - SharedGrantSettings, TransactionRateLimit, VolumeRateLimit, ether_transfer, - token_transfers, + actors::{ + keyholder::KeyHolderState, + user_agent::{ + OutOfBand, UserAgentConnection, UserAgentSession, + session::{ + BootstrapError, Error, HandleBootstrapEncryptedKey, HandleEvmWalletCreate, HandleEvmWalletList, HandleGrantCreate, HandleGrantDelete, HandleGrantList, HandleQueryVaultState, HandleUnsealEncryptedKey, HandleUnsealRequest, UnsealError + }, }, }, + evm::policies::{ + Grant, SharedGrantSettings, SpecificGrant, TransactionRateLimit, VolumeRateLimit, + ether_transfer, token_transfers, + }, + utils::defer, }; use alloy::primitives::{Address, U256}; -use chrono::{DateTime, TimeZone, Utc}; +mod auth; -pub struct GrpcTransport { - sender: mpsc::Sender>, - receiver: Streaming, -} - -impl GrpcTransport { - pub fn new( - sender: mpsc::Sender>, - receiver: Streaming, - ) -> Self { - Self { sender, receiver } - } - - fn request_to_domain(request: UserAgentRequest) -> Result { - match request.payload { - Some(UserAgentRequestPayload::AuthChallengeRequest(ProtoAuthChallengeRequest { - pubkey, - bootstrap_token, - key_type, - })) => Ok(DomainRequest::AuthChallengeRequest { - pubkey: parse_auth_pubkey(key_type, pubkey)?, - bootstrap_token, - }), - Some(UserAgentRequestPayload::AuthChallengeSolution(ProtoAuthChallengeSolution { - signature, - })) => Ok(DomainRequest::AuthChallengeSolution { signature }), - Some(UserAgentRequestPayload::UnsealStart(UnsealStart { client_pubkey })) => { - let client_pubkey: [u8; 32] = client_pubkey - .as_slice() - .try_into() - .map_err(|_| Status::invalid_argument("client_pubkey must be 32 bytes"))?; - Ok(DomainRequest::UnsealStart { - client_pubkey: x25519_dalek::PublicKey::from(client_pubkey), - }) - } - Some(UserAgentRequestPayload::UnsealEncryptedKey(ProtoUnsealEncryptedKey { - nonce, - ciphertext, - associated_data, - })) => Ok(DomainRequest::UnsealEncryptedKey { - nonce, - ciphertext, - associated_data, - }), - Some(UserAgentRequestPayload::BootstrapEncryptedKey(ProtoBootstrapEncryptedKey { - nonce, - ciphertext, - associated_data, - })) => Ok(DomainRequest::BootstrapEncryptedKey { - nonce, - ciphertext, - associated_data, - }), - Some(UserAgentRequestPayload::QueryVaultState(_)) => Ok(DomainRequest::QueryVaultState), - Some(UserAgentRequestPayload::EvmWalletCreate(_)) => Ok(DomainRequest::EvmWalletCreate), - Some(UserAgentRequestPayload::EvmWalletList(_)) => Ok(DomainRequest::EvmWalletList), - Some(UserAgentRequestPayload::ClientConnectionResponse(ClientConnectionResponse { - approved, - })) => Ok(DomainRequest::ClientConnectionResponse { approved }), - - Some(UserAgentRequestPayload::EvmGrantList(_)) => Ok(DomainRequest::ListGrants), - Some(UserAgentRequestPayload::EvmGrantCreate(EvmGrantCreateRequest { - client_id, - shared, - specific, - })) => { - let shared = parse_shared_settings(client_id, shared)?; - let specific = parse_specific_grant(specific)?; - Ok(DomainRequest::EvmGrantCreate { - client_id, - shared, - specific, - }) - } - Some(UserAgentRequestPayload::EvmGrantDelete(EvmGrantDeleteRequest { grant_id })) => { - Ok(DomainRequest::EvmGrantDelete { grant_id }) - } - None => Err(Status::invalid_argument( - "Missing user-agent request payload", - )), - } - } - - fn response_to_proto(response: DomainResponse) -> UserAgentResponse { - let payload = match response { - DomainResponse::AuthChallenge { nonce } => { - UserAgentResponsePayload::AuthChallenge(ProtoAuthChallenge { - pubkey: Vec::new(), - nonce, - }) - } - DomainResponse::AuthOk => UserAgentResponsePayload::AuthOk(ProtoAuthOk {}), - DomainResponse::UnsealStartResponse { server_pubkey } => { - UserAgentResponsePayload::UnsealStartResponse(UnsealStartResponse { - server_pubkey: server_pubkey.as_bytes().to_vec(), - }) - } - DomainResponse::UnsealResult(result) => UserAgentResponsePayload::UnsealResult( - match result { - Ok(()) => ProtoUnsealResult::Success, - Err(UnsealError::InvalidKey) => ProtoUnsealResult::InvalidKey, - Err(UnsealError::Unbootstrapped) => ProtoUnsealResult::Unbootstrapped, - } - .into(), - ), - DomainResponse::BootstrapResult(result) => UserAgentResponsePayload::BootstrapResult( - match result { - Ok(()) => ProtoBootstrapResult::Success, - Err(BootstrapError::AlreadyBootstrapped) => { - ProtoBootstrapResult::AlreadyBootstrapped - } - Err(BootstrapError::InvalidKey) => ProtoBootstrapResult::InvalidKey, - } - .into(), - ), - DomainResponse::VaultState(state) => UserAgentResponsePayload::VaultState( - match state { - VaultState::Unbootstrapped => ProtoVaultState::Unbootstrapped, - VaultState::Sealed => ProtoVaultState::Sealed, - VaultState::Unsealed => ProtoVaultState::Unsealed, - } - .into(), - ), - DomainResponse::ClientConnectionRequest { pubkey } => { - UserAgentResponsePayload::ClientConnectionRequest(ClientConnectionRequest { - pubkey: pubkey.to_bytes().to_vec(), - }) - } - DomainResponse::ClientConnectionCancel => { - UserAgentResponsePayload::ClientConnectionCancel(ClientConnectionCancel {}) - } - DomainResponse::EvmWalletCreate(result) => { - UserAgentResponsePayload::EvmWalletCreate(WalletCreateResponse { - result: Some(match result { - Ok(()) => WalletCreateResult::Wallet(WalletEntry { - address: Vec::new(), - }), - Err(_) => WalletCreateResult::Error(ProtoEvmError::Internal.into()), - }), - }) - } - DomainResponse::EvmWalletList(wallets) => { - UserAgentResponsePayload::EvmWalletList(WalletListResponse { - result: Some(WalletListResult::Wallets(WalletList { - wallets: wallets - .into_iter() - .map(|addr| WalletEntry { - address: addr.as_slice().to_vec(), - }) - .collect(), - })), - }) - } - DomainResponse::ListGrants(grants) => { - UserAgentResponsePayload::EvmGrantList(EvmGrantListResponse { - result: Some(EvmGrantListResult::Grants(EvmGrantList { - grants: grants.into_iter().map(grant_to_proto).collect(), - })), - }) - } - DomainResponse::EvmGrantCreate(result) => { - UserAgentResponsePayload::EvmGrantCreate(EvmGrantCreateResponse { - result: Some(match result { - Ok(grant_id) => EvmGrantCreateResult::GrantId(grant_id), - Err(_) => EvmGrantCreateResult::Error(ProtoEvmError::Internal.into()), - }), - }) - } - DomainResponse::EvmGrantDelete(result) => { - UserAgentResponsePayload::EvmGrantDelete(EvmGrantDeleteResponse { - result: Some(match result { - Ok(()) => EvmGrantDeleteResult::Ok(()), - Err(_) => EvmGrantDeleteResult::Error(ProtoEvmError::Internal.into()), - }), - }) - } - }; - - UserAgentResponse { - payload: Some(payload), - } - } - - fn error_to_status(value: TransportResponseError) -> Status { - match value { - TransportResponseError::UnexpectedRequestPayload => { - Status::invalid_argument("Expected message with payload") - } - TransportResponseError::InvalidStateForUnsealEncryptedKey => { - Status::failed_precondition("Invalid state for unseal encrypted key") - } - TransportResponseError::InvalidClientPubkeyLength => { - Status::invalid_argument("client_pubkey must be 32 bytes") - } - TransportResponseError::StateTransitionFailed => { - Status::internal("State machine error") - } - TransportResponseError::KeyHolderActorUnreachable => { - Status::internal("Vault is not available") - } - TransportResponseError::Auth(ref err) => auth_error_status(err), - TransportResponseError::ConnectionRegistrationFailed => { - Status::internal("Failed registering connection") - } - } - } -} +pub struct OutOfBandAdapter(mpsc::Sender); #[async_trait] -impl Bi> for GrpcTransport { - async fn send( - &mut self, - item: Result, - ) -> Result<(), TransportError> { - let outbound = match item { - Ok(message) => Ok(Self::response_to_proto(message)), - Err(err) => Err(Self::error_to_status(err)), +impl Sender for OutOfBandAdapter { + async fn send(&mut self, item: OutOfBand) -> Result<(), TransportError> { + self.0.send(item).await.map_err(|e| { + warn!(error = ?e, "Failed to send out-of-band message"); + TransportError::ChannelClosed + }) + } +} + +async fn dispatch_loop( + mut bi: GrpcBi, + actor: ActorRef, + mut receiver: mpsc::Receiver, +) { + loop { + tokio::select! { + oob = receiver.recv() => { + let Some(oob) = oob else { + return; + }; + + if send_out_of_band(&mut bi, oob).await.is_err() { + return; + } + } + + conn = bi.recv() => { + let Some(conn) = conn else { + return; + }; + + if dispatch_conn_message(&mut bi, &actor, conn).await.is_err() { + return; + } + } + } + } +} + +async fn dispatch_conn_message( + bi: &mut GrpcBi, + actor: &ActorRef, + conn: Result, +) -> Result<(), ()> { + let conn = match conn { + Ok(conn) => conn, + Err(err) => { + warn!(error = ?err, "Failed to receive user agent request"); + return Err(()); + } + }; + + let Some(payload) = conn.payload else { + let _ = bi.send(Err(Status::invalid_argument("Missing user-agent request payload"))).await; + return Err(()); + }; + + let payload = match payload { + UserAgentRequestPayload::UnsealStart(UnsealStart { client_pubkey }) => { + let client_pubkey = match <[u8; 32]>::try_from(client_pubkey) { + Ok(bytes) => x25519_dalek::PublicKey::from(bytes), + Err(_) => { + let _ = bi.send(Err(Status::invalid_argument("Invalid X25519 public key"))).await; + return Err(()); + } + }; + + match actor.ask(HandleUnsealRequest { client_pubkey }).await { + Ok(response) => UserAgentResponsePayload::UnsealStartResponse( + arbiter_proto::proto::user_agent::UnsealStartResponse { + server_pubkey: response.server_pubkey.as_bytes().to_vec(), + }, + ), + Err(err) => { + warn!(error = ?err, "Failed to handle unseal start request"); + let _ = bi.send(Err(Status::internal("Failed to start unseal flow"))).await; + return Err(()); + } + } + } + UserAgentRequestPayload::UnsealEncryptedKey(ProtoUnsealEncryptedKey { + nonce, + ciphertext, + associated_data, + }) => UserAgentResponsePayload::UnsealResult( + match actor + .ask(HandleUnsealEncryptedKey { + nonce, + ciphertext, + associated_data, + }) + .await + { + Ok(()) => ProtoUnsealResult::Success, + Err(SendError::HandlerError(UnsealError::InvalidKey)) => { + ProtoUnsealResult::InvalidKey + } + Err(err) => { + warn!(error = ?err, "Failed to handle unseal request"); + let _ = bi.send(Err(Status::internal("Failed to unseal vault"))).await; + return Err(()); + } + } + .into(), + ), + UserAgentRequestPayload::BootstrapEncryptedKey(ProtoBootstrapEncryptedKey { + nonce, + ciphertext, + associated_data, + }) => UserAgentResponsePayload::BootstrapResult( + match actor + .ask(HandleBootstrapEncryptedKey { + nonce, + ciphertext, + associated_data, + }) + .await + { + Ok(()) => ProtoBootstrapResult::Success, + Err(SendError::HandlerError(BootstrapError::InvalidKey)) => { + ProtoBootstrapResult::InvalidKey + } + Err(SendError::HandlerError( + BootstrapError::AlreadyBootstrapped, + )) => ProtoBootstrapResult::AlreadyBootstrapped, + Err(err) => { + warn!(error = ?err, "Failed to handle bootstrap request"); + let _ = bi.send(Err(Status::internal("Failed to bootstrap vault"))).await; + return Err(()); + } + } + .into(), + ), + UserAgentRequestPayload::QueryVaultState(_) => UserAgentResponsePayload::VaultState( + match actor.ask(HandleQueryVaultState {}).await { + Ok(KeyHolderState::Unbootstrapped) => ProtoVaultState::Unbootstrapped, + Ok(KeyHolderState::Sealed) => ProtoVaultState::Sealed, + Ok(KeyHolderState::Unsealed) => ProtoVaultState::Unsealed, + Err(err) => { + warn!(error = ?err, "Failed to query vault state"); + ProtoVaultState::Error + } + } + .into(), + ), + UserAgentRequestPayload::EvmWalletCreate(_) => UserAgentResponsePayload::EvmWalletCreate( + EvmGrantOrWallet::wallet_create_response(actor.ask(HandleEvmWalletCreate {}).await), + ), + UserAgentRequestPayload::EvmWalletList(_) => UserAgentResponsePayload::EvmWalletList( + EvmGrantOrWallet::wallet_list_response(actor.ask(HandleEvmWalletList {}).await), + ), + UserAgentRequestPayload::EvmGrantList(_) => UserAgentResponsePayload::EvmGrantList( + EvmGrantOrWallet::grant_list_response(actor.ask(HandleGrantList {}).await), + ), + UserAgentRequestPayload::EvmGrantCreate(EvmGrantCreateRequest { + client_id, + shared, + specific, + }) => { + let (basic, grant) = match parse_grant_request(shared, specific) { + Ok(values) => values, + Err(status) => { + let _ = bi.send(Err(status)).await; + return Err(()); + } + }; + + UserAgentResponsePayload::EvmGrantCreate(EvmGrantOrWallet::grant_create_response( + actor.ask(HandleGrantCreate { + client_id, + basic, + grant, + }) + .await, + )) + } + UserAgentRequestPayload::EvmGrantDelete(EvmGrantDeleteRequest { grant_id }) => { + UserAgentResponsePayload::EvmGrantDelete(EvmGrantOrWallet::grant_delete_response( + actor.ask(HandleGrantDelete { grant_id }).await, + )) + } + payload => { + warn!(?payload, "Unsupported post-auth user agent request"); + let _ = bi.send(Err(Status::invalid_argument("Unsupported user-agent request"))).await; + return Err(()); + } + }; + + bi.send(Ok(UserAgentResponse { + payload: Some(payload), + })) + .await + .map_err(|_| ()) +} + +async fn send_out_of_band( + bi: &mut GrpcBi, + oob: OutOfBand, +) -> Result<(), ()> { + let payload = match oob { + OutOfBand::ClientConnectionRequest { pubkey } => { + UserAgentResponsePayload::ClientConnectionRequest(ClientConnectionRequest { + pubkey: pubkey.to_bytes().to_vec(), + }) + } + OutOfBand::ClientConnectionCancel => { + UserAgentResponsePayload::ClientConnectionCancel(ClientConnectionCancel {}) + } + }; + + bi.send(Ok(UserAgentResponse { + payload: Some(payload), + })) + .await + .map_err(|_| ()) +} + +fn parse_grant_request( + shared: Option, + specific: Option, +) -> Result<(SharedGrantSettings, SpecificGrant), Status> { + let shared = shared.ok_or_else(|| Status::invalid_argument("Missing shared grant settings"))?; + let specific = + specific.ok_or_else(|| Status::invalid_argument("Missing specific grant settings"))?; + + Ok((shared_settings_from_proto(shared)?, specific_grant_from_proto(specific)?)) +} + +fn shared_settings_from_proto(shared: ProtoSharedSettings) -> Result { + Ok(SharedGrantSettings { + wallet_id: shared.wallet_id, + client_id: 0, + chain: shared.chain_id, + valid_from: shared + .valid_from + .map(proto_timestamp_to_utc) + .transpose()?, + valid_until: shared + .valid_until + .map(proto_timestamp_to_utc) + .transpose()?, + max_gas_fee_per_gas: shared + .max_gas_fee_per_gas + .as_deref() + .map(u256_from_proto_bytes) + .transpose()?, + max_priority_fee_per_gas: shared + .max_priority_fee_per_gas + .as_deref() + .map(u256_from_proto_bytes) + .transpose()?, + rate_limit: shared + .rate_limit + .map(|limit| TransactionRateLimit { + count: limit.count, + window: chrono::Duration::seconds(limit.window_secs), + }), + }) +} + +fn specific_grant_from_proto(specific: ProtoSpecificGrant) -> Result { + match specific.grant { + Some(ProtoSpecificGrantType::EtherTransfer(ProtoEtherTransferSettings { + targets, + limit, + })) => Ok(SpecificGrant::EtherTransfer(ether_transfer::Settings { + target: targets + .into_iter() + .map(address_from_bytes) + .collect::>()?, + limit: volume_rate_limit_from_proto( + limit.ok_or_else(|| { + Status::invalid_argument("Missing ether transfer volume rate limit") + })?, + )?, + })), + Some(ProtoSpecificGrantType::TokenTransfer(ProtoTokenTransferSettings { + token_contract, + target, + volume_limits, + })) => Ok(SpecificGrant::TokenTransfer(token_transfers::Settings { + token_contract: address_from_bytes(token_contract)?, + target: target.map(address_from_bytes).transpose()?, + volume_limits: volume_limits + .into_iter() + .map(volume_rate_limit_from_proto) + .collect::>()?, + })), + None => Err(Status::invalid_argument("Missing specific grant kind")), + } +} + +fn volume_rate_limit_from_proto(limit: ProtoVolumeRateLimit) -> Result { + Ok(VolumeRateLimit { + max_volume: u256_from_proto_bytes(&limit.max_volume)?, + window: chrono::Duration::seconds(limit.window_secs), + }) +} + +fn address_from_bytes(bytes: Vec) -> Result { + if bytes.len() != 20 { + return Err(Status::invalid_argument("Invalid EVM address")); + } + + Ok(Address::from_slice(&bytes)) +} + +fn u256_from_proto_bytes(bytes: &[u8]) -> Result { + if bytes.len() > 32 { + return Err(Status::invalid_argument("Invalid U256 byte length")); + } + + Ok(U256::from_be_slice(bytes)) +} + +fn proto_timestamp_to_utc( + timestamp: prost_types::Timestamp, +) -> Result, Status> { + Utc.timestamp_opt(timestamp.seconds, timestamp.nanos as u32) + .single() + .ok_or_else(|| Status::invalid_argument("Invalid timestamp")) +} + +fn shared_settings_to_proto(shared: SharedGrantSettings) -> ProtoSharedSettings { + ProtoSharedSettings { + wallet_id: shared.wallet_id, + chain_id: shared.chain, + valid_from: shared.valid_from.map(|time| prost_types::Timestamp { + seconds: time.timestamp(), + nanos: time.timestamp_subsec_nanos() as i32, + }), + valid_until: shared.valid_until.map(|time| prost_types::Timestamp { + seconds: time.timestamp(), + nanos: time.timestamp_subsec_nanos() as i32, + }), + max_gas_fee_per_gas: shared.max_gas_fee_per_gas.map(|value| { + value.to_be_bytes::<32>().to_vec() + }), + max_priority_fee_per_gas: shared.max_priority_fee_per_gas.map(|value| { + value.to_be_bytes::<32>().to_vec() + }), + rate_limit: shared.rate_limit.map(|limit| ProtoTransactionRateLimit { + count: limit.count, + window_secs: limit.window.num_seconds(), + }), + } +} + +fn specific_grant_to_proto(grant: SpecificGrant) -> ProtoSpecificGrant { + let grant = match grant { + SpecificGrant::EtherTransfer(settings) => { + ProtoSpecificGrantType::EtherTransfer(ProtoEtherTransferSettings { + targets: settings.target.into_iter().map(|address| address.to_vec()).collect(), + limit: Some(ProtoVolumeRateLimit { + max_volume: settings.limit.max_volume.to_be_bytes::<32>().to_vec(), + window_secs: settings.limit.window.num_seconds(), + }), + }) + } + SpecificGrant::TokenTransfer(settings) => { + ProtoSpecificGrantType::TokenTransfer(ProtoTokenTransferSettings { + token_contract: settings.token_contract.to_vec(), + target: settings.target.map(|address| address.to_vec()), + volume_limits: settings + .volume_limits + .into_iter() + .map(|limit| ProtoVolumeRateLimit { + max_volume: limit.max_volume.to_be_bytes::<32>().to_vec(), + window_secs: limit.window.num_seconds(), + }) + .collect(), + }) + } + }; + + ProtoSpecificGrant { grant: Some(grant) } +} + +struct EvmGrantOrWallet; + +impl EvmGrantOrWallet { + fn wallet_create_response( + result: Result>, + ) -> WalletCreateResponse { + let result = match result { + Ok(wallet) => WalletCreateResult::Wallet(WalletEntry { + address: wallet.to_vec(), + }), + Err(err) => { + warn!(error = ?err, "Failed to create EVM wallet"); + WalletCreateResult::Error(ProtoEvmError::Internal.into()) + } }; - self.sender - .send(outbound) - .await - .map_err(|_| TransportError::ChannelClosed) + WalletCreateResponse { result: Some(result) } } - async fn recv(&mut self) -> Option { - match self.receiver.next().await { - Some(Ok(item)) => match Self::request_to_domain(item) { - Ok(request) => Some(request), - Err(status) => { - let _ = self.sender.send(Err(status)).await; - None - } - }, - Some(Err(error)) => { - tracing::error!(error = ?error, "grpc user-agent recv failed; closing stream"); - None + fn wallet_list_response( + result: Result, SendError>, + ) -> WalletListResponse { + let result = match result { + Ok(wallets) => WalletListResult::Wallets(WalletList { + wallets: wallets + .into_iter() + .map(|wallet| WalletEntry { + address: wallet.to_vec(), + }) + .collect(), + }), + Err(err) => { + warn!(error = ?err, "Failed to list EVM wallets"); + WalletListResult::Error(ProtoEvmError::Internal.into()) } - None => None, - } + }; + + WalletListResponse { result: Some(result) } + } + + fn grant_create_response( + result: Result>, + ) -> EvmGrantCreateResponse { + let result = match result { + Ok(grant_id) => EvmGrantCreateResult::GrantId(grant_id), + Err(err) => { + warn!(error = ?err, "Failed to create EVM grant"); + EvmGrantCreateResult::Error(ProtoEvmError::Internal.into()) + } + }; + + EvmGrantCreateResponse { result: Some(result) } + } + + fn grant_delete_response( + result: Result<(), SendError>, + ) -> EvmGrantDeleteResponse { + let result = match result { + Ok(()) => EvmGrantDeleteResult::Ok(()), + Err(err) => { + warn!(error = ?err, "Failed to delete EVM grant"); + EvmGrantDeleteResult::Error(ProtoEvmError::Internal.into()) + } + }; + + EvmGrantDeleteResponse { result: Some(result) } + } + + fn grant_list_response( + result: Result>, SendError>, + ) -> EvmGrantListResponse { + let result = match result { + Ok(grants) => EvmGrantListResult::Grants(EvmGrantList { + grants: grants + .into_iter() + .map(|grant| GrantEntry { + id: grant.id, + client_id: grant.shared.client_id, + shared: Some(shared_settings_to_proto(grant.shared)), + specific: Some(specific_grant_to_proto(grant.settings)), + }) + .collect(), + }), + Err(err) => { + warn!(error = ?err, "Failed to list EVM grants"); + EvmGrantListResult::Error(ProtoEvmError::Internal.into()) + } + }; + + EvmGrantListResponse { result: Some(result) } } } -fn grant_to_proto(grant: Grant) -> proto::evm::GrantEntry { - GrantEntry { - id: grant.id, - specific: Some(match grant.settings { - SpecificGrant::EtherTransfer(settings) => ProtoSpecificGrant { - grant: Some(ProtoSpecificGrantType::EtherTransfer( - ProtoEtherTransferSettings { - targets: settings - .target - .into_iter() - .map(|addr| addr.as_slice().to_vec()) - .collect(), - limit: Some(proto::evm::VolumeRateLimit { - max_volume: settings.limit.max_volume.to_be_bytes_vec(), - window_secs: settings.limit.window.num_seconds(), - }), - }, - )), - }, - SpecificGrant::TokenTransfer(settings) => ProtoSpecificGrant { - grant: Some(ProtoSpecificGrantType::TokenTransfer( - ProtoTokenTransferSettings { - token_contract: settings.token_contract.as_slice().to_vec(), - target: settings.target.map(|addr| addr.as_slice().to_vec()), - volume_limits: settings - .volume_limits - .into_iter() - .map(|vrl| proto::evm::VolumeRateLimit { - max_volume: vrl.max_volume.to_be_bytes_vec(), - window_secs: vrl.window.num_seconds(), - }) - .collect(), - }, - )), - }, - }), - client_id: grant.shared.client_id, - shared: Some(proto::evm::SharedSettings { - wallet_id: grant.shared.wallet_id, - chain_id: grant.shared.chain, - valid_from: grant.shared.valid_from.map(|dt| Timestamp { - seconds: dt.timestamp(), - nanos: 0, - }), - valid_until: grant.shared.valid_until.map(|dt| Timestamp { - seconds: dt.timestamp(), - nanos: 0, - }), - max_gas_fee_per_gas: grant - .shared - .max_gas_fee_per_gas - .map(|fee| fee.to_be_bytes_vec()), - max_priority_fee_per_gas: grant - .shared - .max_priority_fee_per_gas - .map(|fee| fee.to_be_bytes_vec()), - rate_limit: grant - .shared - .rate_limit - .map(|limit| proto::evm::TransactionRateLimit { - count: limit.count, - window_secs: limit.window.num_seconds(), - }), - }), - } -} - -fn parse_volume_rate_limit(vrl: ProtoVolumeRateLimit) -> Result { - Ok(VolumeRateLimit { - max_volume: U256::from_be_slice(&vrl.max_volume), - window: chrono::Duration::seconds(vrl.window_secs), - }) -} - -fn parse_shared_settings( - client_id: i32, - proto: Option, -) -> Result { - let s = proto.ok_or_else(|| Status::invalid_argument("missing shared settings"))?; - let parse_u256 = |b: Vec| -> Result { - if b.is_empty() { - Err(Status::invalid_argument("U256 bytes must not be empty")) - } else { - Ok(U256::from_be_slice(&b)) +pub async fn start( + mut conn: UserAgentConnection, + mut bi: GrpcBi, +) { + let pubkey = match auth::start(&mut conn, &mut bi).await { + Ok(pubkey) => pubkey, + Err(e) => { + warn!(error = ?e, "Authentication failed"); + return; } }; - let parse_ts = |ts: prost_types::Timestamp| -> Result, Status> { - Utc.timestamp_opt(ts.seconds, ts.nanos as u32) - .single() - .ok_or_else(|| Status::invalid_argument("invalid timestamp")) - }; - Ok(SharedGrantSettings { - wallet_id: s.wallet_id, - client_id, - chain: s.chain_id, - valid_from: s.valid_from.map(parse_ts).transpose()?, - valid_until: s.valid_until.map(parse_ts).transpose()?, - max_gas_fee_per_gas: s.max_gas_fee_per_gas.map(parse_u256).transpose()?, - max_priority_fee_per_gas: s.max_priority_fee_per_gas.map(parse_u256).transpose()?, - rate_limit: s.rate_limit.map(|rl| TransactionRateLimit { - count: rl.count, - window: chrono::Duration::seconds(rl.window_secs), - }), - }) -} - -fn parse_specific_grant(proto: Option) -> Result { - use proto::evm::specific_grant::Grant as ProtoGrant; - let g = proto - .and_then(|sg| sg.grant) - .ok_or_else(|| Status::invalid_argument("missing specific grant"))?; - match g { - ProtoGrant::EtherTransfer(s) => { - let limit = parse_volume_rate_limit( - s.limit - .ok_or_else(|| Status::invalid_argument("missing ether transfer limit"))?, - )?; - let target = s - .targets - .into_iter() - .map(|b| { - if b.len() == 20 { - Ok(Address::from_slice(&b)) - } else { - Err(Status::invalid_argument( - "ether transfer target must be 20 bytes", - )) - } - }) - .collect::, _>>()?; - Ok(SpecificGrant::EtherTransfer(ether_transfer::Settings { - target, - limit, - })) - } - ProtoGrant::TokenTransfer(s) => { - if s.token_contract.len() != 20 { - return Err(Status::invalid_argument("token_contract must be 20 bytes")); - } - let target = s - .target - .map(|b| { - if b.len() == 20 { - Ok(Address::from_slice(&b)) - } else { - Err(Status::invalid_argument( - "token transfer target must be 20 bytes", - )) - } - }) - .transpose()?; - let volume_limits = s - .volume_limits - .into_iter() - .map(parse_volume_rate_limit) - .collect::, _>>()?; - Ok(SpecificGrant::TokenTransfer(token_transfers::Settings { - token_contract: Address::from_slice(&s.token_contract), - target, - volume_limits, - })) - } - } -} - -fn parse_auth_pubkey(key_type: i32, pubkey: Vec) -> Result { - match ProtoKeyType::try_from(key_type).unwrap_or(ProtoKeyType::Unspecified) { - ProtoKeyType::Unspecified | ProtoKeyType::Ed25519 => { - let bytes: [u8; 32] = pubkey - .as_slice() - .try_into() - .map_err(|_| Status::invalid_argument("invalid Ed25519 public key length"))?; - let key = ed25519_dalek::VerifyingKey::from_bytes(&bytes) - .map_err(|_| Status::invalid_argument("invalid Ed25519 public key encoding"))?; - Ok(AuthPublicKey::Ed25519(key)) - } - ProtoKeyType::EcdsaSecp256k1 => { - let key = k256::ecdsa::VerifyingKey::from_sec1_bytes(&pubkey) - .map_err(|_| Status::invalid_argument("invalid secp256k1 public key encoding"))?; - Ok(AuthPublicKey::EcdsaSecp256k1(key)) - } - ProtoKeyType::Rsa => { - use rsa::pkcs8::DecodePublicKey as _; - - let key = rsa::RsaPublicKey::from_public_key_der(&pubkey) - .map_err(|_| Status::invalid_argument("invalid RSA public key encoding"))?; - Ok(AuthPublicKey::Rsa(key)) - } - } -} - -fn auth_error_status(value: &user_agent::auth::Error) -> Status { - use user_agent::auth::Error; - - match value { - Error::UnexpectedMessagePayload | Error::InvalidClientPubkeyLength => { - Status::invalid_argument(value.to_string()) - } - Error::InvalidAuthPubkeyEncoding => { - Status::invalid_argument("Failed to convert pubkey to VerifyingKey") - } - Error::PublicKeyNotRegistered | Error::InvalidChallengeSolution => { - Status::unauthenticated(value.to_string()) - } - Error::InvalidBootstrapToken => Status::invalid_argument("Invalid bootstrap token"), - Error::Transport => Status::internal("Transport error"), - Error::BootstrapperActorUnreachable => { - Status::internal("Bootstrap token consumption failed") - } - Error::DatabasePoolUnavailable => Status::internal("Database pool error"), - Error::DatabaseOperationFailed => Status::internal("Database error"), - } + + let (oob_sender, oob_receiver) = mpsc::channel(16); + let oob_adapter = OutOfBandAdapter(oob_sender); + + let actor = UserAgentSession::spawn(UserAgentSession::new(conn, Box::new(oob_adapter))); + let actor_for_cleanup = actor.clone(); + + // when connection closes + let _ = defer(move || { + actor_for_cleanup.kill(); + }); + + info!(?pubkey, "User authenticated successfully"); + dispatch_loop(bi, actor, oob_receiver).await; } diff --git a/server/crates/arbiter-server/src/grpc/user_agent/auth.rs b/server/crates/arbiter-server/src/grpc/user_agent/auth.rs new file mode 100644 index 0000000..0e648a8 --- /dev/null +++ b/server/crates/arbiter-server/src/grpc/user_agent/auth.rs @@ -0,0 +1,151 @@ +use arbiter_proto::{ + proto::{ + self, + evm::{ + EtherTransferSettings as ProtoEtherTransferSettings, EvmError as ProtoEvmError, + EvmGrantCreateRequest, EvmGrantCreateResponse, EvmGrantDeleteRequest, + EvmGrantDeleteResponse, EvmGrantList, EvmGrantListResponse, GrantEntry, + SharedSettings as ProtoSharedSettings, SpecificGrant as ProtoSpecificGrant, + TokenTransferSettings as ProtoTokenTransferSettings, + VolumeRateLimit as ProtoVolumeRateLimit, WalletCreateResponse, WalletEntry, WalletList, + WalletListResponse, evm_grant_create_response::Result as EvmGrantCreateResult, + evm_grant_delete_response::Result as EvmGrantDeleteResult, + evm_grant_list_response::Result as EvmGrantListResult, + specific_grant::Grant as ProtoSpecificGrantType, + wallet_create_response::Result as WalletCreateResult, + wallet_list_response::Result as WalletListResult, + }, + user_agent::{ + AuthChallenge as ProtoAuthChallenge, AuthChallengeRequest as ProtoAuthChallengeRequest, + AuthChallengeSolution as ProtoAuthChallengeSolution, AuthResult as ProtoAuthResult, + BootstrapEncryptedKey as ProtoBootstrapEncryptedKey, + BootstrapResult as ProtoBootstrapResult, ClientConnectionCancel, + ClientConnectionRequest, ClientConnectionResponse, KeyType as ProtoKeyType, + UnsealEncryptedKey as ProtoUnsealEncryptedKey, UnsealResult as ProtoUnsealResult, + UnsealStart, UnsealStartResponse, UserAgentRequest, UserAgentResponse, + VaultState as ProtoVaultState, user_agent_request::Payload as UserAgentRequestPayload, + user_agent_response::Payload as UserAgentResponsePayload, + }, + }, + transport::{Bi, Error as TransportError, Receiver, Sender, grpc::GrpcBi}, +}; +use async_trait::async_trait; +use tonic::{Status, Streaming}; +use tracing::{info, warn}; + +use crate::{ + actors::user_agent::{ + self, AuthPublicKey, OutOfBand as DomainResponse, UserAgentConnection, auth, + }, + db::models::KeyType, + evm::policies::{ + Grant, SharedGrantSettings, SpecificGrant, TransactionRateLimit, VolumeRateLimit, + ether_transfer, token_transfers, + }, +}; +use alloy::primitives::{Address, U256}; +use chrono::{DateTime, TimeZone, Utc}; + +pub struct AuthTransportAdapter<'a>(&'a mut GrpcBi); + +#[async_trait] +impl Sender> for AuthTransportAdapter<'_> { + async fn send( + &mut self, + item: Result, + ) -> Result<(), TransportError> { + use auth::{Error, Outbound}; + let response = match item { + Ok(Outbound::AuthChallenge { nonce }) => Ok(UserAgentResponsePayload::AuthChallenge( + ProtoAuthChallenge { nonce }, + )), + Ok(Outbound::AuthSuccess) => Ok(UserAgentResponsePayload::AuthResult( + ProtoAuthResult::Success.into(), + )), + + Err(Error::UnregisteredPublicKey) => Ok(UserAgentResponsePayload::AuthResult( + ProtoAuthResult::InvalidKey.into(), + )), + Err(Error::InvalidChallengeSolution) => Ok(UserAgentResponsePayload::AuthResult( + ProtoAuthResult::InvalidSignature.into(), + )), + Err(Error::InvalidBootstrapToken) => Ok(UserAgentResponsePayload::BootstrapResult( + ProtoAuthResult::TokenInvalid.into(), + )), + Err(Error::Internal { details }) => Err(Status::internal(details)), + Err(Error::Transport) => Err(Status::unavailable("transport error")), + }; + self.0 + .send(response.map(|r| UserAgentResponse { payload: Some(r) })) + .await + } +} + +#[async_trait] +impl Receiver for AuthTransportAdapter<'_> { + async fn recv(&mut self) -> Option { + let Ok(UserAgentRequest { + payload: Some(payload), + }) = self.0.recv().await? + else { + warn!( + event = "received request with empty payload", + "grpc.useragent.auth_adapter" + ); + return None; + }; + + match payload { + UserAgentRequestPayload::AuthChallengeRequest(ProtoAuthChallengeRequest { + pubkey, + bootstrap_token, + key_type, + }) => { + let Ok(key_type) = ProtoKeyType::try_from(key_type) else { + warn!( + event = "received request with invalid key type", + "grpc.useragent.auth_adapter" + ); + return None; + }; + let key_type = match key_type { + ProtoKeyType::Ed25519 => KeyType::Ed25519, + ProtoKeyType::EcdsaSecp256k1 => KeyType::EcdsaSecp256k1, + ProtoKeyType::Rsa => KeyType::Rsa, + ProtoKeyType::Unspecified => { + warn!( + event = "received request with unspecified key type", + "grpc.useragent.auth_adapter" + ); + return None; + } + }; + let Ok(pubkey) = AuthPublicKey::try_from((key_type, pubkey)) else { + warn!( + event = "received request with invalid public key", + "grpc.useragent.auth_adapter" + ); + return None; + }; + + Some(auth::Inbound::AuthChallengeRequest { + pubkey, + bootstrap_token, + }) + } + UserAgentRequestPayload::AuthChallengeSolution(ProtoAuthChallengeSolution { + signature, + }) => Some(auth::Inbound::AuthChallengeSolution { signature }), + _ => None, // Ignore other request types for this adapter + } + } +} +impl Bi> for AuthTransportAdapter<'_> {} + +pub async fn start( + conn: &mut UserAgentConnection, + bi: &mut GrpcBi, +) -> Result { + let mut transport = AuthTransportAdapter(bi); + auth::authenticate(conn, transport).await +} diff --git a/server/crates/arbiter-server/src/lib.rs b/server/crates/arbiter-server/src/lib.rs index 410e499..0b255e5 100644 --- a/server/crates/arbiter-server/src/lib.rs +++ b/server/crates/arbiter-server/src/lib.rs @@ -13,6 +13,7 @@ pub mod db; pub mod evm; pub mod grpc; pub mod safe_cell; +pub mod utils; const DEFAULT_CHANNEL_SIZE: usize = 1000; diff --git a/server/crates/arbiter-server/src/utils.rs b/server/crates/arbiter-server/src/utils.rs new file mode 100644 index 0000000..d072aa7 --- /dev/null +++ b/server/crates/arbiter-server/src/utils.rs @@ -0,0 +1,16 @@ +struct DeferClosure { + f: Option, +} + +impl Drop for DeferClosure { + fn drop(&mut self) { + if let Some(f) = self.f.take() { + f(); + } + } +} + +// Run some code when a scope is exited, similar to Go's defer statement +pub fn defer(f: F) -> impl Drop + Sized { + DeferClosure { f: Some(f) } +} diff --git a/server/crates/arbiter-server/tests/user_agent/auth.rs b/server/crates/arbiter-server/tests/user_agent/auth.rs index 4f23e9d..1a7bbad 100644 --- a/server/crates/arbiter-server/tests/user_agent/auth.rs +++ b/server/crates/arbiter-server/tests/user_agent/auth.rs @@ -3,7 +3,7 @@ use arbiter_server::{ actors::{ GlobalActors, bootstrap::GetToken, - user_agent::{AuthPublicKey, Request, Response, UserAgentConnection, connect_user_agent}, + user_agent::{AuthPublicKey, Request, OutOfBand, UserAgentConnection, connect_user_agent}, }, db::{self, schema}, }; @@ -118,7 +118,7 @@ pub async fn test_challenge_auth() { .expect("should receive challenge"); let challenge = match response { Ok(resp) => match resp { - Response::AuthChallenge { nonce } => nonce, + OutOfBand::AuthChallenge { nonce } => nonce, other => panic!("Expected AuthChallenge, got {other:?}"), }, Err(err) => panic!("Expected Ok response, got Err({err:?})"), diff --git a/server/crates/arbiter-server/tests/user_agent/unseal.rs b/server/crates/arbiter-server/tests/user_agent/unseal.rs index ec5de37..4b6d7a3 100644 --- a/server/crates/arbiter-server/tests/user_agent/unseal.rs +++ b/server/crates/arbiter-server/tests/user_agent/unseal.rs @@ -2,7 +2,7 @@ use arbiter_server::{ actors::{ GlobalActors, keyholder::{Bootstrap, Seal}, - user_agent::{Request, Response, UnsealError, session::UserAgentSession}, + user_agent::{Request, OutOfBand, UnsealError, session::UserAgentSession}, }, db, safe_cell::{SafeCell, SafeCellHandle as _}, @@ -40,7 +40,7 @@ async fn client_dh_encrypt(user_agent: &mut UserAgentSession, key_to_send: &[u8] .unwrap(); let server_pubkey = match response { - Response::UnsealStartResponse { server_pubkey } => server_pubkey, + OutOfBand::UnsealStartResponse { server_pubkey } => server_pubkey, other => panic!("Expected UnsealStartResponse, got {other:?}"), }; @@ -73,7 +73,7 @@ pub async fn test_unseal_success() { .await .unwrap(); - assert!(matches!(response, Response::UnsealResult(Ok(())))); + assert!(matches!(response, OutOfBand::UnsealResult(Ok(())))); } #[tokio::test] @@ -90,7 +90,7 @@ pub async fn test_unseal_wrong_seal_key() { assert!(matches!( response, - Response::UnsealResult(Err(UnsealError::InvalidKey)) + OutOfBand::UnsealResult(Err(UnsealError::InvalidKey)) )); } @@ -120,7 +120,7 @@ pub async fn test_unseal_corrupted_ciphertext() { assert!(matches!( response, - Response::UnsealResult(Err(UnsealError::InvalidKey)) + OutOfBand::UnsealResult(Err(UnsealError::InvalidKey)) )); } @@ -140,7 +140,7 @@ pub async fn test_unseal_retry_after_invalid_key() { assert!(matches!( response, - Response::UnsealResult(Err(UnsealError::InvalidKey)) + OutOfBand::UnsealResult(Err(UnsealError::InvalidKey)) )); } @@ -152,6 +152,6 @@ pub async fn test_unseal_retry_after_invalid_key() { .await .unwrap(); - assert!(matches!(response, Response::UnsealResult(Ok(())))); + assert!(matches!(response, OutOfBand::UnsealResult(Ok(())))); } }