diff --git a/mise.lock b/mise.lock index d83a929..2d2300e 100644 --- a/mise.lock +++ b/mise.lock @@ -10,6 +10,10 @@ backend = "cargo:cargo-features" version = "0.11.1" backend = "cargo:cargo-features-manager" +[[tools."cargo:cargo-nextest"]] +version = "0.9.126" +backend = "cargo:cargo-nextest" + [[tools."cargo:cargo-vet"]] version = "0.10.2" backend = "cargo:cargo-vet" diff --git a/mise.toml b/mise.toml index 8623f45..b9682c7 100644 --- a/mise.toml +++ b/mise.toml @@ -7,3 +7,4 @@ flutter = "3.38.9-stable" protoc = "29.6" rust = "1.93.0" "cargo:cargo-features-manager" = "0.11.1" +"cargo:cargo-nextest" = "0.9.126" diff --git a/protobufs/auth.proto b/protobufs/auth.proto index 34eb330..33a5e50 100644 --- a/protobufs/auth.proto +++ b/protobufs/auth.proto @@ -11,13 +11,11 @@ message AuthChallengeRequest { message AuthChallenge { bytes pubkey = 1; - bytes nonce = 2; - google.protobuf.Timestamp minted = 3; + int32 nonce = 2; } message AuthChallengeSolution { - AuthChallenge challenge = 1; - bytes signature = 2; + bytes signature = 1; } message AuthOk {} diff --git a/server/Cargo.lock b/server/Cargo.lock index b550575..2b8d6e5 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -61,6 +61,7 @@ version = "0.1.0" dependencies = [ "bytes", "futures", + "hex", "kameo", "prost", "prost-build", @@ -93,6 +94,7 @@ dependencies = [ "kameo", "memsafe", "miette", + "prost-types", "rand", "rcgen", "restructed", @@ -1017,6 +1019,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "http" version = "1.4.0" diff --git a/server/Cargo.toml b/server/Cargo.toml index 3d439d2..f779e65 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -25,4 +25,5 @@ thiserror = "2.0.18" async-trait = "0.1.89" futures = "0.3.31" tokio-stream = { version = "0.1.18", features = ["full"] } -kameo = "0.19.2" \ No newline at end of file +kameo = "0.19.2" +prost-types = { version = "0.14.3", features = ["chrono"] } \ No newline at end of file diff --git a/server/crates/arbiter-proto/Cargo.toml b/server/crates/arbiter-proto/Cargo.toml index d3bd268..258ca48 100644 --- a/server/crates/arbiter-proto/Cargo.toml +++ b/server/crates/arbiter-proto/Cargo.toml @@ -9,12 +9,13 @@ tonic.workspace = true prost.workspace = true bytes = "1.11.1" prost-derive = "0.14.3" -prost-types = { version = "0.14.3", features = ["chrono"] } +prost-types.workspace = true tonic-prost = "0.14.3" rkyv = "0.8.15" tokio.workspace = true futures.workspace = true kameo.workspace = true +hex = "0.4.3" diff --git a/server/crates/arbiter-proto/src/lib.rs b/server/crates/arbiter-proto/src/lib.rs index db493d3..bce8e36 100644 --- a/server/crates/arbiter-proto/src/lib.rs +++ b/server/crates/arbiter-proto/src/lib.rs @@ -1,3 +1,5 @@ +use crate::proto::auth::AuthChallenge; + pub mod proto { tonic::include_proto!("arbiter"); @@ -22,3 +24,8 @@ pub fn home_path() -> Result { Ok(arbiter_home) } + +pub fn format_challenge(challenge: &AuthChallenge) -> Vec { + let concat_form = format!("{}:{}", challenge.nonce, hex::encode(&challenge.pubkey)); + concat_form.into_bytes().to_vec() +} diff --git a/server/crates/arbiter-server/Cargo.toml b/server/crates/arbiter-server/Cargo.toml index f0f3e94..af4e3bc 100644 --- a/server/crates/arbiter-server/Cargo.toml +++ b/server/crates/arbiter-server/Cargo.toml @@ -34,4 +34,4 @@ memsafe = "0.4.0" chacha20poly1305 = { version = "0.10.1", features = ["std"] } zeroize = { version = "1.8.2", features = ["std", "simd"] } kameo.workspace = true - +prost-types.workspace = true \ No newline at end of file diff --git a/server/crates/arbiter-server/src/actors/user_agent.rs b/server/crates/arbiter-server/src/actors/user_agent.rs index 6aff08f..d44b087 100644 --- a/server/crates/arbiter-server/src/actors/user_agent.rs +++ b/server/crates/arbiter-server/src/actors/user_agent.rs @@ -1,74 +1,81 @@ -use std::sync::Arc; - -use arbiter_proto::{ - proto::{ - UserAgentRequest, UserAgentResponse, - auth::{ - self, AuthChallengeRequest, ClientMessage, ServerMessage as AuthServerMessage, - client_message::Payload as ClientAuthPayload, - server_message::Payload as ServerAuthPayload, - }, - user_agent_request::Payload as UserAgentRequestPayload, - user_agent_response::Payload as UserAgentResponsePayload, +use arbiter_proto::proto::{ + UserAgentRequest, UserAgentResponse, + auth::{ + self, AuthChallenge, AuthChallengeRequest, AuthOk, ClientMessage, + ServerMessage as AuthServerMessage, client_message::Payload as ClientAuthPayload, + server_message::Payload as ServerAuthPayload, }, - transport::Bi, + user_agent_request::Payload as UserAgentRequestPayload, + user_agent_response::Payload as UserAgentResponsePayload, }; -use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl}; +use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, dsl::update}; use diesel_async::{AsyncConnection, RunQueryDsl}; -use ed25519_dalek::{SigningKey, VerifyingKey}; +use ed25519_dalek::VerifyingKey; use futures::StreamExt; use kameo::{ Actor, actor::{ActorRef, Spawn}, error::SendError, - message::StreamMessage, messages, prelude::Context, }; -use secrecy::{ExposeSecret, SecretBox}; use tokio::sync::mpsc; use tokio::sync::mpsc::Sender; -use tonic::{Status, transport::Server}; -use tracing::error; +use tonic::Status; +use tracing::{error, info}; -use crate::{ServerContext, context::bootstrap::ConsumeToken, db::schema}; +use crate::{ServerContext, context::bootstrap::ConsumeToken, db::schema, errors::GrpcStatusExt}; -#[derive(Debug)] +/// Context for state machine with validated key and sent challenge +/// Challenge is then transformed to bytes using shared function and verified +#[derive(Clone, Debug)] pub struct ChallengeContext { - challenge: auth::AuthChallenge, - key: SigningKey, + challenge: AuthChallenge, + key: VerifyingKey, +} + +// Request context with deserialized public key for state machine. +// This intermediate struct is needed because the state machine branches depending on presence of bootstrap token, +// but we want to have the deserialized key in both branches. +#[derive(Clone, Debug)] +pub struct AuthRequestContext { + pubkey: VerifyingKey, bootstrap_token: Option, } smlang::statemachine!( name: UserAgent, derive_states: [Debug], + custom_error: false, transitions: { - *Init + ReceivedRequest(ed25519_dalek::VerifyingKey) [async check_key_existence] / provide_challenge = WaitingForChallengeSolution(ChallengeContext), - Init + ReceivedBootstrapToken(String) = Authenticated, + *Init + AuthRequest(AuthRequestContext) / auth_request_context = ReceivedAuthRequest(AuthRequestContext), + ReceivedAuthRequest(AuthRequestContext) + ReceivedBootstrapToken = Authenticated, + + ReceivedAuthRequest(AuthRequestContext) + SentChallenge(ChallengeContext) / move_challenge = WaitingForChallengeSolution(ChallengeContext), WaitingForChallengeSolution(ChallengeContext) + ReceivedGoodSolution = Authenticated, - WaitingForChallengeSolution(ChallengeContext) + ReceivedBadSolution = Error, + WaitingForChallengeSolution(ChallengeContext) + ReceivedBadSolution = AuthError, // block further transitions, but connection should close anyway } ); impl UserAgentStateMachineContext for ServerContext { #[allow(missing_docs)] #[allow(clippy::unused_unit)] - fn provide_challenge( + fn move_challenge( &mut self, - event_data: ed25519_dalek::VerifyingKey, + state_data: &AuthRequestContext, + event_data: ChallengeContext, ) -> Result { - todo!() + Ok(event_data) } #[allow(missing_docs)] - #[allow(clippy::result_unit_err)] - async fn check_key_existence( - &self, - event_data: &ed25519_dalek::VerifyingKey, - ) -> Result { - todo!() + #[allow(clippy::unused_unit)] + fn auth_request_context( + &mut self, + event_data: AuthRequestContext, + ) -> Result { + Ok(event_data) } } @@ -76,21 +83,29 @@ impl UserAgentStateMachineContext for ServerContext { pub struct UserAgentActor { context: ServerContext, state: UserAgentStateMachine, - rx: Sender>, + tx: Sender>, } impl UserAgentActor { pub(crate) fn new( context: ServerContext, - rx: Sender>, + tx: Sender>, ) -> Self { Self { context: context.clone(), state: UserAgentStateMachine::new(context), - rx, + tx, } } + fn transition(&mut self, event: UserAgentEvents) -> Result<(), Status> { + self.state.process_event(event).map_err(|e| { + error!(?e, "State transition failed"); + Status::internal("State machine error") + })?; + Ok(()) + } + async fn auth_with_bootstrap_token( &mut self, pubkey: ed25519_dalek::VerifyingKey, @@ -106,11 +121,13 @@ impl UserAgentActor { Status::internal("Bootstrap token consumption failed") })?; - if token_ok { - let mut conn = self.context.db.get().await.map_err(|e| { - error!(?pubkey, "Failed to get DB connection: {e}"); - Status::internal("Database connection error") - })?; + if !token_ok { + error!(?pubkey, "Invalid bootstrap token provided"); + return Err(Status::invalid_argument("Invalid bootstrap token")); + } + + { + let mut conn = self.context.db.get().await.to_status()?; diesel::insert_into(schema::useragent_client::table) .values(( @@ -119,24 +136,105 @@ impl UserAgentActor { )) .execute(&mut conn) .await - .map_err(|e| { - error!(?pubkey, "Failed to insert new user agent client: {e}"); - Status::internal("Database error") - })?; - - return Ok(UserAgentResponse { - payload: Some(UserAgentResponsePayload::AuthMessage(AuthServerMessage { - payload: Some(ServerAuthPayload::Auth), - })), - }); + .to_status()?; } - todo!() + self.transition(UserAgentEvents::ReceivedBootstrapToken)?; + + Ok(auth_response(ServerAuthPayload::AuthOk(AuthOk {}))) + } + + async fn auth_with_challenge(&mut self, pubkey: VerifyingKey, pubkey_bytes: Vec) -> Output { + let nonce: Option = { + let mut db_conn = self.context.db.get().await.to_status()?; + db_conn + .transaction(|conn| { + Box::pin(async move { + let current_nonce = schema::useragent_client::table + .filter( + schema::useragent_client::public_key.eq(pubkey.as_bytes().to_vec()), + ) + .select(schema::useragent_client::nonce) + .first::(conn) + .await?; + + update(schema::useragent_client::table) + .filter( + schema::useragent_client::public_key.eq(pubkey.as_bytes().to_vec()), + ) + .set(schema::useragent_client::nonce.eq(current_nonce + 1)) + .execute(conn) + .await?; + + Result::<_, diesel::result::Error>::Ok(current_nonce) + }) + }) + .await + .optional() + .to_status()? + }; + + let Some(nonce) = nonce else { + error!(?pubkey, "Public key not found in database"); + return Err(Status::unauthenticated("Public key not registered")); + }; + + let challenge = auth::AuthChallenge { + pubkey: pubkey_bytes, + nonce: nonce, + }; + + self.transition(UserAgentEvents::SentChallenge(ChallengeContext { + challenge: challenge.clone(), + key: pubkey, + }))?; + + info!( + ?pubkey, + ?challenge, + "Sent authentication challenge to client" + ); + + Ok(auth_response(ServerAuthPayload::AuthChallenge(challenge))) + } + + fn verify_challenge_solution( + &self, + solution: &auth::AuthChallengeSolution, + ) -> Result<(bool, &ChallengeContext), Status> { + let UserAgentStates::WaitingForChallengeSolution(challenge_context) = self.state.state() + else { + error!("Received challenge solution in invalid state"); + return Err(Status::invalid_argument( + "Invalid state for challenge solution", + )); + }; + let formatted_challenge = arbiter_proto::format_challenge(&challenge_context.challenge); + + let signature = solution.signature.as_slice().try_into().map_err(|_| { + error!(?solution, "Invalid signature length"); + Status::invalid_argument("Invalid signature length") + })?; + + let valid = challenge_context + .key + .verify_strict(&formatted_challenge, &signature) + .is_ok(); + + Ok((valid, challenge_context)) } } type Output = Result; +fn auth_response(payload: ServerAuthPayload) -> UserAgentResponse { + UserAgentResponse { + payload: Some(UserAgentResponsePayload::AuthMessage(AuthServerMessage { + payload: Some(payload), + })), + } +} + #[messages] impl UserAgentActor { #[message(ctx)] @@ -153,32 +251,15 @@ impl UserAgentActor { Status::invalid_argument("Failed to convert pubkey to VerifyingKey") })?; - if let Some(token) = req.bootstrap_token { - return self.auth_with_bootstrap_token(pubkey, token).await; + self.transition(UserAgentEvents::AuthRequest(AuthRequestContext { + pubkey, + bootstrap_token: req.bootstrap_token.clone(), + }))?; + + match req.bootstrap_token { + Some(token) => self.auth_with_bootstrap_token(pubkey, token).await, + None => self.auth_with_challenge(pubkey, req.pubkey).await, } - - let mut db_conn = self.context.db.get().await.map_err(|err| { - error!(?pubkey, "Failed to get DB connection: {err}"); - Status::internal("Database connection error") - })?; - - let nonce = db_conn - .transaction(|mut conn| { - Box::pin(async move { - let current_nonce = schema::useragent_client::table - .filter(schema::useragent_client::public_key.eq(pubkey.as_bytes().to_vec())) - .select(schema::useragent_client::nonce) - .first::(&mut db_conn) - .await?; - - Ok(()) - }) - }) - .await; - - // let nonce = match last_used_nonce - - todo!() } #[message(ctx)] @@ -187,7 +268,20 @@ impl UserAgentActor { solution: auth::AuthChallengeSolution, ctx: &mut Context, ) -> Output { - todo!() + let (valid, challenge_context) = self.verify_challenge_solution(&solution)?; + + if valid { + info!( + ?challenge_context, + "Client provided valid solution to authentication challenge" + ); + self.transition(UserAgentEvents::ReceivedGoodSolution)?; + Ok(auth_response(ServerAuthPayload::AuthOk(AuthOk {}))) + } else { + error!("Client provided invalid solution to authentication challenge"); + self.transition(UserAgentEvents::ReceivedBadSolution)?; + Err(Status::unauthenticated("Invalid challenge solution")) + } } } @@ -240,7 +334,7 @@ async fn process_message( )); }; - let result = match client_message { + match client_message { ClientAuthPayload::AuthChallengeRequest(req) => actor .ask(HandleAuthChallengeRequest { req }) .await @@ -249,9 +343,7 @@ async fn process_message( .ask(HandleAuthChallengeSolution { solution }) .await .map_err(into_status), - }; - - result + } } fn into_status(e: SendError) -> Status { diff --git a/server/crates/arbiter-server/src/db.rs b/server/crates/arbiter-server/src/db.rs index 80fa260..129b9c1 100644 --- a/server/crates/arbiter-server/src/db.rs +++ b/server/crates/arbiter-server/src/db.rs @@ -1,8 +1,11 @@ use std::sync::Arc; -use diesel::{Connection as _, SqliteConnection, connection::SimpleConnection as _}; +use diesel::{ + Connection as _, SqliteConnection, + connection::{SimpleConnection as _, TransactionManager}, +}; use diesel_async::{ - AsyncConnection, SimpleAsyncConnection as _, + AsyncConnection, SimpleAsyncConnection, pooled_connection::{AsyncDieselConnectionManager, ManagerConfig, RecyclingMethod}, sync_connection_wrapper::SyncConnectionWrapper, }; @@ -53,7 +56,6 @@ fn database_path() -> Result { Ok(db_path) } - fn db_config(conn: &mut SqliteConnection) -> Result<(), diesel::result::Error> { // fsync only in critical moments conn.batch_execute("PRAGMA synchronous = NORMAL;")?; @@ -115,10 +117,12 @@ pub async fn create_pool() -> Result { }) }); - let pool = DatabasePool::builder().build(AsyncDieselConnectionManager::new_with_config( - database_url, - config, - )).await?; + let pool = DatabasePool::builder() + .build(AsyncDieselConnectionManager::new_with_config( + database_url, + config, + )) + .await?; Ok(pool) } diff --git a/server/crates/arbiter-server/src/errors.rs b/server/crates/arbiter-server/src/errors.rs new file mode 100644 index 0000000..4115f9c --- /dev/null +++ b/server/crates/arbiter-server/src/errors.rs @@ -0,0 +1,24 @@ +use tonic::Status; +use tracing::error; + +pub trait GrpcStatusExt { + fn to_status(self) -> Result; +} + +impl GrpcStatusExt for Result { + fn to_status(self) -> Result { + self.map_err(|e| { + error!(error = ?e, "Database error"); + Status::internal("Database error") + }) + } +} + +impl GrpcStatusExt for Result { + fn to_status(self) -> Result { + self.map_err(|e| { + error!(error = ?e, "Database pool error"); + Status::internal("Database pool error") + }) + } +} \ No newline at end of file diff --git a/server/crates/arbiter-server/src/lib.rs b/server/crates/arbiter-server/src/lib.rs index da05100..921bb8e 100644 --- a/server/crates/arbiter-server/src/lib.rs +++ b/server/crates/arbiter-server/src/lib.rs @@ -20,6 +20,7 @@ use crate::{ pub mod actors; mod context; mod db; +mod errors; const DEFAULT_CHANNEL_SIZE: usize = 1000;