refactor(server::useragent): migrated to new connection design

This commit is contained in:
hdbg
2026-03-17 18:39:12 +01:00
committed by Stas
parent c439c9645d
commit d61dab3285
20 changed files with 1151 additions and 958 deletions

View File

@@ -2,8 +2,8 @@ syntax = "proto3";
package arbiter.user_agent; package arbiter.user_agent;
import "google/protobuf/empty.proto";
import "evm.proto"; import "evm.proto";
import "google/protobuf/empty.proto";
enum KeyType { enum KeyType {
KEY_TYPE_UNSPECIFIED = 0; KEY_TYPE_UNSPECIFIED = 0;
@@ -19,15 +19,23 @@ message AuthChallengeRequest {
} }
message AuthChallenge { message AuthChallenge {
bytes pubkey = 1;
int32 nonce = 2; int32 nonce = 2;
reserved 1;
} }
message AuthChallengeSolution { message AuthChallengeSolution {
bytes signature = 1; bytes signature = 1;
} }
message AuthOk {} enum AuthResult {
AUTH_RESULT_UNSPECIFIED = 0;
AUTH_RESULT_SUCCESS = 1;
AUTH_RESULT_INVALID_KEY = 2;
AUTH_RESULT_INVALID_SIGNATURE = 3;
AUTH_RESULT_BOOTSTRAP_REQUIRED = 4;
AUTH_RESULT_TOKEN_INVALID = 5;
AUTH_RESULT_INTERNAL = 6;
}
message UnsealStart { message UnsealStart {
bytes client_pubkey = 1; bytes client_pubkey = 1;
@@ -99,7 +107,7 @@ message UserAgentRequest {
message UserAgentResponse { message UserAgentResponse {
oneof payload { oneof payload {
AuthChallenge auth_challenge = 1; AuthChallenge auth_challenge = 1;
AuthOk auth_ok = 2; AuthResult auth_result = 2;
UnsealStartResponse unseal_start_response = 3; UnsealStartResponse unseal_start_response = 3;
UnsealResult unseal_result = 4; UnsealResult unseal_result = 4;
VaultState vault_state = 5; VaultState vault_state = 5;

1
server/Cargo.lock generated
View File

@@ -697,6 +697,7 @@ dependencies = [
"rustls-pki-types", "rustls-pki-types",
"thiserror", "thiserror",
"tokio", "tokio",
"tokio-stream",
"tonic", "tonic",
"tonic-prost", "tonic-prost",
"tonic-prost-build", "tonic-prost-build",

View File

@@ -21,6 +21,7 @@ base64 = "0.22.1"
prost-types.workspace = true prost-types.workspace = true
tracing.workspace = true tracing.workspace = true
async-trait.workspace = true async-trait.workspace = true
tokio-stream.workspace = true
[build-dependencies] [build-dependencies]
tonic-prost-build = "0.14.3" tonic-prost-build = "0.14.3"

View File

@@ -63,16 +63,29 @@ where
extractor(msg).ok_or(Error::UnexpectedMessage) extractor(msg).ok_or(Error::UnexpectedMessage)
} }
#[async_trait]
pub trait Sender<Outbound>: Send + Sync {
async fn send(&mut self, item: Outbound) -> Result<(), Error>;
}
#[async_trait]
pub trait Receiver<Inbound>: Send + Sync {
async fn recv(&mut self) -> Option<Inbound>;
}
/// Minimal bidirectional transport abstraction used by protocol code. /// Minimal bidirectional transport abstraction used by protocol code.
/// ///
/// `Bi<Inbound, Outbound>` models a duplex channel with: /// `Bi<Inbound, Outbound>` models a duplex channel with:
/// - inbound items of type `Inbound` read via [`Bi::recv`] /// - inbound items of type `Inbound` read via [`Bi::recv`]
/// - outbound items of type `Outbound` written via [`Bi::send`] /// - outbound items of type `Outbound` written via [`Bi::send`]
#[async_trait] pub trait Bi<Inbound, Outbound>: Sender<Outbound> + Receiver<Inbound> + Send + Sync {}
pub trait Bi<Inbound, Outbound>: Send + Sync + 'static {
async fn send(&mut self, item: Outbound) -> Result<(), Error>;
async fn recv(&mut self) -> Option<Inbound>; pub trait SplittableBi<Inbound, Outbound>: Bi<Inbound, Outbound> {
type Sender: Sender<Outbound>;
type Receiver: Receiver<Inbound>;
fn split(self) -> (Self::Sender, Self::Receiver);
fn from_parts(sender: Self::Sender, receiver: Self::Receiver) -> Self;
} }
/// No-op [`Bi`] transport for tests and manual actor usage. /// No-op [`Bi`] transport for tests and manual actor usage.
@@ -83,22 +96,16 @@ pub struct DummyTransport<Inbound, Outbound> {
_marker: PhantomData<(Inbound, Outbound)>, _marker: PhantomData<(Inbound, Outbound)>,
} }
impl<Inbound, Outbound> DummyTransport<Inbound, Outbound> { impl<Inbound, Outbound> Default for DummyTransport<Inbound, Outbound> {
pub fn new() -> Self { fn default() -> Self {
Self { Self {
_marker: PhantomData, _marker: PhantomData,
} }
} }
} }
impl<Inbound, Outbound> Default for DummyTransport<Inbound, Outbound> {
fn default() -> Self {
Self::new()
}
}
#[async_trait] #[async_trait]
impl<Inbound, Outbound> Bi<Inbound, Outbound> for DummyTransport<Inbound, Outbound> impl<Inbound, Outbound> Sender<Outbound> for DummyTransport<Inbound, Outbound>
where where
Inbound: Send + Sync + 'static, Inbound: Send + Sync + 'static,
Outbound: Send + Sync + 'static, Outbound: Send + Sync + 'static,
@@ -106,9 +113,25 @@ where
async fn send(&mut self, _item: Outbound) -> Result<(), Error> { async fn send(&mut self, _item: Outbound) -> Result<(), Error> {
Ok(()) Ok(())
} }
}
#[async_trait]
impl<Inbound, Outbound> Receiver<Inbound> for DummyTransport<Inbound, Outbound>
where
Inbound: Send + Sync + 'static,
Outbound: Send + Sync + 'static,
{
async fn recv(&mut self) -> Option<Inbound> { async fn recv(&mut self) -> Option<Inbound> {
std::future::pending::<()>().await; std::future::pending::<()>().await;
None None
} }
} }
impl<Inbound, Outbound> Bi<Inbound, Outbound> for DummyTransport<Inbound, Outbound>
where
Inbound: Send + Sync + 'static,
Outbound: Send + Sync + 'static,
{
}
pub mod grpc;

View File

@@ -0,0 +1,106 @@
use async_trait::async_trait;
use futures::StreamExt;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use super::{Bi, Receiver, Sender};
pub struct GrpcSender<Outbound> {
tx: mpsc::Sender<Result<Outbound, tonic::Status>>,
}
#[async_trait]
impl<Outbound> Sender<Result<Outbound, tonic::Status>> for GrpcSender<Outbound>
where
Outbound: Send + Sync + 'static,
{
async fn send(&mut self, item: Result<Outbound, tonic::Status>) -> Result<(), super::Error> {
self.tx
.send(item)
.await
.map_err(|_| super::Error::ChannelClosed)
}
}
pub struct GrpcReceiver<Inbound> {
rx: tonic::Streaming<Inbound>,
}
#[async_trait]
impl<Inbound> Receiver<Result<Inbound, tonic::Status>> for GrpcReceiver<Inbound>
where
Inbound: Send + Sync + 'static,
{
async fn recv(&mut self) -> Option<Result<Inbound, tonic::Status>> {
self.rx.next().await
}
}
pub struct GrpcBi<Inbound, Outbound> {
sender: GrpcSender<Outbound>,
receiver: GrpcReceiver<Inbound>,
}
impl<Inbound, Outbound> GrpcBi<Inbound, Outbound>
where
Inbound: Send + Sync + 'static,
Outbound: Send + Sync + 'static,
{
pub fn from_bi_stream(
receiver: tonic::Streaming<Inbound>,
) -> (Self, ReceiverStream<Result<Outbound, tonic::Status>>) {
let (tx, rx) = mpsc::channel(10);
let sender = GrpcSender { tx };
let receiver = GrpcReceiver { rx: receiver };
let bi = GrpcBi { sender, receiver };
(bi, ReceiverStream::new(rx))
}
}
#[async_trait]
impl<Inbound, Outbound> Sender<Result<Outbound, tonic::Status>> for GrpcBi<Inbound, Outbound>
where
Inbound: Send + Sync + 'static,
Outbound: Send + Sync + 'static,
{
async fn send(&mut self, item: Result<Outbound, tonic::Status>) -> Result<(), super::Error> {
self.sender.send(item).await
}
}
#[async_trait]
impl<Inbound, Outbound> Receiver<Result<Inbound, tonic::Status>> for GrpcBi<Inbound, Outbound>
where
Inbound: Send + Sync + 'static,
Outbound: Send + Sync + 'static,
{
async fn recv(&mut self) -> Option<Result<Inbound, tonic::Status>> {
self.receiver.recv().await
}
}
impl<Inbound, Outbound> Bi<Result<Inbound, tonic::Status>, Result<Outbound, tonic::Status>>
for GrpcBi<Inbound, Outbound>
where
Inbound: Send + Sync + 'static,
Outbound: Send + Sync + 'static,
{
}
impl<Inbound, Outbound>
super::SplittableBi<Result<Inbound, tonic::Status>, Result<Outbound, tonic::Status>>
for GrpcBi<Inbound, Outbound>
where
Inbound: Send + Sync + 'static,
Outbound: Send + Sync + 'static,
{
type Sender = GrpcSender<Outbound>;
type Receiver = GrpcReceiver<Inbound>;
fn split(self) -> (Self::Sender, Self::Receiver) {
(self.sender, self.receiver)
}
fn from_parts(sender: Self::Sender, receiver: Self::Receiver) -> Self {
GrpcBi { sender, receiver }
}
}

View File

@@ -22,7 +22,7 @@ use encryption::v1::{self, KeyCell, Nonce};
pub mod encryption; pub mod encryption;
#[derive(Default, EnumDiscriminants)] #[derive(Default, EnumDiscriminants)]
#[strum_discriminants(derive(Reply), vis(pub))] #[strum_discriminants(derive(Reply), vis(pub), name(KeyHolderState))]
enum State { enum State {
#[default] #[default]
Unbootstrapped, Unbootstrapped,
@@ -325,7 +325,7 @@ impl KeyHolder {
} }
#[message] #[message]
pub fn get_state(&self) -> StateDiscriminants { pub fn get_state(&self) -> KeyHolderState {
self.state.discriminant() self.state.discriminant()
} }

View File

@@ -1,74 +1,82 @@
use arbiter_proto::transport::Bi;
use tracing::error; use tracing::error;
use crate::actors::user_agent::{ use crate::actors::user_agent::{
Request, UserAgentConnection, AuthPublicKey, UserAgentConnection,
auth::state::{AuthContext, AuthStateMachine}, auth::state::{AuthContext, AuthStateMachine},
AuthPublicKey,
session::UserAgentSession,
}; };
#[derive(thiserror::Error, Debug, PartialEq)]
pub enum Error {
#[error("Unexpected message payload")]
UnexpectedMessagePayload,
#[error("Invalid client public key length")]
InvalidClientPubkeyLength,
#[error("Invalid client public key encoding")]
InvalidAuthPubkeyEncoding,
#[error("Database pool unavailable")]
DatabasePoolUnavailable,
#[error("Database operation failed")]
DatabaseOperationFailed,
#[error("Public key not registered")]
PublicKeyNotRegistered,
#[error("Transport error")]
Transport,
#[error("Invalid bootstrap token")]
InvalidBootstrapToken,
#[error("Bootstrapper actor unreachable")]
BootstrapperActorUnreachable,
#[error("Invalid challenge solution")]
InvalidChallengeSolution,
}
mod state; mod state;
use state::*; use state::*;
fn parse_auth_event(payload: Request) -> Result<AuthEvents, Error> { #[derive(Debug, Clone)]
match payload { pub enum Inbound {
Request::AuthChallengeRequest { AuthChallengeRequest {
pubkey, pubkey: AuthPublicKey,
bootstrap_token: None, bootstrap_token: Option<String>,
} => Ok(AuthEvents::AuthRequest(ChallengeRequest { pubkey })), },
Request::AuthChallengeRequest { AuthChallengeSolution {
pubkey, signature: Vec<u8>,
bootstrap_token: Some(token), },
} => Ok(AuthEvents::BootstrapAuthRequest(BootstrapAuthRequest { }
pubkey,
token, #[derive(Debug)]
})), pub enum Error {
Request::AuthChallengeSolution { signature } => { UnregisteredPublicKey,
Ok(AuthEvents::ReceivedSolution(ChallengeSolution { InvalidChallengeSolution,
solution: signature, InvalidBootstrapToken,
})) Internal { details: String },
Transport,
}
impl Error {
fn internal(details: impl Into<String>) -> Self {
Self::Internal {
details: details.into(),
} }
_ => Err(Error::UnexpectedMessagePayload),
} }
} }
pub async fn authenticate(props: &mut UserAgentConnection) -> Result<AuthPublicKey, Error> { #[derive(Debug, Clone)]
let mut state = AuthStateMachine::new(AuthContext::new(props)); pub enum Outbound {
AuthChallenge { nonce: i32 },
AuthSuccess,
}
fn parse_auth_event(payload: Inbound) -> AuthEvents {
match payload {
Inbound::AuthChallengeRequest {
pubkey,
bootstrap_token: None,
} => AuthEvents::AuthRequest(ChallengeRequest { pubkey }),
Inbound::AuthChallengeRequest {
pubkey,
bootstrap_token: Some(token),
} => AuthEvents::BootstrapAuthRequest(BootstrapAuthRequest { pubkey, token }),
Inbound::AuthChallengeSolution { signature } => {
AuthEvents::ReceivedSolution(ChallengeSolution {
solution: signature,
})
}
}
}
pub async fn authenticate<T>(
props: &mut UserAgentConnection,
transport: T,
) -> Result<AuthPublicKey, Error>
where
T: Bi<Inbound, Result<Outbound, Error>> + Send,
{
let mut state = AuthStateMachine::new(AuthContext::new(props, transport));
loop { loop {
// `state` holds a mutable reference to `props` so we can't access it directly here // `state` holds a mutable reference to `props` so we can't access it directly here
let transport = state.context_mut().conn.transport.as_mut(); let Some(payload) = state.context_mut().transport.recv().await else {
let Some(payload) = transport.recv().await else {
return Err(Error::Transport); return Err(Error::Transport);
}; };
let event = parse_auth_event(payload)?; match state.process_event(parse_auth_event(payload)).await {
match state.process_event(event).await {
Ok(AuthStates::AuthOk(key)) => return Ok(key.clone()), Ok(AuthStates::AuthOk(key)) => return Ok(key.clone()),
Err(AuthError::ActionFailed(err)) => { Err(AuthError::ActionFailed(err)) => {
error!(?err, "State machine action failed"); error!(?err, "State machine action failed");
@@ -91,11 +99,3 @@ pub async fn authenticate(props: &mut UserAgentConnection) -> Result<AuthPublicK
} }
} }
} }
pub async fn authenticate_and_create(
mut props: UserAgentConnection,
) -> Result<UserAgentSession, Error> {
let _key = authenticate(&mut props).await?;
let session = UserAgentSession::new(props);
Ok(session)
}

View File

@@ -1,3 +1,5 @@
use alloy::transports::Transport;
use arbiter_proto::transport::Bi;
use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, update}; use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, update};
use diesel_async::RunQueryDsl; use diesel_async::RunQueryDsl;
use tracing::error; use tracing::error;
@@ -6,7 +8,7 @@ use super::Error;
use crate::{ use crate::{
actors::{ actors::{
bootstrap::ConsumeToken, bootstrap::ConsumeToken,
user_agent::{AuthPublicKey, Response, UserAgentConnection}, user_agent::{AuthPublicKey, OutOfBand, UserAgentConnection, auth::Outbound},
}, },
db::schema, db::schema,
}; };
@@ -42,7 +44,7 @@ smlang::statemachine!(
async fn create_nonce(db: &crate::db::DatabasePool, pubkey_bytes: &[u8]) -> Result<i32, Error> { async fn create_nonce(db: &crate::db::DatabasePool, pubkey_bytes: &[u8]) -> Result<i32, Error> {
let mut db_conn = db.get().await.map_err(|e| { let mut db_conn = db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error"); error!(error = ?e, "Database pool error");
Error::DatabasePoolUnavailable Error::internal("Database unavailable")
})?; })?;
db_conn db_conn
.exclusive_transaction(|conn| { .exclusive_transaction(|conn| {
@@ -66,11 +68,11 @@ async fn create_nonce(db: &crate::db::DatabasePool, pubkey_bytes: &[u8]) -> Resu
.optional() .optional()
.map_err(|e| { .map_err(|e| {
error!(error = ?e, "Database error"); error!(error = ?e, "Database error");
Error::DatabaseOperationFailed Error::internal("Database operation failed")
})? })?
.ok_or_else(|| { .ok_or_else(|| {
error!(?pubkey_bytes, "Public key not found in database"); error!(?pubkey_bytes, "Public key not found in database");
Error::PublicKeyNotRegistered Error::UnregisteredPublicKey
}) })
} }
@@ -79,7 +81,7 @@ async fn register_key(db: &crate::db::DatabasePool, pubkey: &AuthPublicKey) -> R
let key_type = pubkey.key_type(); let key_type = pubkey.key_type();
let mut conn = db.get().await.map_err(|e| { let mut conn = db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error"); error!(error = ?e, "Database pool error");
Error::DatabasePoolUnavailable Error::internal("Database unavailable")
})?; })?;
diesel::insert_into(schema::useragent_client::table) diesel::insert_into(schema::useragent_client::table)
@@ -92,23 +94,27 @@ async fn register_key(db: &crate::db::DatabasePool, pubkey: &AuthPublicKey) -> R
.await .await
.map_err(|e| { .map_err(|e| {
error!(error = ?e, "Database error"); error!(error = ?e, "Database error");
Error::DatabaseOperationFailed Error::internal("Database operation failed")
})?; })?;
Ok(()) Ok(())
} }
pub struct AuthContext<'a> { pub struct AuthContext<'a, T> {
pub(super) conn: &'a mut UserAgentConnection, pub(super) conn: &'a mut UserAgentConnection,
pub(super) transport: T,
} }
impl<'a> AuthContext<'a> { impl<'a, T> AuthContext<'a, T> {
pub fn new(conn: &'a mut UserAgentConnection) -> Self { pub fn new(conn: &'a mut UserAgentConnection, transport: T) -> Self {
Self { conn } Self { conn, transport }
} }
} }
impl AuthStateMachineContext for AuthContext<'_> { impl<T> AuthStateMachineContext for AuthContext<'_, T>
where
T: Bi<super::Inbound, Result<super::Outbound, Error>> + Send,
{
type Error = Error; type Error = Error;
async fn prepare_challenge( async fn prepare_challenge(
@@ -118,9 +124,9 @@ impl AuthStateMachineContext for AuthContext<'_> {
let stored_bytes = pubkey.to_stored_bytes(); let stored_bytes = pubkey.to_stored_bytes();
let nonce = create_nonce(&self.conn.db, &stored_bytes).await?; let nonce = create_nonce(&self.conn.db, &stored_bytes).await?;
self.conn self
.transport .transport
.send(Ok(Response::AuthChallenge { nonce })) .send(Ok(Outbound::AuthChallenge { nonce }))
.await .await
.map_err(|e| { .map_err(|e| {
error!(?e, "Failed to send auth challenge"); error!(?e, "Failed to send auth challenge");
@@ -149,7 +155,7 @@ impl AuthStateMachineContext for AuthContext<'_> {
.await .await
.map_err(|e| { .map_err(|e| {
error!(?e, "Failed to consume bootstrap token"); error!(?e, "Failed to consume bootstrap token");
Error::BootstrapperActorUnreachable Error::internal("Failed to consume bootstrap token")
})?; })?;
if !token_ok { if !token_ok {
@@ -159,11 +165,11 @@ impl AuthStateMachineContext for AuthContext<'_> {
register_key(&self.conn.db, &pubkey).await?; register_key(&self.conn.db, &pubkey).await?;
self.conn self
.transport .transport
.send(Ok(Response::AuthOk)) .send(Ok(Outbound::AuthSuccess))
.await .await
.map_err(|_| Error::Transport)?; .map_err(|_| Error::Transport)?;
Ok(pubkey) Ok(pubkey)
} }
@@ -172,7 +178,10 @@ impl AuthStateMachineContext for AuthContext<'_> {
#[allow(clippy::unused_unit)] #[allow(clippy::unused_unit)]
async fn verify_solution( async fn verify_solution(
&mut self, &mut self,
ChallengeContext { challenge_nonce, key }: &ChallengeContext, ChallengeContext {
challenge_nonce,
key,
}: &ChallengeContext,
ChallengeSolution { solution }: ChallengeSolution, ChallengeSolution { solution }: ChallengeSolution,
) -> Result<AuthPublicKey, Self::Error> { ) -> Result<AuthPublicKey, Self::Error> {
let formatted = arbiter_proto::format_challenge(*challenge_nonce, &key.to_stored_bytes()); let formatted = arbiter_proto::format_challenge(*challenge_nonce, &key.to_stored_bytes());
@@ -205,9 +214,9 @@ impl AuthStateMachineContext for AuthContext<'_> {
}; };
if valid { if valid {
self.conn self
.transport .transport
.send(Ok(Response::AuthOk)) .send(Ok(Outbound::AuthSuccess))
.await .await
.map_err(|_| Error::Transport)?; .map_err(|_| Error::Transport)?;
} }

View File

@@ -1,33 +1,15 @@
use alloy::primitives::Address; use alloy::primitives::Address;
use arbiter_proto::transport::Bi; use arbiter_proto::transport::{Bi, Sender};
use kameo::actor::Spawn as _; use kameo::actor::Spawn as _;
use tracing::{error, info}; use tracing::{error, info};
use crate::{ use crate::{
actors::{GlobalActors, evm, user_agent::session::UserAgentSession}, actors::{GlobalActors, evm},
db::{self, models::KeyType}, db::{self, models::KeyType},
evm::policies::SharedGrantSettings, evm::policies::SharedGrantSettings,
evm::policies::{Grant, SpecificGrant}, evm::policies::{Grant, SpecificGrant},
}; };
#[derive(Debug, thiserror::Error, PartialEq)]
pub enum TransportResponseError {
#[error("Unexpected request payload")]
UnexpectedRequestPayload,
#[error("Invalid state for unseal encrypted key")]
InvalidStateForUnsealEncryptedKey,
#[error("client_pubkey must be 32 bytes")]
InvalidClientPubkeyLength,
#[error("State machine error")]
StateTransitionFailed,
#[error("Vault is not available")]
KeyHolderActorUnreachable,
#[error(transparent)]
Auth(#[from] auth::Error),
#[error("Failed registering connection")]
ConnectionRegistrationFailed,
}
/// Abstraction over Ed25519 / ECDSA-secp256k1 / RSA public keys used during the auth handshake. /// Abstraction over Ed25519 / ECDSA-secp256k1 / RSA public keys used during the auth handshake.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum AuthPublicKey { pub enum AuthPublicKey {
@@ -65,119 +47,55 @@ impl AuthPublicKey {
} }
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)] impl TryFrom<(KeyType, Vec<u8>)> for AuthPublicKey {
pub enum UnsealError { type Error = &'static str;
InvalidKey,
Unbootstrapped, fn try_from(value: (KeyType, Vec<u8>)) -> Result<Self, Self::Error> {
} let (key_type, bytes) = value;
match key_type {
#[derive(Debug, Clone, Copy, PartialEq, Eq)] KeyType::Ed25519 => {
pub enum BootstrapError { let bytes: [u8; 32] = bytes.try_into().map_err(|_| "invalid Ed25519 key length")?;
AlreadyBootstrapped, let key = ed25519_dalek::VerifyingKey::from_bytes(&bytes)
InvalidKey, .map_err(|e| "invalid Ed25519 key")?;
} Ok(AuthPublicKey::Ed25519(key))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)] KeyType::EcdsaSecp256k1 => {
pub enum VaultState { let point =
Unbootstrapped, k256::EncodedPoint::from_bytes(&bytes).map_err(|e| "invalid ECDSA key")?;
Sealed, let key = k256::ecdsa::VerifyingKey::from_encoded_point(&point)
Unsealed, .map_err(|e| "invalid ECDSA key")?;
} Ok(AuthPublicKey::EcdsaSecp256k1(key))
}
#[derive(Debug, Clone)] KeyType::Rsa => {
pub enum Request { use rsa::pkcs8::DecodePublicKey as _;
AuthChallengeRequest { let key = rsa::RsaPublicKey::from_public_key_der(&bytes)
pubkey: AuthPublicKey, .map_err(|e| "invalid RSA key")?;
bootstrap_token: Option<String>, Ok(AuthPublicKey::Rsa(key))
}, }
AuthChallengeSolution { }
signature: Vec<u8>, }
},
UnsealStart {
client_pubkey: x25519_dalek::PublicKey,
},
UnsealEncryptedKey {
nonce: Vec<u8>,
ciphertext: Vec<u8>,
associated_data: Vec<u8>,
},
BootstrapEncryptedKey {
nonce: Vec<u8>,
ciphertext: Vec<u8>,
associated_data: Vec<u8>,
},
QueryVaultState,
EvmWalletCreate,
EvmWalletList,
ClientConnectionResponse {
approved: bool,
},
ListGrants,
EvmGrantCreate {
client_id: i32,
shared: SharedGrantSettings,
specific: SpecificGrant,
},
EvmGrantDelete {
grant_id: i32,
},
} }
// Messages, sent by user agent to connection client without having a request
#[derive(Debug)] #[derive(Debug)]
pub enum Response { pub enum OutOfBand {
AuthChallenge { ClientConnectionRequest { pubkey: ed25519_dalek::VerifyingKey },
nonce: i32,
},
AuthOk,
UnsealStartResponse {
server_pubkey: x25519_dalek::PublicKey,
},
UnsealResult(Result<(), UnsealError>),
BootstrapResult(Result<(), BootstrapError>),
VaultState(VaultState),
ClientConnectionRequest {
pubkey: ed25519_dalek::VerifyingKey,
},
ClientConnectionCancel, ClientConnectionCancel,
EvmWalletCreate(Result<(), evm::Error>),
EvmWalletList(Vec<Address>),
ListGrants(Vec<Grant<SpecificGrant>>),
EvmGrantCreate(Result<i32, evm::Error>),
EvmGrantDelete(Result<(), evm::Error>),
} }
pub type Transport = Box<dyn Bi<Request, Result<Response, TransportResponseError>> + Send>;
pub struct UserAgentConnection { pub struct UserAgentConnection {
db: db::DatabasePool, pub(crate) db: db::DatabasePool,
actors: GlobalActors, pub(crate) actors: GlobalActors,
transport: Transport,
} }
impl UserAgentConnection { impl UserAgentConnection {
pub fn new(db: db::DatabasePool, actors: GlobalActors, transport: Transport) -> Self { pub fn new(db: db::DatabasePool, actors: GlobalActors) -> Self {
Self { Self { db, actors }
db,
actors,
transport,
}
} }
} }
pub mod auth; pub mod auth;
pub mod session; pub mod session;
#[tracing::instrument(skip(props))] pub use auth::authenticate;
pub async fn connect_user_agent(props: UserAgentConnection) { pub use session::UserAgentSession;
match auth::authenticate_and_create(props).await {
Ok(session) => {
UserAgentSession::spawn(session);
info!("User authenticated, session started");
}
Err(err) => {
error!(?err, "Authentication failed, closing connection");
}
}
}

View File

@@ -1,93 +1,63 @@
use std::{borrow::Cow, convert::Infallible};
use arbiter_proto::transport::Sender;
use ed25519_dalek::VerifyingKey; use ed25519_dalek::VerifyingKey;
use kameo::{Actor, messages, prelude::Context}; use kameo::{Actor, messages, prelude::Context};
use thiserror::Error;
use tokio::{select, sync::watch}; use tokio::{select, sync::watch};
use tracing::{error, info}; use tracing::{error, info};
use crate::actors::{ use crate::actors::{
router::RegisterUserAgent, router::RegisterUserAgent,
user_agent::{ user_agent::{OutOfBand, UserAgentConnection},
Request, Response, TransportResponseError,
UserAgentConnection,
},
}; };
mod state; mod state;
use state::{DummyContext, UserAgentEvents, UserAgentStateMachine}; use state::{DummyContext, UserAgentEvents, UserAgentStateMachine};
// Error for consumption by other actors #[derive(Debug, Error)]
#[derive(Debug, thiserror::Error, PartialEq)]
pub enum Error { pub enum Error {
#[error("User agent session ended due to connection loss")] #[error("State transition failed")]
ConnectionLost, State,
#[error("User agent session ended due to unexpected message")] #[error("Internal error: {message}")]
UnexpectedMessage, Internal { message: Cow<'static, str> },
}
impl Error {
pub fn internal(message: impl Into<Cow<'static, str>>) -> Self {
Self::Internal {
message: message.into(),
}
}
} }
pub struct UserAgentSession { pub struct UserAgentSession {
props: UserAgentConnection, props: UserAgentConnection,
state: UserAgentStateMachine<DummyContext>, state: UserAgentStateMachine<DummyContext>,
sender: Box<dyn Sender<OutOfBand>>,
} }
mod connection; mod connection;
pub(crate) use connection::{
BootstrapError, HandleBootstrapEncryptedKey, HandleEvmWalletCreate, HandleEvmWalletList,
HandleGrantCreate, HandleGrantDelete, HandleGrantList, HandleQueryVaultState,
HandleUnsealEncryptedKey, HandleUnsealRequest, UnsealError,
};
impl UserAgentSession { impl UserAgentSession {
pub(crate) fn new(props: UserAgentConnection) -> Self { pub(crate) fn new(props: UserAgentConnection, sender: Box<dyn Sender<OutOfBand>>) -> Self {
Self { Self {
props, props,
state: UserAgentStateMachine::new(DummyContext), state: UserAgentStateMachine::new(DummyContext),
sender,
} }
} }
pub(super) async fn send_msg<Reply: kameo::Reply>( fn transition(&mut self, event: UserAgentEvents) -> Result<(), Error> {
&mut self,
msg: Response,
_ctx: &mut Context<Self, Reply>,
) -> Result<(), Error> {
self.props.transport.send(Ok(msg)).await.map_err(|_| {
error!(
actor = "useragent",
reason = "channel closed",
"send.failed"
);
Error::ConnectionLost
})
}
async fn expect_msg<Extractor, Msg, Reply>(
&mut self,
extractor: Extractor,
ctx: &mut Context<Self, Reply>,
) -> Result<Msg, Error>
where
Extractor: FnOnce(Request) -> Option<Msg>,
Reply: kameo::Reply,
{
let msg = self.props.transport.recv().await.ok_or_else(|| {
error!(
actor = "useragent",
reason = "channel closed",
"recv.failed"
);
ctx.stop();
Error::ConnectionLost
})?;
extractor(msg).ok_or_else(|| {
error!(
actor = "useragent",
reason = "unexpected message",
"recv.failed"
);
ctx.stop();
Error::UnexpectedMessage
})
}
fn transition(&mut self, event: UserAgentEvents) -> Result<(), TransportResponseError> {
self.state.process_event(event).map_err(|e| { self.state.process_event(event).map_err(|e| {
error!(?e, "State transition failed"); error!(?e, "State transition failed");
TransportResponseError::StateTransitionFailed Error::State
})?; })?;
Ok(()) Ok(())
} }
@@ -95,52 +65,21 @@ impl UserAgentSession {
#[messages] #[messages]
impl UserAgentSession { impl UserAgentSession {
// TODO: Think about refactoring it to state-machine based flow, as we already have one
#[message(ctx)] #[message(ctx)]
pub async fn request_new_client_approval( pub async fn request_new_client_approval(
&mut self, &mut self,
client_pubkey: VerifyingKey, client_pubkey: VerifyingKey,
mut cancel_flag: watch::Receiver<()>, mut cancel_flag: watch::Receiver<()>,
ctx: &mut Context<Self, Result<bool, Error>>, ctx: &mut Context<Self, Result<bool, ()>>,
) -> Result<bool, Error> { ) -> Result<bool, ()> {
self.send_msg( todo!("Think about refactoring it to state-machine based flow, as we already have one")
Response::ClientConnectionRequest {
pubkey: client_pubkey,
},
ctx,
)
.await?;
let extractor = |msg| {
if let Request::ClientConnectionResponse { approved } = msg {
Some(approved)
} else {
None
}
};
tokio::select! {
_ = cancel_flag.changed() => {
info!(actor = "useragent", "client connection approval cancelled");
self.send_msg(
Response::ClientConnectionCancel,
ctx,
).await?;
Ok(false)
}
result = self.expect_msg(extractor, ctx) => {
let result = result?;
info!(actor = "useragent", "received client connection approval result: approved={}", result);
Ok(result)
}
}
} }
} }
impl Actor for UserAgentSession { impl Actor for UserAgentSession {
type Args = Self; type Args = Self;
type Error = TransportResponseError; type Error = Error;
async fn on_start( async fn on_start(
args: Self::Args, args: Self::Args,
@@ -155,56 +94,8 @@ impl Actor for UserAgentSession {
.await .await
.map_err(|err| { .map_err(|err| {
error!(?err, "Failed to register user agent connection with router"); error!(?err, "Failed to register user agent connection with router");
TransportResponseError::ConnectionRegistrationFailed Error::internal("Failed to register user agent connection with router")
})?; })?;
Ok(args) Ok(args)
} }
async fn next(
&mut self,
_actor_ref: kameo::prelude::WeakActorRef<Self>,
mailbox_rx: &mut kameo::prelude::MailboxReceiver<Self>,
) -> Option<kameo::mailbox::Signal<Self>> {
loop {
select! {
signal = mailbox_rx.recv() => {
return signal;
}
msg = self.props.transport.recv() => {
match msg {
Some(request) => {
match self.process_transport_inbound(request).await {
Ok(response) => {
if self.props.transport.send(Ok(response)).await.is_err() {
error!(actor = "useragent", reason = "channel closed", "send.failed");
return Some(kameo::mailbox::Signal::Stop);
}
}
Err(err) => {
let _ = self.props.transport.send(Err(err)).await;
return Some(kameo::mailbox::Signal::Stop);
}
}
}
None => {
info!(actor = "useragent", "transport.closed");
return Some(kameo::mailbox::Signal::Stop);
}
}
}
}
}
}
}
impl UserAgentSession {
pub fn new_test(db: crate::db::DatabasePool, actors: crate::actors::GlobalActors) -> Self {
use arbiter_proto::transport::DummyTransport;
let transport: super::Transport = Box::new(DummyTransport::new());
let props = UserAgentConnection::new(db, actors, transport);
Self {
props,
state: UserAgentStateMachine::new(DummyContext),
}
}
} }

View File

@@ -1,10 +1,15 @@
use std::sync::Mutex; use std::sync::Mutex;
use alloy::primitives::Address;
use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit}; use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit};
use kameo::error::SendError; use kameo::error::SendError;
use kameo::messages;
use tracing::{error, info}; use tracing::{error, info};
use x25519_dalek::{EphemeralSecret, PublicKey}; use x25519_dalek::{EphemeralSecret, PublicKey};
use crate::actors::keyholder::KeyHolderState;
use crate::actors::user_agent::session::Error;
use crate::evm::policies::{Grant, SpecificGrant};
use crate::safe_cell::SafeCell; use crate::safe_cell::SafeCell;
use crate::{ use crate::{
actors::{ actors::{
@@ -13,7 +18,7 @@ use crate::{
}, },
keyholder::{self, Bootstrap, TryUnseal}, keyholder::{self, Bootstrap, TryUnseal},
user_agent::{ user_agent::{
BootstrapError, Request, Response, TransportResponseError, UnsealError, VaultState, OutOfBand,
session::{ session::{
UserAgentSession, UserAgentSession,
state::{UnsealContext, UserAgentEvents, UserAgentStates}, state::{UnsealContext, UserAgentEvents, UserAgentStates},
@@ -24,55 +29,10 @@ use crate::{
}; };
impl UserAgentSession { impl UserAgentSession {
pub async fn process_transport_inbound(&mut self, req: Request) -> Output { fn take_unseal_secret(&mut self) -> Result<(EphemeralSecret, PublicKey), Error> {
match req {
Request::UnsealStart { client_pubkey } => {
self.handle_unseal_request(client_pubkey).await
}
Request::UnsealEncryptedKey {
nonce,
ciphertext,
associated_data,
} => {
self.handle_unseal_encrypted_key(nonce, ciphertext, associated_data)
.await
}
Request::BootstrapEncryptedKey {
nonce,
ciphertext,
associated_data,
} => {
self.handle_bootstrap_encrypted_key(nonce, ciphertext, associated_data)
.await
}
Request::ListGrants => self.handle_grant_list().await,
Request::QueryVaultState => self.handle_query_vault_state().await,
Request::EvmWalletCreate => self.handle_evm_wallet_create().await,
Request::EvmWalletList => self.handle_evm_wallet_list().await,
Request::AuthChallengeRequest { .. }
| Request::AuthChallengeSolution { .. }
| Request::ClientConnectionResponse { .. } => {
Err(TransportResponseError::UnexpectedRequestPayload)
}
Request::EvmGrantCreate {
client_id,
shared,
specific,
} => self.handle_grant_create(client_id, shared, specific).await,
Request::EvmGrantDelete { grant_id } => self.handle_grant_delete(grant_id).await,
}
}
}
type Output = Result<Response, TransportResponseError>;
impl UserAgentSession {
fn take_unseal_secret(
&mut self,
) -> Result<(EphemeralSecret, PublicKey), TransportResponseError> {
let UserAgentStates::WaitingForUnsealKey(unseal_context) = self.state.state() else { let UserAgentStates::WaitingForUnsealKey(unseal_context) = self.state.state() else {
error!("Received encrypted key in invalid state"); error!("Received encrypted key in invalid state");
return Err(TransportResponseError::InvalidStateForUnsealEncryptedKey); return Err(Error::internal("Invalid state for unseal encrypted key"));
}; };
let ephemeral_secret = { let ephemeral_secret = {
@@ -87,7 +47,7 @@ impl UserAgentSession {
None => { None => {
drop(secret_lock); drop(secret_lock);
error!("Ephemeral secret already taken"); error!("Ephemeral secret already taken");
return Err(TransportResponseError::StateTransitionFailed); return Err(Error::internal("Ephemeral secret already taken"));
} }
} }
}; };
@@ -121,8 +81,38 @@ impl UserAgentSession {
} }
} }
} }
}
async fn handle_unseal_request(&mut self, client_pubkey: x25519_dalek::PublicKey) -> Output { pub struct UnsealStartResponse {
pub server_pubkey: PublicKey,
}
#[derive(Debug, Error)]
pub enum UnsealError {
#[error("Invalid key provided for unsealing")]
InvalidKey,
#[error("Internal error during unsealing process")]
General(#[from] super::Error),
}
#[derive(Debug, Error)]
pub enum BootstrapError {
#[error("Invalid key provided for bootstrapping")]
InvalidKey,
#[error("Vault is already bootstrapped")]
AlreadyBootstrapped,
#[error("Internal error during bootstrapping process")]
General(#[from] super::Error),
}
#[messages]
impl UserAgentSession {
#[message]
pub(crate) async fn handle_unseal_request(
&mut self,
client_pubkey: x25519_dalek::PublicKey,
) -> Result<UnsealStartResponse, Error> {
let secret = EphemeralSecret::random(); let secret = EphemeralSecret::random();
let public_key = PublicKey::from(&secret); let public_key = PublicKey::from(&secret);
@@ -131,24 +121,27 @@ impl UserAgentSession {
client_public_key: client_pubkey, client_public_key: client_pubkey,
}))?; }))?;
Ok(Response::UnsealStartResponse { Ok(UnsealStartResponse {
server_pubkey: public_key, server_pubkey: public_key,
}) })
} }
async fn handle_unseal_encrypted_key( #[message]
pub(crate) async fn handle_unseal_encrypted_key(
&mut self, &mut self,
nonce: Vec<u8>, nonce: Vec<u8>,
ciphertext: Vec<u8>, ciphertext: Vec<u8>,
associated_data: Vec<u8>, associated_data: Vec<u8>,
) -> Output { ) -> Result<(), UnsealError> {
let (ephemeral_secret, client_public_key) = match self.take_unseal_secret() { let (ephemeral_secret, client_public_key) = match self.take_unseal_secret() {
Ok(values) => values, Ok(values) => values,
Err(TransportResponseError::StateTransitionFailed) => { Err(Error::State) => {
self.transition(UserAgentEvents::ReceivedInvalidKey)?; self.transition(UserAgentEvents::ReceivedInvalidKey)?;
return Ok(Response::UnsealResult(Err(UnsealError::InvalidKey))); return Err(UnsealError::InvalidKey);
}
Err(err) => {
return Err(Error::internal("Failed to take unseal secret").into());
} }
Err(err) => return Err(err),
}; };
let seal_key_buffer = match Self::decrypt_client_key_material( let seal_key_buffer = match Self::decrypt_client_key_material(
@@ -161,7 +154,7 @@ impl UserAgentSession {
Ok(buffer) => buffer, Ok(buffer) => buffer,
Err(()) => { Err(()) => {
self.transition(UserAgentEvents::ReceivedInvalidKey)?; self.transition(UserAgentEvents::ReceivedInvalidKey)?;
return Ok(Response::UnsealResult(Err(UnsealError::InvalidKey))); return Err(UnsealError::InvalidKey);
} }
}; };
@@ -177,38 +170,39 @@ impl UserAgentSession {
Ok(_) => { Ok(_) => {
info!("Successfully unsealed key with client-provided key"); info!("Successfully unsealed key with client-provided key");
self.transition(UserAgentEvents::ReceivedValidKey)?; self.transition(UserAgentEvents::ReceivedValidKey)?;
Ok(Response::UnsealResult(Ok(()))) Ok(())
} }
Err(SendError::HandlerError(keyholder::Error::InvalidKey)) => { Err(SendError::HandlerError(keyholder::Error::InvalidKey)) => {
self.transition(UserAgentEvents::ReceivedInvalidKey)?; self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(Response::UnsealResult(Err(UnsealError::InvalidKey))) Err(UnsealError::InvalidKey)
} }
Err(SendError::HandlerError(err)) => { Err(SendError::HandlerError(err)) => {
error!(?err, "Keyholder failed to unseal key"); error!(?err, "Keyholder failed to unseal key");
self.transition(UserAgentEvents::ReceivedInvalidKey)?; self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(Response::UnsealResult(Err(UnsealError::InvalidKey))) Err(UnsealError::InvalidKey)
} }
Err(err) => { Err(err) => {
error!(?err, "Failed to send unseal request to keyholder"); error!(?err, "Failed to send unseal request to keyholder");
self.transition(UserAgentEvents::ReceivedInvalidKey)?; self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Err(TransportResponseError::KeyHolderActorUnreachable) Err(Error::internal("Vault actor error").into())
} }
} }
} }
async fn handle_bootstrap_encrypted_key( #[message]
pub(crate) async fn handle_bootstrap_encrypted_key(
&mut self, &mut self,
nonce: Vec<u8>, nonce: Vec<u8>,
ciphertext: Vec<u8>, ciphertext: Vec<u8>,
associated_data: Vec<u8>, associated_data: Vec<u8>,
) -> Output { ) -> Result<(), BootstrapError> {
let (ephemeral_secret, client_public_key) = match self.take_unseal_secret() { let (ephemeral_secret, client_public_key) = match self.take_unseal_secret() {
Ok(values) => values, Ok(values) => values,
Err(TransportResponseError::StateTransitionFailed) => { Err(Error::State) => {
self.transition(UserAgentEvents::ReceivedInvalidKey)?; self.transition(UserAgentEvents::ReceivedInvalidKey)?;
return Ok(Response::BootstrapResult(Err(BootstrapError::InvalidKey))); return Err(BootstrapError::InvalidKey);
} }
Err(err) => return Err(err), Err(err) => return Err(err.into()),
}; };
let seal_key_buffer = match Self::decrypt_client_key_material( let seal_key_buffer = match Self::decrypt_client_key_material(
@@ -221,7 +215,7 @@ impl UserAgentSession {
Ok(buffer) => buffer, Ok(buffer) => buffer,
Err(()) => { Err(()) => {
self.transition(UserAgentEvents::ReceivedInvalidKey)?; self.transition(UserAgentEvents::ReceivedInvalidKey)?;
return Ok(Response::BootstrapResult(Err(BootstrapError::InvalidKey))); return Err(BootstrapError::InvalidKey);
} }
}; };
@@ -237,87 +231,94 @@ impl UserAgentSession {
Ok(_) => { Ok(_) => {
info!("Successfully bootstrapped vault with client-provided key"); info!("Successfully bootstrapped vault with client-provided key");
self.transition(UserAgentEvents::ReceivedValidKey)?; self.transition(UserAgentEvents::ReceivedValidKey)?;
Ok(Response::BootstrapResult(Ok(()))) Ok(())
} }
Err(SendError::HandlerError(keyholder::Error::AlreadyBootstrapped)) => { Err(SendError::HandlerError(keyholder::Error::AlreadyBootstrapped)) => {
self.transition(UserAgentEvents::ReceivedInvalidKey)?; self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(Response::BootstrapResult(Err( Err(BootstrapError::AlreadyBootstrapped)
BootstrapError::AlreadyBootstrapped,
)))
} }
Err(SendError::HandlerError(err)) => { Err(SendError::HandlerError(err)) => {
error!(?err, "Keyholder failed to bootstrap vault"); error!(?err, "Keyholder failed to bootstrap vault");
self.transition(UserAgentEvents::ReceivedInvalidKey)?; self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(Response::BootstrapResult(Err(BootstrapError::InvalidKey))) Err(BootstrapError::InvalidKey)
} }
Err(err) => { Err(err) => {
error!(?err, "Failed to send bootstrap request to keyholder"); error!(?err, "Failed to send bootstrap request to keyholder");
self.transition(UserAgentEvents::ReceivedInvalidKey)?; self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Err(TransportResponseError::KeyHolderActorUnreachable) Err(BootstrapError::General(Error::internal(
"Vault actor error",
)))
} }
} }
} }
} }
#[messages]
impl UserAgentSession { impl UserAgentSession {
async fn handle_query_vault_state(&mut self) -> Output { #[message]
use crate::actors::keyholder::{GetState, StateDiscriminants}; pub(crate) async fn handle_query_vault_state(&mut self) -> Result<KeyHolderState, Error> {
use crate::actors::keyholder::GetState;
let vault_state = match self.props.actors.key_holder.ask(GetState {}).await { let vault_state = match self.props.actors.key_holder.ask(GetState {}).await {
Ok(StateDiscriminants::Unbootstrapped) => VaultState::Unbootstrapped, Ok(state) => state,
Ok(StateDiscriminants::Sealed) => VaultState::Sealed,
Ok(StateDiscriminants::Unsealed) => VaultState::Unsealed,
Err(err) => { Err(err) => {
error!(?err, actor = "useragent", "keyholder.query.failed"); error!(?err, actor = "useragent", "keyholder.query.failed");
return Err(TransportResponseError::KeyHolderActorUnreachable); return Err(Error::internal("Vault is in broken state").into());
} }
}; };
Ok(Response::VaultState(vault_state)) Ok(vault_state)
} }
} }
#[messages]
impl UserAgentSession { impl UserAgentSession {
async fn handle_evm_wallet_create(&mut self) -> Output { #[message]
let result = match self.props.actors.evm.ask(Generate {}).await { pub(crate) async fn handle_evm_wallet_create(&mut self) -> Result<Address, Error> {
Ok(_address) => return Ok(Response::EvmWalletCreate(Ok(()))), match self.props.actors.evm.ask(Generate {}).await {
Err(SendError::HandlerError(err)) => Err(err), Ok(address) => return Ok(address),
Err(SendError::HandlerError(err)) => Err(Error::internal(format!(
"EVM wallet generation failed: {err}"
))),
Err(err) => { Err(err) => {
error!(?err, "EVM actor unreachable during wallet create"); error!(?err, "EVM actor unreachable during wallet create");
return Err(TransportResponseError::KeyHolderActorUnreachable); return Err(Error::internal("EVM actor unreachable"));
} }
}; }
Ok(Response::EvmWalletCreate(result))
} }
async fn handle_evm_wallet_list(&mut self) -> Output { #[message]
pub(crate) async fn handle_evm_wallet_list(&mut self) -> Result<Vec<Address>, Error> {
match self.props.actors.evm.ask(ListWallets {}).await { match self.props.actors.evm.ask(ListWallets {}).await {
Ok(wallets) => Ok(Response::EvmWalletList(wallets)), Ok(wallets) => Ok(wallets),
Err(err) => { Err(err) => {
error!(?err, "EVM wallet list failed"); error!(?err, "EVM wallet list failed");
Err(TransportResponseError::KeyHolderActorUnreachable) Err(Error::internal("Failed to list EVM wallets"))
} }
} }
} }
} }
#[messages]
impl UserAgentSession { impl UserAgentSession {
async fn handle_grant_list(&mut self) -> Output { #[message]
pub(crate) async fn handle_grant_list(&mut self) -> Result<Vec<Grant<SpecificGrant>>, Error> {
match self.props.actors.evm.ask(UseragentListGrants {}).await { match self.props.actors.evm.ask(UseragentListGrants {}).await {
Ok(grants) => Ok(Response::ListGrants(grants)), Ok(grants) => Ok(grants),
Err(err) => { Err(err) => {
error!(?err, "EVM grant list failed"); error!(?err, "EVM grant list failed");
Err(TransportResponseError::KeyHolderActorUnreachable) Err(Error::internal("Failed to list EVM grants"))
} }
} }
} }
async fn handle_grant_create( #[message]
pub(crate) async fn handle_grant_create(
&mut self, &mut self,
client_id: i32, client_id: i32,
basic: crate::evm::policies::SharedGrantSettings, basic: crate::evm::policies::SharedGrantSettings,
grant: crate::evm::policies::SpecificGrant, grant: crate::evm::policies::SpecificGrant,
) -> Output { ) -> Result<i32, Error> {
match self match self
.props .props
.actors .actors
@@ -329,15 +330,16 @@ impl UserAgentSession {
}) })
.await .await
{ {
Ok(grant_id) => Ok(Response::EvmGrantCreate(Ok(grant_id))), Ok(grant_id) => Ok(grant_id),
Err(err) => { Err(err) => {
error!(?err, "EVM grant create failed"); error!(?err, "EVM grant create failed");
Err(TransportResponseError::KeyHolderActorUnreachable) Err(Error::internal("Failed to create EVM grant"))
} }
} }
} }
async fn handle_grant_delete(&mut self, grant_id: i32) -> Output { #[message]
pub(crate) async fn handle_grant_delete(&mut self, grant_id: i32) -> Result<(), Error> {
match self match self
.props .props
.actors .actors
@@ -345,10 +347,10 @@ impl UserAgentSession {
.ask(UseragentDeleteGrant { grant_id }) .ask(UseragentDeleteGrant { grant_id })
.await .await
{ {
Ok(()) => Ok(Response::EvmGrantDelete(Ok(()))), Ok(()) => Ok(()),
Err(err) => { Err(err) => {
error!(?err, "EVM grant delete failed"); error!(?err, "EVM grant delete failed");
Err(TransportResponseError::KeyHolderActorUnreachable) Err(Error::internal("Failed to delete EVM grant"))
} }
} }
} }

View File

@@ -44,6 +44,14 @@ pub enum DatabaseSetupError {
Pool(#[from] PoolInitError), Pool(#[from] PoolInitError),
} }
#[derive(Error, Debug)]
pub enum DatabaseError {
#[error("Database connection error")]
Pool(#[from] PoolError),
#[error("Database query error")]
Connection(#[from] diesel::result::Error),
}
#[tracing::instrument(level = "info")] #[tracing::instrument(level = "info")]
fn database_path() -> Result<std::path::PathBuf, DatabaseSetupError> { fn database_path() -> Result<std::path::PathBuf, DatabaseSetupError> {
let arbiter_home = arbiter_proto::home_path().map_err(DatabaseSetupError::HomeDir)?; let arbiter_home = arbiter_proto::home_path().map_err(DatabaseSetupError::HomeDir)?;

View File

@@ -1,14 +1,13 @@
use arbiter_proto::{ use arbiter_proto::{
proto::client::{ proto::client::{
AuthChallenge as ProtoAuthChallenge, AuthChallenge as ProtoAuthChallenge, AuthChallengeRequest as ProtoAuthChallengeRequest,
AuthChallengeRequest as ProtoAuthChallengeRequest,
AuthChallengeSolution as ProtoAuthChallengeSolution, AuthOk as ProtoAuthOk, AuthChallengeSolution as ProtoAuthChallengeSolution, AuthOk as ProtoAuthOk,
ClientConnectError, ClientRequest, ClientResponse, ClientConnectError, ClientRequest, ClientResponse,
client_connect_error::Code as ProtoClientConnectErrorCode, client_connect_error::Code as ProtoClientConnectErrorCode,
client_request::Payload as ClientRequestPayload, client_request::Payload as ClientRequestPayload,
client_response::Payload as ClientResponsePayload, client_response::Payload as ClientResponsePayload,
}, },
transport::{Bi, Error as TransportError}, transport::{Bi, Error as TransportError, Sender},
}; };
use async_trait::async_trait; use async_trait::async_trait;
use futures::StreamExt as _; use futures::StreamExt as _;
@@ -37,9 +36,9 @@ impl GrpcTransport {
Some(ClientRequestPayload::AuthChallengeRequest(ProtoAuthChallengeRequest { Some(ClientRequestPayload::AuthChallengeRequest(ProtoAuthChallengeRequest {
pubkey, pubkey,
})) => Ok(DomainRequest::AuthChallengeRequest { pubkey }), })) => Ok(DomainRequest::AuthChallengeRequest { pubkey }),
Some(ClientRequestPayload::AuthChallengeSolution( Some(ClientRequestPayload::AuthChallengeSolution(ProtoAuthChallengeSolution {
ProtoAuthChallengeSolution { signature }, signature,
)) => Ok(DomainRequest::AuthChallengeSolution { signature }), })) => Ok(DomainRequest::AuthChallengeSolution { signature }),
None => Err(Status::invalid_argument("Missing client request payload")), None => Err(Status::invalid_argument("Missing client request payload")),
} }
} }
@@ -86,8 +85,11 @@ impl GrpcTransport {
} }
#[async_trait] #[async_trait]
impl Bi<DomainRequest, Result<DomainResponse, ClientError>> for GrpcTransport { impl Sender<Result<DomainResponse, ClientError>> for GrpcTransport {
async fn send(&mut self, item: Result<DomainResponse, ClientError>) -> Result<(), TransportError> { async fn send(
&mut self,
item: Result<DomainResponse, ClientError>,
) -> Result<(), TransportError> {
let outbound = match item { let outbound = match item {
Ok(message) => Ok(Self::response_to_proto(message)), Ok(message) => Ok(Self::response_to_proto(message)),
Err(err) => Err(Self::error_to_status(err)), Err(err) => Err(Self::error_to_status(err)),
@@ -98,7 +100,10 @@ impl Bi<DomainRequest, Result<DomainResponse, ClientError>> for GrpcTransport {
.await .await
.map_err(|_| TransportError::ChannelClosed) .map_err(|_| TransportError::ChannelClosed)
} }
}
#[async_trait]
impl Bi<DomainRequest, Result<DomainResponse, ClientError>> for GrpcTransport {
async fn recv(&mut self) -> Option<DomainRequest> { async fn recv(&mut self) -> Option<DomainRequest> {
match self.receiver.next().await { match self.receiver.next().await {
Some(Ok(item)) => match Self::request_to_domain(item) { Some(Ok(item)) => match Self::request_to_domain(item) {

View File

@@ -1,7 +1,9 @@
use arbiter_proto::{
use arbiter_proto::proto::{ proto::{
client::{ClientRequest, ClientResponse}, client::{ClientRequest, ClientResponse},
user_agent::{UserAgentRequest, UserAgentResponse}, user_agent::{UserAgentRequest, UserAgentResponse},
},
transport::grpc::GrpcBi,
}; };
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream; use tokio_stream::wrappers::ReceiverStream;
@@ -10,7 +12,11 @@ use tracing::info;
use crate::{ use crate::{
DEFAULT_CHANNEL_SIZE, DEFAULT_CHANNEL_SIZE,
actors::{client::{ClientConnection, connect_client}, user_agent::{UserAgentConnection, connect_user_agent}}, actors::{
client::{ClientConnection, connect_client},
user_agent::UserAgentConnection,
},
grpc::{self, user_agent::start},
}; };
pub mod client; pub mod client;
@@ -48,18 +54,19 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for super::Ser
request: Request<tonic::Streaming<UserAgentRequest>>, request: Request<tonic::Streaming<UserAgentRequest>>,
) -> Result<Response<Self::UserAgentStream>, Status> { ) -> Result<Response<Self::UserAgentStream>, Status> {
let req_stream = request.into_inner(); let req_stream = request.into_inner();
let (tx, rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
let transport = user_agent::GrpcTransport::new(tx, req_stream); let (bi, rx) = GrpcBi::from_bi_stream(req_stream);
let props = UserAgentConnection::new(
self.context.db.clone(), tokio::spawn(start(
self.context.actors.clone(), UserAgentConnection {
Box::new(transport), db: self.context.db.clone(),
); actors: self.context.actors.clone(),
tokio::spawn(connect_user_agent(props)); },
bi,
));
info!(event = "connection established", "grpc.user_agent"); info!(event = "connection established", "grpc.user_agent");
Ok(Response::new(ReceiverStream::new(rx))) Ok(Response::new(rx))
} }
} }

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,151 @@
use arbiter_proto::{
proto::{
self,
evm::{
EtherTransferSettings as ProtoEtherTransferSettings, EvmError as ProtoEvmError,
EvmGrantCreateRequest, EvmGrantCreateResponse, EvmGrantDeleteRequest,
EvmGrantDeleteResponse, EvmGrantList, EvmGrantListResponse, GrantEntry,
SharedSettings as ProtoSharedSettings, SpecificGrant as ProtoSpecificGrant,
TokenTransferSettings as ProtoTokenTransferSettings,
VolumeRateLimit as ProtoVolumeRateLimit, WalletCreateResponse, WalletEntry, WalletList,
WalletListResponse, evm_grant_create_response::Result as EvmGrantCreateResult,
evm_grant_delete_response::Result as EvmGrantDeleteResult,
evm_grant_list_response::Result as EvmGrantListResult,
specific_grant::Grant as ProtoSpecificGrantType,
wallet_create_response::Result as WalletCreateResult,
wallet_list_response::Result as WalletListResult,
},
user_agent::{
AuthChallenge as ProtoAuthChallenge, AuthChallengeRequest as ProtoAuthChallengeRequest,
AuthChallengeSolution as ProtoAuthChallengeSolution, AuthResult as ProtoAuthResult,
BootstrapEncryptedKey as ProtoBootstrapEncryptedKey,
BootstrapResult as ProtoBootstrapResult, ClientConnectionCancel,
ClientConnectionRequest, ClientConnectionResponse, KeyType as ProtoKeyType,
UnsealEncryptedKey as ProtoUnsealEncryptedKey, UnsealResult as ProtoUnsealResult,
UnsealStart, UnsealStartResponse, UserAgentRequest, UserAgentResponse,
VaultState as ProtoVaultState, user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
},
},
transport::{Bi, Error as TransportError, Receiver, Sender, grpc::GrpcBi},
};
use async_trait::async_trait;
use tonic::{Status, Streaming};
use tracing::{info, warn};
use crate::{
actors::user_agent::{
self, AuthPublicKey, OutOfBand as DomainResponse, UserAgentConnection, auth,
},
db::models::KeyType,
evm::policies::{
Grant, SharedGrantSettings, SpecificGrant, TransactionRateLimit, VolumeRateLimit,
ether_transfer, token_transfers,
},
};
use alloy::primitives::{Address, U256};
use chrono::{DateTime, TimeZone, Utc};
pub struct AuthTransportAdapter<'a>(&'a mut GrpcBi<UserAgentRequest, UserAgentResponse>);
#[async_trait]
impl Sender<Result<auth::Outbound, auth::Error>> for AuthTransportAdapter<'_> {
async fn send(
&mut self,
item: Result<auth::Outbound, auth::Error>,
) -> Result<(), TransportError> {
use auth::{Error, Outbound};
let response = match item {
Ok(Outbound::AuthChallenge { nonce }) => Ok(UserAgentResponsePayload::AuthChallenge(
ProtoAuthChallenge { nonce },
)),
Ok(Outbound::AuthSuccess) => Ok(UserAgentResponsePayload::AuthResult(
ProtoAuthResult::Success.into(),
)),
Err(Error::UnregisteredPublicKey) => Ok(UserAgentResponsePayload::AuthResult(
ProtoAuthResult::InvalidKey.into(),
)),
Err(Error::InvalidChallengeSolution) => Ok(UserAgentResponsePayload::AuthResult(
ProtoAuthResult::InvalidSignature.into(),
)),
Err(Error::InvalidBootstrapToken) => Ok(UserAgentResponsePayload::BootstrapResult(
ProtoAuthResult::TokenInvalid.into(),
)),
Err(Error::Internal { details }) => Err(Status::internal(details)),
Err(Error::Transport) => Err(Status::unavailable("transport error")),
};
self.0
.send(response.map(|r| UserAgentResponse { payload: Some(r) }))
.await
}
}
#[async_trait]
impl Receiver<auth::Inbound> for AuthTransportAdapter<'_> {
async fn recv(&mut self) -> Option<auth::Inbound> {
let Ok(UserAgentRequest {
payload: Some(payload),
}) = self.0.recv().await?
else {
warn!(
event = "received request with empty payload",
"grpc.useragent.auth_adapter"
);
return None;
};
match payload {
UserAgentRequestPayload::AuthChallengeRequest(ProtoAuthChallengeRequest {
pubkey,
bootstrap_token,
key_type,
}) => {
let Ok(key_type) = ProtoKeyType::try_from(key_type) else {
warn!(
event = "received request with invalid key type",
"grpc.useragent.auth_adapter"
);
return None;
};
let key_type = match key_type {
ProtoKeyType::Ed25519 => KeyType::Ed25519,
ProtoKeyType::EcdsaSecp256k1 => KeyType::EcdsaSecp256k1,
ProtoKeyType::Rsa => KeyType::Rsa,
ProtoKeyType::Unspecified => {
warn!(
event = "received request with unspecified key type",
"grpc.useragent.auth_adapter"
);
return None;
}
};
let Ok(pubkey) = AuthPublicKey::try_from((key_type, pubkey)) else {
warn!(
event = "received request with invalid public key",
"grpc.useragent.auth_adapter"
);
return None;
};
Some(auth::Inbound::AuthChallengeRequest {
pubkey,
bootstrap_token,
})
}
UserAgentRequestPayload::AuthChallengeSolution(ProtoAuthChallengeSolution {
signature,
}) => Some(auth::Inbound::AuthChallengeSolution { signature }),
_ => None, // Ignore other request types for this adapter
}
}
}
impl Bi<auth::Inbound, Result<auth::Outbound, auth::Error>> for AuthTransportAdapter<'_> {}
pub async fn start(
conn: &mut UserAgentConnection,
bi: &mut GrpcBi<UserAgentRequest, UserAgentResponse>,
) -> Result<AuthPublicKey, auth::Error> {
let mut transport = AuthTransportAdapter(bi);
auth::authenticate(conn, transport).await
}

View File

@@ -13,6 +13,7 @@ pub mod db;
pub mod evm; pub mod evm;
pub mod grpc; pub mod grpc;
pub mod safe_cell; pub mod safe_cell;
pub mod utils;
const DEFAULT_CHANNEL_SIZE: usize = 1000; const DEFAULT_CHANNEL_SIZE: usize = 1000;

View File

@@ -0,0 +1,16 @@
struct DeferClosure<F: FnOnce()> {
f: Option<F>,
}
impl<F: FnOnce()> Drop for DeferClosure<F> {
fn drop(&mut self) {
if let Some(f) = self.f.take() {
f();
}
}
}
// Run some code when a scope is exited, similar to Go's defer statement
pub fn defer<F: FnOnce()>(f: F) -> impl Drop + Sized {
DeferClosure { f: Some(f) }
}

View File

@@ -3,7 +3,7 @@ use arbiter_server::{
actors::{ actors::{
GlobalActors, GlobalActors,
bootstrap::GetToken, bootstrap::GetToken,
user_agent::{AuthPublicKey, Request, Response, UserAgentConnection, connect_user_agent}, user_agent::{AuthPublicKey, Request, OutOfBand, UserAgentConnection, connect_user_agent},
}, },
db::{self, schema}, db::{self, schema},
}; };
@@ -118,7 +118,7 @@ pub async fn test_challenge_auth() {
.expect("should receive challenge"); .expect("should receive challenge");
let challenge = match response { let challenge = match response {
Ok(resp) => match resp { Ok(resp) => match resp {
Response::AuthChallenge { nonce } => nonce, OutOfBand::AuthChallenge { nonce } => nonce,
other => panic!("Expected AuthChallenge, got {other:?}"), other => panic!("Expected AuthChallenge, got {other:?}"),
}, },
Err(err) => panic!("Expected Ok response, got Err({err:?})"), Err(err) => panic!("Expected Ok response, got Err({err:?})"),

View File

@@ -2,7 +2,7 @@ use arbiter_server::{
actors::{ actors::{
GlobalActors, GlobalActors,
keyholder::{Bootstrap, Seal}, keyholder::{Bootstrap, Seal},
user_agent::{Request, Response, UnsealError, session::UserAgentSession}, user_agent::{Request, OutOfBand, UnsealError, session::UserAgentSession},
}, },
db, db,
safe_cell::{SafeCell, SafeCellHandle as _}, safe_cell::{SafeCell, SafeCellHandle as _},
@@ -40,7 +40,7 @@ async fn client_dh_encrypt(user_agent: &mut UserAgentSession, key_to_send: &[u8]
.unwrap(); .unwrap();
let server_pubkey = match response { let server_pubkey = match response {
Response::UnsealStartResponse { server_pubkey } => server_pubkey, OutOfBand::UnsealStartResponse { server_pubkey } => server_pubkey,
other => panic!("Expected UnsealStartResponse, got {other:?}"), other => panic!("Expected UnsealStartResponse, got {other:?}"),
}; };
@@ -73,7 +73,7 @@ pub async fn test_unseal_success() {
.await .await
.unwrap(); .unwrap();
assert!(matches!(response, Response::UnsealResult(Ok(())))); assert!(matches!(response, OutOfBand::UnsealResult(Ok(()))));
} }
#[tokio::test] #[tokio::test]
@@ -90,7 +90,7 @@ pub async fn test_unseal_wrong_seal_key() {
assert!(matches!( assert!(matches!(
response, response,
Response::UnsealResult(Err(UnsealError::InvalidKey)) OutOfBand::UnsealResult(Err(UnsealError::InvalidKey))
)); ));
} }
@@ -120,7 +120,7 @@ pub async fn test_unseal_corrupted_ciphertext() {
assert!(matches!( assert!(matches!(
response, response,
Response::UnsealResult(Err(UnsealError::InvalidKey)) OutOfBand::UnsealResult(Err(UnsealError::InvalidKey))
)); ));
} }
@@ -140,7 +140,7 @@ pub async fn test_unseal_retry_after_invalid_key() {
assert!(matches!( assert!(matches!(
response, response,
Response::UnsealResult(Err(UnsealError::InvalidKey)) OutOfBand::UnsealResult(Err(UnsealError::InvalidKey))
)); ));
} }
@@ -152,6 +152,6 @@ pub async fn test_unseal_retry_after_invalid_key() {
.await .await
.unwrap(); .unwrap();
assert!(matches!(response, Response::UnsealResult(Ok(())))); assert!(matches!(response, OutOfBand::UnsealResult(Ok(()))));
} }
} }