Compare commits
13 Commits
PoC-terror
...
784261f4d8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
784261f4d8 | ||
|
|
971db0e919 | ||
|
|
e1a8553142 | ||
|
|
ec70561c93 | ||
|
|
3993d3a8cc | ||
|
|
c87456ae2f | ||
|
|
e89983de3a | ||
|
|
f56668d9f6 | ||
|
|
434738bae5 | ||
|
|
77c3babec7 | ||
|
|
6f03ce4d1d | ||
|
|
c90af9c196 | ||
|
|
a5a9bc73b0 |
@@ -106,6 +106,16 @@ enum VaultState {
|
||||
VAULT_STATE_ERROR = 4;
|
||||
}
|
||||
|
||||
message SdkClientConnectionRequest {
|
||||
bytes pubkey = 1;
|
||||
}
|
||||
|
||||
message SdkClientConnectionResponse {
|
||||
bool approved = 1;
|
||||
}
|
||||
|
||||
message SdkClientConnectionCancel {}
|
||||
|
||||
message UserAgentRequest {
|
||||
oneof payload {
|
||||
AuthChallengeRequest auth_challenge_request = 1;
|
||||
@@ -118,7 +128,7 @@ message UserAgentRequest {
|
||||
arbiter.evm.EvmGrantCreateRequest evm_grant_create = 8;
|
||||
arbiter.evm.EvmGrantDeleteRequest evm_grant_delete = 9;
|
||||
arbiter.evm.EvmGrantListRequest evm_grant_list = 10;
|
||||
// field 11 reserved: was client_connection_response (online approval removed)
|
||||
SdkClientConnectionResponse sdk_client_connection_response = 11;
|
||||
SdkClientApproveRequest sdk_client_approve = 12;
|
||||
SdkClientRevokeRequest sdk_client_revoke = 13;
|
||||
google.protobuf.Empty sdk_client_list = 14;
|
||||
@@ -136,7 +146,8 @@ message UserAgentResponse {
|
||||
arbiter.evm.EvmGrantCreateResponse evm_grant_create = 8;
|
||||
arbiter.evm.EvmGrantDeleteResponse evm_grant_delete = 9;
|
||||
arbiter.evm.EvmGrantListResponse evm_grant_list = 10;
|
||||
// fields 11, 12 reserved: were client_connection_request, client_connection_cancel (online approval removed)
|
||||
SdkClientConnectionRequest sdk_client_connection_request = 11;
|
||||
SdkClientConnectionCancel sdk_client_connection_cancel = 12;
|
||||
SdkClientApproveResponse sdk_client_approve = 13;
|
||||
SdkClientRevokeResponse sdk_client_revoke = 14;
|
||||
SdkClientListResponse sdk_client_list = 15;
|
||||
|
||||
139
server/Cargo.lock
generated
139
server/Cargo.lock
generated
@@ -100,7 +100,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_with",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -136,7 +136,7 @@ dependencies = [
|
||||
"futures",
|
||||
"futures-util",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -178,7 +178,7 @@ dependencies = [
|
||||
"alloy-rlp",
|
||||
"crc",
|
||||
"serde",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -203,7 +203,7 @@ dependencies = [
|
||||
"alloy-rlp",
|
||||
"borsh",
|
||||
"serde",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -239,7 +239,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_with",
|
||||
"sha2 0.10.9",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -280,7 +280,7 @@ dependencies = [
|
||||
"http",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
@@ -307,7 +307,7 @@ dependencies = [
|
||||
"futures-utils-wasm",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -382,7 +382,7 @@ dependencies = [
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"url",
|
||||
@@ -475,7 +475,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_with",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -501,7 +501,7 @@ dependencies = [
|
||||
"either",
|
||||
"elliptic-curve",
|
||||
"k256",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -517,7 +517,7 @@ dependencies = [
|
||||
"async-trait",
|
||||
"k256",
|
||||
"rand 0.8.5",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -608,7 +608,7 @@ dependencies = [
|
||||
"parking_lot",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tower",
|
||||
"tracing",
|
||||
@@ -644,7 +644,7 @@ dependencies = [
|
||||
"nybbles",
|
||||
"serde",
|
||||
"smallvec",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
@@ -684,8 +684,9 @@ dependencies = [
|
||||
"async-trait",
|
||||
"ed25519-dalek",
|
||||
"http",
|
||||
"rand 0.10.0",
|
||||
"rustls-webpki",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tonic",
|
||||
@@ -708,7 +709,7 @@ dependencies = [
|
||||
"rcgen",
|
||||
"rstest",
|
||||
"rustls-pki-types",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tonic",
|
||||
"tonic-prost",
|
||||
@@ -733,6 +734,7 @@ dependencies = [
|
||||
"diesel-async",
|
||||
"diesel_migrations",
|
||||
"ed25519-dalek",
|
||||
"fatality",
|
||||
"futures",
|
||||
"insta",
|
||||
"k256",
|
||||
@@ -751,7 +753,7 @@ dependencies = [
|
||||
"spki",
|
||||
"strum",
|
||||
"test-log",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tonic",
|
||||
@@ -761,13 +763,6 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "arbiter-terrors-poc"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"terrors",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "arbiter-tokens-registry"
|
||||
version = "0.1.0"
|
||||
@@ -791,7 +786,7 @@ dependencies = [
|
||||
"sha2 0.10.9",
|
||||
"smlang",
|
||||
"spki",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tonic",
|
||||
@@ -1019,7 +1014,7 @@ dependencies = [
|
||||
"nom",
|
||||
"num-traits",
|
||||
"rusticata-macros",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
"time",
|
||||
]
|
||||
|
||||
@@ -2079,6 +2074,21 @@ dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "expander"
|
||||
version = "2.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2c470c71d91ecbd179935b24170459e926382eaaa86b590b78814e180d8a8e2"
|
||||
dependencies = [
|
||||
"blake2",
|
||||
"file-guard",
|
||||
"fs-err",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.3.0"
|
||||
@@ -2107,6 +2117,30 @@ dependencies = [
|
||||
"bytes",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fatality"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec6f82451ff7f0568c6181287189126d492b5654e30a788add08027b6363d019"
|
||||
dependencies = [
|
||||
"fatality-proc-macro",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fatality-proc-macro"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eb42427514b063d97ce21d5199f36c0c307d981434a6be32582bc79fe5bd2303"
|
||||
dependencies = [
|
||||
"expander",
|
||||
"indexmap 2.13.0",
|
||||
"proc-macro-crate",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ff"
|
||||
version = "0.13.1"
|
||||
@@ -2129,6 +2163,16 @@ version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "64cd1e32ddd350061ae6edb1b082d7c54915b5c672c389143b9a63403a109f24"
|
||||
|
||||
[[package]]
|
||||
name = "file-guard"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "21ef72acf95ec3d7dbf61275be556299490a245f017cf084bd23b4f68cf9407c"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "find-msvc-tools"
|
||||
version = "0.1.9"
|
||||
@@ -2190,6 +2234,15 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fs-err"
|
||||
version = "2.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "88a41f105fe1d5b6b34b2055e3dc59bb79b46b48b2040b9e6c7b4b5de097aa41"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fs_extra"
|
||||
version = "1.3.0"
|
||||
@@ -3784,7 +3837,7 @@ dependencies = [
|
||||
"rustc-hash",
|
||||
"rustls",
|
||||
"socket2",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"web-time",
|
||||
@@ -3805,7 +3858,7 @@ dependencies = [
|
||||
"rustls",
|
||||
"rustls-pki-types",
|
||||
"slab",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
"tinyvec",
|
||||
"tracing",
|
||||
"web-time",
|
||||
@@ -4140,7 +4193,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a8a1f2315036ef6b1fbacd1972e8ee7688030b0a2121edfc2a6550febd41574d"
|
||||
dependencies = [
|
||||
"hashbrown 0.16.1",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4863,12 +4916,6 @@ dependencies = [
|
||||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "terrors"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "987fd8c678ca950df2a18b2c6e9da6ca511d449278fab3565efe0d49c0c07a5d"
|
||||
|
||||
[[package]]
|
||||
name = "test-log"
|
||||
version = "0.2.19"
|
||||
@@ -4900,13 +4947,33 @@ dependencies = [
|
||||
"unicode-width 0.2.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.69"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52"
|
||||
dependencies = [
|
||||
"thiserror-impl 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "2.0.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
"thiserror-impl 2.0.18",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.69"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6000,7 +6067,7 @@ dependencies = [
|
||||
"nom",
|
||||
"oid-registry",
|
||||
"rusticata-macros",
|
||||
"thiserror",
|
||||
"thiserror 2.0.18",
|
||||
"time",
|
||||
]
|
||||
|
||||
|
||||
@@ -20,3 +20,4 @@ thiserror.workspace = true
|
||||
http = "1.4.0"
|
||||
rustls-webpki = { version = "0.103.9", features = ["aws-lc-rs"] }
|
||||
async-trait.workspace = true
|
||||
rand.workspace = true
|
||||
|
||||
@@ -5,7 +5,7 @@ use alloy::{
|
||||
signers::{Error, Result, Signer},
|
||||
};
|
||||
use arbiter_proto::{
|
||||
format_challenge,
|
||||
format_challenge, home_path,
|
||||
proto::{
|
||||
arbiter_service_client::ArbiterServiceClient,
|
||||
client::{
|
||||
@@ -21,10 +21,13 @@ use arbiter_proto::{
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use ed25519_dalek::Signer as _;
|
||||
use std::path::{Path, PathBuf};
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tonic::transport::ClientTlsConfig;
|
||||
|
||||
const BUFFER_LENGTH: usize = 16;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ConnectError {
|
||||
#[error("Could not establish connection")]
|
||||
@@ -50,6 +53,83 @@ pub enum ConnectError {
|
||||
|
||||
#[error("Unexpected auth response payload")]
|
||||
UnexpectedAuthResponse,
|
||||
|
||||
#[error("Signing key storage error")]
|
||||
Storage(#[from] StorageError),
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum StorageError {
|
||||
#[error("I/O error")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("Invalid signing key length in storage: expected {expected} bytes, got {actual} bytes")]
|
||||
InvalidKeyLength { expected: usize, actual: usize },
|
||||
}
|
||||
|
||||
pub trait SigningKeyStorage {
|
||||
fn load_or_create(&self) -> std::result::Result<ed25519_dalek::SigningKey, StorageError>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FileSigningKeyStorage {
|
||||
path: PathBuf,
|
||||
}
|
||||
|
||||
impl FileSigningKeyStorage {
|
||||
pub const DEFAULT_FILE_NAME: &str = "sdk_client_ed25519.key";
|
||||
|
||||
pub fn new(path: impl Into<PathBuf>) -> Self {
|
||||
Self { path: path.into() }
|
||||
}
|
||||
|
||||
pub fn from_default_location() -> std::result::Result<Self, StorageError> {
|
||||
Ok(Self::new(home_path()?.join(Self::DEFAULT_FILE_NAME)))
|
||||
}
|
||||
|
||||
fn read_key(path: &Path) -> std::result::Result<ed25519_dalek::SigningKey, StorageError> {
|
||||
let bytes = std::fs::read(path)?;
|
||||
let raw: [u8; 32] =
|
||||
bytes
|
||||
.try_into()
|
||||
.map_err(|v: Vec<u8>| StorageError::InvalidKeyLength {
|
||||
expected: 32,
|
||||
actual: v.len(),
|
||||
})?;
|
||||
Ok(ed25519_dalek::SigningKey::from_bytes(&raw))
|
||||
}
|
||||
}
|
||||
|
||||
impl SigningKeyStorage for FileSigningKeyStorage {
|
||||
fn load_or_create(&self) -> std::result::Result<ed25519_dalek::SigningKey, StorageError> {
|
||||
if let Some(parent) = self.path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
if self.path.exists() {
|
||||
return Self::read_key(&self.path);
|
||||
}
|
||||
|
||||
let key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
|
||||
let raw_key = key.to_bytes();
|
||||
|
||||
// Use create_new to prevent accidental overwrite if another process creates the key first.
|
||||
match std::fs::OpenOptions::new()
|
||||
.create_new(true)
|
||||
.write(true)
|
||||
.open(&self.path)
|
||||
{
|
||||
Ok(mut file) => {
|
||||
use std::io::Write as _;
|
||||
file.write_all(&raw_key)?;
|
||||
Ok(key)
|
||||
}
|
||||
Err(err) if err.kind() == std::io::ErrorKind::AlreadyExists => {
|
||||
Self::read_key(&self.path)
|
||||
}
|
||||
Err(err) => Err(StorageError::Io(err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
@@ -65,6 +145,9 @@ enum ClientSignError {
|
||||
|
||||
#[error("Remote signing was rejected")]
|
||||
Rejected,
|
||||
|
||||
#[error("Wallet address is not configured")]
|
||||
WalletAddressNotConfigured,
|
||||
}
|
||||
|
||||
struct ClientTransport {
|
||||
@@ -91,15 +174,27 @@ impl ClientTransport {
|
||||
|
||||
pub struct ArbiterSigner {
|
||||
transport: Mutex<ClientTransport>,
|
||||
address: Address,
|
||||
address: Option<Address>,
|
||||
chain_id: Option<ChainId>,
|
||||
}
|
||||
|
||||
impl ArbiterSigner {
|
||||
pub async fn connect_grpc(
|
||||
pub async fn connect_grpc(url: ArbiterUrl) -> std::result::Result<Self, ConnectError> {
|
||||
let storage = FileSigningKeyStorage::from_default_location()?;
|
||||
Self::connect_grpc_with_storage(url, &storage).await
|
||||
}
|
||||
|
||||
pub async fn connect_grpc_with_storage<S: SigningKeyStorage>(
|
||||
url: ArbiterUrl,
|
||||
storage: &S,
|
||||
) -> std::result::Result<Self, ConnectError> {
|
||||
let key = storage.load_or_create()?;
|
||||
Self::connect_grpc_with_key(url, key).await
|
||||
}
|
||||
|
||||
pub async fn connect_grpc_with_key(
|
||||
url: ArbiterUrl,
|
||||
key: ed25519_dalek::SigningKey,
|
||||
address: Address,
|
||||
) -> std::result::Result<Self, ConnectError> {
|
||||
let anchor = webpki::anchor_from_trusted_cert(&url.ca_cert)?.to_owned();
|
||||
let tls = ClientTlsConfig::new().trust_anchor(anchor);
|
||||
@@ -112,7 +207,7 @@ impl ArbiterSigner {
|
||||
.await?;
|
||||
|
||||
let mut client = ArbiterServiceClient::new(channel);
|
||||
let (tx, rx) = mpsc::channel(16);
|
||||
let (tx, rx) = mpsc::channel(BUFFER_LENGTH);
|
||||
let response_stream = client.client(ReceiverStream::new(rx)).await?.into_inner();
|
||||
|
||||
let mut transport = ClientTransport {
|
||||
@@ -120,19 +215,37 @@ impl ArbiterSigner {
|
||||
receiver: response_stream,
|
||||
};
|
||||
|
||||
authenticate(&mut transport, key).await?;
|
||||
authenticate(&mut transport, &key).await?;
|
||||
|
||||
Ok(Self {
|
||||
transport: Mutex::new(transport),
|
||||
address,
|
||||
address: None,
|
||||
chain_id: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn sign_transaction_via_arbiter(
|
||||
pub fn wallet_address(&self) -> Option<Address> {
|
||||
self.address
|
||||
}
|
||||
|
||||
pub fn set_wallet_address(&mut self, address: Option<Address>) {
|
||||
self.address = address;
|
||||
}
|
||||
|
||||
pub fn with_wallet_address(mut self, address: Address) -> Self {
|
||||
self.address = Some(address);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_chain_id(mut self, chain_id: ChainId) -> Self {
|
||||
self.chain_id = Some(chain_id);
|
||||
self
|
||||
}
|
||||
|
||||
fn build_sign_transaction_request(
|
||||
&self,
|
||||
tx: &mut dyn SignableTransaction<Signature>,
|
||||
) -> Result<Signature> {
|
||||
) -> Result<ClientRequest> {
|
||||
if let Some(chain_id) = self.chain_id
|
||||
&& !tx.set_chain_id_checked(chain_id)
|
||||
{
|
||||
@@ -145,15 +258,21 @@ impl ArbiterSigner {
|
||||
let mut rlp_transaction = Vec::new();
|
||||
tx.encode_for_signing(&mut rlp_transaction);
|
||||
|
||||
let request = ClientRequest {
|
||||
let wallet_address = self
|
||||
.address
|
||||
.ok_or_else(|| Error::other(ClientSignError::WalletAddressNotConfigured))?;
|
||||
|
||||
Ok(ClientRequest {
|
||||
payload: Some(ClientRequestPayload::EvmSignTransaction(
|
||||
EvmSignTransactionRequest {
|
||||
wallet_address: self.address.as_slice().to_vec(),
|
||||
wallet_address: wallet_address.as_slice().to_vec(),
|
||||
rlp_transaction,
|
||||
},
|
||||
)),
|
||||
};
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute_sign_transaction_request(&self, request: ClientRequest) -> Result<Signature> {
|
||||
let mut transport = self.transport.lock().await;
|
||||
transport.send(request).await.map_err(Error::other)?;
|
||||
let response = transport.recv().await.map_err(Error::other)?;
|
||||
@@ -181,9 +300,18 @@ impl ArbiterSigner {
|
||||
}
|
||||
}
|
||||
|
||||
async fn authenticate(
|
||||
fn map_connect_error(code: i32) -> ConnectError {
|
||||
match client_connect_error::Code::try_from(code).unwrap_or(client_connect_error::Code::Unknown)
|
||||
{
|
||||
client_connect_error::Code::ApprovalDenied => ConnectError::ApprovalDenied,
|
||||
client_connect_error::Code::NoUserAgentsOnline => ConnectError::NoUserAgentsOnline,
|
||||
client_connect_error::Code::Unknown => ConnectError::UnexpectedAuthResponse,
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_auth_challenge_request(
|
||||
transport: &mut ClientTransport,
|
||||
key: ed25519_dalek::SigningKey,
|
||||
key: &ed25519_dalek::SigningKey,
|
||||
) -> std::result::Result<(), ConnectError> {
|
||||
transport
|
||||
.send(ClientRequest {
|
||||
@@ -194,8 +322,12 @@ async fn authenticate(
|
||||
)),
|
||||
})
|
||||
.await
|
||||
.map_err(|_| ConnectError::UnexpectedAuthResponse)?;
|
||||
.map_err(|_| ConnectError::UnexpectedAuthResponse)
|
||||
}
|
||||
|
||||
async fn receive_auth_challenge(
|
||||
transport: &mut ClientTransport,
|
||||
) -> std::result::Result<arbiter_proto::proto::client::AuthChallenge, ConnectError> {
|
||||
let response = transport
|
||||
.recv()
|
||||
.await
|
||||
@@ -203,39 +335,58 @@ async fn authenticate(
|
||||
|
||||
let payload = response.payload.ok_or(ConnectError::MissingAuthChallenge)?;
|
||||
match payload {
|
||||
ClientResponsePayload::AuthChallenge(challenge) => {
|
||||
let challenge_payload = format_challenge(challenge.nonce, &challenge.pubkey);
|
||||
let signature = key.sign(&challenge_payload).to_bytes().to_vec();
|
||||
|
||||
transport
|
||||
.send(ClientRequest {
|
||||
payload: Some(ClientRequestPayload::AuthChallengeSolution(
|
||||
AuthChallengeSolution { signature },
|
||||
)),
|
||||
})
|
||||
.await
|
||||
.map_err(|_| ConnectError::UnexpectedAuthResponse)?;
|
||||
|
||||
// Current server flow does not emit `AuthOk` for SDK clients, so we proceed after
|
||||
// sending the solution. If authentication fails, the first business request will return
|
||||
// a `ClientConnectError` or the stream will close.
|
||||
Ok(())
|
||||
}
|
||||
ClientResponsePayload::ClientConnectError(err) => {
|
||||
match client_connect_error::Code::try_from(err.code)
|
||||
.unwrap_or(client_connect_error::Code::Unknown)
|
||||
{
|
||||
client_connect_error::Code::ApprovalDenied => Err(ConnectError::ApprovalDenied),
|
||||
client_connect_error::Code::NoUserAgentsOnline => {
|
||||
Err(ConnectError::NoUserAgentsOnline)
|
||||
}
|
||||
client_connect_error::Code::Unknown => Err(ConnectError::UnexpectedAuthResponse),
|
||||
}
|
||||
}
|
||||
ClientResponsePayload::AuthChallenge(challenge) => Ok(challenge),
|
||||
ClientResponsePayload::ClientConnectError(err) => Err(map_connect_error(err.code)),
|
||||
_ => Err(ConnectError::UnexpectedAuthResponse),
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_auth_challenge_solution(
|
||||
transport: &mut ClientTransport,
|
||||
key: &ed25519_dalek::SigningKey,
|
||||
challenge: arbiter_proto::proto::client::AuthChallenge,
|
||||
) -> std::result::Result<(), ConnectError> {
|
||||
let challenge_payload = format_challenge(challenge.nonce, &challenge.pubkey);
|
||||
let signature = key.sign(&challenge_payload).to_bytes().to_vec();
|
||||
|
||||
transport
|
||||
.send(ClientRequest {
|
||||
payload: Some(ClientRequestPayload::AuthChallengeSolution(
|
||||
AuthChallengeSolution { signature },
|
||||
)),
|
||||
})
|
||||
.await
|
||||
.map_err(|_| ConnectError::UnexpectedAuthResponse)
|
||||
}
|
||||
|
||||
async fn receive_auth_confirmation(
|
||||
transport: &mut ClientTransport,
|
||||
) -> std::result::Result<(), ConnectError> {
|
||||
let response = transport
|
||||
.recv()
|
||||
.await
|
||||
.map_err(|_| ConnectError::UnexpectedAuthResponse)?;
|
||||
|
||||
let payload = response
|
||||
.payload
|
||||
.ok_or(ConnectError::UnexpectedAuthResponse)?;
|
||||
match payload {
|
||||
ClientResponsePayload::AuthOk(_) => Ok(()),
|
||||
ClientResponsePayload::ClientConnectError(err) => Err(map_connect_error(err.code)),
|
||||
_ => Err(ConnectError::UnexpectedAuthResponse),
|
||||
}
|
||||
}
|
||||
|
||||
async fn authenticate(
|
||||
transport: &mut ClientTransport,
|
||||
key: &ed25519_dalek::SigningKey,
|
||||
) -> std::result::Result<(), ConnectError> {
|
||||
send_auth_challenge_request(transport, key).await?;
|
||||
let challenge = receive_auth_challenge(transport).await?;
|
||||
send_auth_challenge_solution(transport, key, challenge).await?;
|
||||
receive_auth_confirmation(transport).await
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Signer for ArbiterSigner {
|
||||
async fn sign_hash(&self, _hash: &B256) -> Result<Signature> {
|
||||
@@ -245,7 +396,7 @@ impl Signer for ArbiterSigner {
|
||||
}
|
||||
|
||||
fn address(&self) -> Address {
|
||||
self.address
|
||||
self.address.unwrap_or(Address::ZERO)
|
||||
}
|
||||
|
||||
fn chain_id(&self) -> Option<ChainId> {
|
||||
@@ -260,13 +411,70 @@ impl Signer for ArbiterSigner {
|
||||
#[async_trait]
|
||||
impl TxSigner<Signature> for ArbiterSigner {
|
||||
fn address(&self) -> Address {
|
||||
self.address
|
||||
self.address.unwrap_or(Address::ZERO)
|
||||
}
|
||||
|
||||
async fn sign_transaction(
|
||||
&self,
|
||||
tx: &mut dyn SignableTransaction<Signature>,
|
||||
) -> Result<Signature> {
|
||||
self.sign_transaction_via_arbiter(tx).await
|
||||
let request = self.build_sign_transaction_request(tx)?;
|
||||
self.execute_sign_transaction_request(request).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{FileSigningKeyStorage, SigningKeyStorage, StorageError};
|
||||
|
||||
fn unique_temp_key_path() -> std::path::PathBuf {
|
||||
let nanos = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.expect("clock should be after unix epoch")
|
||||
.as_nanos();
|
||||
std::env::temp_dir().join(format!(
|
||||
"arbiter-client-key-{}-{}.bin",
|
||||
std::process::id(),
|
||||
nanos
|
||||
))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn file_storage_creates_and_reuses_key() {
|
||||
let path = unique_temp_key_path();
|
||||
let storage = FileSigningKeyStorage::new(path.clone());
|
||||
|
||||
let key_a = storage
|
||||
.load_or_create()
|
||||
.expect("first load_or_create should create key");
|
||||
let key_b = storage
|
||||
.load_or_create()
|
||||
.expect("second load_or_create should read same key");
|
||||
|
||||
assert_eq!(key_a.to_bytes(), key_b.to_bytes());
|
||||
assert!(path.exists());
|
||||
|
||||
std::fs::remove_file(path).expect("temp key file should be removable");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn file_storage_rejects_invalid_key_length() {
|
||||
let path = unique_temp_key_path();
|
||||
std::fs::write(&path, [42u8; 31]).expect("should write invalid key file");
|
||||
let storage = FileSigningKeyStorage::new(path.clone());
|
||||
|
||||
let err = storage
|
||||
.load_or_create()
|
||||
.expect_err("storage should reject non-32-byte key file");
|
||||
|
||||
match err {
|
||||
StorageError::InvalidKeyLength { expected, actual } => {
|
||||
assert_eq!(expected, 32);
|
||||
assert_eq!(actual, 31);
|
||||
}
|
||||
other => panic!("unexpected error: {other:?}"),
|
||||
}
|
||||
|
||||
std::fs::remove_file(path).expect("temp key file should be removable");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ rustls.workspace = true
|
||||
smlang.workspace = true
|
||||
miette.workspace = true
|
||||
thiserror.workspace = true
|
||||
fatality = "0.1.1"
|
||||
diesel_migrations = { version = "2.3.1", features = ["sqlite"] }
|
||||
async-trait.workspace = true
|
||||
secrecy = "0.10.3"
|
||||
|
||||
@@ -157,3 +157,5 @@ create table if not exists evm_ether_transfer_grant_target (
|
||||
|
||||
create unique index if not exists uniq_ether_transfer_target on evm_ether_transfer_grant_target(grant_id, address);
|
||||
|
||||
CREATE UNIQUE INDEX program_client_public_key_unique
|
||||
ON program_client (public_key);
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
DROP INDEX IF EXISTS program_client_public_key_unique;
|
||||
@@ -1,2 +0,0 @@
|
||||
CREATE UNIQUE INDEX program_client_public_key_unique
|
||||
ON program_client (public_key);
|
||||
@@ -1,25 +1,50 @@
|
||||
use arbiter_proto::{
|
||||
format_challenge,
|
||||
proto::client::{
|
||||
AuthChallenge, AuthChallengeSolution, ClientConnectError, ClientRequest, ClientResponse,
|
||||
client_connect_error::Code as ConnectErrorCode,
|
||||
AuthChallenge, AuthChallengeSolution, AuthOk, ClientConnectError, ClientRequest,
|
||||
ClientResponse, client_connect_error::Code as ConnectErrorCode,
|
||||
client_request::Payload as ClientRequestPayload,
|
||||
client_response::Payload as ClientResponsePayload,
|
||||
},
|
||||
transport::expect_message,
|
||||
};
|
||||
use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl as _, update};
|
||||
use diesel::{
|
||||
ExpressionMethods as _, OptionalExtension as _, QueryDsl as _, dsl::insert_into, update,
|
||||
};
|
||||
use diesel_async::RunQueryDsl as _;
|
||||
use ed25519_dalek::VerifyingKey;
|
||||
use kameo::error::SendError;
|
||||
use tracing::error;
|
||||
|
||||
use crate::{
|
||||
actors::client::ClientConnection,
|
||||
actors::{
|
||||
client::ClientConnection,
|
||||
router::{self, RequestClientApproval},
|
||||
},
|
||||
db::{self, schema::program_client},
|
||||
};
|
||||
|
||||
use super::session::ClientSession;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct ClientId(i32);
|
||||
|
||||
impl ClientId {
|
||||
pub fn new(raw: i32) -> Self {
|
||||
Self(raw)
|
||||
}
|
||||
|
||||
pub fn as_i32(self) -> i32 {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
struct ClientNonceState {
|
||||
client_id: ClientId,
|
||||
nonce: i32,
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
#[error("Unexpected message payload")]
|
||||
@@ -34,20 +59,30 @@ pub enum Error {
|
||||
DatabaseOperationFailed,
|
||||
#[error("Invalid challenge solution")]
|
||||
InvalidChallengeSolution,
|
||||
#[error("Client not registered")]
|
||||
NotRegistered,
|
||||
#[error("Client approval request failed")]
|
||||
ApproveError(#[from] ApproveError),
|
||||
#[error("Internal error")]
|
||||
InternalError,
|
||||
#[error("Transport error")]
|
||||
Transport,
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ApproveError {
|
||||
#[error("Internal error")]
|
||||
Internal,
|
||||
#[error("Client connection denied by user agents")]
|
||||
Denied,
|
||||
#[error("Upstream error: {0}")]
|
||||
Upstream(router::ApprovalError),
|
||||
}
|
||||
|
||||
/// Atomically reads and increments the nonce for a known client.
|
||||
/// Returns `None` if the pubkey is not registered.
|
||||
async fn get_nonce(
|
||||
db: &db::DatabasePool,
|
||||
pubkey: &VerifyingKey,
|
||||
) -> Result<Option<(i32, i32)>, Error> {
|
||||
) -> Result<Option<ClientNonceState>, Error> {
|
||||
let pubkey_bytes = pubkey.as_bytes().to_vec();
|
||||
|
||||
let mut conn = db.get().await.map_err(|e| {
|
||||
@@ -74,7 +109,10 @@ async fn get_nonce(
|
||||
.execute(conn)
|
||||
.await?;
|
||||
|
||||
Ok(Some((client_id, current_nonce)))
|
||||
Ok(Some(ClientNonceState {
|
||||
client_id: ClientId::new(client_id),
|
||||
nonce: current_nonce,
|
||||
}))
|
||||
})
|
||||
})
|
||||
.await
|
||||
@@ -84,6 +122,85 @@ async fn get_nonce(
|
||||
})
|
||||
}
|
||||
|
||||
async fn approve_new_client(
|
||||
actors: &crate::actors::GlobalActors,
|
||||
pubkey: VerifyingKey,
|
||||
) -> Result<(), Error> {
|
||||
let result = actors
|
||||
.router
|
||||
.ask(RequestClientApproval {
|
||||
client_pubkey: pubkey,
|
||||
})
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(true) => Ok(()),
|
||||
Ok(false) => Err(Error::ApproveError(ApproveError::Denied)),
|
||||
Err(SendError::HandlerError(e)) => {
|
||||
error!(error = ?e, "Approval upstream error");
|
||||
Err(Error::ApproveError(ApproveError::Upstream(e)))
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = ?e, "Approval request to router failed");
|
||||
Err(Error::ApproveError(ApproveError::Internal))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum InsertClientResult {
|
||||
Inserted(ClientId),
|
||||
AlreadyExists,
|
||||
}
|
||||
|
||||
async fn insert_client(
|
||||
db: &db::DatabasePool,
|
||||
pubkey: &VerifyingKey,
|
||||
) -> Result<InsertClientResult, Error> {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs() as i32;
|
||||
|
||||
let mut conn = db.get().await.map_err(|e| {
|
||||
error!(error = ?e, "Database pool error");
|
||||
Error::DatabasePoolUnavailable
|
||||
})?;
|
||||
|
||||
match insert_into(program_client::table)
|
||||
.values((
|
||||
program_client::public_key.eq(pubkey.as_bytes().to_vec()),
|
||||
program_client::nonce.eq(1), // pre-incremented; challenge uses 0
|
||||
program_client::created_at.eq(now),
|
||||
program_client::updated_at.eq(now),
|
||||
))
|
||||
.execute(&mut conn)
|
||||
.await
|
||||
{
|
||||
Ok(_) => {}
|
||||
Err(diesel::result::Error::DatabaseError(
|
||||
diesel::result::DatabaseErrorKind::UniqueViolation,
|
||||
_,
|
||||
)) => return Ok(InsertClientResult::AlreadyExists),
|
||||
Err(e) => {
|
||||
error!(error = ?e, "Failed to insert new client");
|
||||
return Err(Error::DatabaseOperationFailed);
|
||||
}
|
||||
}
|
||||
|
||||
let client_id = program_client::table
|
||||
.filter(program_client::public_key.eq(pubkey.as_bytes().to_vec()))
|
||||
.order(program_client::id.desc())
|
||||
.select(program_client::id)
|
||||
.first::<i32>(&mut conn)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!(error = ?e, "Failed to load inserted client id");
|
||||
Error::DatabaseOperationFailed
|
||||
})?;
|
||||
|
||||
Ok(InsertClientResult::Inserted(ClientId::new(client_id)))
|
||||
}
|
||||
|
||||
async fn challenge_client(
|
||||
props: &mut ClientConnection,
|
||||
pubkey: VerifyingKey,
|
||||
@@ -129,17 +246,31 @@ async fn challenge_client(
|
||||
Error::InvalidChallengeSolution
|
||||
})?;
|
||||
|
||||
props
|
||||
.transport
|
||||
.send(Ok(ClientResponse {
|
||||
payload: Some(ClientResponsePayload::AuthOk(AuthOk {})),
|
||||
}))
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!(error = ?e, "Failed to send auth ok");
|
||||
Error::Transport
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn connect_error_code(err: &Error) -> ConnectErrorCode {
|
||||
match err {
|
||||
Error::NotRegistered => ConnectErrorCode::ApprovalDenied,
|
||||
Error::ApproveError(ApproveError::Denied) => ConnectErrorCode::ApprovalDenied,
|
||||
Error::ApproveError(ApproveError::Upstream(
|
||||
router::ApprovalError::NoUserAgentsConnected,
|
||||
)) => ConnectErrorCode::NoUserAgentsOnline,
|
||||
_ => ConnectErrorCode::Unknown,
|
||||
}
|
||||
}
|
||||
|
||||
async fn authenticate(props: &mut ClientConnection) -> Result<(VerifyingKey, i32), Error> {
|
||||
async fn authenticate(props: &mut ClientConnection) -> Result<(VerifyingKey, ClientId), Error> {
|
||||
let Some(ClientRequest {
|
||||
payload: Some(ClientRequestPayload::AuthChallengeRequest(challenge)),
|
||||
}) = props.transport.recv().await
|
||||
@@ -155,8 +286,17 @@ async fn authenticate(props: &mut ClientConnection) -> Result<(VerifyingKey, i32
|
||||
VerifyingKey::from_bytes(pubkey_bytes).map_err(|_| Error::InvalidAuthPubkeyEncoding)?;
|
||||
|
||||
let (client_id, nonce) = match get_nonce(&props.db, &pubkey).await? {
|
||||
Some((client_id, nonce)) => (client_id, nonce),
|
||||
None => return Err(Error::NotRegistered),
|
||||
Some(state) => (state.client_id, state.nonce),
|
||||
None => {
|
||||
approve_new_client(&props.actors, pubkey).await?;
|
||||
match insert_client(&props.db, &pubkey).await? {
|
||||
InsertClientResult::Inserted(client_id) => (client_id, 0),
|
||||
InsertClientResult::AlreadyExists => match get_nonce(&props.db, &pubkey).await? {
|
||||
Some(state) => (state.client_id, state.nonce),
|
||||
None => return Err(Error::InternalError),
|
||||
},
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
challenge_client(props, pubkey, nonce).await?;
|
||||
|
||||
@@ -15,7 +15,7 @@ use tracing::{error, info};
|
||||
use crate::{
|
||||
actors::{
|
||||
GlobalActors,
|
||||
client::{ClientConnection, ClientError},
|
||||
client::{ClientConnection, ClientError, auth::ClientId},
|
||||
evm::ClientSignTransaction,
|
||||
router::RegisterClient,
|
||||
},
|
||||
@@ -24,11 +24,11 @@ use crate::{
|
||||
|
||||
pub struct ClientSession {
|
||||
props: ClientConnection,
|
||||
client_id: i32,
|
||||
client_id: ClientId,
|
||||
}
|
||||
|
||||
impl ClientSession {
|
||||
pub(crate) fn new(props: ClientConnection, client_id: i32) -> Self {
|
||||
pub(crate) fn new(props: ClientConnection, client_id: ClientId) -> Self {
|
||||
Self { props, client_id }
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ impl ClientSession {
|
||||
.actors
|
||||
.evm
|
||||
.ask(ClientSignTransaction {
|
||||
client_id: self.client_id,
|
||||
client_id: self.client_id.as_i32(),
|
||||
wallet_address: Address::from_slice(&wallet_address),
|
||||
transaction: tx,
|
||||
})
|
||||
@@ -145,7 +145,7 @@ impl ClientSession {
|
||||
let props = ClientConnection::new(db, transport, actors);
|
||||
Self {
|
||||
props,
|
||||
client_id: 0,
|
||||
client_id: ClientId::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
use std::{collections::HashMap, ops::ControlFlow};
|
||||
|
||||
use ed25519_dalek::VerifyingKey;
|
||||
use kameo::{
|
||||
Actor,
|
||||
actor::{ActorId, ActorRef},
|
||||
messages,
|
||||
prelude::{ActorStopReason, Context, WeakActorRef},
|
||||
reply::DelegatedReply,
|
||||
};
|
||||
use tracing::info;
|
||||
use tokio::{sync::watch, task::JoinSet};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::actors::{client::session::ClientSession, user_agent::session::UserAgentSession};
|
||||
use crate::actors::{
|
||||
client::session::ClientSession,
|
||||
user_agent::session::{RequestNewClientApproval, UserAgentSession},
|
||||
};
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct MessageRouter {
|
||||
@@ -50,6 +56,73 @@ impl Actor for MessageRouter {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum ApprovalError {
|
||||
#[error("No user agents connected")]
|
||||
NoUserAgentsConnected,
|
||||
}
|
||||
|
||||
async fn request_client_approval(
|
||||
user_agents: &[WeakActorRef<UserAgentSession>],
|
||||
client_pubkey: VerifyingKey,
|
||||
) -> Result<bool, ApprovalError> {
|
||||
if user_agents.is_empty() {
|
||||
return Err(ApprovalError::NoUserAgentsConnected);
|
||||
}
|
||||
|
||||
let mut pool = JoinSet::new();
|
||||
let (cancel_tx, cancel_rx) = watch::channel(());
|
||||
|
||||
for weak_ref in user_agents {
|
||||
match weak_ref.upgrade() {
|
||||
Some(agent) => {
|
||||
let cancel_rx = cancel_rx.clone();
|
||||
pool.spawn(async move {
|
||||
agent
|
||||
.ask(RequestNewClientApproval {
|
||||
client_pubkey,
|
||||
cancel_flag: cancel_rx.clone(),
|
||||
})
|
||||
.await
|
||||
});
|
||||
}
|
||||
None => {
|
||||
warn!(
|
||||
id = weak_ref.id().to_string(),
|
||||
actor = "MessageRouter",
|
||||
event = "useragent.disconnected_before_approval"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(result) = pool.join_next().await {
|
||||
match result {
|
||||
Ok(Ok(approved)) => {
|
||||
// cancel other pending requests
|
||||
let _ = cancel_tx.send(());
|
||||
return Ok(approved);
|
||||
}
|
||||
Ok(Err(err)) => {
|
||||
warn!(
|
||||
?err,
|
||||
actor = "MessageRouter",
|
||||
event = "useragent.approval_error"
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
?err,
|
||||
actor = "MessageRouter",
|
||||
event = "useragent.approval_task_failed"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(ApprovalError::NoUserAgentsConnected)
|
||||
}
|
||||
|
||||
#[messages]
|
||||
impl MessageRouter {
|
||||
#[message(ctx)]
|
||||
@@ -73,4 +146,28 @@ impl MessageRouter {
|
||||
ctx.actor_ref().link(&actor).await;
|
||||
self.clients.insert(actor.id(), actor);
|
||||
}
|
||||
|
||||
#[message(ctx)]
|
||||
pub async fn request_client_approval(
|
||||
&mut self,
|
||||
client_pubkey: VerifyingKey,
|
||||
ctx: &mut Context<Self, DelegatedReply<Result<bool, ApprovalError>>>,
|
||||
) -> DelegatedReply<Result<bool, ApprovalError>> {
|
||||
let (reply, Some(reply_sender)) = ctx.reply_sender() else {
|
||||
panic!("Expected `request_client_approval` to have callback channel");
|
||||
};
|
||||
|
||||
let weak_refs = self
|
||||
.user_agents
|
||||
.values()
|
||||
.map(|agent| agent.downgrade())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
let result = request_client_approval(&weak_refs, client_pubkey).await;
|
||||
reply_sender.send(result);
|
||||
});
|
||||
|
||||
reply
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
use arbiter_proto::{
|
||||
proto::user_agent::{UserAgentRequest, UserAgentResponse},
|
||||
proto::user_agent::{
|
||||
SdkClientError as ProtoSdkClientError, UserAgentRequest, UserAgentResponse,
|
||||
},
|
||||
transport::Bi,
|
||||
};
|
||||
use fatality::Fatality;
|
||||
use kameo::actor::Spawn as _;
|
||||
use tracing::{error, info};
|
||||
|
||||
@@ -24,12 +27,27 @@ pub enum TransportResponseError {
|
||||
StateTransitionFailed,
|
||||
#[error("Vault is not available")]
|
||||
KeyHolderActorUnreachable,
|
||||
#[error("SDK client approve failed: {0:?}")]
|
||||
SdkClientApprove(ProtoSdkClientError),
|
||||
#[error("SDK client list failed: {0:?}")]
|
||||
SdkClientList(ProtoSdkClientError),
|
||||
#[error("SDK client revoke failed: {0:?}")]
|
||||
SdkClientRevoke(ProtoSdkClientError),
|
||||
#[error(transparent)]
|
||||
Auth(#[from] auth::Error),
|
||||
#[error("Failed registering connection")]
|
||||
ConnectionRegistrationFailed,
|
||||
}
|
||||
|
||||
impl Fatality for TransportResponseError {
|
||||
fn is_fatal(&self) -> bool {
|
||||
!matches!(
|
||||
self,
|
||||
Self::SdkClientApprove(_) | Self::SdkClientList(_) | Self::SdkClientRevoke(_)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub type Transport =
|
||||
Box<dyn Bi<UserAgentRequest, Result<UserAgentResponse, TransportResponseError>> + Send>;
|
||||
|
||||
|
||||
@@ -3,21 +3,23 @@ use std::{ops::DerefMut, sync::Mutex};
|
||||
use arbiter_proto::proto::{
|
||||
evm as evm_proto,
|
||||
user_agent::{
|
||||
SdkClientApproveRequest, SdkClientApproveResponse, SdkClientEntry,
|
||||
SdkClientError as ProtoSdkClientError, SdkClientList, SdkClientListResponse,
|
||||
SdkClientRevokeRequest, SdkClientRevokeResponse, UnsealEncryptedKey, UnsealResult,
|
||||
UnsealStart, UnsealStartResponse, UserAgentRequest, UserAgentResponse,
|
||||
sdk_client_approve_response, sdk_client_list_response, sdk_client_revoke_response,
|
||||
user_agent_request::Payload as UserAgentRequestPayload,
|
||||
SdkClientApproveRequest, SdkClientApproveResponse, SdkClientConnectionCancel,
|
||||
SdkClientConnectionRequest, SdkClientEntry, SdkClientError as ProtoSdkClientError,
|
||||
SdkClientList, SdkClientListResponse, SdkClientRevokeRequest, SdkClientRevokeResponse,
|
||||
UnsealEncryptedKey, UnsealResult, UnsealStart, UnsealStartResponse, UserAgentRequest,
|
||||
UserAgentResponse, sdk_client_approve_response, sdk_client_list_response,
|
||||
sdk_client_revoke_response, user_agent_request::Payload as UserAgentRequestPayload,
|
||||
user_agent_response::Payload as UserAgentResponsePayload,
|
||||
},
|
||||
};
|
||||
use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit};
|
||||
use diesel::{ExpressionMethods as _, QueryDsl as _, dsl::insert_into};
|
||||
use diesel_async::RunQueryDsl as _;
|
||||
use kameo::{Actor, error::SendError, prelude::Context};
|
||||
use ed25519_dalek::VerifyingKey;
|
||||
use fatality::Fatality;
|
||||
use kameo::{Actor, error::SendError, messages, prelude::Context};
|
||||
use memsafe::MemSafe;
|
||||
use tokio::select;
|
||||
use tokio::{select, sync::watch};
|
||||
use tracing::{error, info};
|
||||
use x25519_dalek::{EphemeralSecret, PublicKey};
|
||||
|
||||
@@ -115,6 +117,53 @@ impl UserAgentSession {
|
||||
}
|
||||
}
|
||||
|
||||
#[messages]
|
||||
impl UserAgentSession {
|
||||
// TODO: Think about refactoring it to state-machine based flow, as we already have one
|
||||
#[message(ctx)]
|
||||
pub async fn request_new_client_approval(
|
||||
&mut self,
|
||||
client_pubkey: VerifyingKey,
|
||||
mut cancel_flag: watch::Receiver<()>,
|
||||
ctx: &mut Context<Self, Result<bool, Error>>,
|
||||
) -> Result<bool, Error> {
|
||||
self.send_msg(
|
||||
UserAgentResponsePayload::SdkClientConnectionRequest(SdkClientConnectionRequest {
|
||||
pubkey: client_pubkey.as_bytes().to_vec(),
|
||||
}),
|
||||
ctx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let extractor = |msg| {
|
||||
if let UserAgentRequestPayload::SdkClientConnectionResponse(
|
||||
client_connection_response,
|
||||
) = msg
|
||||
{
|
||||
Some(client_connection_response)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
tokio::select! {
|
||||
_ = cancel_flag.changed() => {
|
||||
info!(actor = "useragent", "client connection approval cancelled");
|
||||
self.send_msg(
|
||||
UserAgentResponsePayload::SdkClientConnectionCancel(SdkClientConnectionCancel {}),
|
||||
ctx,
|
||||
).await?;
|
||||
Ok(false)
|
||||
}
|
||||
result = self.expect_msg(extractor, ctx) => {
|
||||
let result = result?;
|
||||
info!(actor = "useragent", "received client connection approval result: approved={}", result.approved);
|
||||
Ok(result.approved)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserAgentSession {
|
||||
pub async fn process_transport_inbound(&mut self, req: UserAgentRequest) -> Output {
|
||||
let msg = req.payload.ok_or_else(|| {
|
||||
@@ -304,11 +353,9 @@ impl UserAgentSession {
|
||||
use sdk_client_approve_response::Result as ApproveResult;
|
||||
|
||||
if req.pubkey.len() != 32 {
|
||||
return Ok(response(UserAgentResponsePayload::SdkClientApprove(
|
||||
SdkClientApproveResponse {
|
||||
result: Some(ApproveResult::Error(ProtoSdkClientError::Internal.into())),
|
||||
},
|
||||
)));
|
||||
return Err(TransportResponseError::SdkClientApprove(
|
||||
ProtoSdkClientError::Internal,
|
||||
));
|
||||
}
|
||||
|
||||
let now = std::time::SystemTime::now()
|
||||
@@ -320,11 +367,9 @@ impl UserAgentSession {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!(?e, "Failed to get DB connection for sdk_client_approve");
|
||||
return Ok(response(UserAgentResponsePayload::SdkClientApprove(
|
||||
SdkClientApproveResponse {
|
||||
result: Some(ApproveResult::Error(ProtoSdkClientError::Internal.into())),
|
||||
},
|
||||
)));
|
||||
return Err(TransportResponseError::SdkClientApprove(
|
||||
ProtoSdkClientError::Internal,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -336,60 +381,35 @@ impl UserAgentSession {
|
||||
program_client::created_at.eq(now),
|
||||
program_client::updated_at.eq(now),
|
||||
))
|
||||
.execute(&mut conn)
|
||||
.returning((
|
||||
program_client::id,
|
||||
program_client::public_key,
|
||||
program_client::created_at,
|
||||
))
|
||||
.get_result::<(i32, Vec<u8>, i32)>(&mut conn)
|
||||
.await;
|
||||
|
||||
match insert_result {
|
||||
Ok(_) => {
|
||||
match program_client::table
|
||||
.filter(program_client::public_key.eq(&pubkey_bytes))
|
||||
.order(program_client::id.desc())
|
||||
.select((
|
||||
program_client::id,
|
||||
program_client::public_key,
|
||||
program_client::created_at,
|
||||
))
|
||||
.first::<(i32, Vec<u8>, i32)>(&mut conn)
|
||||
.await
|
||||
{
|
||||
Ok((id, pubkey, created_at)) => Ok(response(
|
||||
UserAgentResponsePayload::SdkClientApprove(SdkClientApproveResponse {
|
||||
result: Some(ApproveResult::Client(SdkClientEntry {
|
||||
id,
|
||||
pubkey,
|
||||
created_at,
|
||||
})),
|
||||
}),
|
||||
)),
|
||||
Err(e) => {
|
||||
error!(?e, "Failed to fetch inserted SDK client");
|
||||
Ok(response(UserAgentResponsePayload::SdkClientApprove(
|
||||
SdkClientApproveResponse {
|
||||
result: Some(ApproveResult::Error(
|
||||
ProtoSdkClientError::Internal.into(),
|
||||
)),
|
||||
},
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok((id, pubkey, created_at)) => Ok(response(
|
||||
UserAgentResponsePayload::SdkClientApprove(SdkClientApproveResponse {
|
||||
result: Some(ApproveResult::Client(SdkClientEntry {
|
||||
id,
|
||||
pubkey,
|
||||
created_at,
|
||||
})),
|
||||
}),
|
||||
)),
|
||||
Err(diesel::result::Error::DatabaseError(
|
||||
diesel::result::DatabaseErrorKind::UniqueViolation,
|
||||
_,
|
||||
)) => Ok(response(UserAgentResponsePayload::SdkClientApprove(
|
||||
SdkClientApproveResponse {
|
||||
result: Some(ApproveResult::Error(
|
||||
ProtoSdkClientError::AlreadyExists.into(),
|
||||
)),
|
||||
},
|
||||
))),
|
||||
)) => Err(TransportResponseError::SdkClientApprove(
|
||||
ProtoSdkClientError::AlreadyExists,
|
||||
)),
|
||||
Err(e) => {
|
||||
error!(?e, "Failed to insert SDK client");
|
||||
Ok(response(UserAgentResponsePayload::SdkClientApprove(
|
||||
SdkClientApproveResponse {
|
||||
result: Some(ApproveResult::Error(ProtoSdkClientError::Internal.into())),
|
||||
},
|
||||
)))
|
||||
Err(TransportResponseError::SdkClientApprove(
|
||||
ProtoSdkClientError::Internal,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -399,13 +419,9 @@ impl UserAgentSession {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!(?e, "Failed to get DB connection for sdk_client_list");
|
||||
return Ok(response(UserAgentResponsePayload::SdkClientList(
|
||||
SdkClientListResponse {
|
||||
result: Some(sdk_client_list_response::Result::Error(
|
||||
ProtoSdkClientError::Internal.into(),
|
||||
)),
|
||||
},
|
||||
)));
|
||||
return Err(TransportResponseError::SdkClientList(
|
||||
ProtoSdkClientError::Internal,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -434,13 +450,9 @@ impl UserAgentSession {
|
||||
))),
|
||||
Err(e) => {
|
||||
error!(?e, "Failed to list SDK clients");
|
||||
Ok(response(UserAgentResponsePayload::SdkClientList(
|
||||
SdkClientListResponse {
|
||||
result: Some(sdk_client_list_response::Result::Error(
|
||||
ProtoSdkClientError::Internal.into(),
|
||||
)),
|
||||
},
|
||||
)))
|
||||
Err(TransportResponseError::SdkClientList(
|
||||
ProtoSdkClientError::Internal,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -452,11 +464,9 @@ impl UserAgentSession {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!(?e, "Failed to get DB connection for sdk_client_revoke");
|
||||
return Ok(response(UserAgentResponsePayload::SdkClientRevoke(
|
||||
SdkClientRevokeResponse {
|
||||
result: Some(RevokeResult::Error(ProtoSdkClientError::Internal.into())),
|
||||
},
|
||||
)));
|
||||
return Err(TransportResponseError::SdkClientRevoke(
|
||||
ProtoSdkClientError::Internal,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -465,11 +475,9 @@ impl UserAgentSession {
|
||||
.execute(&mut conn)
|
||||
.await
|
||||
{
|
||||
Ok(0) => Ok(response(UserAgentResponsePayload::SdkClientRevoke(
|
||||
SdkClientRevokeResponse {
|
||||
result: Some(RevokeResult::Error(ProtoSdkClientError::NotFound.into())),
|
||||
},
|
||||
))),
|
||||
Ok(0) => Err(TransportResponseError::SdkClientRevoke(
|
||||
ProtoSdkClientError::NotFound,
|
||||
)),
|
||||
Ok(_) => Ok(response(UserAgentResponsePayload::SdkClientRevoke(
|
||||
SdkClientRevokeResponse {
|
||||
result: Some(RevokeResult::Ok(())),
|
||||
@@ -478,20 +486,14 @@ impl UserAgentSession {
|
||||
Err(diesel::result::Error::DatabaseError(
|
||||
diesel::result::DatabaseErrorKind::ForeignKeyViolation,
|
||||
_,
|
||||
)) => Ok(response(UserAgentResponsePayload::SdkClientRevoke(
|
||||
SdkClientRevokeResponse {
|
||||
result: Some(RevokeResult::Error(
|
||||
ProtoSdkClientError::HasRelatedData.into(),
|
||||
)),
|
||||
},
|
||||
))),
|
||||
)) => Err(TransportResponseError::SdkClientRevoke(
|
||||
ProtoSdkClientError::HasRelatedData,
|
||||
)),
|
||||
Err(e) => {
|
||||
error!(?e, "Failed to delete SDK client");
|
||||
Ok(response(UserAgentResponsePayload::SdkClientRevoke(
|
||||
SdkClientRevokeResponse {
|
||||
result: Some(RevokeResult::Error(ProtoSdkClientError::Internal.into())),
|
||||
},
|
||||
)))
|
||||
Err(TransportResponseError::SdkClientRevoke(
|
||||
ProtoSdkClientError::Internal,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -558,8 +560,15 @@ impl Actor for UserAgentSession {
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = self.props.transport.send(Err(err)).await;
|
||||
return Some(kameo::mailbox::Signal::Stop);
|
||||
let should_stop = err.is_fatal();
|
||||
if self.props.transport.send(Err(err)).await.is_err() {
|
||||
error!(actor = "useragent", reason = "channel closed", "send.failed");
|
||||
return Some(kameo::mailbox::Signal::Stop);
|
||||
}
|
||||
|
||||
if should_stop {
|
||||
return Some(kameo::mailbox::Signal::Stop);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,12 @@
|
||||
use arbiter_proto::{
|
||||
proto::{
|
||||
client::{ClientRequest, ClientResponse},
|
||||
user_agent::{UserAgentRequest, UserAgentResponse},
|
||||
user_agent::{
|
||||
SdkClientApproveResponse, SdkClientListResponse, SdkClientRevokeResponse,
|
||||
UserAgentRequest, UserAgentResponse, sdk_client_approve_response,
|
||||
sdk_client_list_response, sdk_client_revoke_response,
|
||||
user_agent_response::Payload as UserAgentResponsePayload,
|
||||
},
|
||||
},
|
||||
transport::{IdentityRecvConverter, SendConverter, grpc},
|
||||
};
|
||||
@@ -37,6 +42,27 @@ impl SendConverter for UserAgentGrpcSender {
|
||||
fn convert(&self, item: Self::Input) -> Self::Output {
|
||||
match item {
|
||||
Ok(message) => Ok(message),
|
||||
Err(TransportResponseError::SdkClientApprove(code)) => Ok(UserAgentResponse {
|
||||
payload: Some(UserAgentResponsePayload::SdkClientApprove(
|
||||
SdkClientApproveResponse {
|
||||
result: Some(sdk_client_approve_response::Result::Error(code.into())),
|
||||
},
|
||||
)),
|
||||
}),
|
||||
Err(TransportResponseError::SdkClientList(code)) => Ok(UserAgentResponse {
|
||||
payload: Some(UserAgentResponsePayload::SdkClientList(
|
||||
SdkClientListResponse {
|
||||
result: Some(sdk_client_list_response::Result::Error(code.into())),
|
||||
},
|
||||
)),
|
||||
}),
|
||||
Err(TransportResponseError::SdkClientRevoke(code)) => Ok(UserAgentResponse {
|
||||
payload: Some(UserAgentResponsePayload::SdkClientRevoke(
|
||||
SdkClientRevokeResponse {
|
||||
result: Some(sdk_client_revoke_response::Result::Error(code.into())),
|
||||
},
|
||||
)),
|
||||
}),
|
||||
Err(err) => Err(user_agent_error_status(err)),
|
||||
}
|
||||
}
|
||||
@@ -79,7 +105,7 @@ fn client_auth_error_status(value: &client::auth::Error) -> Status {
|
||||
Status::invalid_argument("Failed to convert pubkey to VerifyingKey")
|
||||
}
|
||||
Error::InvalidChallengeSolution => Status::unauthenticated(value.to_string()),
|
||||
Error::NotRegistered => Status::permission_denied(value.to_string()),
|
||||
Error::ApproveError(_) => Status::permission_denied(value.to_string()),
|
||||
Error::Transport => Status::internal("Transport error"),
|
||||
Error::DatabasePoolUnavailable => Status::internal("Database pool error"),
|
||||
Error::DatabaseOperationFailed => Status::internal("Database error"),
|
||||
@@ -103,6 +129,11 @@ fn user_agent_error_status(value: TransportResponseError) -> Status {
|
||||
TransportResponseError::KeyHolderActorUnreachable => {
|
||||
Status::internal("Vault is not available")
|
||||
}
|
||||
TransportResponseError::SdkClientApprove(_)
|
||||
| TransportResponseError::SdkClientList(_)
|
||||
| TransportResponseError::SdkClientRevoke(_) => {
|
||||
Status::internal("SDK client operation failed")
|
||||
}
|
||||
TransportResponseError::Auth(ref err) => auth_error_status(err),
|
||||
TransportResponseError::ConnectionRegistrationFailed => {
|
||||
Status::internal("Failed registering connection")
|
||||
|
||||
@@ -114,6 +114,15 @@ pub async fn test_challenge_auth() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let response = test_transport.recv().await.expect("should receive auth ok");
|
||||
match response {
|
||||
Ok(resp) => match resp.payload {
|
||||
Some(ClientResponsePayload::AuthOk(_)) => {}
|
||||
other => panic!("Expected AuthOk, got {other:?}"),
|
||||
},
|
||||
Err(err) => panic!("Expected Ok response, got Err({err:?})"),
|
||||
}
|
||||
|
||||
// Auth completes, session spawned
|
||||
task.await.unwrap();
|
||||
}
|
||||
@@ -178,6 +187,15 @@ pub async fn test_evm_sign_request_payload_is_handled() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let response = test_transport.recv().await.expect("should receive auth ok");
|
||||
match response {
|
||||
Ok(resp) => match resp.payload {
|
||||
Some(ClientResponsePayload::AuthOk(_)) => {}
|
||||
other => panic!("Expected AuthOk, got {other:?}"),
|
||||
},
|
||||
Err(err) => panic!("Expected Ok response, got Err({err:?})"),
|
||||
}
|
||||
|
||||
task.await.unwrap();
|
||||
|
||||
let tx = TxEip1559 {
|
||||
|
||||
@@ -5,7 +5,10 @@ use arbiter_proto::proto::user_agent::{
|
||||
user_agent_response::Payload as UserAgentResponsePayload,
|
||||
};
|
||||
use arbiter_server::{
|
||||
actors::{GlobalActors, user_agent::session::UserAgentSession},
|
||||
actors::{
|
||||
GlobalActors,
|
||||
user_agent::{TransportResponseError, session::UserAgentSession},
|
||||
},
|
||||
db,
|
||||
};
|
||||
|
||||
@@ -68,22 +71,15 @@ async fn test_sdk_client_approve_duplicate_returns_already_exists() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let response = session
|
||||
let err = session
|
||||
.process_transport_inbound(req)
|
||||
.await
|
||||
.expect("second insert should not panic");
|
||||
.expect_err("second insert should return typed TransportResponseError");
|
||||
|
||||
match response.payload.unwrap() {
|
||||
UserAgentResponsePayload::SdkClientApprove(resp) => match resp.result.unwrap() {
|
||||
sdk_client_approve_response::Result::Error(code) => {
|
||||
assert_eq!(code, ProtoSdkClientError::AlreadyExists as i32);
|
||||
}
|
||||
sdk_client_approve_response::Result::Client(_) => {
|
||||
panic!("Expected AlreadyExists error for duplicate pubkey")
|
||||
}
|
||||
},
|
||||
other => panic!("Expected SdkClientApprove, got {other:?}"),
|
||||
}
|
||||
assert_eq!(
|
||||
err,
|
||||
TransportResponseError::SdkClientApprove(ProtoSdkClientError::AlreadyExists)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -203,26 +199,19 @@ async fn test_sdk_client_revoke_not_found_returns_error() {
|
||||
let db = db::create_test_pool().await;
|
||||
let mut session = make_session(&db).await;
|
||||
|
||||
let response = session
|
||||
let err = session
|
||||
.process_transport_inbound(UserAgentRequest {
|
||||
payload: Some(UserAgentRequestPayload::SdkClientRevoke(
|
||||
SdkClientRevokeRequest { client_id: 9999 },
|
||||
)),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
.expect_err("missing client should return typed TransportResponseError");
|
||||
|
||||
match response.payload.unwrap() {
|
||||
UserAgentResponsePayload::SdkClientRevoke(resp) => match resp.result.unwrap() {
|
||||
sdk_client_revoke_response::Result::Error(code) => {
|
||||
assert_eq!(code, ProtoSdkClientError::NotFound as i32);
|
||||
}
|
||||
sdk_client_revoke_response::Result::Ok(_) => {
|
||||
panic!("Expected NotFound error for missing client_id")
|
||||
}
|
||||
},
|
||||
other => panic!("Expected SdkClientRevoke, got {other:?}"),
|
||||
}
|
||||
assert_eq!(
|
||||
err,
|
||||
TransportResponseError::SdkClientRevoke(ProtoSdkClientError::NotFound)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
[package]
|
||||
name = "arbiter-terrors-poc"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
terrors = "0.3"
|
||||
@@ -1,139 +0,0 @@
|
||||
use crate::errors::{InternalError1, InternalError2, InvalidSignature, NotRegistered};
|
||||
use terrors::OneOf;
|
||||
|
||||
use crate::errors::ProtoError;
|
||||
|
||||
// Each sub-call's error type already implements DrainInto<ProtoError>, so we convert
|
||||
// directly to ProtoError without broaden — no turbofish needed anywhere.
|
||||
//
|
||||
// Call chain:
|
||||
// load_config() → OneOf<(InternalError2,)> → ProtoError::from
|
||||
// get_nonce() → OneOf<(InternalError1, InternalError2)> → ProtoError::from
|
||||
// verify_sig() → OneOf<(InvalidSignature,)> → ProtoError::from
|
||||
pub fn process_request(id: u32, sig: &str) -> Result<String, ProtoError> {
|
||||
if id == 0 {
|
||||
return Err(ProtoError::NotRegistered);
|
||||
}
|
||||
|
||||
let config = load_config(id).map_err(ProtoError::from)?;
|
||||
let nonce = crate::db::get_nonce(id).map_err(ProtoError::from)?;
|
||||
verify_signature(nonce, sig).map_err(ProtoError::from)?;
|
||||
|
||||
Ok(format!("config={config} nonce={nonce} sig={sig}"))
|
||||
}
|
||||
|
||||
// Simulates loading a config value.
|
||||
// id=97 triggers InternalError2 ("config read failed").
|
||||
fn load_config(id: u32) -> Result<String, OneOf<(InternalError2,)>> {
|
||||
if id == 97 {
|
||||
return Err(OneOf::new(InternalError2("config read failed".to_owned())));
|
||||
}
|
||||
Ok(format!("cfg-{id}"))
|
||||
}
|
||||
|
||||
pub fn verify_signature(_nonce: u32, sig: &str) -> Result<(), OneOf<(InvalidSignature,)>> {
|
||||
if sig != "ok" {
|
||||
return Err(OneOf::new(InvalidSignature));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
type AuthError = OneOf<(
|
||||
NotRegistered,
|
||||
InvalidSignature,
|
||||
InternalError1,
|
||||
InternalError2,
|
||||
)>;
|
||||
|
||||
pub fn authenticate(id: u32, sig: &str) -> Result<u32, AuthError> {
|
||||
if id == 0 {
|
||||
return Err(OneOf::new(NotRegistered));
|
||||
}
|
||||
|
||||
// Return type AuthError lets the compiler infer the broaden target.
|
||||
let nonce = crate::db::get_nonce(id).map_err(OneOf::broaden)?;
|
||||
verify_signature(nonce, sig).map_err(OneOf::broaden)?;
|
||||
|
||||
Ok(nonce)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn verify_signature_ok() {
|
||||
assert!(verify_signature(42, "ok").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn verify_signature_bad() {
|
||||
let err = verify_signature(42, "bad").unwrap_err();
|
||||
assert!(err.narrow::<crate::errors::InvalidSignature, _>().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn authenticate_success() {
|
||||
assert_eq!(authenticate(1, "ok").unwrap(), 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn authenticate_not_registered() {
|
||||
let err = authenticate(0, "ok").unwrap_err();
|
||||
assert!(err.narrow::<crate::errors::NotRegistered, _>().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn authenticate_invalid_signature() {
|
||||
let err = authenticate(1, "bad").unwrap_err();
|
||||
assert!(err.narrow::<crate::errors::InvalidSignature, _>().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn authenticate_internal_error1() {
|
||||
let err = authenticate(99, "ok").unwrap_err();
|
||||
assert!(err.narrow::<crate::errors::InternalError1, _>().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn authenticate_internal_error2() {
|
||||
let err = authenticate(98, "ok").unwrap_err();
|
||||
assert!(err.narrow::<crate::errors::InternalError2, _>().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn process_request_success() {
|
||||
let result = process_request(1, "ok").unwrap();
|
||||
assert!(result.contains("nonce=42"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn process_request_not_registered() {
|
||||
let err = process_request(0, "ok").unwrap_err();
|
||||
assert!(matches!(err, crate::errors::ProtoError::NotRegistered));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn process_request_invalid_signature() {
|
||||
let err = process_request(1, "bad").unwrap_err();
|
||||
assert!(matches!(err, crate::errors::ProtoError::InvalidSignature));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn process_request_internal_from_config() {
|
||||
// id=97 → load_config returns InternalError2
|
||||
let err = process_request(97, "ok").unwrap_err();
|
||||
assert!(
|
||||
matches!(err, crate::errors::ProtoError::Internal(ref msg) if msg == "config read failed")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn process_request_internal_from_db() {
|
||||
// id=99 → get_nonce returns InternalError1
|
||||
let err = process_request(99, "ok").unwrap_err();
|
||||
assert!(
|
||||
matches!(err, crate::errors::ProtoError::Internal(ref msg) if msg == "db pool unavailable")
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
use crate::errors::{InternalError1, InternalError2};
|
||||
use terrors::OneOf;
|
||||
|
||||
// Simulates fetching a nonce from a database.
|
||||
// id=99 → InternalError1 (pool unavailable)
|
||||
// id=98 → InternalError2 (query timeout)
|
||||
pub fn get_nonce(id: u32) -> Result<u32, OneOf<(InternalError1, InternalError2)>> {
|
||||
match id {
|
||||
99 => Err(OneOf::new(InternalError1("db pool unavailable".to_owned()))),
|
||||
98 => Err(OneOf::new(InternalError2("query timeout".to_owned()))),
|
||||
_ => Ok(42),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn get_nonce_returns_nonce_for_valid_id() {
|
||||
assert_eq!(get_nonce(1).unwrap(), 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_nonce_returns_internal_error1_for_sentinel() {
|
||||
let err = get_nonce(99).unwrap_err();
|
||||
let internal = err.narrow::<crate::errors::InternalError1, _>().unwrap();
|
||||
assert_eq!(internal.0, "db pool unavailable");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_nonce_returns_internal_error2_for_sentinel() {
|
||||
let err = get_nonce(98).unwrap_err();
|
||||
let e = err.narrow::<crate::errors::InternalError1, _>().unwrap_err();
|
||||
let internal = e.take::<crate::errors::InternalError2>();
|
||||
assert_eq!(internal.0, "query timeout");
|
||||
}
|
||||
}
|
||||
@@ -1,130 +0,0 @@
|
||||
use terrors::OneOf;
|
||||
|
||||
// Wire boundary type — what would go into a proto response
|
||||
#[derive(Debug)]
|
||||
pub enum ProtoError {
|
||||
NotRegistered,
|
||||
InvalidSignature,
|
||||
Internal(String), // Or Box<dyn Error>, who cares?
|
||||
}
|
||||
|
||||
// Internal terrors types
|
||||
#[derive(Debug)]
|
||||
pub struct NotRegistered;
|
||||
#[derive(Debug)]
|
||||
pub struct InvalidSignature;
|
||||
#[derive(Debug)]
|
||||
pub struct InternalError1(pub String);
|
||||
#[derive(Debug)]
|
||||
pub struct InternalError2(pub String);
|
||||
|
||||
// Errors can be scattered across the codebase as long as they implement Into<ProtoError>
|
||||
impl From<NotRegistered> for ProtoError {
|
||||
fn from(_: NotRegistered) -> Self {
|
||||
ProtoError::NotRegistered
|
||||
}
|
||||
}
|
||||
|
||||
impl From<InvalidSignature> for ProtoError {
|
||||
fn from(_: InvalidSignature) -> Self {
|
||||
ProtoError::InvalidSignature
|
||||
}
|
||||
}
|
||||
|
||||
impl From<InternalError1> for ProtoError {
|
||||
fn from(e: InternalError1) -> Self {
|
||||
ProtoError::Internal(e.0)
|
||||
}
|
||||
}
|
||||
impl From<InternalError2> for ProtoError {
|
||||
fn from(e: InternalError2) -> Self {
|
||||
ProtoError::Internal(e.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Private helper trait for converting from OneOf<T...> where each T can be converted
|
||||
/// into the target type `O` by recursively narrowing until a match is found.
|
||||
///
|
||||
/// IDK why this isn't already in terrors.
|
||||
trait DrainInto<O>: terrors::TypeSet + Sized {
|
||||
fn drain(e: OneOf<Self>) -> O;
|
||||
}
|
||||
|
||||
macro_rules! impl_drain_into {
|
||||
($head:ident) => {
|
||||
impl<$head, O> DrainInto<O> for ($head,)
|
||||
where
|
||||
$head: Into<O> + 'static,
|
||||
{
|
||||
fn drain(e: OneOf<($head,)>) -> O {
|
||||
e.take().into()
|
||||
}
|
||||
}
|
||||
};
|
||||
($head:ident, $($tail:ident),+) => {
|
||||
impl<$head, $($tail),+, O> DrainInto<O> for ($head, $($tail),+)
|
||||
where
|
||||
$head: Into<O> + 'static,
|
||||
($($tail,)+): DrainInto<O>,
|
||||
{
|
||||
fn drain(e: OneOf<($head, $($tail),+)>) -> O {
|
||||
match e.narrow::<$head, _>() {
|
||||
Ok(h) => h.into(),
|
||||
Err(rest) => <($($tail,)+)>::drain(rest),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl_drain_into!($($tail),+);
|
||||
};
|
||||
}
|
||||
|
||||
// Generates impls for all tuple sizes from 1 up to 7 (restricted by terrors internal impl).
|
||||
// Each invocation produces one impl then recurses on the tail.
|
||||
impl_drain_into!(A, B, C, D, E, F, G, H, I);
|
||||
|
||||
// Blanket From impl: body delegates to the recursive drain.
|
||||
impl<E: DrainInto<ProtoError>> From<OneOf<E>> for ProtoError {
|
||||
fn from(e: OneOf<E>) -> Self {
|
||||
E::drain(e)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn not_registered_converts_to_proto() {
|
||||
let e: ProtoError = NotRegistered.into();
|
||||
assert!(matches!(e, ProtoError::NotRegistered));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_signature_converts_to_proto() {
|
||||
let e: ProtoError = InvalidSignature.into();
|
||||
assert!(matches!(e, ProtoError::InvalidSignature));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn internal_converts_to_proto() {
|
||||
let e: ProtoError = InternalError1("boom".into()).into();
|
||||
assert!(matches!(e, ProtoError::Internal(msg) if msg == "boom"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_of_remainder_converts_to_proto_invalid_signature() {
|
||||
use terrors::OneOf;
|
||||
let e: OneOf<(InvalidSignature, InternalError1)> = OneOf::new(InvalidSignature);
|
||||
let proto = ProtoError::from(e);
|
||||
assert!(matches!(proto, ProtoError::InvalidSignature));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_of_remainder_converts_to_proto_internal() {
|
||||
use terrors::OneOf;
|
||||
let e: OneOf<(InvalidSignature, InternalError1)> =
|
||||
OneOf::new(InternalError1("db fail".into()));
|
||||
let proto = ProtoError::from(e);
|
||||
assert!(matches!(proto, ProtoError::Internal(msg) if msg == "db fail"));
|
||||
}
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
mod auth;
|
||||
mod db;
|
||||
mod errors;
|
||||
|
||||
use errors::ProtoError;
|
||||
|
||||
fn run(id: u32, sig: &str) {
|
||||
print!("authenticate(id={id}, sig={sig:?}) => ");
|
||||
match auth::authenticate(id, sig) {
|
||||
Ok(nonce) => println!("Ok(nonce={nonce})"),
|
||||
Err(e) => match e.narrow::<errors::NotRegistered, _>() {
|
||||
Ok(_) => println!("Err(NotRegistered) — handled locally"),
|
||||
Err(remaining) => {
|
||||
let proto = ProtoError::from(remaining);
|
||||
println!("Err(ProtoError::{proto:?}) — forwarded to wire");
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn run_process(id: u32, sig: &str) {
|
||||
print!("process_request(id={id}, sig={sig:?}) => ");
|
||||
match auth::process_request(id, sig) {
|
||||
Ok(s) => println!("Ok({s})"),
|
||||
Err(e) => println!("Err(ProtoError::{e:?})"),
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
println!("=== authenticate ===");
|
||||
run(0, "ok"); // NotRegistered
|
||||
run(1, "bad"); // InvalidSignature
|
||||
run(99, "ok"); // InternalError1
|
||||
run(98, "ok"); // InternalError2
|
||||
run(1, "ok"); // success
|
||||
|
||||
println!("\n=== process_request (Try chain) ===");
|
||||
run_process(0, "ok"); // NotRegistered (guard, no I/O)
|
||||
run_process(97, "ok"); // InternalError2 from load_config
|
||||
run_process(99, "ok"); // InternalError1 from get_nonce
|
||||
run_process(1, "bad"); // InvalidSignature from verify_signature
|
||||
run_process(1, "ok"); // success
|
||||
}
|
||||
Reference in New Issue
Block a user