Compare commits
4 Commits
PoC-terror
...
77c3babec7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
77c3babec7 | ||
|
|
6f03ce4d1d | ||
|
|
c90af9c196 | ||
|
|
a5a9bc73b0 |
@@ -106,6 +106,16 @@ enum VaultState {
|
|||||||
VAULT_STATE_ERROR = 4;
|
VAULT_STATE_ERROR = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message ClientConnectionRequest {
|
||||||
|
bytes pubkey = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ClientConnectionResponse {
|
||||||
|
bool approved = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ClientConnectionCancel {}
|
||||||
|
|
||||||
message UserAgentRequest {
|
message UserAgentRequest {
|
||||||
oneof payload {
|
oneof payload {
|
||||||
AuthChallengeRequest auth_challenge_request = 1;
|
AuthChallengeRequest auth_challenge_request = 1;
|
||||||
@@ -118,7 +128,7 @@ message UserAgentRequest {
|
|||||||
arbiter.evm.EvmGrantCreateRequest evm_grant_create = 8;
|
arbiter.evm.EvmGrantCreateRequest evm_grant_create = 8;
|
||||||
arbiter.evm.EvmGrantDeleteRequest evm_grant_delete = 9;
|
arbiter.evm.EvmGrantDeleteRequest evm_grant_delete = 9;
|
||||||
arbiter.evm.EvmGrantListRequest evm_grant_list = 10;
|
arbiter.evm.EvmGrantListRequest evm_grant_list = 10;
|
||||||
// field 11 reserved: was client_connection_response (online approval removed)
|
ClientConnectionResponse client_connection_response = 11;
|
||||||
SdkClientApproveRequest sdk_client_approve = 12;
|
SdkClientApproveRequest sdk_client_approve = 12;
|
||||||
SdkClientRevokeRequest sdk_client_revoke = 13;
|
SdkClientRevokeRequest sdk_client_revoke = 13;
|
||||||
google.protobuf.Empty sdk_client_list = 14;
|
google.protobuf.Empty sdk_client_list = 14;
|
||||||
@@ -136,7 +146,8 @@ message UserAgentResponse {
|
|||||||
arbiter.evm.EvmGrantCreateResponse evm_grant_create = 8;
|
arbiter.evm.EvmGrantCreateResponse evm_grant_create = 8;
|
||||||
arbiter.evm.EvmGrantDeleteResponse evm_grant_delete = 9;
|
arbiter.evm.EvmGrantDeleteResponse evm_grant_delete = 9;
|
||||||
arbiter.evm.EvmGrantListResponse evm_grant_list = 10;
|
arbiter.evm.EvmGrantListResponse evm_grant_list = 10;
|
||||||
// fields 11, 12 reserved: were client_connection_request, client_connection_cancel (online approval removed)
|
ClientConnectionRequest client_connection_request = 11;
|
||||||
|
ClientConnectionCancel client_connection_cancel = 12;
|
||||||
SdkClientApproveResponse sdk_client_approve = 13;
|
SdkClientApproveResponse sdk_client_approve = 13;
|
||||||
SdkClientRevokeResponse sdk_client_revoke = 14;
|
SdkClientRevokeResponse sdk_client_revoke = 14;
|
||||||
SdkClientListResponse sdk_client_list = 15;
|
SdkClientListResponse sdk_client_list = 15;
|
||||||
|
|||||||
@@ -157,3 +157,5 @@ create table if not exists evm_ether_transfer_grant_target (
|
|||||||
|
|
||||||
create unique index if not exists uniq_ether_transfer_target on evm_ether_transfer_grant_target(grant_id, address);
|
create unique index if not exists uniq_ether_transfer_target on evm_ether_transfer_grant_target(grant_id, address);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX program_client_public_key_unique
|
||||||
|
ON program_client (public_key);
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
DROP INDEX IF EXISTS program_client_public_key_unique;
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
CREATE UNIQUE INDEX program_client_public_key_unique
|
|
||||||
ON program_client (public_key);
|
|
||||||
@@ -8,13 +8,19 @@ use arbiter_proto::{
|
|||||||
},
|
},
|
||||||
transport::expect_message,
|
transport::expect_message,
|
||||||
};
|
};
|
||||||
use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl as _, update};
|
use diesel::{
|
||||||
|
ExpressionMethods as _, OptionalExtension as _, QueryDsl as _, dsl::insert_into, update,
|
||||||
|
};
|
||||||
use diesel_async::RunQueryDsl as _;
|
use diesel_async::RunQueryDsl as _;
|
||||||
use ed25519_dalek::VerifyingKey;
|
use ed25519_dalek::VerifyingKey;
|
||||||
|
use kameo::error::SendError;
|
||||||
use tracing::error;
|
use tracing::error;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
actors::client::ClientConnection,
|
actors::{
|
||||||
|
client::ClientConnection,
|
||||||
|
router::{self, RequestClientApproval},
|
||||||
|
},
|
||||||
db::{self, schema::program_client},
|
db::{self, schema::program_client},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -34,14 +40,24 @@ pub enum Error {
|
|||||||
DatabaseOperationFailed,
|
DatabaseOperationFailed,
|
||||||
#[error("Invalid challenge solution")]
|
#[error("Invalid challenge solution")]
|
||||||
InvalidChallengeSolution,
|
InvalidChallengeSolution,
|
||||||
#[error("Client not registered")]
|
#[error("Client approval request failed")]
|
||||||
NotRegistered,
|
ApproveError(#[from] ApproveError),
|
||||||
#[error("Internal error")]
|
#[error("Internal error")]
|
||||||
InternalError,
|
InternalError,
|
||||||
#[error("Transport error")]
|
#[error("Transport error")]
|
||||||
Transport,
|
Transport,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum ApproveError {
|
||||||
|
#[error("Internal error")]
|
||||||
|
Internal,
|
||||||
|
#[error("Client connection denied by user agents")]
|
||||||
|
Denied,
|
||||||
|
#[error("Upstream error: {0}")]
|
||||||
|
Upstream(router::ApprovalError),
|
||||||
|
}
|
||||||
|
|
||||||
/// Atomically reads and increments the nonce for a known client.
|
/// Atomically reads and increments the nonce for a known client.
|
||||||
/// Returns `None` if the pubkey is not registered.
|
/// Returns `None` if the pubkey is not registered.
|
||||||
async fn get_nonce(
|
async fn get_nonce(
|
||||||
@@ -84,6 +100,85 @@ async fn get_nonce(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn approve_new_client(
|
||||||
|
actors: &crate::actors::GlobalActors,
|
||||||
|
pubkey: VerifyingKey,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
let result = actors
|
||||||
|
.router
|
||||||
|
.ask(RequestClientApproval {
|
||||||
|
client_pubkey: pubkey,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(true) => Ok(()),
|
||||||
|
Ok(false) => Err(Error::ApproveError(ApproveError::Denied)),
|
||||||
|
Err(SendError::HandlerError(e)) => {
|
||||||
|
error!(error = ?e, "Approval upstream error");
|
||||||
|
Err(Error::ApproveError(ApproveError::Upstream(e)))
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!(error = ?e, "Approval request to router failed");
|
||||||
|
Err(Error::ApproveError(ApproveError::Internal))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
enum InsertClientResult {
|
||||||
|
Inserted(i32),
|
||||||
|
AlreadyExists,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn insert_client(
|
||||||
|
db: &db::DatabasePool,
|
||||||
|
pubkey: &VerifyingKey,
|
||||||
|
) -> Result<InsertClientResult, Error> {
|
||||||
|
let now = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs() as i32;
|
||||||
|
|
||||||
|
let mut conn = db.get().await.map_err(|e| {
|
||||||
|
error!(error = ?e, "Database pool error");
|
||||||
|
Error::DatabasePoolUnavailable
|
||||||
|
})?;
|
||||||
|
|
||||||
|
match insert_into(program_client::table)
|
||||||
|
.values((
|
||||||
|
program_client::public_key.eq(pubkey.as_bytes().to_vec()),
|
||||||
|
program_client::nonce.eq(1), // pre-incremented; challenge uses 0
|
||||||
|
program_client::created_at.eq(now),
|
||||||
|
program_client::updated_at.eq(now),
|
||||||
|
))
|
||||||
|
.execute(&mut conn)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(diesel::result::Error::DatabaseError(
|
||||||
|
diesel::result::DatabaseErrorKind::UniqueViolation,
|
||||||
|
_,
|
||||||
|
)) => return Ok(InsertClientResult::AlreadyExists),
|
||||||
|
Err(e) => {
|
||||||
|
error!(error = ?e, "Failed to insert new client");
|
||||||
|
return Err(Error::DatabaseOperationFailed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let client_id = program_client::table
|
||||||
|
.filter(program_client::public_key.eq(pubkey.as_bytes().to_vec()))
|
||||||
|
.order(program_client::id.desc())
|
||||||
|
.select(program_client::id)
|
||||||
|
.first::<i32>(&mut conn)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
error!(error = ?e, "Failed to load inserted client id");
|
||||||
|
Error::DatabaseOperationFailed
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(InsertClientResult::Inserted(client_id))
|
||||||
|
}
|
||||||
|
|
||||||
async fn challenge_client(
|
async fn challenge_client(
|
||||||
props: &mut ClientConnection,
|
props: &mut ClientConnection,
|
||||||
pubkey: VerifyingKey,
|
pubkey: VerifyingKey,
|
||||||
@@ -134,7 +229,10 @@ async fn challenge_client(
|
|||||||
|
|
||||||
fn connect_error_code(err: &Error) -> ConnectErrorCode {
|
fn connect_error_code(err: &Error) -> ConnectErrorCode {
|
||||||
match err {
|
match err {
|
||||||
Error::NotRegistered => ConnectErrorCode::ApprovalDenied,
|
Error::ApproveError(ApproveError::Denied) => ConnectErrorCode::ApprovalDenied,
|
||||||
|
Error::ApproveError(ApproveError::Upstream(
|
||||||
|
router::ApprovalError::NoUserAgentsConnected,
|
||||||
|
)) => ConnectErrorCode::NoUserAgentsOnline,
|
||||||
_ => ConnectErrorCode::Unknown,
|
_ => ConnectErrorCode::Unknown,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -156,7 +254,16 @@ async fn authenticate(props: &mut ClientConnection) -> Result<(VerifyingKey, i32
|
|||||||
|
|
||||||
let (client_id, nonce) = match get_nonce(&props.db, &pubkey).await? {
|
let (client_id, nonce) = match get_nonce(&props.db, &pubkey).await? {
|
||||||
Some((client_id, nonce)) => (client_id, nonce),
|
Some((client_id, nonce)) => (client_id, nonce),
|
||||||
None => return Err(Error::NotRegistered),
|
None => {
|
||||||
|
approve_new_client(&props.actors, pubkey).await?;
|
||||||
|
match insert_client(&props.db, &pubkey).await? {
|
||||||
|
InsertClientResult::Inserted(client_id) => (client_id, 0),
|
||||||
|
InsertClientResult::AlreadyExists => match get_nonce(&props.db, &pubkey).await? {
|
||||||
|
Some((client_id, nonce)) => (client_id, nonce),
|
||||||
|
None => return Err(Error::InternalError),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
challenge_client(props, pubkey, nonce).await?;
|
challenge_client(props, pubkey, nonce).await?;
|
||||||
|
|||||||
@@ -1,14 +1,20 @@
|
|||||||
use std::{collections::HashMap, ops::ControlFlow};
|
use std::{collections::HashMap, ops::ControlFlow};
|
||||||
|
|
||||||
|
use ed25519_dalek::VerifyingKey;
|
||||||
use kameo::{
|
use kameo::{
|
||||||
Actor,
|
Actor,
|
||||||
actor::{ActorId, ActorRef},
|
actor::{ActorId, ActorRef},
|
||||||
messages,
|
messages,
|
||||||
prelude::{ActorStopReason, Context, WeakActorRef},
|
prelude::{ActorStopReason, Context, WeakActorRef},
|
||||||
|
reply::DelegatedReply,
|
||||||
};
|
};
|
||||||
use tracing::info;
|
use tokio::{sync::watch, task::JoinSet};
|
||||||
|
use tracing::{info, warn};
|
||||||
|
|
||||||
use crate::actors::{client::session::ClientSession, user_agent::session::UserAgentSession};
|
use crate::actors::{
|
||||||
|
client::session::ClientSession,
|
||||||
|
user_agent::session::{RequestNewClientApproval, UserAgentSession},
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub struct MessageRouter {
|
pub struct MessageRouter {
|
||||||
@@ -50,6 +56,72 @@ impl Actor for MessageRouter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq, Hash)]
|
||||||
|
pub enum ApprovalError {
|
||||||
|
#[error("No user agents connected")]
|
||||||
|
NoUserAgentsConnected,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn request_client_approval(
|
||||||
|
user_agents: &[WeakActorRef<UserAgentSession>],
|
||||||
|
client_pubkey: VerifyingKey,
|
||||||
|
) -> Result<bool, ApprovalError> {
|
||||||
|
if user_agents.is_empty() {
|
||||||
|
return Err(ApprovalError::NoUserAgentsConnected);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut pool = JoinSet::new();
|
||||||
|
let (cancel_tx, cancel_rx) = watch::channel(());
|
||||||
|
|
||||||
|
for weak_ref in user_agents {
|
||||||
|
match weak_ref.upgrade() {
|
||||||
|
Some(agent) => {
|
||||||
|
let cancel_rx = cancel_rx.clone();
|
||||||
|
pool.spawn(async move {
|
||||||
|
agent
|
||||||
|
.ask(RequestNewClientApproval {
|
||||||
|
client_pubkey,
|
||||||
|
cancel_flag: cancel_rx.clone(),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
});
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
warn!(
|
||||||
|
id = weak_ref.id().to_string(),
|
||||||
|
actor = "MessageRouter",
|
||||||
|
event = "useragent.disconnected_before_approval"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
while let Some(result) = pool.join_next().await {
|
||||||
|
match result {
|
||||||
|
Ok(Ok(approved)) => {
|
||||||
|
let _ = cancel_tx.send(());
|
||||||
|
return Ok(approved);
|
||||||
|
}
|
||||||
|
Ok(Err(err)) => {
|
||||||
|
warn!(
|
||||||
|
?err,
|
||||||
|
actor = "MessageRouter",
|
||||||
|
event = "useragent.approval_error"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
warn!(
|
||||||
|
?err,
|
||||||
|
actor = "MessageRouter",
|
||||||
|
event = "useragent.approval_task_failed"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(ApprovalError::NoUserAgentsConnected)
|
||||||
|
}
|
||||||
|
|
||||||
#[messages]
|
#[messages]
|
||||||
impl MessageRouter {
|
impl MessageRouter {
|
||||||
#[message(ctx)]
|
#[message(ctx)]
|
||||||
@@ -73,4 +145,28 @@ impl MessageRouter {
|
|||||||
ctx.actor_ref().link(&actor).await;
|
ctx.actor_ref().link(&actor).await;
|
||||||
self.clients.insert(actor.id(), actor);
|
self.clients.insert(actor.id(), actor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[message(ctx)]
|
||||||
|
pub async fn request_client_approval(
|
||||||
|
&mut self,
|
||||||
|
client_pubkey: VerifyingKey,
|
||||||
|
ctx: &mut Context<Self, DelegatedReply<Result<bool, ApprovalError>>>,
|
||||||
|
) -> DelegatedReply<Result<bool, ApprovalError>> {
|
||||||
|
let (reply, Some(reply_sender)) = ctx.reply_sender() else {
|
||||||
|
panic!("Exptected `request_client_approval` to have callback channel");
|
||||||
|
};
|
||||||
|
|
||||||
|
let weak_refs = self
|
||||||
|
.user_agents
|
||||||
|
.values()
|
||||||
|
.map(|agent| agent.downgrade())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
tokio::task::spawn(async move {
|
||||||
|
let result = request_client_approval(&weak_refs, client_pubkey).await;
|
||||||
|
reply_sender.send(result);
|
||||||
|
});
|
||||||
|
|
||||||
|
reply
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
use arbiter_proto::{
|
use arbiter_proto::{
|
||||||
proto::user_agent::{UserAgentRequest, UserAgentResponse},
|
proto::user_agent::{
|
||||||
|
SdkClientError as ProtoSdkClientError, UserAgentRequest, UserAgentResponse,
|
||||||
|
},
|
||||||
transport::Bi,
|
transport::Bi,
|
||||||
};
|
};
|
||||||
use kameo::actor::Spawn as _;
|
use kameo::actor::Spawn as _;
|
||||||
@@ -24,12 +26,27 @@ pub enum TransportResponseError {
|
|||||||
StateTransitionFailed,
|
StateTransitionFailed,
|
||||||
#[error("Vault is not available")]
|
#[error("Vault is not available")]
|
||||||
KeyHolderActorUnreachable,
|
KeyHolderActorUnreachable,
|
||||||
|
#[error("SDK client approve failed: {0:?}")]
|
||||||
|
SdkClientApprove(ProtoSdkClientError),
|
||||||
|
#[error("SDK client list failed: {0:?}")]
|
||||||
|
SdkClientList(ProtoSdkClientError),
|
||||||
|
#[error("SDK client revoke failed: {0:?}")]
|
||||||
|
SdkClientRevoke(ProtoSdkClientError),
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Auth(#[from] auth::Error),
|
Auth(#[from] auth::Error),
|
||||||
#[error("Failed registering connection")]
|
#[error("Failed registering connection")]
|
||||||
ConnectionRegistrationFailed,
|
ConnectionRegistrationFailed,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl TransportResponseError {
|
||||||
|
pub fn is_terminal(&self) -> bool {
|
||||||
|
!matches!(
|
||||||
|
self,
|
||||||
|
Self::SdkClientApprove(_) | Self::SdkClientList(_) | Self::SdkClientRevoke(_)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub type Transport =
|
pub type Transport =
|
||||||
Box<dyn Bi<UserAgentRequest, Result<UserAgentResponse, TransportResponseError>> + Send>;
|
Box<dyn Bi<UserAgentRequest, Result<UserAgentResponse, TransportResponseError>> + Send>;
|
||||||
|
|
||||||
|
|||||||
@@ -3,21 +3,22 @@ use std::{ops::DerefMut, sync::Mutex};
|
|||||||
use arbiter_proto::proto::{
|
use arbiter_proto::proto::{
|
||||||
evm as evm_proto,
|
evm as evm_proto,
|
||||||
user_agent::{
|
user_agent::{
|
||||||
SdkClientApproveRequest, SdkClientApproveResponse, SdkClientEntry,
|
ClientConnectionCancel, ClientConnectionRequest, SdkClientApproveRequest,
|
||||||
SdkClientError as ProtoSdkClientError, SdkClientList, SdkClientListResponse,
|
SdkClientApproveResponse, SdkClientEntry, SdkClientError as ProtoSdkClientError,
|
||||||
SdkClientRevokeRequest, SdkClientRevokeResponse, UnsealEncryptedKey, UnsealResult,
|
SdkClientList, SdkClientListResponse, SdkClientRevokeRequest, SdkClientRevokeResponse,
|
||||||
UnsealStart, UnsealStartResponse, UserAgentRequest, UserAgentResponse,
|
UnsealEncryptedKey, UnsealResult, UnsealStart, UnsealStartResponse, UserAgentRequest,
|
||||||
sdk_client_approve_response, sdk_client_list_response, sdk_client_revoke_response,
|
UserAgentResponse, sdk_client_approve_response, sdk_client_list_response,
|
||||||
user_agent_request::Payload as UserAgentRequestPayload,
|
sdk_client_revoke_response, user_agent_request::Payload as UserAgentRequestPayload,
|
||||||
user_agent_response::Payload as UserAgentResponsePayload,
|
user_agent_response::Payload as UserAgentResponsePayload,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit};
|
use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit};
|
||||||
use diesel::{ExpressionMethods as _, QueryDsl as _, dsl::insert_into};
|
use diesel::{ExpressionMethods as _, QueryDsl as _, dsl::insert_into};
|
||||||
use diesel_async::RunQueryDsl as _;
|
use diesel_async::RunQueryDsl as _;
|
||||||
use kameo::{Actor, error::SendError, prelude::Context};
|
use ed25519_dalek::VerifyingKey;
|
||||||
|
use kameo::{Actor, error::SendError, messages, prelude::Context};
|
||||||
use memsafe::MemSafe;
|
use memsafe::MemSafe;
|
||||||
use tokio::select;
|
use tokio::{select, sync::watch};
|
||||||
use tracing::{error, info};
|
use tracing::{error, info};
|
||||||
use x25519_dalek::{EphemeralSecret, PublicKey};
|
use x25519_dalek::{EphemeralSecret, PublicKey};
|
||||||
|
|
||||||
@@ -115,6 +116,52 @@ impl UserAgentSession {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[messages]
|
||||||
|
impl UserAgentSession {
|
||||||
|
// TODO: Think about refactoring it to state-machine based flow, as we already have one
|
||||||
|
#[message(ctx)]
|
||||||
|
pub async fn request_new_client_approval(
|
||||||
|
&mut self,
|
||||||
|
client_pubkey: VerifyingKey,
|
||||||
|
mut cancel_flag: watch::Receiver<()>,
|
||||||
|
ctx: &mut Context<Self, Result<bool, Error>>,
|
||||||
|
) -> Result<bool, Error> {
|
||||||
|
self.send_msg(
|
||||||
|
UserAgentResponsePayload::ClientConnectionRequest(ClientConnectionRequest {
|
||||||
|
pubkey: client_pubkey.as_bytes().to_vec(),
|
||||||
|
}),
|
||||||
|
ctx,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let extractor = |msg| {
|
||||||
|
if let UserAgentRequestPayload::ClientConnectionResponse(client_connection_response) =
|
||||||
|
msg
|
||||||
|
{
|
||||||
|
Some(client_connection_response)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
|
_ = cancel_flag.changed() => {
|
||||||
|
info!(actor = "useragent", "client connection approval cancelled");
|
||||||
|
self.send_msg(
|
||||||
|
UserAgentResponsePayload::ClientConnectionCancel(ClientConnectionCancel {}),
|
||||||
|
ctx,
|
||||||
|
).await?;
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
result = self.expect_msg(extractor, ctx) => {
|
||||||
|
let result = result?;
|
||||||
|
info!(actor = "useragent", "received client connection approval result: approved={}", result.approved);
|
||||||
|
Ok(result.approved)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl UserAgentSession {
|
impl UserAgentSession {
|
||||||
pub async fn process_transport_inbound(&mut self, req: UserAgentRequest) -> Output {
|
pub async fn process_transport_inbound(&mut self, req: UserAgentRequest) -> Output {
|
||||||
let msg = req.payload.ok_or_else(|| {
|
let msg = req.payload.ok_or_else(|| {
|
||||||
@@ -304,11 +351,9 @@ impl UserAgentSession {
|
|||||||
use sdk_client_approve_response::Result as ApproveResult;
|
use sdk_client_approve_response::Result as ApproveResult;
|
||||||
|
|
||||||
if req.pubkey.len() != 32 {
|
if req.pubkey.len() != 32 {
|
||||||
return Ok(response(UserAgentResponsePayload::SdkClientApprove(
|
return Err(TransportResponseError::SdkClientApprove(
|
||||||
SdkClientApproveResponse {
|
ProtoSdkClientError::Internal,
|
||||||
result: Some(ApproveResult::Error(ProtoSdkClientError::Internal.into())),
|
));
|
||||||
},
|
|
||||||
)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let now = std::time::SystemTime::now()
|
let now = std::time::SystemTime::now()
|
||||||
@@ -320,11 +365,9 @@ impl UserAgentSession {
|
|||||||
Ok(c) => c,
|
Ok(c) => c,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!(?e, "Failed to get DB connection for sdk_client_approve");
|
error!(?e, "Failed to get DB connection for sdk_client_approve");
|
||||||
return Ok(response(UserAgentResponsePayload::SdkClientApprove(
|
return Err(TransportResponseError::SdkClientApprove(
|
||||||
SdkClientApproveResponse {
|
ProtoSdkClientError::Internal,
|
||||||
result: Some(ApproveResult::Error(ProtoSdkClientError::Internal.into())),
|
));
|
||||||
},
|
|
||||||
)));
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -363,33 +406,23 @@ impl UserAgentSession {
|
|||||||
)),
|
)),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!(?e, "Failed to fetch inserted SDK client");
|
error!(?e, "Failed to fetch inserted SDK client");
|
||||||
Ok(response(UserAgentResponsePayload::SdkClientApprove(
|
Err(TransportResponseError::SdkClientApprove(
|
||||||
SdkClientApproveResponse {
|
ProtoSdkClientError::Internal,
|
||||||
result: Some(ApproveResult::Error(
|
))
|
||||||
ProtoSdkClientError::Internal.into(),
|
|
||||||
)),
|
|
||||||
},
|
|
||||||
)))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(diesel::result::Error::DatabaseError(
|
Err(diesel::result::Error::DatabaseError(
|
||||||
diesel::result::DatabaseErrorKind::UniqueViolation,
|
diesel::result::DatabaseErrorKind::UniqueViolation,
|
||||||
_,
|
_,
|
||||||
)) => Ok(response(UserAgentResponsePayload::SdkClientApprove(
|
)) => Err(TransportResponseError::SdkClientApprove(
|
||||||
SdkClientApproveResponse {
|
ProtoSdkClientError::AlreadyExists,
|
||||||
result: Some(ApproveResult::Error(
|
)),
|
||||||
ProtoSdkClientError::AlreadyExists.into(),
|
|
||||||
)),
|
|
||||||
},
|
|
||||||
))),
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!(?e, "Failed to insert SDK client");
|
error!(?e, "Failed to insert SDK client");
|
||||||
Ok(response(UserAgentResponsePayload::SdkClientApprove(
|
Err(TransportResponseError::SdkClientApprove(
|
||||||
SdkClientApproveResponse {
|
ProtoSdkClientError::Internal,
|
||||||
result: Some(ApproveResult::Error(ProtoSdkClientError::Internal.into())),
|
))
|
||||||
},
|
|
||||||
)))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -399,13 +432,9 @@ impl UserAgentSession {
|
|||||||
Ok(c) => c,
|
Ok(c) => c,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!(?e, "Failed to get DB connection for sdk_client_list");
|
error!(?e, "Failed to get DB connection for sdk_client_list");
|
||||||
return Ok(response(UserAgentResponsePayload::SdkClientList(
|
return Err(TransportResponseError::SdkClientList(
|
||||||
SdkClientListResponse {
|
ProtoSdkClientError::Internal,
|
||||||
result: Some(sdk_client_list_response::Result::Error(
|
));
|
||||||
ProtoSdkClientError::Internal.into(),
|
|
||||||
)),
|
|
||||||
},
|
|
||||||
)));
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -434,13 +463,9 @@ impl UserAgentSession {
|
|||||||
))),
|
))),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!(?e, "Failed to list SDK clients");
|
error!(?e, "Failed to list SDK clients");
|
||||||
Ok(response(UserAgentResponsePayload::SdkClientList(
|
Err(TransportResponseError::SdkClientList(
|
||||||
SdkClientListResponse {
|
ProtoSdkClientError::Internal,
|
||||||
result: Some(sdk_client_list_response::Result::Error(
|
))
|
||||||
ProtoSdkClientError::Internal.into(),
|
|
||||||
)),
|
|
||||||
},
|
|
||||||
)))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -452,11 +477,9 @@ impl UserAgentSession {
|
|||||||
Ok(c) => c,
|
Ok(c) => c,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!(?e, "Failed to get DB connection for sdk_client_revoke");
|
error!(?e, "Failed to get DB connection for sdk_client_revoke");
|
||||||
return Ok(response(UserAgentResponsePayload::SdkClientRevoke(
|
return Err(TransportResponseError::SdkClientRevoke(
|
||||||
SdkClientRevokeResponse {
|
ProtoSdkClientError::Internal,
|
||||||
result: Some(RevokeResult::Error(ProtoSdkClientError::Internal.into())),
|
));
|
||||||
},
|
|
||||||
)));
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -465,11 +488,9 @@ impl UserAgentSession {
|
|||||||
.execute(&mut conn)
|
.execute(&mut conn)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(0) => Ok(response(UserAgentResponsePayload::SdkClientRevoke(
|
Ok(0) => Err(TransportResponseError::SdkClientRevoke(
|
||||||
SdkClientRevokeResponse {
|
ProtoSdkClientError::NotFound,
|
||||||
result: Some(RevokeResult::Error(ProtoSdkClientError::NotFound.into())),
|
)),
|
||||||
},
|
|
||||||
))),
|
|
||||||
Ok(_) => Ok(response(UserAgentResponsePayload::SdkClientRevoke(
|
Ok(_) => Ok(response(UserAgentResponsePayload::SdkClientRevoke(
|
||||||
SdkClientRevokeResponse {
|
SdkClientRevokeResponse {
|
||||||
result: Some(RevokeResult::Ok(())),
|
result: Some(RevokeResult::Ok(())),
|
||||||
@@ -478,20 +499,14 @@ impl UserAgentSession {
|
|||||||
Err(diesel::result::Error::DatabaseError(
|
Err(diesel::result::Error::DatabaseError(
|
||||||
diesel::result::DatabaseErrorKind::ForeignKeyViolation,
|
diesel::result::DatabaseErrorKind::ForeignKeyViolation,
|
||||||
_,
|
_,
|
||||||
)) => Ok(response(UserAgentResponsePayload::SdkClientRevoke(
|
)) => Err(TransportResponseError::SdkClientRevoke(
|
||||||
SdkClientRevokeResponse {
|
ProtoSdkClientError::HasRelatedData,
|
||||||
result: Some(RevokeResult::Error(
|
)),
|
||||||
ProtoSdkClientError::HasRelatedData.into(),
|
|
||||||
)),
|
|
||||||
},
|
|
||||||
))),
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!(?e, "Failed to delete SDK client");
|
error!(?e, "Failed to delete SDK client");
|
||||||
Ok(response(UserAgentResponsePayload::SdkClientRevoke(
|
Err(TransportResponseError::SdkClientRevoke(
|
||||||
SdkClientRevokeResponse {
|
ProtoSdkClientError::Internal,
|
||||||
result: Some(RevokeResult::Error(ProtoSdkClientError::Internal.into())),
|
))
|
||||||
},
|
|
||||||
)))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -558,8 +573,15 @@ impl Actor for UserAgentSession {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
let _ = self.props.transport.send(Err(err)).await;
|
let should_stop = err.is_terminal();
|
||||||
return Some(kameo::mailbox::Signal::Stop);
|
if self.props.transport.send(Err(err)).await.is_err() {
|
||||||
|
error!(actor = "useragent", reason = "channel closed", "send.failed");
|
||||||
|
return Some(kameo::mailbox::Signal::Stop);
|
||||||
|
}
|
||||||
|
|
||||||
|
if should_stop {
|
||||||
|
return Some(kameo::mailbox::Signal::Stop);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,12 @@
|
|||||||
use arbiter_proto::{
|
use arbiter_proto::{
|
||||||
proto::{
|
proto::{
|
||||||
client::{ClientRequest, ClientResponse},
|
client::{ClientRequest, ClientResponse},
|
||||||
user_agent::{UserAgentRequest, UserAgentResponse},
|
user_agent::{
|
||||||
|
SdkClientApproveResponse, SdkClientListResponse, SdkClientRevokeResponse,
|
||||||
|
UserAgentRequest, UserAgentResponse, sdk_client_approve_response,
|
||||||
|
sdk_client_list_response, sdk_client_revoke_response,
|
||||||
|
user_agent_response::Payload as UserAgentResponsePayload,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
transport::{IdentityRecvConverter, SendConverter, grpc},
|
transport::{IdentityRecvConverter, SendConverter, grpc},
|
||||||
};
|
};
|
||||||
@@ -37,6 +42,27 @@ impl SendConverter for UserAgentGrpcSender {
|
|||||||
fn convert(&self, item: Self::Input) -> Self::Output {
|
fn convert(&self, item: Self::Input) -> Self::Output {
|
||||||
match item {
|
match item {
|
||||||
Ok(message) => Ok(message),
|
Ok(message) => Ok(message),
|
||||||
|
Err(TransportResponseError::SdkClientApprove(code)) => Ok(UserAgentResponse {
|
||||||
|
payload: Some(UserAgentResponsePayload::SdkClientApprove(
|
||||||
|
SdkClientApproveResponse {
|
||||||
|
result: Some(sdk_client_approve_response::Result::Error(code.into())),
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
}),
|
||||||
|
Err(TransportResponseError::SdkClientList(code)) => Ok(UserAgentResponse {
|
||||||
|
payload: Some(UserAgentResponsePayload::SdkClientList(
|
||||||
|
SdkClientListResponse {
|
||||||
|
result: Some(sdk_client_list_response::Result::Error(code.into())),
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
}),
|
||||||
|
Err(TransportResponseError::SdkClientRevoke(code)) => Ok(UserAgentResponse {
|
||||||
|
payload: Some(UserAgentResponsePayload::SdkClientRevoke(
|
||||||
|
SdkClientRevokeResponse {
|
||||||
|
result: Some(sdk_client_revoke_response::Result::Error(code.into())),
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
}),
|
||||||
Err(err) => Err(user_agent_error_status(err)),
|
Err(err) => Err(user_agent_error_status(err)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -79,7 +105,7 @@ fn client_auth_error_status(value: &client::auth::Error) -> Status {
|
|||||||
Status::invalid_argument("Failed to convert pubkey to VerifyingKey")
|
Status::invalid_argument("Failed to convert pubkey to VerifyingKey")
|
||||||
}
|
}
|
||||||
Error::InvalidChallengeSolution => Status::unauthenticated(value.to_string()),
|
Error::InvalidChallengeSolution => Status::unauthenticated(value.to_string()),
|
||||||
Error::NotRegistered => Status::permission_denied(value.to_string()),
|
Error::ApproveError(_) => Status::permission_denied(value.to_string()),
|
||||||
Error::Transport => Status::internal("Transport error"),
|
Error::Transport => Status::internal("Transport error"),
|
||||||
Error::DatabasePoolUnavailable => Status::internal("Database pool error"),
|
Error::DatabasePoolUnavailable => Status::internal("Database pool error"),
|
||||||
Error::DatabaseOperationFailed => Status::internal("Database error"),
|
Error::DatabaseOperationFailed => Status::internal("Database error"),
|
||||||
@@ -103,6 +129,11 @@ fn user_agent_error_status(value: TransportResponseError) -> Status {
|
|||||||
TransportResponseError::KeyHolderActorUnreachable => {
|
TransportResponseError::KeyHolderActorUnreachable => {
|
||||||
Status::internal("Vault is not available")
|
Status::internal("Vault is not available")
|
||||||
}
|
}
|
||||||
|
TransportResponseError::SdkClientApprove(_)
|
||||||
|
| TransportResponseError::SdkClientList(_)
|
||||||
|
| TransportResponseError::SdkClientRevoke(_) => {
|
||||||
|
Status::internal("SDK client operation failed")
|
||||||
|
}
|
||||||
TransportResponseError::Auth(ref err) => auth_error_status(err),
|
TransportResponseError::Auth(ref err) => auth_error_status(err),
|
||||||
TransportResponseError::ConnectionRegistrationFailed => {
|
TransportResponseError::ConnectionRegistrationFailed => {
|
||||||
Status::internal("Failed registering connection")
|
Status::internal("Failed registering connection")
|
||||||
|
|||||||
@@ -5,7 +5,10 @@ use arbiter_proto::proto::user_agent::{
|
|||||||
user_agent_response::Payload as UserAgentResponsePayload,
|
user_agent_response::Payload as UserAgentResponsePayload,
|
||||||
};
|
};
|
||||||
use arbiter_server::{
|
use arbiter_server::{
|
||||||
actors::{GlobalActors, user_agent::session::UserAgentSession},
|
actors::{
|
||||||
|
GlobalActors,
|
||||||
|
user_agent::{TransportResponseError, session::UserAgentSession},
|
||||||
|
},
|
||||||
db,
|
db,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -68,22 +71,15 @@ async fn test_sdk_client_approve_duplicate_returns_already_exists() {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let response = session
|
let err = session
|
||||||
.process_transport_inbound(req)
|
.process_transport_inbound(req)
|
||||||
.await
|
.await
|
||||||
.expect("second insert should not panic");
|
.expect_err("second insert should return typed TransportResponseError");
|
||||||
|
|
||||||
match response.payload.unwrap() {
|
assert_eq!(
|
||||||
UserAgentResponsePayload::SdkClientApprove(resp) => match resp.result.unwrap() {
|
err,
|
||||||
sdk_client_approve_response::Result::Error(code) => {
|
TransportResponseError::SdkClientApprove(ProtoSdkClientError::AlreadyExists)
|
||||||
assert_eq!(code, ProtoSdkClientError::AlreadyExists as i32);
|
);
|
||||||
}
|
|
||||||
sdk_client_approve_response::Result::Client(_) => {
|
|
||||||
panic!("Expected AlreadyExists error for duplicate pubkey")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
other => panic!("Expected SdkClientApprove, got {other:?}"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -203,26 +199,19 @@ async fn test_sdk_client_revoke_not_found_returns_error() {
|
|||||||
let db = db::create_test_pool().await;
|
let db = db::create_test_pool().await;
|
||||||
let mut session = make_session(&db).await;
|
let mut session = make_session(&db).await;
|
||||||
|
|
||||||
let response = session
|
let err = session
|
||||||
.process_transport_inbound(UserAgentRequest {
|
.process_transport_inbound(UserAgentRequest {
|
||||||
payload: Some(UserAgentRequestPayload::SdkClientRevoke(
|
payload: Some(UserAgentRequestPayload::SdkClientRevoke(
|
||||||
SdkClientRevokeRequest { client_id: 9999 },
|
SdkClientRevokeRequest { client_id: 9999 },
|
||||||
)),
|
)),
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.expect_err("missing client should return typed TransportResponseError");
|
||||||
|
|
||||||
match response.payload.unwrap() {
|
assert_eq!(
|
||||||
UserAgentResponsePayload::SdkClientRevoke(resp) => match resp.result.unwrap() {
|
err,
|
||||||
sdk_client_revoke_response::Result::Error(code) => {
|
TransportResponseError::SdkClientRevoke(ProtoSdkClientError::NotFound)
|
||||||
assert_eq!(code, ProtoSdkClientError::NotFound as i32);
|
);
|
||||||
}
|
|
||||||
sdk_client_revoke_response::Result::Ok(_) => {
|
|
||||||
panic!("Expected NotFound error for missing client_id")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
other => panic!("Expected SdkClientRevoke, got {other:?}"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
[package]
|
|
||||||
name = "arbiter-terrors-poc"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition = "2024"
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
terrors = "0.3"
|
|
||||||
@@ -1,139 +0,0 @@
|
|||||||
use crate::errors::{InternalError1, InternalError2, InvalidSignature, NotRegistered};
|
|
||||||
use terrors::OneOf;
|
|
||||||
|
|
||||||
use crate::errors::ProtoError;
|
|
||||||
|
|
||||||
// Each sub-call's error type already implements DrainInto<ProtoError>, so we convert
|
|
||||||
// directly to ProtoError without broaden — no turbofish needed anywhere.
|
|
||||||
//
|
|
||||||
// Call chain:
|
|
||||||
// load_config() → OneOf<(InternalError2,)> → ProtoError::from
|
|
||||||
// get_nonce() → OneOf<(InternalError1, InternalError2)> → ProtoError::from
|
|
||||||
// verify_sig() → OneOf<(InvalidSignature,)> → ProtoError::from
|
|
||||||
pub fn process_request(id: u32, sig: &str) -> Result<String, ProtoError> {
|
|
||||||
if id == 0 {
|
|
||||||
return Err(ProtoError::NotRegistered);
|
|
||||||
}
|
|
||||||
|
|
||||||
let config = load_config(id).map_err(ProtoError::from)?;
|
|
||||||
let nonce = crate::db::get_nonce(id).map_err(ProtoError::from)?;
|
|
||||||
verify_signature(nonce, sig).map_err(ProtoError::from)?;
|
|
||||||
|
|
||||||
Ok(format!("config={config} nonce={nonce} sig={sig}"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simulates loading a config value.
|
|
||||||
// id=97 triggers InternalError2 ("config read failed").
|
|
||||||
fn load_config(id: u32) -> Result<String, OneOf<(InternalError2,)>> {
|
|
||||||
if id == 97 {
|
|
||||||
return Err(OneOf::new(InternalError2("config read failed".to_owned())));
|
|
||||||
}
|
|
||||||
Ok(format!("cfg-{id}"))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn verify_signature(_nonce: u32, sig: &str) -> Result<(), OneOf<(InvalidSignature,)>> {
|
|
||||||
if sig != "ok" {
|
|
||||||
return Err(OneOf::new(InvalidSignature));
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
type AuthError = OneOf<(
|
|
||||||
NotRegistered,
|
|
||||||
InvalidSignature,
|
|
||||||
InternalError1,
|
|
||||||
InternalError2,
|
|
||||||
)>;
|
|
||||||
|
|
||||||
pub fn authenticate(id: u32, sig: &str) -> Result<u32, AuthError> {
|
|
||||||
if id == 0 {
|
|
||||||
return Err(OneOf::new(NotRegistered));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return type AuthError lets the compiler infer the broaden target.
|
|
||||||
let nonce = crate::db::get_nonce(id).map_err(OneOf::broaden)?;
|
|
||||||
verify_signature(nonce, sig).map_err(OneOf::broaden)?;
|
|
||||||
|
|
||||||
Ok(nonce)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn verify_signature_ok() {
|
|
||||||
assert!(verify_signature(42, "ok").is_ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn verify_signature_bad() {
|
|
||||||
let err = verify_signature(42, "bad").unwrap_err();
|
|
||||||
assert!(err.narrow::<crate::errors::InvalidSignature, _>().is_ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn authenticate_success() {
|
|
||||||
assert_eq!(authenticate(1, "ok").unwrap(), 42);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn authenticate_not_registered() {
|
|
||||||
let err = authenticate(0, "ok").unwrap_err();
|
|
||||||
assert!(err.narrow::<crate::errors::NotRegistered, _>().is_ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn authenticate_invalid_signature() {
|
|
||||||
let err = authenticate(1, "bad").unwrap_err();
|
|
||||||
assert!(err.narrow::<crate::errors::InvalidSignature, _>().is_ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn authenticate_internal_error1() {
|
|
||||||
let err = authenticate(99, "ok").unwrap_err();
|
|
||||||
assert!(err.narrow::<crate::errors::InternalError1, _>().is_ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn authenticate_internal_error2() {
|
|
||||||
let err = authenticate(98, "ok").unwrap_err();
|
|
||||||
assert!(err.narrow::<crate::errors::InternalError2, _>().is_ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn process_request_success() {
|
|
||||||
let result = process_request(1, "ok").unwrap();
|
|
||||||
assert!(result.contains("nonce=42"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn process_request_not_registered() {
|
|
||||||
let err = process_request(0, "ok").unwrap_err();
|
|
||||||
assert!(matches!(err, crate::errors::ProtoError::NotRegistered));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn process_request_invalid_signature() {
|
|
||||||
let err = process_request(1, "bad").unwrap_err();
|
|
||||||
assert!(matches!(err, crate::errors::ProtoError::InvalidSignature));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn process_request_internal_from_config() {
|
|
||||||
// id=97 → load_config returns InternalError2
|
|
||||||
let err = process_request(97, "ok").unwrap_err();
|
|
||||||
assert!(
|
|
||||||
matches!(err, crate::errors::ProtoError::Internal(ref msg) if msg == "config read failed")
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn process_request_internal_from_db() {
|
|
||||||
// id=99 → get_nonce returns InternalError1
|
|
||||||
let err = process_request(99, "ok").unwrap_err();
|
|
||||||
assert!(
|
|
||||||
matches!(err, crate::errors::ProtoError::Internal(ref msg) if msg == "db pool unavailable")
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
use crate::errors::{InternalError1, InternalError2};
|
|
||||||
use terrors::OneOf;
|
|
||||||
|
|
||||||
// Simulates fetching a nonce from a database.
|
|
||||||
// id=99 → InternalError1 (pool unavailable)
|
|
||||||
// id=98 → InternalError2 (query timeout)
|
|
||||||
pub fn get_nonce(id: u32) -> Result<u32, OneOf<(InternalError1, InternalError2)>> {
|
|
||||||
match id {
|
|
||||||
99 => Err(OneOf::new(InternalError1("db pool unavailable".to_owned()))),
|
|
||||||
98 => Err(OneOf::new(InternalError2("query timeout".to_owned()))),
|
|
||||||
_ => Ok(42),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn get_nonce_returns_nonce_for_valid_id() {
|
|
||||||
assert_eq!(get_nonce(1).unwrap(), 42);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn get_nonce_returns_internal_error1_for_sentinel() {
|
|
||||||
let err = get_nonce(99).unwrap_err();
|
|
||||||
let internal = err.narrow::<crate::errors::InternalError1, _>().unwrap();
|
|
||||||
assert_eq!(internal.0, "db pool unavailable");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn get_nonce_returns_internal_error2_for_sentinel() {
|
|
||||||
let err = get_nonce(98).unwrap_err();
|
|
||||||
let e = err.narrow::<crate::errors::InternalError1, _>().unwrap_err();
|
|
||||||
let internal = e.take::<crate::errors::InternalError2>();
|
|
||||||
assert_eq!(internal.0, "query timeout");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,130 +0,0 @@
|
|||||||
use terrors::OneOf;
|
|
||||||
|
|
||||||
// Wire boundary type — what would go into a proto response
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum ProtoError {
|
|
||||||
NotRegistered,
|
|
||||||
InvalidSignature,
|
|
||||||
Internal(String), // Or Box<dyn Error>, who cares?
|
|
||||||
}
|
|
||||||
|
|
||||||
// Internal terrors types
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct NotRegistered;
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct InvalidSignature;
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct InternalError1(pub String);
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct InternalError2(pub String);
|
|
||||||
|
|
||||||
// Errors can be scattered across the codebase as long as they implement Into<ProtoError>
|
|
||||||
impl From<NotRegistered> for ProtoError {
|
|
||||||
fn from(_: NotRegistered) -> Self {
|
|
||||||
ProtoError::NotRegistered
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<InvalidSignature> for ProtoError {
|
|
||||||
fn from(_: InvalidSignature) -> Self {
|
|
||||||
ProtoError::InvalidSignature
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<InternalError1> for ProtoError {
|
|
||||||
fn from(e: InternalError1) -> Self {
|
|
||||||
ProtoError::Internal(e.0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl From<InternalError2> for ProtoError {
|
|
||||||
fn from(e: InternalError2) -> Self {
|
|
||||||
ProtoError::Internal(e.0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Private helper trait for converting from OneOf<T...> where each T can be converted
|
|
||||||
/// into the target type `O` by recursively narrowing until a match is found.
|
|
||||||
///
|
|
||||||
/// IDK why this isn't already in terrors.
|
|
||||||
trait DrainInto<O>: terrors::TypeSet + Sized {
|
|
||||||
fn drain(e: OneOf<Self>) -> O;
|
|
||||||
}
|
|
||||||
|
|
||||||
macro_rules! impl_drain_into {
|
|
||||||
($head:ident) => {
|
|
||||||
impl<$head, O> DrainInto<O> for ($head,)
|
|
||||||
where
|
|
||||||
$head: Into<O> + 'static,
|
|
||||||
{
|
|
||||||
fn drain(e: OneOf<($head,)>) -> O {
|
|
||||||
e.take().into()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
($head:ident, $($tail:ident),+) => {
|
|
||||||
impl<$head, $($tail),+, O> DrainInto<O> for ($head, $($tail),+)
|
|
||||||
where
|
|
||||||
$head: Into<O> + 'static,
|
|
||||||
($($tail,)+): DrainInto<O>,
|
|
||||||
{
|
|
||||||
fn drain(e: OneOf<($head, $($tail),+)>) -> O {
|
|
||||||
match e.narrow::<$head, _>() {
|
|
||||||
Ok(h) => h.into(),
|
|
||||||
Err(rest) => <($($tail,)+)>::drain(rest),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl_drain_into!($($tail),+);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generates impls for all tuple sizes from 1 up to 7 (restricted by terrors internal impl).
|
|
||||||
// Each invocation produces one impl then recurses on the tail.
|
|
||||||
impl_drain_into!(A, B, C, D, E, F, G, H, I);
|
|
||||||
|
|
||||||
// Blanket From impl: body delegates to the recursive drain.
|
|
||||||
impl<E: DrainInto<ProtoError>> From<OneOf<E>> for ProtoError {
|
|
||||||
fn from(e: OneOf<E>) -> Self {
|
|
||||||
E::drain(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn not_registered_converts_to_proto() {
|
|
||||||
let e: ProtoError = NotRegistered.into();
|
|
||||||
assert!(matches!(e, ProtoError::NotRegistered));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn invalid_signature_converts_to_proto() {
|
|
||||||
let e: ProtoError = InvalidSignature.into();
|
|
||||||
assert!(matches!(e, ProtoError::InvalidSignature));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn internal_converts_to_proto() {
|
|
||||||
let e: ProtoError = InternalError1("boom".into()).into();
|
|
||||||
assert!(matches!(e, ProtoError::Internal(msg) if msg == "boom"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn one_of_remainder_converts_to_proto_invalid_signature() {
|
|
||||||
use terrors::OneOf;
|
|
||||||
let e: OneOf<(InvalidSignature, InternalError1)> = OneOf::new(InvalidSignature);
|
|
||||||
let proto = ProtoError::from(e);
|
|
||||||
assert!(matches!(proto, ProtoError::InvalidSignature));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn one_of_remainder_converts_to_proto_internal() {
|
|
||||||
use terrors::OneOf;
|
|
||||||
let e: OneOf<(InvalidSignature, InternalError1)> =
|
|
||||||
OneOf::new(InternalError1("db fail".into()));
|
|
||||||
let proto = ProtoError::from(e);
|
|
||||||
assert!(matches!(proto, ProtoError::Internal(msg) if msg == "db fail"));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
mod auth;
|
|
||||||
mod db;
|
|
||||||
mod errors;
|
|
||||||
|
|
||||||
use errors::ProtoError;
|
|
||||||
|
|
||||||
fn run(id: u32, sig: &str) {
|
|
||||||
print!("authenticate(id={id}, sig={sig:?}) => ");
|
|
||||||
match auth::authenticate(id, sig) {
|
|
||||||
Ok(nonce) => println!("Ok(nonce={nonce})"),
|
|
||||||
Err(e) => match e.narrow::<errors::NotRegistered, _>() {
|
|
||||||
Ok(_) => println!("Err(NotRegistered) — handled locally"),
|
|
||||||
Err(remaining) => {
|
|
||||||
let proto = ProtoError::from(remaining);
|
|
||||||
println!("Err(ProtoError::{proto:?}) — forwarded to wire");
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_process(id: u32, sig: &str) {
|
|
||||||
print!("process_request(id={id}, sig={sig:?}) => ");
|
|
||||||
match auth::process_request(id, sig) {
|
|
||||||
Ok(s) => println!("Ok({s})"),
|
|
||||||
Err(e) => println!("Err(ProtoError::{e:?})"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() {
|
|
||||||
println!("=== authenticate ===");
|
|
||||||
run(0, "ok"); // NotRegistered
|
|
||||||
run(1, "bad"); // InvalidSignature
|
|
||||||
run(99, "ok"); // InternalError1
|
|
||||||
run(98, "ok"); // InternalError2
|
|
||||||
run(1, "ok"); // success
|
|
||||||
|
|
||||||
println!("\n=== process_request (Try chain) ===");
|
|
||||||
run_process(0, "ok"); // NotRegistered (guard, no I/O)
|
|
||||||
run_process(97, "ok"); // InternalError2 from load_config
|
|
||||||
run_process(99, "ok"); // InternalError1 from get_nonce
|
|
||||||
run_process(1, "bad"); // InvalidSignature from verify_signature
|
|
||||||
run_process(1, "ok"); // success
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user