feat(auth): simplify auth model and implement bootstrap flow

Remove key_identity indirection table, storing public keys and nonces
directly on client tables. Replace AuthResponse with AuthOk, add a
BootstrapActor to manage token lifecycle, and move user agent stream
handling into the actor module.
This commit is contained in:
hdbg
2026-02-13 17:55:56 +01:00
parent 8fb7a04102
commit ffa60c90b1
8 changed files with 256 additions and 134 deletions

View File

@@ -20,12 +20,7 @@ message AuthChallengeSolution {
bytes signature = 2; bytes signature = 2;
} }
message AuthResponse { message AuthOk {}
string token = 1;
string refresh_token = 2;
google.protobuf.Timestamp expires_at = 3;
google.protobuf.Timestamp refresh_expires_at = 4;
}
message ClientMessage { message ClientMessage {
oneof payload { oneof payload {
@@ -37,6 +32,6 @@ message ClientMessage {
message ServerMessage { message ServerMessage {
oneof payload { oneof payload {
AuthChallenge auth_challenge = 1; AuthChallenge auth_challenge = 1;
AuthResponse auth_response = 2; AuthOk auth_ok = 2;
} }
} }

View File

@@ -14,24 +14,18 @@ create table if not exists arbiter_settings (
cert blob not null cert blob not null
) STRICT; ) STRICT;
create table if not exists key_identity (
id integer not null primary key,
name text not null,
public_key text not null,
created_at integer not null default(unixepoch ('now')),
updated_at integer not null default(unixepoch ('now'))
) STRICT;
create table if not exists useragent_client ( create table if not exists useragent_client (
id integer not null primary key, id integer not null primary key,
key_identity_id integer not null references key_identity (id) on delete cascade, nonce integer not null default (1), -- used for auth challenge
public_key blob not null,
created_at integer not null default(unixepoch ('now')), created_at integer not null default(unixepoch ('now')),
updated_at integer not null default(unixepoch ('now')) updated_at integer not null default(unixepoch ('now'))
) STRICT; ) STRICT;
create table if not exists program_client ( create table if not exists program_client (
id integer not null primary key, id integer not null primary key,
key_identity_id integer not null references key_identity (id) on delete cascade, nonce integer not null default (1), -- used for auth challenge
public_key blob not null,
created_at integer not null default(unixepoch ('now')), created_at integer not null default(unixepoch ('now')),
updated_at integer not null default(unixepoch ('now')) updated_at integer not null default(unixepoch ('now'))
) STRICT; ) STRICT;

View File

@@ -4,27 +4,40 @@ use arbiter_proto::{
proto::{ proto::{
UserAgentRequest, UserAgentResponse, UserAgentRequest, UserAgentResponse,
auth::{ auth::{
self, AuthChallengeRequest, ClientMessage, client_message::Payload as ClientAuthPayload, self, AuthChallengeRequest, ClientMessage, ServerMessage as AuthServerMessage,
client_message::Payload as ClientAuthPayload,
server_message::Payload as ServerAuthPayload,
}, },
user_agent_request::Payload as UserAgentRequestPayload, user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
}, },
transport::Bi, transport::Bi,
}; };
use ed25519_dalek::VerifyingKey; use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl};
use diesel_async::{AsyncConnection, RunQueryDsl};
use ed25519_dalek::{SigningKey, VerifyingKey};
use futures::StreamExt; use futures::StreamExt;
use kameo::{Actor, message::StreamMessage, messages, prelude::Context}; use kameo::{
Actor,
actor::{ActorRef, Spawn},
error::SendError,
message::StreamMessage,
messages,
prelude::Context,
};
use secrecy::{ExposeSecret, SecretBox}; use secrecy::{ExposeSecret, SecretBox};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::Sender;
use tonic::{Status, transport::Server}; use tonic::{Status, transport::Server};
use tracing::error; use tracing::error;
use crate::ServerContext; use crate::{ServerContext, context::bootstrap::ConsumeToken, db::schema};
#[derive(Debug)] #[derive(Debug)]
pub struct ChallengeContext { pub struct ChallengeContext {
challenge: auth::AuthChallenge, challenge: auth::AuthChallenge,
key: ed25519_dalek::SigningKey, key: SigningKey,
bootstrap_token: Option<String>,
} }
smlang::statemachine!( smlang::statemachine!(
@@ -83,6 +96,41 @@ impl UserAgentActor {
pubkey: ed25519_dalek::VerifyingKey, pubkey: ed25519_dalek::VerifyingKey,
token: String, token: String,
) -> Result<UserAgentResponse, Status> { ) -> Result<UserAgentResponse, Status> {
let token_ok: bool = self
.context
.bootstrapper
.ask(ConsumeToken { token })
.await
.map_err(|e| {
error!(?pubkey, "Failed to consume bootstrap token: {e}");
Status::internal("Bootstrap token consumption failed")
})?;
if token_ok {
let mut conn = self.context.db.get().await.map_err(|e| {
error!(?pubkey, "Failed to get DB connection: {e}");
Status::internal("Database connection error")
})?;
diesel::insert_into(schema::useragent_client::table)
.values((
schema::useragent_client::public_key.eq(pubkey.as_bytes().to_vec()),
schema::useragent_client::nonce.eq(1),
))
.execute(&mut conn)
.await
.map_err(|e| {
error!(?pubkey, "Failed to insert new user agent client: {e}");
Status::internal("Database error")
})?;
return Ok(UserAgentResponse {
payload: Some(UserAgentResponsePayload::AuthMessage(AuthServerMessage {
payload: Some(ServerAuthPayload::Auth),
})),
});
}
todo!() todo!()
} }
} }
@@ -92,7 +140,7 @@ type Output = Result<UserAgentResponse, Status>;
#[messages] #[messages]
impl UserAgentActor { impl UserAgentActor {
#[message(ctx)] #[message(ctx)]
async fn handle_auth_challenge_request( pub async fn handle_auth_challenge_request(
&mut self, &mut self,
req: AuthChallengeRequest, req: AuthChallengeRequest,
ctx: &mut Context<Self, Output>, ctx: &mut Context<Self, Output>,
@@ -106,21 +154,112 @@ impl UserAgentActor {
})?; })?;
if let Some(token) = req.bootstrap_token { if let Some(token) = req.bootstrap_token {
return self return self.auth_with_bootstrap_token(pubkey, token).await;
.auth_with_bootstrap_token(pubkey, token)
.await
.map_err(|_| Status::internal("Failed to authenticate with bootstrap token"));
} }
let mut db_conn = self.context.db.get().await.map_err(|err| {
error!(?pubkey, "Failed to get DB connection: {err}");
Status::internal("Database connection error")
})?;
let nonce = db_conn
.transaction(|mut conn| {
Box::pin(async move {
let current_nonce = schema::useragent_client::table
.filter(schema::useragent_client::public_key.eq(pubkey.as_bytes().to_vec()))
.select(schema::useragent_client::nonce)
.first::<i32>(&mut db_conn)
.await?;
Ok(())
})
})
.await;
// let nonce = match last_used_nonce
todo!() todo!()
} }
#[message(ctx)] #[message(ctx)]
async fn handle_auth_challenge_solution( pub async fn handle_auth_challenge_solution(
&mut self, &mut self,
_solution: auth::AuthChallengeSolution, solution: auth::AuthChallengeSolution,
ctx: &mut Context<Self, Output>, ctx: &mut Context<Self, Output>,
) -> Output { ) -> Output {
todo!() todo!()
} }
} }
pub(crate) async fn handle_user_agent(
context: ServerContext,
mut req_stream: tonic::Streaming<UserAgentRequest>,
tx: mpsc::Sender<Result<UserAgentResponse, Status>>,
) {
let actor = UserAgentActor::spawn(UserAgentActor::new(context, tx.clone()));
while let Some(Ok(req)) = req_stream.next().await
&& actor.is_alive()
{
match process_message(&actor, req).await {
Ok(resp) => {
if tx.send(Ok(resp)).await.is_err() {
error!(actor = "useragent", "Failed to send response to client");
break;
}
}
Err(status) => {
let _ = tx.send(Err(status)).await;
break;
}
}
}
actor.kill();
}
async fn process_message(
actor: &ActorRef<UserAgentActor>,
req: UserAgentRequest,
) -> Result<UserAgentResponse, Status> {
let msg = req.payload.ok_or_else(|| {
error!(actor = "useragent", "Received message with no payload");
Status::invalid_argument("Expected message with payload")
})?;
let UserAgentRequestPayload::AuthMessage(ClientMessage {
payload: Some(client_message),
}) = msg
else {
error!(
actor = "useragent",
"Received unexpected message type during authentication"
);
return Err(Status::invalid_argument(
"Expected AuthMessage with ClientMessage payload",
));
};
let result = match client_message {
ClientAuthPayload::AuthChallengeRequest(req) => actor
.ask(HandleAuthChallengeRequest { req })
.await
.map_err(into_status),
ClientAuthPayload::AuthChallengeSolution(solution) => actor
.ask(HandleAuthChallengeSolution { solution })
.await
.map_err(into_status),
};
result
}
fn into_status<M>(e: SendError<M, Status>) -> Status {
match e {
SendError::HandlerError(status) => status,
_ => {
error!(actor = "useragent", "Failed to send message to actor");
Status::internal("session failure")
}
}
}

View File

@@ -3,6 +3,7 @@ use std::sync::Arc;
use diesel::OptionalExtension as _; use diesel::OptionalExtension as _;
use diesel_async::RunQueryDsl as _; use diesel_async::RunQueryDsl as _;
use ed25519_dalek::VerifyingKey; use ed25519_dalek::VerifyingKey;
use kameo::actor::{ActorRef, Spawn};
use miette::Diagnostic; use miette::Diagnostic;
use rand::rngs::StdRng; use rand::rngs::StdRng;
use smlang::statemachine; use smlang::statemachine;
@@ -11,7 +12,7 @@ use tokio::sync::RwLock;
use crate::{ use crate::{
context::{ context::{
bootstrap::generate_token, bootstrap::{BootstrapActor, generate_token},
lease::LeaseHandler, lease::LeaseHandler,
tls::{TlsDataRaw, TlsManager}, tls::{TlsDataRaw, TlsManager},
}, },
@@ -44,6 +45,10 @@ pub enum InitError {
#[diagnostic(code(arbiter_server::init::tls_init))] #[diagnostic(code(arbiter_server::init::tls_init))]
Tls(#[from] tls::TlsInitError), Tls(#[from] tls::TlsInitError),
#[error("Bootstrap token generation failed: {0}")]
#[diagnostic(code(arbiter_server::init::bootstrap_token))]
BootstrapToken(#[from] bootstrap::BootstrapError),
#[error("I/O Error: {0}")] #[error("I/O Error: {0}")]
#[diagnostic(code(arbiter_server::init::io))] #[diagnostic(code(arbiter_server::init::io))]
Io(#[from] std::io::Error), Io(#[from] std::io::Error),
@@ -55,7 +60,7 @@ pub struct KeyStorage;
statemachine! { statemachine! {
name: Server, name: Server,
transitions: { transitions: {
*NotBootstrapped(String) + Bootstrapped = Sealed, *NotBootstrapped + Bootstrapped = Sealed,
Sealed + Unsealed(KeyStorage) / move_key = Ready(KeyStorage), Sealed + Unsealed(KeyStorage) / move_key = Ready(KeyStorage),
Ready(KeyStorage) + Sealed / dispose_key = Sealed, Ready(KeyStorage) + Sealed / dispose_key = Sealed,
} }
@@ -78,8 +83,7 @@ pub(crate) struct _ServerContextInner {
pub state: RwLock<ServerStateMachine<_Context>>, pub state: RwLock<ServerStateMachine<_Context>>,
pub rng: StdRng, pub rng: StdRng,
pub tls: TlsManager, pub tls: TlsManager,
pub user_agent_leases: LeaseHandler<VerifyingKey>, pub bootstrapper: ActorRef<BootstrapActor>,
pub client_leases: LeaseHandler<VerifyingKey>,
} }
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct ServerContext(Arc<_ServerContextInner>); pub(crate) struct ServerContext(Arc<_ServerContextInner>);
@@ -138,9 +142,7 @@ impl ServerContext {
drop(conn); drop(conn);
let bootstrap_token = generate_token().await?; let mut state = ServerStateMachine::new(_Context);
let mut state = ServerStateMachine::new(_Context, bootstrap_token);
if let Some(settings) = &settings if let Some(settings) = &settings
&& settings.root_key_id.is_some() && settings.root_key_id.is_some()
@@ -150,12 +152,11 @@ impl ServerContext {
} }
Ok(Self(Arc::new(_ServerContextInner { Ok(Self(Arc::new(_ServerContextInner {
bootstrapper: BootstrapActor::spawn(BootstrapActor::new(&db).await?),
db, db,
rng, rng,
tls, tls,
state: RwLock::new(state), state: RwLock::new(state),
user_agent_leases: Default::default(),
client_leases: Default::default(),
}))) })))
} }
} }

View File

@@ -1,6 +1,11 @@
use arbiter_proto::{BOOTSTRAP_TOKEN_PATH, home_path}; use arbiter_proto::{BOOTSTRAP_TOKEN_PATH, home_path};
use diesel::{QueryDsl, dsl::exists, select}; use diesel::{
ExpressionMethods, QueryDsl,
dsl::{count, exists},
select,
};
use diesel_async::RunQueryDsl; use diesel_async::RunQueryDsl;
use kameo::{Actor, messages};
use memsafe::MemSafe; use memsafe::MemSafe;
use miette::Diagnostic; use miette::Diagnostic;
use rand::{RngExt, distr::StandardUniform, make_rng, rngs::StdRng}; use rand::{RngExt, distr::StandardUniform, make_rng, rngs::StdRng};
@@ -9,7 +14,10 @@ use thiserror::Error;
use tracing::info; use tracing::info;
use zeroize::{Zeroize, Zeroizing}; use zeroize::{Zeroize, Zeroizing};
use crate::db::{self, schema}; use crate::{
context::{self, ServerContext},
db::{self, DatabasePool, schema},
};
const TOKEN_LENGTH: usize = 64; const TOKEN_LENGTH: usize = 64;
@@ -28,3 +36,70 @@ pub async fn generate_token() -> Result<String, std::io::Error> {
Ok(token) Ok(token)
} }
#[derive(Error, Debug, Diagnostic)]
pub enum BootstrapError {
#[error("Database error: {0}")]
#[diagnostic(code(arbiter_server::bootstrap::database))]
Database(#[from] db::PoolError),
#[error("Database query error: {0}")]
#[diagnostic(code(arbiter_server::bootstrap::database_query))]
Query(#[from] diesel::result::Error),
#[error("I/O error: {0}")]
#[diagnostic(code(arbiter_server::bootstrap::io))]
Io(#[from] std::io::Error),
}
#[derive(Actor)]
pub struct BootstrapActor {
token: Option<String>,
}
impl BootstrapActor {
pub async fn new(db: &DatabasePool) -> Result<Self, BootstrapError> {
let mut conn = db.get().await?;
let needs_token: bool = select(exists(
schema::useragent_client::table
.filter(schema::useragent_client::id.eq(schema::useragent_client::id)), // Just check if the table is empty
))
.first(&mut conn)
.await?;
drop(conn);
let token = if needs_token {
let token = generate_token().await?;
info!(%token, "Generated bootstrap token");
tokio::fs::write(home_path()?.join(BOOTSTRAP_TOKEN_PATH), token.as_str()).await?;
Some(token)
} else {
None
};
Ok(Self { token })
}
}
#[messages]
impl BootstrapActor {
#[message]
pub fn is_correct_token(&self, token: String) -> bool {
match &self.token {
Some(expected) => *expected == token,
None => false,
}
}
#[message]
pub fn consume_token(&mut self, token: String) -> bool {
if self.is_correct_token(token) {
self.token = None;
true
} else {
false
}
}
}

View File

@@ -28,21 +28,12 @@ pub struct ArbiterSetting {
pub cert: Vec<u8>, pub cert: Vec<u8>,
} }
#[derive(Queryable, Debug)]
#[diesel(table_name = schema::key_identity, check_for_backend(Sqlite))]
pub struct KeyIdentity {
pub id: i32,
pub name: String,
pub public_key: String,
pub created_at: i32,
pub updated_at: i32,
}
#[derive(Queryable, Debug)] #[derive(Queryable, Debug)]
#[diesel(table_name = schema::program_client, check_for_backend(Sqlite))] #[diesel(table_name = schema::program_client, check_for_backend(Sqlite))]
pub struct ProgramClient { pub struct ProgramClient {
pub id: i32, pub id: i32,
pub key_identity_id: i32, pub public_key: Vec<u8>,
pub nonce: i32,
pub created_at: i32, pub created_at: i32,
pub updated_at: i32, pub updated_at: i32,
} }
@@ -51,7 +42,8 @@ pub struct ProgramClient {
#[diesel(table_name = schema::useragent_client, check_for_backend(Sqlite))] #[diesel(table_name = schema::useragent_client, check_for_backend(Sqlite))]
pub struct UseragentClient { pub struct UseragentClient {
pub id: i32, pub id: i32,
pub key_identity_id: i32, pub public_key: Vec<u8>,
pub nonce: i32,
pub created_at: i32, pub created_at: i32,
pub updated_at: i32, pub updated_at: i32,
} }

View File

@@ -19,20 +19,11 @@ diesel::table! {
} }
} }
diesel::table! {
key_identity (id) {
id -> Integer,
name -> Text,
public_key -> Text,
created_at -> Integer,
updated_at -> Integer,
}
}
diesel::table! { diesel::table! {
program_client (id) { program_client (id) {
id -> Integer, id -> Integer,
key_identity_id -> Integer, nonce -> Integer,
public_key -> Binary,
created_at -> Integer, created_at -> Integer,
updated_at -> Integer, updated_at -> Integer,
} }
@@ -41,20 +32,18 @@ diesel::table! {
diesel::table! { diesel::table! {
useragent_client (id) { useragent_client (id) {
id -> Integer, id -> Integer,
key_identity_id -> Integer, nonce -> Integer,
public_key -> Binary,
created_at -> Integer, created_at -> Integer,
updated_at -> Integer, updated_at -> Integer,
} }
} }
diesel::joinable!(arbiter_settings -> aead_encrypted (root_key_id)); diesel::joinable!(arbiter_settings -> aead_encrypted (root_key_id));
diesel::joinable!(program_client -> key_identity (key_identity_id));
diesel::joinable!(useragent_client -> key_identity (key_identity_id));
diesel::allow_tables_to_appear_in_same_query!( diesel::allow_tables_to_appear_in_same_query!(
aead_encrypted, aead_encrypted,
arbiter_settings, arbiter_settings,
key_identity,
program_client, program_client,
useragent_client, useragent_client,
); );

View File

@@ -2,32 +2,18 @@
use std::sync::Arc; use std::sync::Arc;
use tracing::error;
use arbiter_proto::{ use arbiter_proto::{
proto::{ proto::{ClientRequest, ClientResponse, UserAgentRequest, UserAgentResponse},
ClientRequest, ClientResponse, UserAgentRequest, UserAgentResponse,
auth::{
self, AuthChallengeRequest, ClientMessage, client_message::Payload as ClientAuthPayload,
},
user_agent_request::Payload as UserAgentRequestPayload,
user_agent_request::*,
},
transport::BiStream, transport::BiStream,
}; };
use async_trait::async_trait; use async_trait::async_trait;
use futures::StreamExt;
use kameo::actor::Spawn;
use tokio_stream::wrappers::ReceiverStream; use tokio_stream::wrappers::ReceiverStream;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tonic::{Request, Response, Status}; use tonic::{Request, Response, Status};
use crate::{ use crate::{
actors::{ actors::{client::handle_client, user_agent::handle_user_agent},
client::handle_client,
user_agent::{self, UserAgentActor},
},
context::ServerContext, context::ServerContext,
}; };
@@ -67,58 +53,9 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server {
&self, &self,
request: Request<tonic::Streaming<UserAgentRequest>>, request: Request<tonic::Streaming<UserAgentRequest>>,
) -> Result<Response<Self::UserAgentStream>, Status> { ) -> Result<Response<Self::UserAgentStream>, Status> {
let mut req_stream = request.into_inner(); let req_stream = request.into_inner();
let (tx, rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); let (tx, rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
tokio::spawn(handle_user_agent(self.context.clone(), req_stream, tx));
let actor = UserAgentActor::spawn(UserAgentActor::new(self.context.clone(), tx.clone()));
tokio::task::spawn(async move {
while let Some(Ok(req)) = req_stream.next().await
&& actor.is_alive()
{
let Some(msg) = req.payload else {
error!(actor = "useragent", "Received message with no payload");
actor.kill();
tx.send(Err(Status::invalid_argument(
"Expected message with payload",
)))
.await;
return;
};
let UserAgentRequestPayload::AuthMessage(ClientMessage {
payload: Some(client_message),
}) = msg
else {
error!(
actor = "useragent",
"Received unexpected message type during authentication"
);
actor.kill();
tx.send(Err(Status::invalid_argument(
"Expected AuthMessage with ClientMessage payload",
)))
.await;
return;
};
match client_message {
ClientAuthPayload::AuthChallengeRequest(req) => {}
ClientAuthPayload::AuthChallengeSolution(_auth_challenge_solution) => todo!(),
_ => {
error!(actor = "useragent", "Received unexpected message type");
actor.kill();
tx.send(Err(Status::invalid_argument(
"Expected AuthMessage with ClientMessage payload",
)))
.await;
return;
}
}
todo!()
}
});
Ok(Response::new(ReceiverStream::new(rx))) Ok(Response::new(ReceiverStream::new(rx)))
} }
} }