refactor: consolidate auth messages into client and user_agent packages

This commit is contained in:
hdbg
2026-03-01 11:44:34 +01:00
parent 06f4d628db
commit 4b4a8f4489
19 changed files with 686 additions and 264 deletions

View File

@@ -2,7 +2,6 @@ syntax = "proto3";
package arbiter;
import "auth.proto";
import "client.proto";
import "user_agent.proto";
@@ -12,6 +11,6 @@ message ServerInfo {
}
service ArbiterService {
rpc Client(stream ClientRequest) returns (stream ClientResponse);
rpc UserAgent(stream UserAgentRequest) returns (stream UserAgentResponse);
rpc Client(stream arbiter.client.ClientRequest) returns (stream arbiter.client.ClientResponse);
rpc UserAgent(stream arbiter.user_agent.UserAgentRequest) returns (stream arbiter.user_agent.UserAgentResponse);
}

View File

@@ -1,35 +0,0 @@
syntax = "proto3";
package arbiter.auth;
import "google/protobuf/timestamp.proto";
message AuthChallengeRequest {
bytes pubkey = 1;
optional string bootstrap_token = 2;
}
message AuthChallenge {
bytes pubkey = 1;
int32 nonce = 2;
}
message AuthChallengeSolution {
bytes signature = 1;
}
message AuthOk {}
message ClientMessage {
oneof payload {
AuthChallengeRequest auth_challenge_request = 1;
AuthChallengeSolution auth_challenge_solution = 2;
}
}
message ServerMessage {
oneof payload {
AuthChallenge auth_challenge = 1;
AuthOk auth_ok = 2;
}
}

View File

@@ -1,17 +1,32 @@
syntax = "proto3";
package arbiter;
package arbiter.client;
import "auth.proto";
message AuthChallengeRequest {
bytes pubkey = 1;
}
message AuthChallenge {
bytes pubkey = 1;
int32 nonce = 2;
}
message AuthChallengeSolution {
bytes signature = 1;
}
message AuthOk {}
message ClientRequest {
oneof payload {
arbiter.auth.ClientMessage auth_message = 1;
AuthChallengeRequest auth_challenge_request = 1;
AuthChallengeSolution auth_challenge_solution = 2;
}
}
message ClientResponse {
oneof payload {
arbiter.auth.ServerMessage auth_message = 1;
AuthChallenge auth_challenge = 1;
AuthOk auth_ok = 2;
}
}

View File

@@ -1,10 +1,25 @@
syntax = "proto3";
package arbiter;
package arbiter.user_agent;
import "auth.proto";
import "google/protobuf/empty.proto";
message AuthChallengeRequest {
bytes pubkey = 1;
optional string bootstrap_token = 2;
}
message AuthChallenge {
bytes pubkey = 1;
int32 nonce = 2;
}
message AuthChallengeSolution {
bytes signature = 1;
}
message AuthOk {}
message UnsealStart {
bytes client_pubkey = 1;
}
@@ -35,17 +50,19 @@ enum VaultState {
message UserAgentRequest {
oneof payload {
arbiter.auth.ClientMessage auth_message = 1;
UnsealStart unseal_start = 2;
UnsealEncryptedKey unseal_encrypted_key = 3;
google.protobuf.Empty query_vault_state = 4;
AuthChallengeRequest auth_challenge_request = 1;
AuthChallengeSolution auth_challenge_solution = 2;
UnsealStart unseal_start = 3;
UnsealEncryptedKey unseal_encrypted_key = 4;
google.protobuf.Empty query_vault_state = 5;
}
}
message UserAgentResponse {
oneof payload {
arbiter.auth.ServerMessage auth_message = 1;
UnsealStartResponse unseal_start_response = 2;
UnsealResult unseal_result = 3;
VaultState vault_state = 4;
AuthChallenge auth_challenge = 1;
AuthOk auth_ok = 2;
UnsealStartResponse unseal_start_response = 3;
UnsealResult unseal_result = 4;
VaultState vault_state = 5;
}
}

View File

@@ -11,7 +11,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.compile_protos(
&[
format!("{}/arbiter.proto", PROTOBUF_DIR),
format!("{}/auth.proto", PROTOBUF_DIR),
format!("{}/user_agent.proto", PROTOBUF_DIR),
format!("{}/client.proto", PROTOBUF_DIR),
],
&[PROTOBUF_DIR.to_string()],
)

View File

@@ -3,13 +3,15 @@ pub mod url;
use base64::{Engine, prelude::BASE64_STANDARD};
use crate::proto::auth::AuthChallenge;
pub mod proto {
tonic::include_proto!("arbiter");
pub mod auth {
tonic::include_proto!("arbiter.auth");
pub mod user_agent {
tonic::include_proto!("arbiter.user_agent");
}
pub mod client {
tonic::include_proto!("arbiter.client");
}
}
@@ -28,7 +30,7 @@ pub fn home_path() -> Result<std::path::PathBuf, std::io::Error> {
Ok(arbiter_home)
}
pub fn format_challenge(challenge: &AuthChallenge) -> Vec<u8> {
let concat_form = format!("{}:{}", challenge.nonce, BASE64_STANDARD.encode(&challenge.pubkey));
concat_form.into_bytes().to_vec()
pub fn format_challenge(nonce: i32, pubkey: &[u8]) -> Vec<u8> {
let concat_form = format!("{}:{}", nonce, BASE64_STANDARD.encode(pubkey));
concat_form.into_bytes()
}

View File

@@ -1,12 +0,0 @@
use arbiter_proto::{
proto::{ClientRequest, ClientResponse},
transport::Bi,
};
use crate::ServerContext;
pub(crate) async fn handle_client(
_context: ServerContext,
_bistream: impl Bi<ClientRequest, ClientResponse>,
) {
}

View File

@@ -0,0 +1,289 @@
use arbiter_proto::{
proto::client::{
AuthChallenge, AuthChallengeRequest, AuthChallengeSolution, AuthOk, ClientRequest,
ClientResponse,
client_request::Payload as ClientRequestPayload,
client_response::Payload as ClientResponsePayload,
},
transport::{Bi, DummyTransport},
};
use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, dsl::update};
use diesel_async::RunQueryDsl;
use ed25519_dalek::VerifyingKey;
use kameo::Actor;
use tokio::select;
use tracing::{error, info};
use crate::{
ServerContext,
actors::client::state::{
ChallengeContext, ClientEvents, ClientStateMachine, ClientStates, DummyContext,
},
db::{self, schema},
};
mod state;
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
pub enum ClientError {
#[error("Expected message with payload")]
MissingRequestPayload,
#[error("Unexpected request payload")]
UnexpectedRequestPayload,
#[error("Invalid state for challenge solution")]
InvalidStateForChallengeSolution,
#[error("Expected pubkey to have specific length")]
InvalidAuthPubkeyLength,
#[error("Failed to convert pubkey to VerifyingKey")]
InvalidAuthPubkeyEncoding,
#[error("Invalid signature length")]
InvalidSignatureLength,
#[error("Public key not registered")]
PublicKeyNotRegistered,
#[error("Invalid challenge solution")]
InvalidChallengeSolution,
#[error("State machine error")]
StateTransitionFailed,
#[error("Database pool error")]
DatabasePoolUnavailable,
#[error("Database error")]
DatabaseOperationFailed,
}
pub struct ClientActor<Transport>
where
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
{
db: db::DatabasePool,
state: ClientStateMachine<DummyContext>,
transport: Transport,
}
impl<Transport> ClientActor<Transport>
where
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
{
pub(crate) fn new(context: ServerContext, transport: Transport) -> Self {
Self {
db: context.db.clone(),
state: ClientStateMachine::new(DummyContext),
transport,
}
}
fn transition(&mut self, event: ClientEvents) -> Result<(), ClientError> {
self.state.process_event(event).map_err(|e| {
error!(?e, "State transition failed");
ClientError::StateTransitionFailed
})?;
Ok(())
}
pub async fn process_transport_inbound(&mut self, req: ClientRequest) -> Output {
let msg = req.payload.ok_or_else(|| {
error!(actor = "client", "Received message with no payload");
ClientError::MissingRequestPayload
})?;
match msg {
ClientRequestPayload::AuthChallengeRequest(req) => {
self.handle_auth_challenge_request(req).await
}
ClientRequestPayload::AuthChallengeSolution(solution) => {
self.handle_auth_challenge_solution(solution).await
}
}
}
async fn handle_auth_challenge_request(&mut self, req: AuthChallengeRequest) -> Output {
let pubkey = req
.pubkey
.as_array()
.ok_or(ClientError::InvalidAuthPubkeyLength)?;
let pubkey = VerifyingKey::from_bytes(pubkey).map_err(|_err| {
error!(?pubkey, "Failed to convert to VerifyingKey");
ClientError::InvalidAuthPubkeyEncoding
})?;
self.transition(ClientEvents::AuthRequest)?;
self.auth_with_challenge(pubkey, req.pubkey).await
}
async fn auth_with_challenge(&mut self, pubkey: VerifyingKey, pubkey_bytes: Vec<u8>) -> Output {
let nonce: Option<i32> = {
let mut db_conn = self.db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error");
ClientError::DatabasePoolUnavailable
})?;
db_conn
.exclusive_transaction(|conn| {
Box::pin(async move {
let current_nonce = schema::program_client::table
.filter(
schema::program_client::public_key.eq(pubkey.as_bytes().to_vec()),
)
.select(schema::program_client::nonce)
.first::<i32>(conn)
.await?;
update(schema::program_client::table)
.filter(
schema::program_client::public_key.eq(pubkey.as_bytes().to_vec()),
)
.set(schema::program_client::nonce.eq(current_nonce + 1))
.execute(conn)
.await?;
Result::<_, diesel::result::Error>::Ok(current_nonce)
})
})
.await
.optional()
.map_err(|e| {
error!(error = ?e, "Database error");
ClientError::DatabaseOperationFailed
})?
};
let Some(nonce) = nonce else {
error!(?pubkey, "Public key not found in database");
return Err(ClientError::PublicKeyNotRegistered);
};
let challenge = AuthChallenge {
pubkey: pubkey_bytes,
nonce,
};
self.transition(ClientEvents::SentChallenge(ChallengeContext {
challenge: challenge.clone(),
key: pubkey,
}))?;
info!(
?pubkey,
?challenge,
"Sent authentication challenge to client"
);
Ok(response(ClientResponsePayload::AuthChallenge(challenge)))
}
fn verify_challenge_solution(
&self,
solution: &AuthChallengeSolution,
) -> Result<(bool, &ChallengeContext), ClientError> {
let ClientStates::WaitingForChallengeSolution(challenge_context) = self.state.state()
else {
error!("Received challenge solution in invalid state");
return Err(ClientError::InvalidStateForChallengeSolution);
};
let formatted_challenge = arbiter_proto::format_challenge(
challenge_context.challenge.nonce,
&challenge_context.challenge.pubkey,
);
let signature = solution.signature.as_slice().try_into().map_err(|_| {
error!(?solution, "Invalid signature length");
ClientError::InvalidSignatureLength
})?;
let valid = challenge_context
.key
.verify_strict(&formatted_challenge, &signature)
.is_ok();
Ok((valid, challenge_context))
}
async fn handle_auth_challenge_solution(
&mut self,
solution: AuthChallengeSolution,
) -> Output {
let (valid, challenge_context) = self.verify_challenge_solution(&solution)?;
if valid {
info!(
?challenge_context,
"Client provided valid solution to authentication challenge"
);
self.transition(ClientEvents::ReceivedGoodSolution)?;
Ok(response(ClientResponsePayload::AuthOk(AuthOk {})))
} else {
error!("Client provided invalid solution to authentication challenge");
self.transition(ClientEvents::ReceivedBadSolution)?;
Err(ClientError::InvalidChallengeSolution)
}
}
}
type Output = Result<ClientResponse, ClientError>;
fn response(payload: ClientResponsePayload) -> ClientResponse {
ClientResponse {
payload: Some(payload),
}
}
impl<Transport> Actor for ClientActor<Transport>
where
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
{
type Args = Self;
type Error = ();
async fn on_start(
args: Self::Args,
_: kameo::prelude::ActorRef<Self>,
) -> Result<Self, Self::Error> {
Ok(args)
}
async fn next(
&mut self,
_actor_ref: kameo::prelude::WeakActorRef<Self>,
mailbox_rx: &mut kameo::prelude::MailboxReceiver<Self>,
) -> Option<kameo::mailbox::Signal<Self>> {
loop {
select! {
signal = mailbox_rx.recv() => {
return signal;
}
msg = self.transport.recv() => {
match msg {
Some(request) => {
match self.process_transport_inbound(request).await {
Ok(resp) => {
if self.transport.send(Ok(resp)).await.is_err() {
error!(actor = "client", reason = "channel closed", "send.failed");
return Some(kameo::mailbox::Signal::Stop);
}
}
Err(err) => {
let _ = self.transport.send(Err(err)).await;
return Some(kameo::mailbox::Signal::Stop);
}
}
}
None => {
info!(actor = "client", "transport.closed");
return Some(kameo::mailbox::Signal::Stop);
}
}
}
}
}
}
}
impl ClientActor<DummyTransport<ClientRequest, Result<ClientResponse, ClientError>>> {
pub fn new_manual(db: db::DatabasePool) -> Self {
Self {
db,
state: ClientStateMachine::new(DummyContext),
transport: DummyTransport::new(),
}
}
}

View File

@@ -0,0 +1,31 @@
use arbiter_proto::proto::client::AuthChallenge;
use ed25519_dalek::VerifyingKey;
/// Context for state machine with validated key and sent challenge
#[derive(Clone, Debug)]
pub struct ChallengeContext {
pub challenge: AuthChallenge,
pub key: VerifyingKey,
}
smlang::statemachine!(
name: Client,
custom_error: false,
transitions: {
*Init + AuthRequest = ReceivedAuthRequest,
ReceivedAuthRequest + SentChallenge(ChallengeContext) / move_challenge = WaitingForChallengeSolution(ChallengeContext),
WaitingForChallengeSolution(ChallengeContext) + ReceivedGoodSolution = Idle,
WaitingForChallengeSolution(ChallengeContext) + ReceivedBadSolution = AuthError,
}
);
pub struct DummyContext;
impl ClientStateMachineContext for DummyContext {
#[allow(missing_docs)]
#[allow(clippy::unused_unit)]
fn move_challenge(&mut self, event_data: ChallengeContext) -> Result<ChallengeContext, ()> {
Ok(event_data)
}
}

View File

@@ -1,14 +1,9 @@
use std::{ops::DerefMut, sync::Mutex};
use arbiter_proto::{
proto::{
UnsealEncryptedKey, UnsealResult, UnsealStart, UnsealStartResponse, UserAgentRequest,
UserAgentResponse,
auth::{
self, AuthChallengeRequest, AuthOk, ClientMessage as ClientAuthMessage,
ServerMessage as AuthServerMessage, client_message::Payload as ClientAuthPayload,
server_message::Payload as ServerAuthPayload,
},
proto::user_agent::{
AuthChallenge, AuthChallengeRequest, AuthChallengeSolution, AuthOk, UnsealEncryptedKey,
UnsealResult, UnsealStart, UnsealStartResponse, UserAgentRequest, UserAgentResponse,
user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
},
@@ -114,12 +109,12 @@ where
})?;
match msg {
UserAgentRequestPayload::AuthMessage(ClientAuthMessage {
payload: Some(ClientAuthPayload::AuthChallengeRequest(req)),
}) => self.handle_auth_challenge_request(req).await,
UserAgentRequestPayload::AuthMessage(ClientAuthMessage {
payload: Some(ClientAuthPayload::AuthChallengeSolution(solution)),
}) => self.handle_auth_challenge_solution(solution).await,
UserAgentRequestPayload::AuthChallengeRequest(req) => {
self.handle_auth_challenge_request(req).await
}
UserAgentRequestPayload::AuthChallengeSolution(solution) => {
self.handle_auth_challenge_solution(solution).await
}
UserAgentRequestPayload::UnsealStart(unseal_start) => {
self.handle_unseal_request(unseal_start).await
}
@@ -171,7 +166,7 @@ where
self.transition(UserAgentEvents::ReceivedBootstrapToken)?;
Ok(auth_response(ServerAuthPayload::AuthOk(AuthOk {})))
Ok(response(UserAgentResponsePayload::AuthOk(AuthOk {})))
}
async fn auth_with_challenge(&mut self, pubkey: VerifyingKey, pubkey_bytes: Vec<u8>) -> Output {
@@ -215,7 +210,7 @@ where
return Err(UserAgentError::PublicKeyNotRegistered);
};
let challenge = auth::AuthChallenge {
let challenge = AuthChallenge {
pubkey: pubkey_bytes,
nonce,
};
@@ -231,19 +226,22 @@ where
"Sent authentication challenge to client"
);
Ok(auth_response(ServerAuthPayload::AuthChallenge(challenge)))
Ok(response(UserAgentResponsePayload::AuthChallenge(challenge)))
}
fn verify_challenge_solution(
&self,
solution: &auth::AuthChallengeSolution,
solution: &AuthChallengeSolution,
) -> Result<(bool, &ChallengeContext), UserAgentError> {
let UserAgentStates::WaitingForChallengeSolution(challenge_context) = self.state.state()
else {
error!("Received challenge solution in invalid state");
return Err(UserAgentError::InvalidStateForChallengeSolution);
};
let formatted_challenge = arbiter_proto::format_challenge(&challenge_context.challenge);
let formatted_challenge = arbiter_proto::format_challenge(
challenge_context.challenge.nonce,
&challenge_context.challenge.pubkey,
);
let signature = solution.signature.as_slice().try_into().map_err(|_| {
error!(?solution, "Invalid signature length");
@@ -261,15 +259,7 @@ where
type Output = Result<UserAgentResponse, UserAgentError>;
fn auth_response(payload: ServerAuthPayload) -> UserAgentResponse {
UserAgentResponse {
payload: Some(UserAgentResponsePayload::AuthMessage(AuthServerMessage {
payload: Some(payload),
})),
}
}
fn unseal_response(payload: UserAgentResponsePayload) -> UserAgentResponse {
fn response(payload: UserAgentResponsePayload) -> UserAgentResponse {
UserAgentResponse {
payload: Some(payload),
}
@@ -295,7 +285,7 @@ where
client_public_key,
}))?;
Ok(unseal_response(
Ok(response(
UserAgentResponsePayload::UnsealStartResponse(UnsealStartResponse {
server_pubkey: public_key.as_bytes().to_vec(),
}),
@@ -316,7 +306,7 @@ where
drop(secret_lock);
error!("Ephemeral secret already taken");
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
return Ok(unseal_response(UserAgentResponsePayload::UnsealResult(
return Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(),
)));
}
@@ -349,20 +339,20 @@ where
Ok(_) => {
info!("Successfully unsealed key with client-provided key");
self.transition(UserAgentEvents::ReceivedValidKey)?;
Ok(unseal_response(UserAgentResponsePayload::UnsealResult(
Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::Success.into(),
)))
}
Err(SendError::HandlerError(keyholder::Error::InvalidKey)) => {
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(unseal_response(UserAgentResponsePayload::UnsealResult(
Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(),
)))
}
Err(SendError::HandlerError(err)) => {
error!(?err, "Keyholder failed to unseal key");
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(unseal_response(UserAgentResponsePayload::UnsealResult(
Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(),
)))
}
@@ -376,7 +366,7 @@ where
Err(err) => {
error!(?err, "Failed to decrypt unseal key");
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Ok(unseal_response(UserAgentResponsePayload::UnsealResult(
Ok(response(UserAgentResponsePayload::UnsealResult(
UnsealResult::InvalidKey.into(),
)))
}
@@ -403,7 +393,7 @@ where
async fn handle_auth_challenge_solution(
&mut self,
solution: auth::AuthChallengeSolution,
solution: AuthChallengeSolution,
) -> Output {
let (valid, challenge_context) = self.verify_challenge_solution(&solution)?;
@@ -413,7 +403,7 @@ where
"Client provided valid solution to authentication challenge"
);
self.transition(UserAgentEvents::ReceivedGoodSolution)?;
Ok(auth_response(ServerAuthPayload::AuthOk(AuthOk {})))
Ok(response(UserAgentResponsePayload::AuthOk(AuthOk {})))
} else {
error!("Client provided invalid solution to authentication challenge");
self.transition(UserAgentEvents::ReceivedBadSolution)?;

View File

@@ -1,6 +1,6 @@
use std::sync::Mutex;
use arbiter_proto::proto::auth::AuthChallenge;
use arbiter_proto::proto::user_agent::AuthChallenge;
use ed25519_dalek::VerifyingKey;
use x25519_dalek::{EphemeralSecret, PublicKey};

View File

@@ -1,6 +1,9 @@
#![forbid(unsafe_code)]
use arbiter_proto::{
proto::{ClientRequest, ClientResponse, UserAgentRequest, UserAgentResponse},
proto::{
client::{ClientRequest, ClientResponse},
user_agent::{UserAgentRequest, UserAgentResponse},
},
transport::{IdentityRecvConverter, SendConverter, grpc},
};
use async_trait::async_trait;
@@ -12,7 +15,10 @@ use tonic::{Request, Response, Status};
use tracing::info;
use crate::{
actors::user_agent::{UserAgentActor, UserAgentError},
actors::{
client::{ClientActor, ClientError},
user_agent::{UserAgentActor, UserAgentError},
},
context::ServerContext,
};
@@ -41,6 +47,56 @@ impl SendConverter for UserAgentGrpcSender {
}
}
/// Converts Client domain outbounds into the tonic stream item emitted by the
/// server.
///
/// The conversion is defined at the server boundary so the actor module remains
/// focused on domain semantics and does not depend on tonic status encoding.
struct ClientGrpcSender;
impl SendConverter for ClientGrpcSender {
type Input = Result<ClientResponse, ClientError>;
type Output = Result<ClientResponse, Status>;
fn convert(&self, item: Self::Input) -> Self::Output {
match item {
Ok(message) => Ok(message),
Err(err) => Err(client_error_status(err)),
}
}
}
/// Maps Client domain errors to public gRPC transport errors for the `client`
/// streaming endpoint.
fn client_error_status(value: ClientError) -> Status {
match value {
ClientError::MissingRequestPayload | ClientError::UnexpectedRequestPayload => {
Status::invalid_argument("Expected message with payload")
}
ClientError::InvalidStateForChallengeSolution => {
Status::invalid_argument("Invalid state for challenge solution")
}
ClientError::InvalidAuthPubkeyLength => {
Status::invalid_argument("Expected pubkey to have specific length")
}
ClientError::InvalidAuthPubkeyEncoding => {
Status::invalid_argument("Failed to convert pubkey to VerifyingKey")
}
ClientError::InvalidSignatureLength => {
Status::invalid_argument("Invalid signature length")
}
ClientError::PublicKeyNotRegistered => {
Status::unauthenticated("Public key not registered")
}
ClientError::InvalidChallengeSolution => {
Status::unauthenticated("Invalid challenge solution")
}
ClientError::StateTransitionFailed => Status::internal("State machine error"),
ClientError::DatabasePoolUnavailable => Status::internal("Database pool error"),
ClientError::DatabaseOperationFailed => Status::internal("Database error"),
}
}
/// Maps User Agent domain errors to public gRPC transport errors for the
/// `user_agent` streaming endpoint.
fn user_agent_error_status(value: UserAgentError) -> Status {
@@ -100,11 +156,25 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server {
type UserAgentStream = ReceiverStream<Result<UserAgentResponse, Status>>;
type ClientStream = ReceiverStream<Result<ClientResponse, Status>>;
#[tracing::instrument(level = "debug", skip(self))]
async fn client(
&self,
_request: Request<tonic::Streaming<ClientRequest>>,
request: Request<tonic::Streaming<ClientRequest>>,
) -> Result<Response<Self::ClientStream>, Status> {
todo!()
let req_stream = request.into_inner();
let (tx, rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
let transport = grpc::GrpcAdapter::new(
tx,
req_stream,
IdentityRecvConverter::<ClientRequest>::new(),
ClientGrpcSender,
);
ClientActor::spawn(ClientActor::new(self.context.clone(), transport));
info!(event = "connection established", "grpc.client");
Ok(Response::new(ReceiverStream::new(rx)))
}
#[tracing::instrument(level = "debug", skip(self))]

View File

@@ -0,0 +1,2 @@
#[path = "client/auth.rs"]
mod auth;

View File

@@ -0,0 +1,102 @@
use arbiter_proto::proto::client::{
AuthChallengeRequest, AuthChallengeSolution, AuthOk, ClientRequest, ClientResponse,
client_request::Payload as ClientRequestPayload,
client_response::Payload as ClientResponsePayload,
};
use arbiter_server::{
actors::client::{ClientActor, ClientError},
db::{self, schema},
};
use diesel::{ExpressionMethods as _, insert_into};
use diesel_async::RunQueryDsl;
use ed25519_dalek::Signer as _;
#[tokio::test]
#[test_log::test]
pub async fn test_unregistered_pubkey_rejected() {
let db = db::create_test_pool().await;
let mut client = ClientActor::new_manual(db.clone());
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
let result = client
.process_transport_inbound(ClientRequest {
payload: Some(ClientRequestPayload::AuthChallengeRequest(
AuthChallengeRequest {
pubkey: pubkey_bytes,
},
)),
})
.await;
match result {
Err(err) => {
assert_eq!(err, ClientError::PublicKeyNotRegistered);
}
Ok(_) => {
panic!("Expected error due to unregistered pubkey, but got success");
}
}
}
#[tokio::test]
#[test_log::test]
pub async fn test_challenge_auth() {
let db = db::create_test_pool().await;
let mut client = ClientActor::new_manual(db.clone());
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
{
let mut conn = db.get().await.unwrap();
insert_into(schema::program_client::table)
.values(schema::program_client::public_key.eq(pubkey_bytes.clone()))
.execute(&mut conn)
.await
.unwrap();
}
let result = client
.process_transport_inbound(ClientRequest {
payload: Some(ClientRequestPayload::AuthChallengeRequest(
AuthChallengeRequest {
pubkey: pubkey_bytes,
},
)),
})
.await
.expect("Shouldn't fail to process message");
let ClientResponse {
payload: Some(ClientResponsePayload::AuthChallenge(challenge)),
} = result
else {
panic!("Expected auth challenge response, got {result:?}");
};
let formatted_challenge = arbiter_proto::format_challenge(challenge.nonce, &challenge.pubkey);
let signature = new_key.sign(&formatted_challenge);
let serialized_signature = signature.to_bytes().to_vec();
let result = client
.process_transport_inbound(ClientRequest {
payload: Some(ClientRequestPayload::AuthChallengeSolution(
AuthChallengeSolution {
signature: serialized_signature,
},
)),
})
.await
.expect("Shouldn't fail to process message");
assert_eq!(
result,
ClientResponse {
payload: Some(ClientResponsePayload::AuthOk(AuthOk {})),
}
);
}

View File

@@ -1,7 +1,5 @@
use arbiter_proto::proto::{
UserAgentResponse,
UserAgentRequest,
auth::{self, AuthChallengeRequest, AuthOk, ClientMessage, client_message::Payload as ClientAuthPayload},
use arbiter_proto::proto::user_agent::{
AuthChallengeRequest, AuthChallengeSolution, AuthOk, UserAgentRequest, UserAgentResponse,
user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
};
@@ -17,18 +15,10 @@ use diesel::{ExpressionMethods as _, QueryDsl, insert_into};
use diesel_async::RunQueryDsl;
use ed25519_dalek::Signer as _;
fn auth_request(payload: ClientAuthPayload) -> UserAgentRequest {
UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthMessage(ClientMessage {
payload: Some(payload),
})),
}
}
#[tokio::test]
#[test_log::test]
pub async fn test_bootstrap_token_auth() {
let db =db::create_test_pool().await;
let db = db::create_test_pool().await;
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap();
@@ -38,25 +28,21 @@ pub async fn test_bootstrap_token_auth() {
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
let result = user_agent
.process_transport_inbound(auth_request(ClientAuthPayload::AuthChallengeRequest(
AuthChallengeRequest {
pubkey: pubkey_bytes,
bootstrap_token: Some(token),
},
)))
.process_transport_inbound(UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthChallengeRequest(
AuthChallengeRequest {
pubkey: pubkey_bytes,
bootstrap_token: Some(token),
},
)),
})
.await
.expect("Shouldn't fail to process message");
assert_eq!(
result,
UserAgentResponse {
payload: Some(UserAgentResponsePayload::AuthMessage(
arbiter_proto::proto::auth::ServerMessage {
payload: Some(arbiter_proto::proto::auth::server_message::Payload::AuthOk(
AuthOk {},
)),
},
)),
payload: Some(UserAgentResponsePayload::AuthOk(AuthOk {})),
}
);
@@ -81,12 +67,14 @@ pub async fn test_bootstrap_invalid_token_auth() {
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
let result = user_agent
.process_transport_inbound(auth_request(ClientAuthPayload::AuthChallengeRequest(
AuthChallengeRequest {
pubkey: pubkey_bytes,
bootstrap_token: Some("invalid_token".to_string()),
},
)))
.process_transport_inbound(UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthChallengeRequest(
AuthChallengeRequest {
pubkey: pubkey_bytes,
bootstrap_token: Some("invalid_token".to_string()),
},
)),
})
.await;
match result {
@@ -120,49 +108,43 @@ pub async fn test_challenge_auth() {
}
let result = user_agent
.process_transport_inbound(auth_request(ClientAuthPayload::AuthChallengeRequest(
AuthChallengeRequest {
pubkey: pubkey_bytes,
bootstrap_token: None,
},
)))
.process_transport_inbound(UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthChallengeRequest(
AuthChallengeRequest {
pubkey: pubkey_bytes,
bootstrap_token: None,
},
)),
})
.await
.expect("Shouldn't fail to process message");
let UserAgentResponse {
payload:
Some(UserAgentResponsePayload::AuthMessage(arbiter_proto::proto::auth::ServerMessage {
payload:
Some(arbiter_proto::proto::auth::server_message::Payload::AuthChallenge(challenge)),
})),
payload: Some(UserAgentResponsePayload::AuthChallenge(challenge)),
} = result
else {
panic!("Expected auth challenge response, got {result:?}");
};
let formatted_challenge = arbiter_proto::format_challenge(&challenge);
let formatted_challenge = arbiter_proto::format_challenge(challenge.nonce, &challenge.pubkey);
let signature = new_key.sign(&formatted_challenge);
let serialized_signature = signature.to_bytes().to_vec();
let result = user_agent
.process_transport_inbound(auth_request(ClientAuthPayload::AuthChallengeSolution(
auth::AuthChallengeSolution {
signature: serialized_signature,
},
)))
.process_transport_inbound(UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthChallengeSolution(
AuthChallengeSolution {
signature: serialized_signature,
},
)),
})
.await
.expect("Shouldn't fail to process message");
assert_eq!(
result,
UserAgentResponse {
payload: Some(UserAgentResponsePayload::AuthMessage(
arbiter_proto::proto::auth::ServerMessage {
payload: Some(arbiter_proto::proto::auth::server_message::Payload::AuthOk(
AuthOk {},
)),
},
)),
payload: Some(UserAgentResponsePayload::AuthOk(AuthOk {})),
}
);
}

View File

@@ -1,6 +1,6 @@
use arbiter_proto::proto::{
UnsealEncryptedKey, UnsealResult, UnsealStart, UserAgentRequest, UserAgentResponse,
auth::{AuthChallengeRequest, ClientMessage, client_message::Payload as ClientAuthPayload},
use arbiter_proto::proto::user_agent::{
AuthChallengeRequest, UnsealEncryptedKey, UnsealResult, UnsealStart,
UserAgentRequest, UserAgentResponse,
user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
};
@@ -21,26 +21,6 @@ use x25519_dalek::{EphemeralSecret, PublicKey};
type TestUserAgent =
UserAgentActor<DummyTransport<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>>;
fn auth_request(payload: ClientAuthPayload) -> UserAgentRequest {
UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthMessage(ClientMessage {
payload: Some(payload),
})),
}
}
fn unseal_start_request(req: UnsealStart) -> UserAgentRequest {
UserAgentRequest {
payload: Some(UserAgentRequestPayload::UnsealStart(req)),
}
}
fn unseal_key_request(req: UnsealEncryptedKey) -> UserAgentRequest {
UserAgentRequest {
payload: Some(UserAgentRequestPayload::UnsealEncryptedKey(req)),
}
}
async fn setup_authenticated_user_agent(
seal_key: &[u8],
) -> (
@@ -64,12 +44,14 @@ async fn setup_authenticated_user_agent(
let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap();
let auth_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
user_agent
.process_transport_inbound(auth_request(ClientAuthPayload::AuthChallengeRequest(
AuthChallengeRequest {
pubkey: auth_key.verifying_key().to_bytes().to_vec(),
bootstrap_token: Some(token),
},
)))
.process_transport_inbound(UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthChallengeRequest(
AuthChallengeRequest {
pubkey: auth_key.verifying_key().to_bytes().to_vec(),
bootstrap_token: Some(token),
},
)),
})
.await
.unwrap();
@@ -84,9 +66,11 @@ async fn client_dh_encrypt(
let client_public = PublicKey::from(&client_secret);
let response = user_agent
.process_transport_inbound(unseal_start_request(UnsealStart {
client_pubkey: client_public.as_bytes().to_vec(),
}))
.process_transport_inbound(UserAgentRequest {
payload: Some(UserAgentRequestPayload::UnsealStart(UnsealStart {
client_pubkey: client_public.as_bytes().to_vec(),
})),
})
.await
.unwrap();
@@ -112,6 +96,12 @@ async fn client_dh_encrypt(
}
}
fn unseal_key_request(req: UnsealEncryptedKey) -> UserAgentRequest {
UserAgentRequest {
payload: Some(UserAgentRequestPayload::UnsealEncryptedKey(req)),
}
}
#[tokio::test]
#[test_log::test]
pub async fn test_unseal_success() {
@@ -158,9 +148,11 @@ pub async fn test_unseal_corrupted_ciphertext() {
let client_public = PublicKey::from(&client_secret);
user_agent
.process_transport_inbound(unseal_start_request(UnsealStart {
client_pubkey: client_public.as_bytes().to_vec(),
}))
.process_transport_inbound(UserAgentRequest {
payload: Some(UserAgentRequestPayload::UnsealStart(UnsealStart {
client_pubkey: client_public.as_bytes().to_vec(),
})),
})
.await
.unwrap();
@@ -191,9 +183,11 @@ pub async fn test_unseal_start_without_auth_fails() {
let client_public = PublicKey::from(&client_secret);
let result = user_agent
.process_transport_inbound(unseal_start_request(UnsealStart {
client_pubkey: client_public.as_bytes().to_vec(),
}))
.process_transport_inbound(UserAgentRequest {
payload: Some(UserAgentRequestPayload::UnsealStart(UnsealStart {
client_pubkey: client_public.as_bytes().to_vec(),
})),
})
.await;
match result {

View File

@@ -1,8 +1,9 @@
use arbiter_proto::{
proto::{
UserAgentRequest, UserAgentResponse, arbiter_service_client::ArbiterServiceClient,
user_agent::{UserAgentRequest, UserAgentResponse},
arbiter_service_client::ArbiterServiceClient,
},
transport::{IdentityRecvConverter, IdentitySendConverter, RecvConverter, grpc},
transport::{IdentityRecvConverter, IdentitySendConverter, grpc},
url::ArbiterUrl,
};
use ed25519_dalek::SigningKey;

View File

@@ -1,12 +1,8 @@
use arbiter_proto::{
format_challenge,
proto::{
proto::user_agent::{
AuthChallengeRequest, AuthChallengeSolution, AuthOk,
UserAgentRequest, UserAgentResponse,
auth::{
self, AuthChallengeRequest, AuthOk, ClientMessage as AuthClientMessage,
ServerMessage as AuthServerMessage, client_message::Payload as ClientAuthPayload,
server_message::Payload as ServerAuthPayload,
},
user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
},
@@ -81,14 +77,6 @@ where
Ok(())
}
fn auth_request(payload: ClientAuthPayload) -> UserAgentRequest {
UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthMessage(AuthClientMessage {
payload: Some(payload),
})),
}
}
async fn send_auth_challenge_request(&mut self) -> Result<(), InboundError> {
let req = AuthChallengeRequest {
pubkey: self.key.verifying_key().to_bytes().to_vec(),
@@ -98,9 +86,9 @@ where
self.transition(UserAgentEvents::SentAuthChallengeRequest)?;
self.transport
.send(Self::auth_request(ClientAuthPayload::AuthChallengeRequest(
req,
)))
.send(UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthChallengeRequest(req)),
})
.await
.map_err(|_| InboundError::TransportSendFailed)?;
@@ -110,20 +98,20 @@ where
async fn handle_auth_challenge(
&mut self,
challenge: auth::AuthChallenge,
challenge: arbiter_proto::proto::user_agent::AuthChallenge,
) -> Result<(), InboundError> {
self.transition(UserAgentEvents::ReceivedAuthChallenge)?;
let formatted = format_challenge(&challenge);
let formatted = format_challenge(challenge.nonce, &challenge.pubkey);
let signature = self.key.sign(&formatted);
let solution = auth::AuthChallengeSolution {
let solution = AuthChallengeSolution {
signature: signature.to_bytes().to_vec(),
};
self.transport
.send(Self::auth_request(
ClientAuthPayload::AuthChallengeSolution(solution),
))
.send(UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthChallengeSolution(solution)),
})
.await
.map_err(|_| InboundError::TransportSendFailed)?;
@@ -141,17 +129,15 @@ where
&mut self,
inbound: UserAgentResponse
) -> Result<(), InboundError> {
let payload = inbound
let payload = inbound
.payload
.ok_or(InboundError::MissingResponsePayload)?;
match payload {
UserAgentResponsePayload::AuthMessage(AuthServerMessage {
payload: Some(ServerAuthPayload::AuthChallenge(challenge)),
}) => self.handle_auth_challenge(challenge).await,
UserAgentResponsePayload::AuthMessage(AuthServerMessage {
payload: Some(ServerAuthPayload::AuthOk(ok)),
}) => self.handle_auth_ok(ok),
UserAgentResponsePayload::AuthChallenge(challenge) => {
self.handle_auth_challenge(challenge).await
}
UserAgentResponsePayload::AuthOk(ok) => self.handle_auth_ok(ok),
_ => Err(InboundError::UnexpectedResponsePayload),
}
}
@@ -206,4 +192,4 @@ where
}
mod grpc;
pub use grpc::{connect_grpc, ConnectError};
pub use grpc::{connect_grpc, ConnectError};

View File

@@ -1,18 +1,14 @@
use arbiter_proto::{
format_challenge,
proto::{
proto::user_agent::{
AuthChallenge, AuthOk,
UserAgentRequest, UserAgentResponse,
auth::{
AuthChallenge, AuthOk, ClientMessage as AuthClientMessage,
ServerMessage as AuthServerMessage, client_message::Payload as ClientAuthPayload,
server_message::Payload as ServerAuthPayload,
},
user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
},
transport::Bi,
};
use arbiter_useragent::{InboundError, UserAgentActor};
use arbiter_useragent::UserAgentActor;
use ed25519_dalek::SigningKey;
use kameo::actor::Spawn;
use tokio::sync::mpsc;
@@ -57,14 +53,6 @@ fn test_key() -> SigningKey {
SigningKey::from_bytes(&[7u8; 32])
}
fn auth_response(payload: ServerAuthPayload) -> UserAgentResponse {
UserAgentResponse {
payload: Some(UserAgentResponsePayload::AuthMessage(AuthServerMessage {
payload: Some(payload),
})),
}
}
#[tokio::test]
async fn sends_auth_request_on_start_with_bootstrap_token() {
let key = test_key();
@@ -80,9 +68,7 @@ async fn sends_auth_request_on_start_with_bootstrap_token() {
.expect("channel closed before auth request");
let UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthMessage(AuthClientMessage {
payload: Some(ClientAuthPayload::AuthChallengeRequest(req)),
})),
payload: Some(UserAgentRequestPayload::AuthChallengeRequest(req)),
} = outbound
else {
panic!("expected auth challenge request");
@@ -113,7 +99,9 @@ async fn challenge_flow_sends_solution_from_transport_inbound() {
nonce: 42,
};
inbound_tx
.send(auth_response(ServerAuthPayload::AuthChallenge(challenge.clone())))
.send(UserAgentResponse {
payload: Some(UserAgentResponsePayload::AuthChallenge(challenge.clone())),
})
.await
.unwrap();
@@ -123,15 +111,13 @@ async fn challenge_flow_sends_solution_from_transport_inbound() {
.expect("missing challenge solution");
let UserAgentRequest {
payload: Some(UserAgentRequestPayload::AuthMessage(AuthClientMessage {
payload: Some(ClientAuthPayload::AuthChallengeSolution(solution)),
})),
payload: Some(UserAgentRequestPayload::AuthChallengeSolution(solution)),
} = outbound
else {
panic!("expected auth challenge solution");
};
let formatted = format_challenge(&challenge);
let formatted = format_challenge(challenge.nonce, &challenge.pubkey);
let sig: ed25519_dalek::Signature = solution
.signature
.as_slice()
@@ -142,7 +128,9 @@ async fn challenge_flow_sends_solution_from_transport_inbound() {
.expect("solution signature should verify");
inbound_tx
.send(auth_response(ServerAuthPayload::AuthOk(AuthOk {})))
.send(UserAgentResponse {
payload: Some(UserAgentResponsePayload::AuthOk(AuthOk {})),
})
.await
.unwrap();