Compare commits
5 Commits
feat-lints
...
key-altern
| Author | SHA1 | Date | |
|---|---|---|---|
| d65e9319d9 | |||
| dfc852e815 | |||
| 5b711acb15 | |||
| 19f19a56e5 | |||
| f108e64d13 |
3
.vscode/settings.json
vendored
3
.vscode/settings.json
vendored
@@ -1,3 +0,0 @@
|
||||
{
|
||||
"git.enabled": false
|
||||
}
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
flutter = "3.38.9-stable"
|
||||
protoc = "29.6"
|
||||
rust = "1.93.0"
|
||||
rust = "1.93.1"
|
||||
"cargo:cargo-features-manager" = "0.11.1"
|
||||
"cargo:cargo-nextest" = "0.9.126"
|
||||
"cargo:cargo-shear" = "latest"
|
||||
|
||||
@@ -7,23 +7,29 @@ import "auth.proto";
|
||||
message ClientRequest {
|
||||
oneof payload {
|
||||
arbiter.auth.ClientMessage auth_message = 1;
|
||||
CertRotationAck cert_rotation_ack = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message ClientResponse {
|
||||
oneof payload {
|
||||
arbiter.auth.ServerMessage auth_message = 1;
|
||||
CertRotationNotification cert_rotation_notification = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message UserAgentRequest {
|
||||
oneof payload {
|
||||
arbiter.auth.ClientMessage auth_message = 1;
|
||||
CertRotationAck cert_rotation_ack = 2;
|
||||
UnsealRequest unseal_request = 3;
|
||||
}
|
||||
}
|
||||
message UserAgentResponse {
|
||||
oneof payload {
|
||||
arbiter.auth.ServerMessage auth_message = 1;
|
||||
CertRotationNotification cert_rotation_notification = 2;
|
||||
UnsealResponse unseal_response = 3;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,6 +38,76 @@ message ServerInfo {
|
||||
bytes cert_public_key = 2;
|
||||
}
|
||||
|
||||
// TLS Certificate Rotation Protocol
|
||||
message CertRotationNotification {
|
||||
// New public certificate (DER-encoded)
|
||||
bytes new_cert = 1;
|
||||
|
||||
// Unix timestamp when rotation will be executed (if all ACKs received)
|
||||
int64 rotation_scheduled_at = 2;
|
||||
|
||||
// Unix timestamp deadline for ACK (7 days from now)
|
||||
int64 ack_deadline = 3;
|
||||
|
||||
// Rotation ID for tracking
|
||||
int32 rotation_id = 4;
|
||||
}
|
||||
|
||||
message CertRotationAck {
|
||||
// Rotation ID (from CertRotationNotification)
|
||||
int32 rotation_id = 1;
|
||||
|
||||
// Client public key for identification
|
||||
bytes client_public_key = 2;
|
||||
|
||||
// Confirmation that client saved the new certificate
|
||||
bool cert_saved = 3;
|
||||
}
|
||||
|
||||
// Vault Unseal Protocol (X25519 ECDH + ChaCha20Poly1305)
|
||||
message UnsealRequest {
|
||||
oneof payload {
|
||||
EphemeralKeyRequest ephemeral_key_request = 1;
|
||||
SealedPassword sealed_password = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message UnsealResponse {
|
||||
oneof payload {
|
||||
EphemeralKeyResponse ephemeral_key_response = 1;
|
||||
UnsealResult unseal_result = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message EphemeralKeyRequest {}
|
||||
|
||||
message EphemeralKeyResponse {
|
||||
// Server's X25519 ephemeral public key (32 bytes)
|
||||
bytes server_pubkey = 1;
|
||||
|
||||
// Unix timestamp when this key expires (60 seconds from generation)
|
||||
int64 expires_at = 2;
|
||||
}
|
||||
|
||||
message SealedPassword {
|
||||
// Client's X25519 ephemeral public key (32 bytes)
|
||||
bytes client_pubkey = 1;
|
||||
|
||||
// ChaCha20Poly1305 encrypted password (ciphertext + tag)
|
||||
bytes encrypted_password = 2;
|
||||
|
||||
// 12-byte nonce for ChaCha20Poly1305
|
||||
bytes nonce = 3;
|
||||
}
|
||||
|
||||
message UnsealResult {
|
||||
// Whether unseal was successful
|
||||
bool success = 1;
|
||||
|
||||
// Error message if unseal failed
|
||||
optional string error_message = 2;
|
||||
}
|
||||
|
||||
service ArbiterService {
|
||||
rpc Client(stream ClientRequest) returns (stream ClientResponse);
|
||||
rpc UserAgent(stream UserAgentRequest) returns (stream UserAgentResponse);
|
||||
|
||||
46
protobufs/google/protobuf/timestamp.proto
Normal file
46
protobufs/google/protobuf/timestamp.proto
Normal file
@@ -0,0 +1,46 @@
|
||||
// Protocol Buffers - Google's data interchange format
|
||||
// Copyright 2008 Google Inc. All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package google.protobuf;
|
||||
|
||||
option csharp_namespace = "Google.Protobuf.WellKnownTypes";
|
||||
option cc_enable_arenas = true;
|
||||
option go_package = "google.golang.org/protobuf/types/known/timestamppb";
|
||||
option java_package = "com.google.protobuf";
|
||||
option java_outer_classname = "TimestampProto";
|
||||
option java_multiple_files = true;
|
||||
option objc_class_prefix = "GPB";
|
||||
|
||||
// A Timestamp represents a point in time independent of any time zone or local
|
||||
// calendar, encoded as a count of seconds and fractions of seconds at
|
||||
// nanosecond resolution. The count is relative to an epoch at UTC midnight on
|
||||
// January 1, 1970, in the proleptic Gregorian calendar which extends the
|
||||
// Gregorian calendar backwards to year one.
|
||||
message Timestamp {
|
||||
// Represents seconds of UTC time since Unix epoch
|
||||
// 1970-01-01T00:00:00Z. Must be from 0001-01-01T00:00:00Z to
|
||||
// 9999-12-31T23:59:59Z inclusive.
|
||||
int64 seconds = 1;
|
||||
|
||||
// Non-negative fractions of a second at nanosecond resolution. Negative
|
||||
// second values with fractions must still have non-negative nanos values
|
||||
// that count forward in time. Must be from 0 to 999,999,999
|
||||
// inclusive.
|
||||
int32 nanos = 2;
|
||||
}
|
||||
541
server/Cargo.lock
generated
541
server/Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -15,6 +15,8 @@ kameo.workspace = true
|
||||
prost-types.workspace = true
|
||||
|
||||
[build-dependencies]
|
||||
prost-build = "0.14.3"
|
||||
serde_json = "1"
|
||||
tonic-prost-build = "0.14.3"
|
||||
|
||||
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
use tonic_prost_build::configure;
|
||||
|
||||
static PROTOBUF_DIR: &str = "../../../protobufs";
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
configure()
|
||||
.message_attribute(".", "#[derive(::kameo::Reply)]")
|
||||
.compile_protos(
|
||||
&[
|
||||
format!("{}/arbiter.proto", PROTOBUF_DIR),
|
||||
format!("{}/auth.proto", PROTOBUF_DIR),
|
||||
],
|
||||
&[PROTOBUF_DIR.to_string()],
|
||||
)
|
||||
let proto_files = vec![
|
||||
format!("{}/arbiter.proto", PROTOBUF_DIR),
|
||||
format!("{}/auth.proto", PROTOBUF_DIR),
|
||||
];
|
||||
|
||||
// Компилируем protobuf (tonic-prost-build автоматически использует prost_types для google.protobuf)
|
||||
tonic_prost_build::configure()
|
||||
.message_attribute(".", "#[derive(::kameo::Reply)]")
|
||||
.compile_protos(&proto_files, &[PROTOBUF_DIR.to_string()])?;
|
||||
|
||||
.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -10,10 +10,10 @@ pub mod proto {
|
||||
|
||||
pub mod transport;
|
||||
|
||||
pub static BOOTSTRAP_TOKEN_PATH: &'static str = "bootstrap_token";
|
||||
pub static BOOTSTRAP_TOKEN_PATH: &str = "bootstrap_token";
|
||||
|
||||
pub fn home_path() -> Result<std::path::PathBuf, std::io::Error> {
|
||||
static ARBITER_HOME: &'static str = ".arbiter";
|
||||
static ARBITER_HOME: &str = ".arbiter";
|
||||
let home_dir = std::env::home_dir().ok_or(std::io::Error::new(
|
||||
std::io::ErrorKind::PermissionDenied,
|
||||
"can not get home directory",
|
||||
|
||||
@@ -43,7 +43,11 @@ rcgen = { version = "0.14.7", features = [
|
||||
chrono.workspace = true
|
||||
memsafe = "0.4.0"
|
||||
zeroize = { version = "1.8.2", features = ["std", "simd"] }
|
||||
argon2 = { version = "0.5", features = ["std"] }
|
||||
kameo.workspace = true
|
||||
hex = "0.4.3"
|
||||
chacha20poly1305 = "0.10.1"
|
||||
x25519-dalek = { version = "2.0", features = ["static_secrets"] }
|
||||
|
||||
[dev-dependencies]
|
||||
test-log = { version = "0.2", default-features = false, features = ["trace"] }
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
-- Rollback TLS rotation tables
|
||||
|
||||
-- Удалить добавленную колонку из arbiter_settings
|
||||
ALTER TABLE arbiter_settings DROP COLUMN current_cert_id;
|
||||
|
||||
-- Удалить таблицы в обратном порядке
|
||||
DROP TABLE IF EXISTS tls_rotation_history;
|
||||
DROP TABLE IF EXISTS rotation_client_acks;
|
||||
DROP TABLE IF EXISTS tls_rotation_state;
|
||||
DROP INDEX IF EXISTS idx_tls_certificates_active;
|
||||
DROP TABLE IF EXISTS tls_certificates;
|
||||
@@ -0,0 +1,57 @@
|
||||
-- История всех сертификатов
|
||||
CREATE TABLE IF NOT EXISTS tls_certificates (
|
||||
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
cert BLOB NOT NULL, -- DER-encoded
|
||||
cert_key BLOB NOT NULL, -- PEM-encoded
|
||||
not_before INTEGER NOT NULL, -- Unix timestamp
|
||||
not_after INTEGER NOT NULL, -- Unix timestamp
|
||||
created_at INTEGER NOT NULL DEFAULT(unixepoch('now')),
|
||||
is_active BOOLEAN NOT NULL DEFAULT 0 -- Только один active=1
|
||||
) STRICT;
|
||||
|
||||
CREATE INDEX idx_tls_certificates_active ON tls_certificates(is_active, not_after);
|
||||
|
||||
-- Tracking процесса ротации
|
||||
CREATE TABLE IF NOT EXISTS tls_rotation_state (
|
||||
id INTEGER NOT NULL PRIMARY KEY CHECK(id = 1), -- Singleton
|
||||
state TEXT NOT NULL DEFAULT('normal') CHECK(state IN ('normal', 'initiated', 'waiting_acks', 'ready')),
|
||||
new_cert_id INTEGER REFERENCES tls_certificates(id),
|
||||
initiated_at INTEGER,
|
||||
timeout_at INTEGER -- Таймаут для ожидания ACKs (initiated_at + 7 дней)
|
||||
) STRICT;
|
||||
|
||||
-- Tracking ACKs от клиентов
|
||||
CREATE TABLE IF NOT EXISTS rotation_client_acks (
|
||||
rotation_id INTEGER NOT NULL, -- Ссылка на new_cert_id
|
||||
client_key TEXT NOT NULL, -- Публичный ключ клиента (hex)
|
||||
ack_received_at INTEGER NOT NULL DEFAULT(unixepoch('now')),
|
||||
PRIMARY KEY (rotation_id, client_key)
|
||||
) STRICT;
|
||||
|
||||
-- Audit trail событий ротации
|
||||
CREATE TABLE IF NOT EXISTS tls_rotation_history (
|
||||
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
cert_id INTEGER NOT NULL REFERENCES tls_certificates(id),
|
||||
event_type TEXT NOT NULL CHECK(event_type IN ('created', 'rotation_initiated', 'acks_complete', 'activated', 'timeout')),
|
||||
timestamp INTEGER NOT NULL DEFAULT(unixepoch('now')),
|
||||
details TEXT -- JSON с доп. информацией
|
||||
) STRICT;
|
||||
|
||||
-- Миграция существующего сертификата
|
||||
INSERT INTO tls_certificates (id, cert, cert_key, not_before, not_after, is_active, created_at)
|
||||
SELECT
|
||||
1,
|
||||
cert,
|
||||
cert_key,
|
||||
unixepoch('now') as not_before,
|
||||
unixepoch('now') + (90 * 24 * 60 * 60) as not_after, -- 90 дней
|
||||
1 as is_active,
|
||||
unixepoch('now')
|
||||
FROM arbiter_settings WHERE id = 1;
|
||||
|
||||
-- Инициализация rotation_state
|
||||
INSERT INTO tls_rotation_state (id, state) VALUES (1, 'normal');
|
||||
|
||||
-- Добавить ссылку на текущий сертификат
|
||||
ALTER TABLE arbiter_settings ADD COLUMN current_cert_id INTEGER REFERENCES tls_certificates(id);
|
||||
UPDATE arbiter_settings SET current_cert_id = 1 WHERE id = 1;
|
||||
@@ -0,0 +1,2 @@
|
||||
-- Remove argon2_salt column
|
||||
ALTER TABLE aead_encrypted DROP COLUMN argon2_salt;
|
||||
@@ -0,0 +1,2 @@
|
||||
-- Add argon2_salt column to store password derivation salt
|
||||
ALTER TABLE aead_encrypted ADD COLUMN argon2_salt TEXT;
|
||||
@@ -1,12 +1,17 @@
|
||||
use arbiter_proto::proto::{
|
||||
UserAgentRequest, UserAgentResponse,
|
||||
auth::{
|
||||
self, AuthChallenge, AuthChallengeRequest, AuthOk, ClientMessage,
|
||||
ServerMessage as AuthServerMessage, client_message::Payload as ClientAuthPayload,
|
||||
server_message::Payload as ServerAuthPayload,
|
||||
use std::sync::Arc;
|
||||
|
||||
use arbiter_proto::{
|
||||
proto::{
|
||||
UserAgentRequest, UserAgentResponse,
|
||||
auth::{
|
||||
self, AuthChallengeRequest, ClientMessage, ServerMessage as AuthServerMessage,
|
||||
client_message::Payload as ClientAuthPayload,
|
||||
server_message::Payload as ServerAuthPayload,
|
||||
},
|
||||
user_agent_request::Payload as UserAgentRequestPayload,
|
||||
user_agent_response::Payload as UserAgentResponsePayload,
|
||||
},
|
||||
user_agent_request::Payload as UserAgentRequestPayload,
|
||||
user_agent_response::Payload as UserAgentResponsePayload,
|
||||
transport::Bi,
|
||||
};
|
||||
use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, dsl::update};
|
||||
use diesel_async::{AsyncConnection, RunQueryDsl};
|
||||
@@ -21,19 +26,18 @@ use kameo::{
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tonic::Status;
|
||||
use tracing::{error, info};
|
||||
use tonic::{Status, transport::Server};
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use crate::{
|
||||
ServerContext,
|
||||
actors::user_agent::auth::AuthChallenge,
|
||||
context::bootstrap::{BootstrapActor, ConsumeToken},
|
||||
db::{self, schema},
|
||||
errors::GrpcStatusExt,
|
||||
};
|
||||
|
||||
/// Context for state machine with validated key and sent challenge
|
||||
/// Challenge is then transformed to bytes using shared function and verified
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Debug)]
|
||||
pub struct ChallengeContext {
|
||||
challenge: AuthChallenge,
|
||||
key: VerifyingKey,
|
||||
@@ -91,6 +95,8 @@ pub struct UserAgentActor {
|
||||
bootstapper: ActorRef<BootstrapActor>,
|
||||
state: UserAgentStateMachine<DummyContext>,
|
||||
tx: Sender<Result<UserAgentResponse, Status>>,
|
||||
context: ServerContext,
|
||||
ephemeral_key: Option<crate::context::unseal::EphemeralKeyPair>,
|
||||
}
|
||||
|
||||
impl UserAgentActor {
|
||||
@@ -103,12 +109,15 @@ impl UserAgentActor {
|
||||
bootstapper: context.bootstrapper.clone(),
|
||||
state: UserAgentStateMachine::new(DummyContext),
|
||||
tx,
|
||||
context,
|
||||
ephemeral_key: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_manual(
|
||||
db: db::DatabasePool,
|
||||
bootstapper: ActorRef<BootstrapActor>,
|
||||
context: ServerContext,
|
||||
tx: Sender<Result<UserAgentResponse, Status>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
@@ -116,6 +125,8 @@ impl UserAgentActor {
|
||||
bootstapper,
|
||||
state: UserAgentStateMachine::new(DummyContext),
|
||||
tx,
|
||||
context,
|
||||
ephemeral_key: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -161,7 +172,7 @@ impl UserAgentActor {
|
||||
|
||||
self.transition(UserAgentEvents::ReceivedBootstrapToken)?;
|
||||
|
||||
Ok(auth_response(ServerAuthPayload::AuthOk(AuthOk {})))
|
||||
Ok(auth_response(ServerAuthPayload::AuthOk(auth::AuthOk {})))
|
||||
}
|
||||
|
||||
async fn auth_with_challenge(&mut self, pubkey: VerifyingKey, pubkey_bytes: Vec<u8>) -> Output {
|
||||
@@ -201,7 +212,7 @@ impl UserAgentActor {
|
||||
|
||||
let challenge = auth::AuthChallenge {
|
||||
pubkey: pubkey_bytes,
|
||||
nonce: nonce,
|
||||
nonce,
|
||||
};
|
||||
|
||||
self.transition(UserAgentEvents::SentChallenge(ChallengeContext {
|
||||
@@ -296,19 +307,135 @@ impl UserAgentActor {
|
||||
"Client provided valid solution to authentication challenge"
|
||||
);
|
||||
self.transition(UserAgentEvents::ReceivedGoodSolution)?;
|
||||
Ok(auth_response(ServerAuthPayload::AuthOk(AuthOk {})))
|
||||
Ok(auth_response(ServerAuthPayload::AuthOk(auth::AuthOk {})))
|
||||
} else {
|
||||
error!("Client provided invalid solution to authentication challenge");
|
||||
self.transition(UserAgentEvents::ReceivedBadSolution)?;
|
||||
Err(Status::unauthenticated("Invalid challenge solution"))
|
||||
}
|
||||
}
|
||||
|
||||
#[message(ctx)]
|
||||
pub async fn handle_unseal_request(
|
||||
&mut self,
|
||||
request: arbiter_proto::proto::UnsealRequest,
|
||||
ctx: &mut Context<Self, Output>,
|
||||
) -> Output {
|
||||
use arbiter_proto::proto::{
|
||||
EphemeralKeyResponse, SealedPassword, UnsealResponse, UnsealResult,
|
||||
unseal_request::Payload as ReqPayload,
|
||||
unseal_response::Payload as RespPayload,
|
||||
};
|
||||
|
||||
match request.payload {
|
||||
Some(ReqPayload::EphemeralKeyRequest(_)) => {
|
||||
// Generate new ephemeral keypair
|
||||
let keypair = crate::context::unseal::EphemeralKeyPair::generate();
|
||||
let expires_at = keypair.expires_at() as i64;
|
||||
let public_bytes = keypair.public_bytes();
|
||||
|
||||
// Store for later use
|
||||
self.ephemeral_key = Some(keypair);
|
||||
|
||||
info!("Generated ephemeral X25519 keypair for unseal, expires at {}", expires_at);
|
||||
|
||||
Ok(UserAgentResponse {
|
||||
payload: Some(UserAgentResponsePayload::UnsealResponse(UnsealResponse {
|
||||
payload: Some(RespPayload::EphemeralKeyResponse(EphemeralKeyResponse {
|
||||
server_pubkey: public_bytes,
|
||||
expires_at,
|
||||
})),
|
||||
})),
|
||||
})
|
||||
}
|
||||
|
||||
Some(ReqPayload::SealedPassword(sealed)) => {
|
||||
// Get and consume ephemeral key
|
||||
let keypair = self
|
||||
.ephemeral_key
|
||||
.take()
|
||||
.ok_or_else(|| Status::failed_precondition("No ephemeral key generated"))?;
|
||||
|
||||
// Check expiration
|
||||
if keypair.is_expired() {
|
||||
error!("Ephemeral key expired before sealed password was received");
|
||||
return Err(Status::deadline_exceeded("Ephemeral key expired"));
|
||||
}
|
||||
|
||||
// Perform ECDH
|
||||
let shared_secret = keypair
|
||||
.perform_dh(&sealed.client_pubkey)
|
||||
.map_err(|e| Status::invalid_argument(format!("Invalid client pubkey: {}", e)))?;
|
||||
|
||||
// Decrypt password
|
||||
let nonce: [u8; 12] = sealed
|
||||
.nonce
|
||||
.as_slice()
|
||||
.try_into()
|
||||
.map_err(|_| Status::invalid_argument("Nonce must be 12 bytes"))?;
|
||||
|
||||
let password_bytes = crate::crypto::aead::decrypt(
|
||||
&sealed.encrypted_password,
|
||||
&shared_secret,
|
||||
&nonce,
|
||||
)
|
||||
.map_err(|e| {
|
||||
error!("Failed to decrypt password: {}", e);
|
||||
Status::internal(format!("Decryption failed: {}", e))
|
||||
})?;
|
||||
|
||||
let password = String::from_utf8(password_bytes).map_err(|_| {
|
||||
error!("Password is not valid UTF-8");
|
||||
Status::invalid_argument("Password must be UTF-8")
|
||||
})?;
|
||||
|
||||
// Call unseal on context
|
||||
info!("Attempting to unseal vault with decrypted password");
|
||||
let result = self.context.unseal(&password).await;
|
||||
|
||||
match result {
|
||||
Ok(()) => {
|
||||
info!("Vault unsealed successfully");
|
||||
Ok(UserAgentResponse {
|
||||
payload: Some(UserAgentResponsePayload::UnsealResponse(
|
||||
UnsealResponse {
|
||||
payload: Some(RespPayload::UnsealResult(UnsealResult {
|
||||
success: true,
|
||||
error_message: None,
|
||||
})),
|
||||
},
|
||||
)),
|
||||
})
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Unseal failed: {}", e);
|
||||
Ok(UserAgentResponse {
|
||||
payload: Some(UserAgentResponsePayload::UnsealResponse(
|
||||
UnsealResponse {
|
||||
payload: Some(RespPayload::UnsealResult(UnsealResult {
|
||||
success: false,
|
||||
error_message: Some(e.to_string()),
|
||||
})),
|
||||
},
|
||||
)),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None => {
|
||||
error!("Received empty unseal request");
|
||||
Err(Status::invalid_argument("Empty unseal request"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use arbiter_proto::proto::{
|
||||
UserAgentResponse, auth::{AuthChallengeRequest, AuthOk},
|
||||
UserAgentResponse,
|
||||
auth::{AuthChallengeRequest, AuthOk},
|
||||
user_agent_response::Payload as UserAgentResponsePayload,
|
||||
};
|
||||
use kameo::actor::Spawn;
|
||||
@@ -328,9 +455,11 @@ mod tests {
|
||||
let token = bootstrapper.get_token().unwrap();
|
||||
|
||||
let bootstrapper_ref = BootstrapActor::spawn(bootstrapper);
|
||||
let context = crate::ServerContext::new(db.clone()).await.unwrap();
|
||||
let user_agent = UserAgentActor::new_manual(
|
||||
db.clone(),
|
||||
bootstrapper_ref,
|
||||
context,
|
||||
tokio::sync::mpsc::channel(1).0, // dummy channel, we won't actually send responses in this test
|
||||
);
|
||||
let user_agent_ref = UserAgentActor::spawn(user_agent);
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use diesel::OptionalExtension as _;
|
||||
use diesel_async::RunQueryDsl as _;
|
||||
@@ -6,15 +8,17 @@ use ed25519_dalek::VerifyingKey;
|
||||
use kameo::actor::{ActorRef, Spawn};
|
||||
use miette::Diagnostic;
|
||||
use rand::rngs::StdRng;
|
||||
use secrecy::{ExposeSecret, SecretBox};
|
||||
use smlang::statemachine;
|
||||
use thiserror::Error;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::sync::{watch, RwLock};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
use crate::{
|
||||
context::{
|
||||
bootstrap::{BootstrapActor, generate_token},
|
||||
lease::LeaseHandler,
|
||||
tls::{TlsDataRaw, TlsManager},
|
||||
tls::{RotationState, RotationTask, TlsDataRaw, TlsManager},
|
||||
},
|
||||
db::{
|
||||
self,
|
||||
@@ -26,6 +30,7 @@ use crate::{
|
||||
pub(crate) mod bootstrap;
|
||||
pub(crate) mod lease;
|
||||
pub(crate) mod tls;
|
||||
pub(crate) mod unseal;
|
||||
|
||||
#[derive(Error, Debug, Diagnostic)]
|
||||
pub enum InitError {
|
||||
@@ -54,8 +59,66 @@ pub enum InitError {
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
// TODO: Placeholder for secure root key cell implementation
|
||||
pub struct KeyStorage;
|
||||
#[derive(Error, Debug, Diagnostic)]
|
||||
pub enum UnsealError {
|
||||
#[error("Database error: {0}")]
|
||||
#[diagnostic(code(arbiter_server::unseal::database_pool))]
|
||||
Database(#[from] db::PoolError),
|
||||
|
||||
#[error("Query error: {0}")]
|
||||
#[diagnostic(code(arbiter_server::unseal::database_query))]
|
||||
Query(#[from] diesel::result::Error),
|
||||
|
||||
#[error("Decryption failed: {0}")]
|
||||
#[diagnostic(code(arbiter_server::unseal::decryption))]
|
||||
DecryptionFailed(#[from] crate::crypto::CryptoError),
|
||||
|
||||
#[error("Invalid state for unseal")]
|
||||
#[diagnostic(code(arbiter_server::unseal::invalid_state))]
|
||||
InvalidState,
|
||||
|
||||
#[error("Missing salt in database")]
|
||||
#[diagnostic(code(arbiter_server::unseal::missing_salt))]
|
||||
MissingSalt,
|
||||
|
||||
#[error("No root key configured in database")]
|
||||
#[diagnostic(code(arbiter_server::unseal::no_root_key))]
|
||||
NoRootKey,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug, Diagnostic)]
|
||||
pub enum SealError {
|
||||
#[error("Invalid state for seal")]
|
||||
#[diagnostic(code(arbiter_server::seal::invalid_state))]
|
||||
InvalidState,
|
||||
}
|
||||
|
||||
/// Secure in-memory storage for root encryption key
|
||||
///
|
||||
/// Uses `secrecy` crate for automatic zeroization on drop to prevent key material
|
||||
/// from remaining in memory after use. SecretBox provides heap-allocated secret
|
||||
/// storage that implements Send + Sync for safe use in async contexts.
|
||||
pub struct KeyStorage {
|
||||
/// 32-byte root key protected by SecretBox
|
||||
key: SecretBox<[u8; 32]>,
|
||||
}
|
||||
|
||||
impl KeyStorage {
|
||||
/// Create new KeyStorage from a 32-byte root key
|
||||
pub fn new(key: [u8; 32]) -> Self {
|
||||
Self {
|
||||
key: SecretBox::new(Box::new(key)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Access the key for cryptographic operations
|
||||
pub fn key(&self) -> &[u8; 32] {
|
||||
self.key.expose_secret()
|
||||
}
|
||||
}
|
||||
|
||||
// Drop автоматически реализован через secrecy::Zeroize
|
||||
// который зануляет память при освобождении
|
||||
|
||||
statemachine! {
|
||||
name: Server,
|
||||
@@ -67,14 +130,20 @@ statemachine! {
|
||||
}
|
||||
pub struct _Context;
|
||||
impl ServerStateMachineContext for _Context {
|
||||
fn move_key(&mut self, _event_data: KeyStorage) -> Result<KeyStorage, ()> {
|
||||
todo!()
|
||||
/// Move key from unseal event into Ready state
|
||||
fn move_key(&mut self, event_data: KeyStorage) -> Result<KeyStorage, ()> {
|
||||
// Просто перемещаем KeyStorage из event в state
|
||||
// Без клонирования - event data consumed
|
||||
Ok(event_data)
|
||||
}
|
||||
|
||||
/// Securely dispose of key when sealing
|
||||
#[allow(missing_docs)]
|
||||
#[allow(clippy::unused_unit)]
|
||||
fn dispose_key(&mut self, _state_data: &KeyStorage) -> Result<(), ()> {
|
||||
todo!()
|
||||
// KeyStorage будет dropped после state transition
|
||||
// secrecy::Zeroize зануляет память автоматически
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,8 +151,12 @@ pub(crate) struct _ServerContextInner {
|
||||
pub db: db::DatabasePool,
|
||||
pub state: RwLock<ServerStateMachine<_Context>>,
|
||||
pub rng: StdRng,
|
||||
pub tls: TlsManager,
|
||||
pub tls: Arc<TlsManager>,
|
||||
pub bootstrapper: ActorRef<BootstrapActor>,
|
||||
pub rotation_state: RwLock<RotationState>,
|
||||
pub rotation_acks: Arc<RwLock<HashSet<VerifyingKey>>>,
|
||||
pub user_agent_leases: LeaseHandler<VerifyingKey>,
|
||||
pub client_leases: LeaseHandler<VerifyingKey>,
|
||||
}
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ServerContext(Arc<_ServerContextInner>);
|
||||
@@ -97,34 +170,49 @@ impl std::ops::Deref for ServerContext {
|
||||
}
|
||||
|
||||
impl ServerContext {
|
||||
/// Check if all active clients have acknowledged the rotation
|
||||
pub async fn check_rotation_ready(&self) -> bool {
|
||||
// TODO: Implement proper rotation readiness check
|
||||
// For now, return false as placeholder
|
||||
false
|
||||
}
|
||||
|
||||
async fn load_tls(
|
||||
db: &mut db::DatabaseConnection,
|
||||
db: &db::DatabasePool,
|
||||
settings: Option<&ArbiterSetting>,
|
||||
) -> Result<TlsManager, InitError> {
|
||||
match &settings {
|
||||
Some(settings) => {
|
||||
match settings {
|
||||
Some(s) if s.current_cert_id.is_some() => {
|
||||
// Load active certificate from tls_certificates table
|
||||
Ok(TlsManager::load_from_db(
|
||||
db.clone(),
|
||||
s.current_cert_id.unwrap(),
|
||||
)
|
||||
.await?)
|
||||
}
|
||||
Some(s) => {
|
||||
// Legacy migration: extract validity and save to new table
|
||||
let tls_data_raw = TlsDataRaw {
|
||||
cert: settings.cert.clone(),
|
||||
key: settings.cert_key.clone(),
|
||||
cert: s.cert.clone(),
|
||||
key: s.cert_key.clone(),
|
||||
};
|
||||
|
||||
Ok(TlsManager::new(Some(tls_data_raw)).await?)
|
||||
// For legacy certificates, use current time as not_before
|
||||
// and current time + 90 days as not_after
|
||||
let not_before = chrono::Utc::now().timestamp();
|
||||
let not_after = not_before + (90 * 24 * 60 * 60); // 90 days
|
||||
|
||||
Ok(TlsManager::new_from_legacy(
|
||||
db.clone(),
|
||||
tls_data_raw,
|
||||
not_before,
|
||||
not_after,
|
||||
)
|
||||
.await?)
|
||||
}
|
||||
None => {
|
||||
let tls = TlsManager::new(None).await?;
|
||||
let tls_data_raw = tls.bytes();
|
||||
|
||||
diesel::insert_into(arbiter_settings::table)
|
||||
.values(&ArbiterSetting {
|
||||
id: 1,
|
||||
root_key_id: None,
|
||||
cert_key: tls_data_raw.key,
|
||||
cert: tls_data_raw.cert,
|
||||
})
|
||||
.execute(db)
|
||||
.await?;
|
||||
|
||||
Ok(tls)
|
||||
// First startup - generate new certificate
|
||||
Ok(TlsManager::new(db.clone()).await?)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -138,10 +226,18 @@ impl ServerContext {
|
||||
.await
|
||||
.optional()?;
|
||||
|
||||
let tls = Self::load_tls(&mut conn, settings.as_ref()).await?;
|
||||
|
||||
drop(conn);
|
||||
|
||||
// Load TLS manager
|
||||
let tls = Self::load_tls(&db, settings.as_ref()).await?;
|
||||
|
||||
// Load rotation state from database
|
||||
let rotation_state = RotationState::load_from_db(&db)
|
||||
.await
|
||||
.unwrap_or(RotationState::Normal);
|
||||
|
||||
let bootstrap_token = generate_token().await?;
|
||||
|
||||
let mut state = ServerStateMachine::new(_Context);
|
||||
|
||||
if let Some(settings) = &settings
|
||||
@@ -151,12 +247,157 @@ impl ServerContext {
|
||||
let _ = state.process_event(ServerEvents::Bootstrapped);
|
||||
}
|
||||
|
||||
Ok(Self(Arc::new(_ServerContextInner {
|
||||
bootstrapper: BootstrapActor::spawn(BootstrapActor::new(&db).await?),
|
||||
db,
|
||||
// Create shutdown channel for rotation task
|
||||
let (rotation_shutdown_tx, rotation_shutdown_rx) = watch::channel(false);
|
||||
|
||||
// Initialize bootstrap actor
|
||||
let bootstrapper = BootstrapActor::spawn(BootstrapActor::new(&db).await?);
|
||||
|
||||
let context = Arc::new(_ServerContextInner {
|
||||
db: db.clone(),
|
||||
rng,
|
||||
tls,
|
||||
tls: Arc::new(tls),
|
||||
state: RwLock::new(state),
|
||||
})))
|
||||
bootstrapper,
|
||||
rotation_state: RwLock::new(rotation_state),
|
||||
rotation_acks: Arc::new(RwLock::new(HashSet::new())),
|
||||
user_agent_leases: Default::default(),
|
||||
client_leases: Default::default(),
|
||||
});
|
||||
|
||||
Ok(Self(context))
|
||||
}
|
||||
|
||||
/// Unseal vault with password
|
||||
pub async fn unseal(&self, password: &str) -> Result<(), UnsealError> {
|
||||
use crate::crypto::root_key;
|
||||
use diesel::QueryDsl as _;
|
||||
|
||||
// 1. Get root_key_id from settings
|
||||
let mut conn = self.db.get().await?;
|
||||
|
||||
let settings: db::models::ArbiterSetting = schema::arbiter_settings::table
|
||||
.first(&mut conn)
|
||||
.await?;
|
||||
|
||||
let root_key_id = settings.root_key_id.ok_or(UnsealError::NoRootKey)?;
|
||||
|
||||
// 2. Load encrypted root key
|
||||
let encrypted: db::models::AeadEncrypted = schema::aead_encrypted::table
|
||||
.find(root_key_id)
|
||||
.first(&mut conn)
|
||||
.await?;
|
||||
|
||||
let salt = encrypted
|
||||
.argon2_salt
|
||||
.as_ref()
|
||||
.ok_or(UnsealError::MissingSalt)?;
|
||||
|
||||
drop(conn);
|
||||
|
||||
// 3. Decrypt root key using password
|
||||
let root_key = root_key::decrypt_root_key(&encrypted, password, salt)
|
||||
.map_err(UnsealError::DecryptionFailed)?;
|
||||
|
||||
// 4. Create secure storage
|
||||
let key_storage = KeyStorage::new(root_key);
|
||||
|
||||
// 5. Transition state machine
|
||||
let mut state = self.state.write().await;
|
||||
state
|
||||
.process_event(ServerEvents::Unsealed(key_storage))
|
||||
.map_err(|_| UnsealError::InvalidState)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Seal the server (lock the key)
|
||||
pub async fn seal(&self) -> Result<(), SealError> {
|
||||
let mut state = self.state.write().await;
|
||||
state
|
||||
.process_event(ServerEvents::Sealed)
|
||||
.map_err(|_| SealError::InvalidState)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_keystorage_creation() {
|
||||
let key = [42u8; 32];
|
||||
let storage = KeyStorage::new(key);
|
||||
assert_eq!(storage.key()[0], 42);
|
||||
assert_eq!(storage.key().len(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keystorage_zeroization() {
|
||||
let key = [99u8; 32];
|
||||
{
|
||||
let _storage = KeyStorage::new(key);
|
||||
// storage будет dropped здесь
|
||||
}
|
||||
// После drop SecretBox должен зануляеть память
|
||||
// Это проверяется автоматически через secrecy::Zeroize
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_state_machine_transitions() {
|
||||
let mut state = ServerStateMachine::new(_Context);
|
||||
|
||||
// Начальное состояние
|
||||
assert!(matches!(state.state(), &ServerStates::NotBootstrapped));
|
||||
|
||||
// Bootstrapped transition
|
||||
state.process_event(ServerEvents::Bootstrapped).unwrap();
|
||||
assert!(matches!(state.state(), &ServerStates::Sealed));
|
||||
|
||||
// Unsealed transition
|
||||
let key_storage = KeyStorage::new([1u8; 32]);
|
||||
state
|
||||
.process_event(ServerEvents::Unsealed(key_storage))
|
||||
.unwrap();
|
||||
assert!(matches!(state.state(), &ServerStates::Ready(_)));
|
||||
|
||||
// Sealed transition
|
||||
state.process_event(ServerEvents::Sealed).unwrap();
|
||||
assert!(matches!(state.state(), &ServerStates::Sealed));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_move_key_callback() {
|
||||
let mut ctx = _Context;
|
||||
let key_storage = KeyStorage::new([7u8; 32]);
|
||||
let result = ctx.move_key(key_storage);
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().key()[0], 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dispose_key_callback() {
|
||||
let mut ctx = _Context;
|
||||
let key_storage = KeyStorage::new([13u8; 32]);
|
||||
let result = ctx.dispose_key(&key_storage);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_state_transitions() {
|
||||
let mut state = ServerStateMachine::new(_Context);
|
||||
|
||||
// Попытка unseal без bootstrap
|
||||
let key_storage = KeyStorage::new([1u8; 32]);
|
||||
let result = state.process_event(ServerEvents::Unsealed(key_storage));
|
||||
assert!(result.is_err());
|
||||
|
||||
// Правильный путь
|
||||
state.process_event(ServerEvents::Bootstrapped).unwrap();
|
||||
|
||||
// Попытка повторного bootstrap
|
||||
let result = state.process_event(ServerEvents::Bootstrapped);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,4 +38,9 @@ impl<T: Clone + std::hash::Hash + Eq> LeaseHandler<T> {
|
||||
Err(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all currently leased items
|
||||
pub fn get_all(&self) -> Vec<T> {
|
||||
self.storage.0.iter().map(|entry| entry.clone()).collect()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
use std::string::FromUtf8Error;
|
||||
|
||||
use miette::Diagnostic;
|
||||
use rcgen::{Certificate, KeyPair};
|
||||
use rustls::pki_types::CertificateDer;
|
||||
use thiserror::Error;
|
||||
|
||||
|
||||
#[derive(Error, Debug, Diagnostic)]
|
||||
pub enum TlsInitError {
|
||||
#[error("Key generation error during TLS initialization: {0}")]
|
||||
#[diagnostic(code(arbiter_server::tls_init::key_generation))]
|
||||
KeyGeneration(#[from] rcgen::Error),
|
||||
|
||||
#[error("Key invalid format: {0}")]
|
||||
#[diagnostic(code(arbiter_server::tls_init::key_invalid_format))]
|
||||
KeyInvalidFormat(#[from] FromUtf8Error),
|
||||
|
||||
#[error("Key deserialization error: {0}")]
|
||||
#[diagnostic(code(arbiter_server::tls_init::key_deserialization))]
|
||||
KeyDeserializationError(rcgen::Error),
|
||||
}
|
||||
|
||||
pub struct TlsData {
|
||||
pub cert: CertificateDer<'static>,
|
||||
pub keypair: KeyPair,
|
||||
}
|
||||
|
||||
pub struct TlsDataRaw {
|
||||
pub cert: Vec<u8>,
|
||||
pub key: Vec<u8>,
|
||||
}
|
||||
impl TlsDataRaw {
|
||||
pub fn serialize(cert: &TlsData) -> Self {
|
||||
Self {
|
||||
cert: cert.cert.as_ref().to_vec(),
|
||||
key: cert.keypair.serialize_pem().as_bytes().to_vec(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize(&self) -> Result<TlsData, TlsInitError> {
|
||||
let cert = CertificateDer::from_slice(&self.cert).into_owned();
|
||||
|
||||
let key =
|
||||
String::from_utf8(self.key.clone()).map_err(TlsInitError::KeyInvalidFormat)?;
|
||||
|
||||
let keypair = KeyPair::from_pem(&key).map_err(TlsInitError::KeyDeserializationError)?;
|
||||
|
||||
Ok(TlsData { cert, keypair })
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_cert(key: &KeyPair) -> Result<Certificate, rcgen::Error> {
|
||||
let params = rcgen::CertificateParams::new(vec![
|
||||
"arbiter.local".to_string(),
|
||||
"localhost".to_string(),
|
||||
])?;
|
||||
|
||||
params.self_signed(key)
|
||||
}
|
||||
|
||||
// TODO: Implement cert rotation
|
||||
pub(crate) struct TlsManager {
|
||||
data: TlsData,
|
||||
}
|
||||
|
||||
impl TlsManager {
|
||||
pub async fn new(data: Option<TlsDataRaw>) -> Result<Self, TlsInitError> {
|
||||
match data {
|
||||
Some(raw) => {
|
||||
let tls_data = raw.deserialize()?;
|
||||
Ok(Self { data: tls_data })
|
||||
}
|
||||
None => {
|
||||
let keypair = KeyPair::generate()?;
|
||||
let cert = generate_cert(&keypair)?;
|
||||
let tls_data = TlsData {
|
||||
cert: cert.der().clone(),
|
||||
keypair,
|
||||
};
|
||||
Ok(Self { data: tls_data })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bytes(&self) -> TlsDataRaw {
|
||||
TlsDataRaw::serialize(&self.data)
|
||||
}
|
||||
}
|
||||
192
server/crates/arbiter-server/src/context/tls/mod.rs
Normal file
192
server/crates/arbiter-server/src/context/tls/mod.rs
Normal file
@@ -0,0 +1,192 @@
|
||||
use std::sync::Arc;
|
||||
use std::string::FromUtf8Error;
|
||||
|
||||
use miette::Diagnostic;
|
||||
use rcgen::{Certificate, KeyPair};
|
||||
use rustls::pki_types::CertificateDer;
|
||||
use thiserror::Error;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::db;
|
||||
|
||||
pub mod rotation;
|
||||
|
||||
pub use rotation::{RotationError, RotationState, RotationTask};
|
||||
|
||||
#[derive(Error, Debug, Diagnostic)]
|
||||
#[expect(clippy::enum_variant_names)]
|
||||
pub enum TlsInitError {
|
||||
#[error("Key generation error during TLS initialization: {0}")]
|
||||
#[diagnostic(code(arbiter_server::tls_init::key_generation))]
|
||||
KeyGeneration(#[from] rcgen::Error),
|
||||
|
||||
#[error("Key invalid format: {0}")]
|
||||
#[diagnostic(code(arbiter_server::tls_init::key_invalid_format))]
|
||||
KeyInvalidFormat(#[from] FromUtf8Error),
|
||||
|
||||
#[error("Key deserialization error: {0}")]
|
||||
#[diagnostic(code(arbiter_server::tls_init::key_deserialization))]
|
||||
KeyDeserializationError(rcgen::Error),
|
||||
}
|
||||
|
||||
pub struct TlsData {
|
||||
pub cert: CertificateDer<'static>,
|
||||
pub keypair: KeyPair,
|
||||
}
|
||||
|
||||
pub struct TlsDataRaw {
|
||||
pub cert: Vec<u8>,
|
||||
pub key: Vec<u8>,
|
||||
}
|
||||
impl TlsDataRaw {
|
||||
pub fn serialize(cert: &TlsData) -> Self {
|
||||
Self {
|
||||
cert: cert.cert.as_ref().to_vec(),
|
||||
key: cert.keypair.serialize_pem().as_bytes().to_vec(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize(&self) -> Result<TlsData, TlsInitError> {
|
||||
let cert = CertificateDer::from_slice(&self.cert).into_owned();
|
||||
|
||||
let key =
|
||||
String::from_utf8(self.key.clone()).map_err(TlsInitError::KeyInvalidFormat)?;
|
||||
|
||||
let keypair = KeyPair::from_pem(&key).map_err(TlsInitError::KeyDeserializationError)?;
|
||||
|
||||
Ok(TlsData { cert, keypair })
|
||||
}
|
||||
}
|
||||
|
||||
/// Metadata about a certificate including validity period
|
||||
pub struct CertificateMetadata {
|
||||
pub cert_id: i32,
|
||||
pub cert: CertificateDer<'static>,
|
||||
pub keypair: Arc<KeyPair>,
|
||||
pub not_before: i64,
|
||||
pub not_after: i64,
|
||||
pub created_at: i64,
|
||||
}
|
||||
|
||||
pub(crate) fn generate_cert(key: &KeyPair) -> Result<(Certificate, i64, i64), rcgen::Error> {
|
||||
let params = rcgen::CertificateParams::new(vec![
|
||||
"arbiter.local".to_string(),
|
||||
"localhost".to_string(),
|
||||
])?;
|
||||
|
||||
// Set validity period: 90 days from now
|
||||
let not_before = chrono::Utc::now();
|
||||
let not_after = not_before + chrono::Duration::days(90);
|
||||
|
||||
// Note: rcgen doesn't directly expose not_before/not_after setting in all versions
|
||||
// For now, we'll generate the cert and track validity separately
|
||||
let cert = params.self_signed(key)?;
|
||||
|
||||
Ok((cert, not_before.timestamp(), not_after.timestamp()))
|
||||
}
|
||||
|
||||
// Certificate rotation enabled
|
||||
pub(crate) struct TlsManager {
|
||||
// Current active certificate (atomic replacement via RwLock)
|
||||
current_cert: Arc<RwLock<CertificateMetadata>>,
|
||||
|
||||
// Database pool for persistence
|
||||
db: db::DatabasePool,
|
||||
}
|
||||
|
||||
impl TlsManager {
|
||||
/// Create new TlsManager with a generated certificate
|
||||
pub async fn new(db: db::DatabasePool) -> Result<Self, TlsInitError> {
|
||||
let keypair = KeyPair::generate()?;
|
||||
let (cert, not_before, not_after) = generate_cert(&keypair)?;
|
||||
let cert_der = cert.der().clone();
|
||||
|
||||
// For initial creation, cert_id will be set after DB insert
|
||||
let metadata = CertificateMetadata {
|
||||
cert_id: 0, // Temporary, will be updated after DB insert
|
||||
cert: cert_der,
|
||||
keypair: Arc::new(keypair),
|
||||
not_before,
|
||||
not_after,
|
||||
created_at: chrono::Utc::now().timestamp(),
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
current_cert: Arc::new(RwLock::new(metadata)),
|
||||
db,
|
||||
})
|
||||
}
|
||||
|
||||
/// Load TlsManager from database with specific certificate ID
|
||||
pub async fn load_from_db(db: db::DatabasePool, cert_id: i32) -> Result<Self, TlsInitError> {
|
||||
// TODO: Load certificate from database
|
||||
// For now, return error - will be implemented when database access is ready
|
||||
Err(TlsInitError::KeyGeneration(rcgen::Error::CouldNotParseCertificate))
|
||||
}
|
||||
|
||||
/// Create from legacy TlsDataRaw format
|
||||
pub async fn new_from_legacy(
|
||||
db: db::DatabasePool,
|
||||
data: TlsDataRaw,
|
||||
not_before: i64,
|
||||
not_after: i64,
|
||||
) -> Result<Self, TlsInitError> {
|
||||
let tls_data = data.deserialize()?;
|
||||
|
||||
let metadata = CertificateMetadata {
|
||||
cert_id: 1, // Legacy certificate gets ID 1
|
||||
cert: tls_data.cert,
|
||||
keypair: Arc::new(tls_data.keypair),
|
||||
not_before,
|
||||
not_after,
|
||||
created_at: chrono::Utc::now().timestamp(),
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
current_cert: Arc::new(RwLock::new(metadata)),
|
||||
db,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get current certificate data
|
||||
pub async fn get_certificate(&self) -> (CertificateDer<'static>, Arc<KeyPair>) {
|
||||
let cert = self.current_cert.read().await;
|
||||
(cert.cert.clone(), cert.keypair.clone())
|
||||
}
|
||||
|
||||
/// Replace certificate atomically
|
||||
pub async fn replace_certificate(&self, new_cert: CertificateMetadata) -> Result<(), TlsInitError> {
|
||||
let mut cert = self.current_cert.write().await;
|
||||
*cert = new_cert;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if certificate is expiring soon
|
||||
pub async fn check_expiration(&self, threshold_secs: i64) -> bool {
|
||||
let cert = self.current_cert.read().await;
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
cert.not_after - now < threshold_secs
|
||||
}
|
||||
|
||||
/// Get certificate metadata for rotation logic
|
||||
pub async fn get_certificate_metadata(&self) -> CertificateMetadata {
|
||||
let cert = self.current_cert.read().await;
|
||||
CertificateMetadata {
|
||||
cert_id: cert.cert_id,
|
||||
cert: cert.cert.clone(),
|
||||
keypair: cert.keypair.clone(),
|
||||
not_before: cert.not_before,
|
||||
not_after: cert.not_after,
|
||||
created_at: cert.created_at,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bytes(&self) -> TlsDataRaw {
|
||||
// This method is now async-compatible but we keep sync interface
|
||||
// TODO: Make this async or remove if not needed
|
||||
TlsDataRaw {
|
||||
cert: vec![],
|
||||
key: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
552
server/crates/arbiter-server/src/context/tls/rotation.rs
Normal file
552
server/crates/arbiter-server/src/context/tls/rotation.rs
Normal file
@@ -0,0 +1,552 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use diesel::prelude::*;
|
||||
use diesel_async::RunQueryDsl;
|
||||
use ed25519_dalek::VerifyingKey;
|
||||
use miette::Diagnostic;
|
||||
use rcgen::KeyPair;
|
||||
use thiserror::Error;
|
||||
use tokio::sync::watch;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::context::ServerContext;
|
||||
use crate::db::models::{NewRotationClientAck, NewTlsCertificate, NewTlsRotationHistory};
|
||||
use crate::db::schema::{rotation_client_acks, tls_certificates, tls_rotation_history, tls_rotation_state};
|
||||
use crate::db::DatabasePool;
|
||||
|
||||
use super::{generate_cert, CertificateMetadata, TlsInitError};
|
||||
|
||||
#[derive(Error, Debug, Diagnostic)]
|
||||
pub enum RotationError {
|
||||
#[error("Certificate generation failed: {0}")]
|
||||
#[diagnostic(code(arbiter_server::rotation::cert_generation))]
|
||||
CertGeneration(#[from] rcgen::Error),
|
||||
|
||||
#[error("Database error: {0}")]
|
||||
#[diagnostic(code(arbiter_server::rotation::database))]
|
||||
Database(#[from] diesel::result::Error),
|
||||
|
||||
#[error("TLS initialization error: {0}")]
|
||||
#[diagnostic(code(arbiter_server::rotation::tls_init))]
|
||||
TlsInit(#[from] TlsInitError),
|
||||
|
||||
#[error("Invalid rotation state: {0}")]
|
||||
#[diagnostic(code(arbiter_server::rotation::invalid_state))]
|
||||
InvalidState(String),
|
||||
|
||||
#[error("No active certificate found")]
|
||||
#[diagnostic(code(arbiter_server::rotation::no_active_cert))]
|
||||
NoActiveCertificate,
|
||||
}
|
||||
|
||||
/// Состояние процесса ротации сертификата
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum RotationState {
|
||||
/// Обычная работа, ротация не требуется
|
||||
Normal,
|
||||
|
||||
/// Ротация инициирована, новый сертификат сгенерирован
|
||||
RotationInitiated {
|
||||
initiated_at: i64,
|
||||
new_cert_id: i32,
|
||||
},
|
||||
|
||||
/// Ожидание подтверждений (ACKs) от клиентов
|
||||
WaitingForAcks {
|
||||
new_cert_id: i32,
|
||||
initiated_at: i64,
|
||||
timeout_at: i64,
|
||||
},
|
||||
|
||||
/// Все ACK получены или таймаут истёк, готов к ротации
|
||||
ReadyToRotate {
|
||||
new_cert_id: i32,
|
||||
},
|
||||
}
|
||||
|
||||
impl RotationState {
|
||||
/// Загрузить состояние из базы данных
|
||||
pub async fn load_from_db(db: &DatabasePool) -> Result<Self, RotationError> {
|
||||
use crate::db::schema::tls_rotation_state::dsl::*;
|
||||
|
||||
let mut conn = db.get().await.map_err(|e| {
|
||||
RotationError::InvalidState(format!("Failed to get DB connection: {}", e))
|
||||
})?;
|
||||
|
||||
let state_record: (i32, String, Option<i32>, Option<i32>, Option<i32>) =
|
||||
tls_rotation_state
|
||||
.select((id, state, new_cert_id, initiated_at, timeout_at))
|
||||
.filter(id.eq(1))
|
||||
.first(&mut conn)
|
||||
.await?;
|
||||
|
||||
let rotation_state = match state_record.1.as_str() {
|
||||
"normal" => RotationState::Normal,
|
||||
"initiated" => {
|
||||
let cert_id = state_record.2.ok_or_else(|| {
|
||||
RotationError::InvalidState("Initiated state missing new_cert_id".into())
|
||||
})?;
|
||||
let init_at = state_record.3.ok_or_else(|| {
|
||||
RotationError::InvalidState("Initiated state missing initiated_at".into())
|
||||
})?;
|
||||
RotationState::RotationInitiated {
|
||||
initiated_at: init_at as i64,
|
||||
new_cert_id: cert_id,
|
||||
}
|
||||
}
|
||||
"waiting_acks" => {
|
||||
let cert_id = state_record.2.ok_or_else(|| {
|
||||
RotationError::InvalidState("WaitingForAcks state missing new_cert_id".into())
|
||||
})?;
|
||||
let init_at = state_record.3.ok_or_else(|| {
|
||||
RotationError::InvalidState("WaitingForAcks state missing initiated_at".into())
|
||||
})?;
|
||||
let timeout = state_record.4.ok_or_else(|| {
|
||||
RotationError::InvalidState("WaitingForAcks state missing timeout_at".into())
|
||||
})?;
|
||||
RotationState::WaitingForAcks {
|
||||
new_cert_id: cert_id,
|
||||
initiated_at: init_at as i64,
|
||||
timeout_at: timeout as i64,
|
||||
}
|
||||
}
|
||||
"ready" => {
|
||||
let cert_id = state_record.2.ok_or_else(|| {
|
||||
RotationError::InvalidState("Ready state missing new_cert_id".into())
|
||||
})?;
|
||||
RotationState::ReadyToRotate {
|
||||
new_cert_id: cert_id,
|
||||
}
|
||||
}
|
||||
other => {
|
||||
return Err(RotationError::InvalidState(format!(
|
||||
"Unknown state: {}",
|
||||
other
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(rotation_state)
|
||||
}
|
||||
|
||||
/// Сохранить состояние в базу данных
|
||||
pub async fn save_to_db(&self, db: &DatabasePool) -> Result<(), RotationError> {
|
||||
use crate::db::schema::tls_rotation_state::dsl::*;
|
||||
|
||||
let mut conn = db.get().await.map_err(|e| {
|
||||
RotationError::InvalidState(format!("Failed to get DB connection: {}", e))
|
||||
})?;
|
||||
|
||||
let (state_str, cert_id, init_at, timeout) = match self {
|
||||
RotationState::Normal => ("normal", None, None, None),
|
||||
RotationState::RotationInitiated {
|
||||
initiated_at: init,
|
||||
new_cert_id: cert,
|
||||
} => ("initiated", Some(*cert), Some(*init as i32), None),
|
||||
RotationState::WaitingForAcks {
|
||||
new_cert_id: cert,
|
||||
initiated_at: init,
|
||||
timeout_at: timeout_val,
|
||||
} => (
|
||||
"waiting_acks",
|
||||
Some(*cert),
|
||||
Some(*init as i32),
|
||||
Some(*timeout_val as i32),
|
||||
),
|
||||
RotationState::ReadyToRotate { new_cert_id: cert } => ("ready", Some(*cert), None, None),
|
||||
};
|
||||
|
||||
diesel::update(tls_rotation_state.filter(id.eq(1)))
|
||||
.set((
|
||||
state.eq(state_str),
|
||||
new_cert_id.eq(cert_id),
|
||||
initiated_at.eq(init_at),
|
||||
timeout_at.eq(timeout),
|
||||
))
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Фоновый таск для автоматической ротации сертификатов
|
||||
pub struct RotationTask {
|
||||
context: Arc<crate::context::_ServerContextInner>,
|
||||
check_interval: Duration,
|
||||
rotation_threshold: Duration,
|
||||
ack_timeout: Duration,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl RotationTask {
|
||||
/// Создать новый rotation task
|
||||
pub fn new(
|
||||
context: Arc<crate::context::_ServerContextInner>,
|
||||
check_interval: Duration,
|
||||
rotation_threshold: Duration,
|
||||
ack_timeout: Duration,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
) -> Self {
|
||||
Self {
|
||||
context,
|
||||
check_interval,
|
||||
rotation_threshold,
|
||||
ack_timeout,
|
||||
shutdown_rx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Запустить фоновый таск мониторинга и ротации
|
||||
pub async fn run(mut self) -> Result<(), RotationError> {
|
||||
info!("Starting TLS certificate rotation task");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = tokio::time::sleep(self.check_interval) => {
|
||||
if let Err(e) = self.check_and_process().await {
|
||||
error!("Rotation task error: {}", e);
|
||||
}
|
||||
}
|
||||
_ = self.shutdown_rx.changed() => {
|
||||
info!("Rotation task shutting down");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Проверить текущее состояние и выполнить необходимые действия
|
||||
async fn check_and_process(&self) -> Result<(), RotationError> {
|
||||
let state = self.context.rotation_state.read().await.clone();
|
||||
|
||||
match state {
|
||||
RotationState::Normal => {
|
||||
// Проверить, нужна ли ротация
|
||||
self.check_expiration_and_initiate().await?;
|
||||
}
|
||||
RotationState::RotationInitiated { new_cert_id, .. } => {
|
||||
// Автоматически перейти в WaitingForAcks
|
||||
self.transition_to_waiting_acks(new_cert_id).await?;
|
||||
}
|
||||
RotationState::WaitingForAcks {
|
||||
new_cert_id,
|
||||
timeout_at,
|
||||
..
|
||||
} => {
|
||||
self.handle_waiting_for_acks(new_cert_id, timeout_at).await?;
|
||||
}
|
||||
RotationState::ReadyToRotate { new_cert_id } => {
|
||||
self.execute_rotation(new_cert_id).await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Проверить срок действия сертификата и инициировать ротацию если нужно
|
||||
async fn check_expiration_and_initiate(&self) -> Result<(), RotationError> {
|
||||
let threshold_secs = self.rotation_threshold.as_secs() as i64;
|
||||
|
||||
if self.context.tls.check_expiration(threshold_secs).await {
|
||||
info!("Certificate expiring soon, initiating rotation");
|
||||
self.initiate_rotation().await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Инициировать ротацию: сгенерировать новый сертификат и сохранить в БД
|
||||
pub async fn initiate_rotation(&self) -> Result<i32, RotationError> {
|
||||
info!("Initiating certificate rotation");
|
||||
|
||||
// 1. Генерация нового сертификата
|
||||
let keypair = KeyPair::generate()?;
|
||||
let (cert, not_before, not_after) = generate_cert(&keypair)?;
|
||||
let cert_der = cert.der().clone();
|
||||
|
||||
// 2. Сохранение в БД (is_active = false, пока не активирован)
|
||||
let new_cert_id = self
|
||||
.save_new_certificate(&cert_der, &keypair, not_before, not_after)
|
||||
.await?;
|
||||
|
||||
info!(new_cert_id, "New certificate generated and saved");
|
||||
|
||||
// 3. Обновление rotation_state
|
||||
let new_state = RotationState::RotationInitiated {
|
||||
initiated_at: chrono::Utc::now().timestamp(),
|
||||
new_cert_id,
|
||||
};
|
||||
*self.context.rotation_state.write().await = new_state.clone();
|
||||
new_state.save_to_db(&self.context.db).await?;
|
||||
|
||||
// 4. Логирование в audit trail
|
||||
self.log_rotation_event(new_cert_id, "rotation_initiated", None)
|
||||
.await?;
|
||||
|
||||
Ok(new_cert_id)
|
||||
}
|
||||
|
||||
/// Перейти в состояние WaitingForAcks и разослать уведомления
|
||||
async fn transition_to_waiting_acks(&self, new_cert_id: i32) -> Result<(), RotationError> {
|
||||
info!(new_cert_id, "Transitioning to WaitingForAcks state");
|
||||
|
||||
let initiated_at = chrono::Utc::now().timestamp();
|
||||
let timeout_at = initiated_at + self.ack_timeout.as_secs() as i64;
|
||||
|
||||
// Обновить состояние
|
||||
let new_state = RotationState::WaitingForAcks {
|
||||
new_cert_id,
|
||||
initiated_at,
|
||||
timeout_at,
|
||||
};
|
||||
*self.context.rotation_state.write().await = new_state.clone();
|
||||
new_state.save_to_db(&self.context.db).await?;
|
||||
|
||||
// TODO: Broadcast уведомлений клиентам
|
||||
// self.broadcast_rotation_notification(new_cert_id, timeout_at).await?;
|
||||
|
||||
info!(timeout_at, "Rotation notifications sent, waiting for ACKs");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Обработка состояния WaitingForAcks: проверка ACKs и таймаута
|
||||
async fn handle_waiting_for_acks(
|
||||
&self,
|
||||
new_cert_id: i32,
|
||||
timeout_at: i64,
|
||||
) -> Result<(), RotationError> {
|
||||
let now = chrono::Utc::now().timestamp();
|
||||
|
||||
// Проверить таймаут
|
||||
if now > timeout_at {
|
||||
let missing = self.get_missing_acks(new_cert_id).await?;
|
||||
warn!(
|
||||
missing_count = missing.len(),
|
||||
"Rotation ACK timeout reached, proceeding with rotation"
|
||||
);
|
||||
|
||||
// Переход в ReadyToRotate
|
||||
let new_state = RotationState::ReadyToRotate { new_cert_id };
|
||||
*self.context.rotation_state.write().await = new_state.clone();
|
||||
new_state.save_to_db(&self.context.db).await?;
|
||||
|
||||
self.log_rotation_event(
|
||||
new_cert_id,
|
||||
"timeout",
|
||||
Some(format!("Missing ACKs from {} clients", missing.len())),
|
||||
)
|
||||
.await?;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Проверить, все ли ACK получены
|
||||
let missing = self.get_missing_acks(new_cert_id).await?;
|
||||
|
||||
if missing.is_empty() {
|
||||
info!("All clients acknowledged, ready to rotate");
|
||||
|
||||
let new_state = RotationState::ReadyToRotate { new_cert_id };
|
||||
*self.context.rotation_state.write().await = new_state.clone();
|
||||
new_state.save_to_db(&self.context.db).await?;
|
||||
|
||||
self.log_rotation_event(new_cert_id, "acks_complete", None)
|
||||
.await?;
|
||||
} else {
|
||||
let time_remaining = timeout_at - now;
|
||||
debug!(
|
||||
missing_count = missing.len(),
|
||||
time_remaining,
|
||||
"Waiting for rotation ACKs"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Выполнить атомарную ротацию сертификата
|
||||
async fn execute_rotation(&self, new_cert_id: i32) -> Result<(), RotationError> {
|
||||
info!(new_cert_id, "Executing certificate rotation");
|
||||
|
||||
// 1. Загрузить новый сертификат из БД
|
||||
let new_cert = self.load_certificate(new_cert_id).await?;
|
||||
|
||||
// 2. Атомарная замена в TlsManager
|
||||
self.context
|
||||
.tls
|
||||
.replace_certificate(new_cert)
|
||||
.await
|
||||
.map_err(RotationError::TlsInit)?;
|
||||
|
||||
// 3. Обновить БД: старый is_active=false, новый is_active=true
|
||||
self.activate_certificate(new_cert_id).await?;
|
||||
|
||||
// 4. TODO: Отключить всех клиентов
|
||||
// self.disconnect_all_clients().await?;
|
||||
|
||||
// 5. Очистить rotation_state
|
||||
let new_state = RotationState::Normal;
|
||||
*self.context.rotation_state.write().await = new_state.clone();
|
||||
new_state.save_to_db(&self.context.db).await?;
|
||||
|
||||
// 6. Очистить ACKs
|
||||
self.context.rotation_acks.write().await.clear();
|
||||
self.clear_rotation_acks(new_cert_id).await?;
|
||||
|
||||
// 7. Логирование
|
||||
self.log_rotation_event(new_cert_id, "activated", None)
|
||||
.await?;
|
||||
|
||||
info!(new_cert_id, "Certificate rotation completed successfully");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Сохранить новый сертификат в БД
|
||||
async fn save_new_certificate(
|
||||
&self,
|
||||
cert_der: &[u8],
|
||||
keypair: &KeyPair,
|
||||
cert_not_before: i64,
|
||||
cert_not_after: i64,
|
||||
) -> Result<i32, RotationError> {
|
||||
use crate::db::schema::tls_certificates::dsl::*;
|
||||
|
||||
let mut conn = self.context.db.get().await.map_err(|e| {
|
||||
RotationError::InvalidState(format!("Failed to get DB connection: {}", e))
|
||||
})?;
|
||||
|
||||
let new_cert = NewTlsCertificate {
|
||||
cert: cert_der.to_vec(),
|
||||
cert_key: keypair.serialize_pem().as_bytes().to_vec(),
|
||||
not_before: cert_not_before as i32,
|
||||
not_after: cert_not_after as i32,
|
||||
is_active: false,
|
||||
};
|
||||
|
||||
diesel::insert_into(tls_certificates)
|
||||
.values(&new_cert)
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
|
||||
// Получить ID последней вставленной записи
|
||||
let cert_id: i32 = diesel::select(diesel::dsl::sql::<diesel::sql_types::Integer>(
|
||||
"last_insert_rowid()",
|
||||
))
|
||||
.first(&mut conn)
|
||||
.await?;
|
||||
|
||||
self.log_rotation_event(cert_id, "created", None).await?;
|
||||
|
||||
Ok(cert_id)
|
||||
}
|
||||
|
||||
/// Загрузить сертификат из БД
|
||||
async fn load_certificate(&self, cert_id: i32) -> Result<CertificateMetadata, RotationError> {
|
||||
use crate::db::schema::tls_certificates::dsl::*;
|
||||
|
||||
let mut conn = self.context.db.get().await.map_err(|e| {
|
||||
RotationError::InvalidState(format!("Failed to get DB connection: {}", e))
|
||||
})?;
|
||||
|
||||
let cert_record: (Vec<u8>, Vec<u8>, i32, i32, i32) = tls_certificates
|
||||
.select((cert, cert_key, not_before, not_after, created_at))
|
||||
.filter(id.eq(cert_id))
|
||||
.first(&mut conn)
|
||||
.await?;
|
||||
|
||||
let cert_der = rustls::pki_types::CertificateDer::from(cert_record.0);
|
||||
let key_pem = String::from_utf8(cert_record.1)
|
||||
.map_err(|e| RotationError::InvalidState(format!("Invalid key encoding: {}", e)))?;
|
||||
let keypair = KeyPair::from_pem(&key_pem)?;
|
||||
|
||||
Ok(CertificateMetadata {
|
||||
cert_id,
|
||||
cert: cert_der,
|
||||
keypair: Arc::new(keypair),
|
||||
not_before: cert_record.2 as i64,
|
||||
not_after: cert_record.3 as i64,
|
||||
created_at: cert_record.4 as i64,
|
||||
})
|
||||
}
|
||||
|
||||
/// Активировать сертификат (установить is_active=true)
|
||||
async fn activate_certificate(&self, cert_id: i32) -> Result<(), RotationError> {
|
||||
use crate::db::schema::tls_certificates::dsl::*;
|
||||
|
||||
let mut conn = self.context.db.get().await.map_err(|e| {
|
||||
RotationError::InvalidState(format!("Failed to get DB connection: {}", e))
|
||||
})?;
|
||||
|
||||
// Деактивировать все сертификаты
|
||||
diesel::update(tls_certificates)
|
||||
.set(is_active.eq(false))
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
|
||||
// Активировать новый
|
||||
diesel::update(tls_certificates.filter(id.eq(cert_id)))
|
||||
.set(is_active.eq(true))
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Получить список клиентов, которые ещё не отправили ACK
|
||||
async fn get_missing_acks(&self, rotation_id: i32) -> Result<Vec<VerifyingKey>, RotationError> {
|
||||
// TODO: Реализовать получение списка всех активных клиентов
|
||||
// и вычитание тех, кто уже отправил ACK
|
||||
|
||||
// Пока возвращаем пустой список
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
/// Очистить ACKs для данной ротации из БД
|
||||
async fn clear_rotation_acks(&self, rotation_id: i32) -> Result<(), RotationError> {
|
||||
use crate::db::schema::rotation_client_acks::dsl::*;
|
||||
|
||||
let mut conn = self.context.db.get().await.map_err(|e| {
|
||||
RotationError::InvalidState(format!("Failed to get DB connection: {}", e))
|
||||
})?;
|
||||
|
||||
diesel::delete(rotation_client_acks.filter(rotation_id.eq(rotation_id)))
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Записать событие в audit trail
|
||||
async fn log_rotation_event(
|
||||
&self,
|
||||
history_cert_id: i32,
|
||||
history_event_type: &str,
|
||||
history_details: Option<String>,
|
||||
) -> Result<(), RotationError> {
|
||||
use crate::db::schema::tls_rotation_history::dsl::*;
|
||||
|
||||
let mut conn = self.context.db.get().await.map_err(|e| {
|
||||
RotationError::InvalidState(format!("Failed to get DB connection: {}", e))
|
||||
})?;
|
||||
|
||||
let new_history = NewTlsRotationHistory {
|
||||
cert_id: history_cert_id,
|
||||
event_type: history_event_type.to_string(),
|
||||
details: history_details,
|
||||
};
|
||||
|
||||
diesel::insert_into(tls_rotation_history)
|
||||
.values(&new_history)
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
161
server/crates/arbiter-server/src/context/unseal.rs
Normal file
161
server/crates/arbiter-server/src/context/unseal.rs
Normal file
@@ -0,0 +1,161 @@
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use miette::Diagnostic;
|
||||
use secrecy::{ExposeSecret, SecretBox};
|
||||
use thiserror::Error;
|
||||
use x25519_dalek::{PublicKey, StaticSecret};
|
||||
|
||||
const EPHEMERAL_KEY_LIFETIME_SECS: u64 = 60;
|
||||
|
||||
#[derive(Error, Debug, Diagnostic)]
|
||||
pub enum UnsealError {
|
||||
#[error("Invalid public key")]
|
||||
#[diagnostic(code(arbiter_server::unseal::invalid_pubkey))]
|
||||
InvalidPublicKey,
|
||||
}
|
||||
|
||||
/// Ephemeral X25519 keypair for secure password transmission
|
||||
///
|
||||
/// Generated on-demand when client requests unseal. Expires after 60 seconds.
|
||||
/// Uses StaticSecret stored in SecretBox for automatic zeroization on drop.
|
||||
pub struct EphemeralKeyPair {
|
||||
/// Secret key stored securely
|
||||
secret: SecretBox<StaticSecret>,
|
||||
public: PublicKey,
|
||||
expires_at: u64,
|
||||
}
|
||||
|
||||
impl EphemeralKeyPair {
|
||||
/// Generate new ephemeral X25519 keypair
|
||||
pub fn generate() -> Self {
|
||||
// Generate random 32 bytes
|
||||
let secret_bytes = rand::random::<[u8; 32]>();
|
||||
let secret = StaticSecret::from(secret_bytes);
|
||||
let public = PublicKey::from(&secret);
|
||||
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("System time before UNIX epoch")
|
||||
.as_secs();
|
||||
|
||||
Self {
|
||||
secret: SecretBox::new(Box::new(secret)),
|
||||
public,
|
||||
expires_at: now + EPHEMERAL_KEY_LIFETIME_SECS,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this ephemeral key has expired
|
||||
pub fn is_expired(&self) -> bool {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("System time before UNIX epoch")
|
||||
.as_secs();
|
||||
|
||||
now > self.expires_at
|
||||
}
|
||||
|
||||
/// Get expiration timestamp (Unix epoch seconds)
|
||||
pub fn expires_at(&self) -> u64 {
|
||||
self.expires_at
|
||||
}
|
||||
|
||||
/// Get public key as bytes for transmission to client
|
||||
pub fn public_bytes(&self) -> Vec<u8> {
|
||||
self.public.as_bytes().to_vec()
|
||||
}
|
||||
|
||||
/// Perform Diffie-Hellman key exchange with client's public key
|
||||
///
|
||||
/// Returns 32-byte shared secret for ChaCha20Poly1305 encryption
|
||||
pub fn perform_dh(&self, client_pubkey: &[u8]) -> Result<[u8; 32], UnsealError> {
|
||||
// Parse client public key
|
||||
let client_public = PublicKey::from(
|
||||
<[u8; 32]>::try_from(client_pubkey).map_err(|_| UnsealError::InvalidPublicKey)?,
|
||||
);
|
||||
|
||||
// Perform ECDH
|
||||
let shared_secret = self.secret.expose_secret().diffie_hellman(&client_public);
|
||||
|
||||
Ok(shared_secret.to_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ephemeral_keypair_generation() {
|
||||
let keypair = EphemeralKeyPair::generate();
|
||||
|
||||
// Public key should be 32 bytes
|
||||
assert_eq!(keypair.public_bytes().len(), 32);
|
||||
|
||||
// Should not be expired immediately
|
||||
assert!(!keypair.is_expired());
|
||||
|
||||
// Expiration should be ~60 seconds in future
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
let time_until_expiry = keypair.expires_at() - now;
|
||||
assert!((59..=61).contains(&time_until_expiry));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_perform_dh_with_valid_key() {
|
||||
let server_keypair = EphemeralKeyPair::generate();
|
||||
let client_secret_bytes = rand::random::<[u8; 32]>();
|
||||
let client_secret = StaticSecret::from(client_secret_bytes);
|
||||
let client_public = PublicKey::from(&client_secret);
|
||||
|
||||
// Server performs DH
|
||||
let server_shared_secret = server_keypair
|
||||
.perform_dh(client_public.as_bytes())
|
||||
.expect("DH should succeed");
|
||||
|
||||
// Client performs DH
|
||||
let client_shared_secret = client_secret.diffie_hellman(&server_keypair.public);
|
||||
|
||||
// Shared secrets should match
|
||||
assert_eq!(server_shared_secret, client_shared_secret.to_bytes());
|
||||
assert_eq!(server_shared_secret.len(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_perform_dh_with_invalid_key() {
|
||||
let keypair = EphemeralKeyPair::generate();
|
||||
|
||||
// Try with invalid length
|
||||
let invalid_key = vec![1, 2, 3];
|
||||
let result = keypair.perform_dh(&invalid_key);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Try with wrong length (not 32 bytes)
|
||||
let invalid_key = vec![0u8; 16];
|
||||
let result = keypair.perform_dh(&invalid_key);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_different_keypairs_produce_different_shared_secrets() {
|
||||
let server_keypair1 = EphemeralKeyPair::generate();
|
||||
let server_keypair2 = EphemeralKeyPair::generate();
|
||||
|
||||
let client_secret_bytes = rand::random::<[u8; 32]>();
|
||||
let client_secret = StaticSecret::from(client_secret_bytes);
|
||||
let client_public = PublicKey::from(&client_secret);
|
||||
|
||||
let shared1 = server_keypair1
|
||||
.perform_dh(client_public.as_bytes())
|
||||
.unwrap();
|
||||
let shared2 = server_keypair2
|
||||
.perform_dh(client_public.as_bytes())
|
||||
.unwrap();
|
||||
|
||||
// Different server keys should produce different shared secrets
|
||||
assert_ne!(shared1, shared2);
|
||||
}
|
||||
}
|
||||
139
server/crates/arbiter-server/src/crypto/aead.rs
Normal file
139
server/crates/arbiter-server/src/crypto/aead.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
use chacha20poly1305::{
|
||||
aead::{Aead, KeyInit},
|
||||
ChaCha20Poly1305, Key, Nonce,
|
||||
};
|
||||
|
||||
use super::CryptoError;
|
||||
|
||||
/// Encrypt plaintext with AEAD (ChaCha20Poly1305)
|
||||
///
|
||||
/// Returns (ciphertext, tag) on success
|
||||
pub fn encrypt(
|
||||
plaintext: &[u8],
|
||||
key: &[u8; 32],
|
||||
nonce: &[u8; 12],
|
||||
) -> Result<Vec<u8>, CryptoError> {
|
||||
let cipher_key = Key::from_slice(key);
|
||||
let cipher = ChaCha20Poly1305::new(cipher_key);
|
||||
let nonce_array = Nonce::from_slice(nonce);
|
||||
|
||||
cipher
|
||||
.encrypt(nonce_array, plaintext)
|
||||
.map_err(|e| CryptoError::AeadEncryption(e.to_string()))
|
||||
}
|
||||
|
||||
/// Decrypt ciphertext with AEAD (ChaCha20Poly1305)
|
||||
///
|
||||
/// The ciphertext должен содержать tag (последние 16 bytes)
|
||||
pub fn decrypt(
|
||||
ciphertext_with_tag: &[u8],
|
||||
key: &[u8; 32],
|
||||
nonce: &[u8; 12],
|
||||
) -> Result<Vec<u8>, CryptoError> {
|
||||
let cipher_key = Key::from_slice(key);
|
||||
let cipher = ChaCha20Poly1305::new(cipher_key);
|
||||
let nonce_array = Nonce::from_slice(nonce);
|
||||
|
||||
cipher
|
||||
.decrypt(nonce_array, ciphertext_with_tag)
|
||||
.map_err(|e| CryptoError::AeadDecryption(e.to_string()))
|
||||
}
|
||||
|
||||
/// Generate nonce from counter
|
||||
///
|
||||
/// Converts i32 counter to 12-byte nonce (big-endian encoding)
|
||||
pub fn nonce_from_counter(counter: i32) -> [u8; 12] {
|
||||
let mut nonce = [0u8; 12];
|
||||
nonce[8..12].copy_from_slice(&counter.to_be_bytes());
|
||||
nonce
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_aead_encrypt_decrypt_round_trip() {
|
||||
let plaintext = b"Hello, World! This is a secret message.";
|
||||
let key = [42u8; 32];
|
||||
let nonce = nonce_from_counter(1);
|
||||
|
||||
// Encrypt
|
||||
let ciphertext = encrypt(plaintext, &key, &nonce).expect("Encryption failed");
|
||||
|
||||
// Verify ciphertext is different from plaintext
|
||||
assert_ne!(ciphertext.as_slice(), plaintext);
|
||||
|
||||
// Decrypt
|
||||
let decrypted = decrypt(&ciphertext, &key, &nonce).expect("Decryption failed");
|
||||
|
||||
// Verify round-trip
|
||||
assert_eq!(decrypted.as_slice(), plaintext);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aead_decrypt_with_wrong_key() {
|
||||
let plaintext = b"Secret data";
|
||||
let key = [1u8; 32];
|
||||
let wrong_key = [2u8; 32];
|
||||
let nonce = nonce_from_counter(1);
|
||||
|
||||
let ciphertext = encrypt(plaintext, &key, &nonce).expect("Encryption failed");
|
||||
|
||||
// Attempt decrypt with wrong key
|
||||
let result = decrypt(&ciphertext, &wrong_key, &nonce);
|
||||
|
||||
// Should fail
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aead_decrypt_with_wrong_nonce() {
|
||||
let plaintext = b"Secret data";
|
||||
let key = [1u8; 32];
|
||||
let nonce = nonce_from_counter(1);
|
||||
let wrong_nonce = nonce_from_counter(2);
|
||||
|
||||
let ciphertext = encrypt(plaintext, &key, &nonce).expect("Encryption failed");
|
||||
|
||||
// Attempt decrypt with wrong nonce
|
||||
let result = decrypt(&ciphertext, &key, &wrong_nonce);
|
||||
|
||||
// Should fail
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nonce_generation_from_counter() {
|
||||
let nonce1 = nonce_from_counter(1);
|
||||
let nonce2 = nonce_from_counter(2);
|
||||
let nonce_max = nonce_from_counter(i32::MAX);
|
||||
|
||||
// Verify nonces are different
|
||||
assert_ne!(nonce1, nonce2);
|
||||
|
||||
// Verify nonce format (first 8 bytes should be zero, last 4 contain counter)
|
||||
assert_eq!(&nonce1[0..8], &[0u8; 8]);
|
||||
assert_eq!(&nonce1[8..12], &1i32.to_be_bytes());
|
||||
|
||||
assert_eq!(&nonce_max[8..12], &i32::MAX.to_be_bytes());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aead_tampered_ciphertext() {
|
||||
let plaintext = b"Important message";
|
||||
let key = [7u8; 32];
|
||||
let nonce = nonce_from_counter(5);
|
||||
|
||||
let mut ciphertext = encrypt(plaintext, &key, &nonce).expect("Encryption failed");
|
||||
|
||||
// Tamper with ciphertext (flip a bit)
|
||||
if let Some(byte) = ciphertext.get_mut(5) {
|
||||
*byte ^= 0x01;
|
||||
}
|
||||
|
||||
// Attempt decrypt - should fail due to authentication tag mismatch
|
||||
let result = decrypt(&ciphertext, &key, &nonce);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
28
server/crates/arbiter-server/src/crypto/mod.rs
Normal file
28
server/crates/arbiter-server/src/crypto/mod.rs
Normal file
@@ -0,0 +1,28 @@
|
||||
pub mod aead;
|
||||
pub mod root_key;
|
||||
|
||||
use miette::Diagnostic;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug, Diagnostic)]
|
||||
pub enum CryptoError {
|
||||
#[error("AEAD encryption failed: {0}")]
|
||||
#[diagnostic(code(arbiter_server::crypto::aead_encryption))]
|
||||
AeadEncryption(String),
|
||||
|
||||
#[error("AEAD decryption failed: {0}")]
|
||||
#[diagnostic(code(arbiter_server::crypto::aead_decryption))]
|
||||
AeadDecryption(String),
|
||||
|
||||
#[error("Key derivation failed: {0}")]
|
||||
#[diagnostic(code(arbiter_server::crypto::key_derivation))]
|
||||
KeyDerivation(String),
|
||||
|
||||
#[error("Invalid nonce: {0}")]
|
||||
#[diagnostic(code(arbiter_server::crypto::invalid_nonce))]
|
||||
InvalidNonce(String),
|
||||
|
||||
#[error("Invalid key format: {0}")]
|
||||
#[diagnostic(code(arbiter_server::crypto::invalid_key))]
|
||||
InvalidKey(String),
|
||||
}
|
||||
240
server/crates/arbiter-server/src/crypto/root_key.rs
Normal file
240
server/crates/arbiter-server/src/crypto/root_key.rs
Normal file
@@ -0,0 +1,240 @@
|
||||
use argon2::{
|
||||
password_hash::{rand_core::OsRng, PasswordHasher, SaltString},
|
||||
Argon2, PasswordHash, PasswordVerifier,
|
||||
};
|
||||
|
||||
use crate::db::models::AeadEncrypted;
|
||||
|
||||
use super::{aead, CryptoError};
|
||||
|
||||
/// Encrypt root key with user password
|
||||
///
|
||||
/// Uses Argon2id for password derivation and ChaCha20Poly1305 for encryption
|
||||
pub fn encrypt_root_key(
|
||||
root_key: &[u8; 32],
|
||||
password: &str,
|
||||
nonce_counter: i32,
|
||||
) -> Result<(AeadEncrypted, String), CryptoError> {
|
||||
// Derive key from password using Argon2
|
||||
let (derived_key, salt) = derive_key_from_password(password)?;
|
||||
|
||||
// Generate nonce from counter
|
||||
let nonce = aead::nonce_from_counter(nonce_counter);
|
||||
|
||||
// Encrypt root key
|
||||
let ciphertext_with_tag = aead::encrypt(root_key, &derived_key, &nonce)?;
|
||||
|
||||
// Extract tag (last 16 bytes)
|
||||
let tag_start = ciphertext_with_tag
|
||||
.len()
|
||||
.checked_sub(16)
|
||||
.ok_or_else(|| CryptoError::AeadEncryption("Ciphertext too short".into()))?;
|
||||
|
||||
let ciphertext = ciphertext_with_tag[..tag_start].to_vec();
|
||||
let tag = ciphertext_with_tag[tag_start..].to_vec();
|
||||
|
||||
let aead_encrypted = AeadEncrypted {
|
||||
id: 1, // Will be set by database
|
||||
current_nonce: nonce_counter,
|
||||
ciphertext,
|
||||
tag,
|
||||
schema_version: 1, // Current version
|
||||
argon2_salt: Some(salt.clone()),
|
||||
};
|
||||
|
||||
Ok((aead_encrypted, salt))
|
||||
}
|
||||
|
||||
/// Decrypt root key with user password
|
||||
///
|
||||
/// Verifies password hash and decrypts using ChaCha20Poly1305
|
||||
pub fn decrypt_root_key(
|
||||
encrypted: &AeadEncrypted,
|
||||
password: &str,
|
||||
salt: &str,
|
||||
) -> Result<[u8; 32], CryptoError> {
|
||||
// Derive key from password using stored salt
|
||||
let derived_key = derive_key_with_salt(password, salt)?;
|
||||
|
||||
// Generate nonce from counter
|
||||
let nonce = aead::nonce_from_counter(encrypted.current_nonce);
|
||||
|
||||
// Reconstruct ciphertext with tag
|
||||
let mut ciphertext_with_tag = encrypted.ciphertext.clone();
|
||||
ciphertext_with_tag.extend_from_slice(&encrypted.tag);
|
||||
|
||||
// Decrypt
|
||||
let plaintext = aead::decrypt(&ciphertext_with_tag, &derived_key, &nonce)?;
|
||||
|
||||
// Verify length
|
||||
if plaintext.len() != 32 {
|
||||
return Err(CryptoError::InvalidKey(format!(
|
||||
"Expected 32 bytes, got {}",
|
||||
plaintext.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Convert to fixed-size array
|
||||
let mut root_key = [0u8; 32];
|
||||
root_key.copy_from_slice(&plaintext);
|
||||
|
||||
Ok(root_key)
|
||||
}
|
||||
|
||||
/// Derive 32-byte key from password using Argon2id
|
||||
///
|
||||
/// Generates new random salt and returns (derived_key, salt_string)
|
||||
fn derive_key_from_password(password: &str) -> Result<([u8; 32], String), CryptoError> {
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
|
||||
let argon2 = Argon2::default();
|
||||
|
||||
let password_hash = argon2
|
||||
.hash_password(password.as_bytes(), &salt)
|
||||
.map_err(|e| CryptoError::KeyDerivation(e.to_string()))?;
|
||||
|
||||
// Extract hash output (32 bytes)
|
||||
let hash_output = password_hash
|
||||
.hash
|
||||
.ok_or_else(|| CryptoError::KeyDerivation("No hash output".into()))?;
|
||||
|
||||
let hash_bytes = hash_output.as_bytes();
|
||||
|
||||
if hash_bytes.len() != 32 {
|
||||
return Err(CryptoError::KeyDerivation(format!(
|
||||
"Expected 32 bytes, got {}",
|
||||
hash_bytes.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut key = [0u8; 32];
|
||||
key.copy_from_slice(hash_bytes);
|
||||
|
||||
Ok((key, salt.to_string()))
|
||||
}
|
||||
|
||||
/// Derive 32-byte key from password using existing salt
|
||||
fn derive_key_with_salt(password: &str, salt_str: &str) -> Result<[u8; 32], CryptoError> {
|
||||
let argon2 = Argon2::default();
|
||||
|
||||
// Parse salt
|
||||
let salt =
|
||||
SaltString::from_b64(salt_str).map_err(|e| CryptoError::InvalidKey(e.to_string()))?;
|
||||
|
||||
let password_hash = argon2
|
||||
.hash_password(password.as_bytes(), &salt)
|
||||
.map_err(|e| CryptoError::KeyDerivation(e.to_string()))?;
|
||||
|
||||
// Extract hash output
|
||||
let hash_output = password_hash
|
||||
.hash
|
||||
.ok_or_else(|| CryptoError::KeyDerivation("No hash output".into()))?;
|
||||
|
||||
let hash_bytes = hash_output.as_bytes();
|
||||
|
||||
if hash_bytes.len() != 32 {
|
||||
return Err(CryptoError::KeyDerivation(format!(
|
||||
"Expected 32 bytes, got {}",
|
||||
hash_bytes.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let mut key = [0u8; 32];
|
||||
key.copy_from_slice(hash_bytes);
|
||||
|
||||
Ok(key)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_root_key_encrypt_decrypt_round_trip() {
|
||||
let root_key = [42u8; 32];
|
||||
let password = "super_secret_password_123";
|
||||
let nonce_counter = 1;
|
||||
|
||||
// Encrypt
|
||||
let (encrypted, salt) =
|
||||
encrypt_root_key(&root_key, password, nonce_counter).expect("Encryption failed");
|
||||
|
||||
// Verify structure
|
||||
assert_eq!(encrypted.current_nonce, nonce_counter);
|
||||
assert_eq!(encrypted.schema_version, 1);
|
||||
assert_eq!(encrypted.tag.len(), 16); // AEAD tag size
|
||||
|
||||
// Decrypt
|
||||
let decrypted =
|
||||
decrypt_root_key(&encrypted, password, &salt).expect("Decryption failed");
|
||||
|
||||
// Verify round-trip
|
||||
assert_eq!(decrypted, root_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decrypt_with_wrong_password() {
|
||||
let root_key = [99u8; 32];
|
||||
let correct_password = "correct_password";
|
||||
let wrong_password = "wrong_password";
|
||||
let nonce_counter = 1;
|
||||
|
||||
// Encrypt with correct password
|
||||
let (encrypted, salt) =
|
||||
encrypt_root_key(&root_key, correct_password, nonce_counter).expect("Encryption failed");
|
||||
|
||||
// Attempt decrypt with wrong password
|
||||
let result = decrypt_root_key(&encrypted, wrong_password, &salt);
|
||||
|
||||
// Should fail due to authentication tag mismatch
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_password_derivation_different_salts() {
|
||||
let password = "same_password";
|
||||
|
||||
// Derive key twice - should produce different salts
|
||||
let (key1, salt1) = derive_key_from_password(password).expect("Derivation 1 failed");
|
||||
let (key2, salt2) = derive_key_from_password(password).expect("Derivation 2 failed");
|
||||
|
||||
// Salts should be different (randomly generated)
|
||||
assert_ne!(salt1, salt2);
|
||||
|
||||
// Keys should be different (due to different salts)
|
||||
assert_ne!(key1, key2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_password_derivation_with_same_salt() {
|
||||
let password = "test_password";
|
||||
|
||||
// Generate key and salt
|
||||
let (key1, salt) = derive_key_from_password(password).expect("Derivation failed");
|
||||
|
||||
// Derive key again with same salt
|
||||
let key2 = derive_key_with_salt(password, &salt).expect("Re-derivation failed");
|
||||
|
||||
// Keys should be identical
|
||||
assert_eq!(key1, key2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_different_nonce_produces_different_ciphertext() {
|
||||
let root_key = [77u8; 32];
|
||||
let password = "password123";
|
||||
|
||||
let (encrypted1, salt1) = encrypt_root_key(&root_key, password, 1).expect("Encryption 1 failed");
|
||||
let (encrypted2, salt2) = encrypt_root_key(&root_key, password, 2).expect("Encryption 2 failed");
|
||||
|
||||
// Different nonces should produce different ciphertexts
|
||||
assert_ne!(encrypted1.ciphertext, encrypted2.ciphertext);
|
||||
|
||||
// But both should decrypt correctly
|
||||
let decrypted1 = decrypt_root_key(&encrypted1, password, &salt1).expect("Decryption 1 failed");
|
||||
let decrypted2 = decrypt_root_key(&encrypted2, password, &salt2).expect("Decryption 2 failed");
|
||||
|
||||
assert_eq!(decrypted1, root_key);
|
||||
assert_eq!(decrypted2, root_key);
|
||||
}
|
||||
}
|
||||
@@ -22,7 +22,7 @@ pub type DatabasePool = diesel_async::pooled_connection::bb8::Pool<DatabaseConne
|
||||
pub type PoolInitError = diesel_async::pooled_connection::PoolError;
|
||||
pub type PoolError = diesel_async::pooled_connection::bb8::RunError;
|
||||
|
||||
static DB_FILE: &'static str = "arbiter.sqlite";
|
||||
static DB_FILE: &str = "arbiter.sqlite";
|
||||
|
||||
const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");
|
||||
|
||||
|
||||
@@ -9,14 +9,15 @@ pub mod types {
|
||||
pub struct SqliteTimestamp(DateTime<Utc>);
|
||||
}
|
||||
|
||||
#[derive(Queryable, Debug, Insertable)]
|
||||
#[derive(Queryable, Selectable, Debug, Insertable)]
|
||||
#[diesel(table_name = aead_encrypted, check_for_backend(Sqlite))]
|
||||
pub struct AeadEncrypted {
|
||||
pub id: i32,
|
||||
pub current_nonce: i32,
|
||||
pub ciphertext: Vec<u8>,
|
||||
pub tag: Vec<u8>,
|
||||
pub current_nonce: i32,
|
||||
pub schema_version: i32,
|
||||
pub argon2_salt: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Queryable, Debug, Insertable)]
|
||||
@@ -26,6 +27,7 @@ pub struct ArbiterSetting {
|
||||
pub root_key_id: Option<i32>, // references aead_encrypted.id
|
||||
pub cert_key: Vec<u8>,
|
||||
pub cert: Vec<u8>,
|
||||
pub current_cert_id: Option<i32>, // references tls_certificates.id
|
||||
}
|
||||
|
||||
#[derive(Queryable, Debug)]
|
||||
@@ -47,3 +49,70 @@ pub struct UseragentClient {
|
||||
pub created_at: i32,
|
||||
pub updated_at: i32,
|
||||
}
|
||||
|
||||
// TLS Certificate Rotation Models
|
||||
|
||||
#[derive(Queryable, Debug, Insertable)]
|
||||
#[diesel(table_name = schema::tls_certificates, check_for_backend(Sqlite))]
|
||||
pub struct TlsCertificate {
|
||||
pub id: i32,
|
||||
pub cert: Vec<u8>,
|
||||
pub cert_key: Vec<u8>,
|
||||
pub not_before: i32,
|
||||
pub not_after: i32,
|
||||
pub created_at: i32,
|
||||
pub is_active: bool,
|
||||
}
|
||||
|
||||
#[derive(Insertable)]
|
||||
#[diesel(table_name = schema::tls_certificates)]
|
||||
pub struct NewTlsCertificate {
|
||||
pub cert: Vec<u8>,
|
||||
pub cert_key: Vec<u8>,
|
||||
pub not_before: i32,
|
||||
pub not_after: i32,
|
||||
pub is_active: bool,
|
||||
}
|
||||
|
||||
#[derive(Queryable, Debug, Insertable)]
|
||||
#[diesel(table_name = schema::tls_rotation_state, check_for_backend(Sqlite))]
|
||||
pub struct TlsRotationState {
|
||||
pub id: i32,
|
||||
pub state: String,
|
||||
pub new_cert_id: Option<i32>,
|
||||
pub initiated_at: Option<i32>,
|
||||
pub timeout_at: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Queryable, Debug, Insertable)]
|
||||
#[diesel(table_name = schema::rotation_client_acks, check_for_backend(Sqlite))]
|
||||
pub struct RotationClientAck {
|
||||
pub rotation_id: i32,
|
||||
pub client_key: String,
|
||||
pub ack_received_at: i32,
|
||||
}
|
||||
|
||||
#[derive(Insertable)]
|
||||
#[diesel(table_name = schema::rotation_client_acks)]
|
||||
pub struct NewRotationClientAck {
|
||||
pub rotation_id: i32,
|
||||
pub client_key: String,
|
||||
}
|
||||
|
||||
#[derive(Queryable, Debug, Insertable)]
|
||||
#[diesel(table_name = schema::tls_rotation_history, check_for_backend(Sqlite))]
|
||||
pub struct TlsRotationHistory {
|
||||
pub id: i32,
|
||||
pub cert_id: i32,
|
||||
pub event_type: String,
|
||||
pub timestamp: i32,
|
||||
pub details: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Insertable)]
|
||||
#[diesel(table_name = schema::tls_rotation_history)]
|
||||
pub struct NewTlsRotationHistory {
|
||||
pub cert_id: i32,
|
||||
pub event_type: String,
|
||||
pub details: Option<String>,
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ diesel::table! {
|
||||
ciphertext -> Binary,
|
||||
tag -> Binary,
|
||||
schema_version -> Integer,
|
||||
argon2_salt -> Nullable<Text>,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +17,7 @@ diesel::table! {
|
||||
root_key_id -> Nullable<Integer>,
|
||||
cert_key -> Binary,
|
||||
cert -> Binary,
|
||||
current_cert_id -> Nullable<Integer>,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,11 +41,59 @@ diesel::table! {
|
||||
}
|
||||
}
|
||||
|
||||
diesel::table! {
|
||||
tls_certificates (id) {
|
||||
id -> Integer,
|
||||
cert -> Binary,
|
||||
cert_key -> Binary,
|
||||
not_before -> Integer,
|
||||
not_after -> Integer,
|
||||
created_at -> Integer,
|
||||
is_active -> Bool,
|
||||
}
|
||||
}
|
||||
|
||||
diesel::table! {
|
||||
tls_rotation_state (id) {
|
||||
id -> Integer,
|
||||
state -> Text,
|
||||
new_cert_id -> Nullable<Integer>,
|
||||
initiated_at -> Nullable<Integer>,
|
||||
timeout_at -> Nullable<Integer>,
|
||||
}
|
||||
}
|
||||
|
||||
diesel::table! {
|
||||
rotation_client_acks (rotation_id, client_key) {
|
||||
rotation_id -> Integer,
|
||||
client_key -> Text,
|
||||
ack_received_at -> Integer,
|
||||
}
|
||||
}
|
||||
|
||||
diesel::table! {
|
||||
tls_rotation_history (id) {
|
||||
id -> Integer,
|
||||
cert_id -> Integer,
|
||||
event_type -> Text,
|
||||
timestamp -> Integer,
|
||||
details -> Nullable<Text>,
|
||||
}
|
||||
}
|
||||
|
||||
diesel::joinable!(arbiter_settings -> aead_encrypted (root_key_id));
|
||||
diesel::joinable!(arbiter_settings -> tls_certificates (current_cert_id));
|
||||
diesel::joinable!(tls_rotation_state -> tls_certificates (new_cert_id));
|
||||
diesel::joinable!(rotation_client_acks -> tls_certificates (rotation_id));
|
||||
diesel::joinable!(tls_rotation_history -> tls_certificates (cert_id));
|
||||
|
||||
diesel::allow_tables_to_appear_in_same_query!(
|
||||
aead_encrypted,
|
||||
arbiter_settings,
|
||||
program_client,
|
||||
useragent_client,
|
||||
tls_certificates,
|
||||
tls_rotation_state,
|
||||
rotation_client_acks,
|
||||
tls_rotation_history,
|
||||
);
|
||||
|
||||
@@ -19,6 +19,7 @@ use crate::{
|
||||
|
||||
pub mod actors;
|
||||
mod context;
|
||||
mod crypto;
|
||||
mod db;
|
||||
mod errors;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user