refactor(transport): convert Bi trait to use async_trait
This commit is contained in:
2
server/Cargo.lock
generated
2
server/Cargo.lock
generated
@@ -59,6 +59,7 @@ version = "0.1.0"
|
|||||||
name = "arbiter-proto"
|
name = "arbiter-proto"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"async-trait",
|
||||||
"base64",
|
"base64",
|
||||||
"futures",
|
"futures",
|
||||||
"hex",
|
"hex",
|
||||||
@@ -122,6 +123,7 @@ name = "arbiter-useragent"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arbiter-proto",
|
"arbiter-proto",
|
||||||
|
"async-trait",
|
||||||
"ed25519-dalek",
|
"ed25519-dalek",
|
||||||
"http",
|
"http",
|
||||||
"kameo",
|
"kameo",
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ rustls-pki-types.workspace = true
|
|||||||
base64 = "0.22.1"
|
base64 = "0.22.1"
|
||||||
prost-types.workspace = true
|
prost-types.workspace = true
|
||||||
tracing.workspace = true
|
tracing.workspace = true
|
||||||
|
async-trait.workspace = true
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
tonic-prost-build = "0.14.3"
|
tonic-prost-build = "0.14.3"
|
||||||
|
|||||||
@@ -76,6 +76,8 @@
|
|||||||
|
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
/// Errors returned by transport adapters implementing [`Bi`].
|
/// Errors returned by transport adapters implementing [`Bi`].
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
/// The outbound side of the transport is no longer accepting messages.
|
/// The outbound side of the transport is no longer accepting messages.
|
||||||
@@ -87,13 +89,11 @@ pub enum Error {
|
|||||||
/// `Bi<Inbound, Outbound>` models a duplex channel with:
|
/// `Bi<Inbound, Outbound>` models a duplex channel with:
|
||||||
/// - inbound items of type `Inbound` read via [`Bi::recv`]
|
/// - inbound items of type `Inbound` read via [`Bi::recv`]
|
||||||
/// - outbound items of type `Outbound` written via [`Bi::send`]
|
/// - outbound items of type `Outbound` written via [`Bi::send`]
|
||||||
|
#[async_trait]
|
||||||
pub trait Bi<Inbound, Outbound>: Send + Sync + 'static {
|
pub trait Bi<Inbound, Outbound>: Send + Sync + 'static {
|
||||||
fn send(
|
async fn send(&mut self, item: Outbound) -> Result<(), Error>;
|
||||||
&mut self,
|
|
||||||
item: Outbound,
|
|
||||||
) -> impl std::future::Future<Output = Result<(), Error>> + Send;
|
|
||||||
|
|
||||||
fn recv(&mut self) -> impl std::future::Future<Output = Option<Inbound>> + Send;
|
async fn recv(&mut self) -> Option<Inbound>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Converts transport-facing inbound items into protocol-facing inbound items.
|
/// Converts transport-facing inbound items into protocol-facing inbound items.
|
||||||
@@ -176,6 +176,7 @@ where
|
|||||||
|
|
||||||
/// gRPC-specific transport adapters and helpers.
|
/// gRPC-specific transport adapters and helpers.
|
||||||
pub mod grpc {
|
pub mod grpc {
|
||||||
|
use async_trait::async_trait;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tonic::Streaming;
|
use tonic::Streaming;
|
||||||
@@ -199,7 +200,6 @@ pub mod grpc {
|
|||||||
outbound_converter: OutboundConverter,
|
outbound_converter: OutboundConverter,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl<InboundTransport, Inbound, InboundConverter, OutboundConverter>
|
impl<InboundTransport, Inbound, InboundConverter, OutboundConverter>
|
||||||
GrpcAdapter<InboundConverter, OutboundConverter>
|
GrpcAdapter<InboundConverter, OutboundConverter>
|
||||||
where
|
where
|
||||||
@@ -221,8 +221,8 @@ pub mod grpc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
impl< InboundConverter, OutboundConverter> Bi<InboundConverter::Output, OutboundConverter::Input>
|
impl<InboundConverter, OutboundConverter> Bi<InboundConverter::Output, OutboundConverter::Input>
|
||||||
for GrpcAdapter<InboundConverter, OutboundConverter>
|
for GrpcAdapter<InboundConverter, OutboundConverter>
|
||||||
where
|
where
|
||||||
InboundConverter: RecvConverter,
|
InboundConverter: RecvConverter,
|
||||||
@@ -275,6 +275,7 @@ impl<Inbound, Outbound> Default for DummyTransport<Inbound, Outbound> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
impl<Inbound, Outbound> Bi<Inbound, Outbound> for DummyTransport<Inbound, Outbound>
|
impl<Inbound, Outbound> Bi<Inbound, Outbound> for DummyTransport<Inbound, Outbound>
|
||||||
where
|
where
|
||||||
Inbound: Send + Sync + 'static,
|
Inbound: Send + Sync + 'static,
|
||||||
@@ -284,10 +285,8 @@ where
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn recv(&mut self) -> impl std::future::Future<Output = Option<Inbound>> + Send {
|
async fn recv(&mut self) -> Option<Inbound> {
|
||||||
async {
|
std::future::pending::<()>().await;
|
||||||
std::future::pending::<()>().await;
|
None
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
use arbiter_proto::{
|
use arbiter_proto::{
|
||||||
proto::client::{
|
proto::client::{
|
||||||
AuthChallenge, AuthChallengeRequest, AuthChallengeSolution, AuthOk, ClientRequest,
|
AuthChallenge, AuthChallengeRequest, AuthChallengeSolution, AuthOk, ClientRequest,
|
||||||
ClientResponse,
|
ClientResponse, client_request::Payload as ClientRequestPayload,
|
||||||
client_request::Payload as ClientRequestPayload,
|
|
||||||
client_response::Payload as ClientResponsePayload,
|
client_response::Payload as ClientResponsePayload,
|
||||||
},
|
},
|
||||||
transport::{Bi, DummyTransport},
|
transport::{Bi, DummyTransport},
|
||||||
@@ -50,19 +49,15 @@ pub enum ClientError {
|
|||||||
DatabaseOperationFailed,
|
DatabaseOperationFailed,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ClientActor<Transport>
|
pub type Transport = Box<dyn Bi<ClientRequest, Result<ClientResponse, ClientError>> + Send>;
|
||||||
where
|
|
||||||
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
|
pub struct ClientActor {
|
||||||
{
|
|
||||||
db: db::DatabasePool,
|
db: db::DatabasePool,
|
||||||
state: ClientStateMachine<DummyContext>,
|
state: ClientStateMachine<DummyContext>,
|
||||||
transport: Transport,
|
transport: Transport,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<Transport> ClientActor<Transport>
|
impl ClientActor {
|
||||||
where
|
|
||||||
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
|
|
||||||
{
|
|
||||||
pub(crate) fn new(context: ServerContext, transport: Transport) -> Self {
|
pub(crate) fn new(context: ServerContext, transport: Transport) -> Self {
|
||||||
Self {
|
Self {
|
||||||
db: context.db.clone(),
|
db: context.db.clone(),
|
||||||
@@ -197,10 +192,7 @@ where
|
|||||||
Ok((valid, challenge_context))
|
Ok((valid, challenge_context))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_auth_challenge_solution(
|
async fn handle_auth_challenge_solution(&mut self, solution: AuthChallengeSolution) -> Output {
|
||||||
&mut self,
|
|
||||||
solution: AuthChallengeSolution,
|
|
||||||
) -> Output {
|
|
||||||
let (valid, challenge_context) = self.verify_challenge_solution(&solution)?;
|
let (valid, challenge_context) = self.verify_challenge_solution(&solution)?;
|
||||||
|
|
||||||
if valid {
|
if valid {
|
||||||
@@ -226,10 +218,7 @@ fn response(payload: ClientResponsePayload) -> ClientResponse {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<Transport> Actor for ClientActor<Transport>
|
impl Actor for ClientActor {
|
||||||
where
|
|
||||||
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
|
|
||||||
{
|
|
||||||
type Args = Self;
|
type Args = Self;
|
||||||
|
|
||||||
type Error = ();
|
type Error = ();
|
||||||
@@ -278,12 +267,12 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ClientActor<DummyTransport<ClientRequest, Result<ClientResponse, ClientError>>> {
|
impl ClientActor {
|
||||||
pub fn new_manual(db: db::DatabasePool) -> Self {
|
pub fn new_manual(db: db::DatabasePool) -> Self {
|
||||||
Self {
|
Self {
|
||||||
db,
|
db,
|
||||||
state: ClientStateMachine::new(DummyContext),
|
state: ClientStateMachine::new(DummyContext),
|
||||||
transport: DummyTransport::new(),
|
transport: Box::new(DummyTransport::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -71,9 +71,9 @@ pub enum UserAgentError {
|
|||||||
DatabaseOperationFailed,
|
DatabaseOperationFailed,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct UserAgentActor<Transport>
|
pub type Transport = Box<dyn Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>> + Send>;
|
||||||
where
|
|
||||||
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
|
pub struct UserAgentActor
|
||||||
{
|
{
|
||||||
db: db::DatabasePool,
|
db: db::DatabasePool,
|
||||||
actors: GlobalActors,
|
actors: GlobalActors,
|
||||||
@@ -81,10 +81,7 @@ where
|
|||||||
transport: Transport,
|
transport: Transport,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<Transport> UserAgentActor<Transport>
|
impl UserAgentActor {
|
||||||
where
|
|
||||||
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
|
|
||||||
{
|
|
||||||
pub(crate) fn new(context: ServerContext, transport: Transport) -> Self {
|
pub(crate) fn new(context: ServerContext, transport: Transport) -> Self {
|
||||||
Self {
|
Self {
|
||||||
db: context.db.clone(),
|
db: context.db.clone(),
|
||||||
@@ -265,10 +262,7 @@ fn response(payload: UserAgentResponsePayload) -> UserAgentResponse {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<Transport> UserAgentActor<Transport>
|
impl UserAgentActor {
|
||||||
where
|
|
||||||
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
|
|
||||||
{
|
|
||||||
async fn handle_unseal_request(&mut self, req: UnsealStart) -> Output {
|
async fn handle_unseal_request(&mut self, req: UnsealStart) -> Output {
|
||||||
let secret = EphemeralSecret::random();
|
let secret = EphemeralSecret::random();
|
||||||
let public_key = PublicKey::from(&secret);
|
let public_key = PublicKey::from(&secret);
|
||||||
@@ -413,10 +407,7 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl<Transport> Actor for UserAgentActor<Transport>
|
impl Actor for UserAgentActor {
|
||||||
where
|
|
||||||
Transport: Bi<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>,
|
|
||||||
{
|
|
||||||
type Args = Self;
|
type Args = Self;
|
||||||
|
|
||||||
type Error = ();
|
type Error = ();
|
||||||
@@ -466,13 +457,13 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl UserAgentActor<DummyTransport<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>> {
|
impl UserAgentActor {
|
||||||
pub fn new_manual(db: db::DatabasePool, actors: GlobalActors) -> Self {
|
pub fn new_manual(db: db::DatabasePool, actors: GlobalActors) -> Self {
|
||||||
Self {
|
Self {
|
||||||
db,
|
db,
|
||||||
actors,
|
actors,
|
||||||
state: UserAgentStateMachine::new(DummyContext),
|
state: UserAgentStateMachine::new(DummyContext),
|
||||||
transport: DummyTransport::new(),
|
transport: Box::new(DummyTransport::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -170,7 +170,7 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server {
|
|||||||
IdentityRecvConverter::<ClientRequest>::new(),
|
IdentityRecvConverter::<ClientRequest>::new(),
|
||||||
ClientGrpcSender,
|
ClientGrpcSender,
|
||||||
);
|
);
|
||||||
ClientActor::spawn(ClientActor::new(self.context.clone(), transport));
|
ClientActor::spawn(ClientActor::new(self.context.clone(), Box::new(transport)));
|
||||||
|
|
||||||
info!(event = "connection established", "grpc.client");
|
info!(event = "connection established", "grpc.client");
|
||||||
|
|
||||||
@@ -191,7 +191,7 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server {
|
|||||||
IdentityRecvConverter::<UserAgentRequest>::new(),
|
IdentityRecvConverter::<UserAgentRequest>::new(),
|
||||||
UserAgentGrpcSender,
|
UserAgentGrpcSender,
|
||||||
);
|
);
|
||||||
UserAgentActor::spawn(UserAgentActor::new(self.context.clone(), transport));
|
UserAgentActor::spawn(UserAgentActor::new(self.context.clone(), Box::new(transport)));
|
||||||
|
|
||||||
info!(event = "connection established", "grpc.user_agent");
|
info!(event = "connection established", "grpc.user_agent");
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
use arbiter_proto::proto::user_agent::{
|
use arbiter_proto::proto::user_agent::{
|
||||||
AuthChallengeRequest, UnsealEncryptedKey, UnsealResult, UnsealStart,
|
AuthChallengeRequest, UnsealEncryptedKey, UnsealResult, UnsealStart,
|
||||||
UserAgentRequest, UserAgentResponse,
|
UserAgentRequest,
|
||||||
user_agent_request::Payload as UserAgentRequestPayload,
|
user_agent_request::Payload as UserAgentRequestPayload,
|
||||||
user_agent_response::Payload as UserAgentResponsePayload,
|
user_agent_response::Payload as UserAgentResponsePayload,
|
||||||
};
|
};
|
||||||
use arbiter_proto::transport::DummyTransport;
|
|
||||||
use arbiter_server::{
|
use arbiter_server::{
|
||||||
actors::{
|
actors::{
|
||||||
GlobalActors,
|
GlobalActors,
|
||||||
@@ -18,14 +17,12 @@ use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit};
|
|||||||
use memsafe::MemSafe;
|
use memsafe::MemSafe;
|
||||||
use x25519_dalek::{EphemeralSecret, PublicKey};
|
use x25519_dalek::{EphemeralSecret, PublicKey};
|
||||||
|
|
||||||
type TestUserAgent =
|
|
||||||
UserAgentActor<DummyTransport<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>>;
|
|
||||||
|
|
||||||
async fn setup_authenticated_user_agent(
|
async fn setup_authenticated_user_agent(
|
||||||
seal_key: &[u8],
|
seal_key: &[u8],
|
||||||
) -> (
|
) -> (
|
||||||
arbiter_server::db::DatabasePool,
|
arbiter_server::db::DatabasePool,
|
||||||
TestUserAgent,
|
UserAgentActor,
|
||||||
) {
|
) {
|
||||||
let db = db::create_test_pool().await;
|
let db = db::create_test_pool().await;
|
||||||
|
|
||||||
@@ -59,7 +56,7 @@ async fn setup_authenticated_user_agent(
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn client_dh_encrypt(
|
async fn client_dh_encrypt(
|
||||||
user_agent: &mut TestUserAgent,
|
user_agent: &mut UserAgentActor,
|
||||||
key_to_send: &[u8],
|
key_to_send: &[u8],
|
||||||
) -> UnsealEncryptedKey {
|
) -> UnsealEncryptedKey {
|
||||||
let client_secret = EphemeralSecret::random();
|
let client_secret = EphemeralSecret::random();
|
||||||
|
|||||||
@@ -18,3 +18,4 @@ thiserror.workspace = true
|
|||||||
tokio-stream.workspace = true
|
tokio-stream.workspace = true
|
||||||
http = "1.4.0"
|
http = "1.4.0"
|
||||||
rustls-webpki = { version = "0.103.9", features = ["aws-lc-rs"] }
|
rustls-webpki = { version = "0.103.9", features = ["aws-lc-rs"] }
|
||||||
|
async-trait.workspace = true
|
||||||
|
|||||||
@@ -13,12 +13,14 @@ use ed25519_dalek::SigningKey;
|
|||||||
use kameo::actor::Spawn;
|
use kameo::actor::Spawn;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tokio::time::{Duration, timeout};
|
use tokio::time::{Duration, timeout};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
struct TestTransport {
|
struct TestTransport {
|
||||||
inbound_rx: mpsc::Receiver<UserAgentResponse>,
|
inbound_rx: mpsc::Receiver<UserAgentResponse>,
|
||||||
outbound_tx: mpsc::Sender<UserAgentRequest>,
|
outbound_tx: mpsc::Sender<UserAgentRequest>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
impl Bi<UserAgentResponse, UserAgentRequest> for TestTransport {
|
impl Bi<UserAgentResponse, UserAgentRequest> for TestTransport {
|
||||||
async fn send(&mut self, item: UserAgentRequest) -> Result<(), arbiter_proto::transport::Error> {
|
async fn send(&mut self, item: UserAgentRequest) -> Result<(), arbiter_proto::transport::Error> {
|
||||||
self.outbound_tx
|
self.outbound_tx
|
||||||
|
|||||||
Reference in New Issue
Block a user