refactor(transport): convert Bi trait to use async_trait

This commit is contained in:
hdbg
2026-03-01 13:11:15 +01:00
parent 4b4a8f4489
commit 657f47e32f
9 changed files with 40 additions and 58 deletions

2
server/Cargo.lock generated
View File

@@ -59,6 +59,7 @@ version = "0.1.0"
name = "arbiter-proto" name = "arbiter-proto"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"async-trait",
"base64", "base64",
"futures", "futures",
"hex", "hex",
@@ -122,6 +123,7 @@ name = "arbiter-useragent"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"arbiter-proto", "arbiter-proto",
"async-trait",
"ed25519-dalek", "ed25519-dalek",
"http", "http",
"kameo", "kameo",

View File

@@ -20,6 +20,7 @@ rustls-pki-types.workspace = true
base64 = "0.22.1" base64 = "0.22.1"
prost-types.workspace = true prost-types.workspace = true
tracing.workspace = true tracing.workspace = true
async-trait.workspace = true
[build-dependencies] [build-dependencies]
tonic-prost-build = "0.14.3" tonic-prost-build = "0.14.3"

View File

@@ -76,6 +76,8 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use async_trait::async_trait;
/// Errors returned by transport adapters implementing [`Bi`]. /// Errors returned by transport adapters implementing [`Bi`].
pub enum Error { pub enum Error {
/// The outbound side of the transport is no longer accepting messages. /// The outbound side of the transport is no longer accepting messages.
@@ -87,13 +89,11 @@ pub enum Error {
/// `Bi<Inbound, Outbound>` models a duplex channel with: /// `Bi<Inbound, Outbound>` models a duplex channel with:
/// - inbound items of type `Inbound` read via [`Bi::recv`] /// - inbound items of type `Inbound` read via [`Bi::recv`]
/// - outbound items of type `Outbound` written via [`Bi::send`] /// - outbound items of type `Outbound` written via [`Bi::send`]
#[async_trait]
pub trait Bi<Inbound, Outbound>: Send + Sync + 'static { pub trait Bi<Inbound, Outbound>: Send + Sync + 'static {
fn send( async fn send(&mut self, item: Outbound) -> Result<(), Error>;
&mut self,
item: Outbound,
) -> impl std::future::Future<Output = Result<(), Error>> + Send;
fn recv(&mut self) -> impl std::future::Future<Output = Option<Inbound>> + Send; async fn recv(&mut self) -> Option<Inbound>;
} }
/// Converts transport-facing inbound items into protocol-facing inbound items. /// Converts transport-facing inbound items into protocol-facing inbound items.
@@ -176,6 +176,7 @@ where
/// gRPC-specific transport adapters and helpers. /// gRPC-specific transport adapters and helpers.
pub mod grpc { pub mod grpc {
use async_trait::async_trait;
use futures::StreamExt; use futures::StreamExt;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tonic::Streaming; use tonic::Streaming;
@@ -199,7 +200,6 @@ pub mod grpc {
outbound_converter: OutboundConverter, outbound_converter: OutboundConverter,
} }
impl<InboundTransport, Inbound, InboundConverter, OutboundConverter> impl<InboundTransport, Inbound, InboundConverter, OutboundConverter>
GrpcAdapter<InboundConverter, OutboundConverter> GrpcAdapter<InboundConverter, OutboundConverter>
where where
@@ -221,8 +221,8 @@ pub mod grpc {
} }
} }
#[async_trait]
impl< InboundConverter, OutboundConverter> Bi<InboundConverter::Output, OutboundConverter::Input> impl<InboundConverter, OutboundConverter> Bi<InboundConverter::Output, OutboundConverter::Input>
for GrpcAdapter<InboundConverter, OutboundConverter> for GrpcAdapter<InboundConverter, OutboundConverter>
where where
InboundConverter: RecvConverter, InboundConverter: RecvConverter,
@@ -275,6 +275,7 @@ impl<Inbound, Outbound> Default for DummyTransport<Inbound, Outbound> {
} }
} }
#[async_trait]
impl<Inbound, Outbound> Bi<Inbound, Outbound> for DummyTransport<Inbound, Outbound> impl<Inbound, Outbound> Bi<Inbound, Outbound> for DummyTransport<Inbound, Outbound>
where where
Inbound: Send + Sync + 'static, Inbound: Send + Sync + 'static,
@@ -284,10 +285,8 @@ where
Ok(()) Ok(())
} }
fn recv(&mut self) -> impl std::future::Future<Output = Option<Inbound>> + Send { async fn recv(&mut self) -> Option<Inbound> {
async { std::future::pending::<()>().await;
std::future::pending::<()>().await; None
None
}
} }
} }

View File

@@ -1,8 +1,7 @@
use arbiter_proto::{ use arbiter_proto::{
proto::client::{ proto::client::{
AuthChallenge, AuthChallengeRequest, AuthChallengeSolution, AuthOk, ClientRequest, AuthChallenge, AuthChallengeRequest, AuthChallengeSolution, AuthOk, ClientRequest,
ClientResponse, ClientResponse, client_request::Payload as ClientRequestPayload,
client_request::Payload as ClientRequestPayload,
client_response::Payload as ClientResponsePayload, client_response::Payload as ClientResponsePayload,
}, },
transport::{Bi, DummyTransport}, transport::{Bi, DummyTransport},
@@ -50,19 +49,15 @@ pub enum ClientError {
DatabaseOperationFailed, DatabaseOperationFailed,
} }
pub struct ClientActor<Transport> pub type Transport = Box<dyn Bi<ClientRequest, Result<ClientResponse, ClientError>> + Send>;
where
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>, pub struct ClientActor {
{
db: db::DatabasePool, db: db::DatabasePool,
state: ClientStateMachine<DummyContext>, state: ClientStateMachine<DummyContext>,
transport: Transport, transport: Transport,
} }
impl<Transport> ClientActor<Transport> impl ClientActor {
where
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
{
pub(crate) fn new(context: ServerContext, transport: Transport) -> Self { pub(crate) fn new(context: ServerContext, transport: Transport) -> Self {
Self { Self {
db: context.db.clone(), db: context.db.clone(),
@@ -197,10 +192,7 @@ where
Ok((valid, challenge_context)) Ok((valid, challenge_context))
} }
async fn handle_auth_challenge_solution( async fn handle_auth_challenge_solution(&mut self, solution: AuthChallengeSolution) -> Output {
&mut self,
solution: AuthChallengeSolution,
) -> Output {
let (valid, challenge_context) = self.verify_challenge_solution(&solution)?; let (valid, challenge_context) = self.verify_challenge_solution(&solution)?;
if valid { if valid {
@@ -226,10 +218,7 @@ fn response(payload: ClientResponsePayload) -> ClientResponse {
} }
} }
impl<Transport> Actor for ClientActor<Transport> impl Actor for ClientActor {
where
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
{
type Args = Self; type Args = Self;
type Error = (); type Error = ();
@@ -278,12 +267,12 @@ where
} }
} }
impl ClientActor<DummyTransport<ClientRequest, Result<ClientResponse, ClientError>>> { impl ClientActor {
pub fn new_manual(db: db::DatabasePool) -> Self { pub fn new_manual(db: db::DatabasePool) -> Self {
Self { Self {
db, db,
state: ClientStateMachine::new(DummyContext), state: ClientStateMachine::new(DummyContext),
transport: DummyTransport::new(), transport: Box::new(DummyTransport::new()),
} }
} }
} }

View File

@@ -71,9 +71,9 @@ pub enum UserAgentError {
DatabaseOperationFailed, DatabaseOperationFailed,
} }
pub struct UserAgentActor<Transport> pub type Transport = Box<dyn Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>> + Send>;
where
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>, pub struct UserAgentActor
{ {
db: db::DatabasePool, db: db::DatabasePool,
actors: GlobalActors, actors: GlobalActors,
@@ -81,10 +81,7 @@ where
transport: Transport, transport: Transport,
} }
impl<Transport> UserAgentActor<Transport> impl UserAgentActor {
where
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
{
pub(crate) fn new(context: ServerContext, transport: Transport) -> Self { pub(crate) fn new(context: ServerContext, transport: Transport) -> Self {
Self { Self {
db: context.db.clone(), db: context.db.clone(),
@@ -265,10 +262,7 @@ fn response(payload: UserAgentResponsePayload) -> UserAgentResponse {
} }
} }
impl<Transport> UserAgentActor<Transport> impl UserAgentActor {
where
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
{
async fn handle_unseal_request(&mut self, req: UnsealStart) -> Output { async fn handle_unseal_request(&mut self, req: UnsealStart) -> Output {
let secret = EphemeralSecret::random(); let secret = EphemeralSecret::random();
let public_key = PublicKey::from(&secret); let public_key = PublicKey::from(&secret);
@@ -413,10 +407,7 @@ where
} }
impl<Transport> Actor for UserAgentActor<Transport> impl Actor for UserAgentActor {
where
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
{
type Args = Self; type Args = Self;
type Error = (); type Error = ();
@@ -466,13 +457,13 @@ where
} }
impl UserAgentActor<DummyTransport<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>> { impl UserAgentActor {
pub fn new_manual(db: db::DatabasePool, actors: GlobalActors) -> Self { pub fn new_manual(db: db::DatabasePool, actors: GlobalActors) -> Self {
Self { Self {
db, db,
actors, actors,
state: UserAgentStateMachine::new(DummyContext), state: UserAgentStateMachine::new(DummyContext),
transport: DummyTransport::new(), transport: Box::new(DummyTransport::new()),
} }
} }
} }

View File

@@ -170,7 +170,7 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server {
IdentityRecvConverter::<ClientRequest>::new(), IdentityRecvConverter::<ClientRequest>::new(),
ClientGrpcSender, ClientGrpcSender,
); );
ClientActor::spawn(ClientActor::new(self.context.clone(), transport)); ClientActor::spawn(ClientActor::new(self.context.clone(), Box::new(transport)));
info!(event = "connection established", "grpc.client"); info!(event = "connection established", "grpc.client");
@@ -191,7 +191,7 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server {
IdentityRecvConverter::<UserAgentRequest>::new(), IdentityRecvConverter::<UserAgentRequest>::new(),
UserAgentGrpcSender, UserAgentGrpcSender,
); );
UserAgentActor::spawn(UserAgentActor::new(self.context.clone(), transport)); UserAgentActor::spawn(UserAgentActor::new(self.context.clone(), Box::new(transport)));
info!(event = "connection established", "grpc.user_agent"); info!(event = "connection established", "grpc.user_agent");

View File

@@ -1,10 +1,9 @@
use arbiter_proto::proto::user_agent::{ use arbiter_proto::proto::user_agent::{
AuthChallengeRequest, UnsealEncryptedKey, UnsealResult, UnsealStart, AuthChallengeRequest, UnsealEncryptedKey, UnsealResult, UnsealStart,
UserAgentRequest, UserAgentResponse, UserAgentRequest,
user_agent_request::Payload as UserAgentRequestPayload, user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload, user_agent_response::Payload as UserAgentResponsePayload,
}; };
use arbiter_proto::transport::DummyTransport;
use arbiter_server::{ use arbiter_server::{
actors::{ actors::{
GlobalActors, GlobalActors,
@@ -18,14 +17,12 @@ use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit};
use memsafe::MemSafe; use memsafe::MemSafe;
use x25519_dalek::{EphemeralSecret, PublicKey}; use x25519_dalek::{EphemeralSecret, PublicKey};
type TestUserAgent =
UserAgentActor<DummyTransport<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>>;
async fn setup_authenticated_user_agent( async fn setup_authenticated_user_agent(
seal_key: &[u8], seal_key: &[u8],
) -> ( ) -> (
arbiter_server::db::DatabasePool, arbiter_server::db::DatabasePool,
TestUserAgent, UserAgentActor,
) { ) {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
@@ -59,7 +56,7 @@ async fn setup_authenticated_user_agent(
} }
async fn client_dh_encrypt( async fn client_dh_encrypt(
user_agent: &mut TestUserAgent, user_agent: &mut UserAgentActor,
key_to_send: &[u8], key_to_send: &[u8],
) -> UnsealEncryptedKey { ) -> UnsealEncryptedKey {
let client_secret = EphemeralSecret::random(); let client_secret = EphemeralSecret::random();

View File

@@ -18,3 +18,4 @@ thiserror.workspace = true
tokio-stream.workspace = true tokio-stream.workspace = true
http = "1.4.0" http = "1.4.0"
rustls-webpki = { version = "0.103.9", features = ["aws-lc-rs"] } rustls-webpki = { version = "0.103.9", features = ["aws-lc-rs"] }
async-trait.workspace = true

View File

@@ -13,12 +13,14 @@ use ed25519_dalek::SigningKey;
use kameo::actor::Spawn; use kameo::actor::Spawn;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio::time::{Duration, timeout}; use tokio::time::{Duration, timeout};
use async_trait::async_trait;
struct TestTransport { struct TestTransport {
inbound_rx: mpsc::Receiver<UserAgentResponse>, inbound_rx: mpsc::Receiver<UserAgentResponse>,
outbound_tx: mpsc::Sender<UserAgentRequest>, outbound_tx: mpsc::Sender<UserAgentRequest>,
} }
#[async_trait]
impl Bi<UserAgentResponse, UserAgentRequest> for TestTransport { impl Bi<UserAgentResponse, UserAgentRequest> for TestTransport {
async fn send(&mut self, item: UserAgentRequest) -> Result<(), arbiter_proto::transport::Error> { async fn send(&mut self, item: UserAgentRequest) -> Result<(), arbiter_proto::transport::Error> {
self.outbound_tx self.outbound_tx