refactor(transport): convert Bi trait to use async_trait
This commit is contained in:
2
server/Cargo.lock
generated
2
server/Cargo.lock
generated
@@ -59,6 +59,7 @@ version = "0.1.0"
|
||||
name = "arbiter-proto"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"base64",
|
||||
"futures",
|
||||
"hex",
|
||||
@@ -122,6 +123,7 @@ name = "arbiter-useragent"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"arbiter-proto",
|
||||
"async-trait",
|
||||
"ed25519-dalek",
|
||||
"http",
|
||||
"kameo",
|
||||
|
||||
@@ -20,6 +20,7 @@ rustls-pki-types.workspace = true
|
||||
base64 = "0.22.1"
|
||||
prost-types.workspace = true
|
||||
tracing.workspace = true
|
||||
async-trait.workspace = true
|
||||
|
||||
[build-dependencies]
|
||||
tonic-prost-build = "0.14.3"
|
||||
|
||||
@@ -76,6 +76,8 @@
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
/// Errors returned by transport adapters implementing [`Bi`].
|
||||
pub enum Error {
|
||||
/// The outbound side of the transport is no longer accepting messages.
|
||||
@@ -87,13 +89,11 @@ pub enum Error {
|
||||
/// `Bi<Inbound, Outbound>` models a duplex channel with:
|
||||
/// - inbound items of type `Inbound` read via [`Bi::recv`]
|
||||
/// - outbound items of type `Outbound` written via [`Bi::send`]
|
||||
#[async_trait]
|
||||
pub trait Bi<Inbound, Outbound>: Send + Sync + 'static {
|
||||
fn send(
|
||||
&mut self,
|
||||
item: Outbound,
|
||||
) -> impl std::future::Future<Output = Result<(), Error>> + Send;
|
||||
async fn send(&mut self, item: Outbound) -> Result<(), Error>;
|
||||
|
||||
fn recv(&mut self) -> impl std::future::Future<Output = Option<Inbound>> + Send;
|
||||
async fn recv(&mut self) -> Option<Inbound>;
|
||||
}
|
||||
|
||||
/// Converts transport-facing inbound items into protocol-facing inbound items.
|
||||
@@ -176,6 +176,7 @@ where
|
||||
|
||||
/// gRPC-specific transport adapters and helpers.
|
||||
pub mod grpc {
|
||||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use tokio::sync::mpsc;
|
||||
use tonic::Streaming;
|
||||
@@ -199,7 +200,6 @@ pub mod grpc {
|
||||
outbound_converter: OutboundConverter,
|
||||
}
|
||||
|
||||
|
||||
impl<InboundTransport, Inbound, InboundConverter, OutboundConverter>
|
||||
GrpcAdapter<InboundConverter, OutboundConverter>
|
||||
where
|
||||
@@ -221,7 +221,7 @@ pub mod grpc {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[async_trait]
|
||||
impl<InboundConverter, OutboundConverter> Bi<InboundConverter::Output, OutboundConverter::Input>
|
||||
for GrpcAdapter<InboundConverter, OutboundConverter>
|
||||
where
|
||||
@@ -275,6 +275,7 @@ impl<Inbound, Outbound> Default for DummyTransport<Inbound, Outbound> {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<Inbound, Outbound> Bi<Inbound, Outbound> for DummyTransport<Inbound, Outbound>
|
||||
where
|
||||
Inbound: Send + Sync + 'static,
|
||||
@@ -284,10 +285,8 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn recv(&mut self) -> impl std::future::Future<Output = Option<Inbound>> + Send {
|
||||
async {
|
||||
async fn recv(&mut self) -> Option<Inbound> {
|
||||
std::future::pending::<()>().await;
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use arbiter_proto::{
|
||||
proto::client::{
|
||||
AuthChallenge, AuthChallengeRequest, AuthChallengeSolution, AuthOk, ClientRequest,
|
||||
ClientResponse,
|
||||
client_request::Payload as ClientRequestPayload,
|
||||
ClientResponse, client_request::Payload as ClientRequestPayload,
|
||||
client_response::Payload as ClientResponsePayload,
|
||||
},
|
||||
transport::{Bi, DummyTransport},
|
||||
@@ -50,19 +49,15 @@ pub enum ClientError {
|
||||
DatabaseOperationFailed,
|
||||
}
|
||||
|
||||
pub struct ClientActor<Transport>
|
||||
where
|
||||
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
|
||||
{
|
||||
pub type Transport = Box<dyn Bi<ClientRequest, Result<ClientResponse, ClientError>> + Send>;
|
||||
|
||||
pub struct ClientActor {
|
||||
db: db::DatabasePool,
|
||||
state: ClientStateMachine<DummyContext>,
|
||||
transport: Transport,
|
||||
}
|
||||
|
||||
impl<Transport> ClientActor<Transport>
|
||||
where
|
||||
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
|
||||
{
|
||||
impl ClientActor {
|
||||
pub(crate) fn new(context: ServerContext, transport: Transport) -> Self {
|
||||
Self {
|
||||
db: context.db.clone(),
|
||||
@@ -197,10 +192,7 @@ where
|
||||
Ok((valid, challenge_context))
|
||||
}
|
||||
|
||||
async fn handle_auth_challenge_solution(
|
||||
&mut self,
|
||||
solution: AuthChallengeSolution,
|
||||
) -> Output {
|
||||
async fn handle_auth_challenge_solution(&mut self, solution: AuthChallengeSolution) -> Output {
|
||||
let (valid, challenge_context) = self.verify_challenge_solution(&solution)?;
|
||||
|
||||
if valid {
|
||||
@@ -226,10 +218,7 @@ fn response(payload: ClientResponsePayload) -> ClientResponse {
|
||||
}
|
||||
}
|
||||
|
||||
impl<Transport> Actor for ClientActor<Transport>
|
||||
where
|
||||
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
|
||||
{
|
||||
impl Actor for ClientActor {
|
||||
type Args = Self;
|
||||
|
||||
type Error = ();
|
||||
@@ -278,12 +267,12 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientActor<DummyTransport<ClientRequest, Result<ClientResponse, ClientError>>> {
|
||||
impl ClientActor {
|
||||
pub fn new_manual(db: db::DatabasePool) -> Self {
|
||||
Self {
|
||||
db,
|
||||
state: ClientStateMachine::new(DummyContext),
|
||||
transport: DummyTransport::new(),
|
||||
transport: Box::new(DummyTransport::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,9 +71,9 @@ pub enum UserAgentError {
|
||||
DatabaseOperationFailed,
|
||||
}
|
||||
|
||||
pub struct UserAgentActor<Transport>
|
||||
where
|
||||
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
|
||||
pub type Transport = Box<dyn Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>> + Send>;
|
||||
|
||||
pub struct UserAgentActor
|
||||
{
|
||||
db: db::DatabasePool,
|
||||
actors: GlobalActors,
|
||||
@@ -81,10 +81,7 @@ where
|
||||
transport: Transport,
|
||||
}
|
||||
|
||||
impl<Transport> UserAgentActor<Transport>
|
||||
where
|
||||
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
|
||||
{
|
||||
impl UserAgentActor {
|
||||
pub(crate) fn new(context: ServerContext, transport: Transport) -> Self {
|
||||
Self {
|
||||
db: context.db.clone(),
|
||||
@@ -265,10 +262,7 @@ fn response(payload: UserAgentResponsePayload) -> UserAgentResponse {
|
||||
}
|
||||
}
|
||||
|
||||
impl<Transport> UserAgentActor<Transport>
|
||||
where
|
||||
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
|
||||
{
|
||||
impl UserAgentActor {
|
||||
async fn handle_unseal_request(&mut self, req: UnsealStart) -> Output {
|
||||
let secret = EphemeralSecret::random();
|
||||
let public_key = PublicKey::from(&secret);
|
||||
@@ -413,10 +407,7 @@ where
|
||||
}
|
||||
|
||||
|
||||
impl<Transport> Actor for UserAgentActor<Transport>
|
||||
where
|
||||
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
|
||||
{
|
||||
impl Actor for UserAgentActor {
|
||||
type Args = Self;
|
||||
|
||||
type Error = ();
|
||||
@@ -466,13 +457,13 @@ where
|
||||
}
|
||||
|
||||
|
||||
impl UserAgentActor<DummyTransport<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>> {
|
||||
impl UserAgentActor {
|
||||
pub fn new_manual(db: db::DatabasePool, actors: GlobalActors) -> Self {
|
||||
Self {
|
||||
db,
|
||||
actors,
|
||||
state: UserAgentStateMachine::new(DummyContext),
|
||||
transport: DummyTransport::new(),
|
||||
transport: Box::new(DummyTransport::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,7 +170,7 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server {
|
||||
IdentityRecvConverter::<ClientRequest>::new(),
|
||||
ClientGrpcSender,
|
||||
);
|
||||
ClientActor::spawn(ClientActor::new(self.context.clone(), transport));
|
||||
ClientActor::spawn(ClientActor::new(self.context.clone(), Box::new(transport)));
|
||||
|
||||
info!(event = "connection established", "grpc.client");
|
||||
|
||||
@@ -191,7 +191,7 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server {
|
||||
IdentityRecvConverter::<UserAgentRequest>::new(),
|
||||
UserAgentGrpcSender,
|
||||
);
|
||||
UserAgentActor::spawn(UserAgentActor::new(self.context.clone(), transport));
|
||||
UserAgentActor::spawn(UserAgentActor::new(self.context.clone(), Box::new(transport)));
|
||||
|
||||
info!(event = "connection established", "grpc.user_agent");
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
use arbiter_proto::proto::user_agent::{
|
||||
AuthChallengeRequest, UnsealEncryptedKey, UnsealResult, UnsealStart,
|
||||
UserAgentRequest, UserAgentResponse,
|
||||
UserAgentRequest,
|
||||
user_agent_request::Payload as UserAgentRequestPayload,
|
||||
user_agent_response::Payload as UserAgentResponsePayload,
|
||||
};
|
||||
use arbiter_proto::transport::DummyTransport;
|
||||
use arbiter_server::{
|
||||
actors::{
|
||||
GlobalActors,
|
||||
@@ -18,14 +17,12 @@ use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit};
|
||||
use memsafe::MemSafe;
|
||||
use x25519_dalek::{EphemeralSecret, PublicKey};
|
||||
|
||||
type TestUserAgent =
|
||||
UserAgentActor<DummyTransport<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>>;
|
||||
|
||||
async fn setup_authenticated_user_agent(
|
||||
seal_key: &[u8],
|
||||
) -> (
|
||||
arbiter_server::db::DatabasePool,
|
||||
TestUserAgent,
|
||||
UserAgentActor,
|
||||
) {
|
||||
let db = db::create_test_pool().await;
|
||||
|
||||
@@ -59,7 +56,7 @@ async fn setup_authenticated_user_agent(
|
||||
}
|
||||
|
||||
async fn client_dh_encrypt(
|
||||
user_agent: &mut TestUserAgent,
|
||||
user_agent: &mut UserAgentActor,
|
||||
key_to_send: &[u8],
|
||||
) -> UnsealEncryptedKey {
|
||||
let client_secret = EphemeralSecret::random();
|
||||
|
||||
@@ -18,3 +18,4 @@ thiserror.workspace = true
|
||||
tokio-stream.workspace = true
|
||||
http = "1.4.0"
|
||||
rustls-webpki = { version = "0.103.9", features = ["aws-lc-rs"] }
|
||||
async-trait.workspace = true
|
||||
|
||||
@@ -13,12 +13,14 @@ use ed25519_dalek::SigningKey;
|
||||
use kameo::actor::Spawn;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::{Duration, timeout};
|
||||
use async_trait::async_trait;
|
||||
|
||||
struct TestTransport {
|
||||
inbound_rx: mpsc::Receiver<UserAgentResponse>,
|
||||
outbound_tx: mpsc::Sender<UserAgentRequest>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Bi<UserAgentResponse, UserAgentRequest> for TestTransport {
|
||||
async fn send(&mut self, item: UserAgentRequest) -> Result<(), arbiter_proto::transport::Error> {
|
||||
self.outbound_tx
|
||||
|
||||
Reference in New Issue
Block a user