diff --git a/server/Cargo.lock b/server/Cargo.lock index f6ab9b4..8f72a04 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -59,6 +59,7 @@ version = "0.1.0" name = "arbiter-proto" version = "0.1.0" dependencies = [ + "async-trait", "base64", "futures", "hex", @@ -122,6 +123,7 @@ name = "arbiter-useragent" version = "0.1.0" dependencies = [ "arbiter-proto", + "async-trait", "ed25519-dalek", "http", "kameo", diff --git a/server/crates/arbiter-proto/Cargo.toml b/server/crates/arbiter-proto/Cargo.toml index 006e406..0673f8a 100644 --- a/server/crates/arbiter-proto/Cargo.toml +++ b/server/crates/arbiter-proto/Cargo.toml @@ -20,6 +20,7 @@ rustls-pki-types.workspace = true base64 = "0.22.1" prost-types.workspace = true tracing.workspace = true +async-trait.workspace = true [build-dependencies] tonic-prost-build = "0.14.3" diff --git a/server/crates/arbiter-proto/src/transport.rs b/server/crates/arbiter-proto/src/transport.rs index 48bb9a3..02ae72c 100644 --- a/server/crates/arbiter-proto/src/transport.rs +++ b/server/crates/arbiter-proto/src/transport.rs @@ -76,6 +76,8 @@ use std::marker::PhantomData; +use async_trait::async_trait; + /// Errors returned by transport adapters implementing [`Bi`]. pub enum Error { /// The outbound side of the transport is no longer accepting messages. @@ -87,13 +89,11 @@ pub enum Error { /// `Bi` models a duplex channel with: /// - inbound items of type `Inbound` read via [`Bi::recv`] /// - outbound items of type `Outbound` written via [`Bi::send`] +#[async_trait] pub trait Bi: Send + Sync + 'static { - fn send( - &mut self, - item: Outbound, - ) -> impl std::future::Future> + Send; + async fn send(&mut self, item: Outbound) -> Result<(), Error>; - fn recv(&mut self) -> impl std::future::Future> + Send; + async fn recv(&mut self) -> Option; } /// Converts transport-facing inbound items into protocol-facing inbound items. @@ -176,6 +176,7 @@ where /// gRPC-specific transport adapters and helpers. pub mod grpc { + use async_trait::async_trait; use futures::StreamExt; use tokio::sync::mpsc; use tonic::Streaming; @@ -199,7 +200,6 @@ pub mod grpc { outbound_converter: OutboundConverter, } - impl GrpcAdapter where @@ -221,8 +221,8 @@ pub mod grpc { } } - - impl< InboundConverter, OutboundConverter> Bi + #[async_trait] + impl Bi for GrpcAdapter where InboundConverter: RecvConverter, @@ -275,6 +275,7 @@ impl Default for DummyTransport { } } +#[async_trait] impl Bi for DummyTransport where Inbound: Send + Sync + 'static, @@ -284,10 +285,8 @@ where Ok(()) } - fn recv(&mut self) -> impl std::future::Future> + Send { - async { - std::future::pending::<()>().await; - None - } + async fn recv(&mut self) -> Option { + std::future::pending::<()>().await; + None } } diff --git a/server/crates/arbiter-server/src/actors/client/mod.rs b/server/crates/arbiter-server/src/actors/client/mod.rs index 8698abb..405bd54 100644 --- a/server/crates/arbiter-server/src/actors/client/mod.rs +++ b/server/crates/arbiter-server/src/actors/client/mod.rs @@ -1,8 +1,7 @@ use arbiter_proto::{ proto::client::{ AuthChallenge, AuthChallengeRequest, AuthChallengeSolution, AuthOk, ClientRequest, - ClientResponse, - client_request::Payload as ClientRequestPayload, + ClientResponse, client_request::Payload as ClientRequestPayload, client_response::Payload as ClientResponsePayload, }, transport::{Bi, DummyTransport}, @@ -50,19 +49,15 @@ pub enum ClientError { DatabaseOperationFailed, } -pub struct ClientActor -where - Transport: Bi>, -{ +pub type Transport = Box> + Send>; + +pub struct ClientActor { db: db::DatabasePool, state: ClientStateMachine, transport: Transport, } -impl ClientActor -where - Transport: Bi>, -{ +impl ClientActor { pub(crate) fn new(context: ServerContext, transport: Transport) -> Self { Self { db: context.db.clone(), @@ -197,10 +192,7 @@ where Ok((valid, challenge_context)) } - async fn handle_auth_challenge_solution( - &mut self, - solution: AuthChallengeSolution, - ) -> Output { + async fn handle_auth_challenge_solution(&mut self, solution: AuthChallengeSolution) -> Output { let (valid, challenge_context) = self.verify_challenge_solution(&solution)?; if valid { @@ -226,10 +218,7 @@ fn response(payload: ClientResponsePayload) -> ClientResponse { } } -impl Actor for ClientActor -where - Transport: Bi>, -{ +impl Actor for ClientActor { type Args = Self; type Error = (); @@ -278,12 +267,12 @@ where } } -impl ClientActor>> { +impl ClientActor { pub fn new_manual(db: db::DatabasePool) -> Self { Self { db, state: ClientStateMachine::new(DummyContext), - transport: DummyTransport::new(), + transport: Box::new(DummyTransport::new()), } } } diff --git a/server/crates/arbiter-server/src/actors/user_agent/mod.rs b/server/crates/arbiter-server/src/actors/user_agent/mod.rs index ba801ee..762ae6d 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/mod.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/mod.rs @@ -71,9 +71,9 @@ pub enum UserAgentError { DatabaseOperationFailed, } -pub struct UserAgentActor -where - Transport: Bi>, +pub type Transport = Box> + Send>; + +pub struct UserAgentActor { db: db::DatabasePool, actors: GlobalActors, @@ -81,10 +81,7 @@ where transport: Transport, } -impl UserAgentActor -where - Transport: Bi>, -{ +impl UserAgentActor { pub(crate) fn new(context: ServerContext, transport: Transport) -> Self { Self { db: context.db.clone(), @@ -265,10 +262,7 @@ fn response(payload: UserAgentResponsePayload) -> UserAgentResponse { } } -impl UserAgentActor -where - Transport: Bi>, -{ +impl UserAgentActor { async fn handle_unseal_request(&mut self, req: UnsealStart) -> Output { let secret = EphemeralSecret::random(); let public_key = PublicKey::from(&secret); @@ -413,10 +407,7 @@ where } -impl Actor for UserAgentActor -where - Transport: Bi>, -{ +impl Actor for UserAgentActor { type Args = Self; type Error = (); @@ -466,13 +457,13 @@ where } -impl UserAgentActor>> { +impl UserAgentActor { pub fn new_manual(db: db::DatabasePool, actors: GlobalActors) -> Self { Self { db, actors, state: UserAgentStateMachine::new(DummyContext), - transport: DummyTransport::new(), + transport: Box::new(DummyTransport::new()), } } } diff --git a/server/crates/arbiter-server/src/lib.rs b/server/crates/arbiter-server/src/lib.rs index e6cb5c5..a7b5ebe 100644 --- a/server/crates/arbiter-server/src/lib.rs +++ b/server/crates/arbiter-server/src/lib.rs @@ -170,7 +170,7 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server { IdentityRecvConverter::::new(), ClientGrpcSender, ); - ClientActor::spawn(ClientActor::new(self.context.clone(), transport)); + ClientActor::spawn(ClientActor::new(self.context.clone(), Box::new(transport))); info!(event = "connection established", "grpc.client"); @@ -191,7 +191,7 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server { IdentityRecvConverter::::new(), UserAgentGrpcSender, ); - UserAgentActor::spawn(UserAgentActor::new(self.context.clone(), transport)); + UserAgentActor::spawn(UserAgentActor::new(self.context.clone(), Box::new(transport))); info!(event = "connection established", "grpc.user_agent"); diff --git a/server/crates/arbiter-server/tests/user_agent/unseal.rs b/server/crates/arbiter-server/tests/user_agent/unseal.rs index 9128a6c..b0f5d1c 100644 --- a/server/crates/arbiter-server/tests/user_agent/unseal.rs +++ b/server/crates/arbiter-server/tests/user_agent/unseal.rs @@ -1,10 +1,9 @@ use arbiter_proto::proto::user_agent::{ AuthChallengeRequest, UnsealEncryptedKey, UnsealResult, UnsealStart, - UserAgentRequest, UserAgentResponse, + UserAgentRequest, user_agent_request::Payload as UserAgentRequestPayload, user_agent_response::Payload as UserAgentResponsePayload, }; -use arbiter_proto::transport::DummyTransport; use arbiter_server::{ actors::{ GlobalActors, @@ -18,14 +17,12 @@ use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit}; use memsafe::MemSafe; use x25519_dalek::{EphemeralSecret, PublicKey}; -type TestUserAgent = - UserAgentActor>>; async fn setup_authenticated_user_agent( seal_key: &[u8], ) -> ( arbiter_server::db::DatabasePool, - TestUserAgent, + UserAgentActor, ) { let db = db::create_test_pool().await; @@ -59,7 +56,7 @@ async fn setup_authenticated_user_agent( } async fn client_dh_encrypt( - user_agent: &mut TestUserAgent, + user_agent: &mut UserAgentActor, key_to_send: &[u8], ) -> UnsealEncryptedKey { let client_secret = EphemeralSecret::random(); diff --git a/server/crates/arbiter-useragent/Cargo.toml b/server/crates/arbiter-useragent/Cargo.toml index de46f67..8b6b85b 100644 --- a/server/crates/arbiter-useragent/Cargo.toml +++ b/server/crates/arbiter-useragent/Cargo.toml @@ -18,3 +18,4 @@ thiserror.workspace = true tokio-stream.workspace = true http = "1.4.0" rustls-webpki = { version = "0.103.9", features = ["aws-lc-rs"] } +async-trait.workspace = true diff --git a/server/crates/arbiter-useragent/tests/auth.rs b/server/crates/arbiter-useragent/tests/auth.rs index cd9dac1..8d79bbe 100644 --- a/server/crates/arbiter-useragent/tests/auth.rs +++ b/server/crates/arbiter-useragent/tests/auth.rs @@ -13,12 +13,14 @@ use ed25519_dalek::SigningKey; use kameo::actor::Spawn; use tokio::sync::mpsc; use tokio::time::{Duration, timeout}; +use async_trait::async_trait; struct TestTransport { inbound_rx: mpsc::Receiver, outbound_tx: mpsc::Sender, } +#[async_trait] impl Bi for TestTransport { async fn send(&mut self, item: UserAgentRequest) -> Result<(), arbiter_proto::transport::Error> { self.outbound_tx