diff --git a/protobufs/auth.proto b/protobufs/auth.proto index b0b3cd8..34eb330 100644 --- a/protobufs/auth.proto +++ b/protobufs/auth.proto @@ -20,12 +20,7 @@ message AuthChallengeSolution { bytes signature = 2; } -message AuthResponse { - string token = 1; - string refresh_token = 2; - google.protobuf.Timestamp expires_at = 3; - google.protobuf.Timestamp refresh_expires_at = 4; -} +message AuthOk {} message ClientMessage { oneof payload { @@ -37,6 +32,6 @@ message ClientMessage { message ServerMessage { oneof payload { AuthChallenge auth_challenge = 1; - AuthResponse auth_response = 2; + AuthOk auth_ok = 2; } } diff --git a/server/crates/arbiter-server/migrations/2026-02-09-143015-0000_init/up.sql b/server/crates/arbiter-server/migrations/2026-02-09-143015-0000_init/up.sql index 366cf76..1ca612d 100644 --- a/server/crates/arbiter-server/migrations/2026-02-09-143015-0000_init/up.sql +++ b/server/crates/arbiter-server/migrations/2026-02-09-143015-0000_init/up.sql @@ -14,24 +14,18 @@ create table if not exists arbiter_settings ( cert blob not null ) STRICT; -create table if not exists key_identity ( - id integer not null primary key, - name text not null, - public_key text not null, - created_at integer not null default(unixepoch ('now')), - updated_at integer not null default(unixepoch ('now')) -) STRICT; - create table if not exists useragent_client ( id integer not null primary key, - key_identity_id integer not null references key_identity (id) on delete cascade, + nonce integer not null default (1), -- used for auth challenge + public_key blob not null, created_at integer not null default(unixepoch ('now')), updated_at integer not null default(unixepoch ('now')) ) STRICT; create table if not exists program_client ( id integer not null primary key, - key_identity_id integer not null references key_identity (id) on delete cascade, + nonce integer not null default (1), -- used for auth challenge + public_key blob not null, created_at integer not null default(unixepoch ('now')), updated_at integer not null default(unixepoch ('now')) ) STRICT; \ No newline at end of file diff --git a/server/crates/arbiter-server/src/actors/user_agent.rs b/server/crates/arbiter-server/src/actors/user_agent.rs index 62fdd95..6aff08f 100644 --- a/server/crates/arbiter-server/src/actors/user_agent.rs +++ b/server/crates/arbiter-server/src/actors/user_agent.rs @@ -4,27 +4,40 @@ use arbiter_proto::{ proto::{ UserAgentRequest, UserAgentResponse, auth::{ - self, AuthChallengeRequest, ClientMessage, client_message::Payload as ClientAuthPayload, + self, AuthChallengeRequest, ClientMessage, ServerMessage as AuthServerMessage, + client_message::Payload as ClientAuthPayload, + server_message::Payload as ServerAuthPayload, }, user_agent_request::Payload as UserAgentRequestPayload, + user_agent_response::Payload as UserAgentResponsePayload, }, transport::Bi, }; -use ed25519_dalek::VerifyingKey; +use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl}; +use diesel_async::{AsyncConnection, RunQueryDsl}; +use ed25519_dalek::{SigningKey, VerifyingKey}; use futures::StreamExt; -use kameo::{Actor, message::StreamMessage, messages, prelude::Context}; +use kameo::{ + Actor, + actor::{ActorRef, Spawn}, + error::SendError, + message::StreamMessage, + messages, + prelude::Context, +}; use secrecy::{ExposeSecret, SecretBox}; use tokio::sync::mpsc; use tokio::sync::mpsc::Sender; use tonic::{Status, transport::Server}; use tracing::error; -use crate::ServerContext; +use crate::{ServerContext, context::bootstrap::ConsumeToken, db::schema}; #[derive(Debug)] pub struct ChallengeContext { challenge: auth::AuthChallenge, - key: ed25519_dalek::SigningKey, + key: SigningKey, + bootstrap_token: Option, } smlang::statemachine!( @@ -83,6 +96,41 @@ impl UserAgentActor { pubkey: ed25519_dalek::VerifyingKey, token: String, ) -> Result { + let token_ok: bool = self + .context + .bootstrapper + .ask(ConsumeToken { token }) + .await + .map_err(|e| { + error!(?pubkey, "Failed to consume bootstrap token: {e}"); + Status::internal("Bootstrap token consumption failed") + })?; + + if token_ok { + let mut conn = self.context.db.get().await.map_err(|e| { + error!(?pubkey, "Failed to get DB connection: {e}"); + Status::internal("Database connection error") + })?; + + diesel::insert_into(schema::useragent_client::table) + .values(( + schema::useragent_client::public_key.eq(pubkey.as_bytes().to_vec()), + schema::useragent_client::nonce.eq(1), + )) + .execute(&mut conn) + .await + .map_err(|e| { + error!(?pubkey, "Failed to insert new user agent client: {e}"); + Status::internal("Database error") + })?; + + return Ok(UserAgentResponse { + payload: Some(UserAgentResponsePayload::AuthMessage(AuthServerMessage { + payload: Some(ServerAuthPayload::Auth), + })), + }); + } + todo!() } } @@ -92,7 +140,7 @@ type Output = Result; #[messages] impl UserAgentActor { #[message(ctx)] - async fn handle_auth_challenge_request( + pub async fn handle_auth_challenge_request( &mut self, req: AuthChallengeRequest, ctx: &mut Context, @@ -106,21 +154,112 @@ impl UserAgentActor { })?; if let Some(token) = req.bootstrap_token { - return self - .auth_with_bootstrap_token(pubkey, token) - .await - .map_err(|_| Status::internal("Failed to authenticate with bootstrap token")); + return self.auth_with_bootstrap_token(pubkey, token).await; } + let mut db_conn = self.context.db.get().await.map_err(|err| { + error!(?pubkey, "Failed to get DB connection: {err}"); + Status::internal("Database connection error") + })?; + + let nonce = db_conn + .transaction(|mut conn| { + Box::pin(async move { + let current_nonce = schema::useragent_client::table + .filter(schema::useragent_client::public_key.eq(pubkey.as_bytes().to_vec())) + .select(schema::useragent_client::nonce) + .first::(&mut db_conn) + .await?; + + Ok(()) + }) + }) + .await; + + // let nonce = match last_used_nonce + todo!() } #[message(ctx)] - async fn handle_auth_challenge_solution( + pub async fn handle_auth_challenge_solution( &mut self, - _solution: auth::AuthChallengeSolution, + solution: auth::AuthChallengeSolution, ctx: &mut Context, ) -> Output { todo!() } } + +pub(crate) async fn handle_user_agent( + context: ServerContext, + mut req_stream: tonic::Streaming, + tx: mpsc::Sender>, +) { + let actor = UserAgentActor::spawn(UserAgentActor::new(context, tx.clone())); + + while let Some(Ok(req)) = req_stream.next().await + && actor.is_alive() + { + match process_message(&actor, req).await { + Ok(resp) => { + if tx.send(Ok(resp)).await.is_err() { + error!(actor = "useragent", "Failed to send response to client"); + break; + } + } + Err(status) => { + let _ = tx.send(Err(status)).await; + break; + } + } + } + + actor.kill(); +} + +async fn process_message( + actor: &ActorRef, + req: UserAgentRequest, +) -> Result { + let msg = req.payload.ok_or_else(|| { + error!(actor = "useragent", "Received message with no payload"); + Status::invalid_argument("Expected message with payload") + })?; + + let UserAgentRequestPayload::AuthMessage(ClientMessage { + payload: Some(client_message), + }) = msg + else { + error!( + actor = "useragent", + "Received unexpected message type during authentication" + ); + return Err(Status::invalid_argument( + "Expected AuthMessage with ClientMessage payload", + )); + }; + + let result = match client_message { + ClientAuthPayload::AuthChallengeRequest(req) => actor + .ask(HandleAuthChallengeRequest { req }) + .await + .map_err(into_status), + ClientAuthPayload::AuthChallengeSolution(solution) => actor + .ask(HandleAuthChallengeSolution { solution }) + .await + .map_err(into_status), + }; + + result +} + +fn into_status(e: SendError) -> Status { + match e { + SendError::HandlerError(status) => status, + _ => { + error!(actor = "useragent", "Failed to send message to actor"); + Status::internal("session failure") + } + } +} diff --git a/server/crates/arbiter-server/src/context.rs b/server/crates/arbiter-server/src/context.rs index b2d3051..6da6a80 100644 --- a/server/crates/arbiter-server/src/context.rs +++ b/server/crates/arbiter-server/src/context.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use diesel::OptionalExtension as _; use diesel_async::RunQueryDsl as _; use ed25519_dalek::VerifyingKey; +use kameo::actor::{ActorRef, Spawn}; use miette::Diagnostic; use rand::rngs::StdRng; use smlang::statemachine; @@ -11,7 +12,7 @@ use tokio::sync::RwLock; use crate::{ context::{ - bootstrap::generate_token, + bootstrap::{BootstrapActor, generate_token}, lease::LeaseHandler, tls::{TlsDataRaw, TlsManager}, }, @@ -44,6 +45,10 @@ pub enum InitError { #[diagnostic(code(arbiter_server::init::tls_init))] Tls(#[from] tls::TlsInitError), + #[error("Bootstrap token generation failed: {0}")] + #[diagnostic(code(arbiter_server::init::bootstrap_token))] + BootstrapToken(#[from] bootstrap::BootstrapError), + #[error("I/O Error: {0}")] #[diagnostic(code(arbiter_server::init::io))] Io(#[from] std::io::Error), @@ -55,7 +60,7 @@ pub struct KeyStorage; statemachine! { name: Server, transitions: { - *NotBootstrapped(String) + Bootstrapped = Sealed, + *NotBootstrapped + Bootstrapped = Sealed, Sealed + Unsealed(KeyStorage) / move_key = Ready(KeyStorage), Ready(KeyStorage) + Sealed / dispose_key = Sealed, } @@ -78,8 +83,7 @@ pub(crate) struct _ServerContextInner { pub state: RwLock>, pub rng: StdRng, pub tls: TlsManager, - pub user_agent_leases: LeaseHandler, - pub client_leases: LeaseHandler, + pub bootstrapper: ActorRef, } #[derive(Clone)] pub(crate) struct ServerContext(Arc<_ServerContextInner>); @@ -138,9 +142,7 @@ impl ServerContext { drop(conn); - let bootstrap_token = generate_token().await?; - - let mut state = ServerStateMachine::new(_Context, bootstrap_token); + let mut state = ServerStateMachine::new(_Context); if let Some(settings) = &settings && settings.root_key_id.is_some() @@ -150,12 +152,11 @@ impl ServerContext { } Ok(Self(Arc::new(_ServerContextInner { + bootstrapper: BootstrapActor::spawn(BootstrapActor::new(&db).await?), db, rng, tls, state: RwLock::new(state), - user_agent_leases: Default::default(), - client_leases: Default::default(), }))) } } diff --git a/server/crates/arbiter-server/src/context/bootstrap.rs b/server/crates/arbiter-server/src/context/bootstrap.rs index 2a00ed5..e765947 100644 --- a/server/crates/arbiter-server/src/context/bootstrap.rs +++ b/server/crates/arbiter-server/src/context/bootstrap.rs @@ -1,6 +1,11 @@ use arbiter_proto::{BOOTSTRAP_TOKEN_PATH, home_path}; -use diesel::{QueryDsl, dsl::exists, select}; +use diesel::{ + ExpressionMethods, QueryDsl, + dsl::{count, exists}, + select, +}; use diesel_async::RunQueryDsl; +use kameo::{Actor, messages}; use memsafe::MemSafe; use miette::Diagnostic; use rand::{RngExt, distr::StandardUniform, make_rng, rngs::StdRng}; @@ -9,7 +14,10 @@ use thiserror::Error; use tracing::info; use zeroize::{Zeroize, Zeroizing}; -use crate::db::{self, schema}; +use crate::{ + context::{self, ServerContext}, + db::{self, DatabasePool, schema}, +}; const TOKEN_LENGTH: usize = 64; @@ -28,3 +36,70 @@ pub async fn generate_token() -> Result { Ok(token) } + +#[derive(Error, Debug, Diagnostic)] +pub enum BootstrapError { + #[error("Database error: {0}")] + #[diagnostic(code(arbiter_server::bootstrap::database))] + Database(#[from] db::PoolError), + + #[error("Database query error: {0}")] + #[diagnostic(code(arbiter_server::bootstrap::database_query))] + Query(#[from] diesel::result::Error), + + #[error("I/O error: {0}")] + #[diagnostic(code(arbiter_server::bootstrap::io))] + Io(#[from] std::io::Error), +} + +#[derive(Actor)] +pub struct BootstrapActor { + token: Option, +} + +impl BootstrapActor { + pub async fn new(db: &DatabasePool) -> Result { + let mut conn = db.get().await?; + + let needs_token: bool = select(exists( + schema::useragent_client::table + .filter(schema::useragent_client::id.eq(schema::useragent_client::id)), // Just check if the table is empty + )) + .first(&mut conn) + .await?; + + drop(conn); + + let token = if needs_token { + let token = generate_token().await?; + info!(%token, "Generated bootstrap token"); + tokio::fs::write(home_path()?.join(BOOTSTRAP_TOKEN_PATH), token.as_str()).await?; + Some(token) + } else { + None + }; + + Ok(Self { token }) + } +} + +#[messages] +impl BootstrapActor { + #[message] + pub fn is_correct_token(&self, token: String) -> bool { + match &self.token { + Some(expected) => *expected == token, + None => false, + } + } + + #[message] + pub fn consume_token(&mut self, token: String) -> bool { + if self.is_correct_token(token) { + self.token = None; + true + } else { + false + } + } +} diff --git a/server/crates/arbiter-server/src/db/models.rs b/server/crates/arbiter-server/src/db/models.rs index 63810ec..7417875 100644 --- a/server/crates/arbiter-server/src/db/models.rs +++ b/server/crates/arbiter-server/src/db/models.rs @@ -28,21 +28,12 @@ pub struct ArbiterSetting { pub cert: Vec, } -#[derive(Queryable, Debug)] -#[diesel(table_name = schema::key_identity, check_for_backend(Sqlite))] -pub struct KeyIdentity { - pub id: i32, - pub name: String, - pub public_key: String, - pub created_at: i32, - pub updated_at: i32, -} - #[derive(Queryable, Debug)] #[diesel(table_name = schema::program_client, check_for_backend(Sqlite))] pub struct ProgramClient { pub id: i32, - pub key_identity_id: i32, + pub public_key: Vec, + pub nonce: i32, pub created_at: i32, pub updated_at: i32, } @@ -51,7 +42,8 @@ pub struct ProgramClient { #[diesel(table_name = schema::useragent_client, check_for_backend(Sqlite))] pub struct UseragentClient { pub id: i32, - pub key_identity_id: i32, + pub public_key: Vec, + pub nonce: i32, pub created_at: i32, pub updated_at: i32, } diff --git a/server/crates/arbiter-server/src/db/schema.rs b/server/crates/arbiter-server/src/db/schema.rs index f662849..38f8afe 100644 --- a/server/crates/arbiter-server/src/db/schema.rs +++ b/server/crates/arbiter-server/src/db/schema.rs @@ -19,20 +19,11 @@ diesel::table! { } } -diesel::table! { - key_identity (id) { - id -> Integer, - name -> Text, - public_key -> Text, - created_at -> Integer, - updated_at -> Integer, - } -} - diesel::table! { program_client (id) { id -> Integer, - key_identity_id -> Integer, + nonce -> Integer, + public_key -> Binary, created_at -> Integer, updated_at -> Integer, } @@ -41,20 +32,18 @@ diesel::table! { diesel::table! { useragent_client (id) { id -> Integer, - key_identity_id -> Integer, + nonce -> Integer, + public_key -> Binary, created_at -> Integer, updated_at -> Integer, } } diesel::joinable!(arbiter_settings -> aead_encrypted (root_key_id)); -diesel::joinable!(program_client -> key_identity (key_identity_id)); -diesel::joinable!(useragent_client -> key_identity (key_identity_id)); diesel::allow_tables_to_appear_in_same_query!( aead_encrypted, arbiter_settings, - key_identity, program_client, useragent_client, ); diff --git a/server/crates/arbiter-server/src/lib.rs b/server/crates/arbiter-server/src/lib.rs index afb7d56..da05100 100644 --- a/server/crates/arbiter-server/src/lib.rs +++ b/server/crates/arbiter-server/src/lib.rs @@ -2,32 +2,18 @@ use std::sync::Arc; -use tracing::error; - use arbiter_proto::{ - proto::{ - ClientRequest, ClientResponse, UserAgentRequest, UserAgentResponse, - auth::{ - self, AuthChallengeRequest, ClientMessage, client_message::Payload as ClientAuthPayload, - }, - user_agent_request::Payload as UserAgentRequestPayload, - user_agent_request::*, - }, + proto::{ClientRequest, ClientResponse, UserAgentRequest, UserAgentResponse}, transport::BiStream, }; use async_trait::async_trait; -use futures::StreamExt; -use kameo::actor::Spawn; use tokio_stream::wrappers::ReceiverStream; use tokio::sync::mpsc; use tonic::{Request, Response, Status}; use crate::{ - actors::{ - client::handle_client, - user_agent::{self, UserAgentActor}, - }, + actors::{client::handle_client, user_agent::handle_user_agent}, context::ServerContext, }; @@ -67,58 +53,9 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server { &self, request: Request>, ) -> Result, Status> { - let mut req_stream = request.into_inner(); + let req_stream = request.into_inner(); let (tx, rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); - - let actor = UserAgentActor::spawn(UserAgentActor::new(self.context.clone(), tx.clone())); - - tokio::task::spawn(async move { - while let Some(Ok(req)) = req_stream.next().await - && actor.is_alive() - { - let Some(msg) = req.payload else { - error!(actor = "useragent", "Received message with no payload"); - actor.kill(); - tx.send(Err(Status::invalid_argument( - "Expected message with payload", - ))) - .await; - return; - }; - - let UserAgentRequestPayload::AuthMessage(ClientMessage { - payload: Some(client_message), - }) = msg - else { - error!( - actor = "useragent", - "Received unexpected message type during authentication" - ); - actor.kill(); - tx.send(Err(Status::invalid_argument( - "Expected AuthMessage with ClientMessage payload", - ))) - .await; - return; - }; - - match client_message { - ClientAuthPayload::AuthChallengeRequest(req) => {} - ClientAuthPayload::AuthChallengeSolution(_auth_challenge_solution) => todo!(), - _ => { - error!(actor = "useragent", "Received unexpected message type"); - actor.kill(); - tx.send(Err(Status::invalid_argument( - "Expected AuthMessage with ClientMessage payload", - ))) - .await; - return; - } - } - todo!() - } - }); - + tokio::spawn(handle_user_agent(self.context.clone(), req_stream, tx)); Ok(Response::new(ReceiverStream::new(rx))) } }