diff --git a/server/crates/arbiter-server/src/actors/client/auth/mod.rs b/server/crates/arbiter-server/src/actors/client/auth/mod.rs index 06b9d29..8ff0600 100644 --- a/server/crates/arbiter-server/src/actors/client/auth/mod.rs +++ b/server/crates/arbiter-server/src/actors/client/auth/mod.rs @@ -51,6 +51,7 @@ fn parse_auth_event(payload: ClientRequestPayload) -> Result solution: signature, })) } + _ => Err(Error::UnexpectedMessagePayload) , } } diff --git a/server/crates/arbiter-server/src/db/models.rs b/server/crates/arbiter-server/src/db/models.rs index e3ebddf..b38315d 100644 --- a/server/crates/arbiter-server/src/db/models.rs +++ b/server/crates/arbiter-server/src/db/models.rs @@ -56,7 +56,7 @@ pub mod types { fn from_sql( mut bytes: ::RawValue<'_>, ) -> diesel::deserialize::Result { - let Some(SqliteType::Integer) = bytes.value_type() else { + let Some(SqliteType::Long) = bytes.value_type() else { return Err(format!( "Expected Integer type for SqliteTimestamp, got {:?}", bytes.value_type() @@ -64,8 +64,8 @@ pub mod types { .into()); }; - let unix_timestamp = bytes.read_integer(); - let datetime = DateTime::from_timestamp(unix_timestamp as i64, 0) + let unix_timestamp = bytes.read_long(); + let datetime = DateTime::from_timestamp(unix_timestamp, 0) .ok_or("Timestamp is out of bounds")?; Ok(SqliteTimestamp(datetime)) diff --git a/server/crates/arbiter-server/src/evm/mod.rs b/server/crates/arbiter-server/src/evm/mod.rs index 7a9432b..9e00fc0 100644 --- a/server/crates/arbiter-server/src/evm/mod.rs +++ b/server/crates/arbiter-server/src/evm/mod.rs @@ -1,24 +1,24 @@ pub mod abi; pub mod safe_signer; -use alloy::{consensus::TxEip1559, primitives::{TxKind, U256}}; +use alloy::{ + consensus::TxEip1559, + primitives::{TxKind, U256}, +}; use chrono::Utc; -use diesel::{QueryResult, insert_into, sqlite::Sqlite}; +use diesel::{ExpressionMethods as _, QueryDsl, QueryResult, insert_into, sqlite::Sqlite}; use diesel_async::{AsyncConnection, RunQueryDsl}; use crate::{ db::{ self, - models::{ - EvmBasicGrant, NewEvmBasicGrant, NewEvmTransactionLog, - SqliteTimestamp, - }, + models::{EvmBasicGrant, NewEvmBasicGrant, NewEvmTransactionLog, SqliteTimestamp}, schema::{self, evm_transaction_log}, }, evm::policies::{ DatabaseID, EvalContext, EvalViolation, FullGrant, Grant, Policy, SharedGrantSettings, - SpecificGrant, SpecificMeaning, - ether_transfer::EtherTransfer, token_transfers::TokenTransfer, + SpecificGrant, SpecificMeaning, ether_transfer::EtherTransfer, + token_transfers::TokenTransfer, }, }; @@ -55,7 +55,6 @@ pub enum VetError { Evaluated(SpecificMeaning, #[source] PolicyError), } - #[derive(Debug, thiserror::Error, miette::Diagnostic)] pub enum SignError { #[error("Database connection pool error")] @@ -118,8 +117,7 @@ async fn check_shared_constraints( let now = Utc::now(); // Validity window - if shared.valid_from.map_or(false, |t| now < t) - || shared.valid_until.map_or(false, |t| now > t) + if shared.valid_from.map_or(false, |t| now < t) || shared.valid_until.map_or(false, |t| now > t) { violations.push(EvalViolation::InvalidTime); } @@ -128,9 +126,9 @@ async fn check_shared_constraints( let fee_exceeded = shared .max_gas_fee_per_gas .map_or(false, |cap| U256::from(context.max_fee_per_gas) > cap); - let priority_exceeded = shared - .max_priority_fee_per_gas - .map_or(false, |cap| U256::from(context.max_priority_fee_per_gas) > cap); + let priority_exceeded = shared.max_priority_fee_per_gas.map_or(false, |cap| { + U256::from(context.max_priority_fee_per_gas) > cap + }); if fee_exceeded || priority_exceeded { violations.push(EvalViolation::GasLimitExceeded { max_gas_fee_per_gas: shared.max_gas_fee_per_gas, @@ -274,13 +272,23 @@ impl Engine { EtherTransfer::find_all_grants(&mut conn) .await? .into_iter() - .map(Grant::from), + .map(|g| Grant { + id: g.id, + shared_grant_id: g.shared_grant_id, + shared: g.shared, + settings: SpecificGrant::EtherTransfer(g.settings), + }), ); grants.extend( TokenTransfer::find_all_grants(&mut conn) .await? .into_iter() - .map(Grant::from), + .map(|g| Grant { + id: g.id, + shared_grant_id: g.shared_grant_id, + shared: g.shared, + settings: SpecificGrant::TokenTransfer(g.settings), + }), ); Ok(grants) diff --git a/server/crates/arbiter-server/src/evm/policies.rs b/server/crates/arbiter-server/src/evm/policies.rs index 4e5524c..23c3444 100644 --- a/server/crates/arbiter-server/src/evm/policies.rs +++ b/server/crates/arbiter-server/src/evm/policies.rs @@ -17,6 +17,7 @@ use crate::{ pub mod ether_transfer; pub mod token_transfers; +#[derive(Debug, Clone)] pub struct EvalContext { // Which wallet is this transaction for pub client_id: i32, @@ -72,6 +73,7 @@ pub struct Grant { pub settings: PolicySettings, } + pub trait Policy: Sized { type Settings: Send + Sync + 'static + Into; type Meaning: Display + std::fmt::Debug + Send + Sync + 'static + Into; @@ -201,19 +203,6 @@ pub enum SpecificGrant { TokenTransfer(token_transfers::Settings), } -/// Blanket conversion from a typed `Grant` into `Grant`. -/// Lets the engine collect across all policies into one `Vec>`. -impl> From> for Grant { - fn from(g: Grant) -> Self { - Grant { - id: g.id, - shared_grant_id: g.shared_grant_id, - shared: g.shared, - settings: g.settings.into(), - } - } -} - pub struct FullGrant { pub basic: SharedGrantSettings, pub specific: PolicyGrant, diff --git a/server/crates/arbiter-server/src/evm/policies/ether_transfer.rs b/server/crates/arbiter-server/src/evm/policies/ether_transfer/mod.rs similarity index 90% rename from server/crates/arbiter-server/src/evm/policies/ether_transfer.rs rename to server/crates/arbiter-server/src/evm/policies/ether_transfer/mod.rs index dda665a..dfea8cb 100644 --- a/server/crates/arbiter-server/src/evm/policies/ether_transfer.rs +++ b/server/crates/arbiter-server/src/evm/policies/ether_transfer/mod.rs @@ -3,7 +3,7 @@ use std::fmt::Display; use alloy::primitives::{Address, U256}; use chrono::{DateTime, Duration, Utc}; -use diesel::dsl::insert_into; +use diesel::dsl::{auto_type, insert_into}; use diesel::sqlite::Sqlite; use diesel::{ExpressionMethods, JoinOnDsl, prelude::*}; use diesel_async::{AsyncConnection, RunQueryDsl}; @@ -24,11 +24,10 @@ use crate::{ evm::{policies::Policy, utils}, }; -#[diesel::auto_type] +#[auto_type] fn grant_join() -> _ { evm_ether_transfer_grant::table.inner_join( - evm_basic_grant::table - .on(evm_ether_transfer_grant::basic_grant_id.eq(evm_basic_grant::id)), + evm_basic_grant::table.on(evm_ether_transfer_grant::basic_grant_id.eq(evm_basic_grant::id)), ) } @@ -197,11 +196,16 @@ impl Policy for EtherTransfer { // Find a grant where: // 1. The basic grant's wallet_id and client_id match the context // 2. Any of the grant's targets match the context's `to` address - let grant: Option<(EvmBasicGrant, EvmEtherTransferGrant)> = grant_join() - .filter(evm_basic_grant::wallet_id.eq(context.wallet_id)) - .filter(evm_basic_grant::client_id.eq(context.client_id)) - .filter(evm_ether_transfer_grant_target::address.eq(&target_bytes)) - .filter(evm_basic_grant::revoked_at.is_null()) + let grant: Option<(EvmBasicGrant, EvmEtherTransferGrant)> = evm_ether_transfer_grant::table + .inner_join(evm_basic_grant::table) + .inner_join(evm_ether_transfer_grant_target::table) + .filter( + evm_basic_grant::wallet_id + .eq(context.wallet_id) + .and(evm_basic_grant::client_id.eq(context.client_id)) + .and(evm_basic_grant::revoked_at.is_null()) + .and(evm_ether_transfer_grant_target::address.eq(&target_bytes)), + ) .select(( EvmBasicGrant::as_select(), EvmEtherTransferGrant::as_select(), @@ -270,7 +274,10 @@ impl Policy for EtherTransfer { ) -> QueryResult>> { let grants: Vec<(EvmBasicGrant, EvmEtherTransferGrant)> = grant_join() .filter(evm_basic_grant::revoked_at.is_null()) - .select((EvmBasicGrant::as_select(), EvmEtherTransferGrant::as_select())) + .select(( + EvmBasicGrant::as_select(), + EvmEtherTransferGrant::as_select(), + )) .load(conn) .await?; @@ -295,7 +302,10 @@ impl Policy for EtherTransfer { let mut targets_by_grant: HashMap> = HashMap::new(); for target in all_targets { - targets_by_grant.entry(target.grant_id).or_default().push(target); + targets_by_grant + .entry(target.grant_id) + .or_default() + .push(target); } let limits_by_id: HashMap = @@ -326,8 +336,9 @@ impl Policy for EtherTransfer { settings: Settings { target: targets, limit: VolumeRateLimit { - max_volume: utils::try_bytes_to_u256(&limit.max_volume) - .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))?, + max_volume: utils::try_bytes_to_u256(&limit.max_volume).map_err( + |e| diesel::result::Error::DeserializationError(Box::new(e)), + )?, window: Duration::seconds(limit.window_secs as i64), }, }, @@ -336,3 +347,6 @@ impl Policy for EtherTransfer { .collect() } } + +#[cfg(test)] +mod tests; diff --git a/server/crates/arbiter-server/src/evm/policies/ether_transfer/tests.rs b/server/crates/arbiter-server/src/evm/policies/ether_transfer/tests.rs new file mode 100644 index 0000000..ebbe185 --- /dev/null +++ b/server/crates/arbiter-server/src/evm/policies/ether_transfer/tests.rs @@ -0,0 +1,387 @@ +use alloy::primitives::{Address, Bytes, U256, address}; +use chrono::{Duration, Utc}; +use diesel::{ExpressionMethods, SelectableHelper, insert_into}; +use diesel_async::RunQueryDsl; + +use crate::db::{ + self, DatabaseConnection, + models::{EvmBasicGrant, NewEvmBasicGrant, NewEvmTransactionLog, SqliteTimestamp}, + schema::{evm_basic_grant, evm_transaction_log}, +}; +use crate::evm::{ + policies::{ + EvalContext, EvalViolation, Grant, Policy, SharedGrantSettings, VolumeRateLimit, + }, + utils, +}; + +use super::{EtherTransfer, Settings}; + +const WALLET_ID: i32 = 1; +const CLIENT_ID: i32 = 2; +const CHAIN_ID: u64 = 1; + +const ALLOWED: Address = address!("1111111111111111111111111111111111111111"); +const OTHER: Address = address!("2222222222222222222222222222222222222222"); + +fn ctx(to: Address, value: U256) -> EvalContext { + EvalContext { + wallet_id: WALLET_ID, + client_id: CLIENT_ID, + chain: CHAIN_ID, + to, + value, + calldata: Bytes::new(), + max_fee_per_gas: 0, + max_priority_fee_per_gas: 0, + } +} + +async fn insert_basic(conn: &mut DatabaseConnection, revoked: bool) -> EvmBasicGrant { + insert_into(evm_basic_grant::table) + .values(NewEvmBasicGrant { + wallet_id: WALLET_ID, + client_id: CLIENT_ID, + chain_id: CHAIN_ID as i32, + valid_from: None, + valid_until: None, + max_gas_fee_per_gas: None, + max_priority_fee_per_gas: None, + rate_limit_count: None, + rate_limit_window_secs: None, + revoked_at: revoked.then(|| SqliteTimestamp(Utc::now())), + }) + .returning(EvmBasicGrant::as_select()) + .get_result(conn) + .await + .unwrap() +} + +fn make_settings(targets: Vec
, max_volume: u64) -> Settings { + Settings { + target: targets, + limit: VolumeRateLimit { + max_volume: U256::from(max_volume), + window: Duration::hours(1), + }, + } +} + +fn shared() -> SharedGrantSettings { + SharedGrantSettings { + wallet_id: WALLET_ID, + chain: CHAIN_ID, + valid_from: None, + valid_until: None, + max_gas_fee_per_gas: None, + max_priority_fee_per_gas: None, + rate_limit: None, + } +} + +// ── analyze ───────────────────────────────────────────────────────────── + +#[test] +fn analyze_matches_empty_calldata() { + let m = EtherTransfer::analyze(&ctx(ALLOWED, U256::from(1_000u64))).unwrap(); + assert_eq!(m.to, ALLOWED); + assert_eq!(m.value, U256::from(1_000u64)); +} + +#[test] +fn analyze_rejects_nonempty_calldata() { + let context = EvalContext { + calldata: Bytes::from(vec![0xde, 0xad, 0xbe, 0xef]), + ..ctx(ALLOWED, U256::from(1u64)) + }; + assert!(EtherTransfer::analyze(&context).is_none()); +} + +// ── evaluate ──────────────────────────────────────────────────────────── + +#[tokio::test] +async fn evaluate_passes_for_allowed_target() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let grant = Grant { + id: 999, + shared_grant_id: 999, + shared: shared(), + settings: make_settings(vec![ALLOWED], 1_000_000), + }; + let context = ctx(ALLOWED, U256::from(100u64)); + let m = EtherTransfer::analyze(&context).unwrap(); + let v = EtherTransfer::evaluate(&context, &m, &grant, &mut *conn) + .await + .unwrap(); + assert!(v.is_empty()); +} + +#[tokio::test] +async fn evaluate_rejects_disallowed_target() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let grant = Grant { + id: 999, + shared_grant_id: 999, + shared: shared(), + settings: make_settings(vec![ALLOWED], 1_000_000), + }; + let context = ctx(OTHER, U256::from(100u64)); + let m = EtherTransfer::analyze(&context).unwrap(); + let v = EtherTransfer::evaluate(&context, &m, &grant, &mut *conn) + .await + .unwrap(); + assert!( + v.iter() + .any(|e| matches!(e, EvalViolation::InvalidTarget { .. })) + ); +} + +#[tokio::test] +async fn evaluate_passes_when_volume_within_limit() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let basic = insert_basic(&mut conn, false).await; + let settings = make_settings(vec![ALLOWED], 1_000); + let grant_id = EtherTransfer::create_grant(&basic, &settings, &mut *conn) + .await + .unwrap(); + + insert_into(evm_transaction_log::table) + .values(NewEvmTransactionLog { + grant_id, + client_id: CLIENT_ID, + wallet_id: WALLET_ID, + chain_id: CHAIN_ID as i32, + eth_value: utils::u256_to_bytes(U256::from(500u64)).to_vec(), + signed_at: SqliteTimestamp(Utc::now()), + }) + .execute(&mut *conn) + .await + .unwrap(); + + let grant = Grant { + id: grant_id, + shared_grant_id: basic.id, + shared: shared(), + settings, + }; + let context = ctx(ALLOWED, U256::from(100u64)); + let m = EtherTransfer::analyze(&context).unwrap(); + let v = EtherTransfer::evaluate(&context, &m, &grant, &mut *conn) + .await + .unwrap(); + assert!( + !v.iter() + .any(|e| matches!(e, EvalViolation::VolumetricLimitExceeded)) + ); +} + +#[tokio::test] +async fn evaluate_rejects_volume_over_limit() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let basic = insert_basic(&mut conn, false).await; + let settings = make_settings(vec![ALLOWED], 1_000); + let grant_id = EtherTransfer::create_grant(&basic, &settings, &mut *conn) + .await + .unwrap(); + + insert_into(evm_transaction_log::table) + .values(NewEvmTransactionLog { + grant_id, + client_id: CLIENT_ID, + wallet_id: WALLET_ID, + chain_id: CHAIN_ID as i32, + eth_value: utils::u256_to_bytes(U256::from(1_001u64)).to_vec(), + signed_at: SqliteTimestamp(Utc::now()), + }) + .execute(&mut *conn) + .await + .unwrap(); + + let grant = Grant { + id: grant_id, + shared_grant_id: basic.id, + shared: shared(), + settings, + }; + let context = ctx(ALLOWED, U256::from(100u64)); + let m = EtherTransfer::analyze(&context).unwrap(); + let v = EtherTransfer::evaluate(&context, &m, &grant, &mut *conn) + .await + .unwrap(); + assert!( + v.iter() + .any(|e| matches!(e, EvalViolation::VolumetricLimitExceeded)) + ); +} + +#[tokio::test] +async fn evaluate_passes_at_exactly_volume_limit() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let basic = insert_basic(&mut conn, false).await; + let settings = make_settings(vec![ALLOWED], 1_000); + let grant_id = EtherTransfer::create_grant(&basic, &settings, &mut *conn) + .await + .unwrap(); + + // Exactly at the limit — the check is `>`, so this should not violate + insert_into(evm_transaction_log::table) + .values(NewEvmTransactionLog { + grant_id, + client_id: CLIENT_ID, + wallet_id: WALLET_ID, + chain_id: CHAIN_ID as i32, + eth_value: utils::u256_to_bytes(U256::from(1_000u64)).to_vec(), + signed_at: SqliteTimestamp(Utc::now()), + }) + .execute(&mut *conn) + .await + .unwrap(); + + let grant = Grant { + id: grant_id, + shared_grant_id: basic.id, + shared: shared(), + settings, + }; + let context = ctx(ALLOWED, U256::from(100u64)); + let m = EtherTransfer::analyze(&context).unwrap(); + let v = EtherTransfer::evaluate(&context, &m, &grant, &mut *conn) + .await + .unwrap(); + assert!( + !v.iter() + .any(|e| matches!(e, EvalViolation::VolumetricLimitExceeded)) + ); +} + +// ── try_find_grant ─────────────────────────────────────────────────────── + +#[tokio::test] +async fn try_find_grant_roundtrip() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let basic = insert_basic(&mut conn, false).await; + let settings = make_settings(vec![ALLOWED], 1_000_000); + EtherTransfer::create_grant(&basic, &settings, &mut *conn) + .await + .unwrap(); + + let found = EtherTransfer::try_find_grant(&ctx(ALLOWED, U256::from(1u64)), &mut *conn) + .await + .unwrap(); + + assert!(found.is_some()); + let g = found.unwrap(); + assert_eq!(g.settings.target, vec![ALLOWED]); + assert_eq!(g.settings.limit.max_volume, U256::from(1_000_000u64)); +} + +#[tokio::test] +async fn try_find_grant_revoked_returns_none() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let basic = insert_basic(&mut conn, true).await; + let settings = make_settings(vec![ALLOWED], 1_000_000); + EtherTransfer::create_grant(&basic, &settings, &mut *conn) + .await + .unwrap(); + + let found = EtherTransfer::try_find_grant(&ctx(ALLOWED, U256::from(1u64)), &mut *conn) + .await + .unwrap(); + assert!(found.is_none()); +} + +#[tokio::test] +async fn try_find_grant_wrong_target_returns_none() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let basic = insert_basic(&mut conn, false).await; + let settings = make_settings(vec![ALLOWED], 1_000_000); + EtherTransfer::create_grant(&basic, &settings, &mut *conn) + .await + .unwrap(); + + let found = EtherTransfer::try_find_grant(&ctx(OTHER, U256::from(1u64)), &mut *conn) + .await + .unwrap(); + assert!(found.is_none()); +} + +// ── find_all_grants ────────────────────────────────────────────────────── + +#[tokio::test] +async fn find_all_grants_empty_db() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + let all = EtherTransfer::find_all_grants(&mut *conn).await.unwrap(); + assert!(all.is_empty()); +} + +#[tokio::test] +async fn find_all_grants_excludes_revoked() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let settings = make_settings(vec![ALLOWED], 1_000_000); + let active = insert_basic(&mut conn, false).await; + EtherTransfer::create_grant(&active, &settings, &mut *conn) + .await + .unwrap(); + let revoked = insert_basic(&mut conn, true).await; + EtherTransfer::create_grant(&revoked, &settings, &mut *conn) + .await + .unwrap(); + + let all = EtherTransfer::find_all_grants(&mut *conn).await.unwrap(); + assert_eq!(all.len(), 1); + assert_eq!(all[0].settings.target, vec![ALLOWED]); +} + +#[tokio::test] +async fn find_all_grants_multiple_targets() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let basic = insert_basic(&mut conn, false).await; + let settings = make_settings(vec![ALLOWED, OTHER], 1_000_000); + EtherTransfer::create_grant(&basic, &settings, &mut *conn) + .await + .unwrap(); + + let all = EtherTransfer::find_all_grants(&mut *conn).await.unwrap(); + assert_eq!(all.len(), 1); + assert_eq!(all[0].settings.target.len(), 2); + assert_eq!(all[0].settings.limit.max_volume, U256::from(1_000_000u64)); +} + +#[tokio::test] +async fn find_all_grants_multiple_grants() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let basic1 = insert_basic(&mut conn, false).await; + EtherTransfer::create_grant(&basic1, &make_settings(vec![ALLOWED], 500), &mut *conn) + .await + .unwrap(); + let basic2 = insert_basic(&mut conn, false).await; + EtherTransfer::create_grant(&basic2, &make_settings(vec![OTHER], 1_000), &mut *conn) + .await + .unwrap(); + + let all = EtherTransfer::find_all_grants(&mut *conn).await.unwrap(); + assert_eq!(all.len(), 2); +} diff --git a/server/crates/arbiter-server/src/evm/policies/token_transfers.rs b/server/crates/arbiter-server/src/evm/policies/token_transfers/mod.rs similarity index 99% rename from server/crates/arbiter-server/src/evm/policies/token_transfers.rs rename to server/crates/arbiter-server/src/evm/policies/token_transfers/mod.rs index 9991553..53d8679 100644 --- a/server/crates/arbiter-server/src/evm/policies/token_transfers.rs +++ b/server/crates/arbiter-server/src/evm/policies/token_transfers/mod.rs @@ -6,7 +6,7 @@ use alloy::{ }; use arbiter_tokens_registry::evm::nonfungible::{self, TokenInfo}; use chrono::{DateTime, Duration, Utc}; -use diesel::dsl::insert_into; +use diesel::dsl::{auto_type, insert_into}; use diesel::sqlite::Sqlite; use diesel::{ExpressionMethods, prelude::*}; use diesel_async::{AsyncConnection, RunQueryDsl}; @@ -29,7 +29,7 @@ use crate::evm::{ use super::{DatabaseID, EvalContext, EvalViolation}; -#[diesel::auto_type] +#[auto_type] fn grant_join() -> _ { evm_token_transfer_grant::table.inner_join( evm_basic_grant::table.on(evm_token_transfer_grant::basic_grant_id.eq(evm_basic_grant::id)), @@ -380,3 +380,6 @@ impl Policy for TokenTransfer { .collect() } } + +#[cfg(test)] +mod tests; diff --git a/server/crates/arbiter-server/src/evm/policies/token_transfers/tests.rs b/server/crates/arbiter-server/src/evm/policies/token_transfers/tests.rs new file mode 100644 index 0000000..290656b --- /dev/null +++ b/server/crates/arbiter-server/src/evm/policies/token_transfers/tests.rs @@ -0,0 +1,397 @@ +use alloy::primitives::{Address, Bytes, U256, address}; +use alloy::sol_types::SolCall; +use chrono::{Duration, Utc}; +use diesel::{ExpressionMethods, SelectableHelper, insert_into}; +use diesel_async::RunQueryDsl; + +use crate::db::{ + self, DatabaseConnection, + models::{EvmBasicGrant, NewEvmBasicGrant, SqliteTimestamp}, + schema::evm_basic_grant, +}; +use crate::evm::{ + abi::IERC20::transferCall, + policies::{EvalContext, EvalViolation, Grant, Policy, SharedGrantSettings, VolumeRateLimit}, + utils, +}; + +use super::{Settings, TokenTransfer}; + +// DAI on Ethereum mainnet — present in the static token registry +const CHAIN_ID: u64 = 1; +const DAI: Address = address!("6B175474E89094C44Da98b954EedeAC495271d0F"); + +const WALLET_ID: i32 = 1; +const CLIENT_ID: i32 = 2; + +const RECIPIENT: Address = address!("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); +const OTHER: Address = address!("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"); +const UNKNOWN_TOKEN: Address = address!("cccccccccccccccccccccccccccccccccccccccc"); + +/// Encode `transfer(to, value)` raw params (no 4-byte selector). +/// `abi_decode_raw_validate` expects exactly this format. +fn transfer_calldata(to: Address, value: U256) -> Bytes { + let mut raw = Vec::new(); + transferCall { to, value }.abi_encode_raw(&mut raw); + Bytes::from(raw) +} + +fn ctx(to: Address, calldata: Bytes) -> EvalContext { + EvalContext { + wallet_id: WALLET_ID, + client_id: CLIENT_ID, + chain: CHAIN_ID, + to, + value: U256::ZERO, + calldata, + max_fee_per_gas: 0, + max_priority_fee_per_gas: 0, + } +} + +async fn insert_basic(conn: &mut DatabaseConnection, revoked: bool) -> EvmBasicGrant { + insert_into(evm_basic_grant::table) + .values(NewEvmBasicGrant { + wallet_id: WALLET_ID, + client_id: CLIENT_ID, + chain_id: CHAIN_ID as i32, + valid_from: None, + valid_until: None, + max_gas_fee_per_gas: None, + max_priority_fee_per_gas: None, + rate_limit_count: None, + rate_limit_window_secs: None, + revoked_at: revoked.then(|| SqliteTimestamp(Utc::now())), + }) + .returning(EvmBasicGrant::as_select()) + .get_result(conn) + .await + .unwrap() +} + +fn make_settings(target: Option
, max_volume: Option) -> Settings { + Settings { + token_contract: DAI, + target, + volume_limits: max_volume + .map(|v| { + vec![VolumeRateLimit { + max_volume: U256::from(v), + window: Duration::hours(1), + }] + }) + .unwrap_or_default(), + } +} + +fn shared() -> SharedGrantSettings { + SharedGrantSettings { + wallet_id: WALLET_ID, + chain: CHAIN_ID, + valid_from: None, + valid_until: None, + max_gas_fee_per_gas: None, + max_priority_fee_per_gas: None, + rate_limit: None, + } +} + +// ── analyze ───────────────────────────────────────────────────────────── + +#[test] +fn analyze_known_token_valid_calldata() { + let calldata = transfer_calldata(RECIPIENT, U256::from(100u64)); + let m = TokenTransfer::analyze(&ctx(DAI, calldata)).unwrap(); + assert_eq!(m.to, RECIPIENT); + assert_eq!(m.value, U256::from(100u64)); +} + +#[test] +fn analyze_unknown_token_returns_none() { + let calldata = transfer_calldata(RECIPIENT, U256::from(100u64)); + assert!(TokenTransfer::analyze(&ctx(UNKNOWN_TOKEN, calldata)).is_none()); +} + +#[test] +fn analyze_invalid_calldata_returns_none() { + let calldata = Bytes::from(vec![0xde, 0xad, 0xbe, 0xef]); + assert!(TokenTransfer::analyze(&ctx(DAI, calldata)).is_none()); +} + +#[test] +fn analyze_empty_calldata_returns_none() { + assert!(TokenTransfer::analyze(&ctx(DAI, Bytes::new())).is_none()); +} + +// ── evaluate ──────────────────────────────────────────────────────────── + +#[tokio::test] +async fn evaluate_rejects_nonzero_eth_value() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let grant = Grant { + id: 999, + shared_grant_id: 999, + shared: shared(), + settings: make_settings(None, None), + }; + let calldata = transfer_calldata(RECIPIENT, U256::from(100u64)); + let mut context = ctx(DAI, calldata); + context.value = U256::from(1u64); // ETH attached to an ERC-20 call + + let m = TokenTransfer::analyze(&EvalContext { value: U256::ZERO, ..context.clone() }) + .unwrap(); + let v = TokenTransfer::evaluate(&context, &m, &grant, &mut *conn).await.unwrap(); + assert!(v.iter().any(|e| matches!(e, EvalViolation::InvalidTransactionType))); +} + +#[tokio::test] +async fn evaluate_passes_any_recipient_when_no_restriction() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let grant = Grant { + id: 999, + shared_grant_id: 999, + shared: shared(), + settings: make_settings(None, None), + }; + let calldata = transfer_calldata(RECIPIENT, U256::from(100u64)); + let context = ctx(DAI, calldata); + let m = TokenTransfer::analyze(&context).unwrap(); + let v = TokenTransfer::evaluate(&context, &m, &grant, &mut *conn).await.unwrap(); + assert!(v.is_empty()); +} + +#[tokio::test] +async fn evaluate_passes_matching_restricted_recipient() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let grant = Grant { + id: 999, + shared_grant_id: 999, + shared: shared(), + settings: make_settings(Some(RECIPIENT), None), + }; + let calldata = transfer_calldata(RECIPIENT, U256::from(100u64)); + let context = ctx(DAI, calldata); + let m = TokenTransfer::analyze(&context).unwrap(); + let v = TokenTransfer::evaluate(&context, &m, &grant, &mut *conn).await.unwrap(); + assert!(v.is_empty()); +} + +#[tokio::test] +async fn evaluate_rejects_wrong_restricted_recipient() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let grant = Grant { + id: 999, + shared_grant_id: 999, + shared: shared(), + settings: make_settings(Some(RECIPIENT), None), + }; + let calldata = transfer_calldata(OTHER, U256::from(100u64)); + let context = ctx(DAI, calldata); + let m = TokenTransfer::analyze(&context).unwrap(); + let v = TokenTransfer::evaluate(&context, &m, &grant, &mut *conn).await.unwrap(); + assert!(v.iter().any(|e| matches!(e, EvalViolation::InvalidTarget { .. }))); +} + +#[tokio::test] +async fn evaluate_passes_volume_within_limit() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let basic = insert_basic(&mut conn, false).await; + let settings = make_settings(None, Some(1_000)); + let grant_id = TokenTransfer::create_grant(&basic, &settings, &mut *conn).await.unwrap(); + + // Record a past transfer of 500 (within 1000 limit) + use crate::db::{models::NewEvmTokenTransferLog, schema::evm_token_transfer_log}; + insert_into(evm_token_transfer_log::table) + .values(NewEvmTokenTransferLog { + grant_id, + log_id: 0, + chain_id: CHAIN_ID as i32, + token_contract: DAI.to_vec(), + recipient_address: RECIPIENT.to_vec(), + value: utils::u256_to_bytes(U256::from(500u64)).to_vec(), + }) + .execute(&mut *conn) + .await + .unwrap(); + + let grant = Grant { id: grant_id, shared_grant_id: basic.id, shared: shared(), settings }; + let calldata = transfer_calldata(RECIPIENT, U256::from(100u64)); + let context = ctx(DAI, calldata); + let m = TokenTransfer::analyze(&context).unwrap(); + let v = TokenTransfer::evaluate(&context, &m, &grant, &mut *conn).await.unwrap(); + assert!(!v.iter().any(|e| matches!(e, EvalViolation::VolumetricLimitExceeded))); +} + +#[tokio::test] +async fn evaluate_rejects_volume_over_limit() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let basic = insert_basic(&mut conn, false).await; + let settings = make_settings(None, Some(1_000)); + let grant_id = TokenTransfer::create_grant(&basic, &settings, &mut *conn).await.unwrap(); + + use crate::db::{models::NewEvmTokenTransferLog, schema::evm_token_transfer_log}; + insert_into(evm_token_transfer_log::table) + .values(NewEvmTokenTransferLog { + grant_id, + log_id: 0, + chain_id: CHAIN_ID as i32, + token_contract: DAI.to_vec(), + recipient_address: RECIPIENT.to_vec(), + value: utils::u256_to_bytes(U256::from(1_001u64)).to_vec(), + }) + .execute(&mut *conn) + .await + .unwrap(); + + let grant = Grant { id: grant_id, shared_grant_id: basic.id, shared: shared(), settings }; + let calldata = transfer_calldata(RECIPIENT, U256::from(100u64)); + let context = ctx(DAI, calldata); + let m = TokenTransfer::analyze(&context).unwrap(); + let v = TokenTransfer::evaluate(&context, &m, &grant, &mut *conn).await.unwrap(); + assert!(v.iter().any(|e| matches!(e, EvalViolation::VolumetricLimitExceeded))); +} + +#[tokio::test] +async fn evaluate_no_volume_limits_always_passes() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let grant = Grant { + id: 999, + shared_grant_id: 999, + shared: shared(), + settings: make_settings(None, None), // no volume limits + }; + let calldata = transfer_calldata(RECIPIENT, U256::from(u64::MAX)); + let context = ctx(DAI, calldata); + let m = TokenTransfer::analyze(&context).unwrap(); + let v = TokenTransfer::evaluate(&context, &m, &grant, &mut *conn).await.unwrap(); + assert!(!v.iter().any(|e| matches!(e, EvalViolation::VolumetricLimitExceeded))); +} + +// ── try_find_grant ─────────────────────────────────────────────────────── + +#[tokio::test] +async fn try_find_grant_roundtrip() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let basic = insert_basic(&mut conn, false).await; + let settings = make_settings(Some(RECIPIENT), Some(5_000)); + TokenTransfer::create_grant(&basic, &settings, &mut *conn).await.unwrap(); + + let calldata = transfer_calldata(RECIPIENT, U256::from(100u64)); + let found = TokenTransfer::try_find_grant(&ctx(DAI, calldata), &mut *conn) + .await + .unwrap(); + + assert!(found.is_some()); + let g = found.unwrap(); + assert_eq!(g.settings.token_contract, DAI); + assert_eq!(g.settings.target, Some(RECIPIENT)); + assert_eq!(g.settings.volume_limits.len(), 1); + assert_eq!(g.settings.volume_limits[0].max_volume, U256::from(5_000u64)); +} + +#[tokio::test] +async fn try_find_grant_revoked_returns_none() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let basic = insert_basic(&mut conn, true).await; + let settings = make_settings(None, None); + TokenTransfer::create_grant(&basic, &settings, &mut *conn).await.unwrap(); + + let calldata = transfer_calldata(RECIPIENT, U256::from(1u64)); + let found = TokenTransfer::try_find_grant(&ctx(DAI, calldata), &mut *conn) + .await + .unwrap(); + assert!(found.is_none()); +} + +#[tokio::test] +async fn try_find_grant_unknown_token_returns_none() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let basic = insert_basic(&mut conn, false).await; + let settings = make_settings(None, None); + TokenTransfer::create_grant(&basic, &settings, &mut *conn).await.unwrap(); + + // Query with a different token contract + let calldata = transfer_calldata(RECIPIENT, U256::from(1u64)); + let found = TokenTransfer::try_find_grant(&ctx(UNKNOWN_TOKEN, calldata), &mut *conn) + .await + .unwrap(); + assert!(found.is_none()); +} + +// ── find_all_grants ────────────────────────────────────────────────────── + +#[tokio::test] +async fn find_all_grants_empty_db() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + let all = TokenTransfer::find_all_grants(&mut *conn).await.unwrap(); + assert!(all.is_empty()); +} + +#[tokio::test] +async fn find_all_grants_excludes_revoked() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let settings = make_settings(None, Some(1_000)); + let active = insert_basic(&mut conn, false).await; + TokenTransfer::create_grant(&active, &settings, &mut *conn).await.unwrap(); + let revoked = insert_basic(&mut conn, true).await; + TokenTransfer::create_grant(&revoked, &settings, &mut *conn).await.unwrap(); + + let all = TokenTransfer::find_all_grants(&mut *conn).await.unwrap(); + assert_eq!(all.len(), 1); +} + +#[tokio::test] +async fn find_all_grants_loads_volume_limits() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let basic = insert_basic(&mut conn, false).await; + let settings = make_settings(None, Some(9_999)); + TokenTransfer::create_grant(&basic, &settings, &mut *conn).await.unwrap(); + + let all = TokenTransfer::find_all_grants(&mut *conn).await.unwrap(); + assert_eq!(all.len(), 1); + assert_eq!(all[0].settings.volume_limits.len(), 1); + assert_eq!(all[0].settings.volume_limits[0].max_volume, U256::from(9_999u64)); +} + +#[tokio::test] +async fn find_all_grants_multiple_grants_batch_loaded() { + let db = db::create_test_pool().await; + let mut conn = db.get().await.unwrap(); + + let b1 = insert_basic(&mut conn, false).await; + TokenTransfer::create_grant(&b1, &make_settings(None, Some(1_000)), &mut *conn) + .await + .unwrap(); + let b2 = insert_basic(&mut conn, false).await; + TokenTransfer::create_grant(&b2, &make_settings(Some(RECIPIENT), Some(2_000)), &mut *conn) + .await + .unwrap(); + + let all = TokenTransfer::find_all_grants(&mut *conn).await.unwrap(); + assert_eq!(all.len(), 2); +}