refactor(transport): convert Bi trait to use async_trait

This commit is contained in:
hdbg
2026-03-01 13:11:15 +01:00
parent 4b4a8f4489
commit 657f47e32f
9 changed files with 40 additions and 58 deletions

View File

@@ -1,8 +1,7 @@
use arbiter_proto::{
proto::client::{
AuthChallenge, AuthChallengeRequest, AuthChallengeSolution, AuthOk, ClientRequest,
ClientResponse,
client_request::Payload as ClientRequestPayload,
ClientResponse, client_request::Payload as ClientRequestPayload,
client_response::Payload as ClientResponsePayload,
},
transport::{Bi, DummyTransport},
@@ -50,19 +49,15 @@ pub enum ClientError {
DatabaseOperationFailed,
}
pub struct ClientActor<Transport>
where
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
{
pub type Transport = Box<dyn Bi<ClientRequest, Result<ClientResponse, ClientError>> + Send>;
pub struct ClientActor {
db: db::DatabasePool,
state: ClientStateMachine<DummyContext>,
transport: Transport,
}
impl<Transport> ClientActor<Transport>
where
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
{
impl ClientActor {
pub(crate) fn new(context: ServerContext, transport: Transport) -> Self {
Self {
db: context.db.clone(),
@@ -197,10 +192,7 @@ where
Ok((valid, challenge_context))
}
async fn handle_auth_challenge_solution(
&mut self,
solution: AuthChallengeSolution,
) -> Output {
async fn handle_auth_challenge_solution(&mut self, solution: AuthChallengeSolution) -> Output {
let (valid, challenge_context) = self.verify_challenge_solution(&solution)?;
if valid {
@@ -226,10 +218,7 @@ fn response(payload: ClientResponsePayload) -> ClientResponse {
}
}
impl<Transport> Actor for ClientActor<Transport>
where
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
{
impl Actor for ClientActor {
type Args = Self;
type Error = ();
@@ -278,12 +267,12 @@ where
}
}
impl ClientActor<DummyTransport<ClientRequest, Result<ClientResponse, ClientError>>> {
impl ClientActor {
pub fn new_manual(db: db::DatabasePool) -> Self {
Self {
db,
state: ClientStateMachine::new(DummyContext),
transport: DummyTransport::new(),
transport: Box::new(DummyTransport::new()),
}
}
}

View File

@@ -71,9 +71,9 @@ pub enum UserAgentError {
DatabaseOperationFailed,
}
pub struct UserAgentActor<Transport>
where
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
pub type Transport = Box<dyn Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>> + Send>;
pub struct UserAgentActor
{
db: db::DatabasePool,
actors: GlobalActors,
@@ -81,10 +81,7 @@ where
transport: Transport,
}
impl<Transport> UserAgentActor<Transport>
where
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
{
impl UserAgentActor {
pub(crate) fn new(context: ServerContext, transport: Transport) -> Self {
Self {
db: context.db.clone(),
@@ -265,10 +262,7 @@ fn response(payload: UserAgentResponsePayload) -> UserAgentResponse {
}
}
impl<Transport> UserAgentActor<Transport>
where
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
{
impl UserAgentActor {
async fn handle_unseal_request(&mut self, req: UnsealStart) -> Output {
let secret = EphemeralSecret::random();
let public_key = PublicKey::from(&secret);
@@ -413,10 +407,7 @@ where
}
impl<Transport> Actor for UserAgentActor<Transport>
where
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
{
impl Actor for UserAgentActor {
type Args = Self;
type Error = ();
@@ -466,13 +457,13 @@ where
}
impl UserAgentActor<DummyTransport<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>> {
impl UserAgentActor {
pub fn new_manual(db: db::DatabasePool, actors: GlobalActors) -> Self {
Self {
db,
actors,
state: UserAgentStateMachine::new(DummyContext),
transport: DummyTransport::new(),
transport: Box::new(DummyTransport::new()),
}
}
}

View File

@@ -170,7 +170,7 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server {
IdentityRecvConverter::<ClientRequest>::new(),
ClientGrpcSender,
);
ClientActor::spawn(ClientActor::new(self.context.clone(), transport));
ClientActor::spawn(ClientActor::new(self.context.clone(), Box::new(transport)));
info!(event = "connection established", "grpc.client");
@@ -191,7 +191,7 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server {
IdentityRecvConverter::<UserAgentRequest>::new(),
UserAgentGrpcSender,
);
UserAgentActor::spawn(UserAgentActor::new(self.context.clone(), transport));
UserAgentActor::spawn(UserAgentActor::new(self.context.clone(), Box::new(transport)));
info!(event = "connection established", "grpc.user_agent");

View File

@@ -1,10 +1,9 @@
use arbiter_proto::proto::user_agent::{
AuthChallengeRequest, UnsealEncryptedKey, UnsealResult, UnsealStart,
UserAgentRequest, UserAgentResponse,
UserAgentRequest,
user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
};
use arbiter_proto::transport::DummyTransport;
use arbiter_server::{
actors::{
GlobalActors,
@@ -18,14 +17,12 @@ use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit};
use memsafe::MemSafe;
use x25519_dalek::{EphemeralSecret, PublicKey};
type TestUserAgent =
UserAgentActor<DummyTransport<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>>;
async fn setup_authenticated_user_agent(
seal_key: &[u8],
) -> (
arbiter_server::db::DatabasePool,
TestUserAgent,
UserAgentActor,
) {
let db = db::create_test_pool().await;
@@ -59,7 +56,7 @@ async fn setup_authenticated_user_agent(
}
async fn client_dh_encrypt(
user_agent: &mut TestUserAgent,
user_agent: &mut UserAgentActor,
key_to_send: &[u8],
) -> UnsealEncryptedKey {
let client_secret = EphemeralSecret::random();