feat(auth): simplify auth model and implement bootstrap flow

Remove key_identity indirection table, storing public keys and nonces
directly on client tables. Replace AuthResponse with AuthOk, add a
BootstrapActor to manage token lifecycle, and move user agent stream
handling into the actor module.
This commit is contained in:
hdbg
2026-02-13 17:55:56 +01:00
parent 8fb7a04102
commit ffa60c90b1
8 changed files with 256 additions and 134 deletions

View File

@@ -20,12 +20,7 @@ message AuthChallengeSolution {
bytes signature = 2;
}
message AuthResponse {
string token = 1;
string refresh_token = 2;
google.protobuf.Timestamp expires_at = 3;
google.protobuf.Timestamp refresh_expires_at = 4;
}
message AuthOk {}
message ClientMessage {
oneof payload {
@@ -37,6 +32,6 @@ message ClientMessage {
message ServerMessage {
oneof payload {
AuthChallenge auth_challenge = 1;
AuthResponse auth_response = 2;
AuthOk auth_ok = 2;
}
}

View File

@@ -14,24 +14,18 @@ create table if not exists arbiter_settings (
cert blob not null
) STRICT;
create table if not exists key_identity (
id integer not null primary key,
name text not null,
public_key text not null,
created_at integer not null default(unixepoch ('now')),
updated_at integer not null default(unixepoch ('now'))
) STRICT;
create table if not exists useragent_client (
id integer not null primary key,
key_identity_id integer not null references key_identity (id) on delete cascade,
nonce integer not null default (1), -- used for auth challenge
public_key blob not null,
created_at integer not null default(unixepoch ('now')),
updated_at integer not null default(unixepoch ('now'))
) STRICT;
create table if not exists program_client (
id integer not null primary key,
key_identity_id integer not null references key_identity (id) on delete cascade,
nonce integer not null default (1), -- used for auth challenge
public_key blob not null,
created_at integer not null default(unixepoch ('now')),
updated_at integer not null default(unixepoch ('now'))
) STRICT;

View File

@@ -4,27 +4,40 @@ use arbiter_proto::{
proto::{
UserAgentRequest, UserAgentResponse,
auth::{
self, AuthChallengeRequest, ClientMessage, client_message::Payload as ClientAuthPayload,
self, AuthChallengeRequest, ClientMessage, ServerMessage as AuthServerMessage,
client_message::Payload as ClientAuthPayload,
server_message::Payload as ServerAuthPayload,
},
user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
},
transport::Bi,
};
use ed25519_dalek::VerifyingKey;
use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl};
use diesel_async::{AsyncConnection, RunQueryDsl};
use ed25519_dalek::{SigningKey, VerifyingKey};
use futures::StreamExt;
use kameo::{Actor, message::StreamMessage, messages, prelude::Context};
use kameo::{
Actor,
actor::{ActorRef, Spawn},
error::SendError,
message::StreamMessage,
messages,
prelude::Context,
};
use secrecy::{ExposeSecret, SecretBox};
use tokio::sync::mpsc;
use tokio::sync::mpsc::Sender;
use tonic::{Status, transport::Server};
use tracing::error;
use crate::ServerContext;
use crate::{ServerContext, context::bootstrap::ConsumeToken, db::schema};
#[derive(Debug)]
pub struct ChallengeContext {
challenge: auth::AuthChallenge,
key: ed25519_dalek::SigningKey,
key: SigningKey,
bootstrap_token: Option<String>,
}
smlang::statemachine!(
@@ -83,6 +96,41 @@ impl UserAgentActor {
pubkey: ed25519_dalek::VerifyingKey,
token: String,
) -> Result<UserAgentResponse, Status> {
let token_ok: bool = self
.context
.bootstrapper
.ask(ConsumeToken { token })
.await
.map_err(|e| {
error!(?pubkey, "Failed to consume bootstrap token: {e}");
Status::internal("Bootstrap token consumption failed")
})?;
if token_ok {
let mut conn = self.context.db.get().await.map_err(|e| {
error!(?pubkey, "Failed to get DB connection: {e}");
Status::internal("Database connection error")
})?;
diesel::insert_into(schema::useragent_client::table)
.values((
schema::useragent_client::public_key.eq(pubkey.as_bytes().to_vec()),
schema::useragent_client::nonce.eq(1),
))
.execute(&mut conn)
.await
.map_err(|e| {
error!(?pubkey, "Failed to insert new user agent client: {e}");
Status::internal("Database error")
})?;
return Ok(UserAgentResponse {
payload: Some(UserAgentResponsePayload::AuthMessage(AuthServerMessage {
payload: Some(ServerAuthPayload::Auth),
})),
});
}
todo!()
}
}
@@ -92,7 +140,7 @@ type Output = Result<UserAgentResponse, Status>;
#[messages]
impl UserAgentActor {
#[message(ctx)]
async fn handle_auth_challenge_request(
pub async fn handle_auth_challenge_request(
&mut self,
req: AuthChallengeRequest,
ctx: &mut Context<Self, Output>,
@@ -106,21 +154,112 @@ impl UserAgentActor {
})?;
if let Some(token) = req.bootstrap_token {
return self
.auth_with_bootstrap_token(pubkey, token)
.await
.map_err(|_| Status::internal("Failed to authenticate with bootstrap token"));
return self.auth_with_bootstrap_token(pubkey, token).await;
}
let mut db_conn = self.context.db.get().await.map_err(|err| {
error!(?pubkey, "Failed to get DB connection: {err}");
Status::internal("Database connection error")
})?;
let nonce = db_conn
.transaction(|mut conn| {
Box::pin(async move {
let current_nonce = schema::useragent_client::table
.filter(schema::useragent_client::public_key.eq(pubkey.as_bytes().to_vec()))
.select(schema::useragent_client::nonce)
.first::<i32>(&mut db_conn)
.await?;
Ok(())
})
})
.await;
// let nonce = match last_used_nonce
todo!()
}
#[message(ctx)]
async fn handle_auth_challenge_solution(
pub async fn handle_auth_challenge_solution(
&mut self,
_solution: auth::AuthChallengeSolution,
solution: auth::AuthChallengeSolution,
ctx: &mut Context<Self, Output>,
) -> Output {
todo!()
}
}
pub(crate) async fn handle_user_agent(
context: ServerContext,
mut req_stream: tonic::Streaming<UserAgentRequest>,
tx: mpsc::Sender<Result<UserAgentResponse, Status>>,
) {
let actor = UserAgentActor::spawn(UserAgentActor::new(context, tx.clone()));
while let Some(Ok(req)) = req_stream.next().await
&& actor.is_alive()
{
match process_message(&actor, req).await {
Ok(resp) => {
if tx.send(Ok(resp)).await.is_err() {
error!(actor = "useragent", "Failed to send response to client");
break;
}
}
Err(status) => {
let _ = tx.send(Err(status)).await;
break;
}
}
}
actor.kill();
}
async fn process_message(
actor: &ActorRef<UserAgentActor>,
req: UserAgentRequest,
) -> Result<UserAgentResponse, Status> {
let msg = req.payload.ok_or_else(|| {
error!(actor = "useragent", "Received message with no payload");
Status::invalid_argument("Expected message with payload")
})?;
let UserAgentRequestPayload::AuthMessage(ClientMessage {
payload: Some(client_message),
}) = msg
else {
error!(
actor = "useragent",
"Received unexpected message type during authentication"
);
return Err(Status::invalid_argument(
"Expected AuthMessage with ClientMessage payload",
));
};
let result = match client_message {
ClientAuthPayload::AuthChallengeRequest(req) => actor
.ask(HandleAuthChallengeRequest { req })
.await
.map_err(into_status),
ClientAuthPayload::AuthChallengeSolution(solution) => actor
.ask(HandleAuthChallengeSolution { solution })
.await
.map_err(into_status),
};
result
}
fn into_status<M>(e: SendError<M, Status>) -> Status {
match e {
SendError::HandlerError(status) => status,
_ => {
error!(actor = "useragent", "Failed to send message to actor");
Status::internal("session failure")
}
}
}

View File

@@ -3,6 +3,7 @@ use std::sync::Arc;
use diesel::OptionalExtension as _;
use diesel_async::RunQueryDsl as _;
use ed25519_dalek::VerifyingKey;
use kameo::actor::{ActorRef, Spawn};
use miette::Diagnostic;
use rand::rngs::StdRng;
use smlang::statemachine;
@@ -11,7 +12,7 @@ use tokio::sync::RwLock;
use crate::{
context::{
bootstrap::generate_token,
bootstrap::{BootstrapActor, generate_token},
lease::LeaseHandler,
tls::{TlsDataRaw, TlsManager},
},
@@ -44,6 +45,10 @@ pub enum InitError {
#[diagnostic(code(arbiter_server::init::tls_init))]
Tls(#[from] tls::TlsInitError),
#[error("Bootstrap token generation failed: {0}")]
#[diagnostic(code(arbiter_server::init::bootstrap_token))]
BootstrapToken(#[from] bootstrap::BootstrapError),
#[error("I/O Error: {0}")]
#[diagnostic(code(arbiter_server::init::io))]
Io(#[from] std::io::Error),
@@ -55,7 +60,7 @@ pub struct KeyStorage;
statemachine! {
name: Server,
transitions: {
*NotBootstrapped(String) + Bootstrapped = Sealed,
*NotBootstrapped + Bootstrapped = Sealed,
Sealed + Unsealed(KeyStorage) / move_key = Ready(KeyStorage),
Ready(KeyStorage) + Sealed / dispose_key = Sealed,
}
@@ -78,8 +83,7 @@ pub(crate) struct _ServerContextInner {
pub state: RwLock<ServerStateMachine<_Context>>,
pub rng: StdRng,
pub tls: TlsManager,
pub user_agent_leases: LeaseHandler<VerifyingKey>,
pub client_leases: LeaseHandler<VerifyingKey>,
pub bootstrapper: ActorRef<BootstrapActor>,
}
#[derive(Clone)]
pub(crate) struct ServerContext(Arc<_ServerContextInner>);
@@ -138,9 +142,7 @@ impl ServerContext {
drop(conn);
let bootstrap_token = generate_token().await?;
let mut state = ServerStateMachine::new(_Context, bootstrap_token);
let mut state = ServerStateMachine::new(_Context);
if let Some(settings) = &settings
&& settings.root_key_id.is_some()
@@ -150,12 +152,11 @@ impl ServerContext {
}
Ok(Self(Arc::new(_ServerContextInner {
bootstrapper: BootstrapActor::spawn(BootstrapActor::new(&db).await?),
db,
rng,
tls,
state: RwLock::new(state),
user_agent_leases: Default::default(),
client_leases: Default::default(),
})))
}
}

View File

@@ -1,6 +1,11 @@
use arbiter_proto::{BOOTSTRAP_TOKEN_PATH, home_path};
use diesel::{QueryDsl, dsl::exists, select};
use diesel::{
ExpressionMethods, QueryDsl,
dsl::{count, exists},
select,
};
use diesel_async::RunQueryDsl;
use kameo::{Actor, messages};
use memsafe::MemSafe;
use miette::Diagnostic;
use rand::{RngExt, distr::StandardUniform, make_rng, rngs::StdRng};
@@ -9,7 +14,10 @@ use thiserror::Error;
use tracing::info;
use zeroize::{Zeroize, Zeroizing};
use crate::db::{self, schema};
use crate::{
context::{self, ServerContext},
db::{self, DatabasePool, schema},
};
const TOKEN_LENGTH: usize = 64;
@@ -28,3 +36,70 @@ pub async fn generate_token() -> Result<String, std::io::Error> {
Ok(token)
}
#[derive(Error, Debug, Diagnostic)]
pub enum BootstrapError {
#[error("Database error: {0}")]
#[diagnostic(code(arbiter_server::bootstrap::database))]
Database(#[from] db::PoolError),
#[error("Database query error: {0}")]
#[diagnostic(code(arbiter_server::bootstrap::database_query))]
Query(#[from] diesel::result::Error),
#[error("I/O error: {0}")]
#[diagnostic(code(arbiter_server::bootstrap::io))]
Io(#[from] std::io::Error),
}
#[derive(Actor)]
pub struct BootstrapActor {
token: Option<String>,
}
impl BootstrapActor {
pub async fn new(db: &DatabasePool) -> Result<Self, BootstrapError> {
let mut conn = db.get().await?;
let needs_token: bool = select(exists(
schema::useragent_client::table
.filter(schema::useragent_client::id.eq(schema::useragent_client::id)), // Just check if the table is empty
))
.first(&mut conn)
.await?;
drop(conn);
let token = if needs_token {
let token = generate_token().await?;
info!(%token, "Generated bootstrap token");
tokio::fs::write(home_path()?.join(BOOTSTRAP_TOKEN_PATH), token.as_str()).await?;
Some(token)
} else {
None
};
Ok(Self { token })
}
}
#[messages]
impl BootstrapActor {
#[message]
pub fn is_correct_token(&self, token: String) -> bool {
match &self.token {
Some(expected) => *expected == token,
None => false,
}
}
#[message]
pub fn consume_token(&mut self, token: String) -> bool {
if self.is_correct_token(token) {
self.token = None;
true
} else {
false
}
}
}

View File

@@ -28,21 +28,12 @@ pub struct ArbiterSetting {
pub cert: Vec<u8>,
}
#[derive(Queryable, Debug)]
#[diesel(table_name = schema::key_identity, check_for_backend(Sqlite))]
pub struct KeyIdentity {
pub id: i32,
pub name: String,
pub public_key: String,
pub created_at: i32,
pub updated_at: i32,
}
#[derive(Queryable, Debug)]
#[diesel(table_name = schema::program_client, check_for_backend(Sqlite))]
pub struct ProgramClient {
pub id: i32,
pub key_identity_id: i32,
pub public_key: Vec<u8>,
pub nonce: i32,
pub created_at: i32,
pub updated_at: i32,
}
@@ -51,7 +42,8 @@ pub struct ProgramClient {
#[diesel(table_name = schema::useragent_client, check_for_backend(Sqlite))]
pub struct UseragentClient {
pub id: i32,
pub key_identity_id: i32,
pub public_key: Vec<u8>,
pub nonce: i32,
pub created_at: i32,
pub updated_at: i32,
}

View File

@@ -19,20 +19,11 @@ diesel::table! {
}
}
diesel::table! {
key_identity (id) {
id -> Integer,
name -> Text,
public_key -> Text,
created_at -> Integer,
updated_at -> Integer,
}
}
diesel::table! {
program_client (id) {
id -> Integer,
key_identity_id -> Integer,
nonce -> Integer,
public_key -> Binary,
created_at -> Integer,
updated_at -> Integer,
}
@@ -41,20 +32,18 @@ diesel::table! {
diesel::table! {
useragent_client (id) {
id -> Integer,
key_identity_id -> Integer,
nonce -> Integer,
public_key -> Binary,
created_at -> Integer,
updated_at -> Integer,
}
}
diesel::joinable!(arbiter_settings -> aead_encrypted (root_key_id));
diesel::joinable!(program_client -> key_identity (key_identity_id));
diesel::joinable!(useragent_client -> key_identity (key_identity_id));
diesel::allow_tables_to_appear_in_same_query!(
aead_encrypted,
arbiter_settings,
key_identity,
program_client,
useragent_client,
);

View File

@@ -2,32 +2,18 @@
use std::sync::Arc;
use tracing::error;
use arbiter_proto::{
proto::{
ClientRequest, ClientResponse, UserAgentRequest, UserAgentResponse,
auth::{
self, AuthChallengeRequest, ClientMessage, client_message::Payload as ClientAuthPayload,
},
user_agent_request::Payload as UserAgentRequestPayload,
user_agent_request::*,
},
proto::{ClientRequest, ClientResponse, UserAgentRequest, UserAgentResponse},
transport::BiStream,
};
use async_trait::async_trait;
use futures::StreamExt;
use kameo::actor::Spawn;
use tokio_stream::wrappers::ReceiverStream;
use tokio::sync::mpsc;
use tonic::{Request, Response, Status};
use crate::{
actors::{
client::handle_client,
user_agent::{self, UserAgentActor},
},
actors::{client::handle_client, user_agent::handle_user_agent},
context::ServerContext,
};
@@ -67,58 +53,9 @@ impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server {
&self,
request: Request<tonic::Streaming<UserAgentRequest>>,
) -> Result<Response<Self::UserAgentStream>, Status> {
let mut req_stream = request.into_inner();
let req_stream = request.into_inner();
let (tx, rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
let actor = UserAgentActor::spawn(UserAgentActor::new(self.context.clone(), tx.clone()));
tokio::task::spawn(async move {
while let Some(Ok(req)) = req_stream.next().await
&& actor.is_alive()
{
let Some(msg) = req.payload else {
error!(actor = "useragent", "Received message with no payload");
actor.kill();
tx.send(Err(Status::invalid_argument(
"Expected message with payload",
)))
.await;
return;
};
let UserAgentRequestPayload::AuthMessage(ClientMessage {
payload: Some(client_message),
}) = msg
else {
error!(
actor = "useragent",
"Received unexpected message type during authentication"
);
actor.kill();
tx.send(Err(Status::invalid_argument(
"Expected AuthMessage with ClientMessage payload",
)))
.await;
return;
};
match client_message {
ClientAuthPayload::AuthChallengeRequest(req) => {}
ClientAuthPayload::AuthChallengeSolution(_auth_challenge_solution) => todo!(),
_ => {
error!(actor = "useragent", "Received unexpected message type");
actor.kill();
tx.send(Err(Status::invalid_argument(
"Expected AuthMessage with ClientMessage payload",
)))
.await;
return;
}
}
todo!()
}
});
tokio::spawn(handle_user_agent(self.context.clone(), req_stream, tx));
Ok(Response::new(ReceiverStream::new(rx)))
}
}