refactor(transport): convert Bi trait to use async_trait

This commit is contained in:
hdbg
2026-03-01 13:11:15 +01:00
parent 4b4a8f4489
commit 657f47e32f
9 changed files with 40 additions and 58 deletions

2
server/Cargo.lock generated
View File

@@ -59,6 +59,7 @@ version = "0.1.0"
name = "arbiter-proto"
version = "0.1.0"
dependencies = [
"async-trait",
"base64",
"futures",
"hex",
@@ -122,6 +123,7 @@ name = "arbiter-useragent"
version = "0.1.0"
dependencies = [
"arbiter-proto",
"async-trait",
"ed25519-dalek",
"http",
"kameo",

View File

@@ -20,6 +20,7 @@ rustls-pki-types.workspace = true
base64 = "0.22.1"
prost-types.workspace = true
tracing.workspace = true
async-trait.workspace = true
[build-dependencies]
tonic-prost-build = "0.14.3"

View File

@@ -76,6 +76,8 @@
use std::marker::PhantomData;
use async_trait::async_trait;
/// Errors returned by transport adapters implementing [`Bi`].
pub enum Error {
/// The outbound side of the transport is no longer accepting messages.
@@ -87,13 +89,11 @@ pub enum Error {
/// `Bi<Inbound, Outbound>` models a duplex channel with:
/// - inbound items of type `Inbound` read via [`Bi::recv`]
/// - outbound items of type `Outbound` written via [`Bi::send`]
#[async_trait]
pub trait Bi<Inbound, Outbound>: Send + Sync + 'static {
fn send(
&mut self,
item: Outbound,
) -> impl std::future::Future<Output = Result<(), Error>> + Send;
async fn send(&mut self, item: Outbound) -> Result<(), Error>;
fn recv(&mut self) -> impl std::future::Future<Output = Option<Inbound>> + Send;
async fn recv(&mut self) -> Option<Inbound>;
}
/// Converts transport-facing inbound items into protocol-facing inbound items.
@@ -176,6 +176,7 @@ where
/// gRPC-specific transport adapters and helpers.
pub mod grpc {
use async_trait::async_trait;
use futures::StreamExt;
use tokio::sync::mpsc;
use tonic::Streaming;
@@ -199,7 +200,6 @@ pub mod grpc {
outbound_converter: OutboundConverter,
}
impl<InboundTransport, Inbound, InboundConverter, OutboundConverter>
GrpcAdapter<InboundConverter, OutboundConverter>
where
@@ -221,7 +221,7 @@ pub mod grpc {
}
}
#[async_trait]
impl<InboundConverter, OutboundConverter> Bi<InboundConverter::Output, OutboundConverter::Input>
for GrpcAdapter<InboundConverter, OutboundConverter>
where
@@ -275,6 +275,7 @@ impl<Inbound, Outbound> Default for DummyTransport<Inbound, Outbound> {
}
}
#[async_trait]
impl<Inbound, Outbound> Bi<Inbound, Outbound> for DummyTransport<Inbound, Outbound>
where
Inbound: Send + Sync + 'static,
@@ -284,10 +285,8 @@ where
Ok(())
}
fn recv(&mut self) -> impl std::future::Future<Output = Option<Inbound>> + Send {
async {
async fn recv(&mut self) -> Option<Inbound> {
std::future::pending::<()>().await;
None
}
}
}

View File

@@ -1,8 +1,7 @@
use arbiter_proto::{
proto::client::{
AuthChallenge, AuthChallengeRequest, AuthChallengeSolution, AuthOk, ClientRequest,
ClientResponse,
client_request::Payload as ClientRequestPayload,
ClientResponse, client_request::Payload as ClientRequestPayload,
client_response::Payload as ClientResponsePayload,
},
transport::{Bi, DummyTransport},
@@ -50,19 +49,15 @@ pub enum ClientError {
DatabaseOperationFailed,
}
pub struct ClientActor<Transport>
where
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
{
pub type Transport = Box<dyn Bi<ClientRequest, Result<ClientResponse, ClientError>> + Send>;
pub struct ClientActor {
db: db::DatabasePool,
state: ClientStateMachine<DummyContext>,
transport: Transport,
}
impl<Transport> ClientActor<Transport>
where
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
{
impl ClientActor {
pub(crate) fn new(context: ServerContext, transport: Transport) -> Self {
Self {
db: context.db.clone(),
@@ -197,10 +192,7 @@ where
Ok((valid, challenge_context))
}
async fn handle_auth_challenge_solution(
&mut self,
solution: AuthChallengeSolution,
) -> Output {
async fn handle_auth_challenge_solution(&mut self, solution: AuthChallengeSolution) -> Output {
let (valid, challenge_context) = self.verify_challenge_solution(&solution)?;
if valid {
@@ -226,10 +218,7 @@ fn response(payload: ClientResponsePayload) -> ClientResponse {
}
}
impl<Transport> Actor for ClientActor<Transport>
where
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
{
impl Actor for ClientActor {
type Args = Self;
type Error = ();
@@ -278,12 +267,12 @@ where
}
}
impl ClientActor<DummyTransport<ClientRequest, Result<ClientResponse, ClientError>>> {
impl ClientActor {
pub fn new_manual(db: db::DatabasePool) -> Self {
Self {
db,
state: ClientStateMachine::new(DummyContext),
transport: DummyTransport::new(),
transport: Box::new(DummyTransport::new()),
}
}
}

View File

@@ -71,9 +71,9 @@ pub enum UserAgentError {
DatabaseOperationFailed,
}
pub struct UserAgentActor<Transport>
where
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
pub type Transport = Box<dyn Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>> + Send>;
pub struct UserAgentActor
{
db: db::DatabasePool,
actors: GlobalActors,
@@ -81,10 +81,7 @@ where
transport: Transport,
}
impl<Transport> UserAgentActor<Transport>
where
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
{
impl UserAgentActor {
pub(crate) fn new(context: ServerContext, transport: Transport) -> Self {
Self {
db: context.db.clone(),
@@ -265,10 +262,7 @@ fn response(payload: UserAgentResponsePayload) -> UserAgentResponse {
}
}
impl<Transport> UserAgentActor<Transport>
where
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
{
impl UserAgentActor {
async fn handle_unseal_request(&mut self, req: UnsealStart) -> Output {
let secret = EphemeralSecret::random();
let public_key = PublicKey::from(&secret);
@@ -413,10 +407,7 @@ where
}
impl<Transport> Actor for UserAgentActor<Transport>
where
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
{
impl Actor for UserAgentActor {
type Args = Self;
type Error = ();
@@ -466,13 +457,13 @@ where
}
impl UserAgentActor<DummyTransport<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>> {
impl UserAgentActor {
pub fn new_manual(db: db::DatabasePool, actors: GlobalActors) -> Self {
Self {
db,
actors,
state: UserAgentStateMachine::new(DummyContext),
transport: DummyTransport::new(),
transport: Box::new(DummyTransport::new()),
}
}
}

View File

@@ -170,7 +170,7 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server {
IdentityRecvConverter::<ClientRequest>::new(),
ClientGrpcSender,
);
ClientActor::spawn(ClientActor::new(self.context.clone(), transport));
ClientActor::spawn(ClientActor::new(self.context.clone(), Box::new(transport)));
info!(event = "connection established", "grpc.client");
@@ -191,7 +191,7 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server {
IdentityRecvConverter::<UserAgentRequest>::new(),
UserAgentGrpcSender,
);
UserAgentActor::spawn(UserAgentActor::new(self.context.clone(), transport));
UserAgentActor::spawn(UserAgentActor::new(self.context.clone(), Box::new(transport)));
info!(event = "connection established", "grpc.user_agent");

View File

@@ -1,10 +1,9 @@
use arbiter_proto::proto::user_agent::{
AuthChallengeRequest, UnsealEncryptedKey, UnsealResult, UnsealStart,
UserAgentRequest, UserAgentResponse,
UserAgentRequest,
user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
};
use arbiter_proto::transport::DummyTransport;
use arbiter_server::{
actors::{
GlobalActors,
@@ -18,14 +17,12 @@ use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit};
use memsafe::MemSafe;
use x25519_dalek::{EphemeralSecret, PublicKey};
type TestUserAgent =
UserAgentActor<DummyTransport<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>>;
async fn setup_authenticated_user_agent(
seal_key: &[u8],
) -> (
arbiter_server::db::DatabasePool,
TestUserAgent,
UserAgentActor,
) {
let db = db::create_test_pool().await;
@@ -59,7 +56,7 @@ async fn setup_authenticated_user_agent(
}
async fn client_dh_encrypt(
user_agent: &mut TestUserAgent,
user_agent: &mut UserAgentActor,
key_to_send: &[u8],
) -> UnsealEncryptedKey {
let client_secret = EphemeralSecret::random();

View File

@@ -18,3 +18,4 @@ thiserror.workspace = true
tokio-stream.workspace = true
http = "1.4.0"
rustls-webpki = { version = "0.103.9", features = ["aws-lc-rs"] }
async-trait.workspace = true

View File

@@ -13,12 +13,14 @@ use ed25519_dalek::SigningKey;
use kameo::actor::Spawn;
use tokio::sync::mpsc;
use tokio::time::{Duration, timeout};
use async_trait::async_trait;
struct TestTransport {
inbound_rx: mpsc::Receiver<UserAgentResponse>,
outbound_tx: mpsc::Sender<UserAgentRequest>,
}
#[async_trait]
impl Bi<UserAgentResponse, UserAgentRequest> for TestTransport {
async fn send(&mut self, item: UserAgentRequest) -> Result<(), arbiter_proto::transport::Error> {
self.outbound_tx