refactor(server): added SafeCell abstraction for easier protected memory swap

This commit is contained in:
hdbg
2026-03-16 18:56:13 +01:00
parent 088fa6fe72
commit 9017ea4017
14 changed files with 178 additions and 105 deletions

View File

@@ -3,11 +3,11 @@ use std::collections::{HashMap, HashSet};
use arbiter_server::{
actors::keyholder::{CreateNew, Error, KeyHolder},
db::{self, models, schema},
safe_cell::{SafeCell, SafeCellHandle as _},
};
use diesel::{ExpressionMethods as _, QueryDsl, SelectableHelper, dsl::sql_query};
use diesel_async::RunQueryDsl;
use kameo::actor::{ActorRef, Spawn as _};
use memsafe::MemSafe;
use tokio::task::JoinSet;
use crate::common;
@@ -24,7 +24,7 @@ async fn write_concurrently(
let plaintext = format!("{prefix}-{i}").into_bytes();
let id = actor
.ask(CreateNew {
plaintext: MemSafe::new(plaintext.clone()).unwrap(),
plaintext: SafeCell::new(plaintext.clone()),
})
.await
.unwrap();
@@ -118,7 +118,7 @@ async fn insert_failure_does_not_create_partial_row() {
drop(conn);
let err = actor
.create_new(MemSafe::new(b"should fail".to_vec()).unwrap())
.create_new(SafeCell::new(b"should fail".to_vec()))
.await
.unwrap_err();
assert!(matches!(err, Error::DatabaseTransaction(_)));
@@ -162,12 +162,12 @@ async fn decrypt_roundtrip_after_high_concurrency() {
let mut decryptor = KeyHolder::new(db.clone()).await.unwrap();
decryptor
.try_unseal(MemSafe::new(b"test-seal-key".to_vec()).unwrap())
.try_unseal(SafeCell::new(b"test-seal-key".to_vec()))
.await
.unwrap();
for (id, plaintext) in expected {
let mut decrypted = decryptor.decrypt(id).await.unwrap();
assert_eq!(*decrypted.read().unwrap(), plaintext);
assert_eq!(*decrypted.read(), plaintext);
}
}

View File

@@ -1,10 +1,10 @@
use arbiter_server::{
actors::keyholder::{Error, KeyHolder},
db::{self, models, schema},
safe_cell::{SafeCell, SafeCellHandle as _},
};
use diesel::{QueryDsl, SelectableHelper};
use diesel_async::RunQueryDsl;
use memsafe::MemSafe;
use crate::common;
@@ -14,7 +14,7 @@ async fn test_bootstrap() {
let db = db::create_test_pool().await;
let mut actor = KeyHolder::new(db.clone()).await.unwrap();
let seal_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap();
let seal_key = SafeCell::new(b"test-seal-key".to_vec());
actor.bootstrap(seal_key).await.unwrap();
let mut conn = db.get().await.unwrap();
@@ -43,7 +43,7 @@ async fn test_bootstrap_rejects_double() {
let db = db::create_test_pool().await;
let mut actor = common::bootstrapped_keyholder(&db).await;
let seal_key2 = MemSafe::new(b"test-seal-key".to_vec()).unwrap();
let seal_key2 = SafeCell::new(b"test-seal-key".to_vec());
let err = actor.bootstrap(seal_key2).await.unwrap_err();
assert!(matches!(err, Error::AlreadyBootstrapped));
}
@@ -55,7 +55,7 @@ async fn test_create_new_before_bootstrap_fails() {
let mut actor = KeyHolder::new(db).await.unwrap();
let err = actor
.create_new(MemSafe::new(b"data".to_vec()).unwrap())
.create_new(SafeCell::new(b"data".to_vec()))
.await
.unwrap_err();
assert!(matches!(err, Error::NotBootstrapped));
@@ -91,17 +91,17 @@ async fn test_unseal_correct_password() {
let plaintext = b"survive a restart";
let aead_id = actor
.create_new(MemSafe::new(plaintext.to_vec()).unwrap())
.create_new(SafeCell::new(plaintext.to_vec()))
.await
.unwrap();
drop(actor);
let mut actor = KeyHolder::new(db.clone()).await.unwrap();
let seal_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap();
let seal_key = SafeCell::new(b"test-seal-key".to_vec());
actor.try_unseal(seal_key).await.unwrap();
let mut decrypted = actor.decrypt(aead_id).await.unwrap();
assert_eq!(*decrypted.read().unwrap(), plaintext);
assert_eq!(*decrypted.read(), plaintext);
}
#[tokio::test]
@@ -112,20 +112,20 @@ async fn test_unseal_wrong_then_correct_password() {
let plaintext = b"important data";
let aead_id = actor
.create_new(MemSafe::new(plaintext.to_vec()).unwrap())
.create_new(SafeCell::new(plaintext.to_vec()))
.await
.unwrap();
drop(actor);
let mut actor = KeyHolder::new(db.clone()).await.unwrap();
let bad_key = MemSafe::new(b"wrong-password".to_vec()).unwrap();
let bad_key = SafeCell::new(b"wrong-password".to_vec());
let err = actor.try_unseal(bad_key).await.unwrap_err();
assert!(matches!(err, Error::InvalidKey));
let good_key = MemSafe::new(b"test-seal-key".to_vec()).unwrap();
let good_key = SafeCell::new(b"test-seal-key".to_vec());
actor.try_unseal(good_key).await.unwrap();
let mut decrypted = actor.decrypt(aead_id).await.unwrap();
assert_eq!(*decrypted.read().unwrap(), plaintext);
assert_eq!(*decrypted.read(), plaintext);
}

View File

@@ -3,10 +3,10 @@ use std::collections::HashSet;
use arbiter_server::{
actors::keyholder::{Error, encryption::v1},
db::{self, models, schema},
safe_cell::{SafeCell, SafeCellHandle as _},
};
use diesel::{ExpressionMethods as _, QueryDsl, SelectableHelper, dsl::update};
use diesel_async::RunQueryDsl;
use memsafe::MemSafe;
use crate::common;
@@ -18,12 +18,12 @@ async fn test_create_decrypt_roundtrip() {
let plaintext = b"hello arbiter";
let aead_id = actor
.create_new(MemSafe::new(plaintext.to_vec()).unwrap())
.create_new(SafeCell::new(plaintext.to_vec()))
.await
.unwrap();
let mut decrypted = actor.decrypt(aead_id).await.unwrap();
assert_eq!(*decrypted.read().unwrap(), plaintext);
assert_eq!(*decrypted.read(), plaintext);
}
#[tokio::test]
@@ -44,11 +44,11 @@ async fn test_ciphertext_differs_across_entries() {
let plaintext = b"same content";
let id1 = actor
.create_new(MemSafe::new(plaintext.to_vec()).unwrap())
.create_new(SafeCell::new(plaintext.to_vec()))
.await
.unwrap();
let id2 = actor
.create_new(MemSafe::new(plaintext.to_vec()).unwrap())
.create_new(SafeCell::new(plaintext.to_vec()))
.await
.unwrap();
@@ -70,8 +70,8 @@ async fn test_ciphertext_differs_across_entries() {
let mut d1 = actor.decrypt(id1).await.unwrap();
let mut d2 = actor.decrypt(id2).await.unwrap();
assert_eq!(*d1.read().unwrap(), plaintext);
assert_eq!(*d2.read().unwrap(), plaintext);
assert_eq!(*d1.read(), plaintext);
assert_eq!(*d2.read(), plaintext);
}
#[tokio::test]
@@ -83,7 +83,7 @@ async fn test_nonce_never_reused() {
let n = 5;
for i in 0..n {
actor
.create_new(MemSafe::new(format!("secret {i}").into_bytes()).unwrap())
.create_new(SafeCell::new(format!("secret {i}").into_bytes()))
.await
.unwrap();
}
@@ -137,7 +137,7 @@ async fn broken_db_nonce_format_fails_closed() {
drop(conn);
let err = actor
.create_new(MemSafe::new(b"must fail".to_vec()).unwrap())
.create_new(SafeCell::new(b"must fail".to_vec()))
.await
.unwrap_err();
assert!(matches!(err, Error::BrokenDatabase));
@@ -145,7 +145,7 @@ async fn broken_db_nonce_format_fails_closed() {
let db = db::create_test_pool().await;
let mut actor = common::bootstrapped_keyholder(&db).await;
let id = actor
.create_new(MemSafe::new(b"decrypt target".to_vec()).unwrap())
.create_new(SafeCell::new(b"decrypt target".to_vec()))
.await
.unwrap();
let mut conn = db.get().await.unwrap();