refactor: consolidate auth messages into client and user_agent packages

This commit is contained in:
hdbg
2026-03-01 11:44:34 +01:00
parent 06f4d628db
commit 4b4a8f4489
19 changed files with 686 additions and 264 deletions

View File

@@ -2,7 +2,6 @@ syntax = "proto3";
package arbiter; package arbiter;
import "auth.proto";
import "client.proto"; import "client.proto";
import "user_agent.proto"; import "user_agent.proto";
@@ -12,6 +11,6 @@ message ServerInfo {
} }
service ArbiterService { service ArbiterService {
rpc Client(stream ClientRequest) returns (stream ClientResponse); rpc Client(stream arbiter.client.ClientRequest) returns (stream arbiter.client.ClientResponse);
rpc UserAgent(stream UserAgentRequest) returns (stream UserAgentResponse); rpc UserAgent(stream arbiter.user_agent.UserAgentRequest) returns (stream arbiter.user_agent.UserAgentResponse);
} }

View File

@@ -1,35 +0,0 @@
syntax = "proto3";
package arbiter.auth;
import "google/protobuf/timestamp.proto";
message AuthChallengeRequest {
bytes pubkey = 1;
optional string bootstrap_token = 2;
}
message AuthChallenge {
bytes pubkey = 1;
int32 nonce = 2;
}
message AuthChallengeSolution {
bytes signature = 1;
}
message AuthOk {}
message ClientMessage {
oneof payload {
AuthChallengeRequest auth_challenge_request = 1;
AuthChallengeSolution auth_challenge_solution = 2;
}
}
message ServerMessage {
oneof payload {
AuthChallenge auth_challenge = 1;
AuthOk auth_ok = 2;
}
}

View File

@@ -1,17 +1,32 @@
syntax = "proto3"; syntax = "proto3";
package arbiter; package arbiter.client;
import "auth.proto"; message AuthChallengeRequest {
bytes pubkey = 1;
}
message AuthChallenge {
bytes pubkey = 1;
int32 nonce = 2;
}
message AuthChallengeSolution {
bytes signature = 1;
}
message AuthOk {}
message ClientRequest { message ClientRequest {
oneof payload { oneof payload {
arbiter.auth.ClientMessage auth_message = 1; AuthChallengeRequest auth_challenge_request = 1;
AuthChallengeSolution auth_challenge_solution = 2;
} }
} }
message ClientResponse { message ClientResponse {
oneof payload { oneof payload {
arbiter.auth.ServerMessage auth_message = 1; AuthChallenge auth_challenge = 1;
AuthOk auth_ok = 2;
} }
} }

View File

@@ -1,10 +1,25 @@
syntax = "proto3"; syntax = "proto3";
package arbiter; package arbiter.user_agent;
import "auth.proto";
import "google/protobuf/empty.proto"; import "google/protobuf/empty.proto";
message AuthChallengeRequest {
bytes pubkey = 1;
optional string bootstrap_token = 2;
}
message AuthChallenge {
bytes pubkey = 1;
int32 nonce = 2;
}
message AuthChallengeSolution {
bytes signature = 1;
}
message AuthOk {}
message UnsealStart { message UnsealStart {
bytes client_pubkey = 1; bytes client_pubkey = 1;
} }
@@ -35,17 +50,19 @@ enum VaultState {
message UserAgentRequest { message UserAgentRequest {
oneof payload { oneof payload {
arbiter.auth.ClientMessage auth_message = 1; AuthChallengeRequest auth_challenge_request = 1;
UnsealStart unseal_start = 2; AuthChallengeSolution auth_challenge_solution = 2;
UnsealEncryptedKey unseal_encrypted_key = 3; UnsealStart unseal_start = 3;
google.protobuf.Empty query_vault_state = 4; UnsealEncryptedKey unseal_encrypted_key = 4;
google.protobuf.Empty query_vault_state = 5;
} }
} }
message UserAgentResponse { message UserAgentResponse {
oneof payload { oneof payload {
arbiter.auth.ServerMessage auth_message = 1; AuthChallenge auth_challenge = 1;
UnsealStartResponse unseal_start_response = 2; AuthOk auth_ok = 2;
UnsealResult unseal_result = 3; UnsealStartResponse unseal_start_response = 3;
VaultState vault_state = 4; UnsealResult unseal_result = 4;
VaultState vault_state = 5;
} }
} }

View File

@@ -11,7 +11,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.compile_protos( .compile_protos(
&[ &[
format!("{}/arbiter.proto", PROTOBUF_DIR), format!("{}/arbiter.proto", PROTOBUF_DIR),
format!("{}/auth.proto", PROTOBUF_DIR), format!("{}/user_agent.proto", PROTOBUF_DIR),
format!("{}/client.proto", PROTOBUF_DIR),
], ],
&[PROTOBUF_DIR.to_string()], &[PROTOBUF_DIR.to_string()],
) )

View File

@@ -3,13 +3,15 @@ pub mod url;
use base64::{Engine, prelude::BASE64_STANDARD}; use base64::{Engine, prelude::BASE64_STANDARD};
use crate::proto::auth::AuthChallenge;
pub mod proto { pub mod proto {
tonic::include_proto!("arbiter"); tonic::include_proto!("arbiter");
pub mod auth { pub mod user_agent {
tonic::include_proto!("arbiter.auth"); tonic::include_proto!("arbiter.user_agent");
}
pub mod client {
tonic::include_proto!("arbiter.client");
} }
} }
@@ -28,7 +30,7 @@ pub fn home_path() -> Result<std::path::PathBuf, std::io::Error> {
Ok(arbiter_home) Ok(arbiter_home)
} }
pub fn format_challenge(challenge: &AuthChallenge) -> Vec<u8> { pub fn format_challenge(nonce: i32, pubkey: &[u8]) -> Vec<u8> {
let concat_form = format!("{}:{}", challenge.nonce, BASE64_STANDARD.encode(&challenge.pubkey)); let concat_form = format!("{}:{}", nonce, BASE64_STANDARD.encode(pubkey));
concat_form.into_bytes().to_vec() concat_form.into_bytes()
} }

View File

@@ -1,12 +0,0 @@
use arbiter_proto::{
proto::{ClientRequest, ClientResponse},
transport::Bi,
};
use crate::ServerContext;
pub(crate) async fn handle_client(
_context: ServerContext,
_bistream: impl Bi<ClientRequest, ClientResponse>,
) {
}

View File

@@ -0,0 +1,289 @@
use arbiter_proto::{
proto::client::{
AuthChallenge, AuthChallengeRequest, AuthChallengeSolution, AuthOk, ClientRequest,
ClientResponse,
client_request::Payload as ClientRequestPayload,
client_response::Payload as ClientResponsePayload,
},
transport::{Bi, DummyTransport},
};
use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, dsl::update};
use diesel_async::RunQueryDsl;
use ed25519_dalek::VerifyingKey;
use kameo::Actor;
use tokio::select;
use tracing::{error, info};
use crate::{
ServerContext,
actors::client::state::{
ChallengeContext, ClientEvents, ClientStateMachine, ClientStates, DummyContext,
},
db::{self, schema},
};
mod state;
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
pub enum ClientError {
#[error("Expected message with payload")]
MissingRequestPayload,
#[error("Unexpected request payload")]
UnexpectedRequestPayload,
#[error("Invalid state for challenge solution")]
InvalidStateForChallengeSolution,
#[error("Expected pubkey to have specific length")]
InvalidAuthPubkeyLength,
#[error("Failed to convert pubkey to VerifyingKey")]
InvalidAuthPubkeyEncoding,
#[error("Invalid signature length")]
InvalidSignatureLength,
#[error("Public key not registered")]
PublicKeyNotRegistered,
#[error("Invalid challenge solution")]
InvalidChallengeSolution,
#[error("State machine error")]
StateTransitionFailed,
#[error("Database pool error")]
DatabasePoolUnavailable,
#[error("Database error")]
DatabaseOperationFailed,
}
pub struct ClientActor<Transport>
where
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
{
db: db::DatabasePool,
state: ClientStateMachine<DummyContext>,
transport: Transport,
}
impl<Transport> ClientActor<Transport>
where
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
{
pub(crate) fn new(context: ServerContext, transport: Transport) -> Self {
Self {
db: context.db.clone(),
state: ClientStateMachine::new(DummyContext),
transport,
}
}
fn transition(&mut self, event: ClientEvents) -> Result<(), ClientError> {
self.state.process_event(event).map_err(|e| {
error!(?e, "State transition failed");
ClientError::StateTransitionFailed
})?;
Ok(())
}
pub async fn process_transport_inbound(&mut self, req: ClientRequest) -> Output {
let msg = req.payload.ok_or_else(|| {
error!(actor = "client", "Received message with no payload");
ClientError::MissingRequestPayload
})?;
match msg {
ClientRequestPayload::AuthChallengeRequest(req) => {
self.handle_auth_challenge_request(req).await
}
ClientRequestPayload::AuthChallengeSolution(solution) => {
self.handle_auth_challenge_solution(solution).await
}
}
}
async fn handle_auth_challenge_request(&mut self, req: AuthChallengeRequest) -> Output {
let pubkey = req
.pubkey
.as_array()
.ok_or(ClientError::InvalidAuthPubkeyLength)?;
let pubkey = VerifyingKey::from_bytes(pubkey).map_err(|_err| {
error!(?pubkey, "Failed to convert to VerifyingKey");
ClientError::InvalidAuthPubkeyEncoding
})?;
self.transition(ClientEvents::AuthRequest)?;
self.auth_with_challenge(pubkey, req.pubkey).await
}
async fn auth_with_challenge(&mut self, pubkey: VerifyingKey, pubkey_bytes: Vec<u8>) -> Output {
let nonce: Option<i32> = {
let mut db_conn = self.db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error");
ClientError::DatabasePoolUnavailable
})?;
db_conn
.exclusive_transaction(|conn| {
Box::pin(async move {
let current_nonce = schema::program_client::table
.filter(
schema::program_client::public_key.eq(pubkey.as_bytes().to_vec()),
)
.select(schema::program_client::nonce)
.first::<i32>(conn)
.await?;
update(schema::program_client::table)
.filter(
schema::program_client::public_key.eq(pubkey.as_bytes().to_vec()),
)
.set(schema::program_client::nonce.eq(current_nonce + 1))
.execute(conn)
.await?;
Result::<_, diesel::result::Error>::Ok(current_nonce)
})
})
.await
.optional()
.map_err(|e| {
error!(error = ?e, "Database error");
ClientError::DatabaseOperationFailed
})?
};
let Some(nonce) = nonce else {
error!(?pubkey, "Public key not found in database");
return Err(ClientError::PublicKeyNotRegistered);
};
let challenge = AuthChallenge {
pubkey: pubkey_bytes,
nonce,
};
self.transition(ClientEvents::SentChallenge(ChallengeContext {
challenge: challenge.clone(),
key: pubkey,
}))?;
info!(
?pubkey,
?challenge,
"Sent authentication challenge to client"
);
Ok(response(ClientResponsePayload::AuthChallenge(challenge)))
}
fn verify_challenge_solution(
&self,
solution: &AuthChallengeSolution,
) -> Result<(bool, &ChallengeContext), ClientError> {
let ClientStates::WaitingForChallengeSolution(challenge_context) = self.state.state()
else {
error!("Received challenge solution in invalid state");
return Err(ClientError::InvalidStateForChallengeSolution);
};
let formatted_challenge = arbiter_proto::format_challenge(
challenge_context.challenge.nonce,
&challenge_context.challenge.pubkey,
);
let signature = solution.signature.as_slice().try_into().map_err(|_| {
error!(?solution, "Invalid signature length");
ClientError::InvalidSignatureLength
})?;
let valid = challenge_context
.key
.verify_strict(&formatted_challenge, &signature)
.is_ok();
Ok((valid, challenge_context))
}
async fn handle_auth_challenge_solution(
&mut self,
solution: AuthChallengeSolution,
) -> Output {
let (valid, challenge_context) = self.verify_challenge_solution(&solution)?;
if valid {
info!(
?challenge_context,
"Client provided valid solution to authentication challenge"
);
self.transition(ClientEvents::ReceivedGoodSolution)?;
Ok(response(ClientResponsePayload::AuthOk(AuthOk {})))
} else {
error!("Client provided invalid solution to authentication challenge");
self.transition(ClientEvents::ReceivedBadSolution)?;
Err(ClientError::InvalidChallengeSolution)
}
}
}
type Output = Result<ClientResponse, ClientError>;
fn response(payload: ClientResponsePayload) -> ClientResponse {
ClientResponse {
payload: Some(payload),
}
}
impl<Transport> Actor for ClientActor<Transport>
where
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
{
type Args = Self;
type Error = ();
async fn on_start(
args: Self::Args,
_: kameo::prelude::ActorRef<Self>,
) -> Result<Self, Self::Error> {
Ok(args)
}
async fn next(
&mut self,
_actor_ref: kameo::prelude::WeakActorRef<Self>,
mailbox_rx: &mut kameo::prelude::MailboxReceiver<Self>,
) -> Option<kameo::mailbox::Signal<Self>> {
loop {
select! {
signal = mailbox_rx.recv() => {
return signal;
}
msg = self.transport.recv() => {
match msg {
Some(request) => {
match self.process_transport_inbound(request).await {
Ok(resp) => {
if self.transport.send(Ok(resp)).await.is_err() {
error!(actor = "client", reason = "channel closed", "send.failed");
return Some(kameo::mailbox::Signal::Stop);
}
}
Err(err) => {
let _ = self.transport.send(Err(err)).await;
return Some(kameo::mailbox::Signal::Stop);
}
}
}
None => {
info!(actor = "client", "transport.closed");
return Some(kameo::mailbox::Signal::Stop);
}
}
}
}
}
}
}
impl ClientActor<DummyTransport<ClientRequest, Result<ClientResponse, ClientError>>> {
pub fn new_manual(db: db::DatabasePool) -> Self {
Self {
db,
state: ClientStateMachine::new(DummyContext),
transport: DummyTransport::new(),
}
}
}

View File

@@ -0,0 +1,31 @@
use arbiter_proto::proto::client::AuthChallenge;
use ed25519_dalek::VerifyingKey;
/// Context for state machine with validated key and sent challenge
#[derive(Clone, Debug)]
pub struct ChallengeContext {
pub challenge: AuthChallenge,
pub key: VerifyingKey,
}
smlang::statemachine!(
name: Client,
custom_error: false,
transitions: {
*Init + AuthRequest = ReceivedAuthRequest,
ReceivedAuthRequest + SentChallenge(ChallengeContext) / move_challenge = WaitingForChallengeSolution(ChallengeContext),
WaitingForChallengeSolution(ChallengeContext) + ReceivedGoodSolution = Idle,
WaitingForChallengeSolution(ChallengeContext) + ReceivedBadSolution = AuthError,
}
);
pub struct DummyContext;
impl ClientStateMachineContext for DummyContext {
#[allow(missing_docs)]
#[allow(clippy::unused_unit)]
fn move_challenge(&mut self, event_data: ChallengeContext) -> Result<ChallengeContext, ()> {
Ok(event_data)
}
}

View File

@@ -1,14 +1,9 @@
use std::{ops::DerefMut, sync::Mutex}; use std::{ops::DerefMut, sync::Mutex};
use arbiter_proto::{ use arbiter_proto::{
proto::{ proto::user_agent::{
UnsealEncryptedKey, UnsealResult, UnsealStart, UnsealStartResponse, UserAgentRequest, AuthChallenge, AuthChallengeRequest, AuthChallengeSolution, AuthOk, UnsealEncryptedKey,
UserAgentResponse, UnsealResult, UnsealStart, UnsealStartResponse, UserAgentRequest, UserAgentResponse,
auth::{
self, AuthChallengeRequest, AuthOk, ClientMessage as ClientAuthMessage,
ServerMessage as AuthServerMessage, client_message::Payload as ClientAuthPayload,
server_message::Payload as ServerAuthPayload,
},
user_agent_request::Payload as UserAgentRequestPayload, user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload, user_agent_response::Payload as UserAgentResponsePayload,
}, },
@@ -114,12 +109,12 @@ where
})?; })?;
match msg { match msg {
UserAgentRequestPayload::AuthMessage(ClientAuthMessage { UserAgentRequestPayload::AuthChallengeRequest(req) => {
payload: Some(ClientAuthPayload::AuthChallengeRequest(req)), self.handle_auth_challenge_request(req).await
}) => self.handle_auth_challenge_request(req).await, }
UserAgentRequestPayload::AuthMessage(ClientAuthMessage { UserAgentRequestPayload::AuthChallengeSolution(solution) => {
payload: Some(ClientAuthPayload::AuthChallengeSolution(solution)), self.handle_auth_challenge_solution(solution).await
}) => self.handle_auth_challenge_solution(solution).await, }
UserAgentRequestPayload::UnsealStart(unseal_start) => { UserAgentRequestPayload::UnsealStart(unseal_start) => {
self.handle_unseal_request(unseal_start).await self.handle_unseal_request(unseal_start).await
} }
@@ -171,7 +166,7 @@ where
self.transition(UserAgentEvents::ReceivedBootstrapToken)?; self.transition(UserAgentEvents::ReceivedBootstrapToken)?;
Ok(auth_response(ServerAuthPayload::AuthOk(AuthOk {}))) Ok(response(UserAgentResponsePayload::AuthOk(AuthOk {})))
} }
async fn auth_with_challenge(&mut self, pubkey: VerifyingKey, pubkey_bytes: Vec<u8>) -> Output { async fn auth_with_challenge(&mut self, pubkey: VerifyingKey, pubkey_bytes: Vec<u8>) -> Output {
@@ -215,7 +210,7 @@ where
return Err(UserAgentError::PublicKeyNotRegistered); return Err(UserAgentError::PublicKeyNotRegistered);
}; };
let challenge = auth::AuthChallenge { let challenge = AuthChallenge {
pubkey: pubkey_bytes, pubkey: pubkey_bytes,
nonce, nonce,
}; };
@@ -231,19 +226,22 @@ where
"Sent authentication challenge to client" "Sent authentication challenge to client"
); );
Ok(auth_response(ServerAuthPayload::AuthChallenge(challenge))) Ok(response(UserAgentResponsePayload::AuthChallenge(challenge)))
} }
fn verify_challenge_solution( fn verify_challenge_solution(
&self, &self,
solution: &auth::AuthChallengeSolution, solution: &AuthChallengeSolution,
) -> Result<(bool, &ChallengeContext), UserAgentError> { ) -> Result<(bool, &ChallengeContext), UserAgentError> {
let UserAgentStates::WaitingForChallengeSolution(challenge_context) = self.state.state() let UserAgentStates::WaitingForChallengeSolution(challenge_context) = self.state.state()
else { else {
error!("Received challenge solution in invalid state"); error!("Received challenge solution in invalid state");
return Err(UserAgentError::InvalidStateForChallengeSolution); return Err(UserAgentError::InvalidStateForChallengeSolution);
}; };
let formatted_challenge = arbiter_proto::format_challenge(&challenge_context.challenge); let formatted_challenge = arbiter_proto::format_challenge(
challenge_context.challenge.nonce,
&challenge_context.challenge.pubkey,
);
let signature = solution.signature.as_slice().try_into().map_err(|_| { let signature = solution.signature.as_slice().try_into().map_err(|_| {
error!(?solution, "Invalid signature length"); error!(?solution, "Invalid signature length");
@@ -261,15 +259,7 @@ where
type Output = Result<UserAgentResponse, UserAgentError>; type Output = Result<UserAgentResponse, UserAgentError>;
fn auth_response(payload: ServerAuthPayload) -> UserAgentResponse { fn response(payload: UserAgentResponsePayload) -> UserAgentResponse {
UserAgentResponse {
payload: Some(UserAgentResponsePayload::AuthMessage(AuthServerMessage {
payload: Some(payload),
})),
}
}
fn unseal_response(payload: UserAgentResponsePayload) -> UserAgentResponse {
UserAgentResponse { UserAgentResponse {
payload: Some(payload), payload: Some(payload),
} }
@@ -295,7 +285,7 @@ where
client_public_key, client_public_key,
}))?; }))?;
Ok(unseal_response( Ok(response(
UserAgentResponsePayload::UnsealStartResponse(UnsealStartResponse { UserAgentResponsePayload::UnsealStartResponse(UnsealStartResponse {
server_pubkey: public_key.as_bytes().to_vec(), server_pubkey: public_key.as_bytes().to_vec(),
}), }),
@@ -316,7 +306,7 @@ where
drop(secret_lock); drop(secret_lock);
error!("Ephemeral secret already taken"); error!("Ephemeral secret already taken");
self.transition(UserAgentEvents::ReceivedInvalidKey)?; self.transition(UserAgentEvents::ReceivedInvalidKey)?;
return Ok(unseal_response(UserAgentResponsePayload::UnsealResult( return Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(), UnsealResult::InvalidKey.into(),
))); )));
} }
@@ -349,20 +339,20 @@ where
Ok(_) => { Ok(_) => {
info!("Successfully unsealed key with client-provided key"); info!("Successfully unsealed key with client-provided key");
self.transition(UserAgentEvents::ReceivedValidKey)?; self.transition(UserAgentEvents::ReceivedValidKey)?;
Ok(unseal_response(UserAgentResponsePayload::UnsealResult( Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::Success.into(), UnsealResult::Success.into(),
))) )))
} }
Err(SendError::HandlerError(keyholder::Error::InvalidKey)) => { Err(SendError::HandlerError(keyholder::Error::InvalidKey)) => {
self.transition(UserAgentEvents::ReceivedInvalidKey)?; self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(unseal_response(UserAgentResponsePayload::UnsealResult( Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(), UnsealResult::InvalidKey.into(),
))) )))
} }
Err(SendError::HandlerError(err)) => { Err(SendError::HandlerError(err)) => {
error!(?err, "Keyholder failed to unseal key"); error!(?err, "Keyholder failed to unseal key");
self.transition(UserAgentEvents::ReceivedInvalidKey)?; self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(unseal_response(UserAgentResponsePayload::UnsealResult( Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(), UnsealResult::InvalidKey.into(),
))) )))
} }
@@ -376,7 +366,7 @@ where
Err(err) => { Err(err) => {
error!(?err, "Failed to decrypt unseal key"); error!(?err, "Failed to decrypt unseal key");
self.transition(UserAgentEvents::ReceivedInvalidKey)?; self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(unseal_response(UserAgentResponsePayload::UnsealResult( Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(), UnsealResult::InvalidKey.into(),
))) )))
} }
@@ -403,7 +393,7 @@ where
async fn handle_auth_challenge_solution( async fn handle_auth_challenge_solution(
&mut self, &mut self,
solution: auth::AuthChallengeSolution, solution: AuthChallengeSolution,
) -> Output { ) -> Output {
let (valid, challenge_context) = self.verify_challenge_solution(&solution)?; let (valid, challenge_context) = self.verify_challenge_solution(&solution)?;
@@ -413,7 +403,7 @@ where
"Client provided valid solution to authentication challenge" "Client provided valid solution to authentication challenge"
); );
self.transition(UserAgentEvents::ReceivedGoodSolution)?; self.transition(UserAgentEvents::ReceivedGoodSolution)?;
Ok(auth_response(ServerAuthPayload::AuthOk(AuthOk {}))) Ok(response(UserAgentResponsePayload::AuthOk(AuthOk {})))
} else { } else {
error!("Client provided invalid solution to authentication challenge"); error!("Client provided invalid solution to authentication challenge");
self.transition(UserAgentEvents::ReceivedBadSolution)?; self.transition(UserAgentEvents::ReceivedBadSolution)?;

View File

@@ -1,6 +1,6 @@
use std::sync::Mutex; use std::sync::Mutex;
use arbiter_proto::proto::auth::AuthChallenge; use arbiter_proto::proto::user_agent::AuthChallenge;
use ed25519_dalek::VerifyingKey; use ed25519_dalek::VerifyingKey;
use x25519_dalek::{EphemeralSecret, PublicKey}; use x25519_dalek::{EphemeralSecret, PublicKey};

View File

@@ -1,6 +1,9 @@
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
use arbiter_proto::{ use arbiter_proto::{
proto::{ClientRequest, ClientResponse, UserAgentRequest, UserAgentResponse}, proto::{
client::{ClientRequest, ClientResponse},
user_agent::{UserAgentRequest, UserAgentResponse},
},
transport::{IdentityRecvConverter, SendConverter, grpc}, transport::{IdentityRecvConverter, SendConverter, grpc},
}; };
use async_trait::async_trait; use async_trait::async_trait;
@@ -12,7 +15,10 @@ use tonic::{Request, Response, Status};
use tracing::info; use tracing::info;
use crate::{ use crate::{
actors::user_agent::{UserAgentActor, UserAgentError}, actors::{
client::{ClientActor, ClientError},
user_agent::{UserAgentActor, UserAgentError},
},
context::ServerContext, context::ServerContext,
}; };
@@ -41,6 +47,56 @@ impl SendConverter for UserAgentGrpcSender {
} }
} }
/// Converts Client domain outbounds into the tonic stream item emitted by the
/// server.
///
/// The conversion is defined at the server boundary so the actor module remains
/// focused on domain semantics and does not depend on tonic status encoding.
struct ClientGrpcSender;
impl SendConverter for ClientGrpcSender {
type Input = Result<ClientResponse, ClientError>;
type Output = Result<ClientResponse, Status>;
fn convert(&self, item: Self::Input) -> Self::Output {
match item {
Ok(message) => Ok(message),
Err(err) => Err(client_error_status(err)),
}
}
}
/// Maps Client domain errors to public gRPC transport errors for the `client`
/// streaming endpoint.
fn client_error_status(value: ClientError) -> Status {
match value {
ClientError::MissingRequestPayload | ClientError::UnexpectedRequestPayload => {
Status::invalid_argument("Expected message with payload")
}
ClientError::InvalidStateForChallengeSolution => {
Status::invalid_argument("Invalid state for challenge solution")
}
ClientError::InvalidAuthPubkeyLength => {
Status::invalid_argument("Expected pubkey to have specific length")
}
ClientError::InvalidAuthPubkeyEncoding => {
Status::invalid_argument("Failed to convert pubkey to VerifyingKey")
}
ClientError::InvalidSignatureLength => {
Status::invalid_argument("Invalid signature length")
}
ClientError::PublicKeyNotRegistered => {
Status::unauthenticated("Public key not registered")
}
ClientError::InvalidChallengeSolution => {
Status::unauthenticated("Invalid challenge solution")
}
ClientError::StateTransitionFailed => Status::internal("State machine error"),
ClientError::DatabasePoolUnavailable => Status::internal("Database pool error"),
ClientError::DatabaseOperationFailed => Status::internal("Database error"),
}
}
/// Maps User Agent domain errors to public gRPC transport errors for the /// Maps User Agent domain errors to public gRPC transport errors for the
/// `user_agent` streaming endpoint. /// `user_agent` streaming endpoint.
fn user_agent_error_status(value: UserAgentError) -> Status { fn user_agent_error_status(value: UserAgentError) -> Status {
@@ -100,11 +156,25 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server {
type UserAgentStream = ReceiverStream<Result<UserAgentResponse, Status>>; type UserAgentStream = ReceiverStream<Result<UserAgentResponse, Status>>;
type ClientStream = ReceiverStream<Result<ClientResponse, Status>>; type ClientStream = ReceiverStream<Result<ClientResponse, Status>>;
#[tracing::instrument(level = "debug", skip(self))]
async fn client( async fn client(
&self, &self,
_request: Request<tonic::Streaming<ClientRequest>>, request: Request<tonic::Streaming<ClientRequest>>,
) -> Result<Response<Self::ClientStream>, Status> { ) -> Result<Response<Self::ClientStream>, Status> {
todo!() let req_stream = request.into_inner();
let (tx, rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
let transport = grpc::GrpcAdapter::new(
tx,
req_stream,
IdentityRecvConverter::<ClientRequest>::new(),
ClientGrpcSender,
);
ClientActor::spawn(ClientActor::new(self.context.clone(), transport));
info!(event = "connection established", "grpc.client");
Ok(Response::new(ReceiverStream::new(rx)))
} }
#[tracing::instrument(level = "debug", skip(self))] #[tracing::instrument(level = "debug", skip(self))]

View File

@@ -0,0 +1,2 @@
#[path = "client/auth.rs"]
mod auth;

View File

@@ -0,0 +1,102 @@
use arbiter_proto::proto::client::{
AuthChallengeRequest, AuthChallengeSolution, AuthOk, ClientRequest, ClientResponse,
client_request::Payload as ClientRequestPayload,
client_response::Payload as ClientResponsePayload,
};
use arbiter_server::{
actors::client::{ClientActor, ClientError},
db::{self, schema},
};
use diesel::{ExpressionMethods as _, insert_into};
use diesel_async::RunQueryDsl;
use ed25519_dalek::Signer as _;
#[tokio::test]
#[test_log::test]
pub async fn test_unregistered_pubkey_rejected() {
let db = db::create_test_pool().await;
let mut client = ClientActor::new_manual(db.clone());
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
let result = client
.process_transport_inbound(ClientRequest {
payload: Some(ClientRequestPayload::AuthChallengeRequest(
AuthChallengeRequest {
pubkey: pubkey_bytes,
},
)),
})
.await;
match result {
Err(err) => {
assert_eq!(err, ClientError::PublicKeyNotRegistered);
}
Ok(_) => {
panic!("Expected error due to unregistered pubkey, but got success");
}
}
}
#[tokio::test]
#[test_log::test]
pub async fn test_challenge_auth() {
let db = db::create_test_pool().await;
let mut client = ClientActor::new_manual(db.clone());
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
{
let mut conn = db.get().await.unwrap();
insert_into(schema::program_client::table)
.values(schema::program_client::public_key.eq(pubkey_bytes.clone()))
.execute(&mut conn)
.await
.unwrap();
}
let result = client
.process_transport_inbound(ClientRequest {
payload: Some(ClientRequestPayload::AuthChallengeRequest(
AuthChallengeRequest {
pubkey: pubkey_bytes,
},
)),
})
.await
.expect("Shouldn't fail to process message");
let ClientResponse {
payload: Some(ClientResponsePayload::AuthChallenge(challenge)),
} = result
else {
panic!("Expected auth challenge response, got {result:?}");
};
let formatted_challenge = arbiter_proto::format_challenge(challenge.nonce, &challenge.pubkey);
let signature = new_key.sign(&formatted_challenge);
let serialized_signature = signature.to_bytes().to_vec();
let result = client
.process_transport_inbound(ClientRequest {
payload: Some(ClientRequestPayload::AuthChallengeSolution(
AuthChallengeSolution {
signature: serialized_signature,
},
)),
})
.await
.expect("Shouldn't fail to process message");
assert_eq!(
result,
ClientResponse {
payload: Some(ClientResponsePayload::AuthOk(AuthOk {})),
}
);
}

View File

@@ -1,7 +1,5 @@
use arbiter_proto::proto::{ use arbiter_proto::proto::user_agent::{
UserAgentResponse, AuthChallengeRequest, AuthChallengeSolution, AuthOk, UserAgentRequest, UserAgentResponse,
UserAgentRequest,
auth::{self, AuthChallengeRequest, AuthOk, ClientMessage, client_message::Payload as ClientAuthPayload},
user_agent_request::Payload as UserAgentRequestPayload, user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload, user_agent_response::Payload as UserAgentResponsePayload,
}; };
@@ -17,18 +15,10 @@ use diesel::{ExpressionMethods as _, QueryDsl, insert_into};
use diesel_async::RunQueryDsl; use diesel_async::RunQueryDsl;
use ed25519_dalek::Signer as _; use ed25519_dalek::Signer as _;
fn auth_request(payload: ClientAuthPayload) -> UserAgentRequest {
UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthMessage(ClientMessage {
payload: Some(payload),
})),
}
}
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
pub async fn test_bootstrap_token_auth() { pub async fn test_bootstrap_token_auth() {
let db =db::create_test_pool().await; let db = db::create_test_pool().await;
let actors = GlobalActors::spawn(db.clone()).await.unwrap(); let actors = GlobalActors::spawn(db.clone()).await.unwrap();
let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap(); let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap();
@@ -38,25 +28,21 @@ pub async fn test_bootstrap_token_auth() {
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
let result = user_agent let result = user_agent
.process_transport_inbound(auth_request(ClientAuthPayload::AuthChallengeRequest( .process_transport_inbound(UserAgentRequest {
AuthChallengeRequest { payload: Some(UserAgentRequestPayload::AuthChallengeRequest(
pubkey: pubkey_bytes, AuthChallengeRequest {
bootstrap_token: Some(token), pubkey: pubkey_bytes,
}, bootstrap_token: Some(token),
))) },
)),
})
.await .await
.expect("Shouldn't fail to process message"); .expect("Shouldn't fail to process message");
assert_eq!( assert_eq!(
result, result,
UserAgentResponse { UserAgentResponse {
payload: Some(UserAgentResponsePayload::AuthMessage( payload: Some(UserAgentResponsePayload::AuthOk(AuthOk {})),
arbiter_proto::proto::auth::ServerMessage {
payload: Some(arbiter_proto::proto::auth::server_message::Payload::AuthOk(
AuthOk {},
)),
},
)),
} }
); );
@@ -81,12 +67,14 @@ pub async fn test_bootstrap_invalid_token_auth() {
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
let result = user_agent let result = user_agent
.process_transport_inbound(auth_request(ClientAuthPayload::AuthChallengeRequest( .process_transport_inbound(UserAgentRequest {
AuthChallengeRequest { payload: Some(UserAgentRequestPayload::AuthChallengeRequest(
pubkey: pubkey_bytes, AuthChallengeRequest {
bootstrap_token: Some("invalid_token".to_string()), pubkey: pubkey_bytes,
}, bootstrap_token: Some("invalid_token".to_string()),
))) },
)),
})
.await; .await;
match result { match result {
@@ -120,49 +108,43 @@ pub async fn test_challenge_auth() {
} }
let result = user_agent let result = user_agent
.process_transport_inbound(auth_request(ClientAuthPayload::AuthChallengeRequest( .process_transport_inbound(UserAgentRequest {
AuthChallengeRequest { payload: Some(UserAgentRequestPayload::AuthChallengeRequest(
pubkey: pubkey_bytes, AuthChallengeRequest {
bootstrap_token: None, pubkey: pubkey_bytes,
}, bootstrap_token: None,
))) },
)),
})
.await .await
.expect("Shouldn't fail to process message"); .expect("Shouldn't fail to process message");
let UserAgentResponse { let UserAgentResponse {
payload: payload: Some(UserAgentResponsePayload::AuthChallenge(challenge)),
Some(UserAgentResponsePayload::AuthMessage(arbiter_proto::proto::auth::ServerMessage {
payload:
Some(arbiter_proto::proto::auth::server_message::Payload::AuthChallenge(challenge)),
})),
} = result } = result
else { else {
panic!("Expected auth challenge response, got {result:?}"); panic!("Expected auth challenge response, got {result:?}");
}; };
let formatted_challenge = arbiter_proto::format_challenge(&challenge); let formatted_challenge = arbiter_proto::format_challenge(challenge.nonce, &challenge.pubkey);
let signature = new_key.sign(&formatted_challenge); let signature = new_key.sign(&formatted_challenge);
let serialized_signature = signature.to_bytes().to_vec(); let serialized_signature = signature.to_bytes().to_vec();
let result = user_agent let result = user_agent
.process_transport_inbound(auth_request(ClientAuthPayload::AuthChallengeSolution( .process_transport_inbound(UserAgentRequest {
auth::AuthChallengeSolution { payload: Some(UserAgentRequestPayload::AuthChallengeSolution(
signature: serialized_signature, AuthChallengeSolution {
}, signature: serialized_signature,
))) },
)),
})
.await .await
.expect("Shouldn't fail to process message"); .expect("Shouldn't fail to process message");
assert_eq!( assert_eq!(
result, result,
UserAgentResponse { UserAgentResponse {
payload: Some(UserAgentResponsePayload::AuthMessage( payload: Some(UserAgentResponsePayload::AuthOk(AuthOk {})),
arbiter_proto::proto::auth::ServerMessage {
payload: Some(arbiter_proto::proto::auth::server_message::Payload::AuthOk(
AuthOk {},
)),
},
)),
} }
); );
} }

View File

@@ -1,6 +1,6 @@
use arbiter_proto::proto::{ use arbiter_proto::proto::user_agent::{
UnsealEncryptedKey, UnsealResult, UnsealStart, UserAgentRequest, UserAgentResponse, AuthChallengeRequest, UnsealEncryptedKey, UnsealResult, UnsealStart,
auth::{AuthChallengeRequest, ClientMessage, client_message::Payload as ClientAuthPayload}, UserAgentRequest, UserAgentResponse,
user_agent_request::Payload as UserAgentRequestPayload, user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload, user_agent_response::Payload as UserAgentResponsePayload,
}; };
@@ -21,26 +21,6 @@ use x25519_dalek::{EphemeralSecret, PublicKey};
type TestUserAgent = type TestUserAgent =
UserAgentActor<DummyTransport<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>>; UserAgentActor<DummyTransport<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>>;
fn auth_request(payload: ClientAuthPayload) -> UserAgentRequest {
UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthMessage(ClientMessage {
payload: Some(payload),
})),
}
}
fn unseal_start_request(req: UnsealStart) -> UserAgentRequest {
UserAgentRequest {
payload: Some(UserAgentRequestPayload::UnsealStart(req)),
}
}
fn unseal_key_request(req: UnsealEncryptedKey) -> UserAgentRequest {
UserAgentRequest {
payload: Some(UserAgentRequestPayload::UnsealEncryptedKey(req)),
}
}
async fn setup_authenticated_user_agent( async fn setup_authenticated_user_agent(
seal_key: &[u8], seal_key: &[u8],
) -> ( ) -> (
@@ -64,12 +44,14 @@ async fn setup_authenticated_user_agent(
let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap(); let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap();
let auth_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let auth_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
user_agent user_agent
.process_transport_inbound(auth_request(ClientAuthPayload::AuthChallengeRequest( .process_transport_inbound(UserAgentRequest {
AuthChallengeRequest { payload: Some(UserAgentRequestPayload::AuthChallengeRequest(
pubkey: auth_key.verifying_key().to_bytes().to_vec(), AuthChallengeRequest {
bootstrap_token: Some(token), pubkey: auth_key.verifying_key().to_bytes().to_vec(),
}, bootstrap_token: Some(token),
))) },
)),
})
.await .await
.unwrap(); .unwrap();
@@ -84,9 +66,11 @@ async fn client_dh_encrypt(
let client_public = PublicKey::from(&client_secret); let client_public = PublicKey::from(&client_secret);
let response = user_agent let response = user_agent
.process_transport_inbound(unseal_start_request(UnsealStart { .process_transport_inbound(UserAgentRequest {
client_pubkey: client_public.as_bytes().to_vec(), payload: Some(UserAgentRequestPayload::UnsealStart(UnsealStart {
})) client_pubkey: client_public.as_bytes().to_vec(),
})),
})
.await .await
.unwrap(); .unwrap();
@@ -112,6 +96,12 @@ async fn client_dh_encrypt(
} }
} }
fn unseal_key_request(req: UnsealEncryptedKey) -> UserAgentRequest {
UserAgentRequest {
payload: Some(UserAgentRequestPayload::UnsealEncryptedKey(req)),
}
}
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
pub async fn test_unseal_success() { pub async fn test_unseal_success() {
@@ -158,9 +148,11 @@ pub async fn test_unseal_corrupted_ciphertext() {
let client_public = PublicKey::from(&client_secret); let client_public = PublicKey::from(&client_secret);
user_agent user_agent
.process_transport_inbound(unseal_start_request(UnsealStart { .process_transport_inbound(UserAgentRequest {
client_pubkey: client_public.as_bytes().to_vec(), payload: Some(UserAgentRequestPayload::UnsealStart(UnsealStart {
})) client_pubkey: client_public.as_bytes().to_vec(),
})),
})
.await .await
.unwrap(); .unwrap();
@@ -191,9 +183,11 @@ pub async fn test_unseal_start_without_auth_fails() {
let client_public = PublicKey::from(&client_secret); let client_public = PublicKey::from(&client_secret);
let result = user_agent let result = user_agent
.process_transport_inbound(unseal_start_request(UnsealStart { .process_transport_inbound(UserAgentRequest {
client_pubkey: client_public.as_bytes().to_vec(), payload: Some(UserAgentRequestPayload::UnsealStart(UnsealStart {
})) client_pubkey: client_public.as_bytes().to_vec(),
})),
})
.await; .await;
match result { match result {

View File

@@ -1,8 +1,9 @@
use arbiter_proto::{ use arbiter_proto::{
proto::{ proto::{
UserAgentRequest, UserAgentResponse, arbiter_service_client::ArbiterServiceClient, user_agent::{UserAgentRequest, UserAgentResponse},
arbiter_service_client::ArbiterServiceClient,
}, },
transport::{IdentityRecvConverter, IdentitySendConverter, RecvConverter, grpc}, transport::{IdentityRecvConverter, IdentitySendConverter, grpc},
url::ArbiterUrl, url::ArbiterUrl,
}; };
use ed25519_dalek::SigningKey; use ed25519_dalek::SigningKey;

View File

@@ -1,12 +1,8 @@
use arbiter_proto::{ use arbiter_proto::{
format_challenge, format_challenge,
proto::{ proto::user_agent::{
AuthChallengeRequest, AuthChallengeSolution, AuthOk,
UserAgentRequest, UserAgentResponse, UserAgentRequest, UserAgentResponse,
auth::{
self, AuthChallengeRequest, AuthOk, ClientMessage as AuthClientMessage,
ServerMessage as AuthServerMessage, client_message::Payload as ClientAuthPayload,
server_message::Payload as ServerAuthPayload,
},
user_agent_request::Payload as UserAgentRequestPayload, user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload, user_agent_response::Payload as UserAgentResponsePayload,
}, },
@@ -81,14 +77,6 @@ where
Ok(()) Ok(())
} }
fn auth_request(payload: ClientAuthPayload) -> UserAgentRequest {
UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthMessage(AuthClientMessage {
payload: Some(payload),
})),
}
}
async fn send_auth_challenge_request(&mut self) -> Result<(), InboundError> { async fn send_auth_challenge_request(&mut self) -> Result<(), InboundError> {
let req = AuthChallengeRequest { let req = AuthChallengeRequest {
pubkey: self.key.verifying_key().to_bytes().to_vec(), pubkey: self.key.verifying_key().to_bytes().to_vec(),
@@ -98,9 +86,9 @@ where
self.transition(UserAgentEvents::SentAuthChallengeRequest)?; self.transition(UserAgentEvents::SentAuthChallengeRequest)?;
self.transport self.transport
.send(Self::auth_request(ClientAuthPayload::AuthChallengeRequest( .send(UserAgentRequest {
req, payload: Some(UserAgentRequestPayload::AuthChallengeRequest(req)),
))) })
.await .await
.map_err(|_| InboundError::TransportSendFailed)?; .map_err(|_| InboundError::TransportSendFailed)?;
@@ -110,20 +98,20 @@ where
async fn handle_auth_challenge( async fn handle_auth_challenge(
&mut self, &mut self,
challenge: auth::AuthChallenge, challenge: arbiter_proto::proto::user_agent::AuthChallenge,
) -> Result<(), InboundError> { ) -> Result<(), InboundError> {
self.transition(UserAgentEvents::ReceivedAuthChallenge)?; self.transition(UserAgentEvents::ReceivedAuthChallenge)?;
let formatted = format_challenge(&challenge); let formatted = format_challenge(challenge.nonce, &challenge.pubkey);
let signature = self.key.sign(&formatted); let signature = self.key.sign(&formatted);
let solution = auth::AuthChallengeSolution { let solution = AuthChallengeSolution {
signature: signature.to_bytes().to_vec(), signature: signature.to_bytes().to_vec(),
}; };
self.transport self.transport
.send(Self::auth_request( .send(UserAgentRequest {
ClientAuthPayload::AuthChallengeSolution(solution), payload: Some(UserAgentRequestPayload::AuthChallengeSolution(solution)),
)) })
.await .await
.map_err(|_| InboundError::TransportSendFailed)?; .map_err(|_| InboundError::TransportSendFailed)?;
@@ -146,12 +134,10 @@ where
.ok_or(InboundError::MissingResponsePayload)?; .ok_or(InboundError::MissingResponsePayload)?;
match payload { match payload {
UserAgentResponsePayload::AuthMessage(AuthServerMessage { UserAgentResponsePayload::AuthChallenge(challenge) => {
payload: Some(ServerAuthPayload::AuthChallenge(challenge)), self.handle_auth_challenge(challenge).await
}) => self.handle_auth_challenge(challenge).await, }
UserAgentResponsePayload::AuthMessage(AuthServerMessage { UserAgentResponsePayload::AuthOk(ok) => self.handle_auth_ok(ok),
payload: Some(ServerAuthPayload::AuthOk(ok)),
}) => self.handle_auth_ok(ok),
_ => Err(InboundError::UnexpectedResponsePayload), _ => Err(InboundError::UnexpectedResponsePayload),
} }
} }

View File

@@ -1,18 +1,14 @@
use arbiter_proto::{ use arbiter_proto::{
format_challenge, format_challenge,
proto::{ proto::user_agent::{
AuthChallenge, AuthOk,
UserAgentRequest, UserAgentResponse, UserAgentRequest, UserAgentResponse,
auth::{
AuthChallenge, AuthOk, ClientMessage as AuthClientMessage,
ServerMessage as AuthServerMessage, client_message::Payload as ClientAuthPayload,
server_message::Payload as ServerAuthPayload,
},
user_agent_request::Payload as UserAgentRequestPayload, user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload, user_agent_response::Payload as UserAgentResponsePayload,
}, },
transport::Bi, transport::Bi,
}; };
use arbiter_useragent::{InboundError, UserAgentActor}; use arbiter_useragent::UserAgentActor;
use ed25519_dalek::SigningKey; use ed25519_dalek::SigningKey;
use kameo::actor::Spawn; use kameo::actor::Spawn;
use tokio::sync::mpsc; use tokio::sync::mpsc;
@@ -57,14 +53,6 @@ fn test_key() -> SigningKey {
SigningKey::from_bytes(&[7u8; 32]) SigningKey::from_bytes(&[7u8; 32])
} }
fn auth_response(payload: ServerAuthPayload) -> UserAgentResponse {
UserAgentResponse {
payload: Some(UserAgentResponsePayload::AuthMessage(AuthServerMessage {
payload: Some(payload),
})),
}
}
#[tokio::test] #[tokio::test]
async fn sends_auth_request_on_start_with_bootstrap_token() { async fn sends_auth_request_on_start_with_bootstrap_token() {
let key = test_key(); let key = test_key();
@@ -80,9 +68,7 @@ async fn sends_auth_request_on_start_with_bootstrap_token() {
.expect("channel closed before auth request"); .expect("channel closed before auth request");
let UserAgentRequest { let UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthMessage(AuthClientMessage { payload: Some(UserAgentRequestPayload::AuthChallengeRequest(req)),
payload: Some(ClientAuthPayload::AuthChallengeRequest(req)),
})),
} = outbound } = outbound
else { else {
panic!("expected auth challenge request"); panic!("expected auth challenge request");
@@ -113,7 +99,9 @@ async fn challenge_flow_sends_solution_from_transport_inbound() {
nonce: 42, nonce: 42,
}; };
inbound_tx inbound_tx
.send(auth_response(ServerAuthPayload::AuthChallenge(challenge.clone()))) .send(UserAgentResponse {
payload: Some(UserAgentResponsePayload::AuthChallenge(challenge.clone())),
})
.await .await
.unwrap(); .unwrap();
@@ -123,15 +111,13 @@ async fn challenge_flow_sends_solution_from_transport_inbound() {
.expect("missing challenge solution"); .expect("missing challenge solution");
let UserAgentRequest { let UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthMessage(AuthClientMessage { payload: Some(UserAgentRequestPayload::AuthChallengeSolution(solution)),
payload: Some(ClientAuthPayload::AuthChallengeSolution(solution)),
})),
} = outbound } = outbound
else { else {
panic!("expected auth challenge solution"); panic!("expected auth challenge solution");
}; };
let formatted = format_challenge(&challenge); let formatted = format_challenge(challenge.nonce, &challenge.pubkey);
let sig: ed25519_dalek::Signature = solution let sig: ed25519_dalek::Signature = solution
.signature .signature
.as_slice() .as_slice()
@@ -142,7 +128,9 @@ async fn challenge_flow_sends_solution_from_transport_inbound() {
.expect("solution signature should verify"); .expect("solution signature should verify");
inbound_tx inbound_tx
.send(auth_response(ServerAuthPayload::AuthOk(AuthOk {}))) .send(UserAgentResponse {
payload: Some(UserAgentResponsePayload::AuthOk(AuthOk {})),
})
.await .await
.unwrap(); .unwrap();