refactor(server::useragent): migrated to new connection design

This commit is contained in:
hdbg
2026-03-17 18:39:12 +01:00
committed by Stas
parent c439c9645d
commit d61dab3285
20 changed files with 1151 additions and 958 deletions

1
server/Cargo.lock generated
View File

@@ -697,6 +697,7 @@ dependencies = [
"rustls-pki-types",
"thiserror",
"tokio",
"tokio-stream",
"tonic",
"tonic-prost",
"tonic-prost-build",

View File

@@ -21,6 +21,7 @@ base64 = "0.22.1"
prost-types.workspace = true
tracing.workspace = true
async-trait.workspace = true
tokio-stream.workspace = true
[build-dependencies]
tonic-prost-build = "0.14.3"

View File

@@ -63,16 +63,29 @@ where
extractor(msg).ok_or(Error::UnexpectedMessage)
}
#[async_trait]
pub trait Sender<Outbound>: Send + Sync {
async fn send(&mut self, item: Outbound) -> Result<(), Error>;
}
#[async_trait]
pub trait Receiver<Inbound>: Send + Sync {
async fn recv(&mut self) -> Option<Inbound>;
}
/// Minimal bidirectional transport abstraction used by protocol code.
///
/// `Bi<Inbound, Outbound>` models a duplex channel with:
/// - inbound items of type `Inbound` read via [`Bi::recv`]
/// - outbound items of type `Outbound` written via [`Bi::send`]
#[async_trait]
pub trait Bi<Inbound, Outbound>: Send + Sync + 'static {
async fn send(&mut self, item: Outbound) -> Result<(), Error>;
pub trait Bi<Inbound, Outbound>: Sender<Outbound> + Receiver<Inbound> + Send + Sync {}
async fn recv(&mut self) -> Option<Inbound>;
pub trait SplittableBi<Inbound, Outbound>: Bi<Inbound, Outbound> {
type Sender: Sender<Outbound>;
type Receiver: Receiver<Inbound>;
fn split(self) -> (Self::Sender, Self::Receiver);
fn from_parts(sender: Self::Sender, receiver: Self::Receiver) -> Self;
}
/// No-op [`Bi`] transport for tests and manual actor usage.
@@ -83,22 +96,16 @@ pub struct DummyTransport<Inbound, Outbound> {
_marker: PhantomData<(Inbound, Outbound)>,
}
impl<Inbound, Outbound> DummyTransport<Inbound, Outbound> {
pub fn new() -> Self {
impl<Inbound, Outbound> Default for DummyTransport<Inbound, Outbound> {
fn default() -> Self {
Self {
_marker: PhantomData,
}
}
}
impl<Inbound, Outbound> Default for DummyTransport<Inbound, Outbound> {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl<Inbound, Outbound> Bi<Inbound, Outbound> for DummyTransport<Inbound, Outbound>
impl<Inbound, Outbound> Sender<Outbound> for DummyTransport<Inbound, Outbound>
where
Inbound: Send + Sync + 'static,
Outbound: Send + Sync + 'static,
@@ -106,9 +113,25 @@ where
async fn send(&mut self, _item: Outbound) -> Result<(), Error> {
Ok(())
}
}
#[async_trait]
impl<Inbound, Outbound> Receiver<Inbound> for DummyTransport<Inbound, Outbound>
where
Inbound: Send + Sync + 'static,
Outbound: Send + Sync + 'static,
{
async fn recv(&mut self) -> Option<Inbound> {
std::future::pending::<()>().await;
None
}
}
impl<Inbound, Outbound> Bi<Inbound, Outbound> for DummyTransport<Inbound, Outbound>
where
Inbound: Send + Sync + 'static,
Outbound: Send + Sync + 'static,
{
}
pub mod grpc;

View File

@@ -0,0 +1,106 @@
use async_trait::async_trait;
use futures::StreamExt;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use super::{Bi, Receiver, Sender};
pub struct GrpcSender<Outbound> {
tx: mpsc::Sender<Result<Outbound, tonic::Status>>,
}
#[async_trait]
impl<Outbound> Sender<Result<Outbound, tonic::Status>> for GrpcSender<Outbound>
where
Outbound: Send + Sync + 'static,
{
async fn send(&mut self, item: Result<Outbound, tonic::Status>) -> Result<(), super::Error> {
self.tx
.send(item)
.await
.map_err(|_| super::Error::ChannelClosed)
}
}
pub struct GrpcReceiver<Inbound> {
rx: tonic::Streaming<Inbound>,
}
#[async_trait]
impl<Inbound> Receiver<Result<Inbound, tonic::Status>> for GrpcReceiver<Inbound>
where
Inbound: Send + Sync + 'static,
{
async fn recv(&mut self) -> Option<Result<Inbound, tonic::Status>> {
self.rx.next().await
}
}
pub struct GrpcBi<Inbound, Outbound> {
sender: GrpcSender<Outbound>,
receiver: GrpcReceiver<Inbound>,
}
impl<Inbound, Outbound> GrpcBi<Inbound, Outbound>
where
Inbound: Send + Sync + 'static,
Outbound: Send + Sync + 'static,
{
pub fn from_bi_stream(
receiver: tonic::Streaming<Inbound>,
) -> (Self, ReceiverStream<Result<Outbound, tonic::Status>>) {
let (tx, rx) = mpsc::channel(10);
let sender = GrpcSender { tx };
let receiver = GrpcReceiver { rx: receiver };
let bi = GrpcBi { sender, receiver };
(bi, ReceiverStream::new(rx))
}
}
#[async_trait]
impl<Inbound, Outbound> Sender<Result<Outbound, tonic::Status>> for GrpcBi<Inbound, Outbound>
where
Inbound: Send + Sync + 'static,
Outbound: Send + Sync + 'static,
{
async fn send(&mut self, item: Result<Outbound, tonic::Status>) -> Result<(), super::Error> {
self.sender.send(item).await
}
}
#[async_trait]
impl<Inbound, Outbound> Receiver<Result<Inbound, tonic::Status>> for GrpcBi<Inbound, Outbound>
where
Inbound: Send + Sync + 'static,
Outbound: Send + Sync + 'static,
{
async fn recv(&mut self) -> Option<Result<Inbound, tonic::Status>> {
self.receiver.recv().await
}
}
impl<Inbound, Outbound> Bi<Result<Inbound, tonic::Status>, Result<Outbound, tonic::Status>>
for GrpcBi<Inbound, Outbound>
where
Inbound: Send + Sync + 'static,
Outbound: Send + Sync + 'static,
{
}
impl<Inbound, Outbound>
super::SplittableBi<Result<Inbound, tonic::Status>, Result<Outbound, tonic::Status>>
for GrpcBi<Inbound, Outbound>
where
Inbound: Send + Sync + 'static,
Outbound: Send + Sync + 'static,
{
type Sender = GrpcSender<Outbound>;
type Receiver = GrpcReceiver<Inbound>;
fn split(self) -> (Self::Sender, Self::Receiver) {
(self.sender, self.receiver)
}
fn from_parts(sender: Self::Sender, receiver: Self::Receiver) -> Self {
GrpcBi { sender, receiver }
}
}

View File

@@ -22,7 +22,7 @@ use encryption::v1::{self, KeyCell, Nonce};
pub mod encryption;
#[derive(Default, EnumDiscriminants)]
#[strum_discriminants(derive(Reply), vis(pub))]
#[strum_discriminants(derive(Reply), vis(pub), name(KeyHolderState))]
enum State {
#[default]
Unbootstrapped,
@@ -325,7 +325,7 @@ impl KeyHolder {
}
#[message]
pub fn get_state(&self) -> StateDiscriminants {
pub fn get_state(&self) -> KeyHolderState {
self.state.discriminant()
}

View File

@@ -1,74 +1,82 @@
use arbiter_proto::transport::Bi;
use tracing::error;
use crate::actors::user_agent::{
Request, UserAgentConnection,
AuthPublicKey, UserAgentConnection,
auth::state::{AuthContext, AuthStateMachine},
AuthPublicKey,
session::UserAgentSession,
};
#[derive(thiserror::Error, Debug, PartialEq)]
pub enum Error {
#[error("Unexpected message payload")]
UnexpectedMessagePayload,
#[error("Invalid client public key length")]
InvalidClientPubkeyLength,
#[error("Invalid client public key encoding")]
InvalidAuthPubkeyEncoding,
#[error("Database pool unavailable")]
DatabasePoolUnavailable,
#[error("Database operation failed")]
DatabaseOperationFailed,
#[error("Public key not registered")]
PublicKeyNotRegistered,
#[error("Transport error")]
Transport,
#[error("Invalid bootstrap token")]
InvalidBootstrapToken,
#[error("Bootstrapper actor unreachable")]
BootstrapperActorUnreachable,
#[error("Invalid challenge solution")]
InvalidChallengeSolution,
}
mod state;
use state::*;
fn parse_auth_event(payload: Request) -> Result<AuthEvents, Error> {
match payload {
Request::AuthChallengeRequest {
pubkey,
bootstrap_token: None,
} => Ok(AuthEvents::AuthRequest(ChallengeRequest { pubkey })),
Request::AuthChallengeRequest {
pubkey,
bootstrap_token: Some(token),
} => Ok(AuthEvents::BootstrapAuthRequest(BootstrapAuthRequest {
pubkey,
token,
})),
Request::AuthChallengeSolution { signature } => {
Ok(AuthEvents::ReceivedSolution(ChallengeSolution {
solution: signature,
}))
#[derive(Debug, Clone)]
pub enum Inbound {
AuthChallengeRequest {
pubkey: AuthPublicKey,
bootstrap_token: Option<String>,
},
AuthChallengeSolution {
signature: Vec<u8>,
},
}
#[derive(Debug)]
pub enum Error {
UnregisteredPublicKey,
InvalidChallengeSolution,
InvalidBootstrapToken,
Internal { details: String },
Transport,
}
impl Error {
fn internal(details: impl Into<String>) -> Self {
Self::Internal {
details: details.into(),
}
_ => Err(Error::UnexpectedMessagePayload),
}
}
pub async fn authenticate(props: &mut UserAgentConnection) -> Result<AuthPublicKey, Error> {
let mut state = AuthStateMachine::new(AuthContext::new(props));
#[derive(Debug, Clone)]
pub enum Outbound {
AuthChallenge { nonce: i32 },
AuthSuccess,
}
fn parse_auth_event(payload: Inbound) -> AuthEvents {
match payload {
Inbound::AuthChallengeRequest {
pubkey,
bootstrap_token: None,
} => AuthEvents::AuthRequest(ChallengeRequest { pubkey }),
Inbound::AuthChallengeRequest {
pubkey,
bootstrap_token: Some(token),
} => AuthEvents::BootstrapAuthRequest(BootstrapAuthRequest { pubkey, token }),
Inbound::AuthChallengeSolution { signature } => {
AuthEvents::ReceivedSolution(ChallengeSolution {
solution: signature,
})
}
}
}
pub async fn authenticate<T>(
props: &mut UserAgentConnection,
transport: T,
) -> Result<AuthPublicKey, Error>
where
T: Bi<Inbound, Result<Outbound, Error>> + Send,
{
let mut state = AuthStateMachine::new(AuthContext::new(props, transport));
loop {
// `state` holds a mutable reference to `props` so we can't access it directly here
let transport = state.context_mut().conn.transport.as_mut();
let Some(payload) = transport.recv().await else {
let Some(payload) = state.context_mut().transport.recv().await else {
return Err(Error::Transport);
};
let event = parse_auth_event(payload)?;
match state.process_event(event).await {
match state.process_event(parse_auth_event(payload)).await {
Ok(AuthStates::AuthOk(key)) => return Ok(key.clone()),
Err(AuthError::ActionFailed(err)) => {
error!(?err, "State machine action failed");
@@ -91,11 +99,3 @@ pub async fn authenticate(props: &mut UserAgentConnection) -> Result<AuthPublicK
}
}
}
pub async fn authenticate_and_create(
mut props: UserAgentConnection,
) -> Result<UserAgentSession, Error> {
let _key = authenticate(&mut props).await?;
let session = UserAgentSession::new(props);
Ok(session)
}

View File

@@ -1,3 +1,5 @@
use alloy::transports::Transport;
use arbiter_proto::transport::Bi;
use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, update};
use diesel_async::RunQueryDsl;
use tracing::error;
@@ -6,7 +8,7 @@ use super::Error;
use crate::{
actors::{
bootstrap::ConsumeToken,
user_agent::{AuthPublicKey, Response, UserAgentConnection},
user_agent::{AuthPublicKey, OutOfBand, UserAgentConnection, auth::Outbound},
},
db::schema,
};
@@ -42,7 +44,7 @@ smlang::statemachine!(
async fn create_nonce(db: &crate::db::DatabasePool, pubkey_bytes: &[u8]) -> Result<i32, Error> {
let mut db_conn = db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error");
Error::DatabasePoolUnavailable
Error::internal("Database unavailable")
})?;
db_conn
.exclusive_transaction(|conn| {
@@ -66,11 +68,11 @@ async fn create_nonce(db: &crate::db::DatabasePool, pubkey_bytes: &[u8]) -> Resu
.optional()
.map_err(|e| {
error!(error = ?e, "Database error");
Error::DatabaseOperationFailed
Error::internal("Database operation failed")
})?
.ok_or_else(|| {
error!(?pubkey_bytes, "Public key not found in database");
Error::PublicKeyNotRegistered
Error::UnregisteredPublicKey
})
}
@@ -79,7 +81,7 @@ async fn register_key(db: &crate::db::DatabasePool, pubkey: &AuthPublicKey) -> R
let key_type = pubkey.key_type();
let mut conn = db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error");
Error::DatabasePoolUnavailable
Error::internal("Database unavailable")
})?;
diesel::insert_into(schema::useragent_client::table)
@@ -92,23 +94,27 @@ async fn register_key(db: &crate::db::DatabasePool, pubkey: &AuthPublicKey) -> R
.await
.map_err(|e| {
error!(error = ?e, "Database error");
Error::DatabaseOperationFailed
Error::internal("Database operation failed")
})?;
Ok(())
}
pub struct AuthContext<'a> {
pub struct AuthContext<'a, T> {
pub(super) conn: &'a mut UserAgentConnection,
pub(super) transport: T,
}
impl<'a> AuthContext<'a> {
pub fn new(conn: &'a mut UserAgentConnection) -> Self {
Self { conn }
impl<'a, T> AuthContext<'a, T> {
pub fn new(conn: &'a mut UserAgentConnection, transport: T) -> Self {
Self { conn, transport }
}
}
impl AuthStateMachineContext for AuthContext<'_> {
impl<T> AuthStateMachineContext for AuthContext<'_, T>
where
T: Bi<super::Inbound, Result<super::Outbound, Error>> + Send,
{
type Error = Error;
async fn prepare_challenge(
@@ -118,9 +124,9 @@ impl AuthStateMachineContext for AuthContext<'_> {
let stored_bytes = pubkey.to_stored_bytes();
let nonce = create_nonce(&self.conn.db, &stored_bytes).await?;
self.conn
self
.transport
.send(Ok(Response::AuthChallenge { nonce }))
.send(Ok(Outbound::AuthChallenge { nonce }))
.await
.map_err(|e| {
error!(?e, "Failed to send auth challenge");
@@ -149,7 +155,7 @@ impl AuthStateMachineContext for AuthContext<'_> {
.await
.map_err(|e| {
error!(?e, "Failed to consume bootstrap token");
Error::BootstrapperActorUnreachable
Error::internal("Failed to consume bootstrap token")
})?;
if !token_ok {
@@ -159,11 +165,11 @@ impl AuthStateMachineContext for AuthContext<'_> {
register_key(&self.conn.db, &pubkey).await?;
self.conn
.transport
.send(Ok(Response::AuthOk))
.await
.map_err(|_| Error::Transport)?;
self
.transport
.send(Ok(Outbound::AuthSuccess))
.await
.map_err(|_| Error::Transport)?;
Ok(pubkey)
}
@@ -172,7 +178,10 @@ impl AuthStateMachineContext for AuthContext<'_> {
#[allow(clippy::unused_unit)]
async fn verify_solution(
&mut self,
ChallengeContext { challenge_nonce, key }: &ChallengeContext,
ChallengeContext {
challenge_nonce,
key,
}: &ChallengeContext,
ChallengeSolution { solution }: ChallengeSolution,
) -> Result<AuthPublicKey, Self::Error> {
let formatted = arbiter_proto::format_challenge(*challenge_nonce, &key.to_stored_bytes());
@@ -205,9 +214,9 @@ impl AuthStateMachineContext for AuthContext<'_> {
};
if valid {
self.conn
self
.transport
.send(Ok(Response::AuthOk))
.send(Ok(Outbound::AuthSuccess))
.await
.map_err(|_| Error::Transport)?;
}

View File

@@ -1,33 +1,15 @@
use alloy::primitives::Address;
use arbiter_proto::transport::Bi;
use arbiter_proto::transport::{Bi, Sender};
use kameo::actor::Spawn as _;
use tracing::{error, info};
use crate::{
actors::{GlobalActors, evm, user_agent::session::UserAgentSession},
actors::{GlobalActors, evm},
db::{self, models::KeyType},
evm::policies::SharedGrantSettings,
evm::policies::{Grant, SpecificGrant},
};
#[derive(Debug, thiserror::Error, PartialEq)]
pub enum TransportResponseError {
#[error("Unexpected request payload")]
UnexpectedRequestPayload,
#[error("Invalid state for unseal encrypted key")]
InvalidStateForUnsealEncryptedKey,
#[error("client_pubkey must be 32 bytes")]
InvalidClientPubkeyLength,
#[error("State machine error")]
StateTransitionFailed,
#[error("Vault is not available")]
KeyHolderActorUnreachable,
#[error(transparent)]
Auth(#[from] auth::Error),
#[error("Failed registering connection")]
ConnectionRegistrationFailed,
}
/// Abstraction over Ed25519 / ECDSA-secp256k1 / RSA public keys used during the auth handshake.
#[derive(Clone, Debug)]
pub enum AuthPublicKey {
@@ -65,119 +47,55 @@ impl AuthPublicKey {
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UnsealError {
InvalidKey,
Unbootstrapped,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BootstrapError {
AlreadyBootstrapped,
InvalidKey,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VaultState {
Unbootstrapped,
Sealed,
Unsealed,
}
#[derive(Debug, Clone)]
pub enum Request {
AuthChallengeRequest {
pubkey: AuthPublicKey,
bootstrap_token: Option<String>,
},
AuthChallengeSolution {
signature: Vec<u8>,
},
UnsealStart {
client_pubkey: x25519_dalek::PublicKey,
},
UnsealEncryptedKey {
nonce: Vec<u8>,
ciphertext: Vec<u8>,
associated_data: Vec<u8>,
},
BootstrapEncryptedKey {
nonce: Vec<u8>,
ciphertext: Vec<u8>,
associated_data: Vec<u8>,
},
QueryVaultState,
EvmWalletCreate,
EvmWalletList,
ClientConnectionResponse {
approved: bool,
},
ListGrants,
EvmGrantCreate {
client_id: i32,
shared: SharedGrantSettings,
specific: SpecificGrant,
},
EvmGrantDelete {
grant_id: i32,
},
impl TryFrom<(KeyType, Vec<u8>)> for AuthPublicKey {
type Error = &'static str;
fn try_from(value: (KeyType, Vec<u8>)) -> Result<Self, Self::Error> {
let (key_type, bytes) = value;
match key_type {
KeyType::Ed25519 => {
let bytes: [u8; 32] = bytes.try_into().map_err(|_| "invalid Ed25519 key length")?;
let key = ed25519_dalek::VerifyingKey::from_bytes(&bytes)
.map_err(|e| "invalid Ed25519 key")?;
Ok(AuthPublicKey::Ed25519(key))
}
KeyType::EcdsaSecp256k1 => {
let point =
k256::EncodedPoint::from_bytes(&bytes).map_err(|e| "invalid ECDSA key")?;
let key = k256::ecdsa::VerifyingKey::from_encoded_point(&point)
.map_err(|e| "invalid ECDSA key")?;
Ok(AuthPublicKey::EcdsaSecp256k1(key))
}
KeyType::Rsa => {
use rsa::pkcs8::DecodePublicKey as _;
let key = rsa::RsaPublicKey::from_public_key_der(&bytes)
.map_err(|e| "invalid RSA key")?;
Ok(AuthPublicKey::Rsa(key))
}
}
}
}
// Messages, sent by user agent to connection client without having a request
#[derive(Debug)]
pub enum Response {
AuthChallenge {
nonce: i32,
},
AuthOk,
UnsealStartResponse {
server_pubkey: x25519_dalek::PublicKey,
},
UnsealResult(Result<(), UnsealError>),
BootstrapResult(Result<(), BootstrapError>),
VaultState(VaultState),
ClientConnectionRequest {
pubkey: ed25519_dalek::VerifyingKey,
},
pub enum OutOfBand {
ClientConnectionRequest { pubkey: ed25519_dalek::VerifyingKey },
ClientConnectionCancel,
EvmWalletCreate(Result<(), evm::Error>),
EvmWalletList(Vec<Address>),
ListGrants(Vec<Grant<SpecificGrant>>),
EvmGrantCreate(Result<i32, evm::Error>),
EvmGrantDelete(Result<(), evm::Error>),
}
pub type Transport = Box<dyn Bi<Request, Result<Response, TransportResponseError>> + Send>;
pub struct UserAgentConnection {
db: db::DatabasePool,
actors: GlobalActors,
transport: Transport,
pub(crate) db: db::DatabasePool,
pub(crate) actors: GlobalActors,
}
impl UserAgentConnection {
pub fn new(db: db::DatabasePool, actors: GlobalActors, transport: Transport) -> Self {
Self {
db,
actors,
transport,
}
pub fn new(db: db::DatabasePool, actors: GlobalActors) -> Self {
Self { db, actors }
}
}
pub mod auth;
pub mod session;
#[tracing::instrument(skip(props))]
pub async fn connect_user_agent(props: UserAgentConnection) {
match auth::authenticate_and_create(props).await {
Ok(session) => {
UserAgentSession::spawn(session);
info!("User authenticated, session started");
}
Err(err) => {
error!(?err, "Authentication failed, closing connection");
}
}
}
pub use auth::authenticate;
pub use session::UserAgentSession;

View File

@@ -1,93 +1,63 @@
use std::{borrow::Cow, convert::Infallible};
use arbiter_proto::transport::Sender;
use ed25519_dalek::VerifyingKey;
use kameo::{Actor, messages, prelude::Context};
use thiserror::Error;
use tokio::{select, sync::watch};
use tracing::{error, info};
use crate::actors::{
router::RegisterUserAgent,
user_agent::{
Request, Response, TransportResponseError,
UserAgentConnection,
},
user_agent::{OutOfBand, UserAgentConnection},
};
mod state;
use state::{DummyContext, UserAgentEvents, UserAgentStateMachine};
// Error for consumption by other actors
#[derive(Debug, thiserror::Error, PartialEq)]
#[derive(Debug, Error)]
pub enum Error {
#[error("User agent session ended due to connection loss")]
ConnectionLost,
#[error("State transition failed")]
State,
#[error("User agent session ended due to unexpected message")]
UnexpectedMessage,
#[error("Internal error: {message}")]
Internal { message: Cow<'static, str> },
}
impl Error {
pub fn internal(message: impl Into<Cow<'static, str>>) -> Self {
Self::Internal {
message: message.into(),
}
}
}
pub struct UserAgentSession {
props: UserAgentConnection,
state: UserAgentStateMachine<DummyContext>,
sender: Box<dyn Sender<OutOfBand>>,
}
mod connection;
pub(crate) use connection::{
BootstrapError, HandleBootstrapEncryptedKey, HandleEvmWalletCreate, HandleEvmWalletList,
HandleGrantCreate, HandleGrantDelete, HandleGrantList, HandleQueryVaultState,
HandleUnsealEncryptedKey, HandleUnsealRequest, UnsealError,
};
impl UserAgentSession {
pub(crate) fn new(props: UserAgentConnection) -> Self {
pub(crate) fn new(props: UserAgentConnection, sender: Box<dyn Sender<OutOfBand>>) -> Self {
Self {
props,
state: UserAgentStateMachine::new(DummyContext),
sender,
}
}
pub(super) async fn send_msg<Reply: kameo::Reply>(
&mut self,
msg: Response,
_ctx: &mut Context<Self, Reply>,
) -> Result<(), Error> {
self.props.transport.send(Ok(msg)).await.map_err(|_| {
error!(
actor = "useragent",
reason = "channel closed",
"send.failed"
);
Error::ConnectionLost
})
}
async fn expect_msg<Extractor, Msg, Reply>(
&mut self,
extractor: Extractor,
ctx: &mut Context<Self, Reply>,
) -> Result<Msg, Error>
where
Extractor: FnOnce(Request) -> Option<Msg>,
Reply: kameo::Reply,
{
let msg = self.props.transport.recv().await.ok_or_else(|| {
error!(
actor = "useragent",
reason = "channel closed",
"recv.failed"
);
ctx.stop();
Error::ConnectionLost
})?;
extractor(msg).ok_or_else(|| {
error!(
actor = "useragent",
reason = "unexpected message",
"recv.failed"
);
ctx.stop();
Error::UnexpectedMessage
})
}
fn transition(&mut self, event: UserAgentEvents) -> Result<(), TransportResponseError> {
fn transition(&mut self, event: UserAgentEvents) -> Result<(), Error> {
self.state.process_event(event).map_err(|e| {
error!(?e, "State transition failed");
TransportResponseError::StateTransitionFailed
Error::State
})?;
Ok(())
}
@@ -95,52 +65,21 @@ impl UserAgentSession {
#[messages]
impl UserAgentSession {
// TODO: Think about refactoring it to state-machine based flow, as we already have one
#[message(ctx)]
pub async fn request_new_client_approval(
&mut self,
client_pubkey: VerifyingKey,
mut cancel_flag: watch::Receiver<()>,
ctx: &mut Context<Self, Result<bool, Error>>,
) -> Result<bool, Error> {
self.send_msg(
Response::ClientConnectionRequest {
pubkey: client_pubkey,
},
ctx,
)
.await?;
let extractor = |msg| {
if let Request::ClientConnectionResponse { approved } = msg {
Some(approved)
} else {
None
}
};
tokio::select! {
_ = cancel_flag.changed() => {
info!(actor = "useragent", "client connection approval cancelled");
self.send_msg(
Response::ClientConnectionCancel,
ctx,
).await?;
Ok(false)
}
result = self.expect_msg(extractor, ctx) => {
let result = result?;
info!(actor = "useragent", "received client connection approval result: approved={}", result);
Ok(result)
}
}
ctx: &mut Context<Self, Result<bool, ()>>,
) -> Result<bool, ()> {
todo!("Think about refactoring it to state-machine based flow, as we already have one")
}
}
impl Actor for UserAgentSession {
type Args = Self;
type Error = TransportResponseError;
type Error = Error;
async fn on_start(
args: Self::Args,
@@ -155,56 +94,8 @@ impl Actor for UserAgentSession {
.await
.map_err(|err| {
error!(?err, "Failed to register user agent connection with router");
TransportResponseError::ConnectionRegistrationFailed
Error::internal("Failed to register user agent connection with router")
})?;
Ok(args)
}
async fn next(
&mut self,
_actor_ref: kameo::prelude::WeakActorRef<Self>,
mailbox_rx: &mut kameo::prelude::MailboxReceiver<Self>,
) -> Option<kameo::mailbox::Signal<Self>> {
loop {
select! {
signal = mailbox_rx.recv() => {
return signal;
}
msg = self.props.transport.recv() => {
match msg {
Some(request) => {
match self.process_transport_inbound(request).await {
Ok(response) => {
if self.props.transport.send(Ok(response)).await.is_err() {
error!(actor = "useragent", reason = "channel closed", "send.failed");
return Some(kameo::mailbox::Signal::Stop);
}
}
Err(err) => {
let _ = self.props.transport.send(Err(err)).await;
return Some(kameo::mailbox::Signal::Stop);
}
}
}
None => {
info!(actor = "useragent", "transport.closed");
return Some(kameo::mailbox::Signal::Stop);
}
}
}
}
}
}
}
impl UserAgentSession {
pub fn new_test(db: crate::db::DatabasePool, actors: crate::actors::GlobalActors) -> Self {
use arbiter_proto::transport::DummyTransport;
let transport: super::Transport = Box::new(DummyTransport::new());
let props = UserAgentConnection::new(db, actors, transport);
Self {
props,
state: UserAgentStateMachine::new(DummyContext),
}
}
}

View File

@@ -1,10 +1,15 @@
use std::sync::Mutex;
use alloy::primitives::Address;
use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit};
use kameo::error::SendError;
use kameo::messages;
use tracing::{error, info};
use x25519_dalek::{EphemeralSecret, PublicKey};
use crate::actors::keyholder::KeyHolderState;
use crate::actors::user_agent::session::Error;
use crate::evm::policies::{Grant, SpecificGrant};
use crate::safe_cell::SafeCell;
use crate::{
actors::{
@@ -13,7 +18,7 @@ use crate::{
},
keyholder::{self, Bootstrap, TryUnseal},
user_agent::{
BootstrapError, Request, Response, TransportResponseError, UnsealError, VaultState,
OutOfBand,
session::{
UserAgentSession,
state::{UnsealContext, UserAgentEvents, UserAgentStates},
@@ -24,55 +29,10 @@ use crate::{
};
impl UserAgentSession {
pub async fn process_transport_inbound(&mut self, req: Request) -> Output {
match req {
Request::UnsealStart { client_pubkey } => {
self.handle_unseal_request(client_pubkey).await
}
Request::UnsealEncryptedKey {
nonce,
ciphertext,
associated_data,
} => {
self.handle_unseal_encrypted_key(nonce, ciphertext, associated_data)
.await
}
Request::BootstrapEncryptedKey {
nonce,
ciphertext,
associated_data,
} => {
self.handle_bootstrap_encrypted_key(nonce, ciphertext, associated_data)
.await
}
Request::ListGrants => self.handle_grant_list().await,
Request::QueryVaultState => self.handle_query_vault_state().await,
Request::EvmWalletCreate => self.handle_evm_wallet_create().await,
Request::EvmWalletList => self.handle_evm_wallet_list().await,
Request::AuthChallengeRequest { .. }
| Request::AuthChallengeSolution { .. }
| Request::ClientConnectionResponse { .. } => {
Err(TransportResponseError::UnexpectedRequestPayload)
}
Request::EvmGrantCreate {
client_id,
shared,
specific,
} => self.handle_grant_create(client_id, shared, specific).await,
Request::EvmGrantDelete { grant_id } => self.handle_grant_delete(grant_id).await,
}
}
}
type Output = Result<Response, TransportResponseError>;
impl UserAgentSession {
fn take_unseal_secret(
&mut self,
) -> Result<(EphemeralSecret, PublicKey), TransportResponseError> {
fn take_unseal_secret(&mut self) -> Result<(EphemeralSecret, PublicKey), Error> {
let UserAgentStates::WaitingForUnsealKey(unseal_context) = self.state.state() else {
error!("Received encrypted key in invalid state");
return Err(TransportResponseError::InvalidStateForUnsealEncryptedKey);
return Err(Error::internal("Invalid state for unseal encrypted key"));
};
let ephemeral_secret = {
@@ -87,7 +47,7 @@ impl UserAgentSession {
None => {
drop(secret_lock);
error!("Ephemeral secret already taken");
return Err(TransportResponseError::StateTransitionFailed);
return Err(Error::internal("Ephemeral secret already taken"));
}
}
};
@@ -121,8 +81,38 @@ impl UserAgentSession {
}
}
}
}
async fn handle_unseal_request(&mut self, client_pubkey: x25519_dalek::PublicKey) -> Output {
pub struct UnsealStartResponse {
pub server_pubkey: PublicKey,
}
#[derive(Debug, Error)]
pub enum UnsealError {
#[error("Invalid key provided for unsealing")]
InvalidKey,
#[error("Internal error during unsealing process")]
General(#[from] super::Error),
}
#[derive(Debug, Error)]
pub enum BootstrapError {
#[error("Invalid key provided for bootstrapping")]
InvalidKey,
#[error("Vault is already bootstrapped")]
AlreadyBootstrapped,
#[error("Internal error during bootstrapping process")]
General(#[from] super::Error),
}
#[messages]
impl UserAgentSession {
#[message]
pub(crate) async fn handle_unseal_request(
&mut self,
client_pubkey: x25519_dalek::PublicKey,
) -> Result<UnsealStartResponse, Error> {
let secret = EphemeralSecret::random();
let public_key = PublicKey::from(&secret);
@@ -131,24 +121,27 @@ impl UserAgentSession {
client_public_key: client_pubkey,
}))?;
Ok(Response::UnsealStartResponse {
Ok(UnsealStartResponse {
server_pubkey: public_key,
})
}
async fn handle_unseal_encrypted_key(
#[message]
pub(crate) async fn handle_unseal_encrypted_key(
&mut self,
nonce: Vec<u8>,
ciphertext: Vec<u8>,
associated_data: Vec<u8>,
) -> Output {
) -> Result<(), UnsealError> {
let (ephemeral_secret, client_public_key) = match self.take_unseal_secret() {
Ok(values) => values,
Err(TransportResponseError::StateTransitionFailed) => {
Err(Error::State) => {
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
return Ok(Response::UnsealResult(Err(UnsealError::InvalidKey)));
return Err(UnsealError::InvalidKey);
}
Err(err) => {
return Err(Error::internal("Failed to take unseal secret").into());
}
Err(err) => return Err(err),
};
let seal_key_buffer = match Self::decrypt_client_key_material(
@@ -161,7 +154,7 @@ impl UserAgentSession {
Ok(buffer) => buffer,
Err(()) => {
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
return Ok(Response::UnsealResult(Err(UnsealError::InvalidKey)));
return Err(UnsealError::InvalidKey);
}
};
@@ -177,38 +170,39 @@ impl UserAgentSession {
Ok(_) => {
info!("Successfully unsealed key with client-provided key");
self.transition(UserAgentEvents::ReceivedValidKey)?;
Ok(Response::UnsealResult(Ok(())))
Ok(())
}
Err(SendError::HandlerError(keyholder::Error::InvalidKey)) => {
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(Response::UnsealResult(Err(UnsealError::InvalidKey)))
Err(UnsealError::InvalidKey)
}
Err(SendError::HandlerError(err)) => {
error!(?err, "Keyholder failed to unseal key");
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(Response::UnsealResult(Err(UnsealError::InvalidKey)))
Err(UnsealError::InvalidKey)
}
Err(err) => {
error!(?err, "Failed to send unseal request to keyholder");
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Err(TransportResponseError::KeyHolderActorUnreachable)
Err(Error::internal("Vault actor error").into())
}
}
}
async fn handle_bootstrap_encrypted_key(
#[message]
pub(crate) async fn handle_bootstrap_encrypted_key(
&mut self,
nonce: Vec<u8>,
ciphertext: Vec<u8>,
associated_data: Vec<u8>,
) -> Output {
) -> Result<(), BootstrapError> {
let (ephemeral_secret, client_public_key) = match self.take_unseal_secret() {
Ok(values) => values,
Err(TransportResponseError::StateTransitionFailed) => {
Err(Error::State) => {
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
return Ok(Response::BootstrapResult(Err(BootstrapError::InvalidKey)));
return Err(BootstrapError::InvalidKey);
}
Err(err) => return Err(err),
Err(err) => return Err(err.into()),
};
let seal_key_buffer = match Self::decrypt_client_key_material(
@@ -221,7 +215,7 @@ impl UserAgentSession {
Ok(buffer) => buffer,
Err(()) => {
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
return Ok(Response::BootstrapResult(Err(BootstrapError::InvalidKey)));
return Err(BootstrapError::InvalidKey);
}
};
@@ -237,87 +231,94 @@ impl UserAgentSession {
Ok(_) => {
info!("Successfully bootstrapped vault with client-provided key");
self.transition(UserAgentEvents::ReceivedValidKey)?;
Ok(Response::BootstrapResult(Ok(())))
Ok(())
}
Err(SendError::HandlerError(keyholder::Error::AlreadyBootstrapped)) => {
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(Response::BootstrapResult(Err(
BootstrapError::AlreadyBootstrapped,
)))
Err(BootstrapError::AlreadyBootstrapped)
}
Err(SendError::HandlerError(err)) => {
error!(?err, "Keyholder failed to bootstrap vault");
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(Response::BootstrapResult(Err(BootstrapError::InvalidKey)))
Err(BootstrapError::InvalidKey)
}
Err(err) => {
error!(?err, "Failed to send bootstrap request to keyholder");
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Err(TransportResponseError::KeyHolderActorUnreachable)
Err(BootstrapError::General(Error::internal(
"Vault actor error",
)))
}
}
}
}
#[messages]
impl UserAgentSession {
async fn handle_query_vault_state(&mut self) -> Output {
use crate::actors::keyholder::{GetState, StateDiscriminants};
#[message]
pub(crate) async fn handle_query_vault_state(&mut self) -> Result<KeyHolderState, Error> {
use crate::actors::keyholder::GetState;
let vault_state = match self.props.actors.key_holder.ask(GetState {}).await {
Ok(StateDiscriminants::Unbootstrapped) => VaultState::Unbootstrapped,
Ok(StateDiscriminants::Sealed) => VaultState::Sealed,
Ok(StateDiscriminants::Unsealed) => VaultState::Unsealed,
Ok(state) => state,
Err(err) => {
error!(?err, actor = "useragent", "keyholder.query.failed");
return Err(TransportResponseError::KeyHolderActorUnreachable);
return Err(Error::internal("Vault is in broken state").into());
}
};
Ok(Response::VaultState(vault_state))
Ok(vault_state)
}
}
#[messages]
impl UserAgentSession {
async fn handle_evm_wallet_create(&mut self) -> Output {
let result = match self.props.actors.evm.ask(Generate {}).await {
Ok(_address) => return Ok(Response::EvmWalletCreate(Ok(()))),
Err(SendError::HandlerError(err)) => Err(err),
#[message]
pub(crate) async fn handle_evm_wallet_create(&mut self) -> Result<Address, Error> {
match self.props.actors.evm.ask(Generate {}).await {
Ok(address) => return Ok(address),
Err(SendError::HandlerError(err)) => Err(Error::internal(format!(
"EVM wallet generation failed: {err}"
))),
Err(err) => {
error!(?err, "EVM actor unreachable during wallet create");
return Err(TransportResponseError::KeyHolderActorUnreachable);
return Err(Error::internal("EVM actor unreachable"));
}
};
Ok(Response::EvmWalletCreate(result))
}
}
async fn handle_evm_wallet_list(&mut self) -> Output {
#[message]
pub(crate) async fn handle_evm_wallet_list(&mut self) -> Result<Vec<Address>, Error> {
match self.props.actors.evm.ask(ListWallets {}).await {
Ok(wallets) => Ok(Response::EvmWalletList(wallets)),
Ok(wallets) => Ok(wallets),
Err(err) => {
error!(?err, "EVM wallet list failed");
Err(TransportResponseError::KeyHolderActorUnreachable)
Err(Error::internal("Failed to list EVM wallets"))
}
}
}
}
#[messages]
impl UserAgentSession {
async fn handle_grant_list(&mut self) -> Output {
#[message]
pub(crate) async fn handle_grant_list(&mut self) -> Result<Vec<Grant<SpecificGrant>>, Error> {
match self.props.actors.evm.ask(UseragentListGrants {}).await {
Ok(grants) => Ok(Response::ListGrants(grants)),
Ok(grants) => Ok(grants),
Err(err) => {
error!(?err, "EVM grant list failed");
Err(TransportResponseError::KeyHolderActorUnreachable)
Err(Error::internal("Failed to list EVM grants"))
}
}
}
async fn handle_grant_create(
#[message]
pub(crate) async fn handle_grant_create(
&mut self,
client_id: i32,
basic: crate::evm::policies::SharedGrantSettings,
grant: crate::evm::policies::SpecificGrant,
) -> Output {
) -> Result<i32, Error> {
match self
.props
.actors
@@ -329,15 +330,16 @@ impl UserAgentSession {
})
.await
{
Ok(grant_id) => Ok(Response::EvmGrantCreate(Ok(grant_id))),
Ok(grant_id) => Ok(grant_id),
Err(err) => {
error!(?err, "EVM grant create failed");
Err(TransportResponseError::KeyHolderActorUnreachable)
Err(Error::internal("Failed to create EVM grant"))
}
}
}
async fn handle_grant_delete(&mut self, grant_id: i32) -> Output {
#[message]
pub(crate) async fn handle_grant_delete(&mut self, grant_id: i32) -> Result<(), Error> {
match self
.props
.actors
@@ -345,10 +347,10 @@ impl UserAgentSession {
.ask(UseragentDeleteGrant { grant_id })
.await
{
Ok(()) => Ok(Response::EvmGrantDelete(Ok(()))),
Ok(()) => Ok(()),
Err(err) => {
error!(?err, "EVM grant delete failed");
Err(TransportResponseError::KeyHolderActorUnreachable)
Err(Error::internal("Failed to delete EVM grant"))
}
}
}

View File

@@ -44,6 +44,14 @@ pub enum DatabaseSetupError {
Pool(#[from] PoolInitError),
}
#[derive(Error, Debug)]
pub enum DatabaseError {
#[error("Database connection error")]
Pool(#[from] PoolError),
#[error("Database query error")]
Connection(#[from] diesel::result::Error),
}
#[tracing::instrument(level = "info")]
fn database_path() -> Result<std::path::PathBuf, DatabaseSetupError> {
let arbiter_home = arbiter_proto::home_path().map_err(DatabaseSetupError::HomeDir)?;

View File

@@ -1,14 +1,13 @@
use arbiter_proto::{
proto::client::{
AuthChallenge as ProtoAuthChallenge,
AuthChallengeRequest as ProtoAuthChallengeRequest,
AuthChallenge as ProtoAuthChallenge, AuthChallengeRequest as ProtoAuthChallengeRequest,
AuthChallengeSolution as ProtoAuthChallengeSolution, AuthOk as ProtoAuthOk,
ClientConnectError, ClientRequest, ClientResponse,
client_connect_error::Code as ProtoClientConnectErrorCode,
client_request::Payload as ClientRequestPayload,
client_response::Payload as ClientResponsePayload,
},
transport::{Bi, Error as TransportError},
transport::{Bi, Error as TransportError, Sender},
};
use async_trait::async_trait;
use futures::StreamExt as _;
@@ -37,9 +36,9 @@ impl GrpcTransport {
Some(ClientRequestPayload::AuthChallengeRequest(ProtoAuthChallengeRequest {
pubkey,
})) => Ok(DomainRequest::AuthChallengeRequest { pubkey }),
Some(ClientRequestPayload::AuthChallengeSolution(
ProtoAuthChallengeSolution { signature },
)) => Ok(DomainRequest::AuthChallengeSolution { signature }),
Some(ClientRequestPayload::AuthChallengeSolution(ProtoAuthChallengeSolution {
signature,
})) => Ok(DomainRequest::AuthChallengeSolution { signature }),
None => Err(Status::invalid_argument("Missing client request payload")),
}
}
@@ -86,8 +85,11 @@ impl GrpcTransport {
}
#[async_trait]
impl Bi<DomainRequest, Result<DomainResponse, ClientError>> for GrpcTransport {
async fn send(&mut self, item: Result<DomainResponse, ClientError>) -> Result<(), TransportError> {
impl Sender<Result<DomainResponse, ClientError>> for GrpcTransport {
async fn send(
&mut self,
item: Result<DomainResponse, ClientError>,
) -> Result<(), TransportError> {
let outbound = match item {
Ok(message) => Ok(Self::response_to_proto(message)),
Err(err) => Err(Self::error_to_status(err)),
@@ -98,7 +100,10 @@ impl Bi<DomainRequest, Result<DomainResponse, ClientError>> for GrpcTransport {
.await
.map_err(|_| TransportError::ChannelClosed)
}
}
#[async_trait]
impl Bi<DomainRequest, Result<DomainResponse, ClientError>> for GrpcTransport {
async fn recv(&mut self) -> Option<DomainRequest> {
match self.receiver.next().await {
Some(Ok(item)) => match Self::request_to_domain(item) {

View File

@@ -1,7 +1,9 @@
use arbiter_proto::proto::{
client::{ClientRequest, ClientResponse},
user_agent::{UserAgentRequest, UserAgentResponse},
use arbiter_proto::{
proto::{
client::{ClientRequest, ClientResponse},
user_agent::{UserAgentRequest, UserAgentResponse},
},
transport::grpc::GrpcBi,
};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
@@ -10,7 +12,11 @@ use tracing::info;
use crate::{
DEFAULT_CHANNEL_SIZE,
actors::{client::{ClientConnection, connect_client}, user_agent::{UserAgentConnection, connect_user_agent}},
actors::{
client::{ClientConnection, connect_client},
user_agent::UserAgentConnection,
},
grpc::{self, user_agent::start},
};
pub mod client;
@@ -48,18 +54,19 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for super::Ser
request: Request<tonic::Streaming<UserAgentRequest>>,
) -> Result<Response<Self::UserAgentStream>, Status> {
let req_stream = request.into_inner();
let (tx, rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
let transport = user_agent::GrpcTransport::new(tx, req_stream);
let props = UserAgentConnection::new(
self.context.db.clone(),
self.context.actors.clone(),
Box::new(transport),
);
tokio::spawn(connect_user_agent(props));
let (bi, rx) = GrpcBi::from_bi_stream(req_stream);
tokio::spawn(start(
UserAgentConnection {
db: self.context.db.clone(),
actors: self.context.actors.clone(),
},
bi,
));
info!(event = "connection established", "grpc.user_agent");
Ok(Response::new(ReceiverStream::new(rx)))
Ok(Response::new(rx))
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,151 @@
use arbiter_proto::{
proto::{
self,
evm::{
EtherTransferSettings as ProtoEtherTransferSettings, EvmError as ProtoEvmError,
EvmGrantCreateRequest, EvmGrantCreateResponse, EvmGrantDeleteRequest,
EvmGrantDeleteResponse, EvmGrantList, EvmGrantListResponse, GrantEntry,
SharedSettings as ProtoSharedSettings, SpecificGrant as ProtoSpecificGrant,
TokenTransferSettings as ProtoTokenTransferSettings,
VolumeRateLimit as ProtoVolumeRateLimit, WalletCreateResponse, WalletEntry, WalletList,
WalletListResponse, evm_grant_create_response::Result as EvmGrantCreateResult,
evm_grant_delete_response::Result as EvmGrantDeleteResult,
evm_grant_list_response::Result as EvmGrantListResult,
specific_grant::Grant as ProtoSpecificGrantType,
wallet_create_response::Result as WalletCreateResult,
wallet_list_response::Result as WalletListResult,
},
user_agent::{
AuthChallenge as ProtoAuthChallenge, AuthChallengeRequest as ProtoAuthChallengeRequest,
AuthChallengeSolution as ProtoAuthChallengeSolution, AuthResult as ProtoAuthResult,
BootstrapEncryptedKey as ProtoBootstrapEncryptedKey,
BootstrapResult as ProtoBootstrapResult, ClientConnectionCancel,
ClientConnectionRequest, ClientConnectionResponse, KeyType as ProtoKeyType,
UnsealEncryptedKey as ProtoUnsealEncryptedKey, UnsealResult as ProtoUnsealResult,
UnsealStart, UnsealStartResponse, UserAgentRequest, UserAgentResponse,
VaultState as ProtoVaultState, user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
},
},
transport::{Bi, Error as TransportError, Receiver, Sender, grpc::GrpcBi},
};
use async_trait::async_trait;
use tonic::{Status, Streaming};
use tracing::{info, warn};
use crate::{
actors::user_agent::{
self, AuthPublicKey, OutOfBand as DomainResponse, UserAgentConnection, auth,
},
db::models::KeyType,
evm::policies::{
Grant, SharedGrantSettings, SpecificGrant, TransactionRateLimit, VolumeRateLimit,
ether_transfer, token_transfers,
},
};
use alloy::primitives::{Address, U256};
use chrono::{DateTime, TimeZone, Utc};
pub struct AuthTransportAdapter<'a>(&'a mut GrpcBi<UserAgentRequest, UserAgentResponse>);
#[async_trait]
impl Sender<Result<auth::Outbound, auth::Error>> for AuthTransportAdapter<'_> {
async fn send(
&mut self,
item: Result<auth::Outbound, auth::Error>,
) -> Result<(), TransportError> {
use auth::{Error, Outbound};
let response = match item {
Ok(Outbound::AuthChallenge { nonce }) => Ok(UserAgentResponsePayload::AuthChallenge(
ProtoAuthChallenge { nonce },
)),
Ok(Outbound::AuthSuccess) => Ok(UserAgentResponsePayload::AuthResult(
ProtoAuthResult::Success.into(),
)),
Err(Error::UnregisteredPublicKey) => Ok(UserAgentResponsePayload::AuthResult(
ProtoAuthResult::InvalidKey.into(),
)),
Err(Error::InvalidChallengeSolution) => Ok(UserAgentResponsePayload::AuthResult(
ProtoAuthResult::InvalidSignature.into(),
)),
Err(Error::InvalidBootstrapToken) => Ok(UserAgentResponsePayload::BootstrapResult(
ProtoAuthResult::TokenInvalid.into(),
)),
Err(Error::Internal { details }) => Err(Status::internal(details)),
Err(Error::Transport) => Err(Status::unavailable("transport error")),
};
self.0
.send(response.map(|r| UserAgentResponse { payload: Some(r) }))
.await
}
}
#[async_trait]
impl Receiver<auth::Inbound> for AuthTransportAdapter<'_> {
async fn recv(&mut self) -> Option<auth::Inbound> {
let Ok(UserAgentRequest {
payload: Some(payload),
}) = self.0.recv().await?
else {
warn!(
event = "received request with empty payload",
"grpc.useragent.auth_adapter"
);
return None;
};
match payload {
UserAgentRequestPayload::AuthChallengeRequest(ProtoAuthChallengeRequest {
pubkey,
bootstrap_token,
key_type,
}) => {
let Ok(key_type) = ProtoKeyType::try_from(key_type) else {
warn!(
event = "received request with invalid key type",
"grpc.useragent.auth_adapter"
);
return None;
};
let key_type = match key_type {
ProtoKeyType::Ed25519 => KeyType::Ed25519,
ProtoKeyType::EcdsaSecp256k1 => KeyType::EcdsaSecp256k1,
ProtoKeyType::Rsa => KeyType::Rsa,
ProtoKeyType::Unspecified => {
warn!(
event = "received request with unspecified key type",
"grpc.useragent.auth_adapter"
);
return None;
}
};
let Ok(pubkey) = AuthPublicKey::try_from((key_type, pubkey)) else {
warn!(
event = "received request with invalid public key",
"grpc.useragent.auth_adapter"
);
return None;
};
Some(auth::Inbound::AuthChallengeRequest {
pubkey,
bootstrap_token,
})
}
UserAgentRequestPayload::AuthChallengeSolution(ProtoAuthChallengeSolution {
signature,
}) => Some(auth::Inbound::AuthChallengeSolution { signature }),
_ => None, // Ignore other request types for this adapter
}
}
}
impl Bi<auth::Inbound, Result<auth::Outbound, auth::Error>> for AuthTransportAdapter<'_> {}
pub async fn start(
conn: &mut UserAgentConnection,
bi: &mut GrpcBi<UserAgentRequest, UserAgentResponse>,
) -> Result<AuthPublicKey, auth::Error> {
let mut transport = AuthTransportAdapter(bi);
auth::authenticate(conn, transport).await
}

View File

@@ -13,6 +13,7 @@ pub mod db;
pub mod evm;
pub mod grpc;
pub mod safe_cell;
pub mod utils;
const DEFAULT_CHANNEL_SIZE: usize = 1000;

View File

@@ -0,0 +1,16 @@
struct DeferClosure<F: FnOnce()> {
f: Option<F>,
}
impl<F: FnOnce()> Drop for DeferClosure<F> {
fn drop(&mut self) {
if let Some(f) = self.f.take() {
f();
}
}
}
// Run some code when a scope is exited, similar to Go's defer statement
pub fn defer<F: FnOnce()>(f: F) -> impl Drop + Sized {
DeferClosure { f: Some(f) }
}

View File

@@ -3,7 +3,7 @@ use arbiter_server::{
actors::{
GlobalActors,
bootstrap::GetToken,
user_agent::{AuthPublicKey, Request, Response, UserAgentConnection, connect_user_agent},
user_agent::{AuthPublicKey, Request, OutOfBand, UserAgentConnection, connect_user_agent},
},
db::{self, schema},
};
@@ -118,7 +118,7 @@ pub async fn test_challenge_auth() {
.expect("should receive challenge");
let challenge = match response {
Ok(resp) => match resp {
Response::AuthChallenge { nonce } => nonce,
OutOfBand::AuthChallenge { nonce } => nonce,
other => panic!("Expected AuthChallenge, got {other:?}"),
},
Err(err) => panic!("Expected Ok response, got Err({err:?})"),

View File

@@ -2,7 +2,7 @@ use arbiter_server::{
actors::{
GlobalActors,
keyholder::{Bootstrap, Seal},
user_agent::{Request, Response, UnsealError, session::UserAgentSession},
user_agent::{Request, OutOfBand, UnsealError, session::UserAgentSession},
},
db,
safe_cell::{SafeCell, SafeCellHandle as _},
@@ -40,7 +40,7 @@ async fn client_dh_encrypt(user_agent: &mut UserAgentSession, key_to_send: &[u8]
.unwrap();
let server_pubkey = match response {
Response::UnsealStartResponse { server_pubkey } => server_pubkey,
OutOfBand::UnsealStartResponse { server_pubkey } => server_pubkey,
other => panic!("Expected UnsealStartResponse, got {other:?}"),
};
@@ -73,7 +73,7 @@ pub async fn test_unseal_success() {
.await
.unwrap();
assert!(matches!(response, Response::UnsealResult(Ok(()))));
assert!(matches!(response, OutOfBand::UnsealResult(Ok(()))));
}
#[tokio::test]
@@ -90,7 +90,7 @@ pub async fn test_unseal_wrong_seal_key() {
assert!(matches!(
response,
Response::UnsealResult(Err(UnsealError::InvalidKey))
OutOfBand::UnsealResult(Err(UnsealError::InvalidKey))
));
}
@@ -120,7 +120,7 @@ pub async fn test_unseal_corrupted_ciphertext() {
assert!(matches!(
response,
Response::UnsealResult(Err(UnsealError::InvalidKey))
OutOfBand::UnsealResult(Err(UnsealError::InvalidKey))
));
}
@@ -140,7 +140,7 @@ pub async fn test_unseal_retry_after_invalid_key() {
assert!(matches!(
response,
Response::UnsealResult(Err(UnsealError::InvalidKey))
OutOfBand::UnsealResult(Err(UnsealError::InvalidKey))
));
}
@@ -152,6 +152,6 @@ pub async fn test_unseal_retry_after_invalid_key() {
.await
.unwrap();
assert!(matches!(response, Response::UnsealResult(Ok(()))));
assert!(matches!(response, OutOfBand::UnsealResult(Ok(()))));
}
}