5 Commits

Author SHA1 Message Date
hdbg
90f2476f3d ci(server): introduce tests pipeline
Some checks failed
ci/woodpecker/push/server-test Pipeline failed
2026-02-14 18:39:57 +01:00
hdbg
81a55d28f0 test(db): add create_test_pool and use in tests 2026-02-14 18:33:33 +01:00
hdbg
69dd8f57ca tests(server): UserAgent bootstrap token auth flow test 2026-02-14 18:16:19 +01:00
hdbg
345a967c13 refactor(server): separated UserAgentActor gRPC transport related things into separate module 2026-02-14 17:58:25 +01:00
hdbg
069a997691 feat(server): UserAgent auth flow implemented 2026-02-14 17:53:58 +01:00
17 changed files with 576 additions and 189 deletions

View File

@@ -0,0 +1,26 @@
when:
- event: pull_request
path:
include: ['.woodpecker/server-*.yaml', 'server/**']
- event: push
branch: main
path:
include: ['.woodpecker/server-*.yaml', 'server/**']
steps:
- name: test
image: mise:latest
directory: server
environment:
CARGO_TERM_COLOR: always
CARGO_TARGET_DIR: /usr/local/cargo/target
CARGO_HOME: /usr/local/cargo/registry
volumes:
- cargo-target:/usr/local/cargo/target
- cargo-registry:/usr/local/cargo/registry
commands:
- apt-get update && apt-get install -y pkg-config
# Install only the necessary Rust toolchain and test runner to speed up the CI
- mise install rust
- mise install cargo:cargo-nextest
- mise exec cargo:cargo-nextest -- cargo nextest run --no-fail-fast

View File

@@ -10,6 +10,10 @@ backend = "cargo:cargo-features"
version = "0.11.1"
backend = "cargo:cargo-features-manager"
[[tools."cargo:cargo-nextest"]]
version = "0.9.126"
backend = "cargo:cargo-nextest"
[[tools."cargo:cargo-vet"]]
version = "0.10.2"
backend = "cargo:cargo-vet"

View File

@@ -7,3 +7,4 @@ flutter = "3.38.9-stable"
protoc = "29.6"
rust = "1.93.0"
"cargo:cargo-features-manager" = "0.11.1"
"cargo:cargo-nextest" = "0.9.126"

View File

@@ -11,13 +11,11 @@ message AuthChallengeRequest {
message AuthChallenge {
bytes pubkey = 1;
bytes nonce = 2;
google.protobuf.Timestamp minted = 3;
int32 nonce = 2;
}
message AuthChallengeSolution {
AuthChallenge challenge = 1;
bytes signature = 2;
bytes signature = 1;
}
message AuthOk {}

74
server/Cargo.lock generated
View File

@@ -61,6 +61,7 @@ version = "0.1.0"
dependencies = [
"bytes",
"futures",
"hex",
"kameo",
"prost",
"prost-build",
@@ -93,6 +94,7 @@ dependencies = [
"kameo",
"memsafe",
"miette",
"prost-types",
"rand",
"rcgen",
"restructed",
@@ -101,6 +103,8 @@ dependencies = [
"secrecy",
"smlang",
"statig",
"tempfile",
"test-log",
"thiserror",
"tokio",
"tokio-stream",
@@ -741,6 +745,7 @@ checksum = "053618a4c3d3bc24f188aa660ae75a46eeab74ef07fb415c61431e5e7cd4749b"
dependencies = [
"curve25519-dalek",
"ed25519",
"rand_core 0.10.0",
"sha2",
"subtle",
"zeroize",
@@ -1017,6 +1022,12 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hex"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]]
name = "http"
version = "1.4.0"
@@ -1301,6 +1312,15 @@ version = "0.4.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897"
[[package]]
name = "matchers"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9"
dependencies = [
"regex-automata",
]
[[package]]
name = "matchit"
version = "0.8.4"
@@ -2144,6 +2164,15 @@ dependencies = [
"digest",
]
[[package]]
name = "sharded-slab"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6"
dependencies = [
"lazy_static",
]
[[package]]
name = "shlex"
version = "1.3.0"
@@ -2365,6 +2394,27 @@ dependencies = [
"windows-sys 0.60.2",
]
[[package]]
name = "test-log"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37d53ac171c92a39e4769491c4b4dde7022c60042254b5fc044ae409d34a24d4"
dependencies = [
"test-log-macros",
"tracing-subscriber",
]
[[package]]
name = "test-log-macros"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be35209fd0781c5401458ab66e4f98accf63553e8fae7425503e92fdd319783b"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.114",
]
[[package]]
name = "textwrap"
version = "0.16.2"
@@ -2395,6 +2445,15 @@ dependencies = [
"syn 2.0.114",
]
[[package]]
name = "thread_local"
version = "1.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185"
dependencies = [
"cfg-if",
]
[[package]]
name = "time"
version = "0.3.47"
@@ -2669,6 +2728,21 @@ dependencies = [
"once_cell",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e"
dependencies = [
"matchers",
"once_cell",
"regex-automata",
"sharded-slab",
"thread_local",
"tracing",
"tracing-core",
]
[[package]]
name = "try-lock"
version = "0.2.5"

View File

@@ -14,7 +14,7 @@ tonic = { version = "0.14.3", features = ["deflate", "gzip", "tls-connect-info",
tracing = "0.1.44"
tokio = { version = "1.49.0", features = ["full"] }
ed25519 = "3.0.0-rc.4"
ed25519-dalek = "3.0.0-pre.6"
ed25519-dalek = { version = "3.0.0-pre.6", features = ["rand_core"] }
chrono = { version = "0.4.43", features = ["serde"] }
rand = "0.10.0"
uuid = "1.20.0"
@@ -25,4 +25,5 @@ thiserror = "2.0.18"
async-trait = "0.1.89"
futures = "0.3.31"
tokio-stream = { version = "0.1.18", features = ["full"] }
kameo = "0.19.2"
kameo = "0.19.2"
prost-types = { version = "0.14.3", features = ["chrono"] }

View File

@@ -9,16 +9,17 @@ tonic.workspace = true
prost.workspace = true
bytes = "1.11.1"
prost-derive = "0.14.3"
prost-types = { version = "0.14.3", features = ["chrono"] }
prost-types.workspace = true
tonic-prost = "0.14.3"
rkyv = "0.8.15"
tokio.workspace = true
futures.workspace = true
kameo.workspace = true
hex = "0.4.3"
[build-dependencies]
prost-build = "0.14.3"
serde_json = "1"
tonic-prost-build = "0.14.3"

View File

@@ -1,3 +1,5 @@
use crate::proto::auth::AuthChallenge;
pub mod proto {
tonic::include_proto!("arbiter");
@@ -22,3 +24,8 @@ pub fn home_path() -> Result<std::path::PathBuf, std::io::Error> {
Ok(arbiter_home)
}
pub fn format_challenge(challenge: &AuthChallenge) -> Vec<u8> {
let concat_form = format!("{}:{}", challenge.nonce, hex::encode(&challenge.pubkey));
concat_form.into_bytes().to_vec()
}

View File

@@ -5,8 +5,19 @@ edition = "2024"
repository = "https://git.markettakers.org/MarketTakers/arbiter"
[dependencies]
diesel = { version = "2.3.6", features = ["sqlite", "uuid", "time", "chrono", "serde_json"] }
diesel-async = { version = "0.7.4", features = ["bb8", "migrations", "sqlite", "tokio"] }
diesel = { version = "2.3.6", features = [
"sqlite",
"uuid",
"time",
"chrono",
"serde_json",
] }
diesel-async = { version = "0.7.4", features = [
"bb8",
"migrations",
"sqlite",
"tokio",
] }
ed25519.workspace = true
ed25519-dalek.workspace = true
arbiter-proto.path = "../arbiter-proto"
@@ -25,8 +36,17 @@ futures.workspace = true
tokio-stream.workspace = true
dashmap = "6.1.0"
rand.workspace = true
rcgen = { version = "0.14.7", features = ["aws_lc_rs", "pem", "x509-parser", "zeroize"], default-features = false }
rkyv = { version = "0.8.15", features = ["aligned", "little_endian", "pointer_width_64"] }
rcgen = { version = "0.14.7", features = [
"aws_lc_rs",
"pem",
"x509-parser",
"zeroize",
], default-features = false }
rkyv = { version = "0.8.15", features = [
"aligned",
"little_endian",
"pointer_width_64",
] }
restructed = "0.2.2"
chrono.workspace = true
bytes = "1.11.1"
@@ -34,4 +54,8 @@ memsafe = "0.4.0"
chacha20poly1305 = { version = "0.10.1", features = ["std"] }
zeroize = { version = "1.8.2", features = ["std", "simd"] }
kameo.workspace = true
prost-types.workspace = true
[dev-dependencies]
test-log = { version = "0.2", default-features = false, features = ["trace"] }
tempfile = "3.25.0"

View File

@@ -1,104 +1,139 @@
use std::sync::Arc;
use arbiter_proto::{
proto::{
UserAgentRequest, UserAgentResponse,
auth::{
self, AuthChallengeRequest, ClientMessage, ServerMessage as AuthServerMessage,
client_message::Payload as ClientAuthPayload,
server_message::Payload as ServerAuthPayload,
},
user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
use arbiter_proto::proto::{
UserAgentRequest, UserAgentResponse,
auth::{
self, AuthChallenge, AuthChallengeRequest, AuthOk, ClientMessage,
ServerMessage as AuthServerMessage, client_message::Payload as ClientAuthPayload,
server_message::Payload as ServerAuthPayload,
},
transport::Bi,
user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
};
use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl};
use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, dsl::update};
use diesel_async::{AsyncConnection, RunQueryDsl};
use ed25519_dalek::{SigningKey, VerifyingKey};
use ed25519_dalek::VerifyingKey;
use futures::StreamExt;
use kameo::{
Actor,
actor::{ActorRef, Spawn},
error::SendError,
message::StreamMessage,
messages,
prelude::Context,
};
use secrecy::{ExposeSecret, SecretBox};
use tokio::sync::mpsc;
use tokio::sync::mpsc::Sender;
use tonic::{Status, transport::Server};
use tracing::error;
use tonic::Status;
use tracing::{error, info};
use crate::{ServerContext, context::bootstrap::ConsumeToken, db::schema};
use crate::{
ServerContext,
context::bootstrap::{BootstrapActor, ConsumeToken},
db::{self, schema},
errors::GrpcStatusExt,
};
#[derive(Debug)]
/// Context for state machine with validated key and sent challenge
/// Challenge is then transformed to bytes using shared function and verified
#[derive(Clone, Debug)]
pub struct ChallengeContext {
challenge: auth::AuthChallenge,
key: SigningKey,
challenge: AuthChallenge,
key: VerifyingKey,
}
// Request context with deserialized public key for state machine.
// This intermediate struct is needed because the state machine branches depending on presence of bootstrap token,
// but we want to have the deserialized key in both branches.
#[derive(Clone, Debug)]
pub struct AuthRequestContext {
pubkey: VerifyingKey,
bootstrap_token: Option<String>,
}
smlang::statemachine!(
name: UserAgent,
derive_states: [Debug],
custom_error: false,
transitions: {
*Init + ReceivedRequest(ed25519_dalek::VerifyingKey) [async check_key_existence] / provide_challenge = WaitingForChallengeSolution(ChallengeContext),
Init + ReceivedBootstrapToken(String) = Authenticated,
*Init + AuthRequest(AuthRequestContext) / auth_request_context = ReceivedAuthRequest(AuthRequestContext),
ReceivedAuthRequest(AuthRequestContext) + ReceivedBootstrapToken = Authenticated,
ReceivedAuthRequest(AuthRequestContext) + SentChallenge(ChallengeContext) / move_challenge = WaitingForChallengeSolution(ChallengeContext),
WaitingForChallengeSolution(ChallengeContext) + ReceivedGoodSolution = Authenticated,
WaitingForChallengeSolution(ChallengeContext) + ReceivedBadSolution = Error,
WaitingForChallengeSolution(ChallengeContext) + ReceivedBadSolution = AuthError, // block further transitions, but connection should close anyway
}
);
impl UserAgentStateMachineContext for ServerContext {
pub struct DummyContext;
impl UserAgentStateMachineContext for DummyContext {
#[allow(missing_docs)]
#[allow(clippy::unused_unit)]
fn provide_challenge(
fn move_challenge(
&mut self,
event_data: ed25519_dalek::VerifyingKey,
state_data: &AuthRequestContext,
event_data: ChallengeContext,
) -> Result<ChallengeContext, ()> {
todo!()
Ok(event_data)
}
#[allow(missing_docs)]
#[allow(clippy::result_unit_err)]
async fn check_key_existence(
&self,
event_data: &ed25519_dalek::VerifyingKey,
) -> Result<bool, ()> {
todo!()
#[allow(clippy::unused_unit)]
fn auth_request_context(
&mut self,
event_data: AuthRequestContext,
) -> Result<AuthRequestContext, ()> {
Ok(event_data)
}
}
#[derive(Actor)]
pub struct UserAgentActor {
context: ServerContext,
state: UserAgentStateMachine<ServerContext>,
rx: Sender<Result<UserAgentResponse, Status>>,
db: db::DatabasePool,
bootstapper: ActorRef<BootstrapActor>,
state: UserAgentStateMachine<DummyContext>,
tx: Sender<Result<UserAgentResponse, Status>>,
}
impl UserAgentActor {
pub(crate) fn new(
context: ServerContext,
rx: Sender<Result<UserAgentResponse, Status>>,
tx: Sender<Result<UserAgentResponse, Status>>,
) -> Self {
Self {
context: context.clone(),
state: UserAgentStateMachine::new(context),
rx,
db: context.db.clone(),
bootstapper: context.bootstrapper.clone(),
state: UserAgentStateMachine::new(DummyContext),
tx,
}
}
pub(crate) fn new_manual(
db: db::DatabasePool,
bootstapper: ActorRef<BootstrapActor>,
tx: Sender<Result<UserAgentResponse, Status>>,
) -> Self {
Self {
db,
bootstapper,
state: UserAgentStateMachine::new(DummyContext),
tx,
}
}
fn transition(&mut self, event: UserAgentEvents) -> Result<(), Status> {
self.state.process_event(event).map_err(|e| {
error!(?e, "State transition failed");
Status::internal("State machine error")
})?;
Ok(())
}
async fn auth_with_bootstrap_token(
&mut self,
pubkey: ed25519_dalek::VerifyingKey,
token: String,
) -> Result<UserAgentResponse, Status> {
let token_ok: bool = self
.context
.bootstrapper
.bootstapper
.ask(ConsumeToken { token })
.await
.map_err(|e| {
@@ -106,11 +141,13 @@ impl UserAgentActor {
Status::internal("Bootstrap token consumption failed")
})?;
if token_ok {
let mut conn = self.context.db.get().await.map_err(|e| {
error!(?pubkey, "Failed to get DB connection: {e}");
Status::internal("Database connection error")
})?;
if !token_ok {
error!(?pubkey, "Invalid bootstrap token provided");
return Err(Status::invalid_argument("Invalid bootstrap token"));
}
{
let mut conn = self.db.get().await.to_status()?;
diesel::insert_into(schema::useragent_client::table)
.values((
@@ -119,24 +156,105 @@ impl UserAgentActor {
))
.execute(&mut conn)
.await
.map_err(|e| {
error!(?pubkey, "Failed to insert new user agent client: {e}");
Status::internal("Database error")
})?;
return Ok(UserAgentResponse {
payload: Some(UserAgentResponsePayload::AuthMessage(AuthServerMessage {
payload: Some(ServerAuthPayload::Auth),
})),
});
.to_status()?;
}
todo!()
self.transition(UserAgentEvents::ReceivedBootstrapToken)?;
Ok(auth_response(ServerAuthPayload::AuthOk(AuthOk {})))
}
async fn auth_with_challenge(&mut self, pubkey: VerifyingKey, pubkey_bytes: Vec<u8>) -> Output {
let nonce: Option<i32> = {
let mut db_conn = self.db.get().await.to_status()?;
db_conn
.transaction(|conn| {
Box::pin(async move {
let current_nonce = schema::useragent_client::table
.filter(
schema::useragent_client::public_key.eq(pubkey.as_bytes().to_vec()),
)
.select(schema::useragent_client::nonce)
.first::<i32>(conn)
.await?;
update(schema::useragent_client::table)
.filter(
schema::useragent_client::public_key.eq(pubkey.as_bytes().to_vec()),
)
.set(schema::useragent_client::nonce.eq(current_nonce + 1))
.execute(conn)
.await?;
Result::<_, diesel::result::Error>::Ok(current_nonce)
})
})
.await
.optional()
.to_status()?
};
let Some(nonce) = nonce else {
error!(?pubkey, "Public key not found in database");
return Err(Status::unauthenticated("Public key not registered"));
};
let challenge = auth::AuthChallenge {
pubkey: pubkey_bytes,
nonce: nonce,
};
self.transition(UserAgentEvents::SentChallenge(ChallengeContext {
challenge: challenge.clone(),
key: pubkey,
}))?;
info!(
?pubkey,
?challenge,
"Sent authentication challenge to client"
);
Ok(auth_response(ServerAuthPayload::AuthChallenge(challenge)))
}
fn verify_challenge_solution(
&self,
solution: &auth::AuthChallengeSolution,
) -> Result<(bool, &ChallengeContext), Status> {
let UserAgentStates::WaitingForChallengeSolution(challenge_context) = self.state.state()
else {
error!("Received challenge solution in invalid state");
return Err(Status::invalid_argument(
"Invalid state for challenge solution",
));
};
let formatted_challenge = arbiter_proto::format_challenge(&challenge_context.challenge);
let signature = solution.signature.as_slice().try_into().map_err(|_| {
error!(?solution, "Invalid signature length");
Status::invalid_argument("Invalid signature length")
})?;
let valid = challenge_context
.key
.verify_strict(&formatted_challenge, &signature)
.is_ok();
Ok((valid, challenge_context))
}
}
type Output = Result<UserAgentResponse, Status>;
fn auth_response(payload: ServerAuthPayload) -> UserAgentResponse {
UserAgentResponse {
payload: Some(UserAgentResponsePayload::AuthMessage(AuthServerMessage {
payload: Some(payload),
})),
}
}
#[messages]
impl UserAgentActor {
#[message(ctx)]
@@ -153,32 +271,15 @@ impl UserAgentActor {
Status::invalid_argument("Failed to convert pubkey to VerifyingKey")
})?;
if let Some(token) = req.bootstrap_token {
return self.auth_with_bootstrap_token(pubkey, token).await;
self.transition(UserAgentEvents::AuthRequest(AuthRequestContext {
pubkey,
bootstrap_token: req.bootstrap_token.clone(),
}))?;
match req.bootstrap_token {
Some(token) => self.auth_with_bootstrap_token(pubkey, token).await,
None => self.auth_with_challenge(pubkey, req.pubkey).await,
}
let mut db_conn = self.context.db.get().await.map_err(|err| {
error!(?pubkey, "Failed to get DB connection: {err}");
Status::internal("Database connection error")
})?;
let nonce = db_conn
.transaction(|mut conn| {
Box::pin(async move {
let current_nonce = schema::useragent_client::table
.filter(schema::useragent_client::public_key.eq(pubkey.as_bytes().to_vec()))
.select(schema::useragent_client::nonce)
.first::<i32>(&mut db_conn)
.await?;
Ok(())
})
})
.await;
// let nonce = match last_used_nonce
todo!()
}
#[message(ctx)]
@@ -187,79 +288,82 @@ impl UserAgentActor {
solution: auth::AuthChallengeSolution,
ctx: &mut Context<Self, Output>,
) -> Output {
todo!()
}
}
let (valid, challenge_context) = self.verify_challenge_solution(&solution)?;
pub(crate) async fn handle_user_agent(
context: ServerContext,
mut req_stream: tonic::Streaming<UserAgentRequest>,
tx: mpsc::Sender<Result<UserAgentResponse, Status>>,
) {
let actor = UserAgentActor::spawn(UserAgentActor::new(context, tx.clone()));
while let Some(Ok(req)) = req_stream.next().await
&& actor.is_alive()
{
match process_message(&actor, req).await {
Ok(resp) => {
if tx.send(Ok(resp)).await.is_err() {
error!(actor = "useragent", "Failed to send response to client");
break;
}
}
Err(status) => {
let _ = tx.send(Err(status)).await;
break;
}
if valid {
info!(
?challenge_context,
"Client provided valid solution to authentication challenge"
);
self.transition(UserAgentEvents::ReceivedGoodSolution)?;
Ok(auth_response(ServerAuthPayload::AuthOk(AuthOk {})))
} else {
error!("Client provided invalid solution to authentication challenge");
self.transition(UserAgentEvents::ReceivedBadSolution)?;
Err(Status::unauthenticated("Invalid challenge solution"))
}
}
actor.kill();
}
async fn process_message(
actor: &ActorRef<UserAgentActor>,
req: UserAgentRequest,
) -> Result<UserAgentResponse, Status> {
let msg = req.payload.ok_or_else(|| {
error!(actor = "useragent", "Received message with no payload");
Status::invalid_argument("Expected message with payload")
})?;
#[cfg(test)]
mod tests {
use arbiter_proto::proto::{
UserAgentResponse, auth::{AuthChallengeRequest, AuthOk},
user_agent_response::Payload as UserAgentResponsePayload,
};
use kameo::actor::Spawn;
let UserAgentRequestPayload::AuthMessage(ClientMessage {
payload: Some(client_message),
}) = msg
else {
error!(
actor = "useragent",
"Received unexpected message type during authentication"
use crate::{
actors::user_agent::HandleAuthChallengeRequest, context::bootstrap::BootstrapActor, db,
};
use super::UserAgentActor;
#[tokio::test]
#[test_log::test]
pub async fn test_bootstrap_token_auth() {
let db = db::create_test_pool().await;
// explicitly not installing any user_agent pubkeys
let bootstrapper = BootstrapActor::new(&db).await.unwrap(); // this will create bootstrap token
let token = bootstrapper.get_token().unwrap();
let bootstrapper_ref = BootstrapActor::spawn(bootstrapper);
let user_agent = UserAgentActor::new_manual(
db.clone(),
bootstrapper_ref,
tokio::sync::mpsc::channel(1).0, // dummy channel, we won't actually send responses in this test
);
return Err(Status::invalid_argument(
"Expected AuthMessage with ClientMessage payload",
));
};
let user_agent_ref = UserAgentActor::spawn(user_agent);
let result = match client_message {
ClientAuthPayload::AuthChallengeRequest(req) => actor
.ask(HandleAuthChallengeRequest { req })
// simulate client sending auth request with bootstrap token
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
let result = user_agent_ref
.ask(HandleAuthChallengeRequest {
req: AuthChallengeRequest {
pubkey: pubkey_bytes,
bootstrap_token: Some(token),
},
})
.await
.map_err(into_status),
ClientAuthPayload::AuthChallengeSolution(solution) => actor
.ask(HandleAuthChallengeSolution { solution })
.await
.map_err(into_status),
};
result
}
fn into_status<M>(e: SendError<M, Status>) -> Status {
match e {
SendError::HandlerError(status) => status,
_ => {
error!(actor = "useragent", "Failed to send message to actor");
Status::internal("session failure")
}
.expect("Shouldn't fail to send message");
// auth succeeded
assert_eq!(
result,
UserAgentResponse {
payload: Some(UserAgentResponsePayload::AuthMessage(
arbiter_proto::proto::auth::ServerMessage {
payload: Some(arbiter_proto::proto::auth::server_message::Payload::AuthOk(
AuthOk {},
)),
},
)),
}
);
}
}
mod transport;
pub(crate) use transport::handle_user_agent;

View File

@@ -0,0 +1,95 @@
use super::UserAgentActor;
use arbiter_proto::proto::{
UserAgentRequest, UserAgentResponse,
auth::{
self, AuthChallenge, AuthChallengeRequest, AuthOk, ClientMessage,
ServerMessage as AuthServerMessage, client_message::Payload as ClientAuthPayload,
server_message::Payload as ServerAuthPayload,
},
user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload,
};
use futures::StreamExt;
use kameo::{
actor::{ActorRef, Spawn as _},
error::SendError,
};
use tokio::sync::mpsc;
use tonic::Status;
use tracing::error;
use crate::{
actors::user_agent::{HandleAuthChallengeRequest, HandleAuthChallengeSolution},
context::ServerContext,
};
pub(crate) async fn handle_user_agent(
context: ServerContext,
mut req_stream: tonic::Streaming<UserAgentRequest>,
tx: mpsc::Sender<Result<UserAgentResponse, Status>>,
) {
let actor = UserAgentActor::spawn(UserAgentActor::new(context, tx.clone()));
while let Some(Ok(req)) = req_stream.next().await
&& actor.is_alive()
{
match process_message(&actor, req).await {
Ok(resp) => {
if tx.send(Ok(resp)).await.is_err() {
error!(actor = "useragent", "Failed to send response to client");
break;
}
}
Err(status) => {
let _ = tx.send(Err(status)).await;
break;
}
}
}
actor.kill();
}
async fn process_message(
actor: &ActorRef<UserAgentActor>,
req: UserAgentRequest,
) -> Result<UserAgentResponse, Status> {
let msg = req.payload.ok_or_else(|| {
error!(actor = "useragent", "Received message with no payload");
Status::invalid_argument("Expected message with payload")
})?;
let UserAgentRequestPayload::AuthMessage(ClientMessage {
payload: Some(client_message),
}) = msg
else {
error!(
actor = "useragent",
"Received unexpected message type during authentication"
);
return Err(Status::invalid_argument(
"Expected AuthMessage with ClientMessage payload",
));
};
match client_message {
ClientAuthPayload::AuthChallengeRequest(req) => actor
.ask(HandleAuthChallengeRequest { req })
.await
.map_err(into_status),
ClientAuthPayload::AuthChallengeSolution(solution) => actor
.ask(HandleAuthChallengeSolution { solution })
.await
.map_err(into_status),
}
}
fn into_status<M>(e: SendError<M, Status>) -> Status {
match e {
SendError::HandlerError(status) => status,
_ => {
error!(actor = "useragent", "Failed to send message to actor");
Status::internal("session failure")
}
}
}

View File

@@ -1,9 +1,5 @@
use arbiter_proto::{BOOTSTRAP_TOKEN_PATH, home_path};
use diesel::{
ExpressionMethods, QueryDsl,
dsl::{count, exists},
select,
};
use diesel::{ExpressionMethods, QueryDsl};
use diesel_async::RunQueryDsl;
use kameo::{Actor, messages};
use memsafe::MemSafe;
@@ -61,16 +57,14 @@ impl BootstrapActor {
pub async fn new(db: &DatabasePool) -> Result<Self, BootstrapError> {
let mut conn = db.get().await?;
let needs_token: bool = select(exists(
schema::useragent_client::table
.filter(schema::useragent_client::id.eq(schema::useragent_client::id)), // Just check if the table is empty
))
.first(&mut conn)
.await?;
let row_count: i64 = schema::useragent_client::table
.count()
.get_result(&mut conn)
.await?;
drop(conn);
let token = if needs_token {
let token = if row_count == 0 {
let token = generate_token().await?;
info!(%token, "Generated bootstrap token");
tokio::fs::write(home_path()?.join(BOOTSTRAP_TOKEN_PATH), token.as_str()).await?;
@@ -81,6 +75,11 @@ impl BootstrapActor {
Ok(Self { token })
}
#[cfg(test)]
pub fn get_token(&self) -> Option<String> {
self.token.clone()
}
}
#[messages]

View File

@@ -1,14 +1,18 @@
use std::sync::Arc;
use diesel::{Connection as _, SqliteConnection, connection::SimpleConnection as _};
use diesel::{
Connection as _, SqliteConnection,
connection::{SimpleConnection as _, TransactionManager},
};
use diesel_async::{
AsyncConnection, SimpleAsyncConnection as _,
AsyncConnection, SimpleAsyncConnection,
pooled_connection::{AsyncDieselConnectionManager, ManagerConfig, RecyclingMethod},
sync_connection_wrapper::SyncConnectionWrapper,
};
use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations};
use miette::Diagnostic;
use thiserror::Error;
use tracing::info;
pub mod models;
pub mod schema;
@@ -20,7 +24,7 @@ pub type PoolError = diesel_async::pooled_connection::bb8::RunError;
static DB_FILE: &'static str = "arbiter.sqlite";
const MIGRATIONS: EmbeddedMigrations = embed_migrations!();
const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");
#[derive(Error, Diagnostic, Debug)]
pub enum DatabaseSetupError {
@@ -45,6 +49,7 @@ pub enum DatabaseSetupError {
Pool(#[from] PoolInitError),
}
#[tracing::instrument(level = "info")]
fn database_path() -> Result<std::path::PathBuf, DatabaseSetupError> {
let arbiter_home = arbiter_proto::home_path().map_err(DatabaseSetupError::HomeDir)?;
@@ -53,7 +58,7 @@ fn database_path() -> Result<std::path::PathBuf, DatabaseSetupError> {
Ok(db_path)
}
#[tracing::instrument(level = "info", skip(conn))]
fn db_config(conn: &mut SqliteConnection) -> Result<(), diesel::result::Error> {
// fsync only in critical moments
conn.batch_execute("PRAGMA synchronous = NORMAL;")?;
@@ -75,6 +80,7 @@ fn db_config(conn: &mut SqliteConnection) -> Result<(), diesel::result::Error> {
Ok(())
}
#[tracing::instrument(level = "info", skip(url))]
fn initialize_database(url: &str) -> Result<(), DatabaseSetupError> {
let mut conn = SqliteConnection::establish(url).map_err(DatabaseSetupError::Connection)?;
@@ -83,16 +89,19 @@ fn initialize_database(url: &str) -> Result<(), DatabaseSetupError> {
conn.run_pending_migrations(MIGRATIONS)
.map_err(DatabaseSetupError::Migration)?;
info!(%url, "Database initialized successfully");
Ok(())
}
pub async fn create_pool() -> Result<DatabasePool, DatabaseSetupError> {
let database_url = format!(
#[tracing::instrument(level = "info")]
pub async fn create_pool(url: Option<&str>) -> Result<DatabasePool, DatabaseSetupError> {
let database_url = url.map(String::from).unwrap_or(format!(
"{}?mode=rwc",
database_path()?
(database_path()?
.to_str()
.expect("database path is not valid UTF-8")
);
.expect("database path is not valid UTF-8"))
));
initialize_database(&database_url)?;
@@ -115,10 +124,29 @@ pub async fn create_pool() -> Result<DatabasePool, DatabaseSetupError> {
})
});
let pool = DatabasePool::builder().build(AsyncDieselConnectionManager::new_with_config(
database_url,
config,
)).await?;
let pool = DatabasePool::builder()
.build(AsyncDieselConnectionManager::new_with_config(
database_url,
config,
))
.await?;
Ok(pool)
}
#[cfg(test)]
pub async fn create_test_pool() -> DatabasePool {
use rand::distr::{Alphanumeric, SampleString as _};
let tempfile_name = Alphanumeric.sample_string(&mut rand::rng(), 16);
let file = std::env::temp_dir().join(tempfile_name);
let url = format!(
"{}?mode=rwc",
file.to_str().expect("temp file path is not valid UTF-8")
);
create_pool(Some(&url))
.await
.expect("Failed to create test database pool")
}

View File

@@ -0,0 +1,24 @@
use tonic::Status;
use tracing::error;
pub trait GrpcStatusExt<T> {
fn to_status(self) -> Result<T, Status>;
}
impl<T> GrpcStatusExt<T> for Result<T, diesel::result::Error> {
fn to_status(self) -> Result<T, Status> {
self.map_err(|e| {
error!(error = ?e, "Database error");
Status::internal("Database error")
})
}
}
impl<T> GrpcStatusExt<T> for Result<T, crate::db::PoolError> {
fn to_status(self) -> Result<T, Status> {
self.map_err(|e| {
error!(error = ?e, "Database pool error");
Status::internal("Database pool error")
})
}
}

View File

@@ -20,6 +20,7 @@ use crate::{
pub mod actors;
mod context;
mod db;
mod errors;
const DEFAULT_CHANNEL_SIZE: usize = 1000;