refactor: consolidate auth messages into client and user_agent packages
This commit is contained in:
@@ -2,7 +2,6 @@ syntax = "proto3";
|
||||
|
||||
package arbiter;
|
||||
|
||||
import "auth.proto";
|
||||
import "client.proto";
|
||||
import "user_agent.proto";
|
||||
|
||||
@@ -12,6 +11,6 @@ message ServerInfo {
|
||||
}
|
||||
|
||||
service ArbiterService {
|
||||
rpc Client(stream ClientRequest) returns (stream ClientResponse);
|
||||
rpc UserAgent(stream UserAgentRequest) returns (stream UserAgentResponse);
|
||||
rpc Client(stream arbiter.client.ClientRequest) returns (stream arbiter.client.ClientResponse);
|
||||
rpc UserAgent(stream arbiter.user_agent.UserAgentRequest) returns (stream arbiter.user_agent.UserAgentResponse);
|
||||
}
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package arbiter.auth;
|
||||
|
||||
import "google/protobuf/timestamp.proto";
|
||||
|
||||
message AuthChallengeRequest {
|
||||
bytes pubkey = 1;
|
||||
optional string bootstrap_token = 2;
|
||||
}
|
||||
|
||||
message AuthChallenge {
|
||||
bytes pubkey = 1;
|
||||
int32 nonce = 2;
|
||||
}
|
||||
|
||||
message AuthChallengeSolution {
|
||||
bytes signature = 1;
|
||||
}
|
||||
|
||||
message AuthOk {}
|
||||
|
||||
message ClientMessage {
|
||||
oneof payload {
|
||||
AuthChallengeRequest auth_challenge_request = 1;
|
||||
AuthChallengeSolution auth_challenge_solution = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message ServerMessage {
|
||||
oneof payload {
|
||||
AuthChallenge auth_challenge = 1;
|
||||
AuthOk auth_ok = 2;
|
||||
}
|
||||
}
|
||||
@@ -1,17 +1,32 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package arbiter;
|
||||
package arbiter.client;
|
||||
|
||||
import "auth.proto";
|
||||
message AuthChallengeRequest {
|
||||
bytes pubkey = 1;
|
||||
}
|
||||
|
||||
message AuthChallenge {
|
||||
bytes pubkey = 1;
|
||||
int32 nonce = 2;
|
||||
}
|
||||
|
||||
message AuthChallengeSolution {
|
||||
bytes signature = 1;
|
||||
}
|
||||
|
||||
message AuthOk {}
|
||||
|
||||
message ClientRequest {
|
||||
oneof payload {
|
||||
arbiter.auth.ClientMessage auth_message = 1;
|
||||
AuthChallengeRequest auth_challenge_request = 1;
|
||||
AuthChallengeSolution auth_challenge_solution = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message ClientResponse {
|
||||
oneof payload {
|
||||
arbiter.auth.ServerMessage auth_message = 1;
|
||||
AuthChallenge auth_challenge = 1;
|
||||
AuthOk auth_ok = 2;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,25 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package arbiter;
|
||||
package arbiter.user_agent;
|
||||
|
||||
import "auth.proto";
|
||||
import "google/protobuf/empty.proto";
|
||||
|
||||
message AuthChallengeRequest {
|
||||
bytes pubkey = 1;
|
||||
optional string bootstrap_token = 2;
|
||||
}
|
||||
|
||||
message AuthChallenge {
|
||||
bytes pubkey = 1;
|
||||
int32 nonce = 2;
|
||||
}
|
||||
|
||||
message AuthChallengeSolution {
|
||||
bytes signature = 1;
|
||||
}
|
||||
|
||||
message AuthOk {}
|
||||
|
||||
message UnsealStart {
|
||||
bytes client_pubkey = 1;
|
||||
}
|
||||
@@ -35,17 +50,19 @@ enum VaultState {
|
||||
|
||||
message UserAgentRequest {
|
||||
oneof payload {
|
||||
arbiter.auth.ClientMessage auth_message = 1;
|
||||
UnsealStart unseal_start = 2;
|
||||
UnsealEncryptedKey unseal_encrypted_key = 3;
|
||||
google.protobuf.Empty query_vault_state = 4;
|
||||
AuthChallengeRequest auth_challenge_request = 1;
|
||||
AuthChallengeSolution auth_challenge_solution = 2;
|
||||
UnsealStart unseal_start = 3;
|
||||
UnsealEncryptedKey unseal_encrypted_key = 4;
|
||||
google.protobuf.Empty query_vault_state = 5;
|
||||
}
|
||||
}
|
||||
message UserAgentResponse {
|
||||
oneof payload {
|
||||
arbiter.auth.ServerMessage auth_message = 1;
|
||||
UnsealStartResponse unseal_start_response = 2;
|
||||
UnsealResult unseal_result = 3;
|
||||
VaultState vault_state = 4;
|
||||
AuthChallenge auth_challenge = 1;
|
||||
AuthOk auth_ok = 2;
|
||||
UnsealStartResponse unseal_start_response = 3;
|
||||
UnsealResult unseal_result = 4;
|
||||
VaultState vault_state = 5;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.compile_protos(
|
||||
&[
|
||||
format!("{}/arbiter.proto", PROTOBUF_DIR),
|
||||
format!("{}/auth.proto", PROTOBUF_DIR),
|
||||
format!("{}/user_agent.proto", PROTOBUF_DIR),
|
||||
format!("{}/client.proto", PROTOBUF_DIR),
|
||||
],
|
||||
&[PROTOBUF_DIR.to_string()],
|
||||
)
|
||||
|
||||
@@ -3,13 +3,15 @@ pub mod url;
|
||||
|
||||
use base64::{Engine, prelude::BASE64_STANDARD};
|
||||
|
||||
use crate::proto::auth::AuthChallenge;
|
||||
|
||||
pub mod proto {
|
||||
tonic::include_proto!("arbiter");
|
||||
|
||||
pub mod auth {
|
||||
tonic::include_proto!("arbiter.auth");
|
||||
pub mod user_agent {
|
||||
tonic::include_proto!("arbiter.user_agent");
|
||||
}
|
||||
|
||||
pub mod client {
|
||||
tonic::include_proto!("arbiter.client");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,7 +30,7 @@ pub fn home_path() -> Result<std::path::PathBuf, std::io::Error> {
|
||||
Ok(arbiter_home)
|
||||
}
|
||||
|
||||
pub fn format_challenge(challenge: &AuthChallenge) -> Vec<u8> {
|
||||
let concat_form = format!("{}:{}", challenge.nonce, BASE64_STANDARD.encode(&challenge.pubkey));
|
||||
concat_form.into_bytes().to_vec()
|
||||
pub fn format_challenge(nonce: i32, pubkey: &[u8]) -> Vec<u8> {
|
||||
let concat_form = format!("{}:{}", nonce, BASE64_STANDARD.encode(pubkey));
|
||||
concat_form.into_bytes()
|
||||
}
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
use arbiter_proto::{
|
||||
proto::{ClientRequest, ClientResponse},
|
||||
transport::Bi,
|
||||
};
|
||||
|
||||
use crate::ServerContext;
|
||||
|
||||
pub(crate) async fn handle_client(
|
||||
_context: ServerContext,
|
||||
_bistream: impl Bi<ClientRequest, ClientResponse>,
|
||||
) {
|
||||
}
|
||||
289
server/crates/arbiter-server/src/actors/client/mod.rs
Normal file
289
server/crates/arbiter-server/src/actors/client/mod.rs
Normal file
@@ -0,0 +1,289 @@
|
||||
use arbiter_proto::{
|
||||
proto::client::{
|
||||
AuthChallenge, AuthChallengeRequest, AuthChallengeSolution, AuthOk, ClientRequest,
|
||||
ClientResponse,
|
||||
client_request::Payload as ClientRequestPayload,
|
||||
client_response::Payload as ClientResponsePayload,
|
||||
},
|
||||
transport::{Bi, DummyTransport},
|
||||
};
|
||||
use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, dsl::update};
|
||||
use diesel_async::RunQueryDsl;
|
||||
use ed25519_dalek::VerifyingKey;
|
||||
use kameo::Actor;
|
||||
use tokio::select;
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::{
|
||||
ServerContext,
|
||||
actors::client::state::{
|
||||
ChallengeContext, ClientEvents, ClientStateMachine, ClientStates, DummyContext,
|
||||
},
|
||||
db::{self, schema},
|
||||
};
|
||||
|
||||
mod state;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
|
||||
pub enum ClientError {
|
||||
#[error("Expected message with payload")]
|
||||
MissingRequestPayload,
|
||||
#[error("Unexpected request payload")]
|
||||
UnexpectedRequestPayload,
|
||||
#[error("Invalid state for challenge solution")]
|
||||
InvalidStateForChallengeSolution,
|
||||
#[error("Expected pubkey to have specific length")]
|
||||
InvalidAuthPubkeyLength,
|
||||
#[error("Failed to convert pubkey to VerifyingKey")]
|
||||
InvalidAuthPubkeyEncoding,
|
||||
#[error("Invalid signature length")]
|
||||
InvalidSignatureLength,
|
||||
#[error("Public key not registered")]
|
||||
PublicKeyNotRegistered,
|
||||
#[error("Invalid challenge solution")]
|
||||
InvalidChallengeSolution,
|
||||
#[error("State machine error")]
|
||||
StateTransitionFailed,
|
||||
#[error("Database pool error")]
|
||||
DatabasePoolUnavailable,
|
||||
#[error("Database error")]
|
||||
DatabaseOperationFailed,
|
||||
}
|
||||
|
||||
pub struct ClientActor<Transport>
|
||||
where
|
||||
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
|
||||
{
|
||||
db: db::DatabasePool,
|
||||
state: ClientStateMachine<DummyContext>,
|
||||
transport: Transport,
|
||||
}
|
||||
|
||||
impl<Transport> ClientActor<Transport>
|
||||
where
|
||||
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
|
||||
{
|
||||
pub(crate) fn new(context: ServerContext, transport: Transport) -> Self {
|
||||
Self {
|
||||
db: context.db.clone(),
|
||||
state: ClientStateMachine::new(DummyContext),
|
||||
transport,
|
||||
}
|
||||
}
|
||||
|
||||
fn transition(&mut self, event: ClientEvents) -> Result<(), ClientError> {
|
||||
self.state.process_event(event).map_err(|e| {
|
||||
error!(?e, "State transition failed");
|
||||
ClientError::StateTransitionFailed
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn process_transport_inbound(&mut self, req: ClientRequest) -> Output {
|
||||
let msg = req.payload.ok_or_else(|| {
|
||||
error!(actor = "client", "Received message with no payload");
|
||||
ClientError::MissingRequestPayload
|
||||
})?;
|
||||
|
||||
match msg {
|
||||
ClientRequestPayload::AuthChallengeRequest(req) => {
|
||||
self.handle_auth_challenge_request(req).await
|
||||
}
|
||||
ClientRequestPayload::AuthChallengeSolution(solution) => {
|
||||
self.handle_auth_challenge_solution(solution).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_auth_challenge_request(&mut self, req: AuthChallengeRequest) -> Output {
|
||||
let pubkey = req
|
||||
.pubkey
|
||||
.as_array()
|
||||
.ok_or(ClientError::InvalidAuthPubkeyLength)?;
|
||||
let pubkey = VerifyingKey::from_bytes(pubkey).map_err(|_err| {
|
||||
error!(?pubkey, "Failed to convert to VerifyingKey");
|
||||
ClientError::InvalidAuthPubkeyEncoding
|
||||
})?;
|
||||
|
||||
self.transition(ClientEvents::AuthRequest)?;
|
||||
|
||||
self.auth_with_challenge(pubkey, req.pubkey).await
|
||||
}
|
||||
|
||||
async fn auth_with_challenge(&mut self, pubkey: VerifyingKey, pubkey_bytes: Vec<u8>) -> Output {
|
||||
let nonce: Option<i32> = {
|
||||
let mut db_conn = self.db.get().await.map_err(|e| {
|
||||
error!(error = ?e, "Database pool error");
|
||||
ClientError::DatabasePoolUnavailable
|
||||
})?;
|
||||
db_conn
|
||||
.exclusive_transaction(|conn| {
|
||||
Box::pin(async move {
|
||||
let current_nonce = schema::program_client::table
|
||||
.filter(
|
||||
schema::program_client::public_key.eq(pubkey.as_bytes().to_vec()),
|
||||
)
|
||||
.select(schema::program_client::nonce)
|
||||
.first::<i32>(conn)
|
||||
.await?;
|
||||
|
||||
update(schema::program_client::table)
|
||||
.filter(
|
||||
schema::program_client::public_key.eq(pubkey.as_bytes().to_vec()),
|
||||
)
|
||||
.set(schema::program_client::nonce.eq(current_nonce + 1))
|
||||
.execute(conn)
|
||||
.await?;
|
||||
|
||||
Result::<_, diesel::result::Error>::Ok(current_nonce)
|
||||
})
|
||||
})
|
||||
.await
|
||||
.optional()
|
||||
.map_err(|e| {
|
||||
error!(error = ?e, "Database error");
|
||||
ClientError::DatabaseOperationFailed
|
||||
})?
|
||||
};
|
||||
|
||||
let Some(nonce) = nonce else {
|
||||
error!(?pubkey, "Public key not found in database");
|
||||
return Err(ClientError::PublicKeyNotRegistered);
|
||||
};
|
||||
|
||||
let challenge = AuthChallenge {
|
||||
pubkey: pubkey_bytes,
|
||||
nonce,
|
||||
};
|
||||
|
||||
self.transition(ClientEvents::SentChallenge(ChallengeContext {
|
||||
challenge: challenge.clone(),
|
||||
key: pubkey,
|
||||
}))?;
|
||||
|
||||
info!(
|
||||
?pubkey,
|
||||
?challenge,
|
||||
"Sent authentication challenge to client"
|
||||
);
|
||||
|
||||
Ok(response(ClientResponsePayload::AuthChallenge(challenge)))
|
||||
}
|
||||
|
||||
fn verify_challenge_solution(
|
||||
&self,
|
||||
solution: &AuthChallengeSolution,
|
||||
) -> Result<(bool, &ChallengeContext), ClientError> {
|
||||
let ClientStates::WaitingForChallengeSolution(challenge_context) = self.state.state()
|
||||
else {
|
||||
error!("Received challenge solution in invalid state");
|
||||
return Err(ClientError::InvalidStateForChallengeSolution);
|
||||
};
|
||||
let formatted_challenge = arbiter_proto::format_challenge(
|
||||
challenge_context.challenge.nonce,
|
||||
&challenge_context.challenge.pubkey,
|
||||
);
|
||||
|
||||
let signature = solution.signature.as_slice().try_into().map_err(|_| {
|
||||
error!(?solution, "Invalid signature length");
|
||||
ClientError::InvalidSignatureLength
|
||||
})?;
|
||||
|
||||
let valid = challenge_context
|
||||
.key
|
||||
.verify_strict(&formatted_challenge, &signature)
|
||||
.is_ok();
|
||||
|
||||
Ok((valid, challenge_context))
|
||||
}
|
||||
|
||||
async fn handle_auth_challenge_solution(
|
||||
&mut self,
|
||||
solution: AuthChallengeSolution,
|
||||
) -> Output {
|
||||
let (valid, challenge_context) = self.verify_challenge_solution(&solution)?;
|
||||
|
||||
if valid {
|
||||
info!(
|
||||
?challenge_context,
|
||||
"Client provided valid solution to authentication challenge"
|
||||
);
|
||||
self.transition(ClientEvents::ReceivedGoodSolution)?;
|
||||
Ok(response(ClientResponsePayload::AuthOk(AuthOk {})))
|
||||
} else {
|
||||
error!("Client provided invalid solution to authentication challenge");
|
||||
self.transition(ClientEvents::ReceivedBadSolution)?;
|
||||
Err(ClientError::InvalidChallengeSolution)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Output = Result<ClientResponse, ClientError>;
|
||||
|
||||
fn response(payload: ClientResponsePayload) -> ClientResponse {
|
||||
ClientResponse {
|
||||
payload: Some(payload),
|
||||
}
|
||||
}
|
||||
|
||||
impl<Transport> Actor for ClientActor<Transport>
|
||||
where
|
||||
Transport: Bi<ClientRequest, Result<ClientResponse, ClientError>>,
|
||||
{
|
||||
type Args = Self;
|
||||
|
||||
type Error = ();
|
||||
|
||||
async fn on_start(
|
||||
args: Self::Args,
|
||||
_: kameo::prelude::ActorRef<Self>,
|
||||
) -> Result<Self, Self::Error> {
|
||||
Ok(args)
|
||||
}
|
||||
|
||||
async fn next(
|
||||
&mut self,
|
||||
_actor_ref: kameo::prelude::WeakActorRef<Self>,
|
||||
mailbox_rx: &mut kameo::prelude::MailboxReceiver<Self>,
|
||||
) -> Option<kameo::mailbox::Signal<Self>> {
|
||||
loop {
|
||||
select! {
|
||||
signal = mailbox_rx.recv() => {
|
||||
return signal;
|
||||
}
|
||||
msg = self.transport.recv() => {
|
||||
match msg {
|
||||
Some(request) => {
|
||||
match self.process_transport_inbound(request).await {
|
||||
Ok(resp) => {
|
||||
if self.transport.send(Ok(resp)).await.is_err() {
|
||||
error!(actor = "client", reason = "channel closed", "send.failed");
|
||||
return Some(kameo::mailbox::Signal::Stop);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = self.transport.send(Err(err)).await;
|
||||
return Some(kameo::mailbox::Signal::Stop);
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
info!(actor = "client", "transport.closed");
|
||||
return Some(kameo::mailbox::Signal::Stop);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientActor<DummyTransport<ClientRequest, Result<ClientResponse, ClientError>>> {
|
||||
pub fn new_manual(db: db::DatabasePool) -> Self {
|
||||
Self {
|
||||
db,
|
||||
state: ClientStateMachine::new(DummyContext),
|
||||
transport: DummyTransport::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
31
server/crates/arbiter-server/src/actors/client/state.rs
Normal file
31
server/crates/arbiter-server/src/actors/client/state.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
use arbiter_proto::proto::client::AuthChallenge;
|
||||
use ed25519_dalek::VerifyingKey;
|
||||
|
||||
/// Context for state machine with validated key and sent challenge
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ChallengeContext {
|
||||
pub challenge: AuthChallenge,
|
||||
pub key: VerifyingKey,
|
||||
}
|
||||
|
||||
smlang::statemachine!(
|
||||
name: Client,
|
||||
custom_error: false,
|
||||
transitions: {
|
||||
*Init + AuthRequest = ReceivedAuthRequest,
|
||||
|
||||
ReceivedAuthRequest + SentChallenge(ChallengeContext) / move_challenge = WaitingForChallengeSolution(ChallengeContext),
|
||||
|
||||
WaitingForChallengeSolution(ChallengeContext) + ReceivedGoodSolution = Idle,
|
||||
WaitingForChallengeSolution(ChallengeContext) + ReceivedBadSolution = AuthError,
|
||||
}
|
||||
);
|
||||
|
||||
pub struct DummyContext;
|
||||
impl ClientStateMachineContext for DummyContext {
|
||||
#[allow(missing_docs)]
|
||||
#[allow(clippy::unused_unit)]
|
||||
fn move_challenge(&mut self, event_data: ChallengeContext) -> Result<ChallengeContext, ()> {
|
||||
Ok(event_data)
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,9 @@
|
||||
use std::{ops::DerefMut, sync::Mutex};
|
||||
|
||||
use arbiter_proto::{
|
||||
proto::{
|
||||
UnsealEncryptedKey, UnsealResult, UnsealStart, UnsealStartResponse, UserAgentRequest,
|
||||
UserAgentResponse,
|
||||
auth::{
|
||||
self, AuthChallengeRequest, AuthOk, ClientMessage as ClientAuthMessage,
|
||||
ServerMessage as AuthServerMessage, client_message::Payload as ClientAuthPayload,
|
||||
server_message::Payload as ServerAuthPayload,
|
||||
},
|
||||
proto::user_agent::{
|
||||
AuthChallenge, AuthChallengeRequest, AuthChallengeSolution, AuthOk, UnsealEncryptedKey,
|
||||
UnsealResult, UnsealStart, UnsealStartResponse, UserAgentRequest, UserAgentResponse,
|
||||
user_agent_request::Payload as UserAgentRequestPayload,
|
||||
user_agent_response::Payload as UserAgentResponsePayload,
|
||||
},
|
||||
@@ -114,12 +109,12 @@ where
|
||||
})?;
|
||||
|
||||
match msg {
|
||||
UserAgentRequestPayload::AuthMessage(ClientAuthMessage {
|
||||
payload: Some(ClientAuthPayload::AuthChallengeRequest(req)),
|
||||
}) => self.handle_auth_challenge_request(req).await,
|
||||
UserAgentRequestPayload::AuthMessage(ClientAuthMessage {
|
||||
payload: Some(ClientAuthPayload::AuthChallengeSolution(solution)),
|
||||
}) => self.handle_auth_challenge_solution(solution).await,
|
||||
UserAgentRequestPayload::AuthChallengeRequest(req) => {
|
||||
self.handle_auth_challenge_request(req).await
|
||||
}
|
||||
UserAgentRequestPayload::AuthChallengeSolution(solution) => {
|
||||
self.handle_auth_challenge_solution(solution).await
|
||||
}
|
||||
UserAgentRequestPayload::UnsealStart(unseal_start) => {
|
||||
self.handle_unseal_request(unseal_start).await
|
||||
}
|
||||
@@ -171,7 +166,7 @@ where
|
||||
|
||||
self.transition(UserAgentEvents::ReceivedBootstrapToken)?;
|
||||
|
||||
Ok(auth_response(ServerAuthPayload::AuthOk(AuthOk {})))
|
||||
Ok(response(UserAgentResponsePayload::AuthOk(AuthOk {})))
|
||||
}
|
||||
|
||||
async fn auth_with_challenge(&mut self, pubkey: VerifyingKey, pubkey_bytes: Vec<u8>) -> Output {
|
||||
@@ -215,7 +210,7 @@ where
|
||||
return Err(UserAgentError::PublicKeyNotRegistered);
|
||||
};
|
||||
|
||||
let challenge = auth::AuthChallenge {
|
||||
let challenge = AuthChallenge {
|
||||
pubkey: pubkey_bytes,
|
||||
nonce,
|
||||
};
|
||||
@@ -231,19 +226,22 @@ where
|
||||
"Sent authentication challenge to client"
|
||||
);
|
||||
|
||||
Ok(auth_response(ServerAuthPayload::AuthChallenge(challenge)))
|
||||
Ok(response(UserAgentResponsePayload::AuthChallenge(challenge)))
|
||||
}
|
||||
|
||||
fn verify_challenge_solution(
|
||||
&self,
|
||||
solution: &auth::AuthChallengeSolution,
|
||||
solution: &AuthChallengeSolution,
|
||||
) -> Result<(bool, &ChallengeContext), UserAgentError> {
|
||||
let UserAgentStates::WaitingForChallengeSolution(challenge_context) = self.state.state()
|
||||
else {
|
||||
error!("Received challenge solution in invalid state");
|
||||
return Err(UserAgentError::InvalidStateForChallengeSolution);
|
||||
};
|
||||
let formatted_challenge = arbiter_proto::format_challenge(&challenge_context.challenge);
|
||||
let formatted_challenge = arbiter_proto::format_challenge(
|
||||
challenge_context.challenge.nonce,
|
||||
&challenge_context.challenge.pubkey,
|
||||
);
|
||||
|
||||
let signature = solution.signature.as_slice().try_into().map_err(|_| {
|
||||
error!(?solution, "Invalid signature length");
|
||||
@@ -261,15 +259,7 @@ where
|
||||
|
||||
type Output = Result<UserAgentResponse, UserAgentError>;
|
||||
|
||||
fn auth_response(payload: ServerAuthPayload) -> UserAgentResponse {
|
||||
UserAgentResponse {
|
||||
payload: Some(UserAgentResponsePayload::AuthMessage(AuthServerMessage {
|
||||
payload: Some(payload),
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
fn unseal_response(payload: UserAgentResponsePayload) -> UserAgentResponse {
|
||||
fn response(payload: UserAgentResponsePayload) -> UserAgentResponse {
|
||||
UserAgentResponse {
|
||||
payload: Some(payload),
|
||||
}
|
||||
@@ -295,7 +285,7 @@ where
|
||||
client_public_key,
|
||||
}))?;
|
||||
|
||||
Ok(unseal_response(
|
||||
Ok(response(
|
||||
UserAgentResponsePayload::UnsealStartResponse(UnsealStartResponse {
|
||||
server_pubkey: public_key.as_bytes().to_vec(),
|
||||
}),
|
||||
@@ -316,7 +306,7 @@ where
|
||||
drop(secret_lock);
|
||||
error!("Ephemeral secret already taken");
|
||||
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
|
||||
return Ok(unseal_response(UserAgentResponsePayload::UnsealResult(
|
||||
return Ok(response(UserAgentResponsePayload::UnsealResult(
|
||||
UnsealResult::InvalidKey.into(),
|
||||
)));
|
||||
}
|
||||
@@ -349,20 +339,20 @@ where
|
||||
Ok(_) => {
|
||||
info!("Successfully unsealed key with client-provided key");
|
||||
self.transition(UserAgentEvents::ReceivedValidKey)?;
|
||||
Ok(unseal_response(UserAgentResponsePayload::UnsealResult(
|
||||
Ok(response(UserAgentResponsePayload::UnsealResult(
|
||||
UnsealResult::Success.into(),
|
||||
)))
|
||||
}
|
||||
Err(SendError::HandlerError(keyholder::Error::InvalidKey)) => {
|
||||
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
|
||||
Ok(unseal_response(UserAgentResponsePayload::UnsealResult(
|
||||
Ok(response(UserAgentResponsePayload::UnsealResult(
|
||||
UnsealResult::InvalidKey.into(),
|
||||
)))
|
||||
}
|
||||
Err(SendError::HandlerError(err)) => {
|
||||
error!(?err, "Keyholder failed to unseal key");
|
||||
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
|
||||
Ok(unseal_response(UserAgentResponsePayload::UnsealResult(
|
||||
Ok(response(UserAgentResponsePayload::UnsealResult(
|
||||
UnsealResult::InvalidKey.into(),
|
||||
)))
|
||||
}
|
||||
@@ -376,7 +366,7 @@ where
|
||||
Err(err) => {
|
||||
error!(?err, "Failed to decrypt unseal key");
|
||||
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
|
||||
Ok(unseal_response(UserAgentResponsePayload::UnsealResult(
|
||||
Ok(response(UserAgentResponsePayload::UnsealResult(
|
||||
UnsealResult::InvalidKey.into(),
|
||||
)))
|
||||
}
|
||||
@@ -403,7 +393,7 @@ where
|
||||
|
||||
async fn handle_auth_challenge_solution(
|
||||
&mut self,
|
||||
solution: auth::AuthChallengeSolution,
|
||||
solution: AuthChallengeSolution,
|
||||
) -> Output {
|
||||
let (valid, challenge_context) = self.verify_challenge_solution(&solution)?;
|
||||
|
||||
@@ -413,7 +403,7 @@ where
|
||||
"Client provided valid solution to authentication challenge"
|
||||
);
|
||||
self.transition(UserAgentEvents::ReceivedGoodSolution)?;
|
||||
Ok(auth_response(ServerAuthPayload::AuthOk(AuthOk {})))
|
||||
Ok(response(UserAgentResponsePayload::AuthOk(AuthOk {})))
|
||||
} else {
|
||||
error!("Client provided invalid solution to authentication challenge");
|
||||
self.transition(UserAgentEvents::ReceivedBadSolution)?;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::sync::Mutex;
|
||||
|
||||
use arbiter_proto::proto::auth::AuthChallenge;
|
||||
use arbiter_proto::proto::user_agent::AuthChallenge;
|
||||
use ed25519_dalek::VerifyingKey;
|
||||
use x25519_dalek::{EphemeralSecret, PublicKey};
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
#![forbid(unsafe_code)]
|
||||
use arbiter_proto::{
|
||||
proto::{ClientRequest, ClientResponse, UserAgentRequest, UserAgentResponse},
|
||||
proto::{
|
||||
client::{ClientRequest, ClientResponse},
|
||||
user_agent::{UserAgentRequest, UserAgentResponse},
|
||||
},
|
||||
transport::{IdentityRecvConverter, SendConverter, grpc},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
@@ -12,7 +15,10 @@ use tonic::{Request, Response, Status};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
actors::user_agent::{UserAgentActor, UserAgentError},
|
||||
actors::{
|
||||
client::{ClientActor, ClientError},
|
||||
user_agent::{UserAgentActor, UserAgentError},
|
||||
},
|
||||
context::ServerContext,
|
||||
};
|
||||
|
||||
@@ -41,6 +47,56 @@ impl SendConverter for UserAgentGrpcSender {
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts Client domain outbounds into the tonic stream item emitted by the
|
||||
/// server.
|
||||
///
|
||||
/// The conversion is defined at the server boundary so the actor module remains
|
||||
/// focused on domain semantics and does not depend on tonic status encoding.
|
||||
struct ClientGrpcSender;
|
||||
|
||||
impl SendConverter for ClientGrpcSender {
|
||||
type Input = Result<ClientResponse, ClientError>;
|
||||
type Output = Result<ClientResponse, Status>;
|
||||
|
||||
fn convert(&self, item: Self::Input) -> Self::Output {
|
||||
match item {
|
||||
Ok(message) => Ok(message),
|
||||
Err(err) => Err(client_error_status(err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Maps Client domain errors to public gRPC transport errors for the `client`
|
||||
/// streaming endpoint.
|
||||
fn client_error_status(value: ClientError) -> Status {
|
||||
match value {
|
||||
ClientError::MissingRequestPayload | ClientError::UnexpectedRequestPayload => {
|
||||
Status::invalid_argument("Expected message with payload")
|
||||
}
|
||||
ClientError::InvalidStateForChallengeSolution => {
|
||||
Status::invalid_argument("Invalid state for challenge solution")
|
||||
}
|
||||
ClientError::InvalidAuthPubkeyLength => {
|
||||
Status::invalid_argument("Expected pubkey to have specific length")
|
||||
}
|
||||
ClientError::InvalidAuthPubkeyEncoding => {
|
||||
Status::invalid_argument("Failed to convert pubkey to VerifyingKey")
|
||||
}
|
||||
ClientError::InvalidSignatureLength => {
|
||||
Status::invalid_argument("Invalid signature length")
|
||||
}
|
||||
ClientError::PublicKeyNotRegistered => {
|
||||
Status::unauthenticated("Public key not registered")
|
||||
}
|
||||
ClientError::InvalidChallengeSolution => {
|
||||
Status::unauthenticated("Invalid challenge solution")
|
||||
}
|
||||
ClientError::StateTransitionFailed => Status::internal("State machine error"),
|
||||
ClientError::DatabasePoolUnavailable => Status::internal("Database pool error"),
|
||||
ClientError::DatabaseOperationFailed => Status::internal("Database error"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Maps User Agent domain errors to public gRPC transport errors for the
|
||||
/// `user_agent` streaming endpoint.
|
||||
fn user_agent_error_status(value: UserAgentError) -> Status {
|
||||
@@ -100,11 +156,25 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server {
|
||||
type UserAgentStream = ReceiverStream<Result<UserAgentResponse, Status>>;
|
||||
type ClientStream = ReceiverStream<Result<ClientResponse, Status>>;
|
||||
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
async fn client(
|
||||
&self,
|
||||
_request: Request<tonic::Streaming<ClientRequest>>,
|
||||
request: Request<tonic::Streaming<ClientRequest>>,
|
||||
) -> Result<Response<Self::ClientStream>, Status> {
|
||||
todo!()
|
||||
let req_stream = request.into_inner();
|
||||
let (tx, rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
|
||||
|
||||
let transport = grpc::GrpcAdapter::new(
|
||||
tx,
|
||||
req_stream,
|
||||
IdentityRecvConverter::<ClientRequest>::new(),
|
||||
ClientGrpcSender,
|
||||
);
|
||||
ClientActor::spawn(ClientActor::new(self.context.clone(), transport));
|
||||
|
||||
info!(event = "connection established", "grpc.client");
|
||||
|
||||
Ok(Response::new(ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
|
||||
2
server/crates/arbiter-server/tests/client.rs
Normal file
2
server/crates/arbiter-server/tests/client.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
#[path = "client/auth.rs"]
|
||||
mod auth;
|
||||
102
server/crates/arbiter-server/tests/client/auth.rs
Normal file
102
server/crates/arbiter-server/tests/client/auth.rs
Normal file
@@ -0,0 +1,102 @@
|
||||
use arbiter_proto::proto::client::{
|
||||
AuthChallengeRequest, AuthChallengeSolution, AuthOk, ClientRequest, ClientResponse,
|
||||
client_request::Payload as ClientRequestPayload,
|
||||
client_response::Payload as ClientResponsePayload,
|
||||
};
|
||||
use arbiter_server::{
|
||||
actors::client::{ClientActor, ClientError},
|
||||
db::{self, schema},
|
||||
};
|
||||
use diesel::{ExpressionMethods as _, insert_into};
|
||||
use diesel_async::RunQueryDsl;
|
||||
use ed25519_dalek::Signer as _;
|
||||
|
||||
#[tokio::test]
|
||||
#[test_log::test]
|
||||
pub async fn test_unregistered_pubkey_rejected() {
|
||||
let db = db::create_test_pool().await;
|
||||
|
||||
let mut client = ClientActor::new_manual(db.clone());
|
||||
|
||||
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
|
||||
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
|
||||
|
||||
let result = client
|
||||
.process_transport_inbound(ClientRequest {
|
||||
payload: Some(ClientRequestPayload::AuthChallengeRequest(
|
||||
AuthChallengeRequest {
|
||||
pubkey: pubkey_bytes,
|
||||
},
|
||||
)),
|
||||
})
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Err(err) => {
|
||||
assert_eq!(err, ClientError::PublicKeyNotRegistered);
|
||||
}
|
||||
Ok(_) => {
|
||||
panic!("Expected error due to unregistered pubkey, but got success");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test_log::test]
|
||||
pub async fn test_challenge_auth() {
|
||||
let db = db::create_test_pool().await;
|
||||
|
||||
let mut client = ClientActor::new_manual(db.clone());
|
||||
|
||||
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
|
||||
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
|
||||
|
||||
{
|
||||
let mut conn = db.get().await.unwrap();
|
||||
insert_into(schema::program_client::table)
|
||||
.values(schema::program_client::public_key.eq(pubkey_bytes.clone()))
|
||||
.execute(&mut conn)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let result = client
|
||||
.process_transport_inbound(ClientRequest {
|
||||
payload: Some(ClientRequestPayload::AuthChallengeRequest(
|
||||
AuthChallengeRequest {
|
||||
pubkey: pubkey_bytes,
|
||||
},
|
||||
)),
|
||||
})
|
||||
.await
|
||||
.expect("Shouldn't fail to process message");
|
||||
|
||||
let ClientResponse {
|
||||
payload: Some(ClientResponsePayload::AuthChallenge(challenge)),
|
||||
} = result
|
||||
else {
|
||||
panic!("Expected auth challenge response, got {result:?}");
|
||||
};
|
||||
|
||||
let formatted_challenge = arbiter_proto::format_challenge(challenge.nonce, &challenge.pubkey);
|
||||
let signature = new_key.sign(&formatted_challenge);
|
||||
let serialized_signature = signature.to_bytes().to_vec();
|
||||
|
||||
let result = client
|
||||
.process_transport_inbound(ClientRequest {
|
||||
payload: Some(ClientRequestPayload::AuthChallengeSolution(
|
||||
AuthChallengeSolution {
|
||||
signature: serialized_signature,
|
||||
},
|
||||
)),
|
||||
})
|
||||
.await
|
||||
.expect("Shouldn't fail to process message");
|
||||
|
||||
assert_eq!(
|
||||
result,
|
||||
ClientResponse {
|
||||
payload: Some(ClientResponsePayload::AuthOk(AuthOk {})),
|
||||
}
|
||||
);
|
||||
}
|
||||
@@ -1,7 +1,5 @@
|
||||
use arbiter_proto::proto::{
|
||||
UserAgentResponse,
|
||||
UserAgentRequest,
|
||||
auth::{self, AuthChallengeRequest, AuthOk, ClientMessage, client_message::Payload as ClientAuthPayload},
|
||||
use arbiter_proto::proto::user_agent::{
|
||||
AuthChallengeRequest, AuthChallengeSolution, AuthOk, UserAgentRequest, UserAgentResponse,
|
||||
user_agent_request::Payload as UserAgentRequestPayload,
|
||||
user_agent_response::Payload as UserAgentResponsePayload,
|
||||
};
|
||||
@@ -17,18 +15,10 @@ use diesel::{ExpressionMethods as _, QueryDsl, insert_into};
|
||||
use diesel_async::RunQueryDsl;
|
||||
use ed25519_dalek::Signer as _;
|
||||
|
||||
fn auth_request(payload: ClientAuthPayload) -> UserAgentRequest {
|
||||
UserAgentRequest {
|
||||
payload: Some(UserAgentRequestPayload::AuthMessage(ClientMessage {
|
||||
payload: Some(payload),
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test_log::test]
|
||||
pub async fn test_bootstrap_token_auth() {
|
||||
let db =db::create_test_pool().await;
|
||||
let db = db::create_test_pool().await;
|
||||
|
||||
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
|
||||
let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap();
|
||||
@@ -38,25 +28,21 @@ pub async fn test_bootstrap_token_auth() {
|
||||
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
|
||||
|
||||
let result = user_agent
|
||||
.process_transport_inbound(auth_request(ClientAuthPayload::AuthChallengeRequest(
|
||||
AuthChallengeRequest {
|
||||
pubkey: pubkey_bytes,
|
||||
bootstrap_token: Some(token),
|
||||
},
|
||||
)))
|
||||
.process_transport_inbound(UserAgentRequest {
|
||||
payload: Some(UserAgentRequestPayload::AuthChallengeRequest(
|
||||
AuthChallengeRequest {
|
||||
pubkey: pubkey_bytes,
|
||||
bootstrap_token: Some(token),
|
||||
},
|
||||
)),
|
||||
})
|
||||
.await
|
||||
.expect("Shouldn't fail to process message");
|
||||
|
||||
assert_eq!(
|
||||
result,
|
||||
UserAgentResponse {
|
||||
payload: Some(UserAgentResponsePayload::AuthMessage(
|
||||
arbiter_proto::proto::auth::ServerMessage {
|
||||
payload: Some(arbiter_proto::proto::auth::server_message::Payload::AuthOk(
|
||||
AuthOk {},
|
||||
)),
|
||||
},
|
||||
)),
|
||||
payload: Some(UserAgentResponsePayload::AuthOk(AuthOk {})),
|
||||
}
|
||||
);
|
||||
|
||||
@@ -81,12 +67,14 @@ pub async fn test_bootstrap_invalid_token_auth() {
|
||||
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
|
||||
|
||||
let result = user_agent
|
||||
.process_transport_inbound(auth_request(ClientAuthPayload::AuthChallengeRequest(
|
||||
AuthChallengeRequest {
|
||||
pubkey: pubkey_bytes,
|
||||
bootstrap_token: Some("invalid_token".to_string()),
|
||||
},
|
||||
)))
|
||||
.process_transport_inbound(UserAgentRequest {
|
||||
payload: Some(UserAgentRequestPayload::AuthChallengeRequest(
|
||||
AuthChallengeRequest {
|
||||
pubkey: pubkey_bytes,
|
||||
bootstrap_token: Some("invalid_token".to_string()),
|
||||
},
|
||||
)),
|
||||
})
|
||||
.await;
|
||||
|
||||
match result {
|
||||
@@ -120,49 +108,43 @@ pub async fn test_challenge_auth() {
|
||||
}
|
||||
|
||||
let result = user_agent
|
||||
.process_transport_inbound(auth_request(ClientAuthPayload::AuthChallengeRequest(
|
||||
AuthChallengeRequest {
|
||||
pubkey: pubkey_bytes,
|
||||
bootstrap_token: None,
|
||||
},
|
||||
)))
|
||||
.process_transport_inbound(UserAgentRequest {
|
||||
payload: Some(UserAgentRequestPayload::AuthChallengeRequest(
|
||||
AuthChallengeRequest {
|
||||
pubkey: pubkey_bytes,
|
||||
bootstrap_token: None,
|
||||
},
|
||||
)),
|
||||
})
|
||||
.await
|
||||
.expect("Shouldn't fail to process message");
|
||||
|
||||
let UserAgentResponse {
|
||||
payload:
|
||||
Some(UserAgentResponsePayload::AuthMessage(arbiter_proto::proto::auth::ServerMessage {
|
||||
payload:
|
||||
Some(arbiter_proto::proto::auth::server_message::Payload::AuthChallenge(challenge)),
|
||||
})),
|
||||
payload: Some(UserAgentResponsePayload::AuthChallenge(challenge)),
|
||||
} = result
|
||||
else {
|
||||
panic!("Expected auth challenge response, got {result:?}");
|
||||
};
|
||||
|
||||
let formatted_challenge = arbiter_proto::format_challenge(&challenge);
|
||||
let formatted_challenge = arbiter_proto::format_challenge(challenge.nonce, &challenge.pubkey);
|
||||
let signature = new_key.sign(&formatted_challenge);
|
||||
let serialized_signature = signature.to_bytes().to_vec();
|
||||
|
||||
let result = user_agent
|
||||
.process_transport_inbound(auth_request(ClientAuthPayload::AuthChallengeSolution(
|
||||
auth::AuthChallengeSolution {
|
||||
signature: serialized_signature,
|
||||
},
|
||||
)))
|
||||
.process_transport_inbound(UserAgentRequest {
|
||||
payload: Some(UserAgentRequestPayload::AuthChallengeSolution(
|
||||
AuthChallengeSolution {
|
||||
signature: serialized_signature,
|
||||
},
|
||||
)),
|
||||
})
|
||||
.await
|
||||
.expect("Shouldn't fail to process message");
|
||||
|
||||
assert_eq!(
|
||||
result,
|
||||
UserAgentResponse {
|
||||
payload: Some(UserAgentResponsePayload::AuthMessage(
|
||||
arbiter_proto::proto::auth::ServerMessage {
|
||||
payload: Some(arbiter_proto::proto::auth::server_message::Payload::AuthOk(
|
||||
AuthOk {},
|
||||
)),
|
||||
},
|
||||
)),
|
||||
payload: Some(UserAgentResponsePayload::AuthOk(AuthOk {})),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use arbiter_proto::proto::{
|
||||
UnsealEncryptedKey, UnsealResult, UnsealStart, UserAgentRequest, UserAgentResponse,
|
||||
auth::{AuthChallengeRequest, ClientMessage, client_message::Payload as ClientAuthPayload},
|
||||
use arbiter_proto::proto::user_agent::{
|
||||
AuthChallengeRequest, UnsealEncryptedKey, UnsealResult, UnsealStart,
|
||||
UserAgentRequest, UserAgentResponse,
|
||||
user_agent_request::Payload as UserAgentRequestPayload,
|
||||
user_agent_response::Payload as UserAgentResponsePayload,
|
||||
};
|
||||
@@ -21,26 +21,6 @@ use x25519_dalek::{EphemeralSecret, PublicKey};
|
||||
type TestUserAgent =
|
||||
UserAgentActor<DummyTransport<UserAgentRequest, Result<UserAgentResponse, UserAgentError>>>;
|
||||
|
||||
fn auth_request(payload: ClientAuthPayload) -> UserAgentRequest {
|
||||
UserAgentRequest {
|
||||
payload: Some(UserAgentRequestPayload::AuthMessage(ClientMessage {
|
||||
payload: Some(payload),
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
fn unseal_start_request(req: UnsealStart) -> UserAgentRequest {
|
||||
UserAgentRequest {
|
||||
payload: Some(UserAgentRequestPayload::UnsealStart(req)),
|
||||
}
|
||||
}
|
||||
|
||||
fn unseal_key_request(req: UnsealEncryptedKey) -> UserAgentRequest {
|
||||
UserAgentRequest {
|
||||
payload: Some(UserAgentRequestPayload::UnsealEncryptedKey(req)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn setup_authenticated_user_agent(
|
||||
seal_key: &[u8],
|
||||
) -> (
|
||||
@@ -64,12 +44,14 @@ async fn setup_authenticated_user_agent(
|
||||
let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap();
|
||||
let auth_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
|
||||
user_agent
|
||||
.process_transport_inbound(auth_request(ClientAuthPayload::AuthChallengeRequest(
|
||||
AuthChallengeRequest {
|
||||
pubkey: auth_key.verifying_key().to_bytes().to_vec(),
|
||||
bootstrap_token: Some(token),
|
||||
},
|
||||
)))
|
||||
.process_transport_inbound(UserAgentRequest {
|
||||
payload: Some(UserAgentRequestPayload::AuthChallengeRequest(
|
||||
AuthChallengeRequest {
|
||||
pubkey: auth_key.verifying_key().to_bytes().to_vec(),
|
||||
bootstrap_token: Some(token),
|
||||
},
|
||||
)),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -84,9 +66,11 @@ async fn client_dh_encrypt(
|
||||
let client_public = PublicKey::from(&client_secret);
|
||||
|
||||
let response = user_agent
|
||||
.process_transport_inbound(unseal_start_request(UnsealStart {
|
||||
client_pubkey: client_public.as_bytes().to_vec(),
|
||||
}))
|
||||
.process_transport_inbound(UserAgentRequest {
|
||||
payload: Some(UserAgentRequestPayload::UnsealStart(UnsealStart {
|
||||
client_pubkey: client_public.as_bytes().to_vec(),
|
||||
})),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -112,6 +96,12 @@ async fn client_dh_encrypt(
|
||||
}
|
||||
}
|
||||
|
||||
fn unseal_key_request(req: UnsealEncryptedKey) -> UserAgentRequest {
|
||||
UserAgentRequest {
|
||||
payload: Some(UserAgentRequestPayload::UnsealEncryptedKey(req)),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[test_log::test]
|
||||
pub async fn test_unseal_success() {
|
||||
@@ -158,9 +148,11 @@ pub async fn test_unseal_corrupted_ciphertext() {
|
||||
let client_public = PublicKey::from(&client_secret);
|
||||
|
||||
user_agent
|
||||
.process_transport_inbound(unseal_start_request(UnsealStart {
|
||||
client_pubkey: client_public.as_bytes().to_vec(),
|
||||
}))
|
||||
.process_transport_inbound(UserAgentRequest {
|
||||
payload: Some(UserAgentRequestPayload::UnsealStart(UnsealStart {
|
||||
client_pubkey: client_public.as_bytes().to_vec(),
|
||||
})),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -191,9 +183,11 @@ pub async fn test_unseal_start_without_auth_fails() {
|
||||
let client_public = PublicKey::from(&client_secret);
|
||||
|
||||
let result = user_agent
|
||||
.process_transport_inbound(unseal_start_request(UnsealStart {
|
||||
client_pubkey: client_public.as_bytes().to_vec(),
|
||||
}))
|
||||
.process_transport_inbound(UserAgentRequest {
|
||||
payload: Some(UserAgentRequestPayload::UnsealStart(UnsealStart {
|
||||
client_pubkey: client_public.as_bytes().to_vec(),
|
||||
})),
|
||||
})
|
||||
.await;
|
||||
|
||||
match result {
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use arbiter_proto::{
|
||||
proto::{
|
||||
UserAgentRequest, UserAgentResponse, arbiter_service_client::ArbiterServiceClient,
|
||||
user_agent::{UserAgentRequest, UserAgentResponse},
|
||||
arbiter_service_client::ArbiterServiceClient,
|
||||
},
|
||||
transport::{IdentityRecvConverter, IdentitySendConverter, RecvConverter, grpc},
|
||||
transport::{IdentityRecvConverter, IdentitySendConverter, grpc},
|
||||
url::ArbiterUrl,
|
||||
};
|
||||
use ed25519_dalek::SigningKey;
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
use arbiter_proto::{
|
||||
format_challenge,
|
||||
proto::{
|
||||
proto::user_agent::{
|
||||
AuthChallengeRequest, AuthChallengeSolution, AuthOk,
|
||||
UserAgentRequest, UserAgentResponse,
|
||||
auth::{
|
||||
self, AuthChallengeRequest, AuthOk, ClientMessage as AuthClientMessage,
|
||||
ServerMessage as AuthServerMessage, client_message::Payload as ClientAuthPayload,
|
||||
server_message::Payload as ServerAuthPayload,
|
||||
},
|
||||
user_agent_request::Payload as UserAgentRequestPayload,
|
||||
user_agent_response::Payload as UserAgentResponsePayload,
|
||||
},
|
||||
@@ -81,14 +77,6 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn auth_request(payload: ClientAuthPayload) -> UserAgentRequest {
|
||||
UserAgentRequest {
|
||||
payload: Some(UserAgentRequestPayload::AuthMessage(AuthClientMessage {
|
||||
payload: Some(payload),
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_auth_challenge_request(&mut self) -> Result<(), InboundError> {
|
||||
let req = AuthChallengeRequest {
|
||||
pubkey: self.key.verifying_key().to_bytes().to_vec(),
|
||||
@@ -98,9 +86,9 @@ where
|
||||
self.transition(UserAgentEvents::SentAuthChallengeRequest)?;
|
||||
|
||||
self.transport
|
||||
.send(Self::auth_request(ClientAuthPayload::AuthChallengeRequest(
|
||||
req,
|
||||
)))
|
||||
.send(UserAgentRequest {
|
||||
payload: Some(UserAgentRequestPayload::AuthChallengeRequest(req)),
|
||||
})
|
||||
.await
|
||||
.map_err(|_| InboundError::TransportSendFailed)?;
|
||||
|
||||
@@ -110,20 +98,20 @@ where
|
||||
|
||||
async fn handle_auth_challenge(
|
||||
&mut self,
|
||||
challenge: auth::AuthChallenge,
|
||||
challenge: arbiter_proto::proto::user_agent::AuthChallenge,
|
||||
) -> Result<(), InboundError> {
|
||||
self.transition(UserAgentEvents::ReceivedAuthChallenge)?;
|
||||
|
||||
let formatted = format_challenge(&challenge);
|
||||
let formatted = format_challenge(challenge.nonce, &challenge.pubkey);
|
||||
let signature = self.key.sign(&formatted);
|
||||
let solution = auth::AuthChallengeSolution {
|
||||
let solution = AuthChallengeSolution {
|
||||
signature: signature.to_bytes().to_vec(),
|
||||
};
|
||||
|
||||
self.transport
|
||||
.send(Self::auth_request(
|
||||
ClientAuthPayload::AuthChallengeSolution(solution),
|
||||
))
|
||||
.send(UserAgentRequest {
|
||||
payload: Some(UserAgentRequestPayload::AuthChallengeSolution(solution)),
|
||||
})
|
||||
.await
|
||||
.map_err(|_| InboundError::TransportSendFailed)?;
|
||||
|
||||
@@ -146,12 +134,10 @@ where
|
||||
.ok_or(InboundError::MissingResponsePayload)?;
|
||||
|
||||
match payload {
|
||||
UserAgentResponsePayload::AuthMessage(AuthServerMessage {
|
||||
payload: Some(ServerAuthPayload::AuthChallenge(challenge)),
|
||||
}) => self.handle_auth_challenge(challenge).await,
|
||||
UserAgentResponsePayload::AuthMessage(AuthServerMessage {
|
||||
payload: Some(ServerAuthPayload::AuthOk(ok)),
|
||||
}) => self.handle_auth_ok(ok),
|
||||
UserAgentResponsePayload::AuthChallenge(challenge) => {
|
||||
self.handle_auth_challenge(challenge).await
|
||||
}
|
||||
UserAgentResponsePayload::AuthOk(ok) => self.handle_auth_ok(ok),
|
||||
_ => Err(InboundError::UnexpectedResponsePayload),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
use arbiter_proto::{
|
||||
format_challenge,
|
||||
proto::{
|
||||
proto::user_agent::{
|
||||
AuthChallenge, AuthOk,
|
||||
UserAgentRequest, UserAgentResponse,
|
||||
auth::{
|
||||
AuthChallenge, AuthOk, ClientMessage as AuthClientMessage,
|
||||
ServerMessage as AuthServerMessage, client_message::Payload as ClientAuthPayload,
|
||||
server_message::Payload as ServerAuthPayload,
|
||||
},
|
||||
user_agent_request::Payload as UserAgentRequestPayload,
|
||||
user_agent_response::Payload as UserAgentResponsePayload,
|
||||
},
|
||||
transport::Bi,
|
||||
};
|
||||
use arbiter_useragent::{InboundError, UserAgentActor};
|
||||
use arbiter_useragent::UserAgentActor;
|
||||
use ed25519_dalek::SigningKey;
|
||||
use kameo::actor::Spawn;
|
||||
use tokio::sync::mpsc;
|
||||
@@ -57,14 +53,6 @@ fn test_key() -> SigningKey {
|
||||
SigningKey::from_bytes(&[7u8; 32])
|
||||
}
|
||||
|
||||
fn auth_response(payload: ServerAuthPayload) -> UserAgentResponse {
|
||||
UserAgentResponse {
|
||||
payload: Some(UserAgentResponsePayload::AuthMessage(AuthServerMessage {
|
||||
payload: Some(payload),
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sends_auth_request_on_start_with_bootstrap_token() {
|
||||
let key = test_key();
|
||||
@@ -80,9 +68,7 @@ async fn sends_auth_request_on_start_with_bootstrap_token() {
|
||||
.expect("channel closed before auth request");
|
||||
|
||||
let UserAgentRequest {
|
||||
payload: Some(UserAgentRequestPayload::AuthMessage(AuthClientMessage {
|
||||
payload: Some(ClientAuthPayload::AuthChallengeRequest(req)),
|
||||
})),
|
||||
payload: Some(UserAgentRequestPayload::AuthChallengeRequest(req)),
|
||||
} = outbound
|
||||
else {
|
||||
panic!("expected auth challenge request");
|
||||
@@ -113,7 +99,9 @@ async fn challenge_flow_sends_solution_from_transport_inbound() {
|
||||
nonce: 42,
|
||||
};
|
||||
inbound_tx
|
||||
.send(auth_response(ServerAuthPayload::AuthChallenge(challenge.clone())))
|
||||
.send(UserAgentResponse {
|
||||
payload: Some(UserAgentResponsePayload::AuthChallenge(challenge.clone())),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -123,15 +111,13 @@ async fn challenge_flow_sends_solution_from_transport_inbound() {
|
||||
.expect("missing challenge solution");
|
||||
|
||||
let UserAgentRequest {
|
||||
payload: Some(UserAgentRequestPayload::AuthMessage(AuthClientMessage {
|
||||
payload: Some(ClientAuthPayload::AuthChallengeSolution(solution)),
|
||||
})),
|
||||
payload: Some(UserAgentRequestPayload::AuthChallengeSolution(solution)),
|
||||
} = outbound
|
||||
else {
|
||||
panic!("expected auth challenge solution");
|
||||
};
|
||||
|
||||
let formatted = format_challenge(&challenge);
|
||||
let formatted = format_challenge(challenge.nonce, &challenge.pubkey);
|
||||
let sig: ed25519_dalek::Signature = solution
|
||||
.signature
|
||||
.as_slice()
|
||||
@@ -142,7 +128,9 @@ async fn challenge_flow_sends_solution_from_transport_inbound() {
|
||||
.expect("solution signature should verify");
|
||||
|
||||
inbound_tx
|
||||
.send(auth_response(ServerAuthPayload::AuthOk(AuthOk {})))
|
||||
.send(UserAgentResponse {
|
||||
payload: Some(UserAgentResponsePayload::AuthOk(AuthOk {})),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user