diff --git a/protobufs/client.proto b/protobufs/client.proto index 62761c3..1f70371 100644 --- a/protobufs/client.proto +++ b/protobufs/client.proto @@ -23,6 +23,7 @@ message ClientRequest { oneof payload { AuthChallengeRequest auth_challenge_request = 1; AuthChallengeSolution auth_challenge_solution = 2; + arbiter.evm.EvmSignTransactionRequest evm_sign_transaction = 3; } } diff --git a/protobufs/user_agent.proto b/protobufs/user_agent.proto index fcf508d..2a7e3c0 100644 --- a/protobufs/user_agent.proto +++ b/protobufs/user_agent.proto @@ -12,6 +12,55 @@ enum KeyType { KEY_TYPE_RSA = 3; } +// --- SDK client management --- + +enum SdkClientError { + SDK_CLIENT_ERROR_UNSPECIFIED = 0; + SDK_CLIENT_ERROR_ALREADY_EXISTS = 1; + SDK_CLIENT_ERROR_NOT_FOUND = 2; + SDK_CLIENT_ERROR_HAS_RELATED_DATA = 3; // hard-delete blocked by FK (client has grants or transaction logs) + SDK_CLIENT_ERROR_INTERNAL = 4; +} + +message SdkClientApproveRequest { + bytes pubkey = 1; // 32-byte ed25519 public key +} + +message SdkClientRevokeRequest { + int32 client_id = 1; +} + +message SdkClientEntry { + int32 id = 1; + bytes pubkey = 2; + int32 created_at = 3; +} + +message SdkClientList { + repeated SdkClientEntry clients = 1; +} + +message SdkClientApproveResponse { + oneof result { + SdkClientEntry client = 1; + SdkClientError error = 2; + } +} + +message SdkClientRevokeResponse { + oneof result { + google.protobuf.Empty ok = 1; + SdkClientError error = 2; + } +} + +message SdkClientListResponse { + oneof result { + SdkClientList clients = 1; + SdkClientError error = 2; + } +} + message AuthChallengeRequest { bytes pubkey = 1; optional string bootstrap_token = 2; @@ -57,16 +106,6 @@ enum VaultState { VAULT_STATE_ERROR = 4; } -message ClientConnectionRequest { - bytes pubkey = 1; -} - -message ClientConnectionResponse { - bool approved = 1; -} - -message ClientConnectionCancel {} - message UserAgentRequest { oneof payload { AuthChallengeRequest auth_challenge_request = 1; @@ -79,7 +118,10 @@ message UserAgentRequest { arbiter.evm.EvmGrantCreateRequest evm_grant_create = 8; arbiter.evm.EvmGrantDeleteRequest evm_grant_delete = 9; arbiter.evm.EvmGrantListRequest evm_grant_list = 10; - ClientConnectionResponse client_connection_response = 11; + // field 11 reserved: was client_connection_response (online approval removed) + SdkClientApproveRequest sdk_client_approve = 12; + SdkClientRevokeRequest sdk_client_revoke = 13; + google.protobuf.Empty sdk_client_list = 14; } } message UserAgentResponse { @@ -94,7 +136,9 @@ message UserAgentResponse { arbiter.evm.EvmGrantCreateResponse evm_grant_create = 8; arbiter.evm.EvmGrantDeleteResponse evm_grant_delete = 9; arbiter.evm.EvmGrantListResponse evm_grant_list = 10; - ClientConnectionRequest client_connection_request = 11; - ClientConnectionCancel client_connection_cancel = 12; + // fields 11, 12 reserved: were client_connection_request, client_connection_cancel (online approval removed) + SdkClientApproveResponse sdk_client_approve = 13; + SdkClientRevokeResponse sdk_client_revoke = 14; + SdkClientListResponse sdk_client_list = 15; } } diff --git a/server/Cargo.lock b/server/Cargo.lock index 1586320..2d0b912 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -678,6 +678,18 @@ checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" [[package]] name = "arbiter-client" version = "0.1.0" +dependencies = [ + "alloy", + "arbiter-proto", + "async-trait", + "ed25519-dalek", + "http", + "rustls-webpki", + "thiserror", + "tokio", + "tokio-stream", + "tonic", +] [[package]] name = "arbiter-proto" diff --git a/server/crates/arbiter-client/Cargo.toml b/server/crates/arbiter-client/Cargo.toml index e71c9e7..597a26e 100644 --- a/server/crates/arbiter-client/Cargo.toml +++ b/server/crates/arbiter-client/Cargo.toml @@ -5,4 +5,18 @@ edition = "2024" repository = "https://git.markettakers.org/MarketTakers/arbiter" license = "Apache-2.0" +[lints] +workspace = true + [dependencies] +arbiter-proto.path = "../arbiter-proto" +alloy.workspace = true +tonic.workspace = true +tonic.features = ["tls-aws-lc"] +tokio.workspace = true +tokio-stream.workspace = true +ed25519-dalek.workspace = true +thiserror.workspace = true +http = "1.4.0" +rustls-webpki = { version = "0.103.9", features = ["aws-lc-rs"] } +async-trait.workspace = true diff --git a/server/crates/arbiter-client/src/lib.rs b/server/crates/arbiter-client/src/lib.rs index b93cf3f..322a9bb 100644 --- a/server/crates/arbiter-client/src/lib.rs +++ b/server/crates/arbiter-client/src/lib.rs @@ -1,14 +1,272 @@ -pub fn add(left: u64, right: u64) -> u64 { - left + right +use alloy::{ + consensus::SignableTransaction, + network::TxSigner, + primitives::{Address, B256, ChainId, Signature}, + signers::{Error, Result, Signer}, +}; +use arbiter_proto::{ + format_challenge, + proto::{ + arbiter_service_client::ArbiterServiceClient, + client::{ + AuthChallengeRequest, AuthChallengeSolution, ClientRequest, ClientResponse, + client_connect_error, client_request::Payload as ClientRequestPayload, + client_response::Payload as ClientResponsePayload, + }, + evm::{ + EvmSignTransactionRequest, evm_sign_transaction_response::Result as SignResponseResult, + }, + }, + url::ArbiterUrl, +}; +use async_trait::async_trait; +use ed25519_dalek::Signer as _; +use tokio::sync::{Mutex, mpsc}; +use tokio_stream::wrappers::ReceiverStream; +use tonic::transport::ClientTlsConfig; + +#[derive(Debug, thiserror::Error)] +pub enum ConnectError { + #[error("Could not establish connection")] + Connection(#[from] tonic::transport::Error), + + #[error("Invalid server URI")] + InvalidUri(#[from] http::uri::InvalidUri), + + #[error("Invalid CA certificate")] + InvalidCaCert(#[from] webpki::Error), + + #[error("gRPC error")] + Grpc(#[from] tonic::Status), + + #[error("Auth challenge was not returned by server")] + MissingAuthChallenge, + + #[error("Client approval denied by User Agent")] + ApprovalDenied, + + #[error("No User Agents online to approve client")] + NoUserAgentsOnline, + + #[error("Unexpected auth response payload")] + UnexpectedAuthResponse, } -#[cfg(test)] -mod tests { - use super::*; +#[derive(Debug, thiserror::Error)] +enum ClientSignError { + #[error("Transport channel closed")] + ChannelClosed, - #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); + #[error("Connection closed by server")] + ConnectionClosed, + + #[error("Invalid response payload")] + InvalidResponse, + + #[error("Remote signing was rejected")] + Rejected, +} + +struct ClientTransport { + sender: mpsc::Sender, + receiver: tonic::Streaming, +} + +impl ClientTransport { + async fn send(&mut self, request: ClientRequest) -> std::result::Result<(), ClientSignError> { + self.sender + .send(request) + .await + .map_err(|_| ClientSignError::ChannelClosed) + } + + async fn recv(&mut self) -> std::result::Result { + match self.receiver.message().await { + Ok(Some(resp)) => Ok(resp), + Ok(None) => Err(ClientSignError::ConnectionClosed), + Err(_) => Err(ClientSignError::ConnectionClosed), + } + } +} + +pub struct ArbiterSigner { + transport: Mutex, + address: Address, + chain_id: Option, +} + +impl ArbiterSigner { + pub async fn connect_grpc( + url: ArbiterUrl, + key: ed25519_dalek::SigningKey, + address: Address, + ) -> std::result::Result { + let anchor = webpki::anchor_from_trusted_cert(&url.ca_cert)?.to_owned(); + let tls = ClientTlsConfig::new().trust_anchor(anchor); + + // NOTE: We intentionally keep the same URL construction strategy as the user-agent crate + // to avoid behavior drift between the two clients. + let channel = tonic::transport::Channel::from_shared(format!("{}:{}", url.host, url.port))? + .tls_config(tls)? + .connect() + .await?; + + let mut client = ArbiterServiceClient::new(channel); + let (tx, rx) = mpsc::channel(16); + let response_stream = client.client(ReceiverStream::new(rx)).await?.into_inner(); + + let mut transport = ClientTransport { + sender: tx, + receiver: response_stream, + }; + + authenticate(&mut transport, key).await?; + + Ok(Self { + transport: Mutex::new(transport), + address, + chain_id: None, + }) + } + + async fn sign_transaction_via_arbiter( + &self, + tx: &mut dyn SignableTransaction, + ) -> Result { + if let Some(chain_id) = self.chain_id + && !tx.set_chain_id_checked(chain_id) + { + return Err(Error::TransactionChainIdMismatch { + signer: chain_id, + tx: tx.chain_id().unwrap(), + }); + } + + let mut rlp_transaction = Vec::new(); + tx.encode_for_signing(&mut rlp_transaction); + + let request = ClientRequest { + payload: Some(ClientRequestPayload::EvmSignTransaction( + EvmSignTransactionRequest { + wallet_address: self.address.as_slice().to_vec(), + rlp_transaction, + }, + )), + }; + + let mut transport = self.transport.lock().await; + transport.send(request).await.map_err(Error::other)?; + let response = transport.recv().await.map_err(Error::other)?; + + let payload = response + .payload + .ok_or_else(|| Error::other(ClientSignError::InvalidResponse))?; + + let ClientResponsePayload::EvmSignTransaction(sign_response) = payload else { + return Err(Error::other(ClientSignError::InvalidResponse)); + }; + + let Some(result) = sign_response.result else { + return Err(Error::other(ClientSignError::InvalidResponse)); + }; + + match result { + SignResponseResult::Signature(bytes) => { + Signature::try_from(bytes.as_slice()).map_err(Error::other) + } + SignResponseResult::EvalError(_) | SignResponseResult::Error(_) => { + Err(Error::other(ClientSignError::Rejected)) + } + } + } +} + +async fn authenticate( + transport: &mut ClientTransport, + key: ed25519_dalek::SigningKey, +) -> std::result::Result<(), ConnectError> { + transport + .send(ClientRequest { + payload: Some(ClientRequestPayload::AuthChallengeRequest( + AuthChallengeRequest { + pubkey: key.verifying_key().to_bytes().to_vec(), + }, + )), + }) + .await + .map_err(|_| ConnectError::UnexpectedAuthResponse)?; + + let response = transport + .recv() + .await + .map_err(|_| ConnectError::MissingAuthChallenge)?; + + let payload = response.payload.ok_or(ConnectError::MissingAuthChallenge)?; + match payload { + ClientResponsePayload::AuthChallenge(challenge) => { + let challenge_payload = format_challenge(challenge.nonce, &challenge.pubkey); + let signature = key.sign(&challenge_payload).to_bytes().to_vec(); + + transport + .send(ClientRequest { + payload: Some(ClientRequestPayload::AuthChallengeSolution( + AuthChallengeSolution { signature }, + )), + }) + .await + .map_err(|_| ConnectError::UnexpectedAuthResponse)?; + + // Current server flow does not emit `AuthOk` for SDK clients, so we proceed after + // sending the solution. If authentication fails, the first business request will return + // a `ClientConnectError` or the stream will close. + Ok(()) + } + ClientResponsePayload::ClientConnectError(err) => { + match client_connect_error::Code::try_from(err.code) + .unwrap_or(client_connect_error::Code::Unknown) + { + client_connect_error::Code::ApprovalDenied => Err(ConnectError::ApprovalDenied), + client_connect_error::Code::NoUserAgentsOnline => { + Err(ConnectError::NoUserAgentsOnline) + } + client_connect_error::Code::Unknown => Err(ConnectError::UnexpectedAuthResponse), + } + } + _ => Err(ConnectError::UnexpectedAuthResponse), + } +} + +#[async_trait] +impl Signer for ArbiterSigner { + async fn sign_hash(&self, _hash: &B256) -> Result { + Err(Error::other( + "hash-only signing is not supported for ArbiterSigner; use transaction signing", + )) + } + + fn address(&self) -> Address { + self.address + } + + fn chain_id(&self) -> Option { + self.chain_id + } + + fn set_chain_id(&mut self, chain_id: Option) { + self.chain_id = chain_id; + } +} + +#[async_trait] +impl TxSigner for ArbiterSigner { + fn address(&self) -> Address { + self.address + } + + async fn sign_transaction( + &self, + tx: &mut dyn SignableTransaction, + ) -> Result { + self.sign_transaction_via_arbiter(tx).await } } diff --git a/server/crates/arbiter-server/migrations/2026-03-15-103112-0000_add_program_client_pubkey_unique/down.sql b/server/crates/arbiter-server/migrations/2026-03-15-103112-0000_add_program_client_pubkey_unique/down.sql new file mode 100644 index 0000000..aeda4ed --- /dev/null +++ b/server/crates/arbiter-server/migrations/2026-03-15-103112-0000_add_program_client_pubkey_unique/down.sql @@ -0,0 +1 @@ +DROP INDEX IF EXISTS program_client_public_key_unique; diff --git a/server/crates/arbiter-server/migrations/2026-03-15-103112-0000_add_program_client_pubkey_unique/up.sql b/server/crates/arbiter-server/migrations/2026-03-15-103112-0000_add_program_client_pubkey_unique/up.sql new file mode 100644 index 0000000..5d1d4a2 --- /dev/null +++ b/server/crates/arbiter-server/migrations/2026-03-15-103112-0000_add_program_client_pubkey_unique/up.sql @@ -0,0 +1,2 @@ +CREATE UNIQUE INDEX program_client_public_key_unique + ON program_client (public_key); diff --git a/server/crates/arbiter-server/src/actors/client/auth.rs b/server/crates/arbiter-server/src/actors/client/auth.rs index cb11d9a..649b987 100644 --- a/server/crates/arbiter-server/src/actors/client/auth.rs +++ b/server/crates/arbiter-server/src/actors/client/auth.rs @@ -8,19 +8,13 @@ use arbiter_proto::{ }, transport::expect_message, }; -use diesel::{ - ExpressionMethods as _, OptionalExtension as _, QueryDsl as _, dsl::insert_into, update, -}; +use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl as _, update}; use diesel_async::RunQueryDsl as _; use ed25519_dalek::VerifyingKey; -use kameo::error::SendError; use tracing::error; use crate::{ - actors::{ - client::ClientConnection, - router::{self, RequestClientApproval}, - }, + actors::client::ClientConnection, db::{self, schema::program_client}, }; @@ -40,27 +34,20 @@ pub enum Error { DatabaseOperationFailed, #[error("Invalid challenge solution")] InvalidChallengeSolution, - #[error("Client approval request failed")] - ApproveError(#[from] ApproveError), + #[error("Client not registered")] + NotRegistered, #[error("Internal error")] InternalError, #[error("Transport error")] Transport, } -#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)] -pub enum ApproveError { - #[error("Internal error")] - Internal, - #[error("Client connection denied by user agents")] - Denied, - #[error("Upstream error: {0}")] - Upstream(router::ApprovalError), -} - /// Atomically reads and increments the nonce for a known client. /// Returns `None` if the pubkey is not registered. -async fn get_nonce(db: &db::DatabasePool, pubkey: &VerifyingKey) -> Result, Error> { +async fn get_nonce( + db: &db::DatabasePool, + pubkey: &VerifyingKey, +) -> Result, Error> { let pubkey_bytes = pubkey.as_bytes().to_vec(); let mut conn = db.get().await.map_err(|e| { @@ -71,10 +58,10 @@ async fn get_nonce(db: &db::DatabasePool, pubkey: &VerifyingKey) -> Result(conn) + .select((program_client::id, program_client::nonce)) + .first::<(i32, i32)>(conn) .await .optional()? else { @@ -87,7 +74,7 @@ async fn get_nonce(db: &db::DatabasePool, pubkey: &VerifyingKey) -> Result Result Result<(), Error> { - let result = actors - .router - .ask(RequestClientApproval { - client_pubkey: pubkey, - }) - .await; - - match result { - Ok(true) => Ok(()), - Ok(false) => Err(Error::ApproveError(ApproveError::Denied)), - Err(SendError::HandlerError(e)) => { - error!(error = ?e, "Approval upstream error"); - Err(Error::ApproveError(ApproveError::Upstream(e))) - } - Err(e) => { - error!(error = ?e, "Approval request to router failed"); - Err(Error::ApproveError(ApproveError::Internal)) - } - } -} - -async fn insert_client(db: &db::DatabasePool, pubkey: &VerifyingKey) -> Result<(), Error> { - let now = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_secs() as i32; - - let mut conn = db.get().await.map_err(|e| { - error!(error = ?e, "Database pool error"); - Error::DatabasePoolUnavailable - })?; - - insert_into(program_client::table) - .values(( - program_client::public_key.eq(pubkey.as_bytes().to_vec()), - program_client::nonce.eq(1), // pre-incremented; challenge uses 0 - program_client::created_at.eq(now), - program_client::updated_at.eq(now), - )) - .execute(&mut conn) - .await - .map_err(|e| { - error!(error = ?e, "Failed to insert new client"); - Error::DatabaseOperationFailed - })?; - - Ok(()) -} - async fn challenge_client( props: &mut ClientConnection, pubkey: VerifyingKey, @@ -200,15 +134,12 @@ async fn challenge_client( fn connect_error_code(err: &Error) -> ConnectErrorCode { match err { - Error::ApproveError(ApproveError::Denied) => ConnectErrorCode::ApprovalDenied, - Error::ApproveError(ApproveError::Upstream( - router::ApprovalError::NoUserAgentsConnected, - )) => ConnectErrorCode::NoUserAgentsOnline, + Error::NotRegistered => ConnectErrorCode::ApprovalDenied, _ => ConnectErrorCode::Unknown, } } -async fn authenticate(props: &mut ClientConnection) -> Result { +async fn authenticate(props: &mut ClientConnection) -> Result<(VerifyingKey, i32), Error> { let Some(ClientRequest { payload: Some(ClientRequestPayload::AuthChallengeRequest(challenge)), }) = props.transport.recv().await @@ -223,23 +154,19 @@ async fn authenticate(props: &mut ClientConnection) -> Result nonce, - None => { - approve_new_client(&props.actors, pubkey).await?; - insert_client(&props.db, &pubkey).await?; - 0 - } + let (client_id, nonce) = match get_nonce(&props.db, &pubkey).await? { + Some((client_id, nonce)) => (client_id, nonce), + None => return Err(Error::NotRegistered), }; challenge_client(props, pubkey, nonce).await?; - Ok(pubkey) + Ok((pubkey, client_id)) } pub async fn authenticate_and_create(mut props: ClientConnection) -> Result { match authenticate(&mut props).await { - Ok(_pubkey) => Ok(ClientSession::new(props)), + Ok((_pubkey, client_id)) => Ok(ClientSession::new(props, client_id)), Err(err) => { let code = connect_error_code(&err); let _ = props diff --git a/server/crates/arbiter-server/src/actors/client/session.rs b/server/crates/arbiter-server/src/actors/client/session.rs index a2ae4a4..63d19b6 100644 --- a/server/crates/arbiter-server/src/actors/client/session.rs +++ b/server/crates/arbiter-server/src/actors/client/session.rs @@ -1,19 +1,35 @@ -use arbiter_proto::proto::client::{ClientRequest, ClientResponse}; +use alloy::{consensus::TxEip1559, primitives::Address, rlp::Decodable}; +use arbiter_proto::proto::{ + client::{ + ClientRequest, ClientResponse, client_request::Payload as ClientRequestPayload, + client_response::Payload as ClientResponsePayload, + }, + evm::{ + EvmError, EvmSignTransactionResponse, evm_sign_transaction_response::Result as SignResult, + }, +}; use kameo::Actor; use tokio::select; use tracing::{error, info}; -use crate::{actors::{ - GlobalActors, client::{ClientError, ClientConnection}, router::RegisterClient -}, db}; +use crate::{ + actors::{ + GlobalActors, + client::{ClientConnection, ClientError}, + evm::ClientSignTransaction, + router::RegisterClient, + }, + db, +}; pub struct ClientSession { props: ClientConnection, + client_id: i32, } impl ClientSession { - pub(crate) fn new(props: ClientConnection) -> Self { - Self { props } + pub(crate) fn new(props: ClientConnection, client_id: i32) -> Self { + Self { props, client_id } } pub async fn process_transport_inbound(&mut self, req: ClientRequest) -> Output { @@ -22,8 +38,46 @@ impl ClientSession { ClientError::MissingRequestPayload })?; - let _ = msg; - Err(ClientError::UnexpectedRequestPayload) + match msg { + ClientRequestPayload::EvmSignTransaction(sign_req) => { + let wallet_address: [u8; 20] = sign_req + .wallet_address + .try_into() + .map_err(|_| ClientError::UnexpectedRequestPayload)?; + + let mut rlp_bytes: &[u8] = &sign_req.rlp_transaction; + let tx = TxEip1559::decode(&mut rlp_bytes) + .map_err(|_| ClientError::UnexpectedRequestPayload)?; + + let result = self + .props + .actors + .evm + .ask(ClientSignTransaction { + client_id: self.client_id, + wallet_address: Address::from_slice(&wallet_address), + transaction: tx, + }) + .await; + + let response_result = match result { + Ok(signature) => SignResult::Signature(signature.as_bytes().to_vec()), + Err(err) => { + error!(?err, "client sign transaction failed"); + SignResult::Error(EvmError::Internal.into()) + } + }; + + Ok(ClientResponse { + payload: Some(ClientResponsePayload::EvmSignTransaction( + EvmSignTransactionResponse { + result: Some(response_result), + }, + )), + }) + } + _ => Err(ClientError::UnexpectedRequestPayload), + } } } @@ -89,6 +143,9 @@ impl ClientSession { use arbiter_proto::transport::DummyTransport; let transport: super::Transport = Box::new(DummyTransport::new()); let props = ClientConnection::new(db, transport, actors); - Self { props } + Self { + props, + client_id: 0, + } } } diff --git a/server/crates/arbiter-server/src/actors/router/mod.rs b/server/crates/arbiter-server/src/actors/router/mod.rs index a0a75b8..8d06152 100644 --- a/server/crates/arbiter-server/src/actors/router/mod.rs +++ b/server/crates/arbiter-server/src/actors/router/mod.rs @@ -1,20 +1,14 @@ use std::{collections::HashMap, ops::ControlFlow}; -use ed25519_dalek::VerifyingKey; use kameo::{ Actor, actor::{ActorId, ActorRef}, messages, prelude::{ActorStopReason, Context, WeakActorRef}, - reply::DelegatedReply, }; -use tokio::{sync::watch, task::JoinSet}; -use tracing::{info, warn}; +use tracing::info; -use crate::actors::{ - client::session::ClientSession, - user_agent::session::{RequestNewClientApproval, UserAgentSession}, -}; +use crate::actors::{client::session::ClientSession, user_agent::session::UserAgentSession}; #[derive(Default)] pub struct MessageRouter { @@ -56,73 +50,6 @@ impl Actor for MessageRouter { } } -#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq, Hash)] -pub enum ApprovalError { - #[error("No user agents connected")] - NoUserAgentsConnected, -} - -async fn request_client_approval( - user_agents: &[WeakActorRef], - client_pubkey: VerifyingKey, -) -> Result { - if user_agents.is_empty() { - return Err(ApprovalError::NoUserAgentsConnected); - } - - let mut pool = JoinSet::new(); - let (cancel_tx, cancel_rx) = watch::channel(()); - - for weak_ref in user_agents { - match weak_ref.upgrade() { - Some(agent) => { - let cancel_rx = cancel_rx.clone(); - pool.spawn(async move { - agent - .ask(RequestNewClientApproval { - client_pubkey, - cancel_flag: cancel_rx.clone(), - }) - .await - }); - } - None => { - warn!( - id = weak_ref.id().to_string(), - actor = "MessageRouter", - event = "useragent.disconnected_before_approval" - ); - } - } - } - - while let Some(result) = pool.join_next().await { - match result { - Ok(Ok(approved)) => { - // cancel other pending requests - let _ = cancel_tx.send(()); - return Ok(approved); - } - Ok(Err(err)) => { - warn!( - ?err, - actor = "MessageRouter", - event = "useragent.approval_error" - ); - } - Err(err) => { - warn!( - ?err, - actor = "MessageRouter", - event = "useragent.approval_task_failed" - ); - } - } - } - - Err(ApprovalError::NoUserAgentsConnected) -} - #[messages] impl MessageRouter { #[message(ctx)] @@ -146,29 +73,4 @@ impl MessageRouter { ctx.actor_ref().link(&actor).await; self.clients.insert(actor.id(), actor); } - - #[message(ctx)] - pub async fn request_client_approval( - &mut self, - client_pubkey: VerifyingKey, - ctx: &mut Context>>, - ) -> DelegatedReply> { - let (reply, Some(reply_sender)) = ctx.reply_sender() else { - panic!("Exptected `request_client_approval` to have callback channel"); - }; - - let weak_refs = self - .user_agents - .values() - .map(|agent| agent.downgrade()) - .collect::>(); - - // handle in subtask to not to lock the actor - tokio::task::spawn(async move { - let result = request_client_approval(&weak_refs, client_pubkey).await; - reply_sender.send(result); - }); - - reply - } } diff --git a/server/crates/arbiter-server/src/actors/user_agent/session.rs b/server/crates/arbiter-server/src/actors/user_agent/session.rs index b686796..a19e85b 100644 --- a/server/crates/arbiter-server/src/actors/user_agent/session.rs +++ b/server/crates/arbiter-server/src/actors/user_agent/session.rs @@ -3,25 +3,32 @@ use std::{ops::DerefMut, sync::Mutex}; use arbiter_proto::proto::{ evm as evm_proto, user_agent::{ - ClientConnectionCancel, ClientConnectionRequest, UnsealEncryptedKey, UnsealResult, + SdkClientApproveRequest, SdkClientApproveResponse, SdkClientEntry, + SdkClientError as ProtoSdkClientError, SdkClientList, SdkClientListResponse, + SdkClientRevokeRequest, SdkClientRevokeResponse, UnsealEncryptedKey, UnsealResult, UnsealStart, UnsealStartResponse, UserAgentRequest, UserAgentResponse, + sdk_client_approve_response, sdk_client_list_response, sdk_client_revoke_response, user_agent_request::Payload as UserAgentRequestPayload, user_agent_response::Payload as UserAgentResponsePayload, }, }; use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit}; -use ed25519_dalek::VerifyingKey; -use kameo::{Actor, error::SendError, messages, prelude::Context}; +use diesel::{ExpressionMethods as _, QueryDsl as _, dsl::insert_into}; +use diesel_async::RunQueryDsl as _; +use kameo::{Actor, error::SendError, prelude::Context}; use memsafe::MemSafe; -use tokio::{select, sync::watch}; +use tokio::select; use tracing::{error, info}; use x25519_dalek::{EphemeralSecret, PublicKey}; -use crate::actors::{ - evm::{Generate, ListWallets}, - keyholder::{self, TryUnseal}, - router::RegisterUserAgent, - user_agent::{TransportResponseError, UserAgentConnection}, +use crate::{ + actors::{ + evm::{Generate, ListWallets}, + keyholder::{self, TryUnseal}, + router::RegisterUserAgent, + user_agent::{TransportResponseError, UserAgentConnection}, + }, + db::schema::program_client, }; mod state; @@ -108,52 +115,6 @@ impl UserAgentSession { } } -#[messages] -impl UserAgentSession { - // TODO: Think about refactoring it to state-machine based flow, as we already have one - #[message(ctx)] - pub async fn request_new_client_approval( - &mut self, - client_pubkey: VerifyingKey, - mut cancel_flag: watch::Receiver<()>, - ctx: &mut Context>, - ) -> Result { - self.send_msg( - UserAgentResponsePayload::ClientConnectionRequest(ClientConnectionRequest { - pubkey: client_pubkey.as_bytes().to_vec(), - }), - ctx, - ) - .await?; - - let extractor = |msg| { - if let UserAgentRequestPayload::ClientConnectionResponse(client_connection_response) = - msg - { - Some(client_connection_response) - } else { - None - } - }; - - tokio::select! { - _ = cancel_flag.changed() => { - info!(actor = "useragent", "client connection approval cancelled"); - self.send_msg( - UserAgentResponsePayload::ClientConnectionCancel(ClientConnectionCancel {}), - ctx, - ).await?; - Ok(false) - } - result = self.expect_msg(extractor, ctx) => { - let result = result?; - info!(actor = "useragent", "received client connection approval result: approved={}", result.approved); - Ok(result.approved) - } - } - } -} - impl UserAgentSession { pub async fn process_transport_inbound(&mut self, req: UserAgentRequest) -> Output { let msg = req.payload.ok_or_else(|| { @@ -170,6 +131,13 @@ impl UserAgentSession { } UserAgentRequestPayload::EvmWalletCreate(_) => self.handle_evm_wallet_create().await, UserAgentRequestPayload::EvmWalletList(_) => self.handle_evm_wallet_list().await, + UserAgentRequestPayload::SdkClientApprove(req) => { + self.handle_sdk_client_approve(req).await + } + UserAgentRequestPayload::SdkClientRevoke(req) => { + self.handle_sdk_client_revoke(req).await + } + UserAgentRequestPayload::SdkClientList(_) => self.handle_sdk_client_list().await, _ => Err(TransportResponseError::UnexpectedRequestPayload), } } @@ -331,6 +299,204 @@ impl UserAgentSession { } } +impl UserAgentSession { + async fn handle_sdk_client_approve(&mut self, req: SdkClientApproveRequest) -> Output { + use sdk_client_approve_response::Result as ApproveResult; + + if req.pubkey.len() != 32 { + return Ok(response(UserAgentResponsePayload::SdkClientApprove( + SdkClientApproveResponse { + result: Some(ApproveResult::Error(ProtoSdkClientError::Internal.into())), + }, + ))); + } + + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() as i32; + + let mut conn = match self.props.db.get().await { + Ok(c) => c, + Err(e) => { + error!(?e, "Failed to get DB connection for sdk_client_approve"); + return Ok(response(UserAgentResponsePayload::SdkClientApprove( + SdkClientApproveResponse { + result: Some(ApproveResult::Error(ProtoSdkClientError::Internal.into())), + }, + ))); + } + }; + + let pubkey_bytes = req.pubkey.clone(); + let insert_result = insert_into(program_client::table) + .values(( + program_client::public_key.eq(&pubkey_bytes), + program_client::nonce.eq(1), // pre-incremented; challenge will use nonce=0 + program_client::created_at.eq(now), + program_client::updated_at.eq(now), + )) + .execute(&mut conn) + .await; + + match insert_result { + Ok(_) => { + match program_client::table + .filter(program_client::public_key.eq(&pubkey_bytes)) + .order(program_client::id.desc()) + .select(( + program_client::id, + program_client::public_key, + program_client::created_at, + )) + .first::<(i32, Vec, i32)>(&mut conn) + .await + { + Ok((id, pubkey, created_at)) => Ok(response( + UserAgentResponsePayload::SdkClientApprove(SdkClientApproveResponse { + result: Some(ApproveResult::Client(SdkClientEntry { + id, + pubkey, + created_at, + })), + }), + )), + Err(e) => { + error!(?e, "Failed to fetch inserted SDK client"); + Ok(response(UserAgentResponsePayload::SdkClientApprove( + SdkClientApproveResponse { + result: Some(ApproveResult::Error( + ProtoSdkClientError::Internal.into(), + )), + }, + ))) + } + } + } + Err(diesel::result::Error::DatabaseError( + diesel::result::DatabaseErrorKind::UniqueViolation, + _, + )) => Ok(response(UserAgentResponsePayload::SdkClientApprove( + SdkClientApproveResponse { + result: Some(ApproveResult::Error( + ProtoSdkClientError::AlreadyExists.into(), + )), + }, + ))), + Err(e) => { + error!(?e, "Failed to insert SDK client"); + Ok(response(UserAgentResponsePayload::SdkClientApprove( + SdkClientApproveResponse { + result: Some(ApproveResult::Error(ProtoSdkClientError::Internal.into())), + }, + ))) + } + } + } + + async fn handle_sdk_client_list(&mut self) -> Output { + let mut conn = match self.props.db.get().await { + Ok(c) => c, + Err(e) => { + error!(?e, "Failed to get DB connection for sdk_client_list"); + return Ok(response(UserAgentResponsePayload::SdkClientList( + SdkClientListResponse { + result: Some(sdk_client_list_response::Result::Error( + ProtoSdkClientError::Internal.into(), + )), + }, + ))); + } + }; + + match program_client::table + .select(( + program_client::id, + program_client::public_key, + program_client::created_at, + )) + .load::<(i32, Vec, i32)>(&mut conn) + .await + { + Ok(rows) => Ok(response(UserAgentResponsePayload::SdkClientList( + SdkClientListResponse { + result: Some(sdk_client_list_response::Result::Clients(SdkClientList { + clients: rows + .into_iter() + .map(|(id, pubkey, created_at)| SdkClientEntry { + id, + pubkey, + created_at, + }) + .collect(), + })), + }, + ))), + Err(e) => { + error!(?e, "Failed to list SDK clients"); + Ok(response(UserAgentResponsePayload::SdkClientList( + SdkClientListResponse { + result: Some(sdk_client_list_response::Result::Error( + ProtoSdkClientError::Internal.into(), + )), + }, + ))) + } + } + } + + async fn handle_sdk_client_revoke(&mut self, req: SdkClientRevokeRequest) -> Output { + use sdk_client_revoke_response::Result as RevokeResult; + + let mut conn = match self.props.db.get().await { + Ok(c) => c, + Err(e) => { + error!(?e, "Failed to get DB connection for sdk_client_revoke"); + return Ok(response(UserAgentResponsePayload::SdkClientRevoke( + SdkClientRevokeResponse { + result: Some(RevokeResult::Error(ProtoSdkClientError::Internal.into())), + }, + ))); + } + }; + + match diesel::delete(program_client::table) + .filter(program_client::id.eq(req.client_id)) + .execute(&mut conn) + .await + { + Ok(0) => Ok(response(UserAgentResponsePayload::SdkClientRevoke( + SdkClientRevokeResponse { + result: Some(RevokeResult::Error(ProtoSdkClientError::NotFound.into())), + }, + ))), + Ok(_) => Ok(response(UserAgentResponsePayload::SdkClientRevoke( + SdkClientRevokeResponse { + result: Some(RevokeResult::Ok(())), + }, + ))), + Err(diesel::result::Error::DatabaseError( + diesel::result::DatabaseErrorKind::ForeignKeyViolation, + _, + )) => Ok(response(UserAgentResponsePayload::SdkClientRevoke( + SdkClientRevokeResponse { + result: Some(RevokeResult::Error( + ProtoSdkClientError::HasRelatedData.into(), + )), + }, + ))), + Err(e) => { + error!(?e, "Failed to delete SDK client"); + Ok(response(UserAgentResponsePayload::SdkClientRevoke( + SdkClientRevokeResponse { + result: Some(RevokeResult::Error(ProtoSdkClientError::Internal.into())), + }, + ))) + } + } + } +} + fn map_evm_error(op: &str, err: SendError) -> evm_proto::EvmError { use crate::actors::{evm::Error as EvmError, keyholder::Error as KhError}; match err { diff --git a/server/crates/arbiter-server/src/lib.rs b/server/crates/arbiter-server/src/lib.rs index d712992..abb51a5 100644 --- a/server/crates/arbiter-server/src/lib.rs +++ b/server/crates/arbiter-server/src/lib.rs @@ -79,7 +79,7 @@ fn client_auth_error_status(value: &client::auth::Error) -> Status { Status::invalid_argument("Failed to convert pubkey to VerifyingKey") } Error::InvalidChallengeSolution => Status::unauthenticated(value.to_string()), - Error::ApproveError(_) => Status::permission_denied(value.to_string()), + Error::NotRegistered => Status::permission_denied(value.to_string()), Error::Transport => Status::internal("Transport error"), Error::DatabasePoolUnavailable => Status::internal("Database pool error"), Error::DatabaseOperationFailed => Status::internal("Database error"), diff --git a/server/crates/arbiter-server/tests/client/auth.rs b/server/crates/arbiter-server/tests/client/auth.rs index 6228a58..5d82423 100644 --- a/server/crates/arbiter-server/tests/client/auth.rs +++ b/server/crates/arbiter-server/tests/client/auth.rs @@ -1,7 +1,15 @@ -use arbiter_proto::proto::client::{ - AuthChallengeRequest, AuthChallengeSolution, ClientRequest, - client_request::Payload as ClientRequestPayload, - client_response::Payload as ClientResponsePayload, +use alloy::{ + consensus::TxEip1559, + primitives::{Address, Bytes, TxKind, U256}, + rlp::Encodable, +}; +use arbiter_proto::proto::{ + client::{ + AuthChallengeRequest, AuthChallengeSolution, ClientRequest, + client_request::Payload as ClientRequestPayload, + client_response::Payload as ClientResponsePayload, + }, + evm::EvmSignTransactionRequest, }; use arbiter_proto::transport::Bi; use arbiter_server::actors::GlobalActors; @@ -109,3 +117,106 @@ pub async fn test_challenge_auth() { // Auth completes, session spawned task.await.unwrap(); } + +#[tokio::test] +#[test_log::test] +pub async fn test_evm_sign_request_payload_is_handled() { + let db = db::create_test_pool().await; + + let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); + let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); + + { + let mut conn = db.get().await.unwrap(); + insert_into(schema::program_client::table) + .values(schema::program_client::public_key.eq(pubkey_bytes.clone())) + .execute(&mut conn) + .await + .unwrap(); + } + + let (server_transport, mut test_transport) = ChannelTransport::new(); + let actors = GlobalActors::spawn(db.clone()).await.unwrap(); + + let props = ClientConnection::new(db.clone(), Box::new(server_transport), actors); + let task = tokio::spawn(connect_client(props)); + + test_transport + .send(ClientRequest { + payload: Some(ClientRequestPayload::AuthChallengeRequest( + AuthChallengeRequest { + pubkey: pubkey_bytes, + }, + )), + }) + .await + .unwrap(); + + let response = test_transport + .recv() + .await + .expect("should receive challenge"); + let challenge = match response { + Ok(resp) => match resp.payload { + Some(ClientResponsePayload::AuthChallenge(c)) => c, + other => panic!("Expected AuthChallenge, got {other:?}"), + }, + Err(err) => panic!("Expected Ok response, got Err({err:?})"), + }; + + let formatted_challenge = arbiter_proto::format_challenge(challenge.nonce, &challenge.pubkey); + let signature = new_key.sign(&formatted_challenge); + + test_transport + .send(ClientRequest { + payload: Some(ClientRequestPayload::AuthChallengeSolution( + AuthChallengeSolution { + signature: signature.to_bytes().to_vec(), + }, + )), + }) + .await + .unwrap(); + + task.await.unwrap(); + + let tx = TxEip1559 { + chain_id: 1, + nonce: 0, + gas_limit: 21_000, + max_fee_per_gas: 1, + max_priority_fee_per_gas: 1, + to: TxKind::Call(Address::from_slice(&[0x11; 20])), + value: U256::ZERO, + input: Bytes::new(), + access_list: Default::default(), + }; + + let mut rlp_transaction = Vec::new(); + tx.encode(&mut rlp_transaction); + + test_transport + .send(ClientRequest { + payload: Some(ClientRequestPayload::EvmSignTransaction( + EvmSignTransactionRequest { + wallet_address: [0x22; 20].to_vec(), + rlp_transaction, + }, + )), + }) + .await + .unwrap(); + + let response = test_transport + .recv() + .await + .expect("should receive sign response"); + + match response { + Ok(resp) => match resp.payload { + Some(ClientResponsePayload::EvmSignTransaction(_)) => {} + other => panic!("Expected EvmSignTransaction response, got {other:?}"), + }, + Err(err) => panic!("Expected Ok response, got Err({err:?})"), + } +} diff --git a/server/crates/arbiter-server/tests/user_agent.rs b/server/crates/arbiter-server/tests/user_agent.rs index dcd9789..355a721 100644 --- a/server/crates/arbiter-server/tests/user_agent.rs +++ b/server/crates/arbiter-server/tests/user_agent.rs @@ -2,5 +2,7 @@ mod common; #[path = "user_agent/auth.rs"] mod auth; +#[path = "user_agent/sdk_client.rs"] +mod sdk_client; #[path = "user_agent/unseal.rs"] mod unseal; diff --git a/server/crates/arbiter-server/tests/user_agent/sdk_client.rs b/server/crates/arbiter-server/tests/user_agent/sdk_client.rs new file mode 100644 index 0000000..3e2734a --- /dev/null +++ b/server/crates/arbiter-server/tests/user_agent/sdk_client.rs @@ -0,0 +1,270 @@ +use arbiter_proto::proto::user_agent::{ + SdkClientApproveRequest, SdkClientError as ProtoSdkClientError, SdkClientRevokeRequest, + UserAgentRequest, sdk_client_approve_response, sdk_client_list_response, + sdk_client_revoke_response, user_agent_request::Payload as UserAgentRequestPayload, + user_agent_response::Payload as UserAgentResponsePayload, +}; +use arbiter_server::{ + actors::{GlobalActors, user_agent::session::UserAgentSession}, + db, +}; + +/// Shared helper: create a session and register a client pubkey via sdk_client_approve. +async fn make_session(db: &db::DatabasePool) -> UserAgentSession { + let actors = GlobalActors::spawn(db.clone()).await.unwrap(); + UserAgentSession::new_test(db.clone(), actors) +} + +#[tokio::test] +#[test_log::test] +async fn test_sdk_client_approve_registers_client() { + let db = db::create_test_pool().await; + let mut session = make_session(&db).await; + + let pubkey = [0x42u8; 32]; + + let response = session + .process_transport_inbound(UserAgentRequest { + payload: Some(UserAgentRequestPayload::SdkClientApprove( + SdkClientApproveRequest { + pubkey: pubkey.to_vec(), + }, + )), + }) + .await + .expect("handler should succeed"); + + let entry = match response.payload.unwrap() { + UserAgentResponsePayload::SdkClientApprove(resp) => match resp.result.unwrap() { + sdk_client_approve_response::Result::Client(e) => e, + sdk_client_approve_response::Result::Error(e) => { + panic!("Expected Client, got error {:?}", e) + } + }, + other => panic!("Expected SdkClientApprove, got {other:?}"), + }; + + assert_eq!(entry.pubkey, pubkey.to_vec()); + assert!(entry.id > 0); +} + +#[tokio::test] +#[test_log::test] +async fn test_sdk_client_approve_duplicate_returns_already_exists() { + let db = db::create_test_pool().await; + let mut session = make_session(&db).await; + + let pubkey = [0x11u8; 32]; + let req = UserAgentRequest { + payload: Some(UserAgentRequestPayload::SdkClientApprove( + SdkClientApproveRequest { + pubkey: pubkey.to_vec(), + }, + )), + }; + + session + .process_transport_inbound(req.clone()) + .await + .unwrap(); + + let response = session + .process_transport_inbound(req) + .await + .expect("second insert should not panic"); + + match response.payload.unwrap() { + UserAgentResponsePayload::SdkClientApprove(resp) => match resp.result.unwrap() { + sdk_client_approve_response::Result::Error(code) => { + assert_eq!(code, ProtoSdkClientError::AlreadyExists as i32); + } + sdk_client_approve_response::Result::Client(_) => { + panic!("Expected AlreadyExists error for duplicate pubkey") + } + }, + other => panic!("Expected SdkClientApprove, got {other:?}"), + } +} + +#[tokio::test] +#[test_log::test] +async fn test_sdk_client_list_shows_registered_clients() { + let db = db::create_test_pool().await; + let mut session = make_session(&db).await; + + let pubkey_a = [0x0Au8; 32]; + let pubkey_b = [0x0Bu8; 32]; + + for pubkey in [pubkey_a, pubkey_b] { + session + .process_transport_inbound(UserAgentRequest { + payload: Some(UserAgentRequestPayload::SdkClientApprove( + SdkClientApproveRequest { + pubkey: pubkey.to_vec(), + }, + )), + }) + .await + .unwrap(); + } + + let response = session + .process_transport_inbound(UserAgentRequest { + payload: Some(UserAgentRequestPayload::SdkClientList(())), + }) + .await + .expect("list should succeed"); + + let clients = match response.payload.unwrap() { + UserAgentResponsePayload::SdkClientList(resp) => match resp.result.unwrap() { + sdk_client_list_response::Result::Clients(list) => list.clients, + sdk_client_list_response::Result::Error(e) => { + panic!("Expected Clients, got error {:?}", e) + } + }, + other => panic!("Expected SdkClientList, got {other:?}"), + }; + + assert_eq!(clients.len(), 2); + let pubkeys: Vec> = clients.into_iter().map(|e| e.pubkey).collect(); + assert!(pubkeys.contains(&pubkey_a.to_vec())); + assert!(pubkeys.contains(&pubkey_b.to_vec())); +} + +#[tokio::test] +#[test_log::test] +async fn test_sdk_client_revoke_removes_client() { + let db = db::create_test_pool().await; + let mut session = make_session(&db).await; + + let pubkey = [0xBBu8; 32]; + + // Register a client and get its id + let approve_response = session + .process_transport_inbound(UserAgentRequest { + payload: Some(UserAgentRequestPayload::SdkClientApprove( + SdkClientApproveRequest { + pubkey: pubkey.to_vec(), + }, + )), + }) + .await + .unwrap(); + + let client_id = match approve_response.payload.unwrap() { + UserAgentResponsePayload::SdkClientApprove(resp) => match resp.result.unwrap() { + sdk_client_approve_response::Result::Client(e) => e.id, + sdk_client_approve_response::Result::Error(e) => panic!("approve failed: {:?}", e), + }, + other => panic!("{other:?}"), + }; + + // Revoke the client + let revoke_response = session + .process_transport_inbound(UserAgentRequest { + payload: Some(UserAgentRequestPayload::SdkClientRevoke( + SdkClientRevokeRequest { client_id }, + )), + }) + .await + .expect("revoke should succeed"); + + match revoke_response.payload.unwrap() { + UserAgentResponsePayload::SdkClientRevoke(resp) => match resp.result.unwrap() { + sdk_client_revoke_response::Result::Ok(_) => {} + sdk_client_revoke_response::Result::Error(e) => { + panic!("Expected Ok, got error {:?}", e) + } + }, + other => panic!("Expected SdkClientRevoke, got {other:?}"), + } + + // List should now be empty + let list_response = session + .process_transport_inbound(UserAgentRequest { + payload: Some(UserAgentRequestPayload::SdkClientList(())), + }) + .await + .unwrap(); + + let clients = match list_response.payload.unwrap() { + UserAgentResponsePayload::SdkClientList(resp) => match resp.result.unwrap() { + sdk_client_list_response::Result::Clients(list) => list.clients, + sdk_client_list_response::Result::Error(e) => panic!("list error: {:?}", e), + }, + other => panic!("{other:?}"), + }; + assert!(clients.is_empty(), "client should be removed after revoke"); +} + +#[tokio::test] +#[test_log::test] +async fn test_sdk_client_revoke_not_found_returns_error() { + let db = db::create_test_pool().await; + let mut session = make_session(&db).await; + + let response = session + .process_transport_inbound(UserAgentRequest { + payload: Some(UserAgentRequestPayload::SdkClientRevoke( + SdkClientRevokeRequest { client_id: 9999 }, + )), + }) + .await + .unwrap(); + + match response.payload.unwrap() { + UserAgentResponsePayload::SdkClientRevoke(resp) => match resp.result.unwrap() { + sdk_client_revoke_response::Result::Error(code) => { + assert_eq!(code, ProtoSdkClientError::NotFound as i32); + } + sdk_client_revoke_response::Result::Ok(_) => { + panic!("Expected NotFound error for missing client_id") + } + }, + other => panic!("Expected SdkClientRevoke, got {other:?}"), + } +} + +#[tokio::test] +#[test_log::test] +async fn test_sdk_client_approve_rejected_client_cannot_auth() { + // Verify the core flow: only pre-approved clients can authenticate + use arbiter_proto::proto::client::{ + AuthChallengeRequest, ClientRequest, client_request::Payload as ClientRequestPayload, + client_response::Payload as ClientResponsePayload, + }; + use arbiter_proto::transport::Bi as _; + use arbiter_server::actors::client::{ClientConnection, connect_client}; + + let db = db::create_test_pool().await; + let actors = GlobalActors::spawn(db.clone()).await.unwrap(); + + let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); + let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); + + let (server_transport, mut test_transport) = super::common::ChannelTransport::<_, _>::new(); + let props = ClientConnection::new(db.clone(), Box::new(server_transport), actors.clone()); + let task = tokio::spawn(connect_client(props)); + + test_transport + .send(ClientRequest { + payload: Some(ClientRequestPayload::AuthChallengeRequest( + AuthChallengeRequest { + pubkey: pubkey_bytes.clone(), + }, + )), + }) + .await + .unwrap(); + + let response = test_transport.recv().await.unwrap().unwrap(); + assert!( + matches!( + response.payload.unwrap(), + ClientResponsePayload::ClientConnectError(_) + ), + "unregistered client should be rejected" + ); + + task.await.unwrap(); +} diff --git a/server/crates/arbiter-terrors-poc/Cargo.toml b/server/crates/arbiter-terrors-poc/Cargo.toml new file mode 100644 index 0000000..127adb1 --- /dev/null +++ b/server/crates/arbiter-terrors-poc/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "arbiter-terrors-poc" +version = "0.1.0" +edition = "2024" + +[dependencies] +terrors = "0.3" diff --git a/server/crates/arbiter-terrors-poc/src/errors.rs b/server/crates/arbiter-terrors-poc/src/errors.rs new file mode 100644 index 0000000..0ef0061 --- /dev/null +++ b/server/crates/arbiter-terrors-poc/src/errors.rs @@ -0,0 +1,84 @@ +use terrors::OneOf; + +// Wire boundary type — what would go into a proto response +#[derive(Debug)] +pub enum ProtoError { + NotRegistered, + InvalidSignature, + Internal(String), +} + +// Internal terrors types +pub struct NotRegistered; +pub struct InvalidSignature; +pub struct Internal(pub String); + +impl From for ProtoError { + fn from(_: NotRegistered) -> Self { + ProtoError::NotRegistered + } +} + +impl From for ProtoError { + fn from(_: InvalidSignature) -> Self { + ProtoError::InvalidSignature + } +} + +impl From for ProtoError { + fn from(e: Internal) -> Self { + ProtoError::Internal(e.0) + } +} + +// Converts the narrowed remainder after handling NotRegistered +impl From> for ProtoError { + fn from(e: OneOf<(InvalidSignature, Internal)>) -> Self { + match e.narrow::() { + Ok(_) => ProtoError::InvalidSignature, + Err(e) => { + let Internal(msg) = e.take(); + ProtoError::Internal(msg) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn not_registered_converts_to_proto() { + let e: ProtoError = NotRegistered.into(); + assert!(matches!(e, ProtoError::NotRegistered)); + } + + #[test] + fn invalid_signature_converts_to_proto() { + let e: ProtoError = InvalidSignature.into(); + assert!(matches!(e, ProtoError::InvalidSignature)); + } + + #[test] + fn internal_converts_to_proto() { + let e: ProtoError = Internal("boom".into()).into(); + assert!(matches!(e, ProtoError::Internal(msg) if msg == "boom")); + } + + #[test] + fn one_of_remainder_converts_to_proto_invalid_signature() { + use terrors::OneOf; + let e: OneOf<(InvalidSignature, Internal)> = OneOf::new(InvalidSignature); + let proto = ProtoError::from(e); + assert!(matches!(proto, ProtoError::InvalidSignature)); + } + + #[test] + fn one_of_remainder_converts_to_proto_internal() { + use terrors::OneOf; + let e: OneOf<(InvalidSignature, Internal)> = OneOf::new(Internal("db fail".into())); + let proto = ProtoError::from(e); + assert!(matches!(proto, ProtoError::Internal(msg) if msg == "db fail")); + } +} diff --git a/server/crates/arbiter-terrors-poc/src/main.rs b/server/crates/arbiter-terrors-poc/src/main.rs new file mode 100644 index 0000000..f18efb5 --- /dev/null +++ b/server/crates/arbiter-terrors-poc/src/main.rs @@ -0,0 +1,3 @@ +mod errors; + +fn main() {}