diff --git a/server/crates/arbiter-client/src/lib.rs b/server/crates/arbiter-client/src/lib.rs index 2c7ec5b..cd65220 100644 --- a/server/crates/arbiter-client/src/lib.rs +++ b/server/crates/arbiter-client/src/lib.rs @@ -367,7 +367,9 @@ async fn receive_auth_confirmation( .await .map_err(|_| ConnectError::UnexpectedAuthResponse)?; - let payload = response.payload.ok_or(ConnectError::UnexpectedAuthResponse)?; + let payload = response + .payload + .ok_or(ConnectError::UnexpectedAuthResponse)?; match payload { ClientResponsePayload::AuthOk(_) => Ok(()), ClientResponsePayload::ClientConnectError(err) => Err(map_connect_error(err.code)), diff --git a/server/crates/arbiter-server/src/actors/client/auth.rs b/server/crates/arbiter-server/src/actors/client/auth.rs index 55acb4c..18386d8 100644 --- a/server/crates/arbiter-server/src/actors/client/auth.rs +++ b/server/crates/arbiter-server/src/actors/client/auth.rs @@ -1,8 +1,8 @@ use arbiter_proto::{ format_challenge, proto::client::{ - AuthChallenge, AuthChallengeSolution, ClientConnectError, ClientRequest, ClientResponse, - client_connect_error::Code as ConnectErrorCode, + AuthChallenge, AuthChallengeSolution, AuthOk, ClientConnectError, ClientRequest, + ClientResponse, client_connect_error::Code as ConnectErrorCode, client_request::Payload as ClientRequestPayload, client_response::Payload as ClientResponsePayload, }, @@ -26,6 +26,25 @@ use crate::{ use super::session::ClientSession; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct ClientId(i32); + +impl ClientId { + pub fn new(raw: i32) -> Self { + Self(raw) + } + + pub fn as_i32(self) -> i32 { + self.0 + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct ClientNonceState { + client_id: ClientId, + nonce: i32, +} + #[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)] pub enum Error { #[error("Unexpected message payload")] @@ -63,7 +82,7 @@ pub enum ApproveError { async fn get_nonce( db: &db::DatabasePool, pubkey: &VerifyingKey, -) -> Result, Error> { +) -> Result, Error> { let pubkey_bytes = pubkey.as_bytes().to_vec(); let mut conn = db.get().await.map_err(|e| { @@ -90,7 +109,10 @@ async fn get_nonce( .execute(conn) .await?; - Ok(Some((client_id, current_nonce))) + Ok(Some(ClientNonceState { + client_id: ClientId::new(client_id), + nonce: current_nonce, + })) }) }) .await @@ -126,7 +148,7 @@ async fn approve_new_client( } enum InsertClientResult { - Inserted(i32), + Inserted(ClientId), AlreadyExists, } @@ -176,7 +198,7 @@ async fn insert_client( Error::DatabaseOperationFailed })?; - Ok(InsertClientResult::Inserted(client_id)) + Ok(InsertClientResult::Inserted(ClientId::new(client_id))) } async fn challenge_client( @@ -224,6 +246,17 @@ async fn challenge_client( Error::InvalidChallengeSolution })?; + props + .transport + .send(Ok(ClientResponse { + payload: Some(ClientResponsePayload::AuthOk(AuthOk {})), + })) + .await + .map_err(|e| { + error!(error = ?e, "Failed to send auth ok"); + Error::Transport + })?; + Ok(()) } @@ -237,7 +270,7 @@ fn connect_error_code(err: &Error) -> ConnectErrorCode { } } -async fn authenticate(props: &mut ClientConnection) -> Result<(VerifyingKey, i32), Error> { +async fn authenticate(props: &mut ClientConnection) -> Result<(VerifyingKey, ClientId), Error> { let Some(ClientRequest { payload: Some(ClientRequestPayload::AuthChallengeRequest(challenge)), }) = props.transport.recv().await @@ -253,13 +286,13 @@ async fn authenticate(props: &mut ClientConnection) -> Result<(VerifyingKey, i32 VerifyingKey::from_bytes(pubkey_bytes).map_err(|_| Error::InvalidAuthPubkeyEncoding)?; let (client_id, nonce) = match get_nonce(&props.db, &pubkey).await? { - Some((client_id, nonce)) => (client_id, nonce), + Some(state) => (state.client_id, state.nonce), None => { approve_new_client(&props.actors, pubkey).await?; match insert_client(&props.db, &pubkey).await? { InsertClientResult::Inserted(client_id) => (client_id, 0), InsertClientResult::AlreadyExists => match get_nonce(&props.db, &pubkey).await? { - Some((client_id, nonce)) => (client_id, nonce), + Some(state) => (state.client_id, state.nonce), None => return Err(Error::InternalError), }, } diff --git a/server/crates/arbiter-server/tests/client/auth.rs b/server/crates/arbiter-server/tests/client/auth.rs index 5d82423..45b7335 100644 --- a/server/crates/arbiter-server/tests/client/auth.rs +++ b/server/crates/arbiter-server/tests/client/auth.rs @@ -114,6 +114,15 @@ pub async fn test_challenge_auth() { .await .unwrap(); + let response = test_transport.recv().await.expect("should receive auth ok"); + match response { + Ok(resp) => match resp.payload { + Some(ClientResponsePayload::AuthOk(_)) => {} + other => panic!("Expected AuthOk, got {other:?}"), + }, + Err(err) => panic!("Expected Ok response, got Err({err:?})"), + } + // Auth completes, session spawned task.await.unwrap(); } @@ -178,6 +187,15 @@ pub async fn test_evm_sign_request_payload_is_handled() { .await .unwrap(); + let response = test_transport.recv().await.expect("should receive auth ok"); + match response { + Ok(resp) => match resp.payload { + Some(ClientResponsePayload::AuthOk(_)) => {} + other => panic!("Expected AuthOk, got {other:?}"), + }, + Err(err) => panic!("Expected Ok response, got Err({err:?})"), + } + task.await.unwrap(); let tx = TxEip1559 {