test(db): add create_test_pool and use in tests

This commit is contained in:
hdbg
2026-02-14 18:17:48 +01:00
parent 69dd8f57ca
commit 81a55d28f0
6 changed files with 27 additions and 15 deletions

1
server/Cargo.lock generated
View File

@@ -103,6 +103,7 @@ dependencies = [
"secrecy",
"smlang",
"statig",
"tempfile",
"test-log",
"thiserror",
"tokio",

View File

@@ -22,3 +22,4 @@ prost-build = "0.14.3"
serde_json = "1"
tonic-prost-build = "0.14.3"

View File

@@ -58,3 +58,4 @@ prost-types.workspace = true
[dev-dependencies]
test-log = { version = "0.2", default-features = false, features = ["trace"] }
tempfile = "3.25.0"

View File

@@ -322,9 +322,7 @@ mod tests {
#[tokio::test]
#[test_log::test]
pub async fn test_bootstrap_token_auth() {
let db = db::create_pool(Some("sqlite://:memory:"))
.await
.expect("Failed to create database pool");
let db = db::create_test_pool().await;
// explicitly not installing any user_agent pubkeys
let bootstrapper = BootstrapActor::new(&db).await.unwrap(); // this will create bootstrap token
let token = bootstrapper.get_token().unwrap();

View File

@@ -1,9 +1,5 @@
use arbiter_proto::{BOOTSTRAP_TOKEN_PATH, home_path};
use diesel::{
ExpressionMethods, QueryDsl,
dsl::{count, exists},
select,
};
use diesel::{ExpressionMethods, QueryDsl};
use diesel_async::RunQueryDsl;
use kameo::{Actor, messages};
use memsafe::MemSafe;
@@ -61,16 +57,14 @@ impl BootstrapActor {
pub async fn new(db: &DatabasePool) -> Result<Self, BootstrapError> {
let mut conn = db.get().await?;
let needs_token: bool = select(exists(
schema::useragent_client::table
.filter(schema::useragent_client::id.eq(schema::useragent_client::id)), // Just check if the table is empty
))
.first(&mut conn)
let row_count: i64 = schema::useragent_client::table
.count()
.get_result(&mut conn)
.await?;
drop(conn);
let token = if needs_token {
let token = if row_count == 0 {
let token = generate_token().await?;
info!(%token, "Generated bootstrap token");
tokio::fs::write(home_path()?.join(BOOTSTRAP_TOKEN_PATH), token.as_str()).await?;

View File

@@ -133,3 +133,20 @@ pub async fn create_pool(url: Option<&str>) -> Result<DatabasePool, DatabaseSetu
Ok(pool)
}
#[cfg(test)]
pub async fn create_test_pool() -> DatabasePool {
use rand::distr::{Alphanumeric, SampleString as _};
let tempfile_name = Alphanumeric.sample_string(&mut rand::rng(), 16);
let file = std::env::temp_dir().join(tempfile_name);
let url = format!(
"{}?mode=rwc",
file.to_str().expect("temp file path is not valid UTF-8")
);
create_pool(Some(&url))
.await
.expect("Failed to create test database pool")
}