feat(proto): add URL parsing and TLS certificate management
This commit is contained in:
600
server/Cargo.lock
generated
600
server/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -24,3 +24,12 @@ futures = "0.3.31"
|
|||||||
tokio-stream = { version = "0.1.18", features = ["full"] }
|
tokio-stream = { version = "0.1.18", features = ["full"] }
|
||||||
kameo = "0.19.2"
|
kameo = "0.19.2"
|
||||||
prost-types = { version = "0.14.3", features = ["chrono"] }
|
prost-types = { version = "0.14.3", features = ["chrono"] }
|
||||||
|
x25519-dalek = { version = "2.0.1", features = ["getrandom"] }
|
||||||
|
rstest = "0.26.1"
|
||||||
|
rustls-pki-types = "1.14.0"
|
||||||
|
rcgen = { version = "0.14.7", features = [
|
||||||
|
"aws_lc_rs",
|
||||||
|
"pem",
|
||||||
|
"x509-parser",
|
||||||
|
"zeroize",
|
||||||
|
], default-features = false }
|
||||||
|
|||||||
@@ -13,9 +13,19 @@ hex = "0.4.3"
|
|||||||
tonic-prost = "0.14.3"
|
tonic-prost = "0.14.3"
|
||||||
prost = "0.14.3"
|
prost = "0.14.3"
|
||||||
kameo.workspace = true
|
kameo.workspace = true
|
||||||
|
url = "2.5.8"
|
||||||
|
miette.workspace = true
|
||||||
|
thiserror.workspace = true
|
||||||
|
rustls-pki-types.workspace = true
|
||||||
|
base64 = "0.22.1"
|
||||||
prost-types.workspace = true
|
prost-types.workspace = true
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
tonic-prost-build = "0.14.3"
|
tonic-prost-build = "0.14.3"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
rstest.workspace = true
|
||||||
|
rand.workspace = true
|
||||||
|
rcgen.workspace = true
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,8 @@
|
|||||||
|
pub mod transport;
|
||||||
|
pub mod url;
|
||||||
|
|
||||||
|
use base64::{Engine, prelude::BASE64_STANDARD};
|
||||||
|
|
||||||
use crate::proto::auth::AuthChallenge;
|
use crate::proto::auth::AuthChallenge;
|
||||||
|
|
||||||
pub mod proto {
|
pub mod proto {
|
||||||
@@ -8,9 +13,7 @@ pub mod proto {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub mod transport;
|
pub static BOOTSTRAP_PATH: &str = "bootstrap_token";
|
||||||
|
|
||||||
pub static BOOTSTRAP_TOKEN_PATH: &str = "bootstrap_token";
|
|
||||||
|
|
||||||
pub fn home_path() -> Result<std::path::PathBuf, std::io::Error> {
|
pub fn home_path() -> Result<std::path::PathBuf, std::io::Error> {
|
||||||
static ARBITER_HOME: &str = ".arbiter";
|
static ARBITER_HOME: &str = ".arbiter";
|
||||||
@@ -26,6 +29,6 @@ pub fn home_path() -> Result<std::path::PathBuf, std::io::Error> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn format_challenge(challenge: &AuthChallenge) -> Vec<u8> {
|
pub fn format_challenge(challenge: &AuthChallenge) -> Vec<u8> {
|
||||||
let concat_form = format!("{}:{}", challenge.nonce, hex::encode(&challenge.pubkey));
|
let concat_form = format!("{}:{}", challenge.nonce, BASE64_STANDARD.encode(&challenge.pubkey));
|
||||||
concat_form.into_bytes().to_vec()
|
concat_form.into_bytes().to_vec()
|
||||||
}
|
}
|
||||||
|
|||||||
128
server/crates/arbiter-proto/src/url.rs
Normal file
128
server/crates/arbiter-proto/src/url.rs
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
use std::fmt::Display;
|
||||||
|
|
||||||
|
use base64::{Engine as _, prelude::BASE64_URL_SAFE};
|
||||||
|
use rustls_pki_types::CertificateDer;
|
||||||
|
|
||||||
|
const ARBITER_URL_SCHEME: &str = "arbiter";
|
||||||
|
const CERT_QUERY_KEY: &str = "cert";
|
||||||
|
const BOOTSTRAP_TOKEN_QUERY_KEY: &str = "bootstrap_token";
|
||||||
|
|
||||||
|
pub struct ArbiterUrl {
|
||||||
|
pub host: String,
|
||||||
|
pub port: u16,
|
||||||
|
pub ca_cert: CertificateDer<'static>,
|
||||||
|
pub bootstrap_token: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for ArbiterUrl {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
let mut base = format!(
|
||||||
|
"{ARBITER_URL_SCHEME}://{}:{}?{CERT_QUERY_KEY}={}",
|
||||||
|
self.host,
|
||||||
|
self.port,
|
||||||
|
BASE64_URL_SAFE.encode(self.ca_cert.to_vec())
|
||||||
|
);
|
||||||
|
if let Some(token) = &self.bootstrap_token {
|
||||||
|
base.push_str(&format!("&{BOOTSTRAP_TOKEN_QUERY_KEY}={}", token));
|
||||||
|
}
|
||||||
|
f.write_str(&base)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error, miette::Diagnostic)]
|
||||||
|
pub enum Error {
|
||||||
|
#[error("Invalid URL scheme, expected '{ARBITER_URL_SCHEME}://'")]
|
||||||
|
#[diagnostic(
|
||||||
|
code(arbiter::url::invalid_scheme),
|
||||||
|
help("The URL must start with '{ARBITER_URL_SCHEME}://'")
|
||||||
|
)]
|
||||||
|
InvalidScheme,
|
||||||
|
#[error("Missing host in URL")]
|
||||||
|
#[diagnostic(
|
||||||
|
code(arbiter::url::missing_host),
|
||||||
|
help("The URL must include a host, e.g., '{ARBITER_URL_SCHEME}://127.0.0.1:<port>'")
|
||||||
|
)]
|
||||||
|
MissingHost,
|
||||||
|
#[error("Missing port in URL")]
|
||||||
|
#[diagnostic(
|
||||||
|
code(arbiter::url::missing_port),
|
||||||
|
help("The URL must include a port, e.g., '{ARBITER_URL_SCHEME}://127.0.0.1:1234'")
|
||||||
|
)]
|
||||||
|
MissingPort,
|
||||||
|
#[error("Missing 'cert' query parameter in URL")]
|
||||||
|
#[diagnostic(
|
||||||
|
code(arbiter::url::missing_cert),
|
||||||
|
help("The URL must include a 'cert' query parameter")
|
||||||
|
)]
|
||||||
|
MissingCert,
|
||||||
|
#[error("Invalid base64 in 'cert' query parameter: {0}")]
|
||||||
|
#[diagnostic(code(arbiter::url::invalid_cert_base64))]
|
||||||
|
InvalidCertBase64(#[from] base64::DecodeError),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> TryFrom<&'a str> for ArbiterUrl {
|
||||||
|
type Error = Error;
|
||||||
|
|
||||||
|
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
|
||||||
|
let url = url::Url::parse(value).map_err(|_| Error::InvalidScheme)?;
|
||||||
|
|
||||||
|
if url.scheme() != ARBITER_URL_SCHEME {
|
||||||
|
return Err(Error::InvalidScheme);
|
||||||
|
}
|
||||||
|
|
||||||
|
let host = url.host_str().ok_or(Error::MissingHost)?.to_string();
|
||||||
|
let port = url.port().ok_or(Error::MissingPort)?;
|
||||||
|
let cert_str = url
|
||||||
|
.query_pairs()
|
||||||
|
.find(|(k, _)| k == CERT_QUERY_KEY)
|
||||||
|
.ok_or(Error::MissingCert)?
|
||||||
|
.1;
|
||||||
|
|
||||||
|
let cert = BASE64_URL_SAFE.decode(cert_str.as_ref())?;
|
||||||
|
let cert = CertificateDer::from_slice(&cert).into_owned();
|
||||||
|
|
||||||
|
let bootstrap_token = url
|
||||||
|
.query_pairs()
|
||||||
|
.find(|(k, _)| k == BOOTSTRAP_TOKEN_QUERY_KEY)
|
||||||
|
.map(|(_, v)| v.to_string());
|
||||||
|
|
||||||
|
Ok(ArbiterUrl {
|
||||||
|
host,
|
||||||
|
port,
|
||||||
|
ca_cert: cert,
|
||||||
|
bootstrap_token,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use rcgen::generate_simple_self_signed;
|
||||||
|
use rstest::rstest;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[rstest]
|
||||||
|
|
||||||
|
fn test_parsing_correctness(
|
||||||
|
#[values("127.0.0.1", "localhost", "192.168.1.1", "some.domain.com")] host: &str,
|
||||||
|
|
||||||
|
#[values(None, Some("token123".to_string()))] bootstrap_token: Option<String>,
|
||||||
|
) {
|
||||||
|
let cert = generate_simple_self_signed(&["Arbiter CA".into()]).unwrap();
|
||||||
|
let cert = cert.cert.der();
|
||||||
|
|
||||||
|
let url = ArbiterUrl {
|
||||||
|
host: host.to_string(),
|
||||||
|
port: 1234,
|
||||||
|
ca_cert: cert.clone().into_owned(),
|
||||||
|
bootstrap_token,
|
||||||
|
};
|
||||||
|
let url_str = url.to_string();
|
||||||
|
let parsed_url = ArbiterUrl::try_from(url_str.as_str()).unwrap();
|
||||||
|
assert_eq!(url.host, parsed_url.host);
|
||||||
|
assert_eq!(url.port, parsed_url.port);
|
||||||
|
assert_eq!(url.ca_cert.to_vec(), parsed_url.ca_cert.to_vec());
|
||||||
|
assert_eq!(url.bootstrap_token, parsed_url.bootstrap_token);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -18,6 +18,7 @@ arbiter-proto.path = "../arbiter-proto"
|
|||||||
tracing.workspace = true
|
tracing.workspace = true
|
||||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
tonic.workspace = true
|
tonic.workspace = true
|
||||||
|
tonic.features = ["tls-aws-lc"]
|
||||||
tokio.workspace = true
|
tokio.workspace = true
|
||||||
rustls.workspace = true
|
rustls.workspace = true
|
||||||
smlang.workspace = true
|
smlang.workspace = true
|
||||||
@@ -30,21 +31,17 @@ futures.workspace = true
|
|||||||
tokio-stream.workspace = true
|
tokio-stream.workspace = true
|
||||||
dashmap = "6.1.0"
|
dashmap = "6.1.0"
|
||||||
rand.workspace = true
|
rand.workspace = true
|
||||||
rcgen = { version = "0.14.7", features = [
|
rcgen.workspace = true
|
||||||
"aws_lc_rs",
|
|
||||||
"pem",
|
|
||||||
"x509-parser",
|
|
||||||
"zeroize",
|
|
||||||
], default-features = false }
|
|
||||||
chrono.workspace = true
|
chrono.workspace = true
|
||||||
memsafe = "0.4.0"
|
memsafe = "0.4.0"
|
||||||
zeroize = { version = "1.8.2", features = ["std", "simd"] }
|
zeroize = { version = "1.8.2", features = ["std", "simd"] }
|
||||||
kameo.workspace = true
|
kameo.workspace = true
|
||||||
x25519-dalek = { version = "2.0.1", features = ["getrandom"] }
|
x25519-dalek.workspace = true
|
||||||
chacha20poly1305 = { version = "0.10.1", features = ["std"] }
|
chacha20poly1305 = { version = "0.10.1", features = ["std"] }
|
||||||
argon2 = { version = "0.5.3", features = ["zeroize"] }
|
argon2 = { version = "0.5.3", features = ["zeroize"] }
|
||||||
restructed = "0.2.2"
|
restructed = "0.2.2"
|
||||||
strum = { version = "0.27.2", features = ["derive"] }
|
strum = { version = "0.27.2", features = ["derive"] }
|
||||||
|
pem = "3.0.6"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
insta = "1.46.3"
|
insta = "1.46.3"
|
||||||
|
|||||||
@@ -24,14 +24,24 @@ create unique index if not exists uniq_nonce_per_root_key on aead_encrypted (
|
|||||||
associated_root_key_id
|
associated_root_key_id
|
||||||
);
|
);
|
||||||
|
|
||||||
|
create table if not exists tls_history (
|
||||||
|
id INTEGER not null PRIMARY KEY,
|
||||||
|
cert text not null,
|
||||||
|
cert_key text not null, -- PEM Encoded private key
|
||||||
|
ca_cert text not null,
|
||||||
|
ca_key text not null, -- PEM Encoded private key
|
||||||
|
created_at integer not null default(unixepoch ('now'))
|
||||||
|
) STRICT;
|
||||||
|
|
||||||
-- This is a singleton
|
-- This is a singleton
|
||||||
create table if not exists arbiter_settings (
|
create table if not exists arbiter_settings (
|
||||||
id INTEGER not null PRIMARY KEY CHECK (id = 1), -- singleton row, id must be 1
|
id INTEGER not null PRIMARY KEY CHECK (id = 1), -- singleton row, id must be 1
|
||||||
root_key_id integer references root_key_history (id) on delete RESTRICT, -- if null, means wasn't bootstrapped yet
|
root_key_id integer references root_key_history (id) on delete RESTRICT, -- if null, means wasn't bootstrapped yet
|
||||||
cert_key blob not null,
|
tls_id integer references tls_history (id) on delete RESTRICT
|
||||||
cert blob not null
|
|
||||||
) STRICT;
|
) STRICT;
|
||||||
|
|
||||||
|
insert into arbiter_settings (id) values (1) on conflict do nothing; -- ensure singleton row exists
|
||||||
|
|
||||||
create table if not exists useragent_client (
|
create table if not exists useragent_client (
|
||||||
id integer not null primary key,
|
id integer not null primary key,
|
||||||
nonce integer not null default(1), -- used for auth challenge
|
nonce integer not null default(1), -- used for auth challenge
|
||||||
|
|||||||
@@ -1,28 +1,31 @@
|
|||||||
use arbiter_proto::{BOOTSTRAP_TOKEN_PATH, home_path};
|
use arbiter_proto::{BOOTSTRAP_PATH, home_path};
|
||||||
use diesel::QueryDsl;
|
use diesel::QueryDsl;
|
||||||
use diesel_async::RunQueryDsl;
|
use diesel_async::RunQueryDsl;
|
||||||
use kameo::{Actor, messages};
|
use kameo::{Actor, messages};
|
||||||
use miette::Diagnostic;
|
use miette::Diagnostic;
|
||||||
use rand::{RngExt, distr::StandardUniform, make_rng, rngs::StdRng};
|
use rand::{
|
||||||
|
RngExt,
|
||||||
|
distr::{Alphanumeric},
|
||||||
|
make_rng,
|
||||||
|
rngs::StdRng,
|
||||||
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tracing::info;
|
|
||||||
|
|
||||||
use crate::db::{self, DatabasePool, schema};
|
use crate::db::{self, DatabasePool, schema};
|
||||||
|
|
||||||
const TOKEN_LENGTH: usize = 64;
|
const TOKEN_LENGTH: usize = 64;
|
||||||
|
|
||||||
pub async fn generate_token() -> Result<String, std::io::Error> {
|
pub async fn generate_token() -> Result<String, std::io::Error> {
|
||||||
let rng: StdRng = make_rng();
|
let rng: StdRng = make_rng();
|
||||||
|
|
||||||
let token: String = rng
|
let token: String = rng.sample_iter(Alphanumeric).take(TOKEN_LENGTH).fold(
|
||||||
.sample_iter::<char, _>(StandardUniform)
|
Default::default(),
|
||||||
.take(TOKEN_LENGTH)
|
|mut accum, char| {
|
||||||
.fold(Default::default(), |mut accum, char| {
|
|
||||||
accum += char.to_string().as_str();
|
accum += char.to_string().as_str();
|
||||||
accum
|
accum
|
||||||
});
|
},
|
||||||
|
);
|
||||||
|
|
||||||
tokio::fs::write(home_path()?.join(BOOTSTRAP_TOKEN_PATH), token.as_str()).await?;
|
tokio::fs::write(home_path()?.join(BOOTSTRAP_PATH), token.as_str()).await?;
|
||||||
|
|
||||||
Ok(token)
|
Ok(token)
|
||||||
}
|
}
|
||||||
@@ -58,10 +61,9 @@ impl Bootstrapper {
|
|||||||
|
|
||||||
drop(conn);
|
drop(conn);
|
||||||
|
|
||||||
|
|
||||||
let token = if row_count == 0 {
|
let token = if row_count == 0 {
|
||||||
let token = generate_token().await?;
|
let token = generate_token().await?;
|
||||||
info!(%token, "Generated bootstrap token");
|
|
||||||
tokio::fs::write(home_path()?.join(BOOTSTRAP_TOKEN_PATH), token.as_str()).await?;
|
|
||||||
Some(token)
|
Some(token)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
|||||||
@@ -345,30 +345,15 @@ impl KeyHolder {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use diesel::SelectableHelper;
|
use diesel::SelectableHelper;
|
||||||
use diesel::dsl::insert_into;
|
|
||||||
use diesel_async::RunQueryDsl;
|
use diesel_async::RunQueryDsl;
|
||||||
use memsafe::MemSafe;
|
use memsafe::MemSafe;
|
||||||
|
|
||||||
use crate::db::{self, models::ArbiterSetting};
|
use crate::db::{self};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
async fn seed_settings(pool: &db::DatabasePool) {
|
|
||||||
let mut conn = pool.get().await.unwrap();
|
|
||||||
insert_into(schema::arbiter_settings::table)
|
|
||||||
.values(&ArbiterSetting {
|
|
||||||
id: 1,
|
|
||||||
root_key_id: None,
|
|
||||||
cert_key: vec![],
|
|
||||||
cert: vec![],
|
|
||||||
})
|
|
||||||
.execute(&mut conn)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn bootstrapped_actor(db: &db::DatabasePool) -> KeyHolder {
|
async fn bootstrapped_actor(db: &db::DatabasePool) -> KeyHolder {
|
||||||
seed_settings(db).await;
|
|
||||||
let mut actor = KeyHolder::new(db.clone()).await.unwrap();
|
let mut actor = KeyHolder::new(db.clone()).await.unwrap();
|
||||||
let seal_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap();
|
let seal_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap();
|
||||||
actor.bootstrap(seal_key).await.unwrap();
|
actor.bootstrap(seal_key).await.unwrap();
|
||||||
|
|||||||
@@ -1,14 +1,12 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use diesel::OptionalExtension as _;
|
|
||||||
use diesel_async::RunQueryDsl as _;
|
|
||||||
use miette::Diagnostic;
|
use miette::Diagnostic;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
actors::GlobalActors,
|
actors::GlobalActors,
|
||||||
context::tls::{TlsDataRaw, TlsManager},
|
context::tls::TlsManager,
|
||||||
db::{self, models::ArbiterSetting, schema::arbiter_settings},
|
db::{self},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub mod tls;
|
pub mod tls;
|
||||||
@@ -29,7 +27,7 @@ pub enum InitError {
|
|||||||
|
|
||||||
#[error("TLS initialization failed: {0}")]
|
#[error("TLS initialization failed: {0}")]
|
||||||
#[diagnostic(code(arbiter_server::init::tls_init))]
|
#[diagnostic(code(arbiter_server::init::tls_init))]
|
||||||
Tls(#[from] tls::TlsInitError),
|
Tls(#[from] tls::InitError),
|
||||||
|
|
||||||
#[error("Actor spawn failed: {0}")]
|
#[error("Actor spawn failed: {0}")]
|
||||||
#[diagnostic(code(arbiter_server::init::actor_spawn))]
|
#[diagnostic(code(arbiter_server::init::actor_spawn))]
|
||||||
@@ -57,54 +55,11 @@ impl std::ops::Deref for ServerContext {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ServerContext {
|
impl ServerContext {
|
||||||
async fn load_tls(
|
|
||||||
db: &mut db::DatabaseConnection,
|
|
||||||
settings: Option<&ArbiterSetting>,
|
|
||||||
) -> Result<TlsManager, InitError> {
|
|
||||||
match &settings {
|
|
||||||
Some(settings) => {
|
|
||||||
let tls_data_raw = TlsDataRaw {
|
|
||||||
cert: settings.cert.clone(),
|
|
||||||
key: settings.cert_key.clone(),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(TlsManager::new(Some(tls_data_raw)).await?)
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
let tls = TlsManager::new(None).await?;
|
|
||||||
let tls_data_raw = tls.bytes();
|
|
||||||
|
|
||||||
diesel::insert_into(arbiter_settings::table)
|
|
||||||
.values(&ArbiterSetting {
|
|
||||||
id: 1,
|
|
||||||
root_key_id: None,
|
|
||||||
cert_key: tls_data_raw.key,
|
|
||||||
cert: tls_data_raw.cert,
|
|
||||||
})
|
|
||||||
.execute(db)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(tls)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn new(db: db::DatabasePool) -> Result<Self, InitError> {
|
pub async fn new(db: db::DatabasePool) -> Result<Self, InitError> {
|
||||||
let mut conn = db.get().await?;
|
|
||||||
|
|
||||||
let settings = arbiter_settings::table
|
|
||||||
.first::<ArbiterSetting>(&mut conn)
|
|
||||||
.await
|
|
||||||
.optional()?;
|
|
||||||
|
|
||||||
let tls = Self::load_tls(&mut conn, settings.as_ref()).await?;
|
|
||||||
|
|
||||||
drop(conn);
|
|
||||||
|
|
||||||
Ok(Self(Arc::new(_ServerContextInner {
|
Ok(Self(Arc::new(_ServerContextInner {
|
||||||
actors: GlobalActors::spawn(db.clone()).await?,
|
actors: GlobalActors::spawn(db.clone()).await?,
|
||||||
|
tls: TlsManager::new(db.clone()).await?,
|
||||||
db,
|
db,
|
||||||
tls,
|
|
||||||
})))
|
})))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,36 @@
|
|||||||
use std::string::FromUtf8Error;
|
use std::string::FromUtf8Error;
|
||||||
|
|
||||||
|
use diesel::{ExpressionMethods as _, QueryDsl, SelectableHelper as _};
|
||||||
|
use diesel_async::{AsyncConnection, RunQueryDsl};
|
||||||
use miette::Diagnostic;
|
use miette::Diagnostic;
|
||||||
use rcgen::{Certificate, KeyPair};
|
use pem::Pem;
|
||||||
use rustls::pki_types::CertificateDer;
|
use rcgen::{
|
||||||
|
BasicConstraints, Certificate, CertificateParams, CertifiedIssuer, DistinguishedName, DnType,
|
||||||
|
IsCa, Issuer, KeyPair, KeyUsagePurpose,
|
||||||
|
};
|
||||||
|
use rustls::pki_types::{pem::PemObject};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
use tonic::transport::CertificateDer;
|
||||||
|
|
||||||
|
use crate::db::{
|
||||||
|
self,
|
||||||
|
models::{NewTlsHistory, TlsHistory},
|
||||||
|
schema::{
|
||||||
|
arbiter_settings,
|
||||||
|
tls_history::{self},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const ENCODE_CONFIG: pem::EncodeConfig = {
|
||||||
|
let line_ending = match cfg!(target_family = "windows") {
|
||||||
|
true => pem::LineEnding::CRLF,
|
||||||
|
false => pem::LineEnding::LF,
|
||||||
|
};
|
||||||
|
pem::EncodeConfig::new().set_line_ending(line_ending)
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Error, Debug, Diagnostic)]
|
#[derive(Error, Debug, Diagnostic)]
|
||||||
pub enum TlsInitError {
|
pub enum InitError {
|
||||||
#[error("Key generation error during TLS initialization: {0}")]
|
#[error("Key generation error during TLS initialization: {0}")]
|
||||||
#[diagnostic(code(arbiter_server::tls_init::key_generation))]
|
#[diagnostic(code(arbiter_server::tls_init::key_generation))]
|
||||||
KeyGeneration(#[from] rcgen::Error),
|
KeyGeneration(#[from] rcgen::Error),
|
||||||
@@ -18,68 +42,211 @@ pub enum TlsInitError {
|
|||||||
#[error("Key deserialization error: {0}")]
|
#[error("Key deserialization error: {0}")]
|
||||||
#[diagnostic(code(arbiter_server::tls_init::key_deserialization))]
|
#[diagnostic(code(arbiter_server::tls_init::key_deserialization))]
|
||||||
KeyDeserializationError(rcgen::Error),
|
KeyDeserializationError(rcgen::Error),
|
||||||
|
|
||||||
|
#[error("Database error during TLS initialization: {0}")]
|
||||||
|
#[diagnostic(code(arbiter_server::tls_init::database_error))]
|
||||||
|
DatabaseError(#[from] diesel::result::Error),
|
||||||
|
|
||||||
|
#[error("Pem deserialization error during TLS initialization: {0}")]
|
||||||
|
#[diagnostic(code(arbiter_server::tls_init::pem_deserialization))]
|
||||||
|
PemDeserializationError(#[from] rustls::pki_types::pem::Error),
|
||||||
|
|
||||||
|
#[error("Database pool acquire error during TLS initialization: {0}")]
|
||||||
|
#[diagnostic(code(arbiter_server::tls_init::database_pool_acquire))]
|
||||||
|
DatabasePoolAcquire(#[from] db::PoolError),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct TlsData {
|
pub type PemCert = String;
|
||||||
pub cert: CertificateDer<'static>,
|
|
||||||
pub keypair: KeyPair,
|
pub fn encode_cert_to_pem(cert: &CertificateDer) -> PemCert {
|
||||||
|
pem::encode_config(
|
||||||
|
&Pem::new("CERTIFICATE", cert.to_vec()),
|
||||||
|
ENCODE_CONFIG,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct TlsDataRaw {
|
#[allow(unused)]
|
||||||
pub cert: Vec<u8>,
|
struct SerializedTls {
|
||||||
pub key: Vec<u8>,
|
cert_pem: PemCert,
|
||||||
|
cert_key_pem: String,
|
||||||
}
|
}
|
||||||
impl TlsDataRaw {
|
|
||||||
pub fn serialize(cert: &TlsData) -> Self {
|
struct TlsCa {
|
||||||
Self {
|
issuer: Issuer<'static, KeyPair>,
|
||||||
cert: cert.cert.as_ref().to_vec(),
|
cert: CertificateDer<'static>,
|
||||||
key: cert.keypair.serialize_pem().as_bytes().to_vec(),
|
}
|
||||||
}
|
|
||||||
|
impl TlsCa {
|
||||||
|
fn generate() -> Result<Self, InitError> {
|
||||||
|
let keypair = KeyPair::generate()?;
|
||||||
|
let mut params = CertificateParams::new(["Arbiter Instance CA".into()])?;
|
||||||
|
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
|
||||||
|
params.key_usages = vec![
|
||||||
|
KeyUsagePurpose::KeyCertSign,
|
||||||
|
KeyUsagePurpose::CrlSign,
|
||||||
|
KeyUsagePurpose::DigitalSignature,
|
||||||
|
];
|
||||||
|
|
||||||
|
let mut dn = DistinguishedName::new();
|
||||||
|
dn.push(DnType::CommonName, "Arbiter Instance CA");
|
||||||
|
params.distinguished_name = dn;
|
||||||
|
let certified_issuer = CertifiedIssuer::self_signed(params, keypair)?;
|
||||||
|
|
||||||
|
let cert_key_pem = certified_issuer.key().serialize_pem();
|
||||||
|
|
||||||
|
let issuer = Issuer::from_ca_cert_pem(
|
||||||
|
&certified_issuer.pem(),
|
||||||
|
KeyPair::from_pem(cert_key_pem.as_ref()).unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
issuer,
|
||||||
|
cert: certified_issuer.der().clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
fn generate_leaf(&self) -> Result<TlsCert, InitError> {
|
||||||
|
let cert_key = KeyPair::generate()?;
|
||||||
|
let mut params = CertificateParams::new(["Arbiter Instance Leaf".into()])?;
|
||||||
|
params.is_ca = IsCa::NoCa;
|
||||||
|
params.key_usages = vec![
|
||||||
|
KeyUsagePurpose::DigitalSignature,
|
||||||
|
KeyUsagePurpose::KeyEncipherment,
|
||||||
|
];
|
||||||
|
|
||||||
|
let mut dn = DistinguishedName::new();
|
||||||
|
dn.push(DnType::CommonName, "Arbiter Instance Leaf");
|
||||||
|
params.distinguished_name = dn;
|
||||||
|
|
||||||
|
let new_cert = params.signed_by(&cert_key, &self.issuer)?;
|
||||||
|
|
||||||
|
Ok(TlsCert {
|
||||||
|
cert: new_cert,
|
||||||
|
cert_key,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn deserialize(&self) -> Result<TlsData, TlsInitError> {
|
#[allow(unused)]
|
||||||
let cert = CertificateDer::from_slice(&self.cert).into_owned();
|
fn serialize(&self) -> Result<SerializedTls, InitError> {
|
||||||
|
let cert_key_pem = self.issuer.key().serialize_pem();
|
||||||
|
Ok(SerializedTls {
|
||||||
|
cert_pem: encode_cert_to_pem(&self.cert),
|
||||||
|
cert_key_pem,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
let key = String::from_utf8(self.key.clone()).map_err(TlsInitError::KeyInvalidFormat)?;
|
#[allow(unused)]
|
||||||
|
fn try_deserialize(cert_pem: &str, cert_key_pem: &str) -> Result<Self, InitError> {
|
||||||
let keypair = KeyPair::from_pem(&key).map_err(TlsInitError::KeyDeserializationError)?;
|
let keypair =
|
||||||
|
KeyPair::from_pem(cert_key_pem).map_err(InitError::KeyDeserializationError)?;
|
||||||
Ok(TlsData { cert, keypair })
|
let issuer = Issuer::from_ca_cert_pem(cert_pem, keypair)?;
|
||||||
|
Ok(Self {
|
||||||
|
issuer,
|
||||||
|
cert: CertificateDer::from_pem_slice(cert_pem.as_bytes())?,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn generate_cert(key: &KeyPair) -> Result<Certificate, rcgen::Error> {
|
struct TlsCert {
|
||||||
let params =
|
cert: Certificate,
|
||||||
rcgen::CertificateParams::new(vec!["arbiter.local".to_string(), "localhost".to_string()])?;
|
cert_key: KeyPair,
|
||||||
|
|
||||||
params.self_signed(key)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Implement cert rotation
|
// TODO: Implement cert rotation
|
||||||
pub struct TlsManager {
|
pub struct TlsManager {
|
||||||
data: TlsData,
|
cert: CertificateDer<'static>,
|
||||||
|
keypair: KeyPair,
|
||||||
|
ca_cert: CertificateDer<'static>,
|
||||||
|
_db: db::DatabasePool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TlsManager {
|
impl TlsManager {
|
||||||
pub async fn new(data: Option<TlsDataRaw>) -> Result<Self, TlsInitError> {
|
pub async fn generate_new(db: &db::DatabasePool) -> Result<Self, InitError> {
|
||||||
match data {
|
let ca = TlsCa::generate()?;
|
||||||
Some(raw) => {
|
let new_cert = ca.generate_leaf()?;
|
||||||
let tls_data = raw.deserialize()?;
|
|
||||||
Ok(Self { data: tls_data })
|
{
|
||||||
}
|
let mut conn = db.get().await?;
|
||||||
None => {
|
conn.transaction(|conn| {
|
||||||
let keypair = KeyPair::generate()?;
|
Box::pin(async {
|
||||||
let cert = generate_cert(&keypair)?;
|
let new_tls_history = NewTlsHistory {
|
||||||
let tls_data = TlsData {
|
cert: new_cert.cert.pem(),
|
||||||
cert: cert.der().clone(),
|
cert_key: new_cert.cert_key.serialize_pem(),
|
||||||
keypair,
|
ca_cert: encode_cert_to_pem(&ca.cert),
|
||||||
|
ca_key: ca.issuer.key().serialize_pem(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let inserted_tls_history: i32 = diesel::insert_into(tls_history::table)
|
||||||
|
.values(&new_tls_history)
|
||||||
|
.returning(tls_history::id)
|
||||||
|
.get_result(conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
diesel::update(arbiter_settings::table)
|
||||||
|
.set(arbiter_settings::tls_id.eq(inserted_tls_history))
|
||||||
|
.execute(conn)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Result::<_, diesel::result::Error>::Ok(())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
cert: new_cert.cert.der().clone(),
|
||||||
|
keypair: new_cert.cert_key,
|
||||||
|
ca_cert: ca.cert,
|
||||||
|
_db: db.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn new(db: db::DatabasePool) -> Result<Self, InitError> {
|
||||||
|
let cert_data: Option<TlsHistory> = {
|
||||||
|
let mut conn = db.get().await?;
|
||||||
|
arbiter_settings::table
|
||||||
|
.left_join(tls_history::table)
|
||||||
|
.select(Option::<TlsHistory>::as_select())
|
||||||
|
.first(&mut conn)
|
||||||
|
.await?
|
||||||
|
};
|
||||||
|
|
||||||
|
match cert_data {
|
||||||
|
Some(data) => {
|
||||||
|
let try_load = || -> Result<_, Box<dyn std::error::Error>> {
|
||||||
|
let keypair = KeyPair::from_pem(&data.cert_key)?;
|
||||||
|
let cert = CertificateDer::from_pem_slice(data.cert.as_bytes())?;
|
||||||
|
let ca_cert = CertificateDer::from_pem_slice(data.ca_cert.as_bytes())?;
|
||||||
|
Ok(Self {
|
||||||
|
cert,
|
||||||
|
keypair,
|
||||||
|
ca_cert,
|
||||||
|
_db: db.clone(),
|
||||||
|
})
|
||||||
};
|
};
|
||||||
Ok(Self { data: tls_data })
|
match try_load() {
|
||||||
|
Ok(manager) => Ok(manager),
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("Failed to load existing TLS certs: {e}. Generating new ones.");
|
||||||
|
Self::generate_new(&db).await
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
None => Self::generate_new(&db).await,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn bytes(&self) -> TlsDataRaw {
|
pub fn cert(&self) -> &CertificateDer<'static> {
|
||||||
TlsDataRaw::serialize(&self.data)
|
&self.cert
|
||||||
|
}
|
||||||
|
pub fn ca_cert(&self) -> &CertificateDer<'static> {
|
||||||
|
&self.ca_cert
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn cert_pem(&self) -> PemCert {
|
||||||
|
encode_cert_to_pem(&self.cert)
|
||||||
|
}
|
||||||
|
pub fn key_pem(&self) -> String {
|
||||||
|
self.keypair.serialize_pem()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,23 +24,23 @@ const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");
|
|||||||
#[derive(Error, Diagnostic, Debug)]
|
#[derive(Error, Diagnostic, Debug)]
|
||||||
pub enum DatabaseSetupError {
|
pub enum DatabaseSetupError {
|
||||||
#[error("Failed to determine home directory")]
|
#[error("Failed to determine home directory")]
|
||||||
#[diagnostic(code(arbiter::db::home_dir_error))]
|
#[diagnostic(code(arbiter::db::home_dir))]
|
||||||
HomeDir(std::io::Error),
|
HomeDir(std::io::Error),
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
#[diagnostic(code(arbiter::db::connection_error))]
|
#[diagnostic(code(arbiter::db::connection))]
|
||||||
Connection(diesel::ConnectionError),
|
Connection(diesel::ConnectionError),
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
#[diagnostic(code(arbiter::db::concurrency_error))]
|
#[diagnostic(code(arbiter::db::concurrency))]
|
||||||
ConcurrencySetup(diesel::result::Error),
|
ConcurrencySetup(diesel::result::Error),
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
#[diagnostic(code(arbiter::db::migration_error))]
|
#[diagnostic(code(arbiter::db::migration))]
|
||||||
Migration(Box<dyn std::error::Error + Send + Sync>),
|
Migration(Box<dyn std::error::Error + Send + Sync>),
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
#[diagnostic(code(arbiter::db::pool_error))]
|
#[diagnostic(code(arbiter::db::pool))]
|
||||||
Pool(#[from] PoolInitError),
|
Pool(#[from] PoolInitError),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,12 +91,12 @@ fn initialize_database(url: &str) -> Result<(), DatabaseSetupError> {
|
|||||||
|
|
||||||
#[tracing::instrument(level = "info")]
|
#[tracing::instrument(level = "info")]
|
||||||
pub async fn create_pool(url: Option<&str>) -> Result<DatabasePool, DatabaseSetupError> {
|
pub async fn create_pool(url: Option<&str>) -> Result<DatabasePool, DatabaseSetupError> {
|
||||||
let database_url = url.map(String::from).unwrap_or(format!(
|
let database_url = url.map(String::from).unwrap_or(
|
||||||
"{}?mode=rwc",
|
database_path()?
|
||||||
(database_path()?
|
|
||||||
.to_str()
|
.to_str()
|
||||||
.expect("database path is not valid UTF-8"))
|
.expect("database path is not valid UTF-8")
|
||||||
));
|
.to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
initialize_database(&database_url)?;
|
initialize_database(&database_url)?;
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#![allow(unused)]
|
#![allow(unused)]
|
||||||
#![allow(clippy::all)]
|
#![allow(clippy::all)]
|
||||||
|
|
||||||
use crate::db::schema::{self, aead_encrypted, arbiter_settings, root_key_history};
|
use crate::db::schema::{self, aead_encrypted, arbiter_settings, root_key_history, tls_history};
|
||||||
use diesel::{prelude::*, sqlite::Sqlite};
|
use diesel::{prelude::*, sqlite::Sqlite};
|
||||||
use restructed::Models;
|
use restructed::Models;
|
||||||
|
|
||||||
@@ -46,13 +46,29 @@ pub struct RootKeyHistory {
|
|||||||
pub salt: Vec<u8>,
|
pub salt: Vec<u8>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Queryable, Debug, Insertable)]
|
#[derive(Models, Queryable, Debug, Insertable, Selectable)]
|
||||||
|
#[diesel(table_name = tls_history, check_for_backend(Sqlite))]
|
||||||
|
#[view(
|
||||||
|
NewTlsHistory,
|
||||||
|
derive(Insertable),
|
||||||
|
omit(id, created_at),
|
||||||
|
attributes_with = "deriveless"
|
||||||
|
)]
|
||||||
|
pub struct TlsHistory {
|
||||||
|
pub id: i32,
|
||||||
|
pub cert: String,
|
||||||
|
pub cert_key: String, // PEM Encoded private key
|
||||||
|
pub ca_cert: String, // PEM Encoded certificate for cert signing
|
||||||
|
pub ca_key: String, // PEM Encoded public key for cert signing
|
||||||
|
pub created_at: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Queryable, Debug, Insertable, Selectable)]
|
||||||
#[diesel(table_name = arbiter_settings, check_for_backend(Sqlite))]
|
#[diesel(table_name = arbiter_settings, check_for_backend(Sqlite))]
|
||||||
pub struct ArbiterSetting {
|
pub struct ArbiterSettings {
|
||||||
pub id: i32,
|
pub id: i32,
|
||||||
pub root_key_id: Option<i32>, // references root_key_history.id
|
pub root_key_id: Option<i32>, // references root_key_history.id
|
||||||
pub cert_key: Vec<u8>,
|
pub tls_id: Option<i32>, // references tls_history.id
|
||||||
pub cert: Vec<u8>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Queryable, Debug)]
|
#[derive(Queryable, Debug)]
|
||||||
|
|||||||
@@ -16,8 +16,7 @@ diesel::table! {
|
|||||||
arbiter_settings (id) {
|
arbiter_settings (id) {
|
||||||
id -> Integer,
|
id -> Integer,
|
||||||
root_key_id -> Nullable<Integer>,
|
root_key_id -> Nullable<Integer>,
|
||||||
cert_key -> Binary,
|
tls_id -> Nullable<Integer>,
|
||||||
cert -> Binary,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -43,6 +42,17 @@ diesel::table! {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
diesel::table! {
|
||||||
|
tls_history (id) {
|
||||||
|
id -> Integer,
|
||||||
|
cert -> Text,
|
||||||
|
cert_key -> Text,
|
||||||
|
ca_cert -> Text,
|
||||||
|
ca_key -> Text,
|
||||||
|
created_at -> Integer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
diesel::table! {
|
diesel::table! {
|
||||||
useragent_client (id) {
|
useragent_client (id) {
|
||||||
id -> Integer,
|
id -> Integer,
|
||||||
@@ -55,11 +65,13 @@ diesel::table! {
|
|||||||
|
|
||||||
diesel::joinable!(aead_encrypted -> root_key_history (associated_root_key_id));
|
diesel::joinable!(aead_encrypted -> root_key_history (associated_root_key_id));
|
||||||
diesel::joinable!(arbiter_settings -> root_key_history (root_key_id));
|
diesel::joinable!(arbiter_settings -> root_key_history (root_key_id));
|
||||||
|
diesel::joinable!(arbiter_settings -> tls_history (tls_id));
|
||||||
|
|
||||||
diesel::allow_tables_to_appear_in_same_query!(
|
diesel::allow_tables_to_appear_in_same_query!(
|
||||||
aead_encrypted,
|
aead_encrypted,
|
||||||
arbiter_settings,
|
arbiter_settings,
|
||||||
program_client,
|
program_client,
|
||||||
root_key_history,
|
root_key_history,
|
||||||
|
tls_history,
|
||||||
useragent_client,
|
useragent_client,
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -1,7 +1,13 @@
|
|||||||
use arbiter_proto::proto::arbiter_service_server::ArbiterServiceServer;
|
use std::net::SocketAddr;
|
||||||
use arbiter_server::{Server, context::ServerContext, db};
|
|
||||||
|
use arbiter_proto::{proto::arbiter_service_server::ArbiterServiceServer, url::ArbiterUrl};
|
||||||
|
use arbiter_server::{Server, actors::bootstrap::GetToken, context::ServerContext, db};
|
||||||
|
use miette::miette;
|
||||||
|
use tonic::transport::{Identity, ServerTlsConfig};
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
|
const PORT: u16 = 50051;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> miette::Result<()> {
|
async fn main() -> miette::Result<()> {
|
||||||
tracing_subscriber::fmt()
|
tracing_subscriber::fmt()
|
||||||
@@ -13,18 +19,31 @@ async fn main() -> miette::Result<()> {
|
|||||||
|
|
||||||
info!("Starting arbiter server");
|
info!("Starting arbiter server");
|
||||||
|
|
||||||
info!("Initializing database");
|
|
||||||
let db = db::create_pool(None).await?;
|
let db = db::create_pool(None).await?;
|
||||||
info!("Database ready");
|
info!("Database ready");
|
||||||
|
|
||||||
info!("Initializing server context");
|
|
||||||
let context = ServerContext::new(db).await?;
|
let context = ServerContext::new(db).await?;
|
||||||
info!("Server context ready");
|
|
||||||
|
|
||||||
let addr = "[::1]:50051".parse().expect("valid address");
|
let addr: SocketAddr = format!("127.0.0.1:{PORT}").parse().expect("valid address");
|
||||||
info!(%addr, "Starting gRPC server");
|
info!(%addr, "Starting gRPC server");
|
||||||
|
|
||||||
|
let url = ArbiterUrl {
|
||||||
|
host: addr.ip().to_string(),
|
||||||
|
port: addr.port(),
|
||||||
|
ca_cert: context.tls.ca_cert().clone().into_owned(),
|
||||||
|
bootstrap_token: context.actors.bootstrapper.ask(GetToken).await.unwrap(),
|
||||||
|
};
|
||||||
|
|
||||||
|
info!(%url, "Server URL");
|
||||||
|
|
||||||
|
let tls = ServerTlsConfig::new().identity(Identity::from_pem(
|
||||||
|
context.tls.cert_pem(),
|
||||||
|
context.tls.key_pem(),
|
||||||
|
));
|
||||||
|
|
||||||
tonic::transport::Server::builder()
|
tonic::transport::Server::builder()
|
||||||
|
.tls_config(tls)
|
||||||
|
.map_err(|err| miette!("Faild to setup TLS: {err}"))?
|
||||||
.add_service(ArbiterServiceServer::new(Server::new(context)))
|
.add_service(ArbiterServiceServer::new(Server::new(context)))
|
||||||
.serve(addr)
|
.serve(addr)
|
||||||
.await
|
.await
|
||||||
|
|||||||
@@ -1,28 +1,13 @@
|
|||||||
use arbiter_server::{
|
use arbiter_server::{
|
||||||
actors::keyholder::KeyHolder,
|
actors::keyholder::KeyHolder,
|
||||||
db::{self, models::ArbiterSetting, schema},
|
db::{self, schema},
|
||||||
};
|
};
|
||||||
use diesel::{QueryDsl, insert_into};
|
use diesel::QueryDsl;
|
||||||
use diesel_async::RunQueryDsl;
|
use diesel_async::RunQueryDsl;
|
||||||
use memsafe::MemSafe;
|
use memsafe::MemSafe;
|
||||||
|
|
||||||
pub async fn seed_settings(pool: &db::DatabasePool) {
|
|
||||||
let mut conn = pool.get().await.unwrap();
|
|
||||||
insert_into(schema::arbiter_settings::table)
|
|
||||||
.values(&ArbiterSetting {
|
|
||||||
id: 1,
|
|
||||||
root_key_id: None,
|
|
||||||
cert_key: vec![],
|
|
||||||
cert: vec![],
|
|
||||||
})
|
|
||||||
.execute(&mut conn)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub async fn bootstrapped_keyholder(db: &db::DatabasePool) -> KeyHolder {
|
pub async fn bootstrapped_keyholder(db: &db::DatabasePool) -> KeyHolder {
|
||||||
seed_settings(db).await;
|
|
||||||
let mut actor = KeyHolder::new(db.clone()).await.unwrap();
|
let mut actor = KeyHolder::new(db.clone()).await.unwrap();
|
||||||
actor
|
actor
|
||||||
.bootstrap(MemSafe::new(b"test-seal-key".to_vec()).unwrap())
|
.bootstrap(MemSafe::new(b"test-seal-key".to_vec()).unwrap())
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ use crate::common;
|
|||||||
#[test_log::test]
|
#[test_log::test]
|
||||||
async fn test_bootstrap() {
|
async fn test_bootstrap() {
|
||||||
let db = db::create_test_pool().await;
|
let db = db::create_test_pool().await;
|
||||||
common::seed_settings(&db).await;
|
|
||||||
let mut actor = KeyHolder::new(db.clone()).await.unwrap();
|
let mut actor = KeyHolder::new(db.clone()).await.unwrap();
|
||||||
|
|
||||||
let seal_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap();
|
let seal_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap();
|
||||||
@@ -53,7 +52,6 @@ async fn test_bootstrap_rejects_double() {
|
|||||||
#[test_log::test]
|
#[test_log::test]
|
||||||
async fn test_create_new_before_bootstrap_fails() {
|
async fn test_create_new_before_bootstrap_fails() {
|
||||||
let db = db::create_test_pool().await;
|
let db = db::create_test_pool().await;
|
||||||
common::seed_settings(&db).await;
|
|
||||||
let mut actor = KeyHolder::new(db).await.unwrap();
|
let mut actor = KeyHolder::new(db).await.unwrap();
|
||||||
|
|
||||||
let err = actor
|
let err = actor
|
||||||
@@ -67,7 +65,6 @@ async fn test_create_new_before_bootstrap_fails() {
|
|||||||
#[test_log::test]
|
#[test_log::test]
|
||||||
async fn test_decrypt_before_bootstrap_fails() {
|
async fn test_decrypt_before_bootstrap_fails() {
|
||||||
let db = db::create_test_pool().await;
|
let db = db::create_test_pool().await;
|
||||||
common::seed_settings(&db).await;
|
|
||||||
let mut actor = KeyHolder::new(db).await.unwrap();
|
let mut actor = KeyHolder::new(db).await.unwrap();
|
||||||
|
|
||||||
let err = actor.decrypt(1).await.unwrap_err();
|
let err = actor.decrypt(1).await.unwrap_err();
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ use kameo::actor::Spawn;
|
|||||||
#[test_log::test]
|
#[test_log::test]
|
||||||
pub async fn test_bootstrap_token_auth() {
|
pub async fn test_bootstrap_token_auth() {
|
||||||
let db =db::create_test_pool().await;
|
let db =db::create_test_pool().await;
|
||||||
crate::common::seed_settings(&db).await;
|
|
||||||
|
|
||||||
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
|
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
|
||||||
let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap();
|
let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap();
|
||||||
@@ -67,7 +66,6 @@ pub async fn test_bootstrap_token_auth() {
|
|||||||
#[test_log::test]
|
#[test_log::test]
|
||||||
pub async fn test_bootstrap_invalid_token_auth() {
|
pub async fn test_bootstrap_invalid_token_auth() {
|
||||||
let db = db::create_test_pool().await;
|
let db = db::create_test_pool().await;
|
||||||
crate::common::seed_settings(&db).await;
|
|
||||||
|
|
||||||
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
|
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
|
||||||
let user_agent =
|
let user_agent =
|
||||||
@@ -110,7 +108,6 @@ pub async fn test_bootstrap_invalid_token_auth() {
|
|||||||
#[test_log::test]
|
#[test_log::test]
|
||||||
pub async fn test_challenge_auth() {
|
pub async fn test_challenge_auth() {
|
||||||
let db = db::create_test_pool().await;
|
let db = db::create_test_pool().await;
|
||||||
crate::common::seed_settings(&db).await;
|
|
||||||
|
|
||||||
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
|
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
|
||||||
let user_agent =
|
let user_agent =
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ async fn setup_authenticated_user_agent(
|
|||||||
seal_key: &[u8],
|
seal_key: &[u8],
|
||||||
) -> (arbiter_server::db::DatabasePool, ActorRef<UserAgentActor>) {
|
) -> (arbiter_server::db::DatabasePool, ActorRef<UserAgentActor>) {
|
||||||
let db = db::create_test_pool().await;
|
let db = db::create_test_pool().await;
|
||||||
crate::common::seed_settings(&db).await;
|
|
||||||
|
|
||||||
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
|
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
|
||||||
actors
|
actors
|
||||||
@@ -167,7 +166,6 @@ pub async fn test_unseal_corrupted_ciphertext() {
|
|||||||
#[test_log::test]
|
#[test_log::test]
|
||||||
pub async fn test_unseal_start_without_auth_fails() {
|
pub async fn test_unseal_start_without_auth_fails() {
|
||||||
let db = db::create_test_pool().await;
|
let db = db::create_test_pool().await;
|
||||||
crate::common::seed_settings(&db).await;
|
|
||||||
|
|
||||||
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
|
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
|
||||||
let user_agent =
|
let user_agent =
|
||||||
|
|||||||
@@ -5,3 +5,11 @@ edition = "2024"
|
|||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
arbiter-proto.path = "../arbiter-proto"
|
||||||
|
kameo.workspace = true
|
||||||
|
tokio = {workspace = true, features = ["net"]}
|
||||||
|
tonic.workspace = true
|
||||||
|
tracing.workspace = true
|
||||||
|
ed25519-dalek.workspace = true
|
||||||
|
smlang.workspace = true
|
||||||
|
x25519-dalek.workspace = true
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
pub fn add(left: u64, right: u64) -> u64 {
|
|
||||||
left + right
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn it_works() {
|
|
||||||
let result = add(2, 2);
|
|
||||||
assert_eq!(result, 4);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user