feat(unseal): add unseal protocol and crypto infrastructure

This commit is contained in:
hdbg
2026-02-11 13:31:49 +01:00
parent 8dd0276185
commit bbbb4feaa0
18 changed files with 1323 additions and 88 deletions

View File

@@ -1,6 +1,6 @@
syntax = "proto3";
package arbiter.auth;
package arbiter.unseal;
message UserAgentKeyRequest {}

667
server/Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -10,7 +10,7 @@ resolver = "3"
[workspace.dependencies]
prost = "0.14.3"
tonic = { version = "0.14.3", features = ["tls-connect-info"] }
tonic = { version = "0.14.3", features = ["deflate", "gzip", "tls-connect-info", "zstd"] }
tracing = "0.1.44"
tokio = { version = "1.49.0", features = ["full"] }
ed25519 = "3.0.0-rc.4"
@@ -22,3 +22,6 @@ rustls = "0.23.36"
smlang = "0.8.0"
miette = { version = "7.6.0", features = ["fancy", "serde"] }
thiserror = "2.0.18"
async-trait = "0.1.89"
futures = "0.3.31"
tokio-stream = { version = "0.1.18", features = ["full"] }

View File

@@ -12,6 +12,9 @@ prost-derive = "0.14.3"
prost-types = { version = "0.14.3", features = ["chrono"] }
tonic-prost = "0.14.3"
rkyv = "0.8.15"
tokio.workspace = true
futures.workspace = true
[build-dependencies]

View File

@@ -4,4 +4,21 @@ pub mod proto {
pub mod auth {
tonic::include_proto!("arbiter.auth");
}
}
}
pub mod transport;
pub static BOOTSTRAP_TOKEN_PATH: &'static str = "bootstrap_token";
pub fn home_path() -> Result<std::path::PathBuf, std::io::Error> {
static ARBITER_HOME: &'static str = ".arbiter";
let home_dir = std::env::home_dir().ok_or(std::io::Error::new(
std::io::ErrorKind::PermissionDenied,
"can not get home directory",
))?;
let arbiter_home = home_dir.join(ARBITER_HOME);
std::fs::create_dir_all(&arbiter_home)?;
Ok(arbiter_home)
}

View File

@@ -0,0 +1,46 @@
use futures::{Stream, StreamExt};
use tokio::sync::mpsc::{self, error::SendError};
use tonic::{Status, Streaming};
// Abstraction for stream for sans-io capabilities
pub trait Bi<T, U>: Stream<Item = Result<T, Status>> + Send + Sync + 'static {
type Error;
fn send(
&mut self,
item: Result<U, Status>,
) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
}
// Bi-directional stream abstraction for handling gRPC streaming requests and responses
pub struct BiStream<T, U> {
pub request_stream: Streaming<T>,
pub response_sender: mpsc::Sender<Result<U, Status>>,
}
impl<T, U> Stream for BiStream<T, U>
where
T: Send + 'static,
U: Send + 'static,
{
type Item = Result<T, Status>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.request_stream.poll_next_unpin(cx)
}
}
impl<T, U> Bi<T, U> for BiStream<T, U>
where
T: Send + 'static,
U: Send + 'static,
{
type Error = SendError<Result<U, Status>>;
async fn send(&mut self, item: Result<U, Status>) -> Result<(), Self::Error> {
self.response_sender.send(item).await
}
}

View File

@@ -6,7 +6,7 @@ repository = "https://git.markettakers.org/MarketTakers/arbiter"
[dependencies]
diesel = { version = "2.3.6", features = ["sqlite", "uuid", "time", "chrono", "serde_json"] }
diesel-async = { version = "0.7.4", features = ["sqlite", "tokio", "migrations", "pool", "deadpool"] }
diesel-async = { version = "0.7.4", features = ["bb8", "migrations", "sqlite", "tokio"] }
ed25519.workspace = true
ed25519-dalek.workspace = true
arbiter-proto.path = "../arbiter-proto"
@@ -18,3 +18,17 @@ smlang.workspace = true
miette.workspace = true
thiserror.workspace = true
diesel_migrations = { version = "2.3.1", features = ["sqlite"] }
async-trait.workspace = true
statig = { version = "0.4.1", features = ["async"] }
secrecy = "0.10.3"
futures.workspace = true
tokio-stream.workspace = true
dashmap = "6.1.0"
rand.workspace = true
rcgen = { version = "0.14.7", features = ["aws_lc_rs", "pem", "x509-parser", "zeroize"], default-features = false }
rkyv = { version = "0.8.15", features = ["aligned", "little_endian", "pointer_width_64"] }
restructed = "0.2.2"
chrono.workspace = true
bytes = "1.11.1"
memsafe = "0.4.0"
chacha20poly1305 = { version = "0.10.1", features = ["std"] }

View File

@@ -1,29 +1,37 @@
-- This is a singleton
create table if not exists aead_encrypted (
id INTEGER not null PRIMARY KEY,
current_nonce integer not null default(1), -- if re-encrypted, this should be incremented
ciphertext blob not null,
tag blob not null,
schema_version integer not null default(1) -- server would need to reencrypt, because this means that we have changed algorithm
) STRICT;
-- This is a singleton
create table if not exists arbiter_settings (
root_key_enc blob, -- if null, means wasn't bootstrapped yet
id INTEGER not null PRIMARY KEY CHECK (id = 1), -- singleton row, id must be 1
root_key_id integer references aead_encrypted (id) on delete RESTRICT, -- if null, means wasn't bootstrapped yet
cert_key blob not null,
cert blob not null
) STRICT;
create table if not exists key_identity(
id integer primary key,
create table if not exists key_identity (
id integer not null primary key,
name text not null,
public_key text not null,
created_at integer not null default (unixepoch('now')),
updated_at integer not null default (unixepoch('now'))
created_at integer not null default(unixepoch ('now')),
updated_at integer not null default(unixepoch ('now'))
) STRICT;
create table if not exists useragent_client (
id integer primary key,
key_identity_id integer not null references key_identity(id) on delete cascade,
created_at integer not null default (unixepoch('now')),
updated_at integer not null default (unixepoch('now'))
id integer not null primary key,
key_identity_id integer not null references key_identity (id) on delete cascade,
created_at integer not null default(unixepoch ('now')),
updated_at integer not null default(unixepoch ('now'))
) STRICT;
create table if not exists program_client(
id integer primary key,
key_identity_id integer not null references key_identity(id) on delete cascade,
created_at integer not null default (unixepoch('now')),
updated_at integer not null default (unixepoch('now'))
create table if not exists program_client (
id integer not null primary key,
key_identity_id integer not null references key_identity (id) on delete cascade,
created_at integer not null default(unixepoch ('now')),
updated_at integer not null default(unixepoch ('now'))
) STRICT;

View File

@@ -0,0 +1,157 @@
use std::sync::Arc;
use diesel::OptionalExtension as _;
use diesel_async::RunQueryDsl as _;
use ed25519_dalek::VerifyingKey;
use miette::Diagnostic;
use rand::rngs::StdRng;
use smlang::statemachine;
use thiserror::Error;
use tokio::sync::RwLock;
use crate::{
context::{
lease::LeaseHandler,
tls::{TlsDataRaw, TlsManager},
},
db::{
self,
models::ArbiterSetting,
schema::{self, arbiter_settings},
},
};
pub(crate) mod lease;
pub(crate) mod tls;
pub(crate) mod bootstrap {
}
#[derive(Error, Debug, Diagnostic)]
pub enum InitError {
#[error("Database setup failed: {0}")]
#[diagnostic(code(arbiter_server::init::database_setup))]
DatabaseSetup(#[from] db::DatabaseSetupError),
#[error("Connection acquire failed: {0}")]
#[diagnostic(code(arbiter_server::init::database_pool))]
DatabasePool(#[from] db::PoolError),
#[error("Database query error: {0}")]
#[diagnostic(code(arbiter_server::init::database_query))]
DatabaseQuery(#[from] diesel::result::Error),
#[error("TLS initialization failed: {0}")]
#[diagnostic(code(arbiter_server::init::tls_init))]
Tls(#[from] tls::TlsInitError),
}
// TODO: Placeholder for secure root key cell implementation
pub struct KeyStorage;
statemachine! {
name: Server,
transitions: {
*NotBootstrapped + Bootstrapped = Sealed,
Sealed + Unsealed(KeyStorage) / move_key = Ready(KeyStorage),
Ready(KeyStorage) + Sealed / dispose_key = Sealed,
}
}
pub struct _Context;
impl ServerStateMachineContext for _Context {
fn move_key(&mut self, _event_data: KeyStorage) -> Result<KeyStorage, ()> {
todo!()
}
#[allow(missing_docs)]
#[allow(clippy::unused_unit)]
fn dispose_key(&mut self, _state_data: &KeyStorage) -> Result<(), ()> {
todo!()
}
}
pub(crate) struct _ServerContextInner {
pub db: db::DatabasePool,
pub state: RwLock<ServerStateMachine<_Context>>,
pub rng: StdRng,
pub tls: TlsManager,
pub user_agent_leases: LeaseHandler<VerifyingKey>,
pub client_leases: LeaseHandler<VerifyingKey>,
}
#[derive(Clone)]
pub(crate) struct ServerContext(Arc<_ServerContextInner>);
impl std::ops::Deref for ServerContext {
type Target = _ServerContextInner;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl ServerContext {
async fn load_tls(
db: &mut db::DatabaseConnection,
settings: Option<&ArbiterSetting>,
) -> Result<TlsManager, InitError> {
match &settings {
Some(settings) => {
let tls_data_raw = TlsDataRaw {
cert: settings.cert.clone(),
key: settings.cert_key.clone(),
};
Ok(TlsManager::new(Some(tls_data_raw)).await?)
}
None => {
let tls = TlsManager::new(None).await?;
let tls_data_raw = tls.bytes();
diesel::insert_into(arbiter_settings::table)
.values(&ArbiterSetting {
id: 1,
root_key_id: None,
cert_key: tls_data_raw.key,
cert: tls_data_raw.cert,
})
.execute(db)
.await?;
Ok(tls)
}
}
}
pub async fn new(db: db::DatabasePool) -> Result<Self, InitError> {
let mut conn = db.get().await?;
let rng = rand::make_rng();
let settings = arbiter_settings::table
.first::<ArbiterSetting>(&mut conn)
.await
.optional()?;
let tls = Self::load_tls(&mut conn, settings.as_ref()).await?;
drop(conn);
let mut state = ServerStateMachine::new(_Context);
if let Some(settings) = &settings
&& settings.root_key_id.is_some()
{
// TODO: pass the encrypted root key to the state machine and let it handle decryption and transition to Sealed
let _ = state.process_event(ServerEvents::Bootstrapped);
}
Ok(Self(Arc::new(_ServerContextInner {
db,
rng,
tls,
state: RwLock::new(state),
user_agent_leases: Default::default(),
client_leases: Default::default(),
})))
}
}

View File

@@ -0,0 +1,41 @@
use std::sync::Arc;
use dashmap::DashSet;
#[derive(Clone, Default)]
struct LeaseStorage<T: Eq + std::hash::Hash>(Arc<DashSet<T>>);
// A lease that automatically releases the item when dropped
pub struct Lease<T: Clone + std::hash::Hash + Eq> {
item: T,
storage: LeaseStorage<T>,
}
impl<T: Clone + std::hash::Hash + Eq> Drop for Lease<T> {
fn drop(&mut self) {
self.storage.0.remove(&self.item);
}
}
#[derive(Clone, Default)]
pub struct LeaseHandler<T: Clone + std::hash::Hash + Eq> {
storage: LeaseStorage<T>,
}
impl<T: Clone + std::hash::Hash + Eq> LeaseHandler<T> {
pub fn new() -> Self {
Self {
storage: LeaseStorage(Arc::new(DashSet::new())),
}
}
pub fn acquire(&self, item: T) -> Result<Lease<T>, ()> {
if self.storage.0.insert(item.clone()) {
Ok(Lease {
item,
storage: self.storage.clone(),
})
} else {
Err(())
}
}
}

View File

@@ -0,0 +1,89 @@
use std::string::FromUtf8Error;
use miette::Diagnostic;
use rcgen::{Certificate, KeyPair};
use rustls::pki_types::CertificateDer;
use thiserror::Error;
#[derive(Error, Debug, Diagnostic)]
pub enum TlsInitError {
#[error("Key generation error during TLS initialization: {0}")]
#[diagnostic(code(arbiter_server::tls_init::key_generation))]
KeyGeneration(#[from] rcgen::Error),
#[error("Key invalid format: {0}")]
#[diagnostic(code(arbiter_server::tls_init::key_invalid_format))]
KeyInvalidFormat(#[from] FromUtf8Error),
#[error("Key deserialization error: {0}")]
#[diagnostic(code(arbiter_server::tls_init::key_deserialization))]
KeyDeserializationError(rcgen::Error),
}
pub struct TlsData {
pub cert: CertificateDer<'static>,
pub keypair: KeyPair,
}
pub struct TlsDataRaw {
pub cert: Vec<u8>,
pub key: Vec<u8>,
}
impl TlsDataRaw {
pub fn serialize(cert: &TlsData) -> Self {
Self {
cert: cert.cert.as_ref().to_vec(),
key: cert.keypair.serialize_pem().as_bytes().to_vec(),
}
}
pub fn deserialize(&self) -> Result<TlsData, TlsInitError> {
let cert = CertificateDer::from_slice(&self.cert).into_owned();
let key =
String::from_utf8(self.key.clone()).map_err(TlsInitError::KeyInvalidFormat)?;
let keypair = KeyPair::from_pem(&key).map_err(TlsInitError::KeyDeserializationError)?;
Ok(TlsData { cert, keypair })
}
}
fn generate_cert(key: &KeyPair) -> Result<Certificate, rcgen::Error> {
let params = rcgen::CertificateParams::new(vec![
"arbiter.local".to_string(),
"localhost".to_string(),
])?;
params.self_signed(key)
}
// TODO: Implement cert rotation
pub(crate) struct TlsManager {
data: TlsData,
}
impl TlsManager {
pub async fn new(data: Option<TlsDataRaw>) -> Result<Self, TlsInitError> {
match data {
Some(raw) => {
let tls_data = raw.deserialize()?;
Ok(Self { data: tls_data })
}
None => {
let keypair = KeyPair::generate()?;
let cert = generate_cert(&keypair)?;
let tls_data = TlsData {
cert: cert.der().clone(),
keypair,
};
Ok(Self { data: tls_data })
}
}
}
pub fn bytes(&self) -> TlsDataRaw {
TlsDataRaw::serialize(&self.data)
}
}

View File

@@ -1,5 +1,11 @@
use std::sync::Arc;
use diesel::{Connection as _, SqliteConnection, connection::SimpleConnection as _};
use diesel_async::sync_connection_wrapper::SyncConnectionWrapper;
use diesel_async::{
AsyncConnection, SimpleAsyncConnection as _,
pooled_connection::{AsyncDieselConnectionManager, ManagerConfig, RecyclingMethod},
sync_connection_wrapper::SyncConnectionWrapper,
};
use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations};
use miette::Diagnostic;
use thiserror::Error;
@@ -7,10 +13,12 @@ use thiserror::Error;
pub mod models;
pub mod schema;
pub type Database = SyncConnectionWrapper<SqliteConnection>;
pub type DatabaseConnection = SyncConnectionWrapper<SqliteConnection>;
pub type DatabasePool = diesel_async::pooled_connection::bb8::Pool<DatabaseConnection>;
pub type PoolInitError = diesel_async::pooled_connection::PoolError;
pub type PoolError = diesel_async::pooled_connection::bb8::RunError;
static ARBITER_HOME: &'static str = ".arbiter";
static DB_FILE: &'static str = "db.sqlite";
static DB_FILE: &'static str = "arbiter.sqlite";
const MIGRATIONS: EmbeddedMigrations = embed_migrations!();
@@ -18,7 +26,7 @@ const MIGRATIONS: EmbeddedMigrations = embed_migrations!();
pub enum DatabaseSetupError {
#[error("Failed to determine home directory")]
#[diagnostic(code(arbiter::db::home_dir_error))]
HomeDir(Option<std::io::Error>),
HomeDir(std::io::Error),
#[error(transparent)]
#[diagnostic(code(arbiter::db::connection_error))]
@@ -31,27 +39,22 @@ pub enum DatabaseSetupError {
#[error(transparent)]
#[diagnostic(code(arbiter::db::migration_error))]
Migration(Box<dyn std::error::Error + Send + Sync>),
#[error(transparent)]
#[diagnostic(code(arbiter::db::pool_error))]
Pool(#[from] PoolInitError),
}
fn database_path() -> Result<std::path::PathBuf, DatabaseSetupError> {
let home_dir = std::env::home_dir().ok_or_else(|| DatabaseSetupError::HomeDir(None))?;
let arbiter_home = home_dir.join(ARBITER_HOME);
let arbiter_home = arbiter_proto::home_path().map_err(DatabaseSetupError::HomeDir)?;
let db_path = arbiter_home.join(DB_FILE);
std::fs::create_dir_all(arbiter_home)
.map_err(|err| DatabaseSetupError::HomeDir(Some(err)))?;
Ok(db_path)
}
fn setup_concurrency(conn: &mut SqliteConnection) -> Result<(), diesel::result::Error> {
// see https://fractaledmind.github.io/2023/09/07/enhancing-rails-sqlite-fine-tuning/
// sleep if the database is busy, this corresponds to up to 2 seconds sleeping time.
conn.batch_execute("PRAGMA busy_timeout = 2000;")?;
// better write-concurrency
conn.batch_execute("PRAGMA journal_mode = WAL;")?;
fn db_config(conn: &mut SqliteConnection) -> Result<(), diesel::result::Error> {
// fsync only in critical moments
conn.batch_execute("PRAGMA synchronous = NORMAL;")?;
// write WAL changes back every 1000 pages, for an in average 1MB WAL file.
@@ -60,24 +63,62 @@ fn setup_concurrency(conn: &mut SqliteConnection) -> Result<(), diesel::result::
// free some space by truncating possibly massive WAL files from the last run
conn.batch_execute("PRAGMA wal_checkpoint(TRUNCATE);")?;
// sqlite foreign keys are disabled by default, enable them for safety
conn.batch_execute("PRAGMA foreign_keys = ON;")?;
// better space reclamation
conn.batch_execute("PRAGMA auto_vacuum = FULL;")?;
// secure delete, overwrite deleted content with zeros to prevent recovery
conn.batch_execute("PRAGMA secure_delete = ON;")?;
Ok(())
}
#[tracing::instrument]
pub fn connect() -> Result<Database, DatabaseSetupError> {
let database_url = format!(
"{}?mode=rwc",
database_path()?
.to_str()
.ok_or_else(|| DatabaseSetupError::HomeDir(None))?
);
let mut conn =
SqliteConnection::establish(&database_url).map_err(DatabaseSetupError::Connection)?;
fn initialize_database(url: &str) -> Result<(), DatabaseSetupError> {
let mut conn = SqliteConnection::establish(url).map_err(DatabaseSetupError::Connection)?;
setup_concurrency(&mut conn).map_err(DatabaseSetupError::ConcurrencySetup)?;
db_config(&mut conn).map_err(DatabaseSetupError::ConcurrencySetup)?;
conn.run_pending_migrations(MIGRATIONS)
.map_err(DatabaseSetupError::Migration)?;
Ok(SyncConnectionWrapper::new(conn))
Ok(())
}
pub async fn create_pool() -> Result<DatabasePool, DatabaseSetupError> {
let database_url = format!(
"{}?mode=rwc",
database_path()?
.to_str()
.expect("database path is not valid UTF-8")
);
initialize_database(&database_url)?;
let mut config = ManagerConfig::default();
config.custom_setup = Box::new(|url| {
Box::pin(async move {
let mut conn = DatabaseConnection::establish(url).await?;
// see https://fractaledmind.github.io/2023/09/07/enhancing-rails-sqlite-fine-tuning/
// sleep if the database is busy, this corresponds to up to 9 seconds sleeping time.
conn.batch_execute("PRAGMA busy_timeout = 9000;")
.await
.map_err(diesel::ConnectionError::CouldntSetupConfiguration)?;
// better write-concurrency
conn.batch_execute("PRAGMA journal_mode = WAL;")
.await
.map_err(diesel::ConnectionError::CouldntSetupConfiguration)?;
Ok(conn)
})
});
let pool = DatabasePool::builder().build(AsyncDieselConnectionManager::new_with_config(
database_url,
config,
)).await?;
Ok(pool)
}

View File

@@ -0,0 +1,57 @@
#![allow(unused)]
#![allow(clippy::all)]
use crate::db::schema::{self, aead_encrypted, arbiter_settings};
use diesel::{prelude::*, sqlite::Sqlite};
pub mod types {
use chrono::{DateTime, Utc};
pub struct SqliteTimestamp(DateTime<Utc>);
}
#[derive(Queryable, Debug, Insertable)]
#[diesel(table_name = aead_encrypted, check_for_backend(Sqlite))]
pub struct AeadEncrypted {
pub id: i32,
pub ciphertext: Vec<u8>,
pub tag: Vec<u8>,
pub current_nonce: i32,
pub schema_version: i32,
}
#[derive(Queryable, Debug, Insertable)]
#[diesel(table_name = arbiter_settings, check_for_backend(Sqlite))]
pub struct ArbiterSetting {
pub id: i32,
pub root_key_id: Option<i32>, // references aead_encrypted.id
pub cert_key: Vec<u8>,
pub cert: Vec<u8>,
}
#[derive(Queryable, Debug)]
#[diesel(table_name = schema::key_identity, check_for_backend(Sqlite))]
pub struct KeyIdentity {
pub id: i32,
pub name: String,
pub public_key: String,
pub created_at: i32,
pub updated_at: i32,
}
#[derive(Queryable, Debug)]
#[diesel(table_name = schema::program_client, check_for_backend(Sqlite))]
pub struct ProgramClient {
pub id: i32,
pub key_identity_id: i32,
pub created_at: i32,
pub updated_at: i32,
}
#[derive(Queryable, Debug)]
#[diesel(table_name = schema::useragent_client, check_for_backend(Sqlite))]
pub struct UseragentClient {
pub id: i32,
pub key_identity_id: i32,
pub created_at: i32,
pub updated_at: i32,
}

View File

@@ -1,9 +1,19 @@
// @generated automatically by Diesel CLI.
diesel::table! {
arbiter_settings (rowid) {
rowid -> Integer,
root_key_enc -> Nullable<Binary>,
aead_encrypted (id) {
id -> Integer,
current_nonce -> Integer,
ciphertext -> Binary,
tag -> Binary,
schema_version -> Integer,
}
}
diesel::table! {
arbiter_settings (id) {
id -> Integer,
root_key_id -> Nullable<Integer>,
cert_key -> Binary,
cert -> Binary,
}
@@ -11,7 +21,7 @@ diesel::table! {
diesel::table! {
key_identity (id) {
id -> Nullable<Integer>,
id -> Integer,
name -> Text,
public_key -> Text,
created_at -> Integer,
@@ -21,7 +31,7 @@ diesel::table! {
diesel::table! {
program_client (id) {
id -> Nullable<Integer>,
id -> Integer,
key_identity_id -> Integer,
created_at -> Integer,
updated_at -> Integer,
@@ -30,17 +40,19 @@ diesel::table! {
diesel::table! {
useragent_client (id) {
id -> Nullable<Integer>,
id -> Integer,
key_identity_id -> Integer,
created_at -> Integer,
updated_at -> Integer,
}
}
diesel::joinable!(arbiter_settings -> aead_encrypted (root_key_id));
diesel::joinable!(program_client -> key_identity (key_identity_id));
diesel::joinable!(useragent_client -> key_identity (key_identity_id));
diesel::allow_tables_to_appear_in_same_query!(
aead_encrypted,
arbiter_settings,
key_identity,
program_client,

View File

@@ -0,0 +1,2 @@
pub mod user_agent;
pub mod client;

View File

@@ -0,0 +1,12 @@
use arbiter_proto::{
proto::{ClientRequest, ClientResponse},
transport::Bi,
};
use crate::ServerContext;
pub(crate) async fn handle_client(
_context: ServerContext,
_bistream: impl Bi<ClientRequest, ClientResponse>,
) {
}

View File

@@ -0,0 +1,69 @@
use arbiter_proto::{
proto::{
UserAgentRequest, UserAgentResponse,
auth::{
self, AuthChallengeRequest, ClientMessage, client_message::Payload as ClientAuthPayload
},
user_agent_request::Payload as UserAgentRequestPayload,
},
transport::Bi,
};
use futures::StreamExt;
use tracing::error;
use crate::ServerContext;
smlang::statemachine!(
name: UserAgentAuth,
derive_states: [Debug],
derive_events: [Clone, Debug],
transitions: {
*Init + ReceivedRequest(ed25519_dalek::VerifyingKey) / provide_challenge = WaitingForChallengeSolution(auth::AuthChallenge),
WaitingForChallengeSolution(auth::AuthChallenge) + ReceivedGoodSolution = Authenticated,
WaitingForChallengeSolution(auth::AuthChallenge) + ReceivedBadSolution = Error,
}
);
impl UserAgentAuthStateMachineContext for ServerContext {
#[allow(missing_docs)]
#[allow(clippy::unused_unit)]
fn provide_challenge< >(&mut self,_event_data:ed25519_dalek::VerifyingKey) -> Result<auth::AuthChallenge,()> {
todo!()
}
}
pub(crate) async fn handle_user_agent(
context: ServerContext,
mut bistream: impl Bi<UserAgentRequest, UserAgentResponse> + Unpin,
) {
let auth_sm = UserAgentAuthStateMachine::new(context);
while let Some(Ok(msg)) = bistream.next().await
&& auth_sm.state() != &UserAgentAuthStates::Authenticated
{
let Some(msg) = msg.payload else {
error!(handler = "useragent", "Received message with no payload");
return;
};
let UserAgentRequestPayload::AuthMessage(ClientMessage {
payload: Some(client_message),
}) = msg
else {
error!(
handler = "useragent",
"Received unexpected message type during authentication"
);
return;
};
match client_message {
ClientAuthPayload::AuthChallengeRequest(auth_challenge_request) => {
let AuthChallengeRequest { pubkey } = auth_challenge_request;
},
ClientAuthPayload::AuthChallengeSolution(_auth_challenge_solution) => todo!(),
}
}
}

View File

@@ -1,5 +1,66 @@
#![allow(unused)]
use arbiter_proto::{
proto::{ClientRequest, ClientResponse, UserAgentRequest, UserAgentResponse},
transport::BiStream,
};
use async_trait::async_trait;
use tokio_stream::wrappers::ReceiverStream;
use tokio::sync::mpsc;
use tonic::{Request, Response, Status};
use crate::{
handlers::{client::handle_client, user_agent::handle_user_agent},
context::ServerContext,
};
mod db;
pub mod handlers;
mod context;
const DEFAULT_CHANNEL_SIZE: usize = 1000;
pub struct Server {
pub db: db::Database,
context: ServerContext,
}
#[async_trait]
impl arbiter_proto::proto::arbiter_service_server::ArbiterService for Server {
type UserAgentStream = ReceiverStream<Result<UserAgentResponse, Status>>;
type ClientStream = ReceiverStream<Result<ClientResponse, Status>>;
async fn client(
&self,
request: Request<tonic::Streaming<ClientRequest>>,
) -> Result<Response<Self::ClientStream>, Status> {
let req_stream = request.into_inner();
let (tx, rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
tokio::spawn(handle_client(
self.context.clone(),
BiStream {
request_stream: req_stream,
response_sender: tx,
},
));
Ok(Response::new(ReceiverStream::new(rx)))
}
async fn user_agent(
&self,
request: Request<tonic::Streaming<UserAgentRequest>>,
) -> Result<Response<Self::UserAgentStream>, Status> {
let req_stream = request.into_inner();
let (tx, rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE);
tokio::spawn(handle_user_agent(
self.context.clone(),
BiStream {
request_stream: req_stream,
response_sender: tx,
},
));
Ok(Response::new(ReceiverStream::new(rx)))
}
}