From 81a55d28f042cb98add6185f73198528021f2cfd Mon Sep 17 00:00:00 2001 From: hdbg Date: Sat, 14 Feb 2026 18:17:48 +0100 Subject: [PATCH] test(db): add create_test_pool and use in tests --- server/Cargo.lock | 1 + server/crates/arbiter-proto/Cargo.toml | 1 + server/crates/arbiter-server/Cargo.toml | 1 + .../arbiter-server/src/actors/user_agent.rs | 4 +--- .../arbiter-server/src/context/bootstrap.rs | 18 ++++++------------ server/crates/arbiter-server/src/db.rs | 17 +++++++++++++++++ 6 files changed, 27 insertions(+), 15 deletions(-) diff --git a/server/Cargo.lock b/server/Cargo.lock index 9b2cc3e..17affd3 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -103,6 +103,7 @@ dependencies = [ "secrecy", "smlang", "statig", + "tempfile", "test-log", "thiserror", "tokio", diff --git a/server/crates/arbiter-proto/Cargo.toml b/server/crates/arbiter-proto/Cargo.toml index 2f828fc..85faddd 100644 --- a/server/crates/arbiter-proto/Cargo.toml +++ b/server/crates/arbiter-proto/Cargo.toml @@ -22,3 +22,4 @@ prost-build = "0.14.3" serde_json = "1" tonic-prost-build = "0.14.3" + diff --git a/server/crates/arbiter-server/Cargo.toml b/server/crates/arbiter-server/Cargo.toml index be01b16..afa9713 100644 --- a/server/crates/arbiter-server/Cargo.toml +++ b/server/crates/arbiter-server/Cargo.toml @@ -58,3 +58,4 @@ prost-types.workspace = true [dev-dependencies] test-log = { version = "0.2", default-features = false, features = ["trace"] } +tempfile = "3.25.0" \ No newline at end of file diff --git a/server/crates/arbiter-server/src/actors/user_agent.rs b/server/crates/arbiter-server/src/actors/user_agent.rs index d720b87..0b98e5b 100644 --- a/server/crates/arbiter-server/src/actors/user_agent.rs +++ b/server/crates/arbiter-server/src/actors/user_agent.rs @@ -322,9 +322,7 @@ mod tests { #[tokio::test] #[test_log::test] pub async fn test_bootstrap_token_auth() { - let db = db::create_pool(Some("sqlite://:memory:")) - .await - .expect("Failed to create database pool"); + let db = db::create_test_pool().await; // explicitly not installing any user_agent pubkeys let bootstrapper = BootstrapActor::new(&db).await.unwrap(); // this will create bootstrap token let token = bootstrapper.get_token().unwrap(); diff --git a/server/crates/arbiter-server/src/context/bootstrap.rs b/server/crates/arbiter-server/src/context/bootstrap.rs index 9b73e31..211344c 100644 --- a/server/crates/arbiter-server/src/context/bootstrap.rs +++ b/server/crates/arbiter-server/src/context/bootstrap.rs @@ -1,9 +1,5 @@ use arbiter_proto::{BOOTSTRAP_TOKEN_PATH, home_path}; -use diesel::{ - ExpressionMethods, QueryDsl, - dsl::{count, exists}, - select, -}; +use diesel::{ExpressionMethods, QueryDsl}; use diesel_async::RunQueryDsl; use kameo::{Actor, messages}; use memsafe::MemSafe; @@ -61,16 +57,14 @@ impl BootstrapActor { pub async fn new(db: &DatabasePool) -> Result { let mut conn = db.get().await?; - let needs_token: bool = select(exists( - schema::useragent_client::table - .filter(schema::useragent_client::id.eq(schema::useragent_client::id)), // Just check if the table is empty - )) - .first(&mut conn) - .await?; + let row_count: i64 = schema::useragent_client::table + .count() + .get_result(&mut conn) + .await?; drop(conn); - let token = if needs_token { + let token = if row_count == 0 { let token = generate_token().await?; info!(%token, "Generated bootstrap token"); tokio::fs::write(home_path()?.join(BOOTSTRAP_TOKEN_PATH), token.as_str()).await?; diff --git a/server/crates/arbiter-server/src/db.rs b/server/crates/arbiter-server/src/db.rs index 87f930d..c44d489 100644 --- a/server/crates/arbiter-server/src/db.rs +++ b/server/crates/arbiter-server/src/db.rs @@ -133,3 +133,20 @@ pub async fn create_pool(url: Option<&str>) -> Result DatabasePool { + use rand::distr::{Alphanumeric, SampleString as _}; + + let tempfile_name = Alphanumeric.sample_string(&mut rand::rng(), 16); + + let file = std::env::temp_dir().join(tempfile_name); + let url = format!( + "{}?mode=rwc", + file.to_str().expect("temp file path is not valid UTF-8") + ); + + create_pool(Some(&url)) + .await + .expect("Failed to create test database pool") +}