26 Commits

Author SHA1 Message Date
CleverWild
41b3fc5d39 fix(lints): remove unstable ones
Some checks failed
ci/woodpecker/pr/server-audit Pipeline was successful
ci/woodpecker/pr/server-vet Pipeline failed
ci/woodpecker/pr/server-lint Pipeline was successful
ci/woodpecker/pr/server-test Pipeline was successful
2026-04-10 01:00:21 +02:00
CleverWild
f6a0c32b9d feat: rustc and clippy linting
Some checks failed
ci/woodpecker/pr/server-audit Pipeline was successful
ci/woodpecker/pr/server-vet Pipeline failed
ci/woodpecker/pr/server-lint Pipeline failed
ci/woodpecker/pr/server-test Pipeline was successful
2026-04-10 00:42:43 +02:00
62dff3f810 Merge pull request 'refactor(hashing): introduce Hashable derive macro and migrate server types' (#82) from hashing-proc-macro into main
Some checks failed
ci/woodpecker/push/server-audit Pipeline was successful
ci/woodpecker/push/server-lint Pipeline was successful
ci/woodpecker/push/server-vet Pipeline failed
ci/woodpecker/push/server-test Pipeline was successful
Reviewed-on: #82
Reviewed-by: Stas <business@jexter.tech>
2026-04-08 00:18:40 +00:00
CleverWild
6e22f368c9 refactor(hashing): introduce Hashable derive macro and migrate server types
Some checks failed
ci/woodpecker/pr/server-audit Pipeline was successful
ci/woodpecker/pr/server-vet Pipeline failed
ci/woodpecker/pr/server-lint Pipeline was successful
ci/woodpecker/pr/server-test Pipeline was successful
2026-04-08 01:32:59 +02:00
f3cf6a9438 Merge pull request 'Post-quantum crypto and better useragent security' (#80) from push-xrxykvkuxpsv into main
Some checks failed
ci/woodpecker/push/server-audit Pipeline was successful
ci/woodpecker/push/server-lint Pipeline failed
ci/woodpecker/push/server-vet Pipeline failed
ci/woodpecker/push/server-test Pipeline was successful
Reviewed-on: #80
2026-04-07 19:26:54 +00:00
hdbg
a9f9fc2a9d housekeeping(server): fixed clippy warns
Some checks failed
ci/woodpecker/pr/server-audit Pipeline was successful
ci/woodpecker/pr/server-lint Pipeline failed
ci/woodpecker/pr/server-vet Pipeline failed
ci/woodpecker/pr/server-test Pipeline was successful
2026-04-07 16:28:47 +02:00
hdbg
d22ab49e3d refactor(server): moved shared module crypto into arbiter-crypto 2026-04-07 16:24:51 +02:00
hdbg
a845181ef6 docs: ml-dsa scheme everywhere
Some checks failed
ci/woodpecker/pr/server-audit Pipeline was successful
ci/woodpecker/pr/server-vet Pipeline failed
ci/woodpecker/pr/server-lint Pipeline failed
ci/woodpecker/pr/server-test Pipeline was successful
2026-04-07 15:02:32 +02:00
hdbg
0d424f3afc refactor(server): migrated auth to ml-dsa 2026-04-07 14:55:31 +02:00
hdbg
1497884ce6 fix(server::bootsrapper): token compare is now constant-time
Some checks failed
ci/woodpecker/push/server-audit Pipeline was successful
ci/woodpecker/push/server-vet Pipeline failed
ci/woodpecker/push/server-lint Pipeline failed
ci/woodpecker/push/server-test Pipeline was successful
2026-04-06 18:33:47 +02:00
hdbg
b3464cf8a6 tests(server::client::auth): integrity envelope insertion for valid paths
Some checks failed
ci/woodpecker/pr/server-audit Pipeline was successful
ci/woodpecker/pr/server-vet Pipeline failed
ci/woodpecker/pr/server-lint Pipeline failed
ci/woodpecker/pr/server-test Pipeline was successful
ci/woodpecker/push/server-audit Pipeline was successful
ci/woodpecker/push/server-vet Pipeline failed
ci/woodpecker/push/server-lint Pipeline failed
ci/woodpecker/push/server-test Pipeline was successful
2026-04-06 18:24:13 +02:00
hdbg
46d1318b6f feat(server): add integrity verification for client keys 2026-04-06 18:13:11 +02:00
9c80d51d45 Merge pull request 'fix(server): replaced postcard-based integrity fingerprint with custom trait providing order-independent hashing' (#77) from push-opwuyuwxknyo into main
Some checks failed
ci/woodpecker/push/server-audit Pipeline was successful
ci/woodpecker/push/server-vet Pipeline failed
ci/woodpecker/push/server-lint Pipeline failed
ci/woodpecker/push/server-test Pipeline was successful
Reviewed-on: #77
2026-04-06 15:42:47 +00:00
hdbg
33456a644d tests(server): property-based testing for ordering independency for hash
Some checks failed
ci/woodpecker/pr/server-audit Pipeline was successful
ci/woodpecker/pr/server-lint Pipeline failed
ci/woodpecker/pr/server-vet Pipeline failed
ci/woodpecker/pr/server-test Pipeline was successful
2026-04-06 17:40:41 +02:00
hdbg
5bc0c42cc7 fix(server): replaced postcard-based integrity fingerprint with custom trait providing order-independent hashing 2026-04-06 16:25:32 +02:00
hdbg
f6b62ab884 fix(server): added chain_id check and covered check_shared_constraints with unit tests
Some checks failed
ci/woodpecker/pr/server-audit Pipeline was successful
ci/woodpecker/pr/server-lint Pipeline failed
ci/woodpecker/pr/server-vet Pipeline failed
ci/woodpecker/pr/server-test Pipeline was successful
ci/woodpecker/push/server-audit Pipeline was successful
ci/woodpecker/push/server-lint Pipeline failed
ci/woodpecker/push/server-vet Pipeline failed
ci/woodpecker/push/server-test Pipeline was successful
2026-04-06 12:57:18 +02:00
hdbg
2dd5a3f32f tests(server): initial cargo-mutants
Some checks failed
ci/woodpecker/push/server-audit Pipeline was successful
ci/woodpecker/push/server-vet Pipeline failed
ci/woodpecker/push/server-lint Pipeline failed
ci/woodpecker/push/server-test Pipeline was successful
2026-04-06 12:03:56 +02:00
hdbg
1aca9d4007 fix(server): simplify hash function for debug profile 2026-04-05 22:50:28 +02:00
5ee1b49c43 Merge pull request 'feat(server): integrity envelope engine for EVM grants with HMAC verification' (#51) from integrity-envelope into main
Some checks failed
ci/woodpecker/push/server-audit Pipeline was successful
ci/woodpecker/push/server-lint Pipeline failed
ci/woodpecker/push/server-vet Pipeline failed
ci/woodpecker/push/server-test Pipeline was successful
Reviewed-on: #51
2026-04-05 16:26:51 +00:00
hdbg
00745bb381 tests(server): fixed for new integrity checks
Some checks failed
ci/woodpecker/pr/server-audit Pipeline was successful
ci/woodpecker/pr/server-vet Pipeline failed
ci/woodpecker/pr/server-lint Pipeline failed
ci/woodpecker/pr/server-test Pipeline was successful
2026-04-05 14:49:02 +02:00
hdbg
b122aa464c refactor(server): rework envelopes and integrity check
Some checks failed
ci/woodpecker/pr/server-audit Pipeline was successful
ci/woodpecker/pr/server-vet Pipeline failed
ci/woodpecker/pr/server-lint Pipeline failed
ci/woodpecker/pr/server-test Pipeline failed
2026-04-05 14:17:00 +02:00
hdbg
9fab945a00 fix(server): remove stale mentions of miette
Some checks failed
ci/woodpecker/pr/server-audit Pipeline was successful
ci/woodpecker/pr/server-lint Pipeline failed
ci/woodpecker/pr/server-test Pipeline failed
ci/woodpecker/pr/server-vet Pipeline failed
2026-04-05 10:45:24 +02:00
CleverWild
aeed664e9a chore: inline integrity proto types
Some checks failed
ci/woodpecker/pr/server-lint Pipeline failed
ci/woodpecker/pr/server-audit Pipeline was successful
ci/woodpecker/pr/server-test Pipeline failed
ci/woodpecker/pr/server-vet Pipeline failed
2026-04-05 10:44:21 +02:00
CleverWild
4057c1fc12 feat(server): integrity envelope engine for EVM grants with HMAC verification 2026-04-05 10:44:21 +02:00
hdbg
f5eb51978d docs: add recovery operators and multi-operator details 2026-04-05 08:27:24 +00:00
hdbg
d997e0f843 docs: add multi-operator governance section 2026-04-05 08:27:24 +00:00
91 changed files with 4157 additions and 1912 deletions

View File

@@ -100,6 +100,27 @@ diesel migration generate <name> --migration-dir crates/arbiter-server/migration
diesel migration run --migration-dir crates/arbiter-server/migrations diesel migration run --migration-dir crates/arbiter-server/migrations
``` ```
### Code Conventions
**`#[must_use]` Attribute:**
Apply the `#[must_use]` attribute to return types of functions where the return value is critical and should not be accidentally ignored. This is commonly used for:
- Methods that return `bool` indicating success/failure or validation state
- Any function where ignoring the return value indicates a logic error
Do not apply `#[must_use]` redundantly to items (types or functions) that are already annotated with `#[must_use]`.
Example:
```rust
#[must_use]
pub fn verify(&self, nonce: i32, context: &[u8], signature: &Signature) -> bool {
// verification logic
}
```
This forces callers to either use the return value or explicitly ignore it with `let _ = ...;`, preventing silent failures.
## User Agent (Flutter + Rinf at `useragent/`) ## User Agent (Flutter + Rinf at `useragent/`)
The Flutter app uses [Rinf](https://rinf.cunarist.org) to call Rust code. The Rust logic lives in `useragent/native/hub/` as a separate crate that uses `arbiter-useragent` for the gRPC client. The Flutter app uses [Rinf](https://rinf.cunarist.org) to call Rust code. The Rust logic lives in `useragent/native/hub/` as a separate crate that uses `arbiter-useragent` for the gRPC client.

View File

@@ -11,6 +11,7 @@ Arbiter distinguishes two kinds of peers:
- **User Agent** — A client application used by the owner to manage the vault (create wallets, approve SDK clients, configure policies). - **User Agent** — A client application used by the owner to manage the vault (create wallets, approve SDK clients, configure policies).
- **SDK Client** — A consumer of signing capabilities, typically an automation tool. In the future, this could include a browser-based wallet. - **SDK Client** — A consumer of signing capabilities, typically an automation tool. In the future, this could include a browser-based wallet.
- **Recovery Operator** — A dormant recovery participant with narrowly scoped authority used only for custody recovery and operator replacement.
--- ---
@@ -42,7 +43,149 @@ There is no bootstrap mechanism for SDK clients. They must be explicitly approve
--- ---
## 3. Server Identity ## 3. Multi-Operator Governance
When more than one User Agent is registered, the vault is treated as having multiple operators. In that mode, sensitive actions are governed by voting rather than by a single operator decision.
### 3.1 Voting Rules
Voting is based on the total number of registered operators:
- **1 operator:** no vote is needed; the single operator decides directly.
- **2 operators:** full consensus is required; both operators must approve.
- **3 or more operators:** quorum is `floor(N / 2) + 1`.
For a decision to count, the operator's approval or rejection must be signed by that operator's associated key. Unsigned votes, or votes that fail signature verification, are ignored.
Examples:
- **3 operators:** 2 approvals required
- **4 operators:** 3 approvals required
### 3.2 Actions Requiring a Vote
In multi-operator mode, a successful vote is required for:
- approving new SDK clients
- granting an SDK client visibility to a wallet
- approving a one-off transaction
- approving creation of a persistent grant
- approving operator replacement
- approving server updates
- updating Shamir secret-sharing parameters
### 3.3 Special Rule for Key Rotation
Key rotation always requires full quorum, regardless of the normal voting threshold.
This is stricter than ordinary governance actions because rotating the root key requires every operator to participate in coordinated share refresh/update steps. The root key itself is not redistributed directly, but each operator's share material must be changed consistently.
### 3.4 Root Key Custody
When the vault has multiple operators, the vault root key is protected using Shamir secret sharing.
The vault root key is encrypted in a way that requires reconstruction from user-held shares rather than from a single shared password.
For ordinary operators, the Shamir threshold matches the ordinary governance quorum. For example:
- **2 operators:** `2-of-2`
- **3 operators:** `2-of-3`
- **4 operators:** `3-of-4`
In practice, the Shamir share set also includes Recovery Operator shares. This means the effective Shamir parameters are computed over the combined share pool while keeping the same threshold. For example:
- **3 ordinary operators + 2 recovery shares:** `2-of-5`
This ensures that the normal custody threshold follows the ordinary operator quorum, while still allowing dormant recovery shares to exist for break-glass recovery flows.
### 3.5 Recovery Operators
Recovery Operators are a separate peer type from ordinary vault operators.
Their role is intentionally narrow. They can only:
- participate in unsealing the vault
- vote for operator replacement
Recovery Operators do not participate in routine governance such as approving SDK clients, granting wallet visibility, approving transactions, creating grants, approving server updates, or changing Shamir parameters.
### 3.6 Sleeping and Waking Recovery Operators
By default, Recovery Operators are **sleeping** and do not participate in any active flow.
Any ordinary operator may request that Recovery Operators **wake up**.
Any ordinary operator may also cancel a pending wake-up request.
This creates a dispute window before recovery powers become active. The default wake-up delay is **14 days**.
Recovery Operators are therefore part of the break-glass recovery path rather than the normal operating quorum.
The high-level recovery flow is:
```mermaid
sequenceDiagram
autonumber
actor Op as Ordinary Operator
participant Server
actor Other as Other Operator
actor Rec as Recovery Operator
Op->>Server: Request recovery wake-up
Server-->>Op: Wake-up pending
Note over Server: Default dispute window: 14 days
alt Wake-up cancelled during dispute window
Other->>Server: Cancel wake-up
Server-->>Op: Recovery cancelled
Server-->>Rec: Stay sleeping
else No cancellation for 14 days
Server-->>Rec: Wake up
Rec->>Server: Join recovery flow
critical Recovery authority
Rec->>Server: Participate in unseal
Rec->>Server: Vote on operator replacement
end
Server-->>Op: Recovery mode active
end
```
### 3.7 Committee Formation
There are two ways to form a multi-operator committee:
- convert an existing single-operator vault by adding new operators
- bootstrap an unbootstrapped vault directly into multi-operator mode
In both cases, committee formation is a coordinated process. Arbiter does not allow multi-operator custody to emerge implicitly from unrelated registrations.
### 3.8 Bootstrapping an Unbootstrapped Vault into Multi-Operator Mode
When an unbootstrapped vault is initialized as a multi-operator vault, the setup proceeds as follows:
1. An operator connects to the unbootstrapped vault using a User Agent and the bootstrap token.
2. During bootstrap setup, that operator declares:
- the total number of ordinary operators
- the total number of Recovery Operators
3. The vault enters **multi-bootstrap mode**.
4. While in multi-bootstrap mode:
- every ordinary operator must connect with a User Agent using the bootstrap token
- every Recovery Operator must also connect using the bootstrap token
- each participant is registered individually
- each participant's share is created and protected with that participant's credentials
5. The vault is considered fully bootstrapped only after all declared operator and recovery-share registrations have completed successfully.
This means the operator and recovery set is fixed at bootstrap completion time, based on the counts declared when multi-bootstrap mode was entered.
### 3.9 Special Bootstrap Constraint for Two-Operator Vaults
If a vault is declared with exactly **2 ordinary operators**, Arbiter requires at least **1 Recovery Operator** to be configured during bootstrap.
This prevents the worst-case custody failure in which a `2-of-2` operator set becomes permanently unrecoverable after loss of a single operator.
---
## 4. Server Identity
The server proves its identity using TLS with a self-signed certificate. The TLS private key is generated on first run and is long-term; no rotation mechanism exists yet due to the complexity of multi-peer coordination. The server proves its identity using TLS with a self-signed certificate. The TLS private key is generated on first run and is long-term; no rotation mechanism exists yet due to the complexity of multi-peer coordination.
@@ -55,9 +198,9 @@ Peers verify the server by its **public key fingerprint**:
--- ---
## 4. Key Management ## 5. Key Management
### 4.1 Key Hierarchy ### 5.1 Key Hierarchy
There are three layers of keys: There are three layers of keys:
@@ -72,19 +215,19 @@ This layered design enables:
- **Password rotation** without re-encrypting every wallet key (only the root key is re-encrypted). - **Password rotation** without re-encrypting every wallet key (only the root key is re-encrypted).
- **Root key rotation** without requiring the user to change their password. - **Root key rotation** without requiring the user to change their password.
### 4.2 Encryption at Rest ### 5.2 Encryption at Rest
The database stores everything in encrypted form using symmetric AEAD. The encryption scheme is versioned to support transparent migration — when the vault unseals, Arbiter automatically re-encrypts any entries that are behind the current scheme version. See [IMPLEMENTATION.md](IMPLEMENTATION.md) for the specific scheme and versioning mechanism. The database stores everything in encrypted form using symmetric AEAD. The encryption scheme is versioned to support transparent migration — when the vault unseals, Arbiter automatically re-encrypts any entries that are behind the current scheme version. See [IMPLEMENTATION.md](IMPLEMENTATION.md) for the specific scheme and versioning mechanism.
--- ---
## 5. Vault Lifecycle ## 6. Vault Lifecycle
### 5.1 Sealed State ### 6.1 Sealed State
On boot, the root key is encrypted and the server cannot perform any signing operations. This state is called **Sealed**. On boot, the root key is encrypted and the server cannot perform any signing operations. This state is called **Sealed**.
### 5.2 Unseal Flow ### 6.2 Unseal Flow
To transition to the **Unsealed** state, a User Agent must provide the password: To transition to the **Unsealed** state, a User Agent must provide the password:
@@ -95,7 +238,7 @@ To transition to the **Unsealed** state, a User Agent must provide the password:
- **Success:** The root key is decrypted and placed into a hardened memory cell. The server transitions to `Unsealed`. Any entries pending encryption scheme migration are re-encrypted. - **Success:** The root key is decrypted and placed into a hardened memory cell. The server transitions to `Unsealed`. Any entries pending encryption scheme migration are re-encrypted.
- **Failure:** The server returns an error indicating the password is incorrect. - **Failure:** The server returns an error indicating the password is incorrect.
### 5.3 Memory Protection ### 6.3 Memory Protection
Once unsealed, the root key must be protected in memory against: Once unsealed, the root key must be protected in memory against:
@@ -107,9 +250,9 @@ See [IMPLEMENTATION.md](IMPLEMENTATION.md) for the current and planned memory pr
--- ---
## 6. Permission Engine ## 7. Permission Engine
### 6.1 Fundamental Rules ### 7.1 Fundamental Rules
- SDK clients have **no access by default**. - SDK clients have **no access by default**.
- Access is granted **explicitly** by a User Agent. - Access is granted **explicitly** by a User Agent.
@@ -119,11 +262,45 @@ Each blockchain requires its own policy system due to differences in static tran
Arbiter is also responsible for ensuring that **transaction nonces are never reused**. Arbiter is also responsible for ensuring that **transaction nonces are never reused**.
### 6.2 EVM Policies ### 7.2 EVM Policies
Every EVM grant is scoped to a specific **wallet** and **chain ID**. Every EVM grant is scoped to a specific **wallet** and **chain ID**.
#### 6.2.1 Transaction Sub-Grants #### 7.2.0 Transaction Signing Sequence
The high-level interaction order is:
```mermaid
sequenceDiagram
autonumber
actor SDK as SDK Client
participant Server
participant UA as User Agent
SDK->>Server: SignTransactionRequest
Server->>Server: Resolve wallet and wallet visibility
alt Visibility approval required
Server->>UA: Ask for wallet visibility approval
UA-->>Server: Vote result
end
Server->>Server: Evaluate transaction
Server->>Server: Load grant and limits context
alt Grant approval required
Server->>UA: Ask for execution / grant approval
UA-->>Server: Vote result
opt Create persistent grant
Server->>Server: Create and store grant
end
Server->>Server: Retry evaluation
end
critical Final authorization path
Server->>Server: Check limits and record execution
Server-->>Server: Signature or evaluation error
end
Server-->>SDK: Signature or error
```
#### 7.2.1 Transaction Sub-Grants
Arbiter maintains an ever-expanding database of known contracts and their ABIs. Based on contract knowledge, transaction requests fall into three categories: Arbiter maintains an ever-expanding database of known contracts and their ABIs. Based on contract knowledge, transaction requests fall into three categories:
@@ -147,7 +324,7 @@ Available restrictions:
These transactions have no `calldata` and therefore cannot interact with contracts. They can be subject to the same volume and rate restrictions as above. These transactions have no `calldata` and therefore cannot interact with contracts. They can be subject to the same volume and rate restrictions as above.
#### 6.2.2 Global Limits #### 7.2.2 Global Limits
In addition to sub-grant-specific restrictions, the following limits can be applied across all grant types: In addition to sub-grant-specific restrictions, the following limits can be applied across all grant types:

View File

@@ -100,6 +100,27 @@ diesel migration generate <name> --migration-dir crates/arbiter-server/migration
diesel migration run --migration-dir crates/arbiter-server/migrations diesel migration run --migration-dir crates/arbiter-server/migrations
``` ```
### Code Conventions
**`#[must_use]` Attribute:**
Apply the `#[must_use]` attribute to return types of functions where the return value is critical and should not be accidentally ignored. This is commonly used for:
- Methods that return `bool` indicating success/failure or validation state
- Any function where ignoring the return value indicates a logic error
Do not apply `#[must_use]` redundantly to items (types or functions) that are already annotated with `#[must_use]`.
Example:
```rust
#[must_use]
pub fn verify(&self, nonce: i32, context: &[u8], signature: &Signature) -> bool {
// verification logic
}
```
This forces callers to either use the return value or explicitly ignore it with `let _ = ...;`, preventing silent failures.
## User Agent (Flutter + Rinf at `useragent/`) ## User Agent (Flutter + Rinf at `useragent/`)
The Flutter app uses [Rinf](https://rinf.cunarist.org) to call Rust code. The Rust logic lives in `useragent/native/hub/` as a separate crate that uses `arbiter-useragent` for the gRPC client. The Flutter app uses [Rinf](https://rinf.cunarist.org) to call Rust code. The Rust logic lives in `useragent/native/hub/` as a separate crate that uses `arbiter-useragent` for the gRPC client.

View File

@@ -67,18 +67,14 @@ The `program_client.nonce` column stores the **next usable nonce** — i.e. it i
## Cryptography ## Cryptography
### Authentication ### Authentication
- **Client protocol:** ed25519 - **Client protocol:** ML-DSA
### User-Agent Authentication ### User-Agent Authentication
User-agent authentication supports multiple signature schemes because platform-provided "hardware-bound" keys do not expose a uniform algorithm across operating systems and hardware. User-agent authentication supports multiple signature schemes because platform-provided "hardware-bound" keys do not expose a uniform algorithm across operating systems and hardware.
- **Supported schemes:** RSA, Ed25519, ECDSA (secp256k1) - **Supported schemes:** ML-DSA
- **Why:** the user agent authenticates with keys backed by platform facilities, and those facilities differ by platform - **Why:** Secure Enclave (MacOS) support them natively, on other platforms we could emulate while they roll-out
- **Apple Silicon Secure Enclave / Secure Element:** ECDSA-only in practice
- **Windows Hello / TPM 2.0:** currently RSA-backed in our integration
This is why the user-agent auth protocol carries an explicit `KeyType`, while the SDK client protocol remains fixed to ed25519.
### Encryption at Rest ### Encryption at Rest
- **Scheme:** Symmetric AEAD — currently **XChaCha20-Poly1305** - **Scheme:** Symmetric AEAD — currently **XChaCha20-Poly1305**
@@ -128,6 +124,52 @@ The central abstraction is the `Policy` trait. Each implementation handles one s
4. **Evaluate**`Policy::evaluate` checks the decoded meaning against the grant's policy-specific constraints and returns any violations. 4. **Evaluate**`Policy::evaluate` checks the decoded meaning against the grant's policy-specific constraints and returns any violations.
5. **Record** — If `RunKind::Execution` and there are no violations, the engine writes to `evm_transaction_log` and calls `Policy::record_transaction` for any policy-specific logging (e.g., token transfer volume). 5. **Record** — If `RunKind::Execution` and there are no violations, the engine writes to `evm_transaction_log` and calls `Policy::record_transaction` for any policy-specific logging (e.g., token transfer volume).
The detailed branch structure is shown below:
```mermaid
flowchart TD
A[SDK Client sends sign transaction request] --> B[Server resolves wallet]
B --> C{Wallet exists?}
C -- No --> Z1[Return wallet not found error]
C -- Yes --> D[Check SDK client wallet visibility]
D --> E{Wallet visible to SDK client?}
E -- No --> F[Start wallet visibility voting flow]
F --> G{Vote approved?}
G -- No --> Z2[Return wallet access denied error]
G -- Yes --> H[Persist wallet visibility]
E -- Yes --> I[Classify transaction meaning]
H --> I
I --> J{Meaning supported?}
J -- No --> Z3[Return unsupported transaction error]
J -- Yes --> K[Find matching grant]
K --> L{Grant exists?}
L -- Yes --> M[Check grant limits]
L -- No --> N[Start execution or grant voting flow]
N --> O{User-agent decision}
O -- Reject --> Z4[Return no matching grant error]
O -- Allow once --> M
O -- Create grant --> P[Create grant with user-selected limits]
P --> Q[Persist grant]
Q --> M
M --> R{Limits exceeded?}
R -- Yes --> Z5[Return evaluation error]
R -- No --> S[Record transaction in logs]
S --> T[Produce signature]
T --> U[Return signature to SDK client]
note1[Limit checks include volume, count, and gas constraints.]
note2[Grant lookup depends on classified meaning, such as ether transfer or token transfer.]
K -. uses .-> note2
M -. checks .-> note1
```
### Policy Trait ### Policy Trait
| Method | Purpose | | Method | Purpose |

View File

@@ -48,6 +48,10 @@ backend = "cargo:cargo-features-manager"
version = "1.46.3" version = "1.46.3"
backend = "cargo:cargo-insta" backend = "cargo:cargo-insta"
[[tools."cargo:cargo-mutants"]]
version = "27.0.0"
backend = "cargo:cargo-mutants"
[[tools."cargo:cargo-nextest"]] [[tools."cargo:cargo-nextest"]]
version = "0.9.126" version = "0.9.126"
backend = "cargo:cargo-nextest" backend = "cargo:cargo-nextest"
@@ -111,30 +115,37 @@ backend = "core:python"
[tools.python."platforms.linux-arm64"] [tools.python."platforms.linux-arm64"]
checksum = "sha256:53700338695e402a1a1fe22be4a41fbdacc70e22bb308a48eca8ed67cb7992be" checksum = "sha256:53700338695e402a1a1fe22be4a41fbdacc70e22bb308a48eca8ed67cb7992be"
url = "https://github.com/astral-sh/python-build-standalone/releases/download/20260324/cpython-3.14.3+20260324-aarch64-unknown-linux-gnu-install_only_stripped.tar.gz" url = "https://github.com/astral-sh/python-build-standalone/releases/download/20260324/cpython-3.14.3+20260324-aarch64-unknown-linux-gnu-install_only_stripped.tar.gz"
provenance = "github-attestations"
[tools.python."platforms.linux-arm64-musl"] [tools.python."platforms.linux-arm64-musl"]
checksum = "sha256:53700338695e402a1a1fe22be4a41fbdacc70e22bb308a48eca8ed67cb7992be" checksum = "sha256:53700338695e402a1a1fe22be4a41fbdacc70e22bb308a48eca8ed67cb7992be"
url = "https://github.com/astral-sh/python-build-standalone/releases/download/20260324/cpython-3.14.3+20260324-aarch64-unknown-linux-gnu-install_only_stripped.tar.gz" url = "https://github.com/astral-sh/python-build-standalone/releases/download/20260324/cpython-3.14.3+20260324-aarch64-unknown-linux-gnu-install_only_stripped.tar.gz"
provenance = "github-attestations"
[tools.python."platforms.linux-x64"] [tools.python."platforms.linux-x64"]
checksum = "sha256:d7a9f970914bb4c88756fe3bdcc186d4feb90e9500e54f1db47dae4dc9687e39" checksum = "sha256:d7a9f970914bb4c88756fe3bdcc186d4feb90e9500e54f1db47dae4dc9687e39"
url = "https://github.com/astral-sh/python-build-standalone/releases/download/20260324/cpython-3.14.3+20260324-x86_64-unknown-linux-gnu-install_only_stripped.tar.gz" url = "https://github.com/astral-sh/python-build-standalone/releases/download/20260324/cpython-3.14.3+20260324-x86_64-unknown-linux-gnu-install_only_stripped.tar.gz"
provenance = "github-attestations"
[tools.python."platforms.linux-x64-musl"] [tools.python."platforms.linux-x64-musl"]
checksum = "sha256:d7a9f970914bb4c88756fe3bdcc186d4feb90e9500e54f1db47dae4dc9687e39" checksum = "sha256:d7a9f970914bb4c88756fe3bdcc186d4feb90e9500e54f1db47dae4dc9687e39"
url = "https://github.com/astral-sh/python-build-standalone/releases/download/20260324/cpython-3.14.3+20260324-x86_64-unknown-linux-gnu-install_only_stripped.tar.gz" url = "https://github.com/astral-sh/python-build-standalone/releases/download/20260324/cpython-3.14.3+20260324-x86_64-unknown-linux-gnu-install_only_stripped.tar.gz"
provenance = "github-attestations"
[tools.python."platforms.macos-arm64"] [tools.python."platforms.macos-arm64"]
checksum = "sha256:c43aecde4a663aebff99b9b83da0efec506479f1c3f98331442f33d2c43501f9" checksum = "sha256:c43aecde4a663aebff99b9b83da0efec506479f1c3f98331442f33d2c43501f9"
url = "https://github.com/astral-sh/python-build-standalone/releases/download/20260324/cpython-3.14.3+20260324-aarch64-apple-darwin-install_only_stripped.tar.gz" url = "https://github.com/astral-sh/python-build-standalone/releases/download/20260324/cpython-3.14.3+20260324-aarch64-apple-darwin-install_only_stripped.tar.gz"
provenance = "github-attestations"
[tools.python."platforms.macos-x64"] [tools.python."platforms.macos-x64"]
checksum = "sha256:9ab41dbc2f100a2a45d1833b9c11165f51051c558b5213eda9a9731d5948a0c0" checksum = "sha256:9ab41dbc2f100a2a45d1833b9c11165f51051c558b5213eda9a9731d5948a0c0"
url = "https://github.com/astral-sh/python-build-standalone/releases/download/20260324/cpython-3.14.3+20260324-x86_64-apple-darwin-install_only_stripped.tar.gz" url = "https://github.com/astral-sh/python-build-standalone/releases/download/20260324/cpython-3.14.3+20260324-x86_64-apple-darwin-install_only_stripped.tar.gz"
provenance = "github-attestations"
[tools.python."platforms.windows-x64"] [tools.python."platforms.windows-x64"]
checksum = "sha256:bbe19034b35b0267176a7442575ae7dc6343480fd4d35598cb7700173d431e09" checksum = "sha256:bbe19034b35b0267176a7442575ae7dc6343480fd4d35598cb7700173d431e09"
url = "https://github.com/astral-sh/python-build-standalone/releases/download/20260324/cpython-3.14.3+20260324-x86_64-pc-windows-msvc-install_only_stripped.tar.gz" url = "https://github.com/astral-sh/python-build-standalone/releases/download/20260324/cpython-3.14.3+20260324-x86_64-pc-windows-msvc-install_only_stripped.tar.gz"
provenance = "github-attestations"
[[tools.rust]] [[tools.rust]]
version = "1.93.0" version = "1.93.0"

View File

@@ -12,6 +12,7 @@ protoc = "29.6"
python = "3.14.3" python = "3.14.3"
ast-grep = "0.42.0" ast-grep = "0.42.0"
"cargo:cargo-edit" = "0.13.9" "cargo:cargo-edit" = "0.13.9"
"cargo:cargo-mutants" = "27.0.0"
[tasks.codegen] [tasks.codegen]
sources = ['protobufs/*.proto', 'protobufs/**/*.proto'] sources = ['protobufs/*.proto', 'protobufs/**/*.proto']

View File

@@ -36,6 +36,10 @@ message GasLimitExceededViolation {
} }
message EvalViolation { message EvalViolation {
message ChainIdMismatch {
uint64 expected = 1;
uint64 actual = 2;
}
oneof kind { oneof kind {
bytes invalid_target = 1; // 20-byte Ethereum address bytes invalid_target = 1; // 20-byte Ethereum address
GasLimitExceededViolation gas_limit_exceeded = 2; GasLimitExceededViolation gas_limit_exceeded = 2;
@@ -43,6 +47,8 @@ message EvalViolation {
google.protobuf.Empty volumetric_limit_exceeded = 4; google.protobuf.Empty volumetric_limit_exceeded = 4;
google.protobuf.Empty invalid_time = 5; google.protobuf.Empty invalid_time = 5;
google.protobuf.Empty invalid_transaction_type = 6; google.protobuf.Empty invalid_transaction_type = 6;
ChainIdMismatch chain_id_mismatch = 7;
} }
} }

View File

@@ -0,0 +1 @@
test_tool = "nextest"

2
server/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
mutants.out/
mutants.out.old/

946
server/Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -4,42 +4,169 @@ members = [
] ]
resolver = "3" resolver = "3"
[workspace.lints.clippy]
disallowed-methods = "deny"
[workspace.dependencies] [workspace.dependencies]
tonic = { version = "0.14.5", features = [
"deflate",
"gzip",
"tls-connect-info",
"zstd",
] }
tracing = "0.1.44"
tokio = { version = "1.50.0", features = ["full"] }
ed25519-dalek = { version = "3.0.0-pre.6", features = ["rand_core"] }
chrono = { version = "0.4.44", features = ["serde"] }
rand = "0.10.0"
rustls = { version = "0.23.37", features = ["aws-lc-rs"] }
smlang = "0.8.0"
thiserror = "2.0.18"
async-trait = "0.1.89"
futures = "0.3.32"
tokio-stream = { version = "0.1.18", features = ["full"] }
kameo = "0.19.2"
prost-types = { version = "0.14.3", features = ["chrono"] }
x25519-dalek = { version = "2.0.1", features = ["getrandom"] }
rstest = "0.26.1"
rustls-pki-types = "1.14.0"
alloy = "1.7.3" alloy = "1.7.3"
rcgen = { version = "0.14.7", features = [ async-trait = "0.1.89"
"aws_lc_rs", base64 = "0.22.1"
"pem", chrono = { version = "0.4.44", features = ["serde"] }
"x509-parser", ed25519-dalek = { version = "3.0.0-pre.6", features = ["rand_core"] }
"zeroize", futures = "0.3.32"
], default-features = false } hmac = "0.12.1"
k256 = { version = "0.13.4", features = ["ecdsa", "pkcs8"] } k256 = { version = "0.13.4", features = ["ecdsa", "pkcs8"] }
rsa = { version = "0.9", features = ["sha2"] } kameo = { version = "0.20.0", git = "https://github.com/tqwewe/kameo" } # hold this until new patch version is released
sha2 = "0.10"
spki = "0.7"
miette = { version = "7.6.0", features = ["fancy", "serde"] } miette = { version = "7.6.0", features = ["fancy", "serde"] }
ml-dsa = { version = "0.1.0-rc.8", features = ["zeroize"] }
mutants = "0.0.4"
prost = "0.14.3"
prost-types = { version = "0.14.3", features = ["chrono"] }
rand = "0.10.0"
rcgen = { version = "0.14.7", features = [ "aws_lc_rs", "pem", "x509-parser", "zeroize" ], default-features = false }
rsa = { version = "0.9", features = ["sha2"] }
rstest = "0.26.1"
rustls = { version = "0.23.37", features = ["aws-lc-rs", "logging", "prefer-post-quantum", "std"], default-features = false }
rustls-pki-types = "1.14.0"
sha2 = "0.10"
smlang = "0.8.0"
spki = "0.7"
thiserror = "2.0.18"
tokio = { version = "1.50.0", features = ["full"] }
tokio-stream = { version = "0.1.18", features = ["full"] }
tonic = { version = "0.14.5", features = [ "deflate", "gzip", "tls-connect-info", "zstd" ] }
tracing = "0.1.44"
x25519-dalek = { version = "2.0.1", features = ["getrandom"] }
[workspace.lints.rust]
missing_unsafe_on_extern = "deny"
unsafe_attr_outside_unsafe = "deny"
unsafe_op_in_unsafe_fn = "deny"
unstable_features = "deny"
deprecated_safe_2024 = "warn"
ffi_unwind_calls = "warn"
linker_messages = "warn"
elided_lifetimes_in_paths = "warn"
explicit_outlives_requirements = "warn"
impl-trait-overcaptures = "warn"
impl-trait-redundant-captures = "warn"
redundant_lifetimes = "warn"
single_use_lifetimes = "warn"
unused_lifetimes = "warn"
macro_use_extern_crate = "warn"
redundant_imports = "warn"
unused_import_braces = "warn"
unused_macro_rules = "warn"
unused_qualifications = "warn"
unit_bindings = "warn"
# missing_docs = "warn" # ENABLE BY THE FIRST MAJOR VERSION!!
unnameable_types = "warn"
variant_size_differences = "warn"
[workspace.lints.clippy]
derive_partial_eq_without_eq = "allow"
future_not_send = "allow"
inconsistent_struct_constructor = "allow"
inline_always = "allow"
missing_errors_doc = "allow"
missing_fields_in_debug = "allow"
missing_panics_doc = "allow"
must_use_candidate = "allow"
needless_pass_by_ref_mut = "allow"
pub_underscore_fields = "allow"
redundant_pub_crate = "allow"
uninhabited_references = "allow" # safe with unsafe_code = "forbid" and standard uninhabited pattern (match *self {})
# restriction lints
alloc_instead_of_core = "warn"
allow_attributes_without_reason = "warn"
as_conversions = "warn"
assertions_on_result_states = "warn"
cfg_not_test = "warn"
clone_on_ref_ptr = "warn"
cognitive_complexity = "warn"
create_dir = "warn"
dbg_macro = "warn"
decimal_literal_representation = "warn"
default_union_representation = "warn"
deref_by_slicing = "warn"
disallowed_script_idents = "warn"
doc_include_without_cfg = "warn"
empty_drop = "warn"
empty_enum_variants_with_brackets = "warn"
empty_structs_with_brackets = "warn"
error_impl_error = "warn"
exit = "warn"
filetype_is_file = "warn"
float_arithmetic = "warn"
float_cmp_const = "warn"
fn_to_numeric_cast_any = "warn"
get_unwrap = "warn"
if_then_some_else_none = "warn"
indexing_slicing = "warn"
infinite_loop = "warn"
inline_asm_x86_att_syntax = "warn"
inline_asm_x86_intel_syntax = "warn"
integer_division = "warn"
large_include_file = "warn"
lossy_float_literal = "warn"
map_with_unused_argument_over_ranges = "warn"
mem_forget = "warn"
missing_assert_message = "warn"
mixed_read_write_in_expression = "warn"
modulo_arithmetic = "warn"
multiple_unsafe_ops_per_block = "warn"
mutex_atomic = "warn"
mutex_integer = "warn"
needless_raw_strings = "warn"
non_ascii_literal = "warn"
non_zero_suggestions = "warn"
pathbuf_init_then_push = "warn"
pointer_format = "warn"
precedence_bits = "warn"
pub_without_shorthand = "warn"
rc_buffer = "warn"
rc_mutex = "warn"
redundant_test_prefix = "warn"
redundant_type_annotations = "warn"
ref_patterns = "warn"
renamed_function_params = "warn"
rest_pat_in_fully_bound_structs = "warn"
return_and_then = "warn"
semicolon_inside_block = "warn"
str_to_string = "warn"
string_add = "warn"
string_lit_chars_any = "warn"
string_slice = "warn"
suspicious_xor_used_as_pow = "warn"
try_err = "warn"
undocumented_unsafe_blocks = "warn"
uninlined_format_args = "warn"
unnecessary_safety_comment = "warn"
unnecessary_safety_doc = "warn"
unnecessary_self_imports = "warn"
unneeded_field_pattern = "warn"
unused_result_ok = "warn"
verbose_file_reads = "warn"
# cargo lints
negative_feature_names = "warn"
redundant_feature_names = "warn"
wildcard_dependencies = "warn"
# ENABLE BY THE FIRST MAJOR VERSION!!
# todo = "warn"
# unimplemented = "warn"
# panic = "warn"
# panic_in_result_fn = "warn"
#
# cargo_common_metadata = "warn"
# multiple_crate_versions = "warn" # a controversial option since it's really difficult to maintain
disallowed_methods = "deny"
nursery = { level = "warn", priority = -1 }
pedantic = { level = "warn", priority = -1 }

View File

@@ -7,3 +7,22 @@ disallowed-methods = [
{ path = "rsa::traits::Decryptor::decrypt", reason = "RSA decryption is forbidden (RUSTSEC-2023-0071 Marvin Attack). This blocks decrypt() on rsa::{pkcs1v15,oaep}::DecryptingKey." }, { path = "rsa::traits::Decryptor::decrypt", reason = "RSA decryption is forbidden (RUSTSEC-2023-0071 Marvin Attack). This blocks decrypt() on rsa::{pkcs1v15,oaep}::DecryptingKey." },
{ path = "rsa::traits::RandomizedDecryptor::decrypt_with_rng", reason = "RSA decryption is forbidden (RUSTSEC-2023-0071 Marvin Attack). This blocks decrypt_with_rng() on rsa::{pkcs1v15,oaep}::DecryptingKey." }, { path = "rsa::traits::RandomizedDecryptor::decrypt_with_rng", reason = "RSA decryption is forbidden (RUSTSEC-2023-0071 Marvin Attack). This blocks decrypt_with_rng() on rsa::{pkcs1v15,oaep}::DecryptingKey." },
] ]
allow-indexing-slicing-in-tests = true
allow-panic-in-tests = true
check-inconsistent-struct-field-initializers = true
suppress-restriction-lint-in-const = true
allow-renamed-params-for = [
"core::convert::From",
"core::convert::TryFrom",
"core::str::FromStr",
"kameo::actor::Actor",
]
module-items-ordered-within-groupings = ["UPPER_SNAKE_CASE"]
source-item-ordering = ["enum"]
trait-assoc-item-kinds-order = [
"const",
"type",
"fn",
] # community tested standard

View File

@@ -13,12 +13,12 @@ evm = ["dep:alloy"]
[dependencies] [dependencies]
arbiter-proto.path = "../arbiter-proto" arbiter-proto.path = "../arbiter-proto"
arbiter-crypto.path = "../arbiter-crypto"
alloy = { workspace = true, optional = true } alloy = { workspace = true, optional = true }
tonic.workspace = true tonic.workspace = true
tonic.features = ["tls-aws-lc"] tonic.features = ["tls-aws-lc"]
tokio.workspace = true tokio.workspace = true
tokio-stream.workspace = true tokio-stream.workspace = true
ed25519-dalek.workspace = true
thiserror.workspace = true thiserror.workspace = true
http = "1.4.0" http = "1.4.0"
rustls-webpki = { version = "0.103.10", features = ["aws-lc-rs"] } rustls-webpki = { version = "0.103.10", features = ["aws-lc-rs"] }

View File

@@ -1,5 +1,6 @@
use arbiter_crypto::authn::{CLIENT_CONTEXT, SigningKey, format_challenge};
use arbiter_proto::{ use arbiter_proto::{
ClientMetadata, format_challenge, ClientMetadata,
proto::{ proto::{
client::{ client::{
ClientRequest, ClientRequest,
@@ -14,7 +15,6 @@ use arbiter_proto::{
shared::ClientInfo as ProtoClientInfo, shared::ClientInfo as ProtoClientInfo,
}, },
}; };
use ed25519_dalek::Signer as _;
use crate::{ use crate::{
storage::StorageError, storage::StorageError,
@@ -23,20 +23,20 @@ use crate::{
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum AuthError { pub enum AuthError {
#[error("Auth challenge was not returned by server")]
MissingAuthChallenge,
#[error("Client approval denied by User Agent")] #[error("Client approval denied by User Agent")]
ApprovalDenied, ApprovalDenied,
#[error("Auth challenge was not returned by server")]
MissingAuthChallenge,
#[error("No User Agents online to approve client")] #[error("No User Agents online to approve client")]
NoUserAgentsOnline, NoUserAgentsOnline,
#[error("Unexpected auth response payload")]
UnexpectedAuthResponse,
#[error("Signing key storage error")] #[error("Signing key storage error")]
Storage(#[from] StorageError), Storage(#[from] StorageError),
#[error("Unexpected auth response payload")]
UnexpectedAuthResponse,
} }
fn map_auth_result(code: i32) -> AuthError { fn map_auth_result(code: i32) -> AuthError {
@@ -54,14 +54,14 @@ fn map_auth_result(code: i32) -> AuthError {
async fn send_auth_challenge_request( async fn send_auth_challenge_request(
transport: &mut ClientTransport, transport: &mut ClientTransport,
metadata: ClientMetadata, metadata: ClientMetadata,
key: &ed25519_dalek::SigningKey, key: &SigningKey,
) -> std::result::Result<(), AuthError> { ) -> Result<(), AuthError> {
transport transport
.send(ClientRequest { .send(ClientRequest {
request_id: next_request_id(), request_id: next_request_id(),
payload: Some(ClientRequestPayload::Auth(proto_auth::Request { payload: Some(ClientRequestPayload::Auth(proto_auth::Request {
payload: Some(AuthRequestPayload::ChallengeRequest(AuthChallengeRequest { payload: Some(AuthRequestPayload::ChallengeRequest(AuthChallengeRequest {
pubkey: key.verifying_key().to_bytes().to_vec(), pubkey: key.public_key().to_bytes(),
client_info: Some(ProtoClientInfo { client_info: Some(ProtoClientInfo {
name: metadata.name, name: metadata.name,
description: metadata.description, description: metadata.description,
@@ -76,7 +76,7 @@ async fn send_auth_challenge_request(
async fn receive_auth_challenge( async fn receive_auth_challenge(
transport: &mut ClientTransport, transport: &mut ClientTransport,
) -> std::result::Result<AuthChallenge, AuthError> { ) -> Result<AuthChallenge, AuthError> {
let response = transport let response = transport
.recv() .recv()
.await .await
@@ -95,11 +95,14 @@ async fn receive_auth_challenge(
async fn send_auth_challenge_solution( async fn send_auth_challenge_solution(
transport: &mut ClientTransport, transport: &mut ClientTransport,
key: &ed25519_dalek::SigningKey, key: &SigningKey,
challenge: AuthChallenge, challenge: AuthChallenge,
) -> std::result::Result<(), AuthError> { ) -> Result<(), AuthError> {
let challenge_payload = format_challenge(challenge.nonce, &challenge.pubkey); let challenge_payload = format_challenge(challenge.nonce, &challenge.pubkey);
let signature = key.sign(&challenge_payload).to_bytes().to_vec(); let signature = key
.sign_message(&challenge_payload, CLIENT_CONTEXT)
.map_err(|_| AuthError::UnexpectedAuthResponse)?
.to_bytes();
transport transport
.send(ClientRequest { .send(ClientRequest {
@@ -114,9 +117,7 @@ async fn send_auth_challenge_solution(
.map_err(|_| AuthError::UnexpectedAuthResponse) .map_err(|_| AuthError::UnexpectedAuthResponse)
} }
async fn receive_auth_confirmation( async fn receive_auth_confirmation(transport: &mut ClientTransport) -> Result<(), AuthError> {
transport: &mut ClientTransport,
) -> std::result::Result<(), AuthError> {
let response = transport let response = transport
.recv() .recv()
.await .await
@@ -137,11 +138,11 @@ async fn receive_auth_confirmation(
} }
} }
pub(crate) async fn authenticate( pub async fn authenticate(
transport: &mut ClientTransport, transport: &mut ClientTransport,
metadata: ClientMetadata, metadata: ClientMetadata,
key: &ed25519_dalek::SigningKey, key: &SigningKey,
) -> std::result::Result<(), AuthError> { ) -> Result<(), AuthError> {
send_auth_challenge_request(transport, metadata, key).await?; send_auth_challenge_request(transport, metadata, key).await?;
let challenge = receive_auth_challenge(transport).await?; let challenge = receive_auth_challenge(transport).await?;
send_auth_challenge_solution(transport, key, challenge).await?; send_auth_challenge_solution(transport, key, challenge).await?;

View File

@@ -29,16 +29,16 @@ async fn main() {
} }
}; };
println!("{:#?}", url); println!("{url:#?}");
let metadata = ClientMetadata { let metadata = ClientMetadata {
name: "arbiter-client test_connect".to_string(), name: "arbiter-client test_connect".to_owned(),
description: Some("Manual connection smoke test".to_string()), description: Some("Manual connection smoke test".to_owned()),
version: Some(env!("CARGO_PKG_VERSION").to_string()), version: Some(env!("CARGO_PKG_VERSION").to_owned()),
}; };
match ArbiterClient::connect(url, metadata).await { match ArbiterClient::connect(url, metadata).await {
Ok(_) => println!("Connected and authenticated successfully."), Ok(_) => println!("Connected and authenticated successfully."),
Err(err) => eprintln!("Failed to connect: {:#?}", err), Err(err) => eprintln!("Failed to connect: {err:#?}"),
} }
} }

View File

@@ -1,3 +1,4 @@
use arbiter_crypto::authn::SigningKey;
use arbiter_proto::{ use arbiter_proto::{
ClientMetadata, proto::arbiter_service_client::ArbiterServiceClient, url::ArbiterUrl, ClientMetadata, proto::arbiter_service_client::ArbiterServiceClient, url::ArbiterUrl,
}; };
@@ -17,33 +18,39 @@ use crate::{
use crate::wallets::evm::ArbiterEvmWallet; use crate::wallets::evm::ArbiterEvmWallet;
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum Error { pub enum ArbiterClientError {
#[error("gRPC error")] #[error("Authentication error")]
Grpc(#[from] tonic::Status), Authentication(#[from] AuthError),
#[error("Could not establish connection")] #[error("Could not establish connection")]
Connection(#[from] tonic::transport::Error), Connection(#[from] tonic::transport::Error),
#[error("Invalid server URI")] #[error("gRPC error")]
InvalidUri(#[from] http::uri::InvalidUri), Grpc(#[from] tonic::Status),
#[error("Invalid CA certificate")] #[error("Invalid CA certificate")]
InvalidCaCert(#[from] webpki::Error), InvalidCaCert(#[from] webpki::Error),
#[error("Authentication error")] #[error("Invalid server URI")]
Authentication(#[from] AuthError), InvalidUri(#[from] http::uri::InvalidUri),
#[error("Storage error")] #[error("Storage error")]
Storage(#[from] StorageError), Storage(#[from] StorageError),
} }
pub struct ArbiterClient { pub struct ArbiterClient {
#[allow(dead_code)] #[expect(
dead_code,
reason = "transport will be used in future methods for sending requests and receiving responses"
)]
transport: Arc<Mutex<ClientTransport>>, transport: Arc<Mutex<ClientTransport>>,
} }
impl ArbiterClient { impl ArbiterClient {
pub async fn connect(url: ArbiterUrl, metadata: ClientMetadata) -> Result<Self, Error> { pub async fn connect(
url: ArbiterUrl,
metadata: ClientMetadata,
) -> Result<Self, ArbiterClientError> {
let storage = FileSigningKeyStorage::from_default_location()?; let storage = FileSigningKeyStorage::from_default_location()?;
Self::connect_with_storage(url, metadata, &storage).await Self::connect_with_storage(url, metadata, &storage).await
} }
@@ -52,7 +59,7 @@ impl ArbiterClient {
url: ArbiterUrl, url: ArbiterUrl,
metadata: ClientMetadata, metadata: ClientMetadata,
storage: &S, storage: &S,
) -> Result<Self, Error> { ) -> Result<Self, ArbiterClientError> {
let key = storage.load_or_create()?; let key = storage.load_or_create()?;
Self::connect_with_key(url, metadata, key).await Self::connect_with_key(url, metadata, key).await
} }
@@ -60,8 +67,8 @@ impl ArbiterClient {
pub async fn connect_with_key( pub async fn connect_with_key(
url: ArbiterUrl, url: ArbiterUrl,
metadata: ClientMetadata, metadata: ClientMetadata,
key: ed25519_dalek::SigningKey, key: SigningKey,
) -> Result<Self, Error> { ) -> Result<Self, ArbiterClientError> {
let anchor = webpki::anchor_from_trusted_cert(&url.ca_cert)?.to_owned(); let anchor = webpki::anchor_from_trusted_cert(&url.ca_cert)?.to_owned();
let tls = ClientTlsConfig::new().trust_anchor(anchor); let tls = ClientTlsConfig::new().trust_anchor(anchor);
@@ -88,7 +95,8 @@ impl ArbiterClient {
} }
#[cfg(feature = "evm")] #[cfg(feature = "evm")]
pub async fn evm_wallets(&self) -> Result<Vec<ArbiterEvmWallet>, Error> { #[expect(clippy::unused_async, reason = "false positive")]
pub async fn evm_wallets(&self) -> Result<Vec<ArbiterEvmWallet>, ArbiterClientError> {
todo!("fetch EVM wallet list from server") todo!("fetch EVM wallet list from server")
} }
} }

View File

@@ -5,7 +5,7 @@ mod transport;
pub mod wallets; pub mod wallets;
pub use auth::AuthError; pub use auth::AuthError;
pub use client::{ArbiterClient, Error}; pub use client::{ArbiterClient, ArbiterClientError};
pub use storage::{FileSigningKeyStorage, SigningKeyStorage, StorageError}; pub use storage::{FileSigningKeyStorage, SigningKeyStorage, StorageError};
#[cfg(feature = "evm")] #[cfg(feature = "evm")]

View File

@@ -1,17 +1,18 @@
use arbiter_crypto::authn::SigningKey;
use arbiter_proto::home_path; use arbiter_proto::home_path;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum StorageError { pub enum StorageError {
#[error("I/O error")]
Io(#[from] std::io::Error),
#[error("Invalid signing key length in storage: expected {expected} bytes, got {actual} bytes")] #[error("Invalid signing key length in storage: expected {expected} bytes, got {actual} bytes")]
InvalidKeyLength { expected: usize, actual: usize }, InvalidKeyLength { expected: usize, actual: usize },
#[error("I/O error")]
Io(#[from] std::io::Error),
} }
pub trait SigningKeyStorage { pub trait SigningKeyStorage {
fn load_or_create(&self) -> std::result::Result<ed25519_dalek::SigningKey, StorageError>; fn load_or_create(&self) -> Result<SigningKey, StorageError>;
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -20,17 +21,17 @@ pub struct FileSigningKeyStorage {
} }
impl FileSigningKeyStorage { impl FileSigningKeyStorage {
pub const DEFAULT_FILE_NAME: &str = "sdk_client_ed25519.key"; pub const DEFAULT_FILE_NAME: &str = "sdk_client_ml_dsa.key";
pub fn new(path: impl Into<PathBuf>) -> Self { pub fn new(path: impl Into<PathBuf>) -> Self {
Self { path: path.into() } Self { path: path.into() }
} }
pub fn from_default_location() -> std::result::Result<Self, StorageError> { pub fn from_default_location() -> Result<Self, StorageError> {
Ok(Self::new(home_path()?.join(Self::DEFAULT_FILE_NAME))) Ok(Self::new(home_path()?.join(Self::DEFAULT_FILE_NAME)))
} }
fn read_key(path: &Path) -> std::result::Result<ed25519_dalek::SigningKey, StorageError> { fn read_key(path: &Path) -> Result<SigningKey, StorageError> {
let bytes = std::fs::read(path)?; let bytes = std::fs::read(path)?;
let raw: [u8; 32] = let raw: [u8; 32] =
bytes bytes
@@ -39,12 +40,12 @@ impl FileSigningKeyStorage {
expected: 32, expected: 32,
actual: v.len(), actual: v.len(),
})?; })?;
Ok(ed25519_dalek::SigningKey::from_bytes(&raw)) Ok(SigningKey::from_seed(raw))
} }
} }
impl SigningKeyStorage for FileSigningKeyStorage { impl SigningKeyStorage for FileSigningKeyStorage {
fn load_or_create(&self) -> std::result::Result<ed25519_dalek::SigningKey, StorageError> { fn load_or_create(&self) -> Result<SigningKey, StorageError> {
if let Some(parent) = self.path.parent() { if let Some(parent) = self.path.parent() {
std::fs::create_dir_all(parent)?; std::fs::create_dir_all(parent)?;
} }
@@ -53,8 +54,8 @@ impl SigningKeyStorage for FileSigningKeyStorage {
return Self::read_key(&self.path); return Self::read_key(&self.path);
} }
let key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let key = SigningKey::generate();
let raw_key = key.to_bytes(); let raw_key = key.to_seed();
// Use create_new to prevent accidental overwrite if another process creates the key first. // Use create_new to prevent accidental overwrite if another process creates the key first.
match std::fs::OpenOptions::new() match std::fs::OpenOptions::new()
@@ -103,7 +104,7 @@ mod tests {
.load_or_create() .load_or_create()
.expect("second load_or_create should read same key"); .expect("second load_or_create should read same key");
assert_eq!(key_a.to_bytes(), key_b.to_bytes()); assert_eq!(key_a.to_seed(), key_b.to_seed());
assert!(path.exists()); assert!(path.exists());
std::fs::remove_file(path).expect("temp key file should be removable"); std::fs::remove_file(path).expect("temp key file should be removable");
@@ -124,7 +125,7 @@ mod tests {
assert_eq!(expected, 32); assert_eq!(expected, 32);
assert_eq!(actual, 31); assert_eq!(actual, 31);
} }
other => panic!("unexpected error: {other:?}"), other @ StorageError::Io(_) => panic!("unexpected error: {other:?}"),
} }
std::fs::remove_file(path).expect("temp key file should be removable"); std::fs::remove_file(path).expect("temp key file should be removable");

View File

@@ -2,15 +2,15 @@ use arbiter_proto::proto::client::{ClientRequest, ClientResponse};
use std::sync::atomic::{AtomicI32, Ordering}; use std::sync::atomic::{AtomicI32, Ordering};
use tokio::sync::mpsc; use tokio::sync::mpsc;
pub(crate) const BUFFER_LENGTH: usize = 16; pub const BUFFER_LENGTH: usize = 16;
static NEXT_REQUEST_ID: AtomicI32 = AtomicI32::new(1); static NEXT_REQUEST_ID: AtomicI32 = AtomicI32::new(1);
pub(crate) fn next_request_id() -> i32 { pub fn next_request_id() -> i32 {
NEXT_REQUEST_ID.fetch_add(1, Ordering::Relaxed) NEXT_REQUEST_ID.fetch_add(1, Ordering::Relaxed)
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub(crate) enum ClientSignError { pub enum ClientSignError {
#[error("Transport channel closed")] #[error("Transport channel closed")]
ChannelClosed, ChannelClosed,
@@ -18,7 +18,7 @@ pub(crate) enum ClientSignError {
ConnectionClosed, ConnectionClosed,
} }
pub(crate) struct ClientTransport { pub struct ClientTransport {
pub(crate) sender: mpsc::Sender<ClientRequest>, pub(crate) sender: mpsc::Sender<ClientRequest>,
pub(crate) receiver: tonic::Streaming<ClientResponse>, pub(crate) receiver: tonic::Streaming<ClientResponse>,
} }
@@ -27,18 +27,17 @@ impl ClientTransport {
pub(crate) async fn send( pub(crate) async fn send(
&mut self, &mut self,
request: ClientRequest, request: ClientRequest,
) -> std::result::Result<(), ClientSignError> { ) -> Result<(), ClientSignError> {
self.sender self.sender
.send(request) .send(request)
.await .await
.map_err(|_| ClientSignError::ChannelClosed) .map_err(|_| ClientSignError::ChannelClosed)
} }
pub(crate) async fn recv(&mut self) -> std::result::Result<ClientResponse, ClientSignError> { pub(crate) async fn recv(&mut self) -> Result<ClientResponse, ClientSignError> {
match self.receiver.message().await { match self.receiver.message().await {
Ok(Some(resp)) => Ok(resp), Ok(Some(resp)) => Ok(resp),
Ok(None) => Err(ClientSignError::ConnectionClosed), Ok(None) | Err(_) => Err(ClientSignError::ConnectionClosed),
Err(_) => Err(ClientSignError::ConnectionClosed),
} }
} }
} }

View File

@@ -59,7 +59,11 @@ pub struct ArbiterEvmWallet {
} }
impl ArbiterEvmWallet { impl ArbiterEvmWallet {
pub(crate) fn new(transport: Arc<Mutex<ClientTransport>>, address: Address) -> Self { #[expect(
dead_code,
reason = "new will be used in future methods for creating wallets with different parameters"
)]
pub(crate) const fn new(transport: Arc<Mutex<ClientTransport>>, address: Address) -> Self {
Self { Self {
transport, transport,
address, address,
@@ -67,11 +71,12 @@ impl ArbiterEvmWallet {
} }
} }
pub fn address(&self) -> Address { pub const fn address(&self) -> Address {
self.address self.address
} }
pub fn with_chain_id(mut self, chain_id: ChainId) -> Self { #[must_use]
pub const fn with_chain_id(mut self, chain_id: ChainId) -> Self {
self.chain_id = Some(chain_id); self.chain_id = Some(chain_id);
self self
} }
@@ -146,6 +151,7 @@ impl TxSigner<Signature> for ArbiterEvmWallet {
.recv() .recv()
.await .await
.map_err(|_| Error::other("failed to receive evm sign transaction response"))?; .map_err(|_| Error::other("failed to receive evm sign transaction response"))?;
drop(transport);
if response.request_id != Some(request_id) { if response.request_id != Some(request_id) {
return Err(Error::other( return Err(Error::other(

View File

@@ -0,0 +1 @@
/target

View File

@@ -0,0 +1,21 @@
[package]
name = "arbiter-crypto"
version = "0.1.0"
edition = "2024"
[dependencies]
ml-dsa = {workspace = true, optional = true }
rand = {workspace = true, optional = true}
base64 = {workspace = true, optional = true }
memsafe = {version = "0.4.0", optional = true}
hmac.workspace = true
alloy.workspace = true
chrono.workspace = true
[lints]
workspace = true
[features]
default = ["authn", "safecell"]
authn = ["dep:ml-dsa", "dep:rand", "dep:base64"]
safecell = ["dep:memsafe"]

View File

@@ -0,0 +1,2 @@
pub mod v1;
pub use v1::*;

View File

@@ -0,0 +1,194 @@
use base64::{Engine as _, prelude::BASE64_STANDARD};
use hmac::digest::Digest;
use ml_dsa::{
EncodedVerifyingKey, Error, KeyGen, MlDsa87, Seed, Signature as MlDsaSignature,
SigningKey as MlDsaSigningKey, VerifyingKey as MlDsaVerifyingKey, signature::Keypair as _,
};
pub static CLIENT_CONTEXT: &[u8] = b"arbiter_client";
pub static USERAGENT_CONTEXT: &[u8] = b"arbiter_user_agent";
pub fn format_challenge(nonce: i32, pubkey: &[u8]) -> Vec<u8> {
let concat_form = format!("{}:{}", nonce, BASE64_STANDARD.encode(pubkey));
concat_form.into_bytes()
}
pub type KeyParams = MlDsa87;
#[derive(Clone, Debug, PartialEq)]
pub struct PublicKey(Box<MlDsaVerifyingKey<KeyParams>>);
impl crate::hashing::Hashable for PublicKey {
fn hash<H: Digest>(&self, hasher: &mut H) {
hasher.update(self.to_bytes());
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct Signature(Box<MlDsaSignature<KeyParams>>);
#[derive(Debug)]
pub struct SigningKey(Box<MlDsaSigningKey<KeyParams>>);
impl PublicKey {
pub fn to_bytes(&self) -> Vec<u8> {
self.0.encode().0.to_vec()
}
#[must_use]
pub fn verify(&self, nonce: i32, context: &[u8], signature: &Signature) -> bool {
self.0.verify_with_context(
&format_challenge(nonce, &self.to_bytes()),
context,
&signature.0,
)
}
}
impl Signature {
pub fn to_bytes(&self) -> Vec<u8> {
self.0.encode().0.to_vec()
}
}
impl SigningKey {
pub fn generate() -> Self {
Self(Box::new(KeyParams::key_gen(&mut rand::rng())))
}
pub fn from_seed(seed: [u8; 32]) -> Self {
Self(Box::new(KeyParams::from_seed(&Seed::from(seed))))
}
pub fn to_seed(&self) -> [u8; 32] {
self.0.to_seed().into()
}
pub fn public_key(&self) -> PublicKey {
self.0.verifying_key().into()
}
pub fn sign_message(&self, message: &[u8], context: &[u8]) -> Result<Signature, Error> {
self.0
.signing_key()
.sign_deterministic(message, context)
.map(Into::into)
}
pub fn sign_challenge(&self, nonce: i32, context: &[u8]) -> Result<Signature, Error> {
self.sign_message(
&format_challenge(nonce, &self.public_key().to_bytes()),
context,
)
}
}
impl From<MlDsaVerifyingKey<KeyParams>> for PublicKey {
fn from(value: MlDsaVerifyingKey<KeyParams>) -> Self {
Self(Box::new(value))
}
}
impl From<MlDsaSignature<KeyParams>> for Signature {
fn from(value: MlDsaSignature<KeyParams>) -> Self {
Self(Box::new(value))
}
}
impl From<MlDsaSigningKey<KeyParams>> for SigningKey {
fn from(value: MlDsaSigningKey<KeyParams>) -> Self {
Self(Box::new(value))
}
}
impl TryFrom<Vec<u8>> for PublicKey {
type Error = ();
fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
Self::try_from(value.as_slice())
}
}
impl TryFrom<&'_ [u8]> for PublicKey {
type Error = ();
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
let encoded = EncodedVerifyingKey::<KeyParams>::try_from(value).map_err(|_| ())?;
Ok(Self(Box::new(MlDsaVerifyingKey::decode(&encoded))))
}
}
impl TryFrom<Vec<u8>> for Signature {
type Error = ();
fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
Self::try_from(value.as_slice())
}
}
impl TryFrom<&'_ [u8]> for Signature {
type Error = ();
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
MlDsaSignature::try_from(value)
.map(|sig| Self(Box::new(sig)))
.map_err(|_| ())
}
}
#[cfg(test)]
mod tests {
use ml_dsa::{KeyGen, MlDsa87, signature::Keypair as _};
use super::{CLIENT_CONTEXT, PublicKey, Signature, SigningKey, USERAGENT_CONTEXT};
#[test]
fn public_key_round_trip_decodes() {
let key = MlDsa87::key_gen(&mut rand::rng());
let encoded = PublicKey::from(key.verifying_key()).to_bytes();
let decoded = PublicKey::try_from(encoded.as_slice()).expect("public key should decode");
assert_eq!(decoded, PublicKey::from(key.verifying_key()));
}
#[test]
fn signature_round_trip_decodes() {
let key = SigningKey::generate();
let signature = key
.sign_message(b"challenge", CLIENT_CONTEXT)
.expect("signature should be created");
let decoded =
Signature::try_from(signature.to_bytes().as_slice()).expect("signature should decode");
assert_eq!(decoded, signature);
}
#[test]
fn challenge_verification_uses_context_and_canonical_key_bytes() {
let key = SigningKey::generate();
let public_key = key.public_key();
let nonce = 17;
let signature = key
.sign_challenge(nonce, CLIENT_CONTEXT)
.expect("signature should be created");
assert!(public_key.verify(nonce, CLIENT_CONTEXT, &signature));
assert!(!public_key.verify(nonce, USERAGENT_CONTEXT, &signature));
}
#[test]
fn signing_key_round_trip_seed_preserves_public_key_and_signing() {
let original = SigningKey::generate();
let restored = SigningKey::from_seed(original.to_seed());
assert_eq!(restored.public_key(), original.public_key());
let signature = restored
.sign_challenge(9, CLIENT_CONTEXT)
.expect("signature should be created");
assert!(restored.public_key().verify(9, CLIENT_CONTEXT, &signature));
}
}

View File

@@ -0,0 +1,111 @@
pub use hmac::digest::Digest;
use std::collections::HashSet;
/// Deterministically hash a value by feeding its fields into the hasher in a consistent order.
#[diagnostic::on_unimplemented(
note = "for local types consider adding `#[derive(arbiter_macros::Hashable)]` to your `{Self}` type",
note = "for types from other crates check whether the crate offers a `Hashable` implementation"
)]
pub trait Hashable {
fn hash<H: Digest>(&self, hasher: &mut H);
}
macro_rules! impl_numeric {
($($t:ty),*) => {
$(
impl Hashable for $t {
fn hash<H: Digest>(&self, hasher: &mut H) {
hasher.update(&self.to_be_bytes());
}
}
)*
};
}
impl_numeric!(u8, u16, u32, u64, i8, i16, i32, i64);
impl Hashable for &[u8] {
fn hash<H: Digest>(&self, hasher: &mut H) {
hasher.update(self);
}
}
impl Hashable for String {
fn hash<H: Digest>(&self, hasher: &mut H) {
hasher.update(self.as_bytes());
}
}
impl<T: Hashable + PartialOrd> Hashable for Vec<T> {
fn hash<H: Digest>(&self, hasher: &mut H) {
let ref_sorted = {
let mut sorted = self.iter().collect::<Vec<_>>();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
sorted
};
for item in ref_sorted {
item.hash(hasher);
}
}
}
impl<T: Hashable + PartialOrd, S: std::hash::BuildHasher> Hashable for HashSet<T, S> {
fn hash<H: Digest>(&self, hasher: &mut H) {
let ref_sorted = {
let mut sorted = self.iter().collect::<Vec<_>>();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
sorted
};
for item in ref_sorted {
item.hash(hasher);
}
}
}
impl<T: Hashable> Hashable for Option<T> {
fn hash<H: Digest>(&self, hasher: &mut H) {
match self {
Some(value) => {
hasher.update([1]);
value.hash(hasher);
}
None => hasher.update([0]),
}
}
}
impl<T: Hashable> Hashable for Box<T> {
fn hash<H: Digest>(&self, hasher: &mut H) {
self.as_ref().hash(hasher);
}
}
impl<T: Hashable> Hashable for &T {
fn hash<H: Digest>(&self, hasher: &mut H) {
(*self).hash(hasher);
}
}
impl Hashable for alloy::primitives::Address {
fn hash<H: Digest>(&self, hasher: &mut H) {
hasher.update(self.as_slice());
}
}
impl Hashable for alloy::primitives::U256 {
fn hash<H: Digest>(&self, hasher: &mut H) {
hasher.update(self.to_be_bytes::<32>());
}
}
impl Hashable for chrono::Duration {
fn hash<H: Digest>(&self, hasher: &mut H) {
hasher.update(self.num_seconds().to_be_bytes());
}
}
impl Hashable for chrono::DateTime<chrono::Utc> {
fn hash<H: Digest>(&self, hasher: &mut H) {
hasher.update(self.timestamp_millis().to_be_bytes());
}
}

View File

@@ -0,0 +1,5 @@
#[cfg(feature = "authn")]
pub mod authn;
pub mod hashing;
#[cfg(feature = "safecell")]
pub mod safecell;

View File

@@ -29,7 +29,7 @@ pub trait SafeCellHandle<T> {
let mut cell = Self::new(T::default()); let mut cell = Self::new(T::default());
{ {
let mut handle = cell.write(); let mut handle = cell.write();
f(handle.deref_mut()); f(&mut *handle);
} }
cell cell
} }
@@ -105,6 +105,11 @@ impl<T> SafeCellHandle<T> for MemSafeCell<T> {
fn abort_memory_breach(action: &str, err: &memsafe::error::MemoryError) -> ! { fn abort_memory_breach(action: &str, err: &memsafe::error::MemoryError) -> ! {
eprintln!("fatal {action}: {err}"); eprintln!("fatal {action}: {err}");
// SAFETY: Intentionally cause a segmentation fault to prevent further execution in a compromised state.
unsafe {
let unsafe_pointer = std::ptr::null_mut::<u8>();
std::ptr::write_volatile(unsafe_pointer, 0);
}
std::process::abort(); std::process::abort();
} }

View File

@@ -0,0 +1,18 @@
[package]
name = "arbiter-macros"
version = "0.1.0"
edition = "2024"
[lib]
proc-macro = true
[dependencies]
proc-macro2 = "1.0"
quote = "1.0"
syn = { version = "2.0", features = ["derive", "fold", "full", "visit-mut"] }
[dev-dependencies]
arbiter-crypto = { path = "../arbiter-crypto" }
[lints]
workspace = true

View File

@@ -0,0 +1,133 @@
use proc_macro2::{Span, TokenStream, TokenTree};
use quote::quote;
use syn::parse_quote;
use syn::spanned::Spanned;
use syn::{DataStruct, DeriveInput, Fields, Generics, Index};
use crate::utils::{HASHABLE_TRAIT_PATH, HMAC_DIGEST_PATH};
pub(crate) fn derive(input: &DeriveInput) -> TokenStream {
match &input.data {
syn::Data::Struct(struct_data) => hashable_struct(input, struct_data),
syn::Data::Enum(_) => {
syn::Error::new_spanned(input, "Hashable can currently be derived only for structs")
.to_compile_error()
}
syn::Data::Union(_) => {
syn::Error::new_spanned(input, "Hashable cannot be derived for unions")
.to_compile_error()
}
}
}
fn hashable_struct(input: &DeriveInput, struct_data: &DataStruct) -> TokenStream {
let ident = &input.ident;
let hashable_trait = HASHABLE_TRAIT_PATH.to_path();
let hmac_digest = HMAC_DIGEST_PATH.to_path();
let generics = add_hashable_bounds(input.generics.clone(), &hashable_trait);
let field_accesses = collect_field_accesses(struct_data);
let hash_calls = build_hash_calls(&field_accesses, &hashable_trait);
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
quote! {
#[automatically_derived]
impl #impl_generics #hashable_trait for #ident #ty_generics #where_clause {
fn hash<H: #hmac_digest>(&self, hasher: &mut H) {
#(#hash_calls)*
}
}
}
}
fn add_hashable_bounds(mut generics: Generics, hashable_trait: &syn::Path) -> Generics {
for type_param in generics.type_params_mut() {
type_param.bounds.push(parse_quote!(#hashable_trait));
}
generics
}
struct FieldAccess {
access: TokenStream,
span: Span,
}
fn collect_field_accesses(struct_data: &DataStruct) -> Vec<FieldAccess> {
match &struct_data.fields {
Fields::Named(fields) => {
// Keep deterministic alphabetical order for named fields.
// Do not remove this sort, because it keeps hash output stable regardless of source order.
let mut named_fields = fields
.named
.iter()
.map(|field| {
let name = field
.ident
.as_ref()
.expect("Fields::Named(fields) must have names")
.clone();
(name.to_string(), name)
})
.collect::<Vec<_>>();
named_fields.sort_by(|a, b| a.0.cmp(&b.0));
named_fields
.into_iter()
.map(|(_, name)| FieldAccess {
access: quote! { #name },
span: name.span(),
})
.collect()
}
Fields::Unnamed(fields) => fields
.unnamed
.iter()
.enumerate()
.map(|(i, field)| FieldAccess {
access: {
let index = Index::from(i);
quote! { #index }
},
span: field.ty.span(),
})
.collect(),
Fields::Unit => Vec::new(),
}
}
fn build_hash_calls(
field_accesses: &[FieldAccess],
hashable_trait: &syn::Path,
) -> Vec<TokenStream> {
field_accesses
.iter()
.map(|field| {
let access = &field.access;
let call = quote! {
#hashable_trait::hash(&self.#access, hasher);
};
respan(call, field.span)
})
.collect()
}
/// Recursively set span on all tokens, including interpolated ones.
fn respan(tokens: TokenStream, span: Span) -> TokenStream {
tokens
.into_iter()
.map(|tt| match tt {
TokenTree::Group(g) => {
let mut new = proc_macro2::Group::new(g.delimiter(), respan(g.stream(), span));
new.set_span(span);
TokenTree::Group(new)
}
mut other => {
other.set_span(span);
other
}
})
.collect()
}

View File

@@ -0,0 +1,10 @@
use syn::{DeriveInput, parse_macro_input};
mod hashable;
mod utils;
#[proc_macro_derive(Hashable)]
pub fn derive_hashable(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);
hashable::derive(&input).into()
}

View File

@@ -0,0 +1,24 @@
pub(crate) struct ToPath(pub &'static str);
impl ToPath {
pub(crate) fn to_path(&self) -> syn::Path {
syn::parse_str(self.0).expect("Invalid path")
}
}
macro_rules! ensure_path {
($path:path as $name:ident) => {
const _: () = {
#[cfg(test)]
#[expect(
unused_imports,
reason = "Ensures the path is valid and will cause a compile error if not"
)]
use $path as _;
};
pub(crate) const $name: ToPath = ToPath(stringify!($path));
};
}
ensure_path!(::arbiter_crypto::hashing::Hashable as HASHABLE_TRAIT_PATH);
ensure_path!(::arbiter_crypto::hashing::Digest as HMAC_DIGEST_PATH);

View File

@@ -11,13 +11,13 @@ tokio.workspace = true
futures.workspace = true futures.workspace = true
hex = "0.4.3" hex = "0.4.3"
tonic-prost = "0.14.5" tonic-prost = "0.14.5"
prost = "0.14.3" prost.workspace = true
kameo.workspace = true kameo.workspace = true
url = "2.5.8" url = "2.5.8"
miette.workspace = true miette.workspace = true
thiserror.workspace = true thiserror.workspace = true
rustls-pki-types.workspace = true rustls-pki-types.workspace = true
base64 = "0.22.1" base64.workspace = true
prost-types.workspace = true prost-types.workspace = true
tracing.workspace = true tracing.workspace = true
async-trait.workspace = true async-trait.workspace = true

View File

@@ -1,8 +1,6 @@
pub mod transport; pub mod transport;
pub mod url; pub mod url;
use base64::{Engine, prelude::BASE64_STANDARD};
pub mod proto { pub mod proto {
tonic::include_proto!("arbiter"); tonic::include_proto!("arbiter");
@@ -84,8 +82,3 @@ pub fn home_path() -> Result<std::path::PathBuf, std::io::Error> {
Ok(arbiter_home) Ok(arbiter_home)
} }
pub fn format_challenge(nonce: i32, pubkey: &[u8]) -> Vec<u8> {
let concat_form = format!("{}:{}", nonce, BASE64_STANDARD.encode(pubkey));
concat_form.into_bytes()
}

View File

@@ -105,7 +105,7 @@ mod tests {
#[rstest] #[rstest]
fn test_parsing_correctness( fn parsing_correctness(
#[values("127.0.0.1", "localhost", "192.168.1.1", "some.domain.com")] host: &str, #[values("127.0.0.1", "localhost", "192.168.1.1", "some.domain.com")] host: &str,
#[values(None, Some("token123".to_string()))] bootstrap_token: Option<String>, #[values(None, Some("token123".to_string()))] bootstrap_token: Option<String>,

View File

@@ -16,8 +16,9 @@ diesel-async = { version = "0.8.0", features = [
"sqlite", "sqlite",
"tokio", "tokio",
] } ] }
ed25519-dalek.workspace = true
arbiter-proto.path = "../arbiter-proto" arbiter-proto.path = "../arbiter-proto"
arbiter-crypto.path = "../arbiter-crypto"
arbiter-macros.path = "../arbiter-macros"
tracing.workspace = true tracing.workspace = true
tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tonic.workspace = true tonic.workspace = true
@@ -36,25 +37,31 @@ dashmap = "6.1.0"
rand.workspace = true rand.workspace = true
rcgen.workspace = true rcgen.workspace = true
chrono.workspace = true chrono.workspace = true
memsafe = "0.4.0"
zeroize = { version = "1.8.2", features = ["std", "simd"] } zeroize = { version = "1.8.2", features = ["std", "simd"] }
kameo.workspace = true kameo.workspace = true
x25519-dalek.workspace = true
chacha20poly1305 = { version = "0.10.1", features = ["std"] } chacha20poly1305 = { version = "0.10.1", features = ["std"] }
argon2 = { version = "0.5.3", features = ["zeroize"] } argon2 = { version = "0.5.3", features = ["zeroize"] }
restructed = "0.2.2" restructed = "0.2.2"
strum = { version = "0.28.0", features = ["derive"] } strum = { version = "0.28.0", features = ["derive"] }
pem = "3.0.6" pem = "3.0.6"
k256.workspace = true
rsa.workspace = true
sha2.workspace = true sha2.workspace = true
hmac = "0.12" hmac.workspace = true
spki.workspace = true spki.workspace = true
alloy.workspace = true alloy.workspace = true
prost-types.workspace = true prost-types.workspace = true
prost.workspace = true
arbiter-tokens-registry.path = "../arbiter-tokens-registry" arbiter-tokens-registry.path = "../arbiter-tokens-registry"
anyhow = "1.0.102" anyhow = "1.0.102"
serde_with = "3.18.0"
mutants.workspace = true
subtle = "2.6.1"
ml-dsa.workspace = true
ed25519-dalek.workspace = true
x25519-dalek.workspace = true
k256.workspace = true
[dev-dependencies] [dev-dependencies]
insta = "1.46.3" insta = "1.46.3"
proptest = "1.11.0"
rstest.workspace = true
test-log = { version = "0.2", default-features = false, features = ["trace"] } test-log = { version = "0.2", default-features = false, features = ["trace"] }

View File

@@ -47,8 +47,7 @@ create table if not exists useragent_client (
id integer not null primary key, id integer not null primary key,
nonce integer not null default(1), -- used for auth challenge nonce integer not null default(1), -- used for auth challenge
public_key blob not null, public_key blob not null,
pubkey_integrity_tag blob, key_type integer not null default(1),
key_type integer not null default(1), -- 1=Ed25519, 2=ECDSA(secp256k1)
created_at integer not null default(unixepoch ('now')), created_at integer not null default(unixepoch ('now')),
updated_at integer not null default(unixepoch ('now')) updated_at integer not null default(unixepoch ('now'))
) STRICT; ) STRICT;
@@ -192,3 +191,19 @@ create table if not exists evm_ether_transfer_grant_target (
) STRICT; ) STRICT;
create unique index if not exists uniq_ether_transfer_target on evm_ether_transfer_grant_target (grant_id, address); create unique index if not exists uniq_ether_transfer_target on evm_ether_transfer_grant_target (grant_id, address);
-- ===============================
-- Integrity Envelopes
-- ===============================
create table if not exists integrity_envelope (
id integer not null primary key,
entity_kind text not null,
entity_id blob not null,
payload_version integer not null,
key_version integer not null,
mac blob not null, -- 20-byte recipient address
signed_at integer not null default(unixepoch ('now')),
created_at integer not null default(unixepoch ('now'))
) STRICT;
create unique index if not exists uniq_integrity_envelope_entity on integrity_envelope (entity_kind, entity_id);

View File

@@ -4,6 +4,7 @@ use diesel_async::RunQueryDsl;
use kameo::{Actor, messages}; use kameo::{Actor, messages};
use rand::{RngExt, distr::Alphanumeric, make_rng, rngs::StdRng}; use rand::{RngExt, distr::Alphanumeric, make_rng, rngs::StdRng};
use subtle::ConstantTimeEq as _;
use thiserror::Error; use thiserror::Error;
use crate::db::{self, DatabasePool, schema}; use crate::db::{self, DatabasePool, schema};
@@ -12,8 +13,8 @@ const TOKEN_LENGTH: usize = 64;
pub async fn generate_token() -> Result<String, std::io::Error> { pub async fn generate_token() -> Result<String, std::io::Error> {
let rng: StdRng = make_rng(); let rng: StdRng = make_rng();
let token: String = rng.sample_iter(Alphanumeric).take(TOKEN_LENGTH).fold( let token = rng.sample_iter(Alphanumeric).take(TOKEN_LENGTH).fold(
Default::default(), String::default(),
|mut accum, char| { |mut accum, char| {
accum += char.to_string().as_str(); accum += char.to_string().as_str();
accum accum
@@ -26,15 +27,15 @@ pub async fn generate_token() -> Result<String, std::io::Error> {
} }
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum Error { pub enum BootstrappError {
#[error("Database error: {0}")] #[error("Database error: {0}")]
Database(#[from] db::PoolError), Database(#[from] db::PoolError),
#[error("Database query error: {0}")]
Query(#[from] diesel::result::Error),
#[error("I/O error: {0}")] #[error("I/O error: {0}")]
Io(#[from] std::io::Error), Io(#[from] std::io::Error),
#[error("Database query error: {0}")]
Query(#[from] diesel::result::Error),
} }
#[derive(Actor)] #[derive(Actor)]
@@ -43,15 +44,15 @@ pub struct Bootstrapper {
} }
impl Bootstrapper { impl Bootstrapper {
pub async fn new(db: &DatabasePool) -> Result<Self, Error> { pub async fn new(db: &DatabasePool) -> Result<Self, BootstrappError> {
let mut conn = db.get().await?; let row_count: i64 = {
let mut conn = db.get().await?;
let row_count: i64 = schema::useragent_client::table schema::useragent_client::table
.count() .count()
.get_result(&mut conn) .get_result(&mut conn)
.await?; .await?
};
drop(conn);
let token = if row_count == 0 { let token = if row_count == 0 {
let token = generate_token().await?; let token = generate_token().await?;
@@ -68,10 +69,13 @@ impl Bootstrapper {
impl Bootstrapper { impl Bootstrapper {
#[message] #[message]
pub fn is_correct_token(&self, token: String) -> bool { pub fn is_correct_token(&self, token: String) -> bool {
match &self.token { self.token.as_ref().is_some_and(|expected| {
Some(expected) => *expected == token, let expected_bytes = expected.as_bytes();
None => false, let token_bytes = token.as_bytes();
}
let choice = expected_bytes.ct_eq(token_bytes);
bool::from(choice)
})
} }
#[message] #[message]

View File

@@ -1,5 +1,7 @@
use arbiter_crypto::authn::{self, CLIENT_CONTEXT};
use arbiter_proto::{ use arbiter_proto::{
ClientMetadata, format_challenge, ClientMetadata,
proto::client::auth::{AuthChallenge as ProtoAuthChallenge, AuthResult as ProtoAuthResult},
transport::{Bi, expect_message}, transport::{Bi, expect_message},
}; };
use chrono::Utc; use chrono::Utc;
@@ -8,15 +10,16 @@ use diesel::{
dsl::insert_into, update, dsl::insert_into, update,
}; };
use diesel_async::RunQueryDsl as _; use diesel_async::RunQueryDsl as _;
use ed25519_dalek::{Signature, VerifyingKey}; use kameo::{actor::ActorRef, error::SendError};
use kameo::error::SendError;
use tracing::error; use tracing::error;
use crate::{ use crate::{
actors::{ actors::{
client::{ClientConnection, ClientProfile}, client::{ClientConnection, ClientCredentials, ClientProfile},
flow_coordinator::{self, RequestClientApproval}, flow_coordinator::{self, RequestClientApproval},
keyholder::KeyHolder,
}, },
crypto::integrity::{self, AttestationStatus},
db::{ db::{
self, self,
models::{ProgramClientMetadata, SqliteTimestamp}, models::{ProgramClientMetadata, SqliteTimestamp},
@@ -25,25 +28,60 @@ use crate::{
}; };
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)] #[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
pub enum Error { pub enum ClientAuthError {
#[error("Database pool unavailable")]
DatabasePoolUnavailable,
#[error("Database operation failed")]
DatabaseOperationFailed,
#[error("Invalid challenge solution")]
InvalidChallengeSolution,
#[error("Client approval request failed")] #[error("Client approval request failed")]
ApproveError(#[from] ApproveError), ApproveError(#[from] ApproveError),
#[error("Database operation failed")]
DatabaseOperationFailed,
#[error("Database pool unavailable")]
DatabasePoolUnavailable,
#[error("Integrity check failed")]
IntegrityCheckFailed,
#[error("Invalid challenge solution")]
InvalidChallengeSolution,
#[error("Transport error")] #[error("Transport error")]
Transport, Transport,
} }
impl From<diesel::result::Error> for ClientAuthError {
fn from(e: diesel::result::Error) -> Self {
error!(?e, "Database error");
Self::DatabaseOperationFailed
}
}
impl From<ClientAuthError> for arbiter_proto::proto::client::auth::AuthResult {
fn from(value: ClientAuthError) -> Self {
match value {
ClientAuthError::ApproveError(e) => match e {
ApproveError::Denied => Self::ApprovalDenied,
ApproveError::Internal => Self::Internal,
ApproveError::Upstream(flow_coordinator::ApprovalError::NoUserAgentsConnected) => {
Self::NoUserAgentsOnline
} // ApproveError::Upstream(_) => Self::Internal,
},
ClientAuthError::DatabaseOperationFailed
| ClientAuthError::DatabasePoolUnavailable
| ClientAuthError::IntegrityCheckFailed
| ClientAuthError::Transport => Self::Internal,
ClientAuthError::InvalidChallengeSolution => Self::InvalidSignature,
}
}
}
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)] #[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
pub enum ApproveError { pub enum ApproveError {
#[error("Internal error")]
Internal,
#[error("Client connection denied by user agents")] #[error("Client connection denied by user agents")]
Denied, Denied,
#[error("Internal error")]
Internal,
#[error("Upstream error: {0}")] #[error("Upstream error: {0}")]
Upstream(flow_coordinator::ApprovalError), Upstream(flow_coordinator::ApprovalError),
} }
@@ -51,73 +89,147 @@ pub enum ApproveError {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum Inbound { pub enum Inbound {
AuthChallengeRequest { AuthChallengeRequest {
pubkey: VerifyingKey, pubkey: authn::PublicKey,
metadata: ClientMetadata, metadata: ClientMetadata,
}, },
AuthChallengeSolution { AuthChallengeSolution {
signature: Signature, signature: authn::Signature,
}, },
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum Outbound { pub enum Outbound {
AuthChallenge { pubkey: VerifyingKey, nonce: i32 }, AuthChallenge {
pubkey: authn::PublicKey,
nonce: i32,
},
AuthSuccess, AuthSuccess,
} }
pub struct ClientInfo { impl From<Outbound> for arbiter_proto::proto::client::auth::response::Payload {
pub id: i32, fn from(value: Outbound) -> Self {
pub current_nonce: i32, match value {
Outbound::AuthChallenge { pubkey, nonce } => Self::Challenge(ProtoAuthChallenge {
pubkey: pubkey.to_bytes(),
nonce,
}),
Outbound::AuthSuccess => Self::Result(ProtoAuthResult::Success.into()),
}
}
} }
/// Atomically reads and increments the nonce for a known client. /// Returns the current nonce and client ID for a registered client.
/// Returns `None` if the pubkey is not registered. /// Returns `None` if the pubkey is not registered.
async fn get_client_and_nonce( async fn get_current_nonce_and_id(
db: &db::DatabasePool, db: &db::DatabasePool,
pubkey: &VerifyingKey, pubkey: &authn::PublicKey,
) -> Result<Option<ClientInfo>, Error> { ) -> Result<Option<(i32, i32)>, ClientAuthError> {
let pubkey_bytes = pubkey.as_bytes().to_vec(); let pubkey_bytes = pubkey.to_bytes();
let mut conn = db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error");
ClientAuthError::DatabasePoolUnavailable
})?;
program_client::table
.filter(program_client::public_key.eq(&pubkey_bytes))
.select((program_client::id, program_client::nonce))
.first::<(i32, i32)>(&mut conn)
.await
.optional()
.map_err(|e| {
error!(error = ?e, "Database error");
ClientAuthError::DatabaseOperationFailed
})
}
async fn verify_integrity(
db: &db::DatabasePool,
keyholder: &ActorRef<KeyHolder>,
pubkey: &authn::PublicKey,
) -> Result<(), ClientAuthError> {
let mut db_conn = db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error");
ClientAuthError::DatabasePoolUnavailable
})?;
let (id, nonce) = get_current_nonce_and_id(db, pubkey).await?.ok_or_else(|| {
error!("Client not found during integrity verification");
ClientAuthError::DatabaseOperationFailed
})?;
let attestation = integrity::verify_entity(
&mut db_conn,
keyholder,
&ClientCredentials {
pubkey: pubkey.clone(),
nonce,
},
id,
)
.await
.map_err(|e| {
error!(?e, "Integrity verification failed");
ClientAuthError::IntegrityCheckFailed
})?;
if attestation != AttestationStatus::Attested {
error!("Integrity attestation unavailable for client {id}");
return Err(ClientAuthError::IntegrityCheckFailed);
}
Ok(())
}
/// Atomically increments the nonce and re-signs the integrity envelope.
/// Returns the new nonce, which is used as the challenge nonce.
async fn create_nonce(
db: &db::DatabasePool,
keyholder: &ActorRef<KeyHolder>,
pubkey: &authn::PublicKey,
) -> Result<i32, ClientAuthError> {
let pubkey_bytes = pubkey.to_bytes();
let pubkey = pubkey.clone();
let mut conn = db.get().await.map_err(|e| { let mut conn = db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error"); error!(error = ?e, "Database pool error");
Error::DatabasePoolUnavailable ClientAuthError::DatabasePoolUnavailable
})?; })?;
conn.exclusive_transaction(|conn| { conn.exclusive_transaction(|conn| {
let keyholder = keyholder.clone();
let pubkey = pubkey.clone();
Box::pin(async move { Box::pin(async move {
let Some((client_id, current_nonce)) = program_client::table let (id, new_nonce): (i32, i32) = update(program_client::table)
.filter(program_client::public_key.eq(&pubkey_bytes)) .filter(program_client::public_key.eq(&pubkey_bytes))
.select((program_client::id, program_client::nonce)) .set(program_client::nonce.eq(program_client::nonce + 1))
.first::<(i32, i32)>(conn) .returning((program_client::id, program_client::nonce))
.await .get_result(conn)
.optional()?
else {
return Result::<_, diesel::result::Error>::Ok(None);
};
update(program_client::table)
.filter(program_client::public_key.eq(&pubkey_bytes))
.set(program_client::nonce.eq(current_nonce + 1))
.execute(conn)
.await?; .await?;
Ok(Some(ClientInfo { integrity::sign_entity(
id: client_id, conn,
current_nonce, &keyholder,
})) &ClientCredentials {
pubkey: pubkey.clone(),
nonce: new_nonce,
},
id,
)
.await
.map_err(|e| {
error!(?e, "Integrity sign failed after nonce update");
ClientAuthError::DatabaseOperationFailed
})?;
Ok(new_nonce)
}) })
}) })
.await .await
.map_err(|e| {
error!(error = ?e, "Database error");
Error::DatabaseOperationFailed
})
} }
async fn approve_new_client( async fn approve_new_client(
actors: &crate::actors::GlobalActors, actors: &crate::actors::GlobalActors,
profile: ClientProfile, profile: ClientProfile,
) -> Result<(), Error> { ) -> Result<(), ClientAuthError> {
let result = actors let result = actors
.flow_coordinator .flow_coordinator
.ask(RequestClientApproval { client: profile }) .ask(RequestClientApproval { client: profile })
@@ -125,73 +237,93 @@ async fn approve_new_client(
match result { match result {
Ok(true) => Ok(()), Ok(true) => Ok(()),
Ok(false) => Err(Error::ApproveError(ApproveError::Denied)), Ok(false) => Err(ClientAuthError::ApproveError(ApproveError::Denied)),
Err(SendError::HandlerError(e)) => { Err(SendError::HandlerError(e)) => {
error!(error = ?e, "Approval upstream error"); error!(error = ?e, "Approval upstream error");
Err(Error::ApproveError(ApproveError::Upstream(e))) Err(ClientAuthError::ApproveError(ApproveError::Upstream(e)))
} }
Err(e) => { Err(e) => {
error!(error = ?e, "Approval request to flow coordinator failed"); error!(error = ?e, "Approval request to flow coordinator failed");
Err(Error::ApproveError(ApproveError::Internal)) Err(ClientAuthError::ApproveError(ApproveError::Internal))
} }
} }
} }
async fn insert_client( async fn insert_client(
db: &db::DatabasePool, db: &db::DatabasePool,
pubkey: &VerifyingKey, keyholder: &ActorRef<KeyHolder>,
pubkey: &authn::PublicKey,
metadata: &ClientMetadata, metadata: &ClientMetadata,
) -> Result<i32, Error> { ) -> Result<i32, ClientAuthError> {
use crate::db::schema::{client_metadata, program_client}; use crate::db::schema::client_metadata;
let pubkey = pubkey.clone();
let metadata = metadata.clone();
let mut conn = db.get().await.map_err(|e| { let mut conn = db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error"); error!(error = ?e, "Database pool error");
Error::DatabasePoolUnavailable ClientAuthError::DatabasePoolUnavailable
})?; })?;
let metadata_id = insert_into(client_metadata::table) conn.exclusive_transaction(|conn| {
.values(( let keyholder = keyholder.clone();
client_metadata::name.eq(&metadata.name), let pubkey = pubkey.clone();
client_metadata::description.eq(&metadata.description), Box::pin(async move {
client_metadata::version.eq(&metadata.version), const NONCE_START: i32 = 1;
))
.returning(client_metadata::id)
.get_result::<i32>(&mut conn)
.await
.map_err(|e| {
error!(error = ?e, "Failed to insert client metadata");
Error::DatabaseOperationFailed
})?;
let client_id = insert_into(program_client::table) let metadata_id = insert_into(client_metadata::table)
.values(( .values((
program_client::public_key.eq(pubkey.as_bytes().to_vec()), client_metadata::name.eq(&metadata.name),
program_client::metadata_id.eq(metadata_id), client_metadata::description.eq(&metadata.description),
program_client::nonce.eq(1), // pre-incremented; challenge uses 0 client_metadata::version.eq(&metadata.version),
)) ))
.on_conflict_do_nothing() .returning(client_metadata::id)
.returning(program_client::id) .get_result::<i32>(conn)
.get_result::<i32>(&mut conn) .await?;
.await
.map_err(|e| {
error!(error = ?e, "Failed to insert client metadata");
Error::DatabaseOperationFailed
})?;
Ok(client_id) let client_id = insert_into(program_client::table)
.values((
program_client::public_key.eq(pubkey.to_bytes()),
program_client::metadata_id.eq(metadata_id),
program_client::nonce.eq(NONCE_START),
))
.on_conflict_do_nothing()
.returning(program_client::id)
.get_result::<i32>(conn)
.await?;
integrity::sign_entity(
conn,
&keyholder,
&ClientCredentials {
pubkey: pubkey.clone(),
nonce: NONCE_START,
},
client_id,
)
.await
.map_err(|e| {
error!(error = ?e, "Failed to sign integrity tag for new client key");
ClientAuthError::DatabaseOperationFailed
})?;
Ok(client_id)
})
})
.await
} }
async fn sync_client_metadata( async fn sync_client_metadata(
db: &db::DatabasePool, db: &db::DatabasePool,
client_id: i32, client_id: i32,
metadata: &ClientMetadata, metadata: &ClientMetadata,
) -> Result<(), Error> { ) -> Result<(), ClientAuthError> {
use crate::db::schema::{client_metadata, client_metadata_history}; use crate::db::schema::{client_metadata, client_metadata_history};
let now = SqliteTimestamp(Utc::now()); let now = SqliteTimestamp(Utc::now());
let mut conn = db.get().await.map_err(|e| { let mut conn = db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error"); error!(error = ?e, "Database pool error");
Error::DatabasePoolUnavailable ClientAuthError::DatabasePoolUnavailable
})?; })?;
conn.exclusive_transaction(|conn| { conn.exclusive_transaction(|conn| {
@@ -247,83 +379,84 @@ async fn sync_client_metadata(
.await .await
.map_err(|e| { .map_err(|e| {
error!(error = ?e, "Database error"); error!(error = ?e, "Database error");
Error::DatabaseOperationFailed ClientAuthError::DatabaseOperationFailed
}) })
} }
async fn challenge_client<T>( async fn challenge_client<T>(
transport: &mut T, transport: &mut T,
pubkey: VerifyingKey, pubkey: authn::PublicKey,
nonce: i32, nonce: i32,
) -> Result<(), Error> ) -> Result<(), ClientAuthError>
where where
T: Bi<Inbound, Result<Outbound, Error>> + ?Sized, T: Bi<Inbound, Result<Outbound, ClientAuthError>> + ?Sized,
{ {
transport transport
.send(Ok(Outbound::AuthChallenge { pubkey, nonce })) .send(Ok(Outbound::AuthChallenge {
pubkey: pubkey.clone(),
nonce,
}))
.await .await
.map_err(|e| { .map_err(|e| {
error!(error = ?e, "Failed to send auth challenge"); error!(error = ?e, "Failed to send auth challenge");
Error::Transport ClientAuthError::Transport
})?; })?;
let signature = expect_message(transport, |req: Inbound| match req { let signature = expect_message(transport, |req: Inbound| match req {
Inbound::AuthChallengeSolution { signature } => Some(signature), Inbound::AuthChallengeSolution { signature } => Some(signature),
_ => None, Inbound::AuthChallengeRequest { .. } => None,
}) })
.await .await
.map_err(|e| { .map_err(|e| {
error!(error = ?e, "Failed to receive challenge solution"); error!(error = ?e, "Failed to receive challenge solution");
Error::Transport ClientAuthError::Transport
})?; })?;
let formatted = format_challenge(nonce, pubkey.as_bytes()); if !pubkey.verify(nonce, CLIENT_CONTEXT, &signature) {
pubkey.verify_strict(&formatted, &signature).map_err(|_| {
error!("Challenge solution verification failed"); error!("Challenge solution verification failed");
Error::InvalidChallengeSolution return Err(ClientAuthError::InvalidChallengeSolution);
})?; }
Ok(()) Ok(())
} }
pub async fn authenticate<T>(props: &mut ClientConnection, transport: &mut T) -> Result<i32, Error> pub async fn authenticate<T>(
props: &mut ClientConnection,
transport: &mut T,
) -> Result<i32, ClientAuthError>
where where
T: Bi<Inbound, Result<Outbound, Error>> + Send + ?Sized, T: Bi<Inbound, Result<Outbound, ClientAuthError>> + Send + ?Sized,
{ {
let Some(Inbound::AuthChallengeRequest { pubkey, metadata }) = transport.recv().await else { let Some(Inbound::AuthChallengeRequest { pubkey, metadata }) = transport.recv().await else {
return Err(Error::Transport); return Err(ClientAuthError::Transport);
}; };
let info = match get_client_and_nonce(&props.db, &pubkey).await? { let client_id = if let Some((id, _)) = get_current_nonce_and_id(&props.db, &pubkey).await? {
Some(nonce) => nonce, verify_integrity(&props.db, &props.actors.key_holder, &pubkey).await?;
None => { id
approve_new_client( } else {
&props.actors, approve_new_client(
ClientProfile { &props.actors,
pubkey, ClientProfile {
metadata: metadata.clone(), pubkey: pubkey.clone(),
}, metadata: metadata.clone(),
) },
.await?; )
let client_id = insert_client(&props.db, &pubkey, &metadata).await?; .await?;
ClientInfo { insert_client(&props.db, &props.actors.key_holder, &pubkey, &metadata).await?
id: client_id,
current_nonce: 0,
}
}
}; };
sync_client_metadata(&props.db, info.id, &metadata).await?; sync_client_metadata(&props.db, client_id, &metadata).await?;
challenge_client(transport, pubkey, info.current_nonce).await?; let challenge_nonce = create_nonce(&props.db, &props.actors.key_holder, &pubkey).await?;
challenge_client(transport, pubkey, challenge_nonce).await?;
transport transport
.send(Ok(Outbound::AuthSuccess)) .send(Ok(Outbound::AuthSuccess))
.await .await
.map_err(|e| { .map_err(|e| {
error!(error = ?e, "Failed to send auth success"); error!(error = ?e, "Failed to send auth success");
Error::Transport ClientAuthError::Transport
})?; })?;
Ok(info.id) Ok(client_id)
} }

View File

@@ -1,25 +1,37 @@
use arbiter_crypto::authn;
use arbiter_proto::{ClientMetadata, transport::Bi}; use arbiter_proto::{ClientMetadata, transport::Bi};
use kameo::actor::Spawn; use kameo::actor::Spawn;
use tracing::{error, info}; use tracing::{error, info};
use crate::{ use crate::{
actors::{GlobalActors, client::session::ClientSession}, actors::{GlobalActors, client::session::ClientSession},
crypto::integrity::Integrable,
db, db,
}; };
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ClientProfile { pub struct ClientProfile {
pub pubkey: ed25519_dalek::VerifyingKey, pub pubkey: authn::PublicKey,
pub metadata: ClientMetadata, pub metadata: ClientMetadata,
} }
#[derive(arbiter_macros::Hashable)]
pub struct ClientCredentials {
pub pubkey: authn::PublicKey,
pub nonce: i32,
}
impl Integrable for ClientCredentials {
const KIND: &'static str = "client_credentials";
}
pub struct ClientConnection { pub struct ClientConnection {
pub(crate) db: db::DatabasePool, pub(crate) db: db::DatabasePool,
pub(crate) actors: GlobalActors, pub(crate) actors: GlobalActors,
} }
impl ClientConnection { impl ClientConnection {
pub fn new(db: db::DatabasePool, actors: GlobalActors) -> Self { pub const fn new(db: db::DatabasePool, actors: GlobalActors) -> Self {
Self { db, actors } Self { db, actors }
} }
} }
@@ -29,9 +41,11 @@ pub mod session;
pub async fn connect_client<T>(mut props: ClientConnection, transport: &mut T) pub async fn connect_client<T>(mut props: ClientConnection, transport: &mut T)
where where
T: Bi<auth::Inbound, Result<auth::Outbound, auth::Error>> + Send + ?Sized, T: Bi<auth::Inbound, Result<auth::Outbound, auth::ClientAuthError>> + Send + ?Sized,
{ {
match auth::authenticate(&mut props, transport).await { let fut = auth::authenticate(&mut props, transport);
println!("authenticate future size: {}", size_of_val(&fut));
match fut.await {
Ok(client_id) => { Ok(client_id) => {
ClientSession::spawn(ClientSession::new(props, client_id)); ClientSession::spawn(ClientSession::new(props, client_id));
info!("Client authenticated, session started"); info!("Client authenticated, session started");

View File

@@ -21,7 +21,7 @@ pub struct ClientSession {
} }
impl ClientSession { impl ClientSession {
pub(crate) fn new(props: ClientConnection, client_id: i32) -> Self { pub(crate) const fn new(props: ClientConnection, client_id: i32) -> Self {
Self { props, client_id } Self { props, client_id }
} }
} }
@@ -29,14 +29,16 @@ impl ClientSession {
#[messages] #[messages]
impl ClientSession { impl ClientSession {
#[message] #[message]
pub(crate) async fn handle_query_vault_state(&mut self) -> Result<KeyHolderState, Error> { pub(crate) async fn handle_query_vault_state(
&mut self,
) -> Result<KeyHolderState, ClientSessionError> {
use crate::actors::keyholder::GetState; use crate::actors::keyholder::GetState;
let vault_state = match self.props.actors.key_holder.ask(GetState {}).await { let vault_state = match self.props.actors.key_holder.ask(GetState {}).await {
Ok(state) => state, Ok(state) => state,
Err(err) => { Err(err) => {
error!(?err, actor = "client", "keyholder.query.failed"); error!(?err, actor = "client", "keyholder.query.failed");
return Err(Error::Internal); return Err(ClientSessionError::Internal);
} }
}; };
@@ -75,7 +77,7 @@ impl ClientSession {
impl Actor for ClientSession { impl Actor for ClientSession {
type Args = Self; type Args = Self;
type Error = Error; type Error = ClientSessionError;
async fn on_start( async fn on_start(
args: Self::Args, args: Self::Args,
@@ -86,13 +88,13 @@ impl Actor for ClientSession {
.flow_coordinator .flow_coordinator
.ask(RegisterClient { actor: this }) .ask(RegisterClient { actor: this })
.await .await
.map_err(|_| Error::ConnectionRegistrationFailed)?; .map_err(|_| ClientSessionError::ConnectionRegistrationFailed)?;
Ok(args) Ok(args)
} }
} }
impl ClientSession { impl ClientSession {
pub fn new_test(db: db::DatabasePool, actors: GlobalActors) -> Self { pub const fn new_test(db: db::DatabasePool, actors: GlobalActors) -> Self {
let props = ClientConnection::new(db, actors); let props = ClientConnection::new(db, actors);
Self { Self {
props, props,
@@ -102,7 +104,7 @@ impl ClientSession {
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum Error { pub enum ClientSessionError {
#[error("Connection registration failed")] #[error("Connection registration failed")]
ConnectionRegistrationFailed, ConnectionRegistrationFailed,
#[error("Internal error")] #[error("Internal error")]
@@ -111,9 +113,9 @@ pub enum Error {
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum SignTransactionRpcError { pub enum SignTransactionRpcError {
#[error("Policy evaluation failed")]
Vet(#[from] VetError),
#[error("Internal error")] #[error("Internal error")]
Internal, Internal,
#[error("Policy evaluation failed")]
Vet(#[from] VetError),
} }

View File

@@ -1,4 +1,6 @@
use alloy::{consensus::TxEip1559, primitives::Address, signers::Signature}; use alloy::{
consensus::TxEip1559, network::TxSignerSync as _, primitives::Address, signers::Signature,
};
use diesel::{ use diesel::{
ExpressionMethods, OptionalExtension as _, QueryDsl, SelectableHelper as _, dsl::insert_into, ExpressionMethods, OptionalExtension as _, QueryDsl, SelectableHelper as _, dsl::insert_into,
}; };
@@ -8,20 +10,21 @@ use rand::{SeedableRng, rng, rngs::StdRng};
use crate::{ use crate::{
actors::keyholder::{CreateNew, Decrypt, KeyHolder}, actors::keyholder::{CreateNew, Decrypt, KeyHolder},
crypto::integrity,
db::{ db::{
DatabaseError, DatabasePool, DatabaseError, DatabasePool,
models::{self, SqliteTimestamp}, models::{self},
schema, schema,
}, },
evm::{ evm::{
self, RunKind, self, ListError, RunKind,
policies::{ policies::{
FullGrant, Grant, SharedGrantSettings, SpecificGrant, SpecificMeaning, CombinedSettings, Grant, SharedGrantSettings, SpecificGrant, SpecificMeaning,
ether_transfer::EtherTransfer, token_transfers::TokenTransfer, ether_transfer::EtherTransfer, token_transfers::TokenTransfer,
}, },
}, },
safe_cell::{SafeCell, SafeCellHandle as _},
}; };
use arbiter_crypto::safecell::{SafeCell, SafeCellHandle as _};
pub use crate::evm::safe_signer; pub use crate::evm::safe_signer;
@@ -34,7 +37,7 @@ pub enum SignTransactionError {
Database(#[from] DatabaseError), Database(#[from] DatabaseError),
#[error("Keyholder error: {0}")] #[error("Keyholder error: {0}")]
Keyholder(#[from] crate::actors::keyholder::Error), Keyholder(#[from] crate::actors::keyholder::KeyHolderError),
#[error("Keyholder mailbox error")] #[error("Keyholder mailbox error")]
KeyholderSend, KeyholderSend,
@@ -47,15 +50,18 @@ pub enum SignTransactionError {
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum Error { pub enum EvmActorError {
#[error("Keyholder error: {0}")] #[error("Keyholder error: {0}")]
Keyholder(#[from] crate::actors::keyholder::Error), Keyholder(#[from] crate::actors::keyholder::KeyHolderError),
#[error("Keyholder mailbox error")] #[error("Keyholder mailbox error")]
KeyholderSend, KeyholderSend,
#[error("Database error: {0}")] #[error("Database error: {0}")]
Database(#[from] DatabaseError), Database(#[from] DatabaseError),
#[error("Integrity violation: {0}")]
Integrity(#[from] integrity::IntegrityError),
} }
#[derive(Actor)] #[derive(Actor)]
@@ -71,7 +77,7 @@ impl EvmActor {
// is it safe to seed rng from system once? // is it safe to seed rng from system once?
// todo: audit // todo: audit
let rng = StdRng::from_rng(&mut rng()); let rng = StdRng::from_rng(&mut rng());
let engine = evm::Engine::new(db.clone()); let engine = evm::Engine::new(db.clone(), keyholder.clone());
Self { Self {
keyholder, keyholder,
db, db,
@@ -84,7 +90,7 @@ impl EvmActor {
#[messages] #[messages]
impl EvmActor { impl EvmActor {
#[message] #[message]
pub async fn generate(&mut self) -> Result<(i32, Address), Error> { pub async fn generate(&mut self) -> Result<(i32, Address), EvmActorError> {
let (mut key_cell, address) = safe_signer::generate(&mut self.rng); let (mut key_cell, address) = safe_signer::generate(&mut self.rng);
let plaintext = key_cell.read_inline(|reader| SafeCell::new(reader.to_vec())); let plaintext = key_cell.read_inline(|reader| SafeCell::new(reader.to_vec()));
@@ -93,7 +99,7 @@ impl EvmActor {
.keyholder .keyholder
.ask(CreateNew { plaintext }) .ask(CreateNew { plaintext })
.await .await
.map_err(|_| Error::KeyholderSend)?; .map_err(|_| EvmActorError::KeyholderSend)?;
let mut conn = self.db.get().await.map_err(DatabaseError::from)?; let mut conn = self.db.get().await.map_err(DatabaseError::from)?;
let wallet_id = insert_into(schema::evm_wallet::table) let wallet_id = insert_into(schema::evm_wallet::table)
@@ -110,7 +116,7 @@ impl EvmActor {
} }
#[message] #[message]
pub async fn list_wallets(&self) -> Result<Vec<(i32, Address)>, Error> { pub async fn list_wallets(&self) -> Result<Vec<(i32, Address)>, EvmActorError> {
let mut conn = self.db.get().await.map_err(DatabaseError::from)?; let mut conn = self.db.get().await.map_err(DatabaseError::from)?;
let rows: Vec<models::EvmWallet> = schema::evm_wallet::table let rows: Vec<models::EvmWallet> = schema::evm_wallet::table
.select(models::EvmWallet::as_select()) .select(models::EvmWallet::as_select())
@@ -132,46 +138,64 @@ impl EvmActor {
&mut self, &mut self,
basic: SharedGrantSettings, basic: SharedGrantSettings,
grant: SpecificGrant, grant: SpecificGrant,
) -> Result<i32, DatabaseError> { ) -> Result<i32, EvmActorError> {
match grant { match grant {
SpecificGrant::EtherTransfer(settings) => { SpecificGrant::EtherTransfer(settings) => self
self.engine .engine
.create_grant::<EtherTransfer>(FullGrant { .create_grant::<EtherTransfer>(CombinedSettings {
basic, shared: basic,
specific: settings, specific: settings,
}) })
.await .await
} .map_err(EvmActorError::from),
SpecificGrant::TokenTransfer(settings) => { SpecificGrant::TokenTransfer(settings) => self
self.engine .engine
.create_grant::<TokenTransfer>(FullGrant { .create_grant::<TokenTransfer>(CombinedSettings {
basic, shared: basic,
specific: settings, specific: settings,
}) })
.await .await
} .map_err(EvmActorError::from),
} }
} }
#[message] #[message]
pub async fn useragent_delete_grant(&mut self, grant_id: i32) -> Result<(), Error> { #[expect(clippy::unused_async, reason = "reserved for impl")]
let mut conn = self.db.get().await.map_err(DatabaseError::from)?; pub async fn useragent_delete_grant(&mut self, _grant_id: i32) -> Result<(), EvmActorError> {
diesel::update(schema::evm_basic_grant::table) // let mut conn = self.db.get().await.map_err(DatabaseError::from)?;
.filter(schema::evm_basic_grant::id.eq(grant_id)) // let keyholder = self.keyholder.clone();
.set(schema::evm_basic_grant::revoked_at.eq(SqliteTimestamp::now()))
.execute(&mut conn) // diesel_async::AsyncConnection::transaction(&mut conn, |conn| {
.await // Box::pin(async move {
.map_err(DatabaseError::from)?; // diesel::update(schema::evm_basic_grant::table)
Ok(()) // .filter(schema::evm_basic_grant::id.eq(grant_id))
// .set(schema::evm_basic_grant::revoked_at.eq(SqliteTimestamp::now()))
// .execute(conn)
// .await?;
// let signed = integrity::evm::load_signed_grant_by_basic_id(conn, grant_id).await?;
// diesel::result::QueryResult::Ok(())
// })
// })
// .await
// .map_err(DatabaseError::from)?;
// Ok(())
todo!()
} }
#[message] #[message]
pub async fn useragent_list_grants(&mut self) -> Result<Vec<Grant<SpecificGrant>>, Error> { pub async fn useragent_list_grants(
Ok(self &mut self,
.engine ) -> Result<Vec<Grant<SpecificGrant>>, EvmActorError> {
.list_all_grants() match self.engine.list_all_grants().await {
.await Ok(grants) => Ok(grants),
.map_err(DatabaseError::from)?) Err(ListError::Database(db_err)) => Err(EvmActorError::Database(db_err)),
Err(ListError::Integrity(integrity_err)) => {
Err(EvmActorError::Integrity(integrity_err))
}
}
} }
#[message] #[message]
@@ -250,7 +274,6 @@ impl EvmActor {
.evaluate_transaction(wallet_access, transaction.clone(), RunKind::Execution) .evaluate_transaction(wallet_access, transaction.clone(), RunKind::Execution)
.await?; .await?;
use alloy::network::TxSignerSync as _;
Ok(signer.sign_transaction_sync(&mut transaction)?) Ok(signer.sign_transaction_sync(&mut transaction)?)
} }
} }

View File

@@ -41,7 +41,7 @@ impl Actor for ClientApprovalController {
async fn on_start( async fn on_start(
Args { Args {
client, client,
mut user_agents, user_agents,
reply, reply,
}: Self::Args, }: Self::Args,
actor_ref: ActorRef<Self>, actor_ref: ActorRef<Self>,
@@ -52,8 +52,9 @@ impl Actor for ClientApprovalController {
reply: Some(reply), reply: Some(reply),
}; };
for user_agent in user_agents.drain(..) { for user_agent in user_agents {
actor_ref.link(&user_agent).await; actor_ref.link(&user_agent).await;
let _ = user_agent let _ = user_agent
.tell(BeginNewClientApproval { .tell(BeginNewClientApproval {
client: client.clone(), client: client.clone(),
@@ -85,7 +86,7 @@ impl Actor for ClientApprovalController {
#[messages] #[messages]
impl ClientApprovalController { impl ClientApprovalController {
#[message(ctx)] #[message(ctx)]
pub async fn client_approval_answer(&mut self, approved: bool, ctx: &mut Context<Self, ()>) { pub fn client_approval_answer(&mut self, approved: bool, ctx: &mut Context<Self, ()>) {
if !approved { if !approved {
// Denial wins immediately regardless of other pending responses. // Denial wins immediately regardless of other pending responses.
self.send_reply(Ok(false)); self.send_reply(Ok(false));

View File

@@ -92,7 +92,7 @@ impl FlowCoordinator {
} }
#[message(ctx)] #[message(ctx)]
pub async fn request_client_approval( pub fn request_client_approval(
&mut self, &mut self,
client: ClientProfile, client: ClientProfile,
ctx: &mut Context<Self, DelegatedReply<Result<bool, ApprovalError>>>, ctx: &mut Context<Self, DelegatedReply<Result<bool, ApprovalError>>>,

View File

@@ -4,26 +4,22 @@ use diesel::{
dsl::{insert_into, update}, dsl::{insert_into, update},
}; };
use diesel_async::{AsyncConnection, RunQueryDsl}; use diesel_async::{AsyncConnection, RunQueryDsl};
use hmac::Mac as _;
use kameo::{Actor, Reply, messages}; use kameo::{Actor, Reply, messages};
use strum::{EnumDiscriminants, IntoDiscriminant}; use strum::{EnumDiscriminants, IntoDiscriminant};
use tracing::{error, info}; use tracing::{error, info};
use crate::{ use crate::crypto::{
crypto::{ KeyCell, derive_key,
KeyCell, derive_key, encryption::v1::{self, Nonce},
encryption::v1::{self, Nonce}, integrity::v1::HmacSha256,
integrity::v1::compute_integrity_tag,
},
safe_cell::SafeCell,
}; };
use crate::{ use crate::db::{
db::{ self,
self, models::{self, RootKeyHistory},
models::{self, RootKeyHistory}, schema::{self},
schema::{self},
},
safe_cell::SafeCellHandle as _,
}; };
use arbiter_crypto::safecell::{SafeCell, SafeCellHandle as _};
#[derive(Default, EnumDiscriminants)] #[derive(Default, EnumDiscriminants)]
#[strum_discriminants(derive(Reply), vis(pub), name(KeyHolderState))] #[strum_discriminants(derive(Reply), vis(pub), name(KeyHolderState))]
@@ -40,19 +36,12 @@ enum State {
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum Error { pub enum KeyHolderError {
#[error("Keyholder is already bootstrapped")] #[error("Keyholder is already bootstrapped")]
AlreadyBootstrapped, AlreadyBootstrapped,
#[error("Keyholder is not bootstrapped")]
NotBootstrapped,
#[error("Invalid key provided")]
InvalidKey,
#[error("Requested aead entry not found")] #[error("Broken database")]
NotFound, BrokenDatabase,
#[error("Encryption error: {0}")]
Encryption(#[from] chacha20poly1305::aead::Error),
#[error("Database error: {0}")] #[error("Database error: {0}")]
DatabaseConnection(#[from] db::PoolError), DatabaseConnection(#[from] db::PoolError),
@@ -60,11 +49,21 @@ pub enum Error {
#[error("Database transaction error: {0}")] #[error("Database transaction error: {0}")]
DatabaseTransaction(#[from] diesel::result::Error), DatabaseTransaction(#[from] diesel::result::Error),
#[error("Broken database")] #[error("Encryption error: {0}")]
BrokenDatabase, Encryption(#[from] chacha20poly1305::aead::Error),
#[error("Invalid key provided")]
InvalidKey,
#[error("Keyholder is not bootstrapped")]
NotBootstrapped,
#[error("Requested aead entry not found")]
NotFound,
} }
/// Manages vault root key and tracks current state of the vault (bootstrapped/unbootstrapped, sealed/unsealed). /// Manages vault root key and tracks current state of the vault (bootstrapped/unbootstrapped, sealed/unsealed).
///
/// Provides API for encrypting and decrypting data using the vault root key. /// Provides API for encrypting and decrypting data using the vault root key.
/// Abstraction over database to make sure nonces are never reused and encryption keys are never exposed in plaintext outside of this actor. /// Abstraction over database to make sure nonces are never reused and encryption keys are never exposed in plaintext outside of this actor.
#[derive(Actor)] #[derive(Actor)]
@@ -75,7 +74,7 @@ pub struct KeyHolder {
#[messages] #[messages]
impl KeyHolder { impl KeyHolder {
pub async fn new(db: db::DatabasePool) -> Result<Self, Error> { pub async fn new(db: db::DatabasePool) -> Result<Self, KeyHolderError> {
let state = { let state = {
let mut conn = db.get().await?; let mut conn = db.get().await?;
@@ -98,7 +97,10 @@ impl KeyHolder {
// Exclusive transaction to avoid race condtions if multiple keyholders write // Exclusive transaction to avoid race condtions if multiple keyholders write
// additional layer of protection against nonce-reuse // additional layer of protection against nonce-reuse
async fn get_new_nonce(pool: &db::DatabasePool, root_key_id: i32) -> Result<Nonce, Error> { async fn get_new_nonce(
pool: &db::DatabasePool,
root_key_id: i32,
) -> Result<Nonce, KeyHolderError> {
let mut conn = pool.get().await?; let mut conn = pool.get().await?;
let nonce = conn let nonce = conn
@@ -110,12 +112,12 @@ impl KeyHolder {
.first(conn) .first(conn)
.await?; .await?;
let mut nonce = Nonce::try_from(current_nonce.as_slice()).map_err(|_| { let mut nonce = Nonce::try_from(current_nonce.as_slice()).map_err(|()| {
error!( error!(
"Broken database: invalid nonce for root key history id={}", "Broken database: invalid nonce for root key history id={}",
root_key_id root_key_id
); );
Error::BrokenDatabase KeyHolderError::BrokenDatabase
})?; })?;
nonce.increment(); nonce.increment();
@@ -125,7 +127,7 @@ impl KeyHolder {
.execute(conn) .execute(conn)
.await?; .await?;
Result::<_, Error>::Ok(nonce) Result::<_, KeyHolderError>::Ok(nonce)
}) })
}) })
.await?; .await?;
@@ -134,9 +136,12 @@ impl KeyHolder {
} }
#[message] #[message]
pub async fn bootstrap(&mut self, seal_key_raw: SafeCell<Vec<u8>>) -> Result<(), Error> { pub async fn bootstrap(
&mut self,
seal_key_raw: SafeCell<Vec<u8>>,
) -> Result<(), KeyHolderError> {
if !matches!(self.state, State::Unbootstrapped) { if !matches!(self.state, State::Unbootstrapped) {
return Err(Error::AlreadyBootstrapped); return Err(KeyHolderError::AlreadyBootstrapped);
} }
let salt = v1::generate_salt(); let salt = v1::generate_salt();
let mut seal_key = derive_key(seal_key_raw, &salt); let mut seal_key = derive_key(seal_key_raw, &salt);
@@ -152,7 +157,7 @@ impl KeyHolder {
.encrypt(&root_key_nonce, v1::ROOT_KEY_TAG, root_key_reader) .encrypt(&root_key_nonce, v1::ROOT_KEY_TAG, root_key_reader)
.map_err(|err| { .map_err(|err| {
error!(?err, "Fatal bootstrap error"); error!(?err, "Fatal bootstrap error");
Error::Encryption(err) KeyHolderError::Encryption(err)
}) })
})?; })?;
@@ -196,12 +201,15 @@ impl KeyHolder {
} }
#[message] #[message]
pub async fn try_unseal(&mut self, seal_key_raw: SafeCell<Vec<u8>>) -> Result<(), Error> { pub async fn try_unseal(
&mut self,
seal_key_raw: SafeCell<Vec<u8>>,
) -> Result<(), KeyHolderError> {
let State::Sealed { let State::Sealed {
root_key_history_id, root_key_history_id,
} = &self.state } = &self.state
else { else {
return Err(Error::NotBootstrapped); return Err(KeyHolderError::NotBootstrapped);
}; };
// We don't want to hold connection while doing expensive KDF work // We don't want to hold connection while doing expensive KDF work
@@ -217,16 +225,16 @@ impl KeyHolder {
let salt = &current_key.salt; let salt = &current_key.salt;
let salt = v1::Salt::try_from(salt.as_slice()).map_err(|_| { let salt = v1::Salt::try_from(salt.as_slice()).map_err(|_| {
error!("Broken database: invalid salt for root key"); error!("Broken database: invalid salt for root key");
Error::BrokenDatabase KeyHolderError::BrokenDatabase
})?; })?;
let mut seal_key = derive_key(seal_key_raw, &salt); let mut seal_key = derive_key(seal_key_raw, &salt);
let mut root_key = SafeCell::new(current_key.ciphertext.clone()); let mut root_key = SafeCell::new(current_key.ciphertext.clone());
let nonce = v1::Nonce::try_from(current_key.root_key_encryption_nonce.as_slice()).map_err( let nonce = Nonce::try_from(current_key.root_key_encryption_nonce.as_slice()).map_err(
|_| { |()| {
error!("Broken database: invalid nonce for root key"); error!("Broken database: invalid nonce for root key");
Error::BrokenDatabase KeyHolderError::BrokenDatabase
}, },
)?; )?;
@@ -234,14 +242,14 @@ impl KeyHolder {
.decrypt_in_place(&nonce, v1::ROOT_KEY_TAG, &mut root_key) .decrypt_in_place(&nonce, v1::ROOT_KEY_TAG, &mut root_key)
.map_err(|err| { .map_err(|err| {
error!(?err, "Failed to unseal root key: invalid seal key"); error!(?err, "Failed to unseal root key: invalid seal key");
Error::InvalidKey KeyHolderError::InvalidKey
})?; })?;
self.state = State::Unsealed { self.state = State::Unsealed {
root_key_history_id: current_key.id, root_key_history_id: current_key.id,
root_key: KeyCell::try_from(root_key).map_err(|err| { root_key: KeyCell::try_from(root_key).map_err(|err| {
error!(?err, "Broken database: invalid encryption key size"); error!(?err, "Broken database: invalid encryption key size");
Error::BrokenDatabase KeyHolderError::BrokenDatabase
})?, })?,
}; };
@@ -250,26 +258,10 @@ impl KeyHolder {
Ok(()) Ok(())
} }
// Signs a generic integrity payload using the vault-derived integrity key
#[message] #[message]
pub fn sign_integrity_tag( pub async fn decrypt(&mut self, aead_id: i32) -> Result<SafeCell<Vec<u8>>, KeyHolderError> {
&mut self,
purpose_tag: Vec<u8>,
data_parts: Vec<Vec<u8>>,
) -> Result<Vec<u8>, Error> {
let State::Unsealed { root_key, .. } = &mut self.state else { let State::Unsealed { root_key, .. } = &mut self.state else {
return Err(Error::NotBootstrapped); return Err(KeyHolderError::NotBootstrapped);
};
let tag =
compute_integrity_tag(root_key, &purpose_tag, data_parts.iter().map(Vec::as_slice));
Ok(tag.to_vec())
}
#[message]
pub async fn decrypt(&mut self, aead_id: i32) -> Result<SafeCell<Vec<u8>>, Error> {
let State::Unsealed { root_key, .. } = &mut self.state else {
return Err(Error::NotBootstrapped);
}; };
let row: models::AeadEncrypted = { let row: models::AeadEncrypted = {
@@ -280,15 +272,15 @@ impl KeyHolder {
.first(&mut conn) .first(&mut conn)
.await .await
.optional()? .optional()?
.ok_or(Error::NotFound)? .ok_or(KeyHolderError::NotFound)?
}; };
let nonce = v1::Nonce::try_from(row.current_nonce.as_slice()).map_err(|_| { let nonce = Nonce::try_from(row.current_nonce.as_slice()).map_err(|()| {
error!( error!(
"Broken database: invalid nonce for aead_encrypted id={}", "Broken database: invalid nonce for aead_encrypted id={}",
aead_id aead_id
); );
Error::BrokenDatabase KeyHolderError::BrokenDatabase
})?; })?;
let mut output = SafeCell::new(row.ciphertext); let mut output = SafeCell::new(row.ciphertext);
root_key.decrypt_in_place(&nonce, v1::TAG, &mut output)?; root_key.decrypt_in_place(&nonce, v1::TAG, &mut output)?;
@@ -297,14 +289,17 @@ impl KeyHolder {
// Creates new `aead_encrypted` entry in the database and returns it's ID // Creates new `aead_encrypted` entry in the database and returns it's ID
#[message] #[message]
pub async fn create_new(&mut self, mut plaintext: SafeCell<Vec<u8>>) -> Result<i32, Error> { pub async fn create_new(
&mut self,
mut plaintext: SafeCell<Vec<u8>>,
) -> Result<i32, KeyHolderError> {
let State::Unsealed { let State::Unsealed {
root_key, root_key,
root_key_history_id, root_key_history_id,
.. ..
} = &mut self.state } = &mut self.state
else { else {
return Err(Error::NotBootstrapped); return Err(KeyHolderError::NotBootstrapped);
}; };
// Order matters here - `get_new_nonce` acquires connection, so we need to call it before next acquire // Order matters here - `get_new_nonce` acquires connection, so we need to call it before next acquire
@@ -340,13 +335,63 @@ impl KeyHolder {
} }
#[message] #[message]
pub fn seal(&mut self) -> Result<(), Error> { pub fn sign_integrity(&mut self, mac_input: Vec<u8>) -> Result<(i32, Vec<u8>), KeyHolderError> {
let State::Unsealed {
root_key,
root_key_history_id,
} = &mut self.state
else {
return Err(KeyHolderError::NotBootstrapped);
};
let mut hmac = root_key.0.read_inline(|k| {
HmacSha256::new_from_slice(k)
.unwrap_or_else(|_| unreachable!("HMAC accepts keys of any size"))
});
hmac.update(&root_key_history_id.to_be_bytes());
hmac.update(&mac_input);
let mac = hmac.finalize().into_bytes().to_vec();
Ok((*root_key_history_id, mac))
}
#[message]
pub fn verify_integrity(
&mut self,
mac_input: Vec<u8>,
expected_mac: Vec<u8>,
key_version: i32,
) -> Result<bool, KeyHolderError> {
let State::Unsealed {
root_key,
root_key_history_id,
} = &mut self.state
else {
return Err(KeyHolderError::NotBootstrapped);
};
if *root_key_history_id != key_version {
return Ok(false);
}
let mut hmac = root_key.0.read_inline(|k| {
HmacSha256::new_from_slice(k)
.unwrap_or_else(|_| unreachable!("HMAC accepts keys of any size"))
});
hmac.update(&key_version.to_be_bytes());
hmac.update(&mac_input);
Ok(hmac.verify_slice(&expected_mac).is_ok())
}
#[message]
pub fn seal(&mut self) -> Result<(), KeyHolderError> {
let State::Unsealed { let State::Unsealed {
root_key_history_id, root_key_history_id,
.. ..
} = &self.state } = &self.state
else { else {
return Err(Error::NotBootstrapped); return Err(KeyHolderError::NotBootstrapped);
}; };
self.state = State::Sealed { self.state = State::Sealed {
root_key_history_id: *root_key_history_id, root_key_history_id: *root_key_history_id,
@@ -357,14 +402,7 @@ impl KeyHolder {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use diesel::SelectableHelper; use arbiter_crypto::safecell::SafeCellHandle as _;
use diesel_async::RunQueryDsl;
use crate::{
db::{self},
safe_cell::SafeCell,
};
use super::*; use super::*;
@@ -380,12 +418,12 @@ mod tests {
async fn nonce_monotonic_even_when_nonce_allocation_interleaves() { async fn nonce_monotonic_even_when_nonce_allocation_interleaves() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let mut actor = bootstrapped_actor(&db).await; let mut actor = bootstrapped_actor(&db).await;
let root_key_history_id = match actor.state { let State::Unsealed {
State::Unsealed { root_key_history_id,
root_key_history_id, ..
.. } = actor.state
} => root_key_history_id, else {
_ => panic!("expected unsealed state"), panic!("expected unsealed state");
}; };
let n1 = KeyHolder::get_new_nonce(&db, root_key_history_id) let n1 = KeyHolder::get_new_nonce(&db, root_key_history_id)
@@ -397,8 +435,8 @@ mod tests {
assert!(n2.to_vec() > n1.to_vec(), "nonce must increase"); assert!(n2.to_vec() > n1.to_vec(), "nonce must increase");
let mut conn = db.get().await.unwrap(); let mut conn = db.get().await.unwrap();
let root_row: models::RootKeyHistory = schema::root_key_history::table let root_row = schema::root_key_history::table
.select(models::RootKeyHistory::as_select()) .select(RootKeyHistory::as_select())
.first(&mut conn) .first(&mut conn)
.await .await
.unwrap(); .unwrap();

View File

@@ -11,18 +11,18 @@ use crate::{
pub mod bootstrap; pub mod bootstrap;
pub mod client; pub mod client;
mod evm; pub mod evm;
pub mod flow_coordinator; pub mod flow_coordinator;
pub mod keyholder; pub mod keyholder;
pub mod user_agent; pub mod user_agent;
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum SpawnError { pub enum GlobalActorsSpawnError {
#[error("Failed to spawn Bootstrapper actor")] #[error("Failed to spawn Bootstrapper actor")]
Bootstrapper(#[from] bootstrap::Error), Bootstrapper(#[from] bootstrap::BootstrappError),
#[error("Failed to spawn KeyHolder actor")] #[error("Failed to spawn KeyHolder actor")]
KeyHolder(#[from] keyholder::Error), KeyHolder(#[from] keyholder::KeyHolderError),
} }
/// Long-lived actors that are shared across all connections and handle global state and operations /// Long-lived actors that are shared across all connections and handle global state and operations
@@ -35,7 +35,7 @@ pub struct GlobalActors {
} }
impl GlobalActors { impl GlobalActors {
pub async fn spawn(db: db::DatabasePool) -> Result<Self, SpawnError> { pub async fn spawn(db: db::DatabasePool) -> Result<Self, GlobalActorsSpawnError> {
let key_holder = KeyHolder::spawn(KeyHolder::new(db.clone()).await?); let key_holder = KeyHolder::spawn(KeyHolder::new(db.clone()).await?);
Ok(Self { Ok(Self {
bootstrapper: Bootstrapper::spawn(Bootstrapper::new(&db).await?), bootstrapper: Bootstrapper::spawn(Bootstrapper::new(&db).await?),

View File

@@ -1,18 +1,20 @@
use arbiter_crypto::authn;
use arbiter_proto::transport::Bi; use arbiter_proto::transport::Bi;
use tracing::error; use tracing::error;
use crate::actors::user_agent::{ use crate::actors::user_agent::{
AuthPublicKey, UserAgentConnection, UserAgentConnection,
auth::state::{AuthContext, AuthStateMachine}, auth::state::{AuthContext, AuthStateMachine},
}; };
mod state; mod state;
use state::*; use state::{
AuthError, AuthEvents, AuthStates, BootstrapAuthRequest, ChallengeRequest, ChallengeSolution,
};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum Inbound { pub enum Inbound {
AuthChallengeRequest { AuthChallengeRequest {
pubkey: AuthPublicKey, pubkey: authn::PublicKey,
bootstrap_token: Option<String>, bootstrap_token: Option<String>,
}, },
AuthChallengeSolution { AuthChallengeSolution {
@@ -37,6 +39,13 @@ impl Error {
} }
} }
impl From<diesel::result::Error> for Error {
fn from(e: diesel::result::Error) -> Self {
error!(?e, "Database error");
Self::internal("Database error")
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum Outbound { pub enum Outbound {
AuthChallenge { nonce: i32 }, AuthChallenge { nonce: i32 },
@@ -64,7 +73,7 @@ fn parse_auth_event(payload: Inbound) -> AuthEvents {
pub async fn authenticate<T>( pub async fn authenticate<T>(
props: &mut UserAgentConnection, props: &mut UserAgentConnection,
transport: T, transport: T,
) -> Result<AuthPublicKey, Error> ) -> Result<authn::PublicKey, Error>
where where
T: Bi<Inbound, Result<Outbound, Error>> + Send, T: Bi<Inbound, Result<Outbound, Error>> + Send,
{ {

View File

@@ -1,39 +1,33 @@
use arbiter_crypto::authn::{self, USERAGENT_CONTEXT};
use arbiter_proto::transport::Bi; use arbiter_proto::transport::Bi;
use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, update}; use diesel::{ExpressionMethods as _, OptionalExtension as _, QueryDsl, update};
use diesel_async::RunQueryDsl; use diesel_async::{AsyncConnection, RunQueryDsl};
use kameo::error::SendError; use kameo::actor::ActorRef;
use tracing::error; use tracing::error;
use super::Error; use super::Error;
use crate::{ use crate::{
actors::{ actors::{
bootstrap::ConsumeToken, bootstrap::ConsumeToken,
keyholder::{self, SignIntegrityTag}, keyholder::KeyHolder,
user_agent::{AuthPublicKey, UserAgentConnection, auth::Outbound}, user_agent::{UserAgentConnection, UserAgentCredentials, auth::Outbound},
}, },
crypto::integrity::v1::USERAGENT_INTEGRITY_TAG, crypto::integrity,
db::schema, db::{DatabasePool, schema::useragent_client},
}; };
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AttestationStatus {
Attested,
NotAttested,
Unavailable,
}
pub struct ChallengeRequest { pub struct ChallengeRequest {
pub pubkey: AuthPublicKey, pub pubkey: authn::PublicKey,
} }
pub struct BootstrapAuthRequest { pub struct BootstrapAuthRequest {
pub pubkey: AuthPublicKey, pub pubkey: authn::PublicKey,
pub token: String, pub token: String,
} }
pub struct ChallengeContext { pub struct ChallengeContext {
pub challenge_nonce: i32, pub challenge_nonce: i32,
pub key: AuthPublicKey, pub key: authn::PublicKey,
} }
pub struct ChallengeSolution { pub struct ChallengeSolution {
@@ -45,16 +39,16 @@ smlang::statemachine!(
custom_error: true, custom_error: true,
transitions: { transitions: {
*Init + AuthRequest(ChallengeRequest) / async prepare_challenge = SentChallenge(ChallengeContext), *Init + AuthRequest(ChallengeRequest) / async prepare_challenge = SentChallenge(ChallengeContext),
Init + BootstrapAuthRequest(BootstrapAuthRequest) / async verify_bootstrap_token = AuthOk(AuthPublicKey), Init + BootstrapAuthRequest(BootstrapAuthRequest) / async verify_bootstrap_token = AuthOk(authn::PublicKey),
SentChallenge(ChallengeContext) + ReceivedSolution(ChallengeSolution) / async verify_solution = AuthOk(AuthPublicKey), SentChallenge(ChallengeContext) + ReceivedSolution(ChallengeSolution) / async verify_solution = AuthOk(authn::PublicKey),
} }
); );
async fn create_nonce( /// Returns the current nonce, ready to use for the challenge nonce.
db: &crate::db::DatabasePool, async fn get_current_nonce_and_id(
pubkey_bytes: &[u8], db: &DatabasePool,
key_type: crate::db::models::KeyType, key: &authn::PublicKey,
) -> Result<i32, Error> { ) -> Result<(i32, i32), Error> {
let mut db_conn = db.get().await.map_err(|e| { let mut db_conn = db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error"); error!(error = ?e, "Database pool error");
Error::internal("Database unavailable") Error::internal("Database unavailable")
@@ -62,21 +56,11 @@ async fn create_nonce(
db_conn db_conn
.exclusive_transaction(|conn| { .exclusive_transaction(|conn| {
Box::pin(async move { Box::pin(async move {
let current_nonce = schema::useragent_client::table useragent_client::table
.filter(schema::useragent_client::public_key.eq(pubkey_bytes.to_vec())) .filter(useragent_client::public_key.eq(key.to_bytes()))
.filter(schema::useragent_client::key_type.eq(key_type)) .select((useragent_client::id, useragent_client::nonce))
.select(schema::useragent_client::nonce) .first::<(i32, i32)>(conn)
.first::<i32>(conn) .await
.await?;
update(schema::useragent_client::table)
.filter(schema::useragent_client::public_key.eq(pubkey_bytes.to_vec()))
.filter(schema::useragent_client::key_type.eq(key_type))
.set(schema::useragent_client::nonce.eq(current_nonce + 1))
.execute(conn)
.await?;
Result::<_, diesel::result::Error>::Ok(current_nonce)
}) })
}) })
.await .await
@@ -86,36 +70,130 @@ async fn create_nonce(
Error::internal("Database operation failed") Error::internal("Database operation failed")
})? })?
.ok_or_else(|| { .ok_or_else(|| {
error!(?pubkey_bytes, "Public key not found in database"); error!(?key, "Public key not found in database");
Error::UnregisteredPublicKey Error::UnregisteredPublicKey
}) })
} }
async fn register_key( async fn verify_integrity(
db: &crate::db::DatabasePool, db: &DatabasePool,
pubkey: &AuthPublicKey, keyholder: &ActorRef<KeyHolder>,
integrity_tag: Option<Vec<u8>>, pubkey: &authn::PublicKey,
) -> Result<(), Error> { ) -> Result<(), Error> {
let pubkey_bytes = pubkey.to_stored_bytes(); let mut db_conn = db.get().await.map_err(|e| {
let key_type = pubkey.key_type(); error!(error = ?e, "Database pool error");
Error::internal("Database unavailable")
})?;
let (id, nonce) = get_current_nonce_and_id(db, pubkey).await?;
let _result = integrity::verify_entity(
&mut db_conn,
keyholder,
&UserAgentCredentials {
pubkey: pubkey.clone(),
nonce,
},
id,
)
.await
.map_err(|e| {
error!(?e, "Integrity verification failed");
Error::internal("Integrity verification failed")
})?;
Ok(())
}
async fn create_nonce(
db: &DatabasePool,
keyholder: &ActorRef<KeyHolder>,
pubkey: &authn::PublicKey,
) -> Result<i32, Error> {
let mut db_conn = db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error");
Error::internal("Database unavailable")
})?;
let new_nonce = db_conn
.exclusive_transaction(|conn| {
Box::pin(async move {
let (id, new_nonce): (i32, i32) = update(useragent_client::table)
.filter(useragent_client::public_key.eq(pubkey.to_bytes()))
.set(useragent_client::nonce.eq(useragent_client::nonce + 1))
.returning((useragent_client::id, useragent_client::nonce))
.get_result(conn)
.await
.map_err(|e| {
error!(error = ?e, "Database error");
Error::internal("Database operation failed")
})?;
integrity::sign_entity(
conn,
keyholder,
&UserAgentCredentials {
pubkey: pubkey.clone(),
nonce: new_nonce,
},
id,
)
.await
.map_err(|e| {
error!(?e, "Integrity signature update failed");
Error::internal("Database error")
})?;
Result::<_, Error>::Ok(new_nonce)
})
})
.await?;
Ok(new_nonce)
}
async fn register_key(
db: &DatabasePool,
keyholder: &ActorRef<KeyHolder>,
pubkey: &authn::PublicKey,
) -> Result<(), Error> {
let pubkey_bytes = pubkey.to_bytes();
let mut conn = db.get().await.map_err(|e| { let mut conn = db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error"); error!(error = ?e, "Database pool error");
Error::internal("Database unavailable") Error::internal("Database unavailable")
})?; })?;
diesel::insert_into(schema::useragent_client::table) conn.transaction(|conn| {
.values(( Box::pin(async move {
schema::useragent_client::public_key.eq(pubkey_bytes), const NONCE_START: i32 = 1;
schema::useragent_client::nonce.eq(1),
schema::useragent_client::key_type.eq(key_type), let id: i32 = diesel::insert_into(useragent_client::table)
schema::useragent_client::pubkey_integrity_tag.eq(integrity_tag), .values((
)) useragent_client::public_key.eq(pubkey_bytes),
.execute(&mut conn) useragent_client::nonce.eq(NONCE_START),
.await ))
.map_err(|e| { .returning(useragent_client::id)
error!(error = ?e, "Database error"); .get_result(conn)
Error::internal("Database operation failed") .await
})?; .map_err(|e| {
error!(error = ?e, "Database error");
Error::internal("Database operation failed")
})?;
let entity = UserAgentCredentials {
pubkey: pubkey.clone(),
nonce: NONCE_START,
};
integrity::sign_entity(conn, keyholder, &entity, id)
.await
.map_err(|e| {
error!(error = ?e, "Failed to sign integrity tag for new user-agent key");
Error::internal("Failed to register public key")
})?;
Result::<_, Error>::Ok(())
})
})
.await?;
Ok(()) Ok(())
} }
@@ -126,14 +204,14 @@ pub struct AuthContext<'a, T> {
} }
impl<'a, T> AuthContext<'a, T> { impl<'a, T> AuthContext<'a, T> {
pub fn new(conn: &'a mut UserAgentConnection, transport: T) -> Self { pub const fn new(conn: &'a mut UserAgentConnection, transport: T) -> Self {
Self { conn, transport } Self { conn, transport }
} }
} }
impl<T> AuthStateMachineContext for AuthContext<'_, T> impl<T> AuthStateMachineContext for AuthContext<'_, T>
where where
T: Bi<super::Inbound, Result<super::Outbound, Error>> + Send, T: Bi<super::Inbound, Result<Outbound, Error>> + Send,
{ {
type Error = Error; type Error = Error;
@@ -141,15 +219,9 @@ where
&mut self, &mut self,
ChallengeRequest { pubkey }: ChallengeRequest, ChallengeRequest { pubkey }: ChallengeRequest,
) -> Result<ChallengeContext, Self::Error> { ) -> Result<ChallengeContext, Self::Error> {
match self.verify_pubkey_attestation_status(&pubkey).await? { verify_integrity(&self.conn.db, &self.conn.actors.key_holder, &pubkey).await?;
AttestationStatus::Attested | AttestationStatus::Unavailable => {}
AttestationStatus::NotAttested => {
return Err(Error::InvalidChallengeSolution);
}
}
let stored_bytes = pubkey.to_stored_bytes(); let nonce = create_nonce(&self.conn.db, &self.conn.actors.key_holder, &pubkey).await?;
let nonce = create_nonce(&self.conn.db, &stored_bytes, pubkey.key_type()).await?;
self.transport self.transport
.send(Ok(Outbound::AuthChallenge { nonce })) .send(Ok(Outbound::AuthChallenge { nonce }))
@@ -165,12 +237,10 @@ where
}) })
} }
#[allow(missing_docs)]
#[allow(clippy::result_unit_err)]
async fn verify_bootstrap_token( async fn verify_bootstrap_token(
&mut self, &mut self,
BootstrapAuthRequest { pubkey, token }: BootstrapAuthRequest, BootstrapAuthRequest { pubkey, token }: BootstrapAuthRequest,
) -> Result<AuthPublicKey, Self::Error> { ) -> Result<authn::PublicKey, Self::Error> {
let token_ok: bool = self let token_ok: bool = self
.conn .conn
.actors .actors
@@ -189,26 +259,23 @@ where
return Err(Error::InvalidBootstrapToken); return Err(Error::InvalidBootstrapToken);
} }
let integrity_tag = self if token_ok {
.try_sign_pubkey_integrity_tag(&pubkey) register_key(&self.conn.db, &self.conn.actors.key_holder, &pubkey).await?;
.await self.transport
.map_err(|err| { .send(Ok(Outbound::AuthSuccess))
error!(?err, "Failed to sign user-agent pubkey integrity tag"); .await
Error::internal("Failed to sign user-agent pubkey integrity tag") .map_err(|_| Error::Transport)?;
})?; Ok(pubkey)
} else {
register_key(&self.conn.db, &pubkey, integrity_tag).await?; error!("Invalid bootstrap token provided");
self.transport
self.transport .send(Err(Error::InvalidBootstrapToken))
.send(Ok(Outbound::AuthSuccess)) .await
.await .map_err(|_| Error::Transport)?;
.map_err(|_| Error::Transport)?; Err(Error::InvalidBootstrapToken)
}
Ok(pubkey)
} }
#[allow(missing_docs)]
#[allow(clippy::unused_unit)]
async fn verify_solution( async fn verify_solution(
&mut self, &mut self,
ChallengeContext { ChallengeContext {
@@ -216,141 +283,26 @@ where
key, key,
}: &ChallengeContext, }: &ChallengeContext,
ChallengeSolution { solution }: ChallengeSolution, ChallengeSolution { solution }: ChallengeSolution,
) -> Result<AuthPublicKey, Self::Error> { ) -> Result<authn::PublicKey, Self::Error> {
let formatted = arbiter_proto::format_challenge(*challenge_nonce, &key.to_stored_bytes()); let signature = authn::Signature::try_from(solution.as_slice()).map_err(|()| {
error!("Failed to decode signature in challenge solution");
Error::InvalidChallengeSolution
})?;
let valid = match key { let valid = key.verify(*challenge_nonce, USERAGENT_CONTEXT, &signature);
AuthPublicKey::Ed25519(vk) => {
let sig = solution.as_slice().try_into().map_err(|_| {
error!(?solution, "Invalid Ed25519 signature length");
Error::InvalidChallengeSolution
})?;
vk.verify_strict(&formatted, &sig).is_ok()
}
AuthPublicKey::EcdsaSecp256k1(vk) => {
use k256::ecdsa::signature::Verifier as _;
let sig = k256::ecdsa::Signature::try_from(solution.as_slice()).map_err(|_| {
error!(?solution, "Invalid ECDSA signature bytes");
Error::InvalidChallengeSolution
})?;
vk.verify(&formatted, &sig).is_ok()
}
AuthPublicKey::Rsa(pk) => {
use rsa::signature::Verifier as _;
let verifying_key = rsa::pss::VerifyingKey::<sha2::Sha256>::new(pk.clone());
let sig = rsa::pss::Signature::try_from(solution.as_slice()).map_err(|_| {
error!(?solution, "Invalid RSA signature bytes");
Error::InvalidChallengeSolution
})?;
verifying_key.verify(&formatted, &sig).is_ok()
}
};
match valid { if valid {
true => { self.transport
self.transport .send(Ok(Outbound::AuthSuccess))
.send(Ok(Outbound::AuthSuccess))
.await
.map_err(|_| Error::Transport)?;
Ok(key.clone())
}
false => {
self.transport
.send(Err(Error::InvalidChallengeSolution))
.await
.map_err(|_| Error::Transport)?;
Err(Error::InvalidChallengeSolution)
}
}
}
}
impl<T> AuthContext<'_, T>
where
T: Bi<super::Inbound, Result<super::Outbound, Error>> + Send,
{
async fn try_sign_pubkey_integrity_tag(
&self,
pubkey: &AuthPublicKey,
) -> Result<Option<Vec<u8>>, Error> {
let signed = self
.conn
.actors
.key_holder
.ask(SignIntegrityTag {
purpose_tag: USERAGENT_INTEGRITY_TAG.to_vec(),
data_parts: vec![
(pubkey.key_type() as i32).to_be_bytes().to_vec(),
pubkey.to_stored_bytes(),
],
})
.await;
match signed {
Ok(tag) => Ok(Some(tag)),
Err(SendError::HandlerError(keyholder::Error::NotBootstrapped)) => Ok(None),
Err(SendError::HandlerError(err)) => {
error!(
?err,
"Keyholder failed to sign user-agent pubkey integrity tag"
);
Err(Error::internal(
"Keyholder failed to sign user-agent pubkey integrity tag",
))
}
Err(err) => {
error!(
?err,
"Failed to contact keyholder for user-agent pubkey integrity tag"
);
Err(Error::internal(
"Failed to contact keyholder for user-agent pubkey integrity tag",
))
}
}
}
async fn verify_pubkey_attestation_status(
&self,
pubkey: &AuthPublicKey,
) -> Result<AttestationStatus, Error> {
let stored_tag: Option<Option<Vec<u8>>> = {
let mut conn = self.conn.db.get().await.map_err(|e| {
error!(error = ?e, "Database pool error");
Error::internal("Database unavailable")
})?;
schema::useragent_client::table
.filter(schema::useragent_client::public_key.eq(pubkey.to_stored_bytes()))
.filter(schema::useragent_client::key_type.eq(pubkey.key_type()))
.select(schema::useragent_client::pubkey_integrity_tag)
.first::<Option<Vec<u8>>>(&mut conn)
.await .await
.optional() .map_err(|_| Error::Transport)?;
.map_err(|e| { Ok(key.clone())
error!(error = ?e, "Database error"); } else {
Error::internal("Database operation failed") self.transport
})? .send(Err(Error::InvalidChallengeSolution))
}; .await
.map_err(|_| Error::Transport)?;
let Some(stored_tag) = stored_tag else { Err(Error::InvalidChallengeSolution)
return Err(Error::UnregisteredPublicKey);
};
let Some(expected_tag) = self.try_sign_pubkey_integrity_tag(pubkey).await? else {
return Ok(AttestationStatus::Unavailable);
};
match stored_tag {
Some(stored_tag) if stored_tag == expected_tag => Ok(AttestationStatus::Attested),
Some(_) => {
error!("User-agent pubkey integrity tag mismatch");
Ok(AttestationStatus::NotAttested)
}
None => {
error!("Missing pubkey integrity tag for registered key while vault is unsealed");
Ok(AttestationStatus::NotAttested)
}
} }
} }
} }

View File

@@ -1,79 +1,25 @@
use crate::{ use crate::{
actors::{GlobalActors, client::ClientProfile}, actors::{GlobalActors, client::ClientProfile},
db::{self, models::KeyType}, crypto::integrity::Integrable,
db,
}; };
use arbiter_crypto::authn;
/// Abstraction over Ed25519 / ECDSA-secp256k1 / RSA public keys used during the auth handshake. #[derive(Debug, arbiter_macros::Hashable)]
#[derive(Clone, Debug)] pub struct UserAgentCredentials {
pub enum AuthPublicKey { pub pubkey: authn::PublicKey,
Ed25519(ed25519_dalek::VerifyingKey), pub nonce: i32,
/// Compressed SEC1 public key; signature bytes are raw 64-byte (r||s).
EcdsaSecp256k1(k256::ecdsa::VerifyingKey),
/// RSA-2048+ public key (Windows Hello / KeyCredentialManager); signature bytes are PSS+SHA-256.
Rsa(rsa::RsaPublicKey),
} }
impl AuthPublicKey { impl Integrable for UserAgentCredentials {
/// Canonical bytes stored in DB and echoed back in the challenge. const KIND: &'static str = "useragent_credentials";
/// Ed25519: raw 32 bytes. ECDSA: SEC1 compressed 33 bytes. RSA: DER-encoded SPKI.
pub fn to_stored_bytes(&self) -> Vec<u8> {
match self {
AuthPublicKey::Ed25519(k) => k.to_bytes().to_vec(),
// SEC1 compressed (33 bytes) is the natural compact format for secp256k1
AuthPublicKey::EcdsaSecp256k1(k) => k.to_encoded_point(true).as_bytes().to_vec(),
AuthPublicKey::Rsa(k) => {
use rsa::pkcs8::EncodePublicKey as _;
#[allow(clippy::expect_used)]
k.to_public_key_der()
.expect("rsa SPKI encoding is infallible")
.to_vec()
}
}
}
pub fn key_type(&self) -> KeyType {
match self {
AuthPublicKey::Ed25519(_) => KeyType::Ed25519,
AuthPublicKey::EcdsaSecp256k1(_) => KeyType::EcdsaSecp256k1,
AuthPublicKey::Rsa(_) => KeyType::Rsa,
}
}
}
impl TryFrom<(KeyType, Vec<u8>)> for AuthPublicKey {
type Error = &'static str;
fn try_from(value: (KeyType, Vec<u8>)) -> Result<Self, Self::Error> {
let (key_type, bytes) = value;
match key_type {
KeyType::Ed25519 => {
let bytes: [u8; 32] = bytes.try_into().map_err(|_| "invalid Ed25519 key length")?;
let key = ed25519_dalek::VerifyingKey::from_bytes(&bytes)
.map_err(|_e| "invalid Ed25519 key")?;
Ok(AuthPublicKey::Ed25519(key))
}
KeyType::EcdsaSecp256k1 => {
let point =
k256::EncodedPoint::from_bytes(&bytes).map_err(|_e| "invalid ECDSA key")?;
let key = k256::ecdsa::VerifyingKey::from_encoded_point(&point)
.map_err(|_e| "invalid ECDSA key")?;
Ok(AuthPublicKey::EcdsaSecp256k1(key))
}
KeyType::Rsa => {
use rsa::pkcs8::DecodePublicKey as _;
let key = rsa::RsaPublicKey::from_public_key_der(&bytes)
.map_err(|_e| "invalid RSA key")?;
Ok(AuthPublicKey::Rsa(key))
}
}
}
} }
// Messages, sent by user agent to connection client without having a request // Messages, sent by user agent to connection client without having a request
#[derive(Debug)] #[derive(Debug)]
pub enum OutOfBand { pub enum OutOfBand {
ClientConnectionRequest { profile: ClientProfile }, ClientConnectionRequest { profile: ClientProfile },
ClientConnectionCancel { pubkey: ed25519_dalek::VerifyingKey }, ClientConnectionCancel { pubkey: authn::PublicKey },
} }
pub struct UserAgentConnection { pub struct UserAgentConnection {
@@ -82,7 +28,7 @@ pub struct UserAgentConnection {
} }
impl UserAgentConnection { impl UserAgentConnection {
pub fn new(db: db::DatabasePool, actors: GlobalActors) -> Self { pub const fn new(db: db::DatabasePool, actors: GlobalActors) -> Self {
Self { db, actors } Self { db, actors }
} }
} }

View File

@@ -1,8 +1,9 @@
use arbiter_crypto::authn;
use std::{borrow::Cow, collections::HashMap}; use std::{borrow::Cow, collections::HashMap};
use arbiter_proto::transport::Sender; use arbiter_proto::transport::Sender;
use async_trait::async_trait; use async_trait::async_trait;
use ed25519_dalek::VerifyingKey;
use kameo::{Actor, actor::ActorRef, messages}; use kameo::{Actor, actor::ActorRef, messages};
use thiserror::Error; use thiserror::Error;
use tracing::error; use tracing::error;
@@ -12,33 +13,32 @@ use crate::actors::{
flow_coordinator::{RegisterUserAgent, client_connect_approval::ClientApprovalController}, flow_coordinator::{RegisterUserAgent, client_connect_approval::ClientApprovalController},
user_agent::{OutOfBand, UserAgentConnection}, user_agent::{OutOfBand, UserAgentConnection},
}; };
mod state; mod state;
use state::{DummyContext, UserAgentEvents, UserAgentStateMachine}; use state::{DummyContext, UserAgentEvents, UserAgentStateMachine};
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum Error { pub enum UserAgentSessionError {
#[error("State transition failed")]
State,
#[error("Internal error: {message}")] #[error("Internal error: {message}")]
Internal { message: Cow<'static, str> }, Internal { message: Cow<'static, str> },
#[error("State transition failed")]
State,
} }
impl From<crate::db::PoolError> for Error { impl From<crate::db::PoolError> for UserAgentSessionError {
fn from(err: crate::db::PoolError) -> Self { fn from(err: crate::db::PoolError) -> Self {
error!(?err, "Database pool error"); error!(?err, "Database pool error");
Self::internal("Database pool error") Self::internal("Database pool error")
} }
} }
impl From<diesel::result::Error> for Error { impl From<diesel::result::Error> for UserAgentSessionError {
fn from(err: diesel::result::Error) -> Self { fn from(err: diesel::result::Error) -> Self {
error!(?err, "Database error"); error!(?err, "Database error");
Self::internal("Database error") Self::internal("Database error")
} }
} }
impl Error { impl UserAgentSessionError {
pub fn internal(message: impl Into<Cow<'static, str>>) -> Self { pub fn internal(message: impl Into<Cow<'static, str>>) -> Self {
Self::Internal { Self::Internal {
message: message.into(), message: message.into(),
@@ -47,6 +47,7 @@ impl Error {
} }
pub struct PendingClientApproval { pub struct PendingClientApproval {
pubkey: authn::PublicKey,
controller: ActorRef<ClientApprovalController>, controller: ActorRef<ClientApprovalController>,
} }
@@ -55,7 +56,7 @@ pub struct UserAgentSession {
state: UserAgentStateMachine<DummyContext>, state: UserAgentStateMachine<DummyContext>,
sender: Box<dyn Sender<OutOfBand>>, sender: Box<dyn Sender<OutOfBand>>,
pending_client_approvals: HashMap<VerifyingKey, PendingClientApproval>, pending_client_approvals: HashMap<Vec<u8>, PendingClientApproval>,
} }
pub mod connection; pub mod connection;
@@ -66,7 +67,7 @@ impl UserAgentSession {
props, props,
state: UserAgentStateMachine::new(DummyContext), state: UserAgentStateMachine::new(DummyContext),
sender, sender,
pending_client_approvals: Default::default(), pending_client_approvals: HashMap::default(),
} }
} }
@@ -86,10 +87,10 @@ impl UserAgentSession {
Self::new(UserAgentConnection::new(db, actors), Box::new(DummySender)) Self::new(UserAgentConnection::new(db, actors), Box::new(DummySender))
} }
fn transition(&mut self, event: UserAgentEvents) -> Result<(), Error> { fn transition(&mut self, event: UserAgentEvents) -> Result<(), UserAgentSessionError> {
self.state.process_event(event).map_err(|e| { self.state.process_event(event).map_err(|e| {
error!(?e, "State transition failed"); error!(?e, "State transition failed");
Error::State UserAgentSessionError::State
})?; })?;
Ok(()) Ok(())
} }
@@ -118,19 +119,24 @@ impl UserAgentSession {
return; return;
} }
self.pending_client_approvals self.pending_client_approvals.insert(
.insert(client.pubkey, PendingClientApproval { controller }); client.pubkey.to_bytes(),
PendingClientApproval {
pubkey: client.pubkey,
controller,
},
);
} }
} }
impl Actor for UserAgentSession { impl Actor for UserAgentSession {
type Args = Self; type Args = Self;
type Error = Error; type Error = UserAgentSessionError;
async fn on_start( async fn on_start(
args: Self::Args, args: Self::Args,
this: kameo::prelude::ActorRef<Self>, this: ActorRef<Self>,
) -> Result<Self, Self::Error> { ) -> Result<Self, Self::Error> {
args.props args.props
.actors .actors
@@ -144,7 +150,9 @@ impl Actor for UserAgentSession {
?err, ?err,
"Failed to register user agent connection with flow coordinator" "Failed to register user agent connection with flow coordinator"
); );
Error::internal("Failed to register user agent connection with flow coordinator") UserAgentSessionError::internal(
"Failed to register user agent connection with flow coordinator",
)
})?; })?;
Ok(args) Ok(args)
} }
@@ -158,14 +166,18 @@ impl Actor for UserAgentSession {
let cancelled_pubkey = self let cancelled_pubkey = self
.pending_client_approvals .pending_client_approvals
.iter() .iter()
.find_map(|(k, v)| (v.controller.id() == id).then_some(*k)); .find_map(|(k, v)| (v.controller.id() == id).then_some(k.clone()));
if let Some(pubkey) = cancelled_pubkey { if let Some(pubkey_bytes) = cancelled_pubkey {
self.pending_client_approvals.remove(&pubkey); let Some(approval) = self.pending_client_approvals.remove(&pubkey_bytes) else {
return Ok(std::ops::ControlFlow::Continue(()));
};
if let Err(e) = self if let Err(e) = self
.sender .sender
.send(OutOfBand::ClientConnectionCancel { pubkey }) .send(OutOfBand::ClientConnectionCancel {
pubkey: approval.pubkey,
})
.await .await
{ {
error!( error!(

View File

@@ -1,43 +1,46 @@
use std::sync::Mutex; use std::sync::Mutex;
use alloy::{consensus::TxEip1559, primitives::Address, signers::Signature}; use alloy::{consensus::TxEip1559, primitives::Address, signers::Signature};
use arbiter_crypto::{
authn,
safecell::{SafeCell, SafeCellHandle as _},
};
use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit}; use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit};
use diesel::{ExpressionMethods as _, QueryDsl as _, SelectableHelper}; use diesel::{ExpressionMethods as _, QueryDsl as _, SelectableHelper};
use diesel_async::{AsyncConnection, RunQueryDsl}; use diesel_async::{AsyncConnection, RunQueryDsl};
use kameo::error::SendError; use kameo::error::SendError;
use kameo::messages; use kameo::messages;
use kameo::prelude::Context; use kameo::prelude::Context;
use thiserror::Error;
use tracing::{error, info}; use tracing::{error, info};
use x25519_dalek::{EphemeralSecret, PublicKey}; use x25519_dalek::{EphemeralSecret, PublicKey};
use crate::actors::flow_coordinator::client_connect_approval::ClientApprovalAnswer; use crate::{actors::flow_coordinator::client_connect_approval::ClientApprovalAnswer, evm::policies::SharedGrantSettings};
use crate::actors::keyholder::KeyHolderState; use crate::actors::keyholder::KeyHolderState;
use crate::actors::user_agent::session::Error; use crate::actors::user_agent::session::UserAgentSessionError;
use crate::actors::{
evm::{
ClientSignTransaction, Generate, ListWallets, SignTransactionError as EvmSignError,
UseragentCreateGrant, UseragentListGrants,
},
keyholder::{self, Bootstrap, TryUnseal},
user_agent::session::{
UserAgentSession,
state::{UnsealContext, UserAgentEvents, UserAgentStates},
},
};
use crate::db::models::{ use crate::db::models::{
EvmWalletAccess, NewEvmWalletAccess, ProgramClient, ProgramClientMetadata, EvmWalletAccess, NewEvmWalletAccess, ProgramClient, ProgramClientMetadata,
}; };
use crate::evm::policies::{Grant, SpecificGrant}; use crate::evm::policies::{Grant, SpecificGrant};
use crate::safe_cell::SafeCell;
use crate::{
actors::{
evm::{
ClientSignTransaction, Generate, ListWallets, SignTransactionError as EvmSignError,
UseragentCreateGrant, UseragentDeleteGrant, UseragentListGrants,
},
keyholder::{self, Bootstrap, TryUnseal},
user_agent::session::{
UserAgentSession,
state::{UnsealContext, UserAgentEvents, UserAgentStates},
},
},
safe_cell::SafeCellHandle as _,
};
impl UserAgentSession { impl UserAgentSession {
fn take_unseal_secret(&mut self) -> Result<(EphemeralSecret, PublicKey), Error> { fn take_unseal_secret(&self) -> Result<(EphemeralSecret, PublicKey), UserAgentSessionError> {
let UserAgentStates::WaitingForUnsealKey(unseal_context) = self.state.state() else { let UserAgentStates::WaitingForUnsealKey(unseal_context) = self.state.state() else {
error!("Received encrypted key in invalid state"); error!("Received encrypted key in invalid state");
return Err(Error::internal("Invalid state for unseal encrypted key")); return Err(UserAgentSessionError::internal(
"Invalid state for unseal encrypted key",
));
}; };
let ephemeral_secret = { let ephemeral_secret = {
@@ -47,13 +50,14 @@ impl UserAgentSession {
)] )]
let mut secret_lock = unseal_context.secret.lock().unwrap(); let mut secret_lock = unseal_context.secret.lock().unwrap();
let secret = secret_lock.take(); let secret = secret_lock.take();
match secret { if let Some(secret) = secret {
Some(secret) => secret, secret
None => { } else {
drop(secret_lock); drop(secret_lock);
error!("Ephemeral secret already taken"); error!("Ephemeral secret already taken");
return Err(Error::internal("Ephemeral secret already taken")); return Err(UserAgentSessionError::internal(
} "Ephemeral secret already taken",
));
} }
}; };
@@ -79,7 +83,7 @@ impl UserAgentSession {
}); });
match decryption_result { match decryption_result {
Ok(_) => Ok(key_buffer), Ok(()) => Ok(key_buffer),
Err(err) => { Err(err) => {
error!(?err, "Failed to decrypt encrypted key material"); error!(?err, "Failed to decrypt encrypted key material");
Err(()) Err(())
@@ -97,7 +101,7 @@ pub enum UnsealError {
#[error("Invalid key provided for unsealing")] #[error("Invalid key provided for unsealing")]
InvalidKey, InvalidKey,
#[error("Internal error during unsealing process")] #[error("Internal error during unsealing process")]
General(#[from] super::Error), General(#[from] UserAgentSessionError),
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]
@@ -108,7 +112,7 @@ pub enum BootstrapError {
AlreadyBootstrapped, AlreadyBootstrapped,
#[error("Internal error during bootstrapping process")] #[error("Internal error during bootstrapping process")]
General(#[from] super::Error), General(#[from] UserAgentSessionError),
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]
@@ -120,19 +124,28 @@ pub enum SignTransactionError {
Internal, Internal,
} }
#[derive(Debug, Error)]
pub enum GrantMutationError {
#[error("Vault is sealed")]
VaultSealed,
#[error("Internal grant mutation error")]
Internal,
}
#[messages] #[messages]
impl UserAgentSession { impl UserAgentSession {
#[message] #[message]
pub async fn handle_unseal_request( pub fn handle_unseal_request(
&mut self, &mut self,
client_pubkey: x25519_dalek::PublicKey, client_pubkey: PublicKey,
) -> Result<UnsealStartResponse, Error> { ) -> Result<UnsealStartResponse, UserAgentSessionError> {
let secret = EphemeralSecret::random(); let secret = EphemeralSecret::random();
let public_key = PublicKey::from(&secret); let public_key = PublicKey::from(&secret);
self.transition(UserAgentEvents::UnsealRequest(UnsealContext { self.transition(UserAgentEvents::UnsealRequest(UnsealContext {
secret: Mutex::new(Some(secret)),
client_public_key: client_pubkey, client_public_key: client_pubkey,
secret: Mutex::new(Some(secret)),
}))?; }))?;
Ok(UnsealStartResponse { Ok(UnsealStartResponse {
@@ -149,27 +162,24 @@ impl UserAgentSession {
) -> Result<(), UnsealError> { ) -> Result<(), UnsealError> {
let (ephemeral_secret, client_public_key) = match self.take_unseal_secret() { let (ephemeral_secret, client_public_key) = match self.take_unseal_secret() {
Ok(values) => values, Ok(values) => values,
Err(Error::State) => { Err(UserAgentSessionError::State) => {
self.transition(UserAgentEvents::ReceivedInvalidKey)?; self.transition(UserAgentEvents::ReceivedInvalidKey)?;
return Err(UnsealError::InvalidKey); return Err(UnsealError::InvalidKey);
} }
Err(_err) => { Err(_err) => {
return Err(Error::internal("Failed to take unseal secret").into()); return Err(UserAgentSessionError::internal("Failed to take unseal secret").into());
} }
}; };
let seal_key_buffer = match Self::decrypt_client_key_material( let Ok(seal_key_buffer) = Self::decrypt_client_key_material(
ephemeral_secret, ephemeral_secret,
client_public_key, client_public_key,
&nonce, &nonce,
&ciphertext, &ciphertext,
&associated_data, &associated_data,
) { ) else {
Ok(buffer) => buffer, self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Err(()) => { return Err(UnsealError::InvalidKey);
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
return Err(UnsealError::InvalidKey);
}
}; };
match self match self
@@ -181,12 +191,12 @@ impl UserAgentSession {
}) })
.await .await
{ {
Ok(_) => { Ok(()) => {
info!("Successfully unsealed key with client-provided key"); info!("Successfully unsealed key with client-provided key");
self.transition(UserAgentEvents::ReceivedValidKey)?; self.transition(UserAgentEvents::ReceivedValidKey)?;
Ok(()) Ok(())
} }
Err(SendError::HandlerError(keyholder::Error::InvalidKey)) => { Err(SendError::HandlerError(keyholder::KeyHolderError::InvalidKey)) => {
self.transition(UserAgentEvents::ReceivedInvalidKey)?; self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Err(UnsealError::InvalidKey) Err(UnsealError::InvalidKey)
} }
@@ -198,7 +208,7 @@ impl UserAgentSession {
Err(err) => { Err(err) => {
error!(?err, "Failed to send unseal request to keyholder"); error!(?err, "Failed to send unseal request to keyholder");
self.transition(UserAgentEvents::ReceivedInvalidKey)?; self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Err(Error::internal("Vault actor error").into()) Err(UserAgentSessionError::internal("Vault actor error").into())
} }
} }
} }
@@ -212,25 +222,22 @@ impl UserAgentSession {
) -> Result<(), BootstrapError> { ) -> Result<(), BootstrapError> {
let (ephemeral_secret, client_public_key) = match self.take_unseal_secret() { let (ephemeral_secret, client_public_key) = match self.take_unseal_secret() {
Ok(values) => values, Ok(values) => values,
Err(Error::State) => { Err(UserAgentSessionError::State) => {
self.transition(UserAgentEvents::ReceivedInvalidKey)?; self.transition(UserAgentEvents::ReceivedInvalidKey)?;
return Err(BootstrapError::InvalidKey); return Err(BootstrapError::InvalidKey);
} }
Err(err) => return Err(err.into()), Err(err) => return Err(err.into()),
}; };
let seal_key_buffer = match Self::decrypt_client_key_material( let Ok(seal_key_buffer) = Self::decrypt_client_key_material(
ephemeral_secret, ephemeral_secret,
client_public_key, client_public_key,
&nonce, &nonce,
&ciphertext, &ciphertext,
&associated_data, &associated_data,
) { ) else {
Ok(buffer) => buffer, self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Err(()) => { return Err(BootstrapError::InvalidKey);
self.transition(UserAgentEvents::ReceivedInvalidKey)?;
return Err(BootstrapError::InvalidKey);
}
}; };
match self match self
@@ -242,12 +249,12 @@ impl UserAgentSession {
}) })
.await .await
{ {
Ok(_) => { Ok(()) => {
info!("Successfully bootstrapped vault with client-provided key"); info!("Successfully bootstrapped vault with client-provided key");
self.transition(UserAgentEvents::ReceivedValidKey)?; self.transition(UserAgentEvents::ReceivedValidKey)?;
Ok(()) Ok(())
} }
Err(SendError::HandlerError(keyholder::Error::AlreadyBootstrapped)) => { Err(SendError::HandlerError(keyholder::KeyHolderError::AlreadyBootstrapped)) => {
self.transition(UserAgentEvents::ReceivedInvalidKey)?; self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Err(BootstrapError::AlreadyBootstrapped) Err(BootstrapError::AlreadyBootstrapped)
} }
@@ -259,7 +266,7 @@ impl UserAgentSession {
Err(err) => { Err(err) => {
error!(?err, "Failed to send bootstrap request to keyholder"); error!(?err, "Failed to send bootstrap request to keyholder");
self.transition(UserAgentEvents::ReceivedInvalidKey)?; self.transition(UserAgentEvents::ReceivedInvalidKey)?;
Err(BootstrapError::General(Error::internal( Err(BootstrapError::General(UserAgentSessionError::internal(
"Vault actor error", "Vault actor error",
))) )))
} }
@@ -270,14 +277,16 @@ impl UserAgentSession {
#[messages] #[messages]
impl UserAgentSession { impl UserAgentSession {
#[message] #[message]
pub(crate) async fn handle_query_vault_state(&mut self) -> Result<KeyHolderState, Error> { pub(crate) async fn handle_query_vault_state(
&mut self,
) -> Result<KeyHolderState, UserAgentSessionError> {
use crate::actors::keyholder::GetState; use crate::actors::keyholder::GetState;
let vault_state = match self.props.actors.key_holder.ask(GetState {}).await { let vault_state = match self.props.actors.key_holder.ask(GetState {}).await {
Ok(state) => state, Ok(state) => state,
Err(err) => { Err(err) => {
error!(?err, actor = "useragent", "keyholder.query.failed"); error!(?err, actor = "useragent", "keyholder.query.failed");
return Err(Error::internal("Vault is in broken state")); return Err(UserAgentSessionError::internal("Vault is in broken state"));
} }
}; };
@@ -288,26 +297,32 @@ impl UserAgentSession {
#[messages] #[messages]
impl UserAgentSession { impl UserAgentSession {
#[message] #[message]
pub(crate) async fn handle_evm_wallet_create(&mut self) -> Result<(i32, Address), Error> { pub(crate) async fn handle_evm_wallet_create(
&mut self,
) -> Result<(i32, Address), UserAgentSessionError> {
match self.props.actors.evm.ask(Generate {}).await { match self.props.actors.evm.ask(Generate {}).await {
Ok(address) => Ok(address), Ok(address) => Ok(address),
Err(SendError::HandlerError(err)) => Err(Error::internal(format!( Err(SendError::HandlerError(err)) => Err(UserAgentSessionError::internal(format!(
"EVM wallet generation failed: {err}" "EVM wallet generation failed: {err}"
))), ))),
Err(err) => { Err(err) => {
error!(?err, "EVM actor unreachable during wallet create"); error!(?err, "EVM actor unreachable during wallet create");
Err(Error::internal("EVM actor unreachable")) Err(UserAgentSessionError::internal("EVM actor unreachable"))
} }
} }
} }
#[message] #[message]
pub(crate) async fn handle_evm_wallet_list(&mut self) -> Result<Vec<(i32, Address)>, Error> { pub(crate) async fn handle_evm_wallet_list(
&mut self,
) -> Result<Vec<(i32, Address)>, UserAgentSessionError> {
match self.props.actors.evm.ask(ListWallets {}).await { match self.props.actors.evm.ask(ListWallets {}).await {
Ok(wallets) => Ok(wallets), Ok(wallets) => Ok(wallets),
Err(err) => { Err(err) => {
error!(?err, "EVM wallet list failed"); error!(?err, "EVM wallet list failed");
Err(Error::internal("Failed to list EVM wallets")) Err(UserAgentSessionError::internal(
"Failed to list EVM wallets",
))
} }
} }
} }
@@ -316,12 +331,14 @@ impl UserAgentSession {
#[messages] #[messages]
impl UserAgentSession { impl UserAgentSession {
#[message] #[message]
pub(crate) async fn handle_grant_list(&mut self) -> Result<Vec<Grant<SpecificGrant>>, Error> { pub(crate) async fn handle_grant_list(
&mut self,
) -> Result<Vec<Grant<SpecificGrant>>, UserAgentSessionError> {
match self.props.actors.evm.ask(UseragentListGrants {}).await { match self.props.actors.evm.ask(UseragentListGrants {}).await {
Ok(grants) => Ok(grants), Ok(grants) => Ok(grants),
Err(err) => { Err(err) => {
error!(?err, "EVM grant list failed"); error!(?err, "EVM grant list failed");
Err(Error::internal("Failed to list EVM grants")) Err(UserAgentSessionError::internal("Failed to list EVM grants"))
} }
} }
} }
@@ -329,9 +346,9 @@ impl UserAgentSession {
#[message] #[message]
pub(crate) async fn handle_grant_create( pub(crate) async fn handle_grant_create(
&mut self, &mut self,
basic: crate::evm::policies::SharedGrantSettings, basic: SharedGrantSettings,
grant: crate::evm::policies::SpecificGrant, grant: SpecificGrant,
) -> Result<i32, Error> { ) -> Result<i32, GrantMutationError> {
match self match self
.props .props
.actors .actors
@@ -342,26 +359,32 @@ impl UserAgentSession {
Ok(grant_id) => Ok(grant_id), Ok(grant_id) => Ok(grant_id),
Err(err) => { Err(err) => {
error!(?err, "EVM grant create failed"); error!(?err, "EVM grant create failed");
Err(Error::internal("Failed to create EVM grant")) Err(GrantMutationError::Internal)
} }
} }
} }
#[message] #[message]
pub(crate) async fn handle_grant_delete(&mut self, grant_id: i32) -> Result<(), Error> { #[expect(clippy::unused_async, reason = "false positive")]
match self pub(crate) async fn handle_grant_delete(
.props &mut self,
.actors grant_id: i32,
.evm ) -> Result<(), GrantMutationError> {
.ask(UseragentDeleteGrant { grant_id }) // match self
.await // .props
{ // .actors
Ok(()) => Ok(()), // .evm
Err(err) => { // .ask(UseragentDeleteGrant { grant_id })
error!(?err, "EVM grant delete failed"); // .await
Err(Error::internal("Failed to delete EVM grant")) // {
} // Ok(()) => Ok(()),
} // Err(err) => {
// error!(?err, "EVM grant delete failed");
// Err(GrantMutationError::Internal)
// }
// }
let _ = grant_id;
todo!()
} }
#[message] #[message]
@@ -397,7 +420,7 @@ impl UserAgentSession {
pub(crate) async fn handle_grant_evm_wallet_access( pub(crate) async fn handle_grant_evm_wallet_access(
&mut self, &mut self,
entries: Vec<NewEvmWalletAccess>, entries: Vec<NewEvmWalletAccess>,
) -> Result<(), Error> { ) -> Result<(), UserAgentSessionError> {
let mut conn = self.props.db.get().await?; let mut conn = self.props.db.get().await?;
conn.transaction(|conn| { conn.transaction(|conn| {
Box::pin(async move { Box::pin(async move {
@@ -411,7 +434,7 @@ impl UserAgentSession {
.await?; .await?;
} }
Result::<_, Error>::Ok(()) Result::<_, UserAgentSessionError>::Ok(())
}) })
}) })
.await?; .await?;
@@ -422,7 +445,7 @@ impl UserAgentSession {
pub(crate) async fn handle_revoke_evm_wallet_access( pub(crate) async fn handle_revoke_evm_wallet_access(
&mut self, &mut self,
entries: Vec<i32>, entries: Vec<i32>,
) -> Result<(), Error> { ) -> Result<(), UserAgentSessionError> {
let mut conn = self.props.db.get().await?; let mut conn = self.props.db.get().await?;
conn.transaction(|conn| { conn.transaction(|conn| {
Box::pin(async move { Box::pin(async move {
@@ -434,7 +457,7 @@ impl UserAgentSession {
.await?; .await?;
} }
Result::<_, Error>::Ok(()) Result::<_, UserAgentSessionError>::Ok(())
}) })
}) })
.await?; .await?;
@@ -444,10 +467,9 @@ impl UserAgentSession {
#[message] #[message]
pub(crate) async fn handle_list_wallet_access( pub(crate) async fn handle_list_wallet_access(
&mut self, &mut self,
) -> Result<Vec<EvmWalletAccess>, Error> { ) -> Result<Vec<EvmWalletAccess>, UserAgentSessionError> {
let mut conn = self.props.db.get().await?; let mut conn = self.props.db.get().await?;
use crate::db::schema::evm_wallet_access; let access_entries = crate::db::schema::evm_wallet_access::table
let access_entries = evm_wallet_access::table
.select(EvmWalletAccess::as_select()) .select(EvmWalletAccess::as_select())
.load::<_>(&mut conn) .load::<_>(&mut conn)
.await?; .await?;
@@ -461,15 +483,15 @@ impl UserAgentSession {
pub(crate) async fn handle_new_client_approve( pub(crate) async fn handle_new_client_approve(
&mut self, &mut self,
approved: bool, approved: bool,
pubkey: ed25519_dalek::VerifyingKey, pubkey: authn::PublicKey,
ctx: &mut Context<Self, Result<(), Error>>, ctx: &Context<Self, Result<(), UserAgentSessionError>>,
) -> Result<(), Error> { ) -> Result<(), UserAgentSessionError> {
let pending_approval = match self.pending_client_approvals.remove(&pubkey) { let Some(pending_approval) = self.pending_client_approvals.remove(&pubkey.to_bytes())
Some(approval) => approval, else {
None => { error!("Received client connection response for unknown client");
error!("Received client connection response for unknown client"); return Err(UserAgentSessionError::internal(
return Err(Error::internal("Unknown client in connection response")); "Unknown client in connection response",
} ));
}; };
pending_approval pending_approval
@@ -481,7 +503,9 @@ impl UserAgentSession {
?err, ?err,
"Failed to send client approval response to controller" "Failed to send client approval response to controller"
); );
Error::internal("Failed to send client approval response to controller") UserAgentSessionError::internal(
"Failed to send client approval response to controller",
)
})?; })?;
ctx.actor_ref().unlink(&pending_approval.controller).await; ctx.actor_ref().unlink(&pending_approval.controller).await;
@@ -492,7 +516,7 @@ impl UserAgentSession {
#[message] #[message]
pub(crate) async fn handle_sdk_client_list( pub(crate) async fn handle_sdk_client_list(
&mut self, &mut self,
) -> Result<Vec<(ProgramClient, ProgramClientMetadata)>, Error> { ) -> Result<Vec<(ProgramClient, ProgramClientMetadata)>, UserAgentSessionError> {
use crate::db::schema::{client_metadata, program_client}; use crate::db::schema::{client_metadata, program_client};
let mut conn = self.props.db.get().await?; let mut conn = self.props.db.get().await?;

View File

@@ -19,8 +19,6 @@ smlang::statemachine!(
pub struct DummyContext; pub struct DummyContext;
impl UserAgentStateMachineContext for DummyContext { impl UserAgentStateMachineContext for DummyContext {
#[allow(missing_docs)]
#[allow(clippy::unused_unit)]
fn generate_temp_keypair(&mut self, event_data: UnsealContext) -> Result<UnsealContext, ()> { fn generate_temp_keypair(&mut self, event_data: UnsealContext) -> Result<UnsealContext, ()> {
Ok(event_data) Ok(event_data)
} }

View File

@@ -25,22 +25,22 @@ pub enum InitError {
Tls(#[from] tls::InitError), Tls(#[from] tls::InitError),
#[error("Actor spawn failed: {0}")] #[error("Actor spawn failed: {0}")]
ActorSpawn(#[from] crate::actors::SpawnError), ActorSpawn(#[from] crate::actors::GlobalActorsSpawnError),
#[error("I/O Error: {0}")] #[error("I/O Error: {0}")]
Io(#[from] std::io::Error), Io(#[from] std::io::Error),
} }
pub struct _ServerContextInner { pub struct __ServerContextInner {
pub db: db::DatabasePool, pub db: db::DatabasePool,
pub tls: TlsManager, pub tls: TlsManager,
pub actors: GlobalActors, pub actors: GlobalActors,
} }
#[derive(Clone)] #[derive(Clone)]
pub struct ServerContext(Arc<_ServerContextInner>); pub struct ServerContext(Arc<__ServerContextInner>);
impl std::ops::Deref for ServerContext { impl std::ops::Deref for ServerContext {
type Target = _ServerContextInner; type Target = __ServerContextInner;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
&self.0 &self.0
@@ -49,7 +49,7 @@ impl std::ops::Deref for ServerContext {
impl ServerContext { impl ServerContext {
pub async fn new(db: db::DatabasePool) -> Result<Self, InitError> { pub async fn new(db: db::DatabasePool) -> Result<Self, InitError> {
Ok(Self(Arc::new(_ServerContextInner { Ok(Self(Arc::new(__ServerContextInner {
actors: GlobalActors::spawn(db.clone()).await?, actors: GlobalActors::spawn(db.clone()).await?,
tls: TlsManager::new(db.clone()).await?, tls: TlsManager::new(db.clone()).await?,
db, db,

View File

@@ -22,9 +22,10 @@ use crate::db::{
}; };
const ENCODE_CONFIG: pem::EncodeConfig = { const ENCODE_CONFIG: pem::EncodeConfig = {
let line_ending = match cfg!(target_family = "windows") { let line_ending = if cfg!(target_family = "windows") {
true => pem::LineEnding::CRLF, pem::LineEnding::CRLF
false => pem::LineEnding::LF, } else {
pem::LineEnding::LF
}; };
pem::EncodeConfig::new().set_line_ending(line_ending) pem::EncodeConfig::new().set_line_ending(line_ending)
}; };
@@ -52,11 +53,14 @@ pub enum InitError {
pub type PemCert = String; pub type PemCert = String;
pub fn encode_cert_to_pem(cert: &CertificateDer) -> PemCert { pub fn encode_cert_to_pem(cert: &CertificateDer<'_>) -> PemCert {
pem::encode_config(&Pem::new("CERTIFICATE", cert.to_vec()), ENCODE_CONFIG) pem::encode_config(&Pem::new("CERTIFICATE", cert.to_vec()), ENCODE_CONFIG)
} }
#[allow(unused)] #[expect(
unused,
reason = "may be needed for future cert rotation implementation"
)]
struct SerializedTls { struct SerializedTls {
cert_pem: PemCert, cert_pem: PemCert,
cert_key_pem: String, cert_key_pem: String,
@@ -85,7 +89,7 @@ impl TlsCa {
let cert_key_pem = certified_issuer.key().serialize_pem(); let cert_key_pem = certified_issuer.key().serialize_pem();
#[allow( #[expect(
clippy::unwrap_used, clippy::unwrap_used,
reason = "Broken cert couldn't bootstrap server anyway" reason = "Broken cert couldn't bootstrap server anyway"
)] )]
@@ -124,7 +128,11 @@ impl TlsCa {
}) })
} }
#[allow(unused)] #[expect(
unused,
clippy::unnecessary_wraps,
reason = "may be needed for future cert rotation implementation"
)]
fn serialize(&self) -> Result<SerializedTls, InitError> { fn serialize(&self) -> Result<SerializedTls, InitError> {
let cert_key_pem = self.issuer.key().serialize_pem(); let cert_key_pem = self.issuer.key().serialize_pem();
Ok(SerializedTls { Ok(SerializedTls {
@@ -133,7 +141,10 @@ impl TlsCa {
}) })
} }
#[allow(unused)] #[expect(
unused,
reason = "may be needed for future cert rotation implementation"
)]
fn try_deserialize(cert_pem: &str, cert_key_pem: &str) -> Result<Self, InitError> { fn try_deserialize(cert_pem: &str, cert_key_pem: &str) -> Result<Self, InitError> {
let keypair = let keypair =
KeyPair::from_pem(cert_key_pem).map_err(InitError::KeyDeserializationError)?; KeyPair::from_pem(cert_key_pem).map_err(InitError::KeyDeserializationError)?;
@@ -234,10 +245,10 @@ impl TlsManager {
} }
} }
pub fn cert(&self) -> &CertificateDer<'static> { pub const fn cert(&self) -> &CertificateDer<'static> {
&self.cert &self.cert
} }
pub fn ca_cert(&self) -> &CertificateDer<'static> { pub const fn ca_cert(&self) -> &CertificateDer<'static> {
&self.ca_cert &self.ca_cert
} }

View File

@@ -1 +1,3 @@
pub mod v1; pub mod v1;
pub use v1::*;

View File

@@ -5,8 +5,8 @@ use rand::{
rngs::{StdRng, SysRng}, rngs::{StdRng, SysRng},
}; };
pub const ROOT_KEY_TAG: &[u8] = "arbiter/seal/v1".as_bytes(); pub const ROOT_KEY_TAG: &[u8] = b"arbiter/seal/v1";
pub const TAG: &[u8] = "arbiter/private-key/v1".as_bytes(); pub const TAG: &[u8] = b"arbiter/private-key/v1";
pub const NONCE_LENGTH: usize = 24; pub const NONCE_LENGTH: usize = 24;
@@ -15,11 +15,13 @@ pub struct Nonce(pub [u8; NONCE_LENGTH]);
impl Nonce { impl Nonce {
pub fn increment(&mut self) { pub fn increment(&mut self) {
for i in (0..self.0.len()).rev() { for i in (0..self.0.len()).rev() {
if self.0[i] == 0xFF { if let Some(byte) = self.0.get_mut(i) {
self.0[i] = 0; if *byte == 0xFF {
} else { *byte = 0;
self.0[i] += 1; } else {
break; *byte += 1;
break;
}
} }
} }
} }
@@ -45,24 +47,17 @@ pub type Salt = [u8; ArgonSalt::RECOMMENDED_LENGTH];
pub fn generate_salt() -> Salt { pub fn generate_salt() -> Salt {
let mut salt = Salt::default(); let mut salt = Salt::default();
#[allow( let mut rng =
clippy::unwrap_used, StdRng::try_from_rng(&mut SysRng).expect("Rng failure is unrecoverable and should panic");
reason = "Rng failure is unrecoverable and should panic"
)]
let mut rng = StdRng::try_from_rng(&mut SysRng).unwrap();
rng.fill_bytes(&mut salt); rng.fill_bytes(&mut salt);
salt salt
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::ops::Deref as _;
use super::*; use super::*;
use crate::{ use crate::crypto::derive_key;
crypto::derive_key, use arbiter_crypto::safecell::{SafeCell, SafeCellHandle as _};
safe_cell::{SafeCell, SafeCellHandle as _},
};
#[test] #[test]
pub fn derive_seal_key_deterministic() { pub fn derive_seal_key_deterministic() {
@@ -77,7 +72,7 @@ mod tests {
let key1_reader = key1.0.read(); let key1_reader = key1.0.read();
let key2_reader = key2.0.read(); let key2_reader = key2.0.read();
assert_eq!(key1_reader.deref(), key2_reader.deref()); assert_eq!(&*key1_reader, &*key2_reader);
} }
#[test] #[test]
@@ -88,14 +83,13 @@ mod tests {
let mut key = derive_key(password, &salt); let mut key = derive_key(password, &salt);
let key_reader = key.0.read(); let key_reader = key.0.read();
let key_ref = key_reader.deref();
assert_ne!(key_ref.as_slice(), &[0u8; 32][..]); assert_ne!(key_reader.as_slice(), &[0u8; 32][..]);
} }
#[test] #[test]
// We should fuzz this // We should fuzz this
pub fn test_nonce_increment() { pub fn nonce_increment() {
let mut nonce = Nonce([0u8; NONCE_LENGTH]); let mut nonce = Nonce([0u8; NONCE_LENGTH]);
nonce.increment(); nonce.increment();

View File

@@ -1 +1,3 @@
pub mod v1; pub mod v1;
pub use v1::*;

View File

@@ -1,78 +1,326 @@
use crate::{crypto::KeyCell, safe_cell::SafeCellHandle as _}; use crate::actors::keyholder;
use chacha20poly1305::Key; use arbiter_crypto::hashing::Hashable;
use hmac::Mac as _; use hmac::Hmac;
use sha2::Sha256;
pub const USERAGENT_INTEGRITY_DERIVE_TAG: &[u8] = "arbiter/useragent/integrity-key/v1".as_bytes(); use diesel::{ExpressionMethods as _, QueryDsl, dsl::insert_into, sqlite::Sqlite};
pub const USERAGENT_INTEGRITY_TAG: &[u8] = "arbiter/useragent/pubkey-entry/v1".as_bytes(); use diesel_async::{AsyncConnection, RunQueryDsl};
use kameo::{actor::ActorRef, error::SendError};
use sha2::Digest as _;
/// Computes an integrity tag for a specific domain and payload shape. use crate::{
pub fn compute_integrity_tag<'a, I>( actors::keyholder::{KeyHolder, SignIntegrity, VerifyIntegrity},
integrity_key: &mut KeyCell, db::{
purpose_tag: &[u8], self,
data_parts: I, models::{IntegrityEnvelope, NewIntegrityEnvelope},
) -> [u8; 32] schema::integrity_envelope,
where },
I: IntoIterator<Item = &'a [u8]>, };
{
type HmacSha256 = hmac::Hmac<sha2::Sha256>;
let mut output_tag = [0u8; 32]; #[derive(Debug, thiserror::Error)]
integrity_key.0.read_inline(|integrity_key_bytes: &Key| { pub enum IntegrityError {
let mut mac = <HmacSha256 as hmac::Mac>::new_from_slice(integrity_key_bytes.as_ref()) #[error("Database error: {0}")]
.expect("HMAC key initialization must not fail for 32-byte key"); Database(#[from] db::DatabaseError),
mac.update(purpose_tag);
for data_part in data_parts { #[error("KeyHolder error: {0}")]
mac.update(data_part); Keyholder(#[from] keyholder::KeyHolderError),
#[error("KeyHolder mailbox error")]
KeyholderSend,
#[error("Integrity envelope is missing for entity {entity_kind}")]
MissingEnvelope { entity_kind: &'static str },
#[error(
"Integrity payload version mismatch for entity {entity_kind}: expected {expected}, found {found}"
)]
PayloadVersionMismatch {
entity_kind: &'static str,
expected: i32,
found: i32,
},
#[error("Integrity MAC mismatch for entity {entity_kind}")]
MacMismatch { entity_kind: &'static str },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AttestationStatus {
Attested,
Unavailable,
}
pub const CURRENT_PAYLOAD_VERSION: i32 = 1;
pub const INTEGRITY_SUBKEY_TAG: &[u8] = b"arbiter/db-integrity-key/v1";
pub type HmacSha256 = Hmac<Sha256>;
pub trait Integrable: Hashable {
const KIND: &'static str;
const VERSION: i32 = 1;
}
fn payload_hash(payload: &impl Hashable) -> [u8; 32] {
let mut hasher = Sha256::new();
payload.hash(&mut hasher);
hasher.finalize().into()
}
fn push_len_prefixed(out: &mut Vec<u8>, bytes: &[u8]) {
#[expect(
clippy::cast_possible_truncation,
clippy::as_conversions,
reason = "fixme! #85"
)]
out.extend_from_slice(&(bytes.len() as u32).to_be_bytes());
out.extend_from_slice(bytes);
}
fn build_mac_input(
entity_kind: &str,
entity_id: &[u8],
payload_version: i32,
payload_hash: &[u8; 32],
) -> Vec<u8> {
let mut out = Vec::with_capacity(8 + entity_kind.len() + entity_id.len() + 32);
push_len_prefixed(&mut out, entity_kind.as_bytes());
push_len_prefixed(&mut out, entity_id);
out.extend_from_slice(&payload_version.to_be_bytes());
out.extend_from_slice(payload_hash);
out
}
pub trait IntoId {
fn into_id(self) -> Vec<u8>;
}
impl IntoId for i32 {
fn into_id(self) -> Vec<u8> {
self.to_be_bytes().to_vec()
}
}
impl IntoId for &'_ [u8] {
fn into_id(self) -> Vec<u8> {
self.to_vec()
}
}
pub async fn sign_entity<E: Integrable>(
conn: &mut impl AsyncConnection<Backend = Sqlite>,
keyholder: &ActorRef<KeyHolder>,
entity: &E,
entity_id: impl IntoId,
) -> Result<(), IntegrityError> {
let payload_hash = payload_hash(&entity);
let entity_id = entity_id.into_id();
let mac_input = build_mac_input(E::KIND, &entity_id, E::VERSION, &payload_hash);
let (key_version, mac) = keyholder
.ask(SignIntegrity { mac_input })
.await
.map_err(|err| match err {
SendError::HandlerError(inner) => IntegrityError::Keyholder(inner),
_ => IntegrityError::KeyholderSend,
})?;
insert_into(integrity_envelope::table)
.values(NewIntegrityEnvelope {
entity_kind: E::KIND.to_owned(),
entity_id,
payload_version: E::VERSION,
key_version,
mac: mac.clone(),
})
.on_conflict((
integrity_envelope::entity_id,
integrity_envelope::entity_kind,
))
.do_update()
.set((
integrity_envelope::payload_version.eq(E::VERSION),
integrity_envelope::key_version.eq(key_version),
integrity_envelope::mac.eq(mac),
))
.execute(conn)
.await
.map_err(db::DatabaseError::from)?;
Ok(())
}
pub async fn verify_entity<E: Integrable>(
conn: &mut impl AsyncConnection<Backend = Sqlite>,
keyholder: &ActorRef<KeyHolder>,
entity: &E,
entity_id: impl IntoId,
) -> Result<AttestationStatus, IntegrityError> {
let entity_id = entity_id.into_id();
let envelope: IntegrityEnvelope = integrity_envelope::table
.filter(integrity_envelope::entity_kind.eq(E::KIND))
.filter(integrity_envelope::entity_id.eq(&entity_id))
.first(conn)
.await
.map_err(|err| match err {
diesel::result::Error::NotFound => IntegrityError::MissingEnvelope {
entity_kind: E::KIND,
},
other => IntegrityError::Database(db::DatabaseError::from(other)),
})?;
if envelope.payload_version != E::VERSION {
return Err(IntegrityError::PayloadVersionMismatch {
entity_kind: E::KIND,
expected: E::VERSION,
found: envelope.payload_version,
});
}
let payload_hash = payload_hash(&entity);
let mac_input = build_mac_input(E::KIND, &entity_id, envelope.payload_version, &payload_hash);
let result = keyholder
.ask(VerifyIntegrity {
mac_input,
expected_mac: envelope.mac,
key_version: envelope.key_version,
})
.await;
match result {
Ok(true) => Ok(AttestationStatus::Attested),
Ok(false) => Err(IntegrityError::MacMismatch {
entity_kind: E::KIND,
}),
Err(SendError::HandlerError(keyholder::KeyHolderError::NotBootstrapped)) => {
Ok(AttestationStatus::Unavailable)
} }
output_tag.copy_from_slice(&mac.finalize().into_bytes()); Err(_) => Err(IntegrityError::KeyholderSend),
}); }
output_tag
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use diesel::{ExpressionMethods as _, QueryDsl};
use diesel_async::RunQueryDsl;
use kameo::{actor::ActorRef, prelude::Spawn};
use crate::{ use crate::{
crypto::{derive_key, encryption::v1::generate_salt}, actors::keyholder::{Bootstrap, KeyHolder},
safe_cell::{SafeCell, SafeCellHandle as _}, db::{self, schema},
}; };
use arbiter_crypto::safecell::{SafeCell, SafeCellHandle as _};
use super::{USERAGENT_INTEGRITY_TAG, compute_integrity_tag}; use super::{Integrable, IntegrityError, sign_entity, verify_entity};
#[derive(Clone, arbiter_macros::Hashable)]
#[test] struct DummyEntity {
pub fn integrity_tag_deterministic() { payload_version: i32,
let salt = generate_salt(); payload: Vec<u8>,
let mut integrity_key = derive_key(SafeCell::new(b"password".to_vec()), &salt); }
let key_type = 1i32.to_be_bytes(); impl Integrable for DummyEntity {
let t1 = compute_integrity_tag( const KIND: &'static str = "dummy_entity";
&mut integrity_key,
USERAGENT_INTEGRITY_TAG,
[key_type.as_slice(), b"pubkey".as_ref()],
);
let t2 = compute_integrity_tag(
&mut integrity_key,
USERAGENT_INTEGRITY_TAG,
[key_type.as_slice(), b"pubkey".as_ref()],
);
assert_eq!(t1, t2);
} }
#[test] async fn bootstrapped_keyholder(db: &db::DatabasePool) -> ActorRef<KeyHolder> {
pub fn integrity_tag_changes_with_payload() { let actor = KeyHolder::spawn(KeyHolder::new(db.clone()).await.unwrap());
let salt = generate_salt(); actor
let mut integrity_key = derive_key(SafeCell::new(b"password".to_vec()), &salt); .ask(Bootstrap {
let key_type_1 = 1i32.to_be_bytes(); seal_key_raw: SafeCell::new(b"integrity-test-seal-key".to_vec()),
let key_type_2 = 2i32.to_be_bytes(); })
let t1 = compute_integrity_tag( .await
&mut integrity_key, .unwrap();
USERAGENT_INTEGRITY_TAG, actor
[key_type_1.as_slice(), b"pubkey".as_ref()], }
);
let t2 = compute_integrity_tag( #[tokio::test]
&mut integrity_key, async fn sign_writes_envelope_and_verify_passes() {
USERAGENT_INTEGRITY_TAG, const ENTITY_ID: &[u8] = b"entity-id-7";
[key_type_2.as_slice(), b"pubkey".as_ref()],
); let db = db::create_test_pool().await;
assert_ne!(t1, t2); let keyholder = bootstrapped_keyholder(&db).await;
let mut conn = db.get().await.unwrap();
let entity = DummyEntity {
payload_version: 1,
payload: b"payload-v1".to_vec(),
};
sign_entity(&mut conn, &keyholder, &entity, ENTITY_ID)
.await
.unwrap();
let count: i64 = schema::integrity_envelope::table
.filter(schema::integrity_envelope::entity_kind.eq("dummy_entity"))
.filter(schema::integrity_envelope::entity_id.eq(ENTITY_ID))
.count()
.get_result(&mut conn)
.await
.unwrap();
assert_eq!(count, 1, "envelope row must be created exactly once");
verify_entity(&mut conn, &keyholder, &entity, ENTITY_ID)
.await
.unwrap();
}
#[tokio::test]
async fn tampered_mac_fails_verification() {
const ENTITY_ID: &[u8] = b"entity-id-11";
let db = db::create_test_pool().await;
let keyholder = bootstrapped_keyholder(&db).await;
let mut conn = db.get().await.unwrap();
let entity = DummyEntity {
payload_version: 1,
payload: b"payload-v1".to_vec(),
};
sign_entity(&mut conn, &keyholder, &entity, ENTITY_ID)
.await
.unwrap();
diesel::update(schema::integrity_envelope::table)
.filter(schema::integrity_envelope::entity_kind.eq("dummy_entity"))
.filter(schema::integrity_envelope::entity_id.eq(ENTITY_ID))
.set(schema::integrity_envelope::mac.eq(vec![0u8; 32]))
.execute(&mut conn)
.await
.unwrap();
let err = verify_entity(&mut conn, &keyholder, &entity, ENTITY_ID)
.await
.unwrap_err();
assert!(matches!(err, IntegrityError::MacMismatch { .. }));
}
#[tokio::test]
async fn changed_payload_fails_verification() {
const ENTITY_ID: &[u8] = b"entity-id-21";
let db = db::create_test_pool().await;
let keyholder = bootstrapped_keyholder(&db).await;
let mut conn = db.get().await.unwrap();
let entity = DummyEntity {
payload_version: 1,
payload: b"payload-v1".to_vec(),
};
sign_entity(&mut conn, &keyholder, &entity, ENTITY_ID)
.await
.unwrap();
let tampered = DummyEntity {
payload: b"payload-v1-but-tampered".to_vec(),
..entity
};
let err = verify_entity(&mut conn, &keyholder, &tampered, ENTITY_ID)
.await
.unwrap_err();
assert!(matches!(err, IntegrityError::MacMismatch { .. }));
} }
} }

View File

@@ -1,5 +1,3 @@
use std::ops::Deref as _;
use argon2::{Algorithm, Argon2}; use argon2::{Algorithm, Argon2};
use chacha20poly1305::{ use chacha20poly1305::{
AeadInPlace, Key, KeyInit as _, XChaCha20Poly1305, XNonce, AeadInPlace, Key, KeyInit as _, XChaCha20Poly1305, XNonce,
@@ -10,7 +8,7 @@ use rand::{
rngs::{StdRng, SysRng}, rngs::{StdRng, SysRng},
}; };
use crate::safe_cell::{SafeCell, SafeCellHandle as _}; use arbiter_crypto::safecell::{SafeCell, SafeCellHandle as _};
pub mod encryption; pub mod encryption;
pub mod integrity; pub mod integrity;
@@ -41,11 +39,8 @@ impl TryFrom<SafeCell<Vec<u8>>> for KeyCell {
impl KeyCell { impl KeyCell {
pub fn new_secure_random() -> Self { pub fn new_secure_random() -> Self {
let key = SafeCell::new_inline(|key_buffer: &mut Key| { let key = SafeCell::new_inline(|key_buffer: &mut Key| {
#[allow( let mut rng = StdRng::try_from_rng(&mut SysRng)
clippy::unwrap_used, .expect("Rng failure is unrecoverable and should panic");
reason = "Rng failure is unrecoverable and should panic"
)]
let mut rng = StdRng::try_from_rng(&mut SysRng).unwrap();
rng.fill_bytes(key_buffer); rng.fill_bytes(key_buffer);
}); });
@@ -59,8 +54,7 @@ impl KeyCell {
mut buffer: impl AsMut<Vec<u8>>, mut buffer: impl AsMut<Vec<u8>>,
) -> Result<(), Error> { ) -> Result<(), Error> {
let key_reader = self.0.read(); let key_reader = self.0.read();
let key_ref = key_reader.deref(); let cipher = XChaCha20Poly1305::new(&key_reader);
let cipher = XChaCha20Poly1305::new(key_ref);
let nonce = XNonce::from_slice(nonce.0.as_ref()); let nonce = XNonce::from_slice(nonce.0.as_ref());
let buffer = buffer.as_mut(); let buffer = buffer.as_mut();
cipher.encrypt_in_place(nonce, associated_data, buffer) cipher.encrypt_in_place(nonce, associated_data, buffer)
@@ -72,8 +66,7 @@ impl KeyCell {
buffer: &mut SafeCell<Vec<u8>>, buffer: &mut SafeCell<Vec<u8>>,
) -> Result<(), Error> { ) -> Result<(), Error> {
let key_reader = self.0.read(); let key_reader = self.0.read();
let key_ref = key_reader.deref(); let cipher = XChaCha20Poly1305::new(&key_reader);
let cipher = XChaCha20Poly1305::new(key_ref);
let nonce = XNonce::from_slice(nonce.0.as_ref()); let nonce = XNonce::from_slice(nonce.0.as_ref());
let mut buffer = buffer.write(); let mut buffer = buffer.write();
let buffer: &mut Vec<u8> = buffer.as_mut(); let buffer: &mut Vec<u8> = buffer.as_mut();
@@ -87,8 +80,7 @@ impl KeyCell {
plaintext: impl AsRef<[u8]>, plaintext: impl AsRef<[u8]>,
) -> Result<Vec<u8>, Error> { ) -> Result<Vec<u8>, Error> {
let key_reader = self.0.read(); let key_reader = self.0.read();
let key_ref = key_reader.deref(); let mut cipher = XChaCha20Poly1305::new(&key_reader);
let mut cipher = XChaCha20Poly1305::new(key_ref);
let nonce = XNonce::from_slice(nonce.0.as_ref()); let nonce = XNonce::from_slice(nonce.0.as_ref());
let ciphertext = cipher.encrypt( let ciphertext = cipher.encrypt(
@@ -102,24 +94,29 @@ impl KeyCell {
} }
} }
/// User password might be of different length, have not enough entropy, etc...
/// Derive a fixed-length key from the password using Argon2id, which is designed for password hashing and key derivation. /// Derive a fixed-length key from the password using Argon2id, which is designed for password hashing and key derivation.
pub fn derive_key(mut password: SafeCell<Vec<u8>>, salt: &Salt) -> KeyCell { pub fn derive_key(mut password: SafeCell<Vec<u8>>, salt: &Salt) -> KeyCell {
#[allow(clippy::unwrap_used)] let params = {
let params = argon2::Params::new(262_144, 3, 4, None).unwrap(); #[cfg(debug_assertions)]
{
argon2::Params::new(8, 1, 1, None).unwrap()
}
#[cfg(not(debug_assertions))]
{
argon2::Params::new(262_144, 3, 4, None).unwrap()
}
};
let hasher = Argon2::new(Algorithm::Argon2id, argon2::Version::V0x13, params); let hasher = Argon2::new(Algorithm::Argon2id, argon2::Version::V0x13, params);
let mut key = SafeCell::new(Key::default()); let mut key = SafeCell::new(Key::default());
password.read_inline(|password_source| { password.read_inline(|password_source| {
let mut key_buffer = key.write(); let mut key_buffer = key.write();
let key_buffer: &mut [u8] = key_buffer.as_mut(); let key_buffer: &mut [u8] = key_buffer.as_mut();
#[allow(
clippy::unwrap_used,
reason = "Better fail completely than return a weak key"
)]
hasher hasher
.hash_password_into(password_source.deref(), salt, key_buffer) .hash_password_into(password_source, salt, key_buffer)
.unwrap(); .expect("Better fail completely than return a weak key");
}); });
key.into() key.into()
@@ -131,7 +128,7 @@ mod tests {
derive_key, derive_key,
encryption::v1::{Nonce, generate_salt}, encryption::v1::{Nonce, generate_salt},
}; };
use crate::safe_cell::{SafeCell, SafeCellHandle as _}; use arbiter_crypto::safecell::{SafeCell, SafeCellHandle as _};
#[test] #[test]
pub fn encrypt_decrypt() { pub fn encrypt_decrypt() {

View File

@@ -23,14 +23,14 @@ const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations");
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum DatabaseSetupError { pub enum DatabaseSetupError {
#[error("Failed to determine home directory")] #[error(transparent)]
HomeDir(std::io::Error), ConcurrencySetup(diesel::result::Error),
#[error(transparent)] #[error(transparent)]
Connection(diesel::ConnectionError), Connection(diesel::ConnectionError),
#[error(transparent)] #[error("Failed to determine home directory")]
ConcurrencySetup(diesel::result::Error), HomeDir(std::io::Error),
#[error(transparent)] #[error(transparent)]
Migration(Box<dyn std::error::Error + Send + Sync>), Migration(Box<dyn std::error::Error + Send + Sync>),
@@ -41,10 +41,11 @@ pub enum DatabaseSetupError {
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum DatabaseError { pub enum DatabaseError {
#[error("Database connection error")]
Pool(#[from] PoolError),
#[error("Database query error")] #[error("Database query error")]
Connection(#[from] diesel::result::Error), Connection(#[from] diesel::result::Error),
#[error("Database connection error")]
Pool(#[from] PoolError),
} }
#[tracing::instrument(level = "info")] #[tracing::instrument(level = "info")]
@@ -93,13 +94,16 @@ fn initialize_database(url: &str) -> Result<(), DatabaseSetupError> {
} }
#[tracing::instrument(level = "info")] #[tracing::instrument(level = "info")]
/// Creates a connection pool for the `SQLite` database.
///
/// # Panics
/// Panics if the database path is not valid UTF-8.
pub async fn create_pool(url: Option<&str>) -> Result<DatabasePool, DatabaseSetupError> { pub async fn create_pool(url: Option<&str>) -> Result<DatabasePool, DatabaseSetupError> {
let database_url = url.map(String::from).unwrap_or( let database_url = url.map(String::from).unwrap_or(
#[allow(clippy::expect_used)]
database_path()? database_path()?
.to_str() .to_str()
.expect("database path is not valid UTF-8") .expect("database path is not valid UTF-8")
.to_string(), .to_owned(),
); );
initialize_database(&database_url)?; initialize_database(&database_url)?;
@@ -133,19 +137,20 @@ pub async fn create_pool(url: Option<&str>) -> Result<DatabasePool, DatabaseSetu
Ok(pool) Ok(pool)
} }
#[mutants::skip]
#[expect(clippy::missing_panics_doc, reason = "Tests oriented function")]
/// Creates a test database pool with a temporary `SQLite` database file.
pub async fn create_test_pool() -> DatabasePool { pub async fn create_test_pool() -> DatabasePool {
use rand::distr::{Alphanumeric, SampleString as _}; use rand::distr::{Alphanumeric, SampleString as _};
let tempfile_name = Alphanumeric.sample_string(&mut rand::rng(), 16); let tempfile_name = Alphanumeric.sample_string(&mut rand::rng(), 16);
let file = std::env::temp_dir().join(tempfile_name); let file = std::env::temp_dir().join(tempfile_name);
#[allow(clippy::expect_used)]
let url = file let url = file
.to_str() .to_str()
.expect("temp file path is not valid UTF-8") .expect("temp file path is not valid UTF-8")
.to_string(); .to_owned();
#[allow(clippy::expect_used)]
create_pool(Some(&url)) create_pool(Some(&url))
.await .await
.expect("Failed to create test database pool") .expect("Failed to create test database pool")

View File

@@ -1,13 +1,14 @@
#![allow(unused)] #![allow(
#![allow(clippy::all)] clippy::duplicated_attributes,
reason = "restructed's #[view] causes false positives"
)]
use crate::db::schema::{ use crate::db::schema::{
self, aead_encrypted, arbiter_settings, evm_basic_grant, evm_ether_transfer_grant, self, aead_encrypted, arbiter_settings, evm_basic_grant, evm_ether_transfer_grant,
evm_ether_transfer_grant_target, evm_ether_transfer_limit, evm_token_transfer_grant, evm_ether_transfer_grant_target, evm_ether_transfer_limit, evm_token_transfer_grant,
evm_token_transfer_log, evm_token_transfer_volume_limit, evm_transaction_log, evm_wallet, evm_token_transfer_log, evm_token_transfer_volume_limit, evm_transaction_log, evm_wallet,
root_key_history, tls_history, integrity_envelope, root_key_history, tls_history,
}; };
use chrono::{DateTime, Utc};
use diesel::{prelude::*, sqlite::Sqlite}; use diesel::{prelude::*, sqlite::Sqlite};
use restructed::Models; use restructed::Models;
@@ -27,16 +28,16 @@ pub mod types {
pub struct SqliteTimestamp(pub DateTime<Utc>); pub struct SqliteTimestamp(pub DateTime<Utc>);
impl SqliteTimestamp { impl SqliteTimestamp {
pub fn now() -> Self { pub fn now() -> Self {
SqliteTimestamp(Utc::now()) Self(Utc::now())
} }
} }
impl From<chrono::DateTime<Utc>> for SqliteTimestamp { impl From<DateTime<Utc>> for SqliteTimestamp {
fn from(dt: chrono::DateTime<Utc>) -> Self { fn from(dt: DateTime<Utc>) -> Self {
SqliteTimestamp(dt) Self(dt)
} }
} }
impl From<SqliteTimestamp> for chrono::DateTime<Utc> { impl From<SqliteTimestamp> for DateTime<Utc> {
fn from(ts: SqliteTimestamp) -> Self { fn from(ts: SqliteTimestamp) -> Self {
ts.0 ts.0
} }
@@ -47,6 +48,11 @@ pub mod types {
&'b self, &'b self,
out: &mut diesel::serialize::Output<'b, '_, Sqlite>, out: &mut diesel::serialize::Output<'b, '_, Sqlite>,
) -> diesel::serialize::Result { ) -> diesel::serialize::Result {
#[expect(
clippy::cast_possible_truncation,
clippy::as_conversions,
reason = "fixme! #84; this will break up in 2038 :3"
)]
let unix_timestamp = self.0.timestamp() as i32; let unix_timestamp = self.0.timestamp() as i32;
out.set_value(unix_timestamp); out.set_value(unix_timestamp);
Ok(IsNull::No) Ok(IsNull::No)
@@ -69,41 +75,47 @@ pub mod types {
let datetime = let datetime =
DateTime::from_timestamp(unix_timestamp, 0).ok_or("Timestamp is out of bounds")?; DateTime::from_timestamp(unix_timestamp, 0).ok_or("Timestamp is out of bounds")?;
Ok(SqliteTimestamp(datetime)) Ok(Self(datetime))
} }
} }
/// Key algorithm stored in the `useragent_client.key_type` column. #[derive(Debug, FromSqlRow, AsExpression, Clone)]
/// Values must stay stable — they are persisted in the database.
#[derive(Debug, Clone, Copy, PartialEq, Eq, FromSqlRow, AsExpression, strum::FromRepr)]
#[diesel(sql_type = Integer)] #[diesel(sql_type = Integer)]
#[repr(i32)] #[repr(transparent)] // hint compiler to optimize the wrapper struct away
pub enum KeyType { pub struct ChainId(pub i32);
Ed25519 = 1,
EcdsaSecp256k1 = 2,
Rsa = 3,
}
impl ToSql<Integer, Sqlite> for KeyType { #[expect(
clippy::cast_sign_loss,
clippy::cast_possible_truncation,
clippy::as_conversions,
reason = "safe because chain_id is stored as i32 but is guaranteed to be a valid ChainId by the API when creating grants"
)]
const _: () = {
impl From<ChainId> for alloy::primitives::ChainId {
fn from(chain_id: ChainId) -> Self {
chain_id.0 as Self
}
}
impl From<alloy::primitives::ChainId> for ChainId {
fn from(chain_id: alloy::primitives::ChainId) -> Self {
Self(chain_id as _)
}
}
};
impl FromSql<Integer, Sqlite> for ChainId {
fn from_sql(
bytes: <Sqlite as diesel::backend::Backend>::RawValue<'_>,
) -> diesel::deserialize::Result<Self> {
FromSql::<Integer, Sqlite>::from_sql(bytes).map(Self)
}
}
impl ToSql<Integer, Sqlite> for ChainId {
fn to_sql<'b>( fn to_sql<'b>(
&'b self, &'b self,
out: &mut diesel::serialize::Output<'b, '_, Sqlite>, out: &mut diesel::serialize::Output<'b, '_, Sqlite>,
) -> diesel::serialize::Result { ) -> diesel::serialize::Result {
out.set_value(*self as i32); ToSql::<Integer, Sqlite>::to_sql(&self.0, out)
Ok(IsNull::No)
}
}
impl FromSql<Integer, Sqlite> for KeyType {
fn from_sql(
mut bytes: <Sqlite as diesel::backend::Backend>::RawValue<'_>,
) -> diesel::deserialize::Result<Self> {
let Some(SqliteType::Long) = bytes.value_type() else {
return Err("Expected Integer for KeyType".into());
};
let discriminant = bytes.read_long();
KeyType::from_repr(discriminant as i32)
.ok_or_else(|| format!("Unknown KeyType discriminant: {discriminant}").into())
} }
} }
} }
@@ -242,10 +254,8 @@ pub struct UseragentClient {
pub id: i32, pub id: i32,
pub nonce: i32, pub nonce: i32,
pub public_key: Vec<u8>, pub public_key: Vec<u8>,
pub pubkey_integrity_tag: Option<Vec<u8>>,
pub created_at: SqliteTimestamp, pub created_at: SqliteTimestamp,
pub updated_at: SqliteTimestamp, pub updated_at: SqliteTimestamp,
pub key_type: KeyType,
} }
#[derive(Models, Queryable, Debug, Insertable, Selectable)] #[derive(Models, Queryable, Debug, Insertable, Selectable)]
@@ -273,7 +283,7 @@ pub struct EvmEtherTransferLimit {
pub struct EvmBasicGrant { pub struct EvmBasicGrant {
pub id: i32, pub id: i32,
pub wallet_access_id: i32, // references evm_wallet_access.id pub wallet_access_id: i32, // references evm_wallet_access.id
pub chain_id: i32, pub chain_id: ChainId,
pub valid_from: Option<SqliteTimestamp>, pub valid_from: Option<SqliteTimestamp>,
pub valid_until: Option<SqliteTimestamp>, pub valid_until: Option<SqliteTimestamp>,
pub max_gas_fee_per_gas: Option<Vec<u8>>, pub max_gas_fee_per_gas: Option<Vec<u8>>,
@@ -296,7 +306,7 @@ pub struct EvmTransactionLog {
pub id: i32, pub id: i32,
pub grant_id: i32, pub grant_id: i32,
pub wallet_access_id: i32, pub wallet_access_id: i32,
pub chain_id: i32, pub chain_id: ChainId,
pub eth_value: Vec<u8>, pub eth_value: Vec<u8>,
pub signed_at: SqliteTimestamp, pub signed_at: SqliteTimestamp,
} }
@@ -371,9 +381,28 @@ pub struct EvmTokenTransferLog {
pub id: i32, pub id: i32,
pub grant_id: i32, pub grant_id: i32,
pub log_id: i32, pub log_id: i32,
pub chain_id: i32, pub chain_id: ChainId,
pub token_contract: Vec<u8>, pub token_contract: Vec<u8>,
pub recipient_address: Vec<u8>, pub recipient_address: Vec<u8>,
pub value: Vec<u8>, pub value: Vec<u8>,
pub created_at: SqliteTimestamp, pub created_at: SqliteTimestamp,
} }
#[derive(Models, Queryable, Debug, Insertable, Selectable)]
#[diesel(table_name = integrity_envelope, check_for_backend(Sqlite))]
#[view(
NewIntegrityEnvelope,
derive(Insertable),
omit(id, signed_at, created_at),
attributes_with = "deriveless"
)]
pub struct IntegrityEnvelope {
pub id: i32,
pub entity_kind: String,
pub entity_id: Vec<u8>,
pub payload_version: i32,
pub key_version: i32,
pub mac: Vec<u8>,
pub signed_at: SqliteTimestamp,
pub created_at: SqliteTimestamp,
}

View File

@@ -139,6 +139,19 @@ diesel::table! {
} }
} }
diesel::table! {
integrity_envelope (id) {
id -> Integer,
entity_kind -> Text,
entity_id -> Binary,
payload_version -> Integer,
key_version -> Integer,
mac -> Binary,
signed_at -> Integer,
created_at -> Integer,
}
}
diesel::table! { diesel::table! {
program_client (id) { program_client (id) {
id -> Integer, id -> Integer,
@@ -178,7 +191,6 @@ diesel::table! {
id -> Integer, id -> Integer,
nonce -> Integer, nonce -> Integer,
public_key -> Binary, public_key -> Binary,
pubkey_integrity_tag -> Nullable<Binary>,
key_type -> Integer, key_type -> Integer,
created_at -> Integer, created_at -> Integer,
updated_at -> Integer, updated_at -> Integer,
@@ -220,6 +232,7 @@ diesel::allow_tables_to_appear_in_same_query!(
evm_transaction_log, evm_transaction_log,
evm_wallet, evm_wallet,
evm_wallet_access, evm_wallet_access,
integrity_envelope,
program_client, program_client,
root_key_history, root_key_history,
tls_history, tls_history,

View File

@@ -45,7 +45,7 @@ sol! {
sol! { sol! {
/// Permit2 — Uniswap's canonical token approval manager. /// Permit2 — Uniswap's canonical token approval manager.
/// Replaces per-contract ERC-20 approve() with a single approval hub. /// Replaces per-contract ERC-20 `approve()` with a single approval hub.
#[derive(Debug)] #[derive(Debug)]
interface IPermit2 { interface IPermit2 {
struct TokenPermissions { struct TokenPermissions {

View File

@@ -8,8 +8,11 @@ use alloy::{
use chrono::Utc; use chrono::Utc;
use diesel::{ExpressionMethods as _, QueryDsl as _, QueryResult, insert_into, sqlite::Sqlite}; use diesel::{ExpressionMethods as _, QueryDsl as _, QueryResult, insert_into, sqlite::Sqlite};
use diesel_async::{AsyncConnection, RunQueryDsl}; use diesel_async::{AsyncConnection, RunQueryDsl};
use kameo::actor::ActorRef;
use crate::{ use crate::{
actors::keyholder::KeyHolder,
crypto::integrity,
db::{ db::{
self, DatabaseError, self, DatabaseError,
models::{ models::{
@@ -18,8 +21,8 @@ use crate::{
schema::{self, evm_transaction_log}, schema::{self, evm_transaction_log},
}, },
evm::policies::{ evm::policies::{
DatabaseID, EvalContext, EvalViolation, FullGrant, Grant, Policy, SharedGrantSettings, CombinedSettings, DatabaseID, EvalContext, EvalViolation, Grant, Policy,
SpecificGrant, SpecificMeaning, ether_transfer::EtherTransfer, SharedGrantSettings, SpecificGrant, SpecificMeaning, ether_transfer::EtherTransfer,
token_transfers::TokenTransfer, token_transfers::TokenTransfer,
}, },
}; };
@@ -31,11 +34,14 @@ mod utils;
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum PolicyError { pub enum PolicyError {
#[error("Database error")] #[error("Database error")]
Database(#[from] crate::db::DatabaseError), Database(#[from] DatabaseError),
#[error("Transaction violates policy: {0:?}")] #[error("Transaction violates policy: {0:?}")]
Violations(Vec<EvalViolation>), Violations(Vec<EvalViolation>),
#[error("No matching grant found")] #[error("No matching grant found")]
NoMatchingGrant, NoMatchingGrant,
#[error("Integrity error: {0}")]
Integrity(#[from] integrity::IntegrityError),
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
@@ -57,6 +63,15 @@ pub enum AnalyzeError {
UnsupportedTransactionType, UnsupportedTransactionType,
} }
#[derive(Debug, thiserror::Error)]
pub enum ListError {
#[error("Database error")]
Database(#[from] DatabaseError),
#[error("Integrity verification failed for grant")]
Integrity(#[from] integrity::IntegrityError),
}
/// Controls whether a transaction should be executed or only validated /// Controls whether a transaction should be executed or only validated
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RunKind { pub enum RunKind {
@@ -75,6 +90,14 @@ async fn check_shared_constraints(
let mut violations = Vec::new(); let mut violations = Vec::new();
let now = Utc::now(); let now = Utc::now();
if shared.chain != context.chain {
violations.push(EvalViolation::MismatchingChainId {
expected: shared.chain,
actual: context.chain,
});
return Ok(violations);
}
// Validity window // Validity window
if shared.valid_from.is_some_and(|t| now < t) || shared.valid_until.is_some_and(|t| now > t) { if shared.valid_from.is_some_and(|t| now < t) || shared.valid_until.is_some_and(|t| now > t) {
violations.push(EvalViolation::InvalidTime); violations.push(EvalViolation::InvalidTime);
@@ -104,7 +127,7 @@ async fn check_shared_constraints(
.get_result(conn) .get_result(conn)
.await?; .await?;
if count >= rate_limit.count as i64 { if count >= rate_limit.count.into() {
violations.push(EvalViolation::RateLimitExceeded); violations.push(EvalViolation::RateLimitExceeded);
} }
} }
@@ -115,6 +138,7 @@ async fn check_shared_constraints(
// Supporting only EIP-1559 transactions for now, but we can easily extend this to support legacy transactions if needed // Supporting only EIP-1559 transactions for now, but we can easily extend this to support legacy transactions if needed
pub struct Engine { pub struct Engine {
db: db::DatabasePool, db: db::DatabasePool,
keyholder: ActorRef<KeyHolder>,
} }
impl Engine { impl Engine {
@@ -123,7 +147,10 @@ impl Engine {
context: EvalContext, context: EvalContext,
meaning: &P::Meaning, meaning: &P::Meaning,
run_kind: RunKind, run_kind: RunKind,
) -> Result<(), PolicyError> { ) -> Result<(), PolicyError>
where
P::Settings: Clone,
{
let mut conn = self.db.get().await.map_err(DatabaseError::from)?; let mut conn = self.db.get().await.map_err(DatabaseError::from)?;
let grant = P::try_find_grant(&context, &mut conn) let grant = P::try_find_grant(&context, &mut conn)
@@ -131,10 +158,16 @@ impl Engine {
.map_err(DatabaseError::from)? .map_err(DatabaseError::from)?
.ok_or(PolicyError::NoMatchingGrant)?; .ok_or(PolicyError::NoMatchingGrant)?;
let mut violations = integrity::verify_entity(&mut conn, &self.keyholder, &grant.settings, grant.id).await?;
check_shared_constraints(&context, &grant.shared, grant.shared_grant_id, &mut conn)
.await let mut violations = check_shared_constraints(
.map_err(DatabaseError::from)?; &context,
&grant.settings.shared,
grant.common_settings_id,
&mut conn,
)
.await
.map_err(DatabaseError::from)?;
violations.extend( violations.extend(
P::evaluate(&context, meaning, &grant, &mut conn) P::evaluate(&context, meaning, &grant, &mut conn)
.await .await
@@ -143,14 +176,16 @@ impl Engine {
if !violations.is_empty() { if !violations.is_empty() {
return Err(PolicyError::Violations(violations)); return Err(PolicyError::Violations(violations));
} else if run_kind == RunKind::Execution { }
if run_kind == RunKind::Execution {
conn.transaction(|conn| { conn.transaction(|conn| {
Box::pin(async move { Box::pin(async move {
let log_id: i32 = insert_into(evm_transaction_log::table) let log_id: i32 = insert_into(evm_transaction_log::table)
.values(&NewEvmTransactionLog { .values(&NewEvmTransactionLog {
grant_id: grant.shared_grant_id, grant_id: grant.common_settings_id,
wallet_access_id: context.target.id, wallet_access_id: context.target.id,
chain_id: context.chain as i32, chain_id: context.chain.into(),
eth_value: utils::u256_to_bytes(context.value).to_vec(), eth_value: utils::u256_to_bytes(context.value).to_vec(),
signed_at: Utc::now().into(), signed_at: Utc::now().into(),
}) })
@@ -172,42 +207,52 @@ impl Engine {
} }
impl Engine { impl Engine {
pub fn new(db: db::DatabasePool) -> Self { pub const fn new(db: db::DatabasePool, keyholder: ActorRef<KeyHolder>) -> Self {
Self { db } Self { db, keyholder }
} }
pub async fn create_grant<P: Policy>( pub async fn create_grant<P: Policy>(
&self, &self,
full_grant: FullGrant<P::Settings>, full_grant: CombinedSettings<P::Settings>,
) -> Result<i32, DatabaseError> { ) -> Result<i32, DatabaseError>
where
P::Settings: Clone,
{
let mut conn = self.db.get().await?; let mut conn = self.db.get().await?;
let keyholder = self.keyholder.clone();
let id = conn let id = conn
.transaction(|conn| { .transaction(|conn| {
Box::pin(async move { Box::pin(async move {
use schema::evm_basic_grant; use schema::evm_basic_grant;
#[expect(
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::as_conversions,
reason = "fixme! #86"
)]
let basic_grant: EvmBasicGrant = insert_into(evm_basic_grant::table) let basic_grant: EvmBasicGrant = insert_into(evm_basic_grant::table)
.values(&NewEvmBasicGrant { .values(&NewEvmBasicGrant {
chain_id: full_grant.basic.chain as i32, chain_id: full_grant.shared.chain.into(),
wallet_access_id: full_grant.basic.wallet_access_id, wallet_access_id: full_grant.shared.wallet_access_id,
valid_from: full_grant.basic.valid_from.map(SqliteTimestamp), valid_from: full_grant.shared.valid_from.map(SqliteTimestamp),
valid_until: full_grant.basic.valid_until.map(SqliteTimestamp), valid_until: full_grant.shared.valid_until.map(SqliteTimestamp),
max_gas_fee_per_gas: full_grant max_gas_fee_per_gas: full_grant
.basic .shared
.max_gas_fee_per_gas .max_gas_fee_per_gas
.map(|fee| utils::u256_to_bytes(fee).to_vec()), .map(|fee| utils::u256_to_bytes(fee).to_vec()),
max_priority_fee_per_gas: full_grant max_priority_fee_per_gas: full_grant
.basic .shared
.max_priority_fee_per_gas .max_priority_fee_per_gas
.map(|fee| utils::u256_to_bytes(fee).to_vec()), .map(|fee| utils::u256_to_bytes(fee).to_vec()),
rate_limit_count: full_grant rate_limit_count: full_grant
.basic .shared
.rate_limit .rate_limit
.as_ref() .as_ref()
.map(|rl| rl.count as i32), .map(|rl| rl.count as i32),
rate_limit_window_secs: full_grant rate_limit_window_secs: full_grant
.basic .shared
.rate_limit .rate_limit
.as_ref() .as_ref()
.map(|rl| rl.window.num_seconds() as i32), .map(|rl| rl.window.num_seconds() as i32),
@@ -217,7 +262,13 @@ impl Engine {
.get_result(conn) .get_result(conn)
.await?; .await?;
P::create_grant(&basic_grant, &full_grant.specific, conn).await P::create_grant(&basic_grant, &full_grant.specific, conn).await?;
integrity::sign_entity(conn, &keyholder, &full_grant, basic_grant.id)
.await
.map_err(|_| diesel::result::Error::RollbackTransaction)?;
QueryResult::Ok(basic_grant.id)
}) })
}) })
.await?; .await?;
@@ -225,33 +276,36 @@ impl Engine {
Ok(id) Ok(id)
} }
pub async fn list_all_grants(&self) -> Result<Vec<Grant<SpecificGrant>>, DatabaseError> { async fn list_one_kind<Kind: Policy, Y>(
let mut conn = self.db.get().await?; &self,
conn: &mut impl AsyncConnection<Backend = Sqlite>,
) -> Result<impl Iterator<Item = Grant<Y>>, ListError>
where
Y: From<Kind::Settings>,
{
let all_grants = Kind::find_all_grants(conn)
.await
.map_err(DatabaseError::from)?;
// Verify integrity of all grants before returning any results
for grant in &all_grants {
integrity::verify_entity(conn, &self.keyholder, &grant.settings, grant.id).await?;
}
Ok(all_grants.into_iter().map(|g| Grant {
id: g.id,
common_settings_id: g.common_settings_id,
settings: g.settings.generalize(),
}))
}
pub async fn list_all_grants(&self) -> Result<Vec<Grant<SpecificGrant>>, ListError> {
let mut conn = self.db.get().await.map_err(DatabaseError::from)?;
let mut grants: Vec<Grant<SpecificGrant>> = Vec::new(); let mut grants: Vec<Grant<SpecificGrant>> = Vec::new();
grants.extend( grants.extend(self.list_one_kind::<EtherTransfer, _>(&mut conn).await?);
EtherTransfer::find_all_grants(&mut conn) grants.extend(self.list_one_kind::<TokenTransfer, _>(&mut conn).await?);
.await?
.into_iter()
.map(|g| Grant {
id: g.id,
shared_grant_id: g.shared_grant_id,
shared: g.shared,
settings: SpecificGrant::EtherTransfer(g.settings),
}),
);
grants.extend(
TokenTransfer::find_all_grants(&mut conn)
.await?
.into_iter()
.map(|g| Grant {
id: g.id,
shared_grant_id: g.shared_grant_id,
shared: g.shared,
settings: SpecificGrant::TokenTransfer(g.settings),
}),
);
Ok(grants) Ok(grants)
} }
@@ -265,7 +319,7 @@ impl Engine {
let TxKind::Call(to) = transaction.to else { let TxKind::Call(to) = transaction.to else {
return Err(VetError::ContractCreationNotSupported); return Err(VetError::ContractCreationNotSupported);
}; };
let context = policies::EvalContext { let context = EvalContext {
target, target,
chain: transaction.chain_id, chain: transaction.chain_id,
to, to,
@@ -297,3 +351,261 @@ impl Engine {
Err(VetError::UnsupportedTransactionType) Err(VetError::UnsupportedTransactionType)
} }
} }
#[cfg(test)]
mod tests {
use alloy::primitives::{Address, Bytes, U256, address};
use chrono::{Duration, Utc};
use diesel::{SelectableHelper, insert_into};
use diesel_async::RunQueryDsl;
use rstest::rstest;
use crate::db::{
self, DatabaseConnection,
models::{
EvmBasicGrant, EvmWalletAccess, NewEvmBasicGrant, NewEvmTransactionLog, SqliteTimestamp,
},
schema::{evm_basic_grant, evm_transaction_log},
};
use crate::evm::policies::{
EvalContext, EvalViolation, SharedGrantSettings, TransactionRateLimit,
};
use super::check_shared_constraints;
const WALLET_ACCESS_ID: i32 = 1;
const CHAIN_ID: u64 = 1;
const RECIPIENT: Address = address!("1111111111111111111111111111111111111111");
fn context() -> EvalContext {
EvalContext {
target: EvmWalletAccess {
id: WALLET_ACCESS_ID,
wallet_id: 10,
client_id: 20,
created_at: SqliteTimestamp(Utc::now()),
},
chain: CHAIN_ID,
to: RECIPIENT,
value: U256::ZERO,
calldata: Bytes::new(),
max_fee_per_gas: 100,
max_priority_fee_per_gas: 10,
}
}
fn shared_settings() -> SharedGrantSettings {
SharedGrantSettings {
wallet_access_id: WALLET_ACCESS_ID,
chain: CHAIN_ID,
valid_from: None,
valid_until: None,
max_gas_fee_per_gas: None,
max_priority_fee_per_gas: None,
rate_limit: None,
}
}
async fn insert_basic_grant(
conn: &mut DatabaseConnection,
shared: &SharedGrantSettings,
) -> EvmBasicGrant {
#[expect(
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::as_conversions,
reason = "fixme! #86"
)]
insert_into(evm_basic_grant::table)
.values(NewEvmBasicGrant {
wallet_access_id: shared.wallet_access_id,
chain_id: shared.chain.into(),
valid_from: shared.valid_from.map(SqliteTimestamp),
valid_until: shared.valid_until.map(SqliteTimestamp),
max_gas_fee_per_gas: shared
.max_gas_fee_per_gas
.map(|fee| super::utils::u256_to_bytes(fee).to_vec()),
max_priority_fee_per_gas: shared
.max_priority_fee_per_gas
.map(|fee| super::utils::u256_to_bytes(fee).to_vec()),
rate_limit_count: shared.rate_limit.as_ref().map(|limit| limit.count as i32),
rate_limit_window_secs: shared
.rate_limit
.as_ref()
.map(|limit| limit.window.num_seconds() as i32),
revoked_at: None,
})
.returning(EvmBasicGrant::as_select())
.get_result(conn)
.await
.unwrap()
}
#[rstest]
#[case::matching_chain(CHAIN_ID, false)]
#[case::mismatching_chain(CHAIN_ID + 1, true)]
#[tokio::test]
async fn check_shared_constraints_enforces_chain_id(
#[case] context_chain: u64,
#[case] expect_mismatch: bool,
) {
let db = db::create_test_pool().await;
let mut conn = db.get().await.unwrap();
let context = EvalContext {
chain: context_chain,
..context()
};
let violations = check_shared_constraints(&context, &shared_settings(), 999, &mut *conn)
.await
.unwrap();
assert_eq!(
violations
.iter()
.any(|violation| matches!(violation, EvalViolation::MismatchingChainId { .. })),
expect_mismatch
);
if expect_mismatch {
assert_eq!(violations.len(), 1);
} else {
assert!(violations.is_empty());
}
}
#[rstest]
#[case::valid_from_in_bounds(Some(Utc::now() - Duration::hours(1)), None, false)]
#[case::valid_from_out_of_bounds(Some(Utc::now() + Duration::hours(1)), None, true)]
#[case::valid_until_in_bounds(None, Some(Utc::now() + Duration::hours(1)), false)]
#[case::valid_until_out_of_bounds(None, Some(Utc::now() - Duration::hours(1)), true)]
#[tokio::test]
async fn check_shared_constraints_enforces_validity_window(
#[case] valid_from: Option<chrono::DateTime<Utc>>,
#[case] valid_until: Option<chrono::DateTime<Utc>>,
#[case] expect_invalid_time: bool,
) {
let db = db::create_test_pool().await;
let mut conn = db.get().await.unwrap();
let shared = SharedGrantSettings {
valid_from,
valid_until,
..shared_settings()
};
let violations = check_shared_constraints(&context(), &shared, 999, &mut *conn)
.await
.unwrap();
assert_eq!(
violations
.iter()
.any(|violation| matches!(violation, EvalViolation::InvalidTime)),
expect_invalid_time
);
if expect_invalid_time {
assert_eq!(violations.len(), 1);
} else {
assert!(violations.is_empty());
}
}
#[rstest]
#[case::max_fee_within_limit(Some(U256::from(100u64)), None, 100, 10, false)]
#[case::max_fee_exceeded(Some(U256::from(99u64)), None, 100, 10, true)]
#[case::priority_fee_within_limit(None, Some(U256::from(10u64)), 100, 10, false)]
#[case::priority_fee_exceeded(None, Some(U256::from(9u64)), 100, 10, true)]
#[tokio::test]
async fn check_shared_constraints_enforces_gas_fee_caps(
#[case] max_gas_fee_per_gas: Option<U256>,
#[case] max_priority_fee_per_gas: Option<U256>,
#[case] actual_max_fee_per_gas: u128,
#[case] actual_max_priority_fee_per_gas: u128,
#[case] expect_gas_limit_violation: bool,
) {
let db = db::create_test_pool().await;
let mut conn = db.get().await.unwrap();
let context = EvalContext {
max_fee_per_gas: actual_max_fee_per_gas,
max_priority_fee_per_gas: actual_max_priority_fee_per_gas,
..context()
};
let shared = SharedGrantSettings {
max_gas_fee_per_gas,
max_priority_fee_per_gas,
..shared_settings()
};
let violations = check_shared_constraints(&context, &shared, 999, &mut *conn)
.await
.unwrap();
assert_eq!(
violations
.iter()
.any(|violation| matches!(violation, EvalViolation::GasLimitExceeded { .. })),
expect_gas_limit_violation
);
if expect_gas_limit_violation {
assert_eq!(violations.len(), 1);
} else {
assert!(violations.is_empty());
}
}
#[rstest]
#[case::under_rate_limit(2, false)]
#[case::at_rate_limit(1, true)]
#[tokio::test]
async fn check_shared_constraints_enforces_rate_limit(
#[case] rate_limit_count: u32,
#[case] expect_rate_limit_violation: bool,
) {
let db = db::create_test_pool().await;
let mut conn = db.get().await.unwrap();
let shared = SharedGrantSettings {
rate_limit: Some(TransactionRateLimit {
count: rate_limit_count,
window: Duration::hours(1),
}),
..shared_settings()
};
let basic_grant = insert_basic_grant(&mut conn, &shared).await;
insert_into(evm_transaction_log::table)
.values(NewEvmTransactionLog {
grant_id: basic_grant.id,
wallet_access_id: WALLET_ACCESS_ID,
chain_id: CHAIN_ID.into(),
eth_value: super::utils::u256_to_bytes(U256::ZERO).to_vec(),
signed_at: SqliteTimestamp(Utc::now()),
})
.execute(&mut *conn)
.await
.unwrap();
let violations = check_shared_constraints(&context(), &shared, basic_grant.id, &mut *conn)
.await
.unwrap();
assert_eq!(
violations
.iter()
.any(|violation| matches!(violation, EvalViolation::RateLimitExceeded)),
expect_rate_limit_violation
);
if expect_rate_limit_violation {
assert_eq!(violations.len(), 1);
} else {
assert!(violations.is_empty());
}
}
}

View File

@@ -10,7 +10,8 @@ use diesel_async::{AsyncConnection, RunQueryDsl};
use thiserror::Error; use thiserror::Error;
use crate::{ use crate::{
db::models::{self, EvmBasicGrant, EvmWalletAccess}, crypto::integrity::v1::Integrable,
db::models::{EvmBasicGrant, EvmWalletAccess},
evm::utils, evm::utils,
}; };
@@ -55,6 +56,9 @@ pub enum EvalViolation {
#[error("Transaction type is not allowed by this grant")] #[error("Transaction type is not allowed by this grant")]
InvalidTransactionType, InvalidTransactionType,
#[error("Mismatching chain ID")]
MismatchingChainId { expected: ChainId, actual: ChainId },
} }
pub type DatabaseID = i32; pub type DatabaseID = i32;
@@ -62,13 +66,12 @@ pub type DatabaseID = i32;
#[derive(Debug)] #[derive(Debug)]
pub struct Grant<PolicySettings> { pub struct Grant<PolicySettings> {
pub id: DatabaseID, pub id: DatabaseID,
pub shared_grant_id: DatabaseID, // ID of the basic grant for shared-logic checks like rate limits and validity periods pub common_settings_id: DatabaseID, // ID of the basic grant for shared-logic checks like rate limits and validity periods
pub shared: SharedGrantSettings, pub settings: CombinedSettings<PolicySettings>,
pub settings: PolicySettings,
} }
pub trait Policy: Sized { pub trait Policy: Sized {
type Settings: Send + Sync + 'static + Into<SpecificGrant>; type Settings: Send + Sync + 'static + Into<SpecificGrant> + Integrable;
type Meaning: Display + std::fmt::Debug + Send + Sync + 'static + Into<SpecificMeaning>; type Meaning: Display + std::fmt::Debug + Send + Sync + 'static + Into<SpecificMeaning>;
fn analyze(context: &EvalContext) -> Option<Self::Meaning>; fn analyze(context: &EvalContext) -> Option<Self::Meaning>;
@@ -84,10 +87,10 @@ pub trait Policy: Sized {
// Create a new grant in the database based on the provided grant details, and return its ID // Create a new grant in the database based on the provided grant details, and return its ID
fn create_grant( fn create_grant(
basic: &models::EvmBasicGrant, basic: &EvmBasicGrant,
grant: &Self::Settings, grant: &Self::Settings,
conn: &mut impl AsyncConnection<Backend = Sqlite>, conn: &mut impl AsyncConnection<Backend = Sqlite>,
) -> impl std::future::Future<Output = QueryResult<DatabaseID>> + Send; ) -> impl Future<Output = QueryResult<DatabaseID>> + Send;
// Try to find an existing grant that matches the transaction context, and return its details if found // Try to find an existing grant that matches the transaction context, and return its details if found
// Additionally, return ID of basic grant for shared-logic checks like rate limits and validity periods // Additionally, return ID of basic grant for shared-logic checks like rate limits and validity periods
@@ -124,19 +127,19 @@ pub enum SpecificMeaning {
TokenTransfer(token_transfers::Meaning), TokenTransfer(token_transfers::Meaning),
} }
#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, arbiter_macros::Hashable)]
pub struct TransactionRateLimit { pub struct TransactionRateLimit {
pub count: u32, pub count: u32,
pub window: Duration, pub window: Duration,
} }
#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, arbiter_macros::Hashable)]
pub struct VolumeRateLimit { pub struct VolumeRateLimit {
pub max_volume: U256, pub max_volume: U256,
pub window: Duration, pub window: Duration,
} }
#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[derive(Clone, Debug, PartialEq, Eq, Hash, arbiter_macros::Hashable)]
pub struct SharedGrantSettings { pub struct SharedGrantSettings {
pub wallet_access_id: i32, pub wallet_access_id: i32,
pub chain: ChainId, pub chain: ChainId,
@@ -151,10 +154,10 @@ pub struct SharedGrantSettings {
} }
impl SharedGrantSettings { impl SharedGrantSettings {
fn try_from_model(model: EvmBasicGrant) -> QueryResult<Self> { pub(crate) fn try_from_model(model: EvmBasicGrant) -> QueryResult<Self> {
Ok(Self { Ok(Self {
wallet_access_id: model.wallet_access_id, wallet_access_id: model.wallet_access_id,
chain: model.chain_id as u64, // safe because chain_id is stored as i32 but is guaranteed to be a valid ChainId by the API when creating grants chain: model.chain_id.into(),
valid_from: model.valid_from.map(Into::into), valid_from: model.valid_from.map(Into::into),
valid_until: model.valid_until.map(Into::into), valid_until: model.valid_until.map(Into::into),
max_gas_fee_per_gas: model max_gas_fee_per_gas: model
@@ -165,10 +168,11 @@ impl SharedGrantSettings {
.max_priority_fee_per_gas .max_priority_fee_per_gas
.map(|b| utils::try_bytes_to_u256(&b)) .map(|b| utils::try_bytes_to_u256(&b))
.transpose()?, .transpose()?,
#[expect(clippy::cast_sign_loss, clippy::as_conversions, reason = "fixme! #86")]
rate_limit: match (model.rate_limit_count, model.rate_limit_window_secs) { rate_limit: match (model.rate_limit_count, model.rate_limit_window_secs) {
(Some(count), Some(window_secs)) => Some(TransactionRateLimit { (Some(count), Some(window_secs)) => Some(TransactionRateLimit {
count: count as u32, count: count as u32,
window: Duration::seconds(window_secs as i64), window: Duration::seconds(window_secs.into()),
}), }),
_ => None, _ => None,
}, },
@@ -178,7 +182,7 @@ impl SharedGrantSettings {
pub async fn query_by_id( pub async fn query_by_id(
conn: &mut impl AsyncConnection<Backend = Sqlite>, conn: &mut impl AsyncConnection<Backend = Sqlite>,
id: i32, id: i32,
) -> diesel::result::QueryResult<Self> { ) -> QueryResult<Self> {
use crate::db::schema::evm_basic_grant; use crate::db::schema::evm_basic_grant;
let basic_grant: EvmBasicGrant = evm_basic_grant::table let basic_grant: EvmBasicGrant = evm_basic_grant::table
@@ -197,7 +201,22 @@ pub enum SpecificGrant {
TokenTransfer(token_transfers::Settings), TokenTransfer(token_transfers::Settings),
} }
pub struct FullGrant<PolicyGrant> { #[derive(Debug, arbiter_macros::Hashable)]
pub basic: SharedGrantSettings, pub struct CombinedSettings<PolicyGrant> {
pub shared: SharedGrantSettings,
pub specific: PolicyGrant, pub specific: PolicyGrant,
} }
impl<P> CombinedSettings<P> {
pub fn generalize<Y: From<P>>(self) -> CombinedSettings<Y> {
CombinedSettings {
shared: self.shared,
specific: self.specific.into(),
}
}
}
impl<P: Integrable> Integrable for CombinedSettings<P> {
const KIND: &'static str = P::KIND;
const VERSION: i32 = P::VERSION;
}

View File

@@ -4,21 +4,22 @@ use std::fmt::Display;
use alloy::primitives::{Address, U256}; use alloy::primitives::{Address, U256};
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use diesel::dsl::{auto_type, insert_into}; use diesel::dsl::{auto_type, insert_into};
use diesel::prelude::*;
use diesel::sqlite::Sqlite; use diesel::sqlite::Sqlite;
use diesel::{ExpressionMethods, JoinOnDsl, prelude::*};
use diesel_async::{AsyncConnection, RunQueryDsl}; use diesel_async::{AsyncConnection, RunQueryDsl};
use crate::crypto::integrity::v1::Integrable;
use crate::db::models::{ use crate::db::models::{
EvmBasicGrant, EvmEtherTransferGrant, EvmEtherTransferGrantTarget, EvmEtherTransferLimit, EvmBasicGrant, EvmEtherTransferGrant, EvmEtherTransferGrantTarget, EvmEtherTransferLimit,
NewEvmEtherTransferLimit, SqliteTimestamp, NewEvmEtherTransferLimit, SqliteTimestamp,
}; };
use crate::db::schema::{evm_basic_grant, evm_ether_transfer_limit, evm_transaction_log}; use crate::db::schema::{evm_basic_grant, evm_ether_transfer_limit, evm_transaction_log};
use crate::evm::policies::{ use crate::evm::policies::{
Grant, SharedGrantSettings, SpecificGrant, SpecificMeaning, VolumeRateLimit, CombinedSettings, Grant, SharedGrantSettings, SpecificGrant, SpecificMeaning, VolumeRateLimit,
}; };
use crate::{ use crate::{
db::{ db::{
models::{self, NewEvmEtherTransferGrant, NewEvmEtherTransferGrantTarget}, models::{NewEvmEtherTransferGrant, NewEvmEtherTransferGrantTarget},
schema::{evm_ether_transfer_grant, evm_ether_transfer_grant_target}, schema::{evm_ether_transfer_grant, evm_ether_transfer_grant_target},
}, },
evm::{policies::Policy, utils}, evm::{policies::Policy, utils},
@@ -45,21 +46,24 @@ impl Display for Meaning {
} }
} }
impl From<Meaning> for SpecificMeaning { impl From<Meaning> for SpecificMeaning {
fn from(val: Meaning) -> SpecificMeaning { fn from(val: Meaning) -> Self {
SpecificMeaning::EtherTransfer(val) Self::EtherTransfer(val)
} }
} }
// A grant for ether transfers, which can be scoped to specific target addresses and volume limits // A grant for ether transfers, which can be scoped to specific target addresses and volume limits
#[derive(Debug, Clone)] #[derive(Debug, Clone, arbiter_macros::Hashable)]
pub struct Settings { pub struct Settings {
pub target: Vec<Address>, pub target: Vec<Address>,
pub limit: VolumeRateLimit, pub limit: VolumeRateLimit,
} }
impl Integrable for Settings {
const KIND: &'static str = "EtherTransfer";
}
impl From<Settings> for SpecificGrant { impl From<Settings> for SpecificGrant {
fn from(val: Settings) -> SpecificGrant { fn from(val: Settings) -> Self {
SpecificGrant::EtherTransfer(val) Self::EtherTransfer(val)
} }
} }
@@ -70,9 +74,7 @@ async fn query_relevant_past_transaction(
) -> QueryResult<Vec<(U256, DateTime<Utc>)>> { ) -> QueryResult<Vec<(U256, DateTime<Utc>)>> {
let past_transactions: Vec<(Vec<u8>, SqliteTimestamp)> = evm_transaction_log::table let past_transactions: Vec<(Vec<u8>, SqliteTimestamp)> = evm_transaction_log::table
.filter(evm_transaction_log::grant_id.eq(grant_id)) .filter(evm_transaction_log::grant_id.eq(grant_id))
.filter( .filter(evm_transaction_log::signed_at.ge(SqliteTimestamp(Utc::now() - longest_window)))
evm_transaction_log::signed_at.ge(SqliteTimestamp(chrono::Utc::now() - longest_window)),
)
.select(( .select((
evm_transaction_log::eth_value, evm_transaction_log::eth_value,
evm_transaction_log::signed_at, evm_transaction_log::signed_at,
@@ -95,17 +97,17 @@ async fn check_rate_limits(
db: &mut impl AsyncConnection<Backend = Sqlite>, db: &mut impl AsyncConnection<Backend = Sqlite>,
) -> QueryResult<Vec<EvalViolation>> { ) -> QueryResult<Vec<EvalViolation>> {
let mut violations = Vec::new(); let mut violations = Vec::new();
let window = grant.settings.limit.window; let window = grant.settings.specific.limit.window;
let past_transaction = query_relevant_past_transaction(grant.id, window, db).await?; let past_transaction = query_relevant_past_transaction(grant.id, window, db).await?;
let window_start = chrono::Utc::now() - grant.settings.limit.window; let window_start = Utc::now() - grant.settings.specific.limit.window;
let prospective_cumulative_volume: U256 = past_transaction let prospective_cumulative_volume: U256 = past_transaction
.iter() .iter()
.filter(|(_, timestamp)| timestamp >= &window_start) .filter(|(_, timestamp)| timestamp >= &window_start)
.fold(current_transfer_value, |acc, (value, _)| acc + *value); .fold(current_transfer_value, |acc, (value, _)| acc + *value);
if prospective_cumulative_volume > grant.settings.limit.max_volume { if prospective_cumulative_volume > grant.settings.specific.limit.max_volume {
violations.push(EvalViolation::VolumetricLimitExceeded); violations.push(EvalViolation::VolumetricLimitExceeded);
} }
@@ -138,7 +140,7 @@ impl Policy for EtherTransfer {
let mut violations = Vec::new(); let mut violations = Vec::new();
// Check if the target address is within the grant's allowed targets // Check if the target address is within the grant's allowed targets
if !grant.settings.target.contains(&meaning.to) { if !grant.settings.specific.target.contains(&meaning.to) {
violations.push(EvalViolation::InvalidTarget { target: meaning.to }); violations.push(EvalViolation::InvalidTarget { target: meaning.to });
} }
@@ -149,10 +151,15 @@ impl Policy for EtherTransfer {
} }
async fn create_grant( async fn create_grant(
basic: &models::EvmBasicGrant, basic: &EvmBasicGrant,
grant: &Self::Settings, grant: &Self::Settings,
conn: &mut impl AsyncConnection<Backend = Sqlite>, conn: &mut impl AsyncConnection<Backend = Sqlite>,
) -> diesel::result::QueryResult<DatabaseID> { ) -> QueryResult<DatabaseID> {
#[expect(
clippy::cast_possible_truncation,
clippy::as_conversions,
reason = "fixme! #86"
)]
let limit_id: i32 = insert_into(evm_ether_transfer_limit::table) let limit_id: i32 = insert_into(evm_ether_transfer_limit::table)
.values(NewEvmEtherTransferLimit { .values(NewEvmEtherTransferLimit {
window_secs: grant.limit.window.num_seconds() as i32, window_secs: grant.limit.window.num_seconds() as i32,
@@ -187,7 +194,7 @@ impl Policy for EtherTransfer {
async fn try_find_grant( async fn try_find_grant(
context: &EvalContext, context: &EvalContext,
conn: &mut impl AsyncConnection<Backend = Sqlite>, conn: &mut impl AsyncConnection<Backend = Sqlite>,
) -> diesel::result::QueryResult<Option<Grant<Self::Settings>>> { ) -> QueryResult<Option<Grant<Self::Settings>>> {
let target_bytes = context.to.to_vec(); let target_bytes = context.to.to_vec();
// Find a grant where: // Find a grant where:
@@ -241,15 +248,17 @@ impl Policy for EtherTransfer {
limit: VolumeRateLimit { limit: VolumeRateLimit {
max_volume: utils::try_bytes_to_u256(&limit.max_volume) max_volume: utils::try_bytes_to_u256(&limit.max_volume)
.map_err(|err| diesel::result::Error::DeserializationError(Box::new(err)))?, .map_err(|err| diesel::result::Error::DeserializationError(Box::new(err)))?,
window: chrono::Duration::seconds(limit.window_secs as i64), window: Duration::seconds(limit.window_secs.into()),
}, },
}; };
Ok(Some(Grant { Ok(Some(Grant {
id: grant.id, id: grant.id,
shared_grant_id: grant.basic_grant_id, common_settings_id: grant.basic_grant_id,
shared: SharedGrantSettings::try_from_model(basic_grant)?, settings: CombinedSettings {
settings, shared: SharedGrantSettings::try_from_model(basic_grant)?,
specific: settings,
},
})) }))
} }
@@ -259,7 +268,7 @@ impl Policy for EtherTransfer {
_log_id: i32, _log_id: i32,
_grant: &Grant<Self::Settings>, _grant: &Grant<Self::Settings>,
_conn: &mut impl AsyncConnection<Backend = Sqlite>, _conn: &mut impl AsyncConnection<Backend = Sqlite>,
) -> diesel::result::QueryResult<()> { ) -> QueryResult<()> {
// Basic log is sufficient // Basic log is sufficient
Ok(()) Ok(())
@@ -312,7 +321,7 @@ impl Policy for EtherTransfer {
.map(|(basic, specific)| { .map(|(basic, specific)| {
let targets: Vec<Address> = targets_by_grant let targets: Vec<Address> = targets_by_grant
.get(&specific.id) .get(&specific.id)
.map(|v| v.as_slice()) .map(Vec::as_slice)
.unwrap_or_default() .unwrap_or_default()
.iter() .iter()
.filter_map(|t| { .filter_map(|t| {
@@ -327,15 +336,17 @@ impl Policy for EtherTransfer {
Ok(Grant { Ok(Grant {
id: specific.id, id: specific.id,
shared_grant_id: specific.basic_grant_id, common_settings_id: specific.basic_grant_id,
shared: SharedGrantSettings::try_from_model(basic)?, settings: CombinedSettings {
settings: Settings { shared: SharedGrantSettings::try_from_model(basic)?,
target: targets, specific: Settings {
limit: VolumeRateLimit { target: targets,
max_volume: utils::try_bytes_to_u256(&limit.max_volume).map_err( limit: VolumeRateLimit {
|e| diesel::result::Error::DeserializationError(Box::new(e)), max_volume: utils::try_bytes_to_u256(&limit.max_volume).map_err(
)?, |e| diesel::result::Error::DeserializationError(Box::new(e)),
window: Duration::seconds(limit.window_secs as i64), )?,
window: Duration::seconds(limit.window_secs.into()),
},
}, },
}, },
}) })

View File

@@ -11,14 +11,17 @@ use crate::db::{
schema::{evm_basic_grant, evm_transaction_log}, schema::{evm_basic_grant, evm_transaction_log},
}; };
use crate::evm::{ use crate::evm::{
policies::{EvalContext, EvalViolation, Grant, Policy, SharedGrantSettings, VolumeRateLimit}, policies::{
CombinedSettings, EvalContext, EvalViolation, Grant, Policy, SharedGrantSettings,
VolumeRateLimit,
},
utils, utils,
}; };
use super::{EtherTransfer, Settings}; use super::{EtherTransfer, Settings};
const WALLET_ACCESS_ID: i32 = 1; const WALLET_ACCESS_ID: i32 = 1;
const CHAIN_ID: u64 = 1; const CHAIN_ID: alloy::primitives::ChainId = 1;
const ALLOWED: Address = address!("1111111111111111111111111111111111111111"); const ALLOWED: Address = address!("1111111111111111111111111111111111111111");
const OTHER: Address = address!("2222222222222222222222222222222222222222"); const OTHER: Address = address!("2222222222222222222222222222222222222222");
@@ -44,7 +47,7 @@ async fn insert_basic(conn: &mut DatabaseConnection, revoked: bool) -> EvmBasicG
insert_into(evm_basic_grant::table) insert_into(evm_basic_grant::table)
.values(NewEvmBasicGrant { .values(NewEvmBasicGrant {
wallet_access_id: WALLET_ACCESS_ID, wallet_access_id: WALLET_ACCESS_ID,
chain_id: CHAIN_ID as i32, chain_id: CHAIN_ID.into(),
valid_from: None, valid_from: None,
valid_until: None, valid_until: None,
max_gas_fee_per_gas: None, max_gas_fee_per_gas: None,
@@ -81,8 +84,6 @@ fn shared() -> SharedGrantSettings {
} }
} }
// ── analyze ─────────────────────────────────────────────────────────────
#[test] #[test]
fn analyze_matches_empty_calldata() { fn analyze_matches_empty_calldata() {
let m = EtherTransfer::analyze(&ctx(ALLOWED, U256::from(1_000u64))).unwrap(); let m = EtherTransfer::analyze(&ctx(ALLOWED, U256::from(1_000u64))).unwrap();
@@ -99,8 +100,6 @@ fn analyze_rejects_nonempty_calldata() {
assert!(EtherTransfer::analyze(&context).is_none()); assert!(EtherTransfer::analyze(&context).is_none());
} }
// ── evaluate ────────────────────────────────────────────────────────────
#[tokio::test] #[tokio::test]
async fn evaluate_passes_for_allowed_target() { async fn evaluate_passes_for_allowed_target() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
@@ -108,9 +107,11 @@ async fn evaluate_passes_for_allowed_target() {
let grant = Grant { let grant = Grant {
id: 999, id: 999,
shared_grant_id: 999, common_settings_id: 999,
shared: shared(), settings: CombinedSettings {
settings: make_settings(vec![ALLOWED], 1_000_000), shared: shared(),
specific: make_settings(vec![ALLOWED], 1_000_000),
},
}; };
let context = ctx(ALLOWED, U256::from(100u64)); let context = ctx(ALLOWED, U256::from(100u64));
let m = EtherTransfer::analyze(&context).unwrap(); let m = EtherTransfer::analyze(&context).unwrap();
@@ -127,9 +128,11 @@ async fn evaluate_rejects_disallowed_target() {
let grant = Grant { let grant = Grant {
id: 999, id: 999,
shared_grant_id: 999, common_settings_id: 999,
shared: shared(), settings: CombinedSettings {
settings: make_settings(vec![ALLOWED], 1_000_000), shared: shared(),
specific: make_settings(vec![ALLOWED], 1_000_000),
},
}; };
let context = ctx(OTHER, U256::from(100u64)); let context = ctx(OTHER, U256::from(100u64));
let m = EtherTransfer::analyze(&context).unwrap(); let m = EtherTransfer::analyze(&context).unwrap();
@@ -157,7 +160,7 @@ async fn evaluate_passes_when_volume_within_limit() {
.values(NewEvmTransactionLog { .values(NewEvmTransactionLog {
grant_id, grant_id,
wallet_access_id: WALLET_ACCESS_ID, wallet_access_id: WALLET_ACCESS_ID,
chain_id: CHAIN_ID as i32, chain_id: CHAIN_ID.into(),
eth_value: utils::u256_to_bytes(U256::from(500u64)).to_vec(), eth_value: utils::u256_to_bytes(U256::from(500u64)).to_vec(),
signed_at: SqliteTimestamp(Utc::now()), signed_at: SqliteTimestamp(Utc::now()),
}) })
@@ -167,9 +170,11 @@ async fn evaluate_passes_when_volume_within_limit() {
let grant = Grant { let grant = Grant {
id: grant_id, id: grant_id,
shared_grant_id: basic.id, common_settings_id: basic.id,
shared: shared(), settings: CombinedSettings {
settings, shared: shared(),
specific: settings,
},
}; };
let context = ctx(ALLOWED, U256::from(100u64)); let context = ctx(ALLOWED, U256::from(100u64));
let m = EtherTransfer::analyze(&context).unwrap(); let m = EtherTransfer::analyze(&context).unwrap();
@@ -197,7 +202,7 @@ async fn evaluate_rejects_volume_over_limit() {
.values(NewEvmTransactionLog { .values(NewEvmTransactionLog {
grant_id, grant_id,
wallet_access_id: WALLET_ACCESS_ID, wallet_access_id: WALLET_ACCESS_ID,
chain_id: CHAIN_ID as i32, chain_id: CHAIN_ID.into(),
eth_value: utils::u256_to_bytes(U256::from(1_000u64)).to_vec(), eth_value: utils::u256_to_bytes(U256::from(1_000u64)).to_vec(),
signed_at: SqliteTimestamp(Utc::now()), signed_at: SqliteTimestamp(Utc::now()),
}) })
@@ -207,9 +212,11 @@ async fn evaluate_rejects_volume_over_limit() {
let grant = Grant { let grant = Grant {
id: grant_id, id: grant_id,
shared_grant_id: basic.id, common_settings_id: basic.id,
shared: shared(), settings: CombinedSettings {
settings, shared: shared(),
specific: settings,
},
}; };
let context = ctx(ALLOWED, U256::from(1u64)); let context = ctx(ALLOWED, U256::from(1u64));
let m = EtherTransfer::analyze(&context).unwrap(); let m = EtherTransfer::analyze(&context).unwrap();
@@ -238,7 +245,7 @@ async fn evaluate_passes_at_exactly_volume_limit() {
.values(NewEvmTransactionLog { .values(NewEvmTransactionLog {
grant_id, grant_id,
wallet_access_id: WALLET_ACCESS_ID, wallet_access_id: WALLET_ACCESS_ID,
chain_id: CHAIN_ID as i32, chain_id: CHAIN_ID.into(),
eth_value: utils::u256_to_bytes(U256::from(900u64)).to_vec(), eth_value: utils::u256_to_bytes(U256::from(900u64)).to_vec(),
signed_at: SqliteTimestamp(Utc::now()), signed_at: SqliteTimestamp(Utc::now()),
}) })
@@ -248,9 +255,11 @@ async fn evaluate_passes_at_exactly_volume_limit() {
let grant = Grant { let grant = Grant {
id: grant_id, id: grant_id,
shared_grant_id: basic.id, common_settings_id: basic.id,
shared: shared(), settings: CombinedSettings {
settings, shared: shared(),
specific: settings,
},
}; };
let context = ctx(ALLOWED, U256::from(100u64)); let context = ctx(ALLOWED, U256::from(100u64));
let m = EtherTransfer::analyze(&context).unwrap(); let m = EtherTransfer::analyze(&context).unwrap();
@@ -263,8 +272,6 @@ async fn evaluate_passes_at_exactly_volume_limit() {
); );
} }
// ── try_find_grant ───────────────────────────────────────────────────────
#[tokio::test] #[tokio::test]
async fn try_find_grant_roundtrip() { async fn try_find_grant_roundtrip() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
@@ -282,8 +289,11 @@ async fn try_find_grant_roundtrip() {
assert!(found.is_some()); assert!(found.is_some());
let g = found.unwrap(); let g = found.unwrap();
assert_eq!(g.settings.target, vec![ALLOWED]); assert_eq!(g.settings.specific.target, vec![ALLOWED]);
assert_eq!(g.settings.limit.max_volume, U256::from(1_000_000u64)); assert_eq!(
g.settings.specific.limit.max_volume,
U256::from(1_000_000u64)
);
} }
#[tokio::test] #[tokio::test]
@@ -320,7 +330,36 @@ async fn try_find_grant_wrong_target_returns_none() {
assert!(found.is_none()); assert!(found.is_none());
} }
// ── find_all_grants ────────────────────────────────────────────────────── proptest::proptest! {
#[test]
fn target_order_does_not_affect_hash(
raw_addrs in proptest::collection::vec(proptest::prelude::any::<[u8; 20]>(), 0..8),
seed in proptest::prelude::any::<u64>(),
max_volume in proptest::prelude::any::<u64>(),
window_secs in 1i64..=86400,
) {
use rand::{SeedableRng, seq::SliceRandom};
use sha2::Digest;
use arbiter_crypto::hashing::Hashable;
let addrs: Vec<Address> = raw_addrs.iter().map(|b| Address::from(*b)).collect();
let mut shuffled = addrs.clone();
shuffled.shuffle(&mut rand::rngs::StdRng::seed_from_u64(seed));
let limit = VolumeRateLimit {
max_volume: U256::from(max_volume),
window: Duration::seconds(window_secs),
};
let mut h1 = sha2::Sha256::new();
Settings { target: addrs, limit: limit.clone() }.hash(&mut h1);
let mut h2 = sha2::Sha256::new();
Settings { target: shuffled, limit }.hash(&mut h2);
proptest::prop_assert_eq!(h1.finalize(), h2.finalize());
}
}
#[tokio::test] #[tokio::test]
async fn find_all_grants_empty_db() { async fn find_all_grants_empty_db() {
@@ -347,7 +386,7 @@ async fn find_all_grants_excludes_revoked() {
let all = EtherTransfer::find_all_grants(&mut *conn).await.unwrap(); let all = EtherTransfer::find_all_grants(&mut *conn).await.unwrap();
assert_eq!(all.len(), 1); assert_eq!(all.len(), 1);
assert_eq!(all[0].settings.target, vec![ALLOWED]); assert_eq!(all[0].settings.specific.target, vec![ALLOWED]);
} }
#[tokio::test] #[tokio::test]
@@ -363,8 +402,11 @@ async fn find_all_grants_multiple_targets() {
let all = EtherTransfer::find_all_grants(&mut *conn).await.unwrap(); let all = EtherTransfer::find_all_grants(&mut *conn).await.unwrap();
assert_eq!(all.len(), 1); assert_eq!(all.len(), 1);
assert_eq!(all[0].settings.target.len(), 2); assert_eq!(all[0].settings.specific.target.len(), 2);
assert_eq!(all[0].settings.limit.max_volume, U256::from(1_000_000u64)); assert_eq!(
all[0].settings.specific.limit.max_volume,
U256::from(1_000_000u64)
);
} }
#[tokio::test] #[tokio::test]

View File

@@ -1,20 +1,5 @@
use std::collections::HashMap; use std::collections::HashMap;
use alloy::{
primitives::{Address, U256},
sol_types::SolCall,
};
use arbiter_tokens_registry::evm::nonfungible::{self, TokenInfo};
use chrono::{DateTime, Duration, Utc};
use diesel::dsl::{auto_type, insert_into};
use diesel::sqlite::Sqlite;
use diesel::{ExpressionMethods, prelude::*};
use diesel_async::{AsyncConnection, RunQueryDsl};
use crate::db::models::{
EvmBasicGrant, EvmTokenTransferGrant, EvmTokenTransferVolumeLimit, NewEvmTokenTransferGrant,
NewEvmTokenTransferLog, NewEvmTokenTransferVolumeLimit, SqliteTimestamp,
};
use crate::db::schema::{ use crate::db::schema::{
evm_basic_grant, evm_token_transfer_grant, evm_token_transfer_log, evm_basic_grant, evm_token_transfer_grant, evm_token_transfer_log,
evm_token_transfer_volume_limit, evm_token_transfer_volume_limit,
@@ -26,6 +11,25 @@ use crate::evm::{
}, },
utils, utils,
}; };
use crate::{
crypto::integrity::Integrable,
db::models::{
EvmBasicGrant, EvmTokenTransferGrant, EvmTokenTransferVolumeLimit,
NewEvmTokenTransferGrant, NewEvmTokenTransferLog, NewEvmTokenTransferVolumeLimit,
SqliteTimestamp,
},
evm::policies::CombinedSettings,
};
use alloy::{
primitives::{Address, U256},
sol_types::SolCall,
};
use arbiter_tokens_registry::evm::nonfungible::{self, TokenInfo};
use chrono::{DateTime, Duration, Utc};
use diesel::dsl::{auto_type, insert_into};
use diesel::prelude::*;
use diesel::sqlite::Sqlite;
use diesel_async::{AsyncConnection, RunQueryDsl};
use super::{DatabaseID, EvalContext, EvalViolation}; use super::{DatabaseID, EvalContext, EvalViolation};
@@ -38,9 +42,9 @@ fn grant_join() -> _ {
#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Meaning { pub struct Meaning {
pub(crate) token: &'static TokenInfo, pub token: &'static TokenInfo,
pub(crate) to: Address, pub to: Address,
pub(crate) value: U256, pub value: U256,
} }
impl std::fmt::Display for Meaning { impl std::fmt::Display for Meaning {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@@ -52,21 +56,25 @@ impl std::fmt::Display for Meaning {
} }
} }
impl From<Meaning> for SpecificMeaning { impl From<Meaning> for SpecificMeaning {
fn from(val: Meaning) -> SpecificMeaning { fn from(val: Meaning) -> Self {
SpecificMeaning::TokenTransfer(val) Self::TokenTransfer(val)
} }
} }
// A grant for token transfers, which can be scoped to specific target addresses and volume limits // A grant for token transfers, which can be scoped to specific target addresses and volume limits
#[derive(Debug, Clone)] #[derive(Debug, Clone, arbiter_macros::Hashable)]
pub struct Settings { pub struct Settings {
pub token_contract: Address, pub token_contract: Address,
pub target: Option<Address>, pub target: Option<Address>,
pub volume_limits: Vec<VolumeRateLimit>, pub volume_limits: Vec<VolumeRateLimit>,
} }
impl Integrable for Settings {
const KIND: &'static str = "TokenTransfer";
}
impl From<Settings> for SpecificGrant { impl From<Settings> for SpecificGrant {
fn from(val: Settings) -> SpecificGrant { fn from(val: Settings) -> Self {
SpecificGrant::TokenTransfer(val) Self::TokenTransfer(val)
} }
} }
@@ -77,10 +85,7 @@ async fn query_relevant_past_transfers(
) -> QueryResult<Vec<(U256, DateTime<Utc>)>> { ) -> QueryResult<Vec<(U256, DateTime<Utc>)>> {
let past_logs: Vec<(Vec<u8>, SqliteTimestamp)> = evm_token_transfer_log::table let past_logs: Vec<(Vec<u8>, SqliteTimestamp)> = evm_token_transfer_log::table
.filter(evm_token_transfer_log::grant_id.eq(grant_id)) .filter(evm_token_transfer_log::grant_id.eq(grant_id))
.filter( .filter(evm_token_transfer_log::created_at.ge(SqliteTimestamp(Utc::now() - longest_window)))
evm_token_transfer_log::created_at
.ge(SqliteTimestamp(chrono::Utc::now() - longest_window)),
)
.select(( .select((
evm_token_transfer_log::value, evm_token_transfer_log::value,
evm_token_transfer_log::created_at, evm_token_transfer_log::created_at,
@@ -106,14 +111,21 @@ async fn check_volume_rate_limits(
) -> QueryResult<Vec<EvalViolation>> { ) -> QueryResult<Vec<EvalViolation>> {
let mut violations = Vec::new(); let mut violations = Vec::new();
let Some(longest_window) = grant.settings.volume_limits.iter().map(|l| l.window).max() else { let Some(longest_window) = grant
.settings
.specific
.volume_limits
.iter()
.map(|l| l.window)
.max()
else {
return Ok(violations); return Ok(violations);
}; };
let past_transfers = query_relevant_past_transfers(grant.id, longest_window, db).await?; let past_transfers = query_relevant_past_transfers(grant.id, longest_window, db).await?;
for limit in &grant.settings.volume_limits { for limit in &grant.settings.specific.volume_limits {
let window_start = chrono::Utc::now() - limit.window; let window_start = Utc::now() - limit.window;
let prospective_cumulative_volume: U256 = past_transfers let prospective_cumulative_volume: U256 = past_transfers
.iter() .iter()
.filter(|(_, timestamp)| timestamp >= &window_start) .filter(|(_, timestamp)| timestamp >= &window_start)
@@ -158,7 +170,7 @@ impl Policy for TokenTransfer {
return Ok(violations); return Ok(violations);
} }
if let Some(allowed) = grant.settings.target if let Some(allowed) = grant.settings.specific.target
&& allowed != meaning.to && allowed != meaning.to
{ {
violations.push(EvalViolation::InvalidTarget { target: meaning.to }); violations.push(EvalViolation::InvalidTarget { target: meaning.to });
@@ -189,6 +201,11 @@ impl Policy for TokenTransfer {
.await?; .await?;
for limit in &grant.volume_limits { for limit in &grant.volume_limits {
#[expect(
clippy::cast_possible_truncation,
clippy::as_conversions,
reason = "fixme! #86"
)]
insert_into(evm_token_transfer_volume_limit::table) insert_into(evm_token_transfer_volume_limit::table)
.values(NewEvmTokenTransferVolumeLimit { .values(NewEvmTokenTransferVolumeLimit {
grant_id, grant_id,
@@ -238,7 +255,7 @@ impl Policy for TokenTransfer {
max_volume: utils::try_bytes_to_u256(&row.max_volume).map_err(|err| { max_volume: utils::try_bytes_to_u256(&row.max_volume).map_err(|err| {
diesel::result::Error::DeserializationError(Box::new(err)) diesel::result::Error::DeserializationError(Box::new(err))
})?, })?,
window: Duration::seconds(row.window_secs as i64), window: Duration::seconds(row.window_secs.into()),
}) })
}) })
.collect::<QueryResult<Vec<_>>>()?; .collect::<QueryResult<Vec<_>>>()?;
@@ -269,9 +286,11 @@ impl Policy for TokenTransfer {
Ok(Some(Grant { Ok(Some(Grant {
id: token_grant.id, id: token_grant.id,
shared_grant_id: token_grant.basic_grant_id, common_settings_id: token_grant.basic_grant_id,
shared: SharedGrantSettings::try_from_model(basic_grant)?, settings: CombinedSettings {
settings, shared: SharedGrantSettings::try_from_model(basic_grant)?,
specific: settings,
},
})) }))
} }
@@ -286,7 +305,7 @@ impl Policy for TokenTransfer {
.values(NewEvmTokenTransferLog { .values(NewEvmTokenTransferLog {
grant_id: grant.id, grant_id: grant.id,
log_id, log_id,
chain_id: context.chain as i32, chain_id: context.chain.into(),
token_contract: context.to.to_vec(), token_contract: context.to.to_vec(),
recipient_address: meaning.to.to_vec(), recipient_address: meaning.to.to_vec(),
value: utils::u256_to_bytes(meaning.value).to_vec(), value: utils::u256_to_bytes(meaning.value).to_vec(),
@@ -335,7 +354,7 @@ impl Policy for TokenTransfer {
.map(|(basic, specific)| { .map(|(basic, specific)| {
let volume_limits: Vec<VolumeRateLimit> = limits_by_grant let volume_limits: Vec<VolumeRateLimit> = limits_by_grant
.get(&specific.id) .get(&specific.id)
.map(|v| v.as_slice()) .map(Vec::as_slice)
.unwrap_or_default() .unwrap_or_default()
.iter() .iter()
.map(|row| { .map(|row| {
@@ -343,7 +362,7 @@ impl Policy for TokenTransfer {
max_volume: utils::try_bytes_to_u256(&row.max_volume).map_err(|e| { max_volume: utils::try_bytes_to_u256(&row.max_volume).map_err(|e| {
diesel::result::Error::DeserializationError(Box::new(e)) diesel::result::Error::DeserializationError(Box::new(e))
})?, })?,
window: Duration::seconds(row.window_secs as i64), window: Duration::seconds(row.window_secs.into()),
}) })
}) })
.collect::<QueryResult<Vec<_>>>()?; .collect::<QueryResult<Vec<_>>>()?;
@@ -369,12 +388,14 @@ impl Policy for TokenTransfer {
Ok(Grant { Ok(Grant {
id: specific.id, id: specific.id,
shared_grant_id: specific.basic_grant_id, common_settings_id: specific.basic_grant_id,
shared: SharedGrantSettings::try_from_model(basic)?, settings: CombinedSettings {
settings: Settings { shared: SharedGrantSettings::try_from_model(basic)?,
token_contract: Address::from(token_contract), specific: Settings {
target, token_contract: Address::from(token_contract),
volume_limits, target,
volume_limits,
},
}, },
}) })
}) })

View File

@@ -11,7 +11,10 @@ use crate::db::{
}; };
use crate::evm::{ use crate::evm::{
abi::IERC20::transferCall, abi::IERC20::transferCall,
policies::{EvalContext, EvalViolation, Grant, Policy, SharedGrantSettings, VolumeRateLimit}, policies::{
CombinedSettings, EvalContext, EvalViolation, Grant, Policy, SharedGrantSettings,
VolumeRateLimit,
},
utils, utils,
}; };
@@ -56,7 +59,7 @@ async fn insert_basic(conn: &mut DatabaseConnection, revoked: bool) -> EvmBasicG
insert_into(evm_basic_grant::table) insert_into(evm_basic_grant::table)
.values(NewEvmBasicGrant { .values(NewEvmBasicGrant {
wallet_access_id: WALLET_ACCESS_ID, wallet_access_id: WALLET_ACCESS_ID,
chain_id: CHAIN_ID as i32, chain_id: CHAIN_ID.into(),
valid_from: None, valid_from: None,
valid_until: None, valid_until: None,
max_gas_fee_per_gas: None, max_gas_fee_per_gas: None,
@@ -98,8 +101,6 @@ fn shared() -> SharedGrantSettings {
} }
} }
// ── analyze ─────────────────────────────────────────────────────────────
#[test] #[test]
fn analyze_known_token_valid_calldata() { fn analyze_known_token_valid_calldata() {
let calldata = transfer_calldata(RECIPIENT, U256::from(100u64)); let calldata = transfer_calldata(RECIPIENT, U256::from(100u64));
@@ -125,8 +126,6 @@ fn analyze_empty_calldata_returns_none() {
assert!(TokenTransfer::analyze(&ctx(DAI, Bytes::new())).is_none()); assert!(TokenTransfer::analyze(&ctx(DAI, Bytes::new())).is_none());
} }
// ── evaluate ────────────────────────────────────────────────────────────
#[tokio::test] #[tokio::test]
async fn evaluate_rejects_nonzero_eth_value() { async fn evaluate_rejects_nonzero_eth_value() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
@@ -134,9 +133,11 @@ async fn evaluate_rejects_nonzero_eth_value() {
let grant = Grant { let grant = Grant {
id: 999, id: 999,
shared_grant_id: 999, common_settings_id: 999,
shared: shared(), settings: CombinedSettings {
settings: make_settings(None, None), shared: shared(),
specific: make_settings(None, None),
},
}; };
let calldata = transfer_calldata(RECIPIENT, U256::from(100u64)); let calldata = transfer_calldata(RECIPIENT, U256::from(100u64));
let mut context = ctx(DAI, calldata); let mut context = ctx(DAI, calldata);
@@ -163,9 +164,11 @@ async fn evaluate_passes_any_recipient_when_no_restriction() {
let grant = Grant { let grant = Grant {
id: 999, id: 999,
shared_grant_id: 999, common_settings_id: 999,
shared: shared(), settings: CombinedSettings {
settings: make_settings(None, None), shared: shared(),
specific: make_settings(None, None),
},
}; };
let calldata = transfer_calldata(RECIPIENT, U256::from(100u64)); let calldata = transfer_calldata(RECIPIENT, U256::from(100u64));
let context = ctx(DAI, calldata); let context = ctx(DAI, calldata);
@@ -183,9 +186,11 @@ async fn evaluate_passes_matching_restricted_recipient() {
let grant = Grant { let grant = Grant {
id: 999, id: 999,
shared_grant_id: 999, common_settings_id: 999,
shared: shared(), settings: CombinedSettings {
settings: make_settings(Some(RECIPIENT), None), shared: shared(),
specific: make_settings(Some(RECIPIENT), None),
},
}; };
let calldata = transfer_calldata(RECIPIENT, U256::from(100u64)); let calldata = transfer_calldata(RECIPIENT, U256::from(100u64));
let context = ctx(DAI, calldata); let context = ctx(DAI, calldata);
@@ -203,9 +208,11 @@ async fn evaluate_rejects_wrong_restricted_recipient() {
let grant = Grant { let grant = Grant {
id: 999, id: 999,
shared_grant_id: 999, common_settings_id: 999,
shared: shared(), settings: CombinedSettings {
settings: make_settings(Some(RECIPIENT), None), shared: shared(),
specific: make_settings(Some(RECIPIENT), None),
},
}; };
let calldata = transfer_calldata(OTHER, U256::from(100u64)); let calldata = transfer_calldata(OTHER, U256::from(100u64));
let context = ctx(DAI, calldata); let context = ctx(DAI, calldata);
@@ -231,12 +238,11 @@ async fn evaluate_passes_volume_at_exact_limit() {
.unwrap(); .unwrap();
// Record a past transfer of 900, with current transfer 100 => exactly 1000 limit // Record a past transfer of 900, with current transfer 100 => exactly 1000 limit
use crate::db::{models::NewEvmTokenTransferLog, schema::evm_token_transfer_log}; insert_into(db::schema::evm_token_transfer_log::table)
insert_into(evm_token_transfer_log::table) .values(db::models::NewEvmTokenTransferLog {
.values(NewEvmTokenTransferLog {
grant_id, grant_id,
log_id: 0, log_id: 0,
chain_id: CHAIN_ID as i32, chain_id: CHAIN_ID.into(),
token_contract: DAI.to_vec(), token_contract: DAI.to_vec(),
recipient_address: RECIPIENT.to_vec(), recipient_address: RECIPIENT.to_vec(),
value: utils::u256_to_bytes(U256::from(900u64)).to_vec(), value: utils::u256_to_bytes(U256::from(900u64)).to_vec(),
@@ -247,9 +253,11 @@ async fn evaluate_passes_volume_at_exact_limit() {
let grant = Grant { let grant = Grant {
id: grant_id, id: grant_id,
shared_grant_id: basic.id, common_settings_id: basic.id,
shared: shared(), settings: CombinedSettings {
settings, shared: shared(),
specific: settings,
},
}; };
let calldata = transfer_calldata(RECIPIENT, U256::from(100u64)); let calldata = transfer_calldata(RECIPIENT, U256::from(100u64));
let context = ctx(DAI, calldata); let context = ctx(DAI, calldata);
@@ -274,12 +282,11 @@ async fn evaluate_rejects_volume_over_limit() {
.await .await
.unwrap(); .unwrap();
use crate::db::{models::NewEvmTokenTransferLog, schema::evm_token_transfer_log}; insert_into(db::schema::evm_token_transfer_log::table)
insert_into(evm_token_transfer_log::table) .values(db::models::NewEvmTokenTransferLog {
.values(NewEvmTokenTransferLog {
grant_id, grant_id,
log_id: 0, log_id: 0,
chain_id: CHAIN_ID as i32, chain_id: CHAIN_ID.into(),
token_contract: DAI.to_vec(), token_contract: DAI.to_vec(),
recipient_address: RECIPIENT.to_vec(), recipient_address: RECIPIENT.to_vec(),
value: utils::u256_to_bytes(U256::from(1_000u64)).to_vec(), value: utils::u256_to_bytes(U256::from(1_000u64)).to_vec(),
@@ -290,9 +297,11 @@ async fn evaluate_rejects_volume_over_limit() {
let grant = Grant { let grant = Grant {
id: grant_id, id: grant_id,
shared_grant_id: basic.id, common_settings_id: basic.id,
shared: shared(), settings: CombinedSettings {
settings, shared: shared(),
specific: settings,
},
}; };
let calldata = transfer_calldata(RECIPIENT, U256::from(1u64)); let calldata = transfer_calldata(RECIPIENT, U256::from(1u64));
let context = ctx(DAI, calldata); let context = ctx(DAI, calldata);
@@ -313,9 +322,11 @@ async fn evaluate_no_volume_limits_always_passes() {
let grant = Grant { let grant = Grant {
id: 999, id: 999,
shared_grant_id: 999, common_settings_id: 999,
shared: shared(), settings: CombinedSettings {
settings: make_settings(None, None), // no volume limits shared: shared(),
specific: make_settings(None, None), // no volume limits
},
}; };
let calldata = transfer_calldata(RECIPIENT, U256::from(u64::MAX)); let calldata = transfer_calldata(RECIPIENT, U256::from(u64::MAX));
let context = ctx(DAI, calldata); let context = ctx(DAI, calldata);
@@ -349,10 +360,13 @@ async fn try_find_grant_roundtrip() {
assert!(found.is_some()); assert!(found.is_some());
let g = found.unwrap(); let g = found.unwrap();
assert_eq!(g.settings.token_contract, DAI); assert_eq!(g.settings.specific.token_contract, DAI);
assert_eq!(g.settings.target, Some(RECIPIENT)); assert_eq!(g.settings.specific.target, Some(RECIPIENT));
assert_eq!(g.settings.volume_limits.len(), 1); assert_eq!(g.settings.specific.volume_limits.len(), 1);
assert_eq!(g.settings.volume_limits[0].max_volume, U256::from(5_000u64)); assert_eq!(
g.settings.specific.volume_limits[0].max_volume,
U256::from(5_000u64)
);
} }
#[tokio::test] #[tokio::test]
@@ -392,7 +406,39 @@ async fn try_find_grant_unknown_token_returns_none() {
assert!(found.is_none()); assert!(found.is_none());
} }
// ── find_all_grants ────────────────────────────────────────────────────── proptest::proptest! {
#[test]
fn volume_limits_order_does_not_affect_hash(
raw_limits in proptest::collection::vec(
(proptest::prelude::any::<u64>(), 1i64..=86400),
0..8,
),
seed in proptest::prelude::any::<u64>(),
) {
use rand::{SeedableRng, seq::SliceRandom};
use sha2::Digest;
use arbiter_crypto::hashing::Hashable;
let limits: Vec<VolumeRateLimit> = raw_limits
.iter()
.map(|(max_vol, window_secs)| VolumeRateLimit {
max_volume: U256::from(*max_vol),
window: Duration::seconds(*window_secs),
})
.collect();
let mut shuffled = limits.clone();
shuffled.shuffle(&mut rand::rngs::StdRng::seed_from_u64(seed));
let mut h1 = sha2::Sha256::new();
Settings { token_contract: DAI, target: None, volume_limits: limits }.hash(&mut h1);
let mut h2 = sha2::Sha256::new();
Settings { token_contract: DAI, target: None, volume_limits: shuffled }.hash(&mut h2);
proptest::prop_assert_eq!(h1.finalize(), h2.finalize());
}
}
#[tokio::test] #[tokio::test]
async fn find_all_grants_empty_db() { async fn find_all_grants_empty_db() {
@@ -434,9 +480,9 @@ async fn find_all_grants_loads_volume_limits() {
let all = TokenTransfer::find_all_grants(&mut *conn).await.unwrap(); let all = TokenTransfer::find_all_grants(&mut *conn).await.unwrap();
assert_eq!(all.len(), 1); assert_eq!(all.len(), 1);
assert_eq!(all[0].settings.volume_limits.len(), 1); assert_eq!(all[0].settings.specific.volume_limits.len(), 1);
assert_eq!( assert_eq!(
all[0].settings.volume_limits[0].max_volume, all[0].settings.specific.volume_limits[0].max_volume,
U256::from(9_999u64) U256::from(9_999u64)
); );
} }

View File

@@ -1,12 +1,12 @@
use std::sync::Mutex; use std::sync::Mutex;
use crate::safe_cell::{SafeCell, SafeCellHandle as _};
use alloy::{ use alloy::{
consensus::SignableTransaction, consensus::SignableTransaction,
network::{TxSigner, TxSignerSync}, network::{TxSigner, TxSignerSync},
primitives::{Address, B256, ChainId, Signature}, primitives::{Address, B256, ChainId, Signature},
signers::{Error, Result, Signer, SignerSync, utils::secret_key_to_address}, signers::{Error, Result, Signer, SignerSync, utils::secret_key_to_address},
}; };
use arbiter_crypto::safecell::{SafeCell, SafeCellHandle as _};
use async_trait::async_trait; use async_trait::async_trait;
use k256::ecdsa::{self, RecoveryId, SigningKey, signature::hazmat::PrehashSigner}; use k256::ecdsa::{self, RecoveryId, SigningKey, signature::hazmat::PrehashSigner};
@@ -82,8 +82,8 @@ impl SafeSigner {
}) })
} }
#[expect(clippy::significant_drop_tightening, reason = "false positive")]
fn sign_hash_inner(&self, hash: &B256) -> Result<Signature> { fn sign_hash_inner(&self, hash: &B256) -> Result<Signature> {
#[allow(clippy::expect_used)]
let mut cell = self.key.lock().expect("SafeSigner mutex poisoned"); let mut cell = self.key.lock().expect("SafeSigner mutex poisoned");
let reader = cell.read(); let reader = cell.read();
let sig: (ecdsa::Signature, RecoveryId) = reader.sign_prehash(hash.as_ref())?; let sig: (ecdsa::Signature, RecoveryId) = reader.sign_prehash(hash.as_ref())?;
@@ -96,7 +96,6 @@ impl SafeSigner {
{ {
return Err(Error::TransactionChainIdMismatch { return Err(Error::TransactionChainIdMismatch {
signer: chain_id, signer: chain_id,
#[allow(clippy::expect_used)]
tx: tx.chain_id().expect("Chain ID is guaranteed to be set"), tx: tx.chain_id().expect("Chain ID is guaranteed to be set"),
}); });
} }

View File

@@ -7,7 +7,7 @@ pub struct LengthError {
pub actual: usize, pub actual: usize,
} }
pub fn u256_to_bytes(value: U256) -> [u8; 32] { pub const fn u256_to_bytes(value: U256) -> [u8; 32] {
value.to_le_bytes() value.to_le_bytes()
} }
pub fn bytes_to_u256(bytes: &[u8]) -> Option<U256> { pub fn bytes_to_u256(bytes: &[u8]) -> Option<U256> {

View File

@@ -98,8 +98,7 @@ pub async fn start(mut conn: ClientConnection, mut bi: GrpcBi<ClientRequest, Cli
Err(err) => { Err(err) => {
let _ = bi let _ = bi
.send(Err(Status::unauthenticated(format!( .send(Err(Status::unauthenticated(format!(
"Authentication failed: {}", "Authentication failed: {err}",
err
)))) ))))
.await; .await;
warn!(error = ?err, "Client authentication failed"); warn!(error = ?err, "Client authentication failed");

View File

@@ -1,11 +1,11 @@
use arbiter_crypto::authn;
use arbiter_proto::{ use arbiter_proto::{
ClientMetadata, ClientMetadata,
proto::{ proto::{
client::{ client::{
ClientRequest, ClientResponse, ClientRequest, ClientResponse,
auth::{ auth::{
self as proto_auth, AuthChallenge as ProtoAuthChallenge, self as proto_auth, AuthChallengeRequest as ProtoAuthChallengeRequest,
AuthChallengeRequest as ProtoAuthChallengeRequest,
AuthChallengeSolution as ProtoAuthChallengeSolution, AuthResult as ProtoAuthResult, AuthChallengeSolution as ProtoAuthChallengeSolution, AuthResult as ProtoAuthResult,
request::Payload as AuthRequestPayload, response::Payload as AuthResponsePayload, request::Payload as AuthRequestPayload, response::Payload as AuthResponsePayload,
}, },
@@ -21,7 +21,7 @@ use tonic::Status;
use tracing::warn; use tracing::warn;
use crate::{ use crate::{
actors::client::{self, ClientConnection, auth}, actors::client::{ClientConnection, auth},
grpc::request_tracker::RequestTracker, grpc::request_tracker::RequestTracker,
}; };
@@ -31,7 +31,7 @@ pub struct AuthTransportAdapter<'a> {
} }
impl<'a> AuthTransportAdapter<'a> { impl<'a> AuthTransportAdapter<'a> {
pub fn new( pub const fn new(
bi: &'a mut GrpcBi<ClientRequest, ClientResponse>, bi: &'a mut GrpcBi<ClientRequest, ClientResponse>,
request_tracker: &'a mut RequestTracker, request_tracker: &'a mut RequestTracker,
) -> Self { ) -> Self {
@@ -41,39 +41,6 @@ impl<'a> AuthTransportAdapter<'a> {
} }
} }
fn response_to_proto(response: auth::Outbound) -> AuthResponsePayload {
match response {
auth::Outbound::AuthChallenge { pubkey, nonce } => {
AuthResponsePayload::Challenge(ProtoAuthChallenge {
pubkey: pubkey.to_bytes().to_vec(),
nonce,
})
}
auth::Outbound::AuthSuccess => {
AuthResponsePayload::Result(ProtoAuthResult::Success.into())
}
}
}
fn error_to_proto(error: auth::Error) -> AuthResponsePayload {
AuthResponsePayload::Result(
match error {
auth::Error::InvalidChallengeSolution => ProtoAuthResult::InvalidSignature,
auth::Error::ApproveError(auth::ApproveError::Denied) => {
ProtoAuthResult::ApprovalDenied
}
auth::Error::ApproveError(auth::ApproveError::Upstream(
crate::actors::flow_coordinator::ApprovalError::NoUserAgentsConnected,
)) => ProtoAuthResult::NoUserAgentsOnline,
auth::Error::ApproveError(auth::ApproveError::Internal)
| auth::Error::DatabasePoolUnavailable
| auth::Error::DatabaseOperationFailed
| auth::Error::Transport => ProtoAuthResult::Internal,
}
.into(),
)
}
async fn send_client_response( async fn send_client_response(
&mut self, &mut self,
payload: AuthResponsePayload, payload: AuthResponsePayload,
@@ -95,14 +62,14 @@ impl<'a> AuthTransportAdapter<'a> {
} }
#[async_trait] #[async_trait]
impl Sender<Result<auth::Outbound, auth::Error>> for AuthTransportAdapter<'_> { impl Sender<Result<auth::Outbound, auth::ClientAuthError>> for AuthTransportAdapter<'_> {
async fn send( async fn send(
&mut self, &mut self,
item: Result<auth::Outbound, auth::Error>, item: Result<auth::Outbound, auth::ClientAuthError>,
) -> Result<(), TransportError> { ) -> Result<(), TransportError> {
let payload = match item { let payload = match item {
Ok(message) => AuthTransportAdapter::response_to_proto(message), Ok(message) => message.into(),
Err(err) => AuthTransportAdapter::error_to_proto(err), Err(err) => AuthResponsePayload::Result(ProtoAuthResult::from(err).into()),
}; };
self.send_client_response(payload).await self.send_client_response(payload).await
@@ -159,11 +126,7 @@ impl Receiver<auth::Inbound> for AuthTransportAdapter<'_> {
.await; .await;
return None; return None;
}; };
let Ok(pubkey) = <[u8; 32]>::try_from(pubkey) else { let Ok(pubkey) = authn::PublicKey::try_from(pubkey.as_slice()) else {
let _ = self.send_auth_result(ProtoAuthResult::InvalidKey).await;
return None;
};
let Ok(pubkey) = ed25519_dalek::VerifyingKey::from_bytes(&pubkey) else {
let _ = self.send_auth_result(ProtoAuthResult::InvalidKey).await; let _ = self.send_auth_result(ProtoAuthResult::InvalidKey).await;
return None; return None;
}; };
@@ -173,7 +136,7 @@ impl Receiver<auth::Inbound> for AuthTransportAdapter<'_> {
}) })
} }
AuthRequestPayload::ChallengeSolution(ProtoAuthChallengeSolution { signature }) => { AuthRequestPayload::ChallengeSolution(ProtoAuthChallengeSolution { signature }) => {
let Ok(signature) = ed25519_dalek::Signature::try_from(signature.as_slice()) else { let Ok(signature) = authn::Signature::try_from(signature.as_slice()) else {
let _ = self let _ = self
.send_auth_result(ProtoAuthResult::InvalidSignature) .send_auth_result(ProtoAuthResult::InvalidSignature)
.await; .await;
@@ -185,7 +148,7 @@ impl Receiver<auth::Inbound> for AuthTransportAdapter<'_> {
} }
} }
impl Bi<auth::Inbound, Result<auth::Outbound, auth::Error>> for AuthTransportAdapter<'_> {} impl Bi<auth::Inbound, Result<auth::Outbound, auth::ClientAuthError>> for AuthTransportAdapter<'_> {}
fn client_metadata_from_proto(metadata: ProtoClientInfo) -> ClientMetadata { fn client_metadata_from_proto(metadata: ProtoClientInfo) -> ClientMetadata {
ClientMetadata { ClientMetadata {
@@ -199,7 +162,7 @@ pub async fn start(
conn: &mut ClientConnection, conn: &mut ClientConnection,
bi: &mut GrpcBi<ClientRequest, ClientResponse>, bi: &mut GrpcBi<ClientRequest, ClientResponse>,
request_tracker: &mut RequestTracker, request_tracker: &mut RequestTracker,
) -> Result<i32, auth::Error> { ) -> Result<i32, auth::ClientAuthError> {
let mut transport = AuthTransportAdapter::new(bi, request_tracker); let mut transport = AuthTransportAdapter::new(bi, request_tracker);
client::auth::authenticate(conn, &mut transport).await auth::authenticate(conn, &mut transport).await
} }

View File

@@ -23,7 +23,7 @@ use crate::{
}, },
}; };
fn wrap_response(payload: EvmResponsePayload) -> ClientResponsePayload { const fn wrap_response(payload: EvmResponsePayload) -> ClientResponsePayload {
ClientResponsePayload::Evm(proto_evm::Response { ClientResponsePayload::Evm(proto_evm::Response {
payload: Some(payload), payload: Some(payload),
}) })

View File

@@ -13,7 +13,7 @@ use tonic::Status;
use tracing::warn; use tracing::warn;
use crate::actors::{ use crate::actors::{
client::session::{ClientSession, Error, HandleQueryVaultState}, client::session::{ClientSession, ClientSessionError, HandleQueryVaultState},
keyholder::KeyHolderState, keyholder::KeyHolderState,
}; };
@@ -28,12 +28,14 @@ pub(super) async fn dispatch(
}; };
match payload { match payload {
VaultRequestPayload::QueryState(_) => { VaultRequestPayload::QueryState(()) => {
let state = match actor.ask(HandleQueryVaultState {}).await { let state = match actor.ask(HandleQueryVaultState {}).await {
Ok(KeyHolderState::Unbootstrapped) => ProtoVaultState::Unbootstrapped, Ok(KeyHolderState::Unbootstrapped) => ProtoVaultState::Unbootstrapped,
Ok(KeyHolderState::Sealed) => ProtoVaultState::Sealed, Ok(KeyHolderState::Sealed) => ProtoVaultState::Sealed,
Ok(KeyHolderState::Unsealed) => ProtoVaultState::Unsealed, Ok(KeyHolderState::Unsealed) => ProtoVaultState::Unsealed,
Err(SendError::HandlerError(Error::Internal)) => ProtoVaultState::Error, Err(SendError::HandlerError(ClientSessionError::Internal)) => {
ProtoVaultState::Error
}
Err(err) => { Err(err) => {
warn!(error = ?err, "Failed to query vault state"); warn!(error = ?err, "Failed to query vault state");
ProtoVaultState::Error ProtoVaultState::Error

View File

@@ -8,7 +8,7 @@ use arbiter_proto::proto::{
EvalViolation as ProtoEvalViolation, GasLimitExceededViolation, NoMatchingGrantError, EvalViolation as ProtoEvalViolation, GasLimitExceededViolation, NoMatchingGrantError,
PolicyViolationsError, SpecificMeaning as ProtoSpecificMeaning, PolicyViolationsError, SpecificMeaning as ProtoSpecificMeaning,
TokenInfo as ProtoTokenInfo, TransactionEvalError as ProtoTransactionEvalError, TokenInfo as ProtoTokenInfo, TransactionEvalError as ProtoTransactionEvalError,
eval_violation::Kind as ProtoEvalViolationKind, eval_violation as proto_eval_violation, eval_violation::Kind as ProtoEvalViolationKind,
specific_meaning::Meaning as ProtoSpecificMeaningKind, specific_meaning::Meaning as ProtoSpecificMeaningKind,
transaction_eval_error::Kind as ProtoTransactionEvalErrorKind, transaction_eval_error::Kind as ProtoTransactionEvalErrorKind,
}, },
@@ -31,16 +31,16 @@ impl Convert for SpecificMeaning {
fn convert(self) -> Self::Output { fn convert(self) -> Self::Output {
let kind = match self { let kind = match self {
SpecificMeaning::EtherTransfer(meaning) => ProtoSpecificMeaningKind::EtherTransfer( Self::EtherTransfer(meaning) => ProtoSpecificMeaningKind::EtherTransfer(
arbiter_proto::proto::shared::evm::EtherTransferMeaning { arbiter_proto::proto::shared::evm::EtherTransferMeaning {
to: meaning.to.to_vec(), to: meaning.to.to_vec(),
value: u256_to_proto_bytes(meaning.value), value: u256_to_proto_bytes(meaning.value),
}, },
), ),
SpecificMeaning::TokenTransfer(meaning) => ProtoSpecificMeaningKind::TokenTransfer( Self::TokenTransfer(meaning) => ProtoSpecificMeaningKind::TokenTransfer(
arbiter_proto::proto::shared::evm::TokenTransferMeaning { arbiter_proto::proto::shared::evm::TokenTransferMeaning {
token: Some(ProtoTokenInfo { token: Some(ProtoTokenInfo {
symbol: meaning.token.symbol.to_string(), symbol: meaning.token.symbol.to_owned(),
address: meaning.token.contract.to_vec(), address: meaning.token.contract.to_vec(),
chain_id: meaning.token.chain, chain_id: meaning.token.chain,
}), }),
@@ -61,23 +61,25 @@ impl Convert for EvalViolation {
fn convert(self) -> Self::Output { fn convert(self) -> Self::Output {
let kind = match self { let kind = match self {
EvalViolation::InvalidTarget { target } => { Self::InvalidTarget { target } => {
ProtoEvalViolationKind::InvalidTarget(target.to_vec()) ProtoEvalViolationKind::InvalidTarget(target.to_vec())
} }
EvalViolation::GasLimitExceeded { Self::GasLimitExceeded {
max_gas_fee_per_gas, max_gas_fee_per_gas,
max_priority_fee_per_gas, max_priority_fee_per_gas,
} => ProtoEvalViolationKind::GasLimitExceeded(GasLimitExceededViolation { } => ProtoEvalViolationKind::GasLimitExceeded(GasLimitExceededViolation {
max_gas_fee_per_gas: max_gas_fee_per_gas.map(u256_to_proto_bytes), max_gas_fee_per_gas: max_gas_fee_per_gas.map(u256_to_proto_bytes),
max_priority_fee_per_gas: max_priority_fee_per_gas.map(u256_to_proto_bytes), max_priority_fee_per_gas: max_priority_fee_per_gas.map(u256_to_proto_bytes),
}), }),
EvalViolation::RateLimitExceeded => ProtoEvalViolationKind::RateLimitExceeded(()), Self::RateLimitExceeded => ProtoEvalViolationKind::RateLimitExceeded(()),
EvalViolation::VolumetricLimitExceeded => { Self::VolumetricLimitExceeded => ProtoEvalViolationKind::VolumetricLimitExceeded(()),
ProtoEvalViolationKind::VolumetricLimitExceeded(()) Self::InvalidTime => ProtoEvalViolationKind::InvalidTime(()),
} Self::InvalidTransactionType => ProtoEvalViolationKind::InvalidTransactionType(()),
EvalViolation::InvalidTime => ProtoEvalViolationKind::InvalidTime(()), Self::MismatchingChainId { expected, actual } => {
EvalViolation::InvalidTransactionType => { ProtoEvalViolationKind::ChainIdMismatch(proto_eval_violation::ChainIdMismatch {
ProtoEvalViolationKind::InvalidTransactionType(()) expected,
actual,
})
} }
}; };
@@ -90,13 +92,13 @@ impl Convert for VetError {
fn convert(self) -> Self::Output { fn convert(self) -> Self::Output {
let kind = match self { let kind = match self {
VetError::ContractCreationNotSupported => { Self::ContractCreationNotSupported => {
ProtoTransactionEvalErrorKind::ContractCreationNotSupported(()) ProtoTransactionEvalErrorKind::ContractCreationNotSupported(())
} }
VetError::UnsupportedTransactionType => { Self::UnsupportedTransactionType => {
ProtoTransactionEvalErrorKind::UnsupportedTransactionType(()) ProtoTransactionEvalErrorKind::UnsupportedTransactionType(())
} }
VetError::Evaluated(meaning, policy_error) => match policy_error { Self::Evaluated(meaning, policy_error) => match policy_error {
PolicyError::NoMatchingGrant => { PolicyError::NoMatchingGrant => {
ProtoTransactionEvalErrorKind::NoMatchingGrant(NoMatchingGrantError { ProtoTransactionEvalErrorKind::NoMatchingGrant(NoMatchingGrantError {
meaning: Some(meaning.convert()), meaning: Some(meaning.convert()),
@@ -108,12 +110,12 @@ impl Convert for VetError {
violations: violations.into_iter().map(Convert::convert).collect(), violations: violations.into_iter().map(Convert::convert).collect(),
}) })
} }
PolicyError::Database(_) => { PolicyError::Database(_) | PolicyError::Integrity(_) => {
return EvmSignTransactionResult::Error(ProtoEvmError::Internal.into()); return EvmSignTransactionResult::Error(ProtoEvmError::Internal.into());
} }
}, },
}; };
EvmSignTransactionResult::EvalError(ProtoTransactionEvalError { kind: Some(kind) }.into()) EvmSignTransactionResult::EvalError(ProtoTransactionEvalError { kind: Some(kind) })
} }
} }

View File

@@ -20,7 +20,7 @@ impl RequestTracker {
// This is used to set the response id for auth responses, which need to match the request id of the auth challenge request. // This is used to set the response id for auth responses, which need to match the request id of the auth challenge request.
// -1 offset is needed because request() increments the next_request_id after returning the current request id. // -1 offset is needed because request() increments the next_request_id after returning the current request id.
pub fn current_request_id(&self) -> i32 { pub const fn current_request_id(&self) -> i32 {
self.next_request_id - 1 self.next_request_id - 1
} }
} }

View File

@@ -1,3 +1,4 @@
use arbiter_crypto::authn;
use arbiter_proto::{ use arbiter_proto::{
proto::user_agent::{ proto::user_agent::{
UserAgentRequest, UserAgentResponse, UserAgentRequest, UserAgentResponse,
@@ -5,8 +6,7 @@ use arbiter_proto::{
self as proto_auth, AuthChallenge as ProtoAuthChallenge, self as proto_auth, AuthChallenge as ProtoAuthChallenge,
AuthChallengeRequest as ProtoAuthChallengeRequest, AuthChallengeRequest as ProtoAuthChallengeRequest,
AuthChallengeSolution as ProtoAuthChallengeSolution, AuthResult as ProtoAuthResult, AuthChallengeSolution as ProtoAuthChallengeSolution, AuthResult as ProtoAuthResult,
KeyType as ProtoKeyType, request::Payload as AuthRequestPayload, request::Payload as AuthRequestPayload, response::Payload as AuthResponsePayload,
response::Payload as AuthResponsePayload,
}, },
user_agent_request::Payload as UserAgentRequestPayload, user_agent_request::Payload as UserAgentRequestPayload,
user_agent_response::Payload as UserAgentResponsePayload, user_agent_response::Payload as UserAgentResponsePayload,
@@ -18,8 +18,7 @@ use tonic::Status;
use tracing::warn; use tracing::warn;
use crate::{ use crate::{
actors::user_agent::{AuthPublicKey, UserAgentConnection, auth}, actors::user_agent::{UserAgentConnection, auth},
db::models::KeyType,
grpc::request_tracker::RequestTracker, grpc::request_tracker::RequestTracker,
}; };
@@ -29,7 +28,7 @@ pub struct AuthTransportAdapter<'a> {
} }
impl<'a> AuthTransportAdapter<'a> { impl<'a> AuthTransportAdapter<'a> {
pub fn new( pub const fn new(
bi: &'a mut GrpcBi<UserAgentRequest, UserAgentResponse>, bi: &'a mut GrpcBi<UserAgentRequest, UserAgentResponse>,
request_tracker: &'a mut RequestTracker, request_tracker: &'a mut RequestTracker,
) -> Self { ) -> Self {
@@ -141,28 +140,9 @@ impl Receiver<auth::Inbound> for AuthTransportAdapter<'_> {
AuthRequestPayload::ChallengeRequest(ProtoAuthChallengeRequest { AuthRequestPayload::ChallengeRequest(ProtoAuthChallengeRequest {
pubkey, pubkey,
bootstrap_token, bootstrap_token,
key_type, ..
}) => { }) => {
let Ok(key_type) = ProtoKeyType::try_from(key_type) else { let Ok(pubkey) = authn::PublicKey::try_from(pubkey.as_slice()) else {
warn!(
event = "received request with invalid key type",
"grpc.useragent.auth_adapter"
);
return None;
};
let key_type = match key_type {
ProtoKeyType::Ed25519 => KeyType::Ed25519,
ProtoKeyType::EcdsaSecp256k1 => KeyType::EcdsaSecp256k1,
ProtoKeyType::Rsa => KeyType::Rsa,
ProtoKeyType::Unspecified => {
warn!(
event = "received request with unspecified key type",
"grpc.useragent.auth_adapter"
);
return None;
}
};
let Ok(pubkey) = AuthPublicKey::try_from((key_type, pubkey)) else {
warn!( warn!(
event = "received request with invalid public key", event = "received request with invalid public key",
"grpc.useragent.auth_adapter" "grpc.useragent.auth_adapter"
@@ -188,7 +168,7 @@ pub async fn start(
conn: &mut UserAgentConnection, conn: &mut UserAgentConnection,
bi: &mut GrpcBi<UserAgentRequest, UserAgentResponse>, bi: &mut GrpcBi<UserAgentRequest, UserAgentResponse>,
request_tracker: &mut RequestTracker, request_tracker: &mut RequestTracker,
) -> Result<AuthPublicKey, auth::Error> { ) -> Result<authn::PublicKey, auth::Error> {
let transport = AuthTransportAdapter::new(bi, request_tracker); let transport = AuthTransportAdapter::new(bi, request_tracker);
auth::authenticate(conn, transport).await auth::authenticate(conn, transport).await
} }

View File

@@ -26,8 +26,8 @@ use crate::{
actors::user_agent::{ actors::user_agent::{
UserAgentSession, UserAgentSession,
session::connection::{ session::connection::{
HandleEvmWalletCreate, HandleEvmWalletList, HandleGrantCreate, HandleGrantDelete, GrantMutationError, HandleEvmWalletCreate, HandleEvmWalletList, HandleGrantCreate,
HandleGrantList, HandleSignTransaction, HandleGrantDelete, HandleGrantList, HandleSignTransaction,
SignTransactionError as SessionSignTransactionError, SignTransactionError as SessionSignTransactionError,
}, },
}, },
@@ -37,7 +37,7 @@ use crate::{
}, },
}; };
fn wrap_evm_response(payload: EvmResponsePayload) -> UserAgentResponsePayload { const fn wrap_evm_response(payload: EvmResponsePayload) -> UserAgentResponsePayload {
UserAgentResponsePayload::Evm(proto_evm::Response { UserAgentResponsePayload::Evm(proto_evm::Response {
payload: Some(payload), payload: Some(payload),
}) })
@@ -52,8 +52,8 @@ pub(super) async fn dispatch(
}; };
match payload { match payload {
EvmRequestPayload::WalletCreate(_) => handle_wallet_create(actor).await, EvmRequestPayload::WalletCreate(()) => handle_wallet_create(actor).await,
EvmRequestPayload::WalletList(_) => handle_wallet_list(actor).await, EvmRequestPayload::WalletList(()) => handle_wallet_list(actor).await,
EvmRequestPayload::GrantCreate(req) => handle_grant_create(actor, req).await, EvmRequestPayload::GrantCreate(req) => handle_grant_create(actor, req).await,
EvmRequestPayload::GrantDelete(req) => handle_grant_delete(actor, req).await, EvmRequestPayload::GrantDelete(req) => handle_grant_delete(actor, req).await,
EvmRequestPayload::GrantList(_) => handle_grant_list(actor).await, EvmRequestPayload::GrantList(_) => handle_grant_list(actor).await,
@@ -114,10 +114,10 @@ async fn handle_grant_list(
grants: grants grants: grants
.into_iter() .into_iter()
.map(|grant| GrantEntry { .map(|grant| GrantEntry {
id: grant.id, id: grant.common_settings_id,
wallet_access_id: grant.shared.wallet_access_id, wallet_access_id: grant.settings.shared.wallet_access_id,
shared: Some(grant.shared.convert()), shared: Some(grant.settings.shared.convert()),
specific: Some(grant.settings.convert()), specific: Some(grant.settings.specific.convert()),
}) })
.collect(), .collect(),
}), }),
@@ -148,6 +148,9 @@ async fn handle_grant_create(
let result = match actor.ask(HandleGrantCreate { basic, grant }).await { let result = match actor.ask(HandleGrantCreate { basic, grant }).await {
Ok(grant_id) => EvmGrantCreateResult::GrantId(grant_id), Ok(grant_id) => EvmGrantCreateResult::GrantId(grant_id),
Err(kameo::error::SendError::HandlerError(GrantMutationError::VaultSealed)) => {
EvmGrantCreateResult::Error(ProtoEvmError::VaultSealed.into())
}
Err(err) => { Err(err) => {
warn!(error = ?err, "Failed to create EVM grant"); warn!(error = ?err, "Failed to create EVM grant");
EvmGrantCreateResult::Error(ProtoEvmError::Internal.into()) EvmGrantCreateResult::Error(ProtoEvmError::Internal.into())
@@ -171,6 +174,9 @@ async fn handle_grant_delete(
.await .await
{ {
Ok(()) => EvmGrantDeleteResult::Ok(()), Ok(()) => EvmGrantDeleteResult::Ok(()),
Err(kameo::error::SendError::HandlerError(GrantMutationError::VaultSealed)) => {
EvmGrantDeleteResult::Error(ProtoEvmError::VaultSealed.into())
}
Err(err) => { Err(err) => {
warn!(error = ?err, "Failed to delete EVM grant"); warn!(error = ?err, "Failed to delete EVM grant");
EvmGrantDeleteResult::Error(ProtoEvmError::Internal.into()) EvmGrantDeleteResult::Error(ProtoEvmError::Internal.into())

View File

@@ -22,11 +22,11 @@ use crate::{
grpc::TryConvert, grpc::TryConvert,
}; };
fn address_from_bytes(bytes: Vec<u8>) -> Result<Address, Status> { fn address_from_bytes(bytes: &[u8]) -> Result<Address, Status> {
if bytes.len() != 20 { if bytes.len() != 20 {
return Err(Status::invalid_argument("Invalid EVM address")); return Err(Status::invalid_argument("Invalid EVM address"));
} }
Ok(Address::from_slice(&bytes)) Ok(Address::from_slice(bytes))
} }
fn u256_from_proto_bytes(bytes: &[u8]) -> Result<U256, Status> { fn u256_from_proto_bytes(bytes: &[u8]) -> Result<U256, Status> {
@@ -41,7 +41,7 @@ impl TryConvert for ProtoTimestamp {
type Error = Status; type Error = Status;
fn try_convert(self) -> Result<DateTime<Utc>, Status> { fn try_convert(self) -> Result<DateTime<Utc>, Status> {
Utc.timestamp_opt(self.seconds, self.nanos as u32) Utc.timestamp_opt(self.seconds, self.nanos.try_into().unwrap_or_default())
.single() .single()
.ok_or_else(|| Status::invalid_argument("Invalid timestamp")) .ok_or_else(|| Status::invalid_argument("Invalid timestamp"))
} }
@@ -116,7 +116,8 @@ impl TryConvert for ProtoSpecificGrant {
limit, limit,
})) => Ok(SpecificGrant::EtherTransfer(ether_transfer::Settings { })) => Ok(SpecificGrant::EtherTransfer(ether_transfer::Settings {
target: targets target: targets
.into_iter() .iter()
.map(Vec::as_slice)
.map(address_from_bytes) .map(address_from_bytes)
.collect::<Result<_, _>>()?, .collect::<Result<_, _>>()?,
limit: limit limit: limit
@@ -130,8 +131,10 @@ impl TryConvert for ProtoSpecificGrant {
target, target,
volume_limits, volume_limits,
})) => Ok(SpecificGrant::TokenTransfer(token_transfers::Settings { })) => Ok(SpecificGrant::TokenTransfer(token_transfers::Settings {
token_contract: address_from_bytes(token_contract)?, token_contract: address_from_bytes(&token_contract)?,
target: target.map(address_from_bytes).transpose()?, target: target
.map(|target| address_from_bytes(&target))
.transpose()?,
volume_limits: volume_limits volume_limits: volume_limits
.into_iter() .into_iter()
.map(ProtoVolumeRateLimit::try_convert) .map(ProtoVolumeRateLimit::try_convert)

View File

@@ -22,7 +22,7 @@ impl Convert for DateTime<Utc> {
fn convert(self) -> ProtoTimestamp { fn convert(self) -> ProtoTimestamp {
ProtoTimestamp { ProtoTimestamp {
seconds: self.timestamp(), seconds: self.timestamp(),
nanos: self.timestamp_subsec_nanos() as i32, nanos: self.timestamp_subsec_nanos().try_into().unwrap_or(i32::MAX),
} }
} }
} }
@@ -74,13 +74,13 @@ impl Convert for SpecificGrant {
fn convert(self) -> ProtoSpecificGrant { fn convert(self) -> ProtoSpecificGrant {
let grant = match self { let grant = match self {
SpecificGrant::EtherTransfer(s) => { Self::EtherTransfer(s) => {
ProtoSpecificGrantType::EtherTransfer(ProtoEtherTransferSettings { ProtoSpecificGrantType::EtherTransfer(ProtoEtherTransferSettings {
targets: s.target.into_iter().map(|a| a.to_vec()).collect(), targets: s.target.into_iter().map(|a| a.to_vec()).collect(),
limit: Some(s.limit.convert()), limit: Some(s.limit.convert()),
}) })
} }
SpecificGrant::TokenTransfer(s) => { Self::TokenTransfer(s) => {
ProtoSpecificGrantType::TokenTransfer(ProtoTokenTransferSettings { ProtoSpecificGrantType::TokenTransfer(ProtoTokenTransferSettings {
token_contract: s.token_contract.to_vec(), token_contract: s.token_contract.to_vec(),
target: s.target.map(|a| a.to_vec()), target: s.target.map(|a| a.to_vec()),

View File

@@ -1,3 +1,4 @@
use arbiter_crypto::authn;
use arbiter_proto::proto::{ use arbiter_proto::proto::{
shared::ClientInfo as ProtoClientMetadata, shared::ClientInfo as ProtoClientMetadata,
user_agent::{ user_agent::{
@@ -31,7 +32,7 @@ use crate::{
grpc::Convert, grpc::Convert,
}; };
fn wrap_sdk_client_response(payload: SdkClientResponsePayload) -> UserAgentResponsePayload { const fn wrap_sdk_client_response(payload: SdkClientResponsePayload) -> UserAgentResponsePayload {
UserAgentResponsePayload::SdkClient(proto_sdk_client::Response { UserAgentResponsePayload::SdkClient(proto_sdk_client::Response {
payload: Some(payload), payload: Some(payload),
}) })
@@ -41,7 +42,7 @@ pub(super) fn out_of_band_payload(oob: OutOfBand) -> UserAgentResponsePayload {
match oob { match oob {
OutOfBand::ClientConnectionRequest { profile } => wrap_sdk_client_response( OutOfBand::ClientConnectionRequest { profile } => wrap_sdk_client_response(
SdkClientResponsePayload::ConnectionRequest(ProtoSdkClientConnectionRequest { SdkClientResponsePayload::ConnectionRequest(ProtoSdkClientConnectionRequest {
pubkey: profile.pubkey.to_bytes().to_vec(), pubkey: profile.pubkey.to_bytes(),
info: Some(ProtoClientMetadata { info: Some(ProtoClientMetadata {
name: profile.metadata.name, name: profile.metadata.name,
description: profile.metadata.description, description: profile.metadata.description,
@@ -51,7 +52,7 @@ pub(super) fn out_of_band_payload(oob: OutOfBand) -> UserAgentResponsePayload {
), ),
OutOfBand::ClientConnectionCancel { pubkey } => wrap_sdk_client_response( OutOfBand::ClientConnectionCancel { pubkey } => wrap_sdk_client_response(
SdkClientResponsePayload::ConnectionCancel(ProtoSdkClientConnectionCancel { SdkClientResponsePayload::ConnectionCancel(ProtoSdkClientConnectionCancel {
pubkey: pubkey.to_bytes().to_vec(), pubkey: pubkey.to_bytes(),
}), }),
), ),
} }
@@ -74,14 +75,14 @@ pub(super) async fn dispatch(
SdkClientRequestPayload::Revoke(_) => Err(Status::unimplemented( SdkClientRequestPayload::Revoke(_) => Err(Status::unimplemented(
"SdkClientRevoke is not yet implemented", "SdkClientRevoke is not yet implemented",
)), )),
SdkClientRequestPayload::List(_) => handle_list(actor).await, SdkClientRequestPayload::List(()) => handle_list(actor).await,
SdkClientRequestPayload::GrantWalletAccess(req) => { SdkClientRequestPayload::GrantWalletAccess(req) => {
handle_grant_wallet_access(actor, req).await handle_grant_wallet_access(actor, req).await
} }
SdkClientRequestPayload::RevokeWalletAccess(req) => { SdkClientRequestPayload::RevokeWalletAccess(req) => {
handle_revoke_wallet_access(actor, req).await handle_revoke_wallet_access(actor, req).await
} }
SdkClientRequestPayload::ListWalletAccess(_) => handle_list_wallet_access(actor).await, SdkClientRequestPayload::ListWalletAccess(()) => handle_list_wallet_access(actor).await,
} }
} }
@@ -89,10 +90,8 @@ async fn handle_connection_response(
actor: &ActorRef<UserAgentSession>, actor: &ActorRef<UserAgentSession>,
resp: ProtoSdkClientConnectionResponse, resp: ProtoSdkClientConnectionResponse,
) -> Result<Option<UserAgentResponsePayload>, Status> { ) -> Result<Option<UserAgentResponsePayload>, Status> {
let pubkey_bytes = <[u8; 32]>::try_from(resp.pubkey) let pubkey = authn::PublicKey::try_from(resp.pubkey.as_slice())
.map_err(|_| Status::invalid_argument("Invalid Ed25519 public key length"))?; .map_err(|()| Status::invalid_argument("Invalid ML-DSA public key"))?;
let pubkey = ed25519_dalek::VerifyingKey::from_bytes(&pubkey_bytes)
.map_err(|_| Status::invalid_argument("Invalid Ed25519 public key"))?;
actor actor
.ask(HandleNewClientApprove { .ask(HandleNewClientApprove {
@@ -117,12 +116,17 @@ async fn handle_list(
.into_iter() .into_iter()
.map(|(client, metadata)| ProtoSdkClientEntry { .map(|(client, metadata)| ProtoSdkClientEntry {
id: client.id, id: client.id,
pubkey: client.public_key, pubkey: client.public_key.clone(),
info: Some(ProtoClientMetadata { info: Some(ProtoClientMetadata {
name: metadata.name, name: metadata.name,
description: metadata.description, description: metadata.description,
version: metadata.version, version: metadata.version,
}), }),
#[expect(
clippy::cast_possible_truncation,
clippy::as_conversions,
reason = "fixme! #84"
)]
created_at: client.created_at.0.timestamp() as i32, created_at: client.created_at.0.timestamp() as i32,
}) })
.collect(), .collect(),
@@ -143,7 +147,7 @@ async fn handle_grant_wallet_access(
actor: &ActorRef<UserAgentSession>, actor: &ActorRef<UserAgentSession>,
req: ProtoSdkClientGrantWalletAccess, req: ProtoSdkClientGrantWalletAccess,
) -> Result<Option<UserAgentResponsePayload>, Status> { ) -> Result<Option<UserAgentResponsePayload>, Status> {
let entries: Vec<NewEvmWalletAccess> = req.accesses.into_iter().map(|a| a.convert()).collect(); let entries: Vec<NewEvmWalletAccess> = req.accesses.into_iter().map(Convert::convert).collect();
match actor.ask(HandleGrantEvmWalletAccess { entries }).await { match actor.ask(HandleGrantEvmWalletAccess { entries }).await {
Ok(()) => { Ok(()) => {
info!("Successfully granted wallet access"); info!("Successfully granted wallet access");
@@ -183,7 +187,7 @@ async fn handle_list_wallet_access(
match actor.ask(HandleListWalletAccess {}).await { match actor.ask(HandleListWalletAccess {}).await {
Ok(accesses) => Ok(Some(wrap_sdk_client_response( Ok(accesses) => Ok(Some(wrap_sdk_client_response(
SdkClientResponsePayload::ListWalletAccess(ListWalletAccessResponse { SdkClientResponsePayload::ListWalletAccess(ListWalletAccessResponse {
accesses: accesses.into_iter().map(|a| a.convert()).collect(), accesses: accesses.into_iter().map(Convert::convert).collect(),
}), }),
))), ))),
Err(err) => { Err(err) => {

View File

@@ -31,13 +31,13 @@ use crate::actors::{
}, },
}; };
fn wrap_vault_response(payload: VaultResponsePayload) -> UserAgentResponsePayload { const fn wrap_vault_response(payload: VaultResponsePayload) -> UserAgentResponsePayload {
UserAgentResponsePayload::Vault(proto_vault::Response { UserAgentResponsePayload::Vault(proto_vault::Response {
payload: Some(payload), payload: Some(payload),
}) })
} }
fn wrap_unseal_response(payload: UnsealResponsePayload) -> UserAgentResponsePayload { const fn wrap_unseal_response(payload: UnsealResponsePayload) -> UserAgentResponsePayload {
wrap_vault_response(VaultResponsePayload::Unseal(proto_unseal::Response { wrap_vault_response(VaultResponsePayload::Unseal(proto_unseal::Response {
payload: Some(payload), payload: Some(payload),
})) }))
@@ -58,7 +58,7 @@ pub(super) async fn dispatch(
}; };
match payload { match payload {
VaultRequestPayload::QueryState(_) => handle_query_vault_state(actor).await, VaultRequestPayload::QueryState(()) => handle_query_vault_state(actor).await,
VaultRequestPayload::Unseal(req) => dispatch_unseal_request(actor, req).await, VaultRequestPayload::Unseal(req) => dispatch_unseal_request(actor, req).await,
VaultRequestPayload::Bootstrap(req) => handle_bootstrap_request(actor, req).await, VaultRequestPayload::Bootstrap(req) => handle_bootstrap_request(actor, req).await,
} }

View File

@@ -1,4 +1,3 @@
#![forbid(unsafe_code)]
use crate::context::ServerContext; use crate::context::ServerContext;
pub mod actors; pub mod actors;
@@ -7,7 +6,6 @@ pub mod crypto;
pub mod db; pub mod db;
pub mod evm; pub mod evm;
pub mod grpc; pub mod grpc;
pub mod safe_cell;
pub mod utils; pub mod utils;
pub struct Server { pub struct Server {
@@ -15,7 +13,7 @@ pub struct Server {
} }
impl Server { impl Server {
pub fn new(context: ServerContext) -> Self { pub const fn new(context: ServerContext) -> Self {
Self { context } Self { context }
} }
} }

View File

@@ -10,6 +10,7 @@ use tracing::info;
const PORT: u16 = 50051; const PORT: u16 = 50051;
#[tokio::main] #[tokio::main]
#[mutants::skip]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
aws_lc_rs::default_provider().install_default().unwrap(); aws_lc_rs::default_provider().install_default().unwrap();

View File

@@ -1,13 +1,21 @@
use arbiter_crypto::{
authn::{self, CLIENT_CONTEXT, format_challenge},
safecell::{SafeCell, SafeCellHandle as _},
};
use arbiter_proto::ClientMetadata; use arbiter_proto::ClientMetadata;
use arbiter_proto::transport::{Receiver, Sender}; use arbiter_proto::transport::{Receiver, Sender};
use arbiter_server::actors::GlobalActors;
use arbiter_server::{ use arbiter_server::{
actors::client::{ClientConnection, auth, connect_client}, actors::{
db, GlobalActors,
client::{ClientConnection, ClientCredentials, auth, connect_client},
keyholder::Bootstrap,
},
crypto::integrity,
db::{self, schema},
}; };
use diesel::{ExpressionMethods as _, NullableExpressionMethods as _, QueryDsl as _, insert_into}; use diesel::{ExpressionMethods as _, NullableExpressionMethods as _, QueryDsl as _, insert_into};
use diesel_async::RunQueryDsl; use diesel_async::RunQueryDsl;
use ed25519_dalek::Signer as _; use ml_dsa::{KeyGen, MlDsa87, SigningKey, VerifyingKey, signature::Keypair};
use super::common::ChannelTransport; use super::common::ChannelTransport;
@@ -19,9 +27,14 @@ fn metadata(name: &str, description: Option<&str>, version: Option<&str>) -> Cli
} }
} }
fn verifying_key(key: &SigningKey<MlDsa87>) -> VerifyingKey<MlDsa87> {
<SigningKey<MlDsa87> as Keypair>::verifying_key(key)
}
async fn insert_registered_client( async fn insert_registered_client(
db: &db::DatabasePool, db: &db::DatabasePool,
pubkey: Vec<u8>, actors: &GlobalActors,
pubkey: VerifyingKey<MlDsa87>,
metadata: &ClientMetadata, metadata: &ClientMetadata,
) { ) {
use arbiter_server::db::schema::{client_metadata, program_client}; use arbiter_server::db::schema::{client_metadata, program_client};
@@ -37,34 +50,90 @@ async fn insert_registered_client(
.get_result(&mut conn) .get_result(&mut conn)
.await .await
.unwrap(); .unwrap();
insert_into(program_client::table) let client_id: i32 = insert_into(program_client::table)
.values(( .values((
program_client::public_key.eq(pubkey), program_client::public_key.eq(pubkey.encode().0.to_vec()),
program_client::metadata_id.eq(metadata_id), program_client::metadata_id.eq(metadata_id),
)) ))
.returning(program_client::id)
.get_result(&mut conn)
.await
.unwrap();
integrity::sign_entity(
&mut conn,
&actors.key_holder,
&ClientCredentials {
pubkey: pubkey.into(),
nonce: 1,
},
client_id,
)
.await
.unwrap();
}
fn sign_client_challenge(
key: &SigningKey<MlDsa87>,
nonce: i32,
pubkey: &authn::PublicKey,
) -> authn::Signature {
let challenge = format_challenge(nonce, &pubkey.to_bytes());
key.signing_key()
.sign_deterministic(&challenge, CLIENT_CONTEXT)
.unwrap()
.into()
}
async fn insert_bootstrap_sentinel_useragent(db: &db::DatabasePool) {
let mut conn = db.get().await.unwrap();
let sentinel_key = verifying_key(&MlDsa87::key_gen(&mut rand::rng()))
.encode()
.0
.to_vec();
insert_into(schema::useragent_client::table)
.values((
schema::useragent_client::public_key.eq(sentinel_key),
schema::useragent_client::key_type.eq(1i32),
))
.execute(&mut conn) .execute(&mut conn)
.await .await
.unwrap(); .unwrap();
} }
async fn spawn_test_actors(db: &db::DatabasePool) -> GlobalActors {
insert_bootstrap_sentinel_useragent(db).await;
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
actors
.key_holder
.ask(Bootstrap {
seal_key_raw: SafeCell::new(b"test-seal-key".to_vec()),
})
.await
.unwrap();
actors
}
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
pub async fn test_unregistered_pubkey_rejected() { pub async fn unregistered_pubkey_rejected() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let (server_transport, mut test_transport) = ChannelTransport::new(); let (server_transport, mut test_transport) = ChannelTransport::new();
let actors = GlobalActors::spawn(db.clone()).await.unwrap(); let actors = spawn_test_actors(&db).await;
let props = ClientConnection::new(db.clone(), actors); let props = ClientConnection::new(db.clone(), actors);
let task = tokio::spawn(async move { let task = tokio::spawn(async move {
let mut server_transport = server_transport; let mut server_transport = server_transport;
connect_client(props, &mut server_transport).await; connect_client(props, &mut server_transport).await;
}); });
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let new_key = MlDsa87::key_gen(&mut rand::rng());
test_transport test_transport
.send(auth::Inbound::AuthChallengeRequest { .send(auth::Inbound::AuthChallengeRequest {
pubkey: new_key.verifying_key(), pubkey: verifying_key(&new_key).into(),
metadata: metadata("client", Some("desc"), Some("1.0.0")), metadata: metadata("client", Some("desc"), Some("1.0.0")),
}) })
.await .await
@@ -76,22 +145,21 @@ pub async fn test_unregistered_pubkey_rejected() {
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
pub async fn test_challenge_auth() { pub async fn challenge_auth() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let actors = spawn_test_actors(&db).await;
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let new_key = MlDsa87::key_gen(&mut rand::rng());
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec();
insert_registered_client( Box::pin(insert_registered_client(
&db, &db,
pubkey_bytes.clone(), &actors,
verifying_key(&new_key),
&metadata("client", Some("desc"), Some("1.0.0")), &metadata("client", Some("desc"), Some("1.0.0")),
) ))
.await; .await;
let (server_transport, mut test_transport) = ChannelTransport::new(); let (server_transport, mut test_transport) = ChannelTransport::new();
let actors = GlobalActors::spawn(db.clone()).await.unwrap();
let props = ClientConnection::new(db.clone(), actors); let props = ClientConnection::new(db.clone(), actors);
let task = tokio::spawn(async move { let task = tokio::spawn(async move {
let mut server_transport = server_transport; let mut server_transport = server_transport;
@@ -101,7 +169,7 @@ pub async fn test_challenge_auth() {
// Send challenge request // Send challenge request
test_transport test_transport
.send(auth::Inbound::AuthChallengeRequest { .send(auth::Inbound::AuthChallengeRequest {
pubkey: new_key.verifying_key(), pubkey: verifying_key(&new_key).into(),
metadata: metadata("client", Some("desc"), Some("1.0.0")), metadata: metadata("client", Some("desc"), Some("1.0.0")),
}) })
.await .await
@@ -115,14 +183,13 @@ pub async fn test_challenge_auth() {
let challenge = match response { let challenge = match response {
Ok(resp) => match resp { Ok(resp) => match resp {
auth::Outbound::AuthChallenge { pubkey, nonce } => (pubkey, nonce), auth::Outbound::AuthChallenge { pubkey, nonce } => (pubkey, nonce),
other => panic!("Expected AuthChallenge, got {other:?}"), other @ auth::Outbound::AuthSuccess => panic!("Expected AuthChallenge, got {other:?}"),
}, },
Err(err) => panic!("Expected Ok response, got Err({err:?})"), Err(err) => panic!("Expected Ok response, got Err({err:?})"),
}; };
// Sign the challenge and send solution // Sign the challenge and send solution
let formatted_challenge = arbiter_proto::format_challenge(challenge.1, challenge.0.as_bytes()); let signature = sign_client_challenge(&new_key, challenge.1, &challenge.0);
let signature = new_key.sign(&formatted_challenge);
test_transport test_transport
.send(auth::Inbound::AuthChallengeSolution { signature }) .send(auth::Inbound::AuthChallengeSolution { signature })
@@ -145,36 +212,21 @@ pub async fn test_challenge_auth() {
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
pub async fn test_metadata_unchanged_does_not_append_history() { pub async fn metadata_unchanged_does_not_append_history() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let actors = GlobalActors::spawn(db.clone()).await.unwrap(); let actors = spawn_test_actors(&db).await;
let props = ClientConnection::new(db.clone(), actors); let new_key = MlDsa87::key_gen(&mut rand::rng());
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
let requested = metadata("client", Some("desc"), Some("1.0.0")); let requested = metadata("client", Some("desc"), Some("1.0.0"));
{ Box::pin(insert_registered_client(
use arbiter_server::db::schema::{client_metadata, program_client}; &db,
let mut conn = db.get().await.unwrap(); &actors,
let metadata_id: i32 = insert_into(client_metadata::table) verifying_key(&new_key),
.values(( &requested,
client_metadata::name.eq(&requested.name), ))
client_metadata::description.eq(&requested.description), .await;
client_metadata::version.eq(&requested.version),
)) let props = ClientConnection::new(db.clone(), actors);
.returning(client_metadata::id)
.get_result(&mut conn)
.await
.unwrap();
insert_into(program_client::table)
.values((
program_client::public_key.eq(new_key.verifying_key().to_bytes().to_vec()),
program_client::metadata_id.eq(metadata_id),
))
.execute(&mut conn)
.await
.unwrap();
}
let (server_transport, mut test_transport) = ChannelTransport::new(); let (server_transport, mut test_transport) = ChannelTransport::new();
let task = tokio::spawn(async move { let task = tokio::spawn(async move {
@@ -184,7 +236,7 @@ pub async fn test_metadata_unchanged_does_not_append_history() {
test_transport test_transport
.send(auth::Inbound::AuthChallengeRequest { .send(auth::Inbound::AuthChallengeRequest {
pubkey: new_key.verifying_key(), pubkey: verifying_key(&new_key).into(),
metadata: requested, metadata: requested,
}) })
.await .await
@@ -193,9 +245,9 @@ pub async fn test_metadata_unchanged_does_not_append_history() {
let response = test_transport.recv().await.unwrap().unwrap(); let response = test_transport.recv().await.unwrap().unwrap();
let (pubkey, nonce) = match response { let (pubkey, nonce) = match response {
auth::Outbound::AuthChallenge { pubkey, nonce } => (pubkey, nonce), auth::Outbound::AuthChallenge { pubkey, nonce } => (pubkey, nonce),
other => panic!("Expected AuthChallenge, got {other:?}"), auth::Outbound::AuthSuccess => panic!("Expected AuthChallenge, got AuthSuccess"),
}; };
let signature = new_key.sign(&arbiter_proto::format_challenge(nonce, pubkey.as_bytes())); let signature = sign_client_challenge(&new_key, nonce, &pubkey);
test_transport test_transport
.send(auth::Inbound::AuthChallengeSolution { signature }) .send(auth::Inbound::AuthChallengeSolution { signature })
.await .await
@@ -223,36 +275,21 @@ pub async fn test_metadata_unchanged_does_not_append_history() {
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
pub async fn test_metadata_change_appends_history_and_repoints_binding() { pub async fn metadata_change_appends_history_and_repoints_binding() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let actors = GlobalActors::spawn(db.clone()).await.unwrap(); let actors = spawn_test_actors(&db).await;
let new_key = MlDsa87::key_gen(&mut rand::rng());
Box::pin(insert_registered_client(
&db,
&actors,
verifying_key(&new_key),
&metadata("client", Some("old"), Some("1.0.0")),
))
.await;
let props = ClientConnection::new(db.clone(), actors); let props = ClientConnection::new(db.clone(), actors);
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng());
{
use arbiter_server::db::schema::{client_metadata, program_client};
let mut conn = db.get().await.unwrap();
let metadata_id: i32 = insert_into(client_metadata::table)
.values((
client_metadata::name.eq("client"),
client_metadata::description.eq(Some("old")),
client_metadata::version.eq(Some("1.0.0")),
))
.returning(client_metadata::id)
.get_result(&mut conn)
.await
.unwrap();
insert_into(program_client::table)
.values((
program_client::public_key.eq(new_key.verifying_key().to_bytes().to_vec()),
program_client::metadata_id.eq(metadata_id),
))
.execute(&mut conn)
.await
.unwrap();
}
let (server_transport, mut test_transport) = ChannelTransport::new(); let (server_transport, mut test_transport) = ChannelTransport::new();
let task = tokio::spawn(async move { let task = tokio::spawn(async move {
let mut server_transport = server_transport; let mut server_transport = server_transport;
@@ -261,7 +298,7 @@ pub async fn test_metadata_change_appends_history_and_repoints_binding() {
test_transport test_transport
.send(auth::Inbound::AuthChallengeRequest { .send(auth::Inbound::AuthChallengeRequest {
pubkey: new_key.verifying_key(), pubkey: verifying_key(&new_key).into(),
metadata: metadata("client", Some("new"), Some("2.0.0")), metadata: metadata("client", Some("new"), Some("2.0.0")),
}) })
.await .await
@@ -270,14 +307,14 @@ pub async fn test_metadata_change_appends_history_and_repoints_binding() {
let response = test_transport.recv().await.unwrap().unwrap(); let response = test_transport.recv().await.unwrap().unwrap();
let (pubkey, nonce) = match response { let (pubkey, nonce) = match response {
auth::Outbound::AuthChallenge { pubkey, nonce } => (pubkey, nonce), auth::Outbound::AuthChallenge { pubkey, nonce } => (pubkey, nonce),
other => panic!("Expected AuthChallenge, got {other:?}"), auth::Outbound::AuthSuccess => panic!("Expected AuthChallenge, got AuthSuccess"),
}; };
let signature = new_key.sign(&arbiter_proto::format_challenge(nonce, pubkey.as_bytes())); let signature = sign_client_challenge(&new_key, nonce, &pubkey);
test_transport test_transport
.send(auth::Inbound::AuthChallengeSolution { signature }) .send(auth::Inbound::AuthChallengeSolution { signature })
.await .await
.unwrap(); .unwrap();
let _ = test_transport.recv().await.unwrap(); drop(test_transport.recv().await.unwrap());
task.await.unwrap(); task.await.unwrap();
{ {
@@ -322,3 +359,62 @@ pub async fn test_metadata_change_appends_history_and_repoints_binding() {
); );
} }
} }
#[tokio::test]
#[test_log::test]
pub async fn challenge_auth_rejects_integrity_tag_mismatch() {
let db = db::create_test_pool().await;
let actors = spawn_test_actors(&db).await;
let new_key = MlDsa87::key_gen(&mut rand::rng());
let requested = metadata("client", Some("desc"), Some("1.0.0"));
{
use arbiter_server::db::schema::{client_metadata, program_client};
let mut conn = db.get().await.unwrap();
let metadata_id: i32 = insert_into(client_metadata::table)
.values((
client_metadata::name.eq(&requested.name),
client_metadata::description.eq(&requested.description),
client_metadata::version.eq(&requested.version),
))
.returning(client_metadata::id)
.get_result(&mut conn)
.await
.unwrap();
insert_into(program_client::table)
.values((
program_client::public_key.eq(verifying_key(&new_key).encode().0.to_vec()),
program_client::metadata_id.eq(metadata_id),
))
.execute(&mut conn)
.await
.unwrap();
}
let (server_transport, mut test_transport) = ChannelTransport::new();
let props = ClientConnection::new(db.clone(), actors);
let task = tokio::spawn(async move {
let mut server_transport = server_transport;
connect_client(props, &mut server_transport).await;
});
test_transport
.send(auth::Inbound::AuthChallengeRequest {
pubkey: verifying_key(&new_key).into(),
metadata: requested,
})
.await
.unwrap();
let response = test_transport
.recv()
.await
.expect("should receive auth rejection");
assert!(matches!(
response,
Err(auth::ClientAuthError::IntegrityCheckFailed)
));
task.await.unwrap();
}

View File

@@ -1,15 +1,16 @@
#![allow(dead_code, reason = "Common test utilities that may not be used in every test")]
use arbiter_crypto::safecell::{SafeCell, SafeCellHandle as _};
use arbiter_proto::transport::{Bi, Error, Receiver, Sender}; use arbiter_proto::transport::{Bi, Error, Receiver, Sender};
use arbiter_server::{ use arbiter_server::{
actors::keyholder::KeyHolder, actors::keyholder::KeyHolder,
db::{self, schema}, db::{self, schema},
safe_cell::{SafeCell, SafeCellHandle as _},
}; };
use async_trait::async_trait; use async_trait::async_trait;
use diesel::QueryDsl; use diesel::QueryDsl;
use diesel_async::RunQueryDsl; use diesel_async::RunQueryDsl;
use tokio::sync::mpsc; use tokio::sync::mpsc;
#[allow(dead_code)]
pub async fn bootstrapped_keyholder(db: &db::DatabasePool) -> KeyHolder { pub async fn bootstrapped_keyholder(db: &db::DatabasePool) -> KeyHolder {
let mut actor = KeyHolder::new(db.clone()).await.unwrap(); let mut actor = KeyHolder::new(db.clone()).await.unwrap();
actor actor
@@ -19,7 +20,6 @@ pub async fn bootstrapped_keyholder(db: &db::DatabasePool) -> KeyHolder {
actor actor
} }
#[allow(dead_code)]
pub async fn root_key_history_id(db: &db::DatabasePool) -> i32 { pub async fn root_key_history_id(db: &db::DatabasePool) -> i32 {
let mut conn = db.get().await.unwrap(); let mut conn = db.get().await.unwrap();
let id = schema::arbiter_settings::table let id = schema::arbiter_settings::table
@@ -30,14 +30,12 @@ pub async fn root_key_history_id(db: &db::DatabasePool) -> i32 {
id.expect("root_key_id should be set after bootstrap") id.expect("root_key_id should be set after bootstrap")
} }
#[allow(dead_code)]
pub struct ChannelTransport<T, Y> { pub struct ChannelTransport<T, Y> {
receiver: mpsc::Receiver<T>, receiver: mpsc::Receiver<T>,
sender: mpsc::Sender<Y>, sender: mpsc::Sender<Y>,
} }
impl<T, Y> ChannelTransport<T, Y> { impl<T, Y> ChannelTransport<T, Y> {
#[allow(dead_code)]
pub fn new() -> (Self, ChannelTransport<Y, T>) { pub fn new() -> (Self, ChannelTransport<Y, T>) {
let (tx1, rx1) = mpsc::channel(10); let (tx1, rx1) = mpsc::channel(10);
let (tx2, rx2) = mpsc::channel(10); let (tx2, rx2) = mpsc::channel(10);

View File

@@ -1,10 +1,11 @@
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use arbiter_crypto::safecell::{SafeCell, SafeCellHandle as _};
use arbiter_server::{ use arbiter_server::{
actors::keyholder::{CreateNew, Error, KeyHolder}, actors::keyholder::{CreateNew, KeyHolder, KeyHolderError},
db::{self, models, schema}, db::{self, models, schema},
safe_cell::{SafeCell, SafeCellHandle as _},
}; };
use diesel::{ExpressionMethods as _, QueryDsl, SelectableHelper, dsl::sql_query}; use diesel::{ExpressionMethods as _, QueryDsl, SelectableHelper, dsl::sql_query};
use diesel_async::RunQueryDsl; use diesel_async::RunQueryDsl;
use kameo::actor::{ActorRef, Spawn as _}; use kameo::actor::{ActorRef, Spawn as _};
@@ -121,7 +122,7 @@ async fn insert_failure_does_not_create_partial_row() {
.create_new(SafeCell::new(b"should fail".to_vec())) .create_new(SafeCell::new(b"should fail".to_vec()))
.await .await
.unwrap_err(); .unwrap_err();
assert!(matches!(err, Error::DatabaseTransaction(_))); assert!(matches!(err, KeyHolderError::DatabaseTransaction(_)));
let mut conn = db.get().await.unwrap(); let mut conn = db.get().await.unwrap();
sql_query("DROP TRIGGER fail_aead_insert;") sql_query("DROP TRIGGER fail_aead_insert;")

View File

@@ -1,9 +1,10 @@
use arbiter_crypto::safecell::{SafeCell, SafeCellHandle as _};
use arbiter_server::{ use arbiter_server::{
actors::keyholder::{Error, KeyHolder}, actors::keyholder::{KeyHolder, KeyHolderError},
crypto::encryption::v1::{Nonce, ROOT_KEY_TAG}, crypto::encryption::v1::{Nonce, ROOT_KEY_TAG},
db::{self, models, schema}, db::{self, models, schema},
safe_cell::{SafeCell, SafeCellHandle as _},
}; };
use diesel::{QueryDsl, SelectableHelper}; use diesel::{QueryDsl, SelectableHelper};
use diesel_async::RunQueryDsl; use diesel_async::RunQueryDsl;
@@ -11,7 +12,7 @@ use crate::common;
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
async fn test_bootstrap() { async fn bootstrap() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let mut actor = KeyHolder::new(db.clone()).await.unwrap(); let mut actor = KeyHolder::new(db.clone()).await.unwrap();
@@ -34,18 +35,18 @@ async fn test_bootstrap() {
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
async fn test_bootstrap_rejects_double() { async fn bootstrap_rejects_double() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let mut actor = common::bootstrapped_keyholder(&db).await; let mut actor = common::bootstrapped_keyholder(&db).await;
let seal_key2 = SafeCell::new(b"test-seal-key".to_vec()); let seal_key2 = SafeCell::new(b"test-seal-key".to_vec());
let err = actor.bootstrap(seal_key2).await.unwrap_err(); let err = actor.bootstrap(seal_key2).await.unwrap_err();
assert!(matches!(err, Error::AlreadyBootstrapped)); assert!(matches!(err, KeyHolderError::AlreadyBootstrapped));
} }
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
async fn test_create_new_before_bootstrap_fails() { async fn create_new_before_bootstrap_fails() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let mut actor = KeyHolder::new(db).await.unwrap(); let mut actor = KeyHolder::new(db).await.unwrap();
@@ -53,34 +54,34 @@ async fn test_create_new_before_bootstrap_fails() {
.create_new(SafeCell::new(b"data".to_vec())) .create_new(SafeCell::new(b"data".to_vec()))
.await .await
.unwrap_err(); .unwrap_err();
assert!(matches!(err, Error::NotBootstrapped)); assert!(matches!(err, KeyHolderError::NotBootstrapped));
} }
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
async fn test_decrypt_before_bootstrap_fails() { async fn decrypt_before_bootstrap_fails() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let mut actor = KeyHolder::new(db).await.unwrap(); let mut actor = KeyHolder::new(db).await.unwrap();
let err = actor.decrypt(1).await.unwrap_err(); let err = actor.decrypt(1).await.unwrap_err();
assert!(matches!(err, Error::NotBootstrapped)); assert!(matches!(err, KeyHolderError::NotBootstrapped));
} }
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
async fn test_new_restores_sealed_state() { async fn new_restores_sealed_state() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let actor = common::bootstrapped_keyholder(&db).await; let actor = common::bootstrapped_keyholder(&db).await;
drop(actor); drop(actor);
let mut actor2 = KeyHolder::new(db).await.unwrap(); let mut actor2 = KeyHolder::new(db).await.unwrap();
let err = actor2.decrypt(1).await.unwrap_err(); let err = actor2.decrypt(1).await.unwrap_err();
assert!(matches!(err, Error::NotBootstrapped)); assert!(matches!(err, KeyHolderError::NotBootstrapped));
} }
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
async fn test_unseal_correct_password() { async fn unseal_correct_password() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let mut actor = common::bootstrapped_keyholder(&db).await; let mut actor = common::bootstrapped_keyholder(&db).await;
@@ -101,7 +102,7 @@ async fn test_unseal_correct_password() {
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
async fn test_unseal_wrong_then_correct_password() { async fn unseal_wrong_then_correct_password() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let mut actor = common::bootstrapped_keyholder(&db).await; let mut actor = common::bootstrapped_keyholder(&db).await;
@@ -116,7 +117,7 @@ async fn test_unseal_wrong_then_correct_password() {
let bad_key = SafeCell::new(b"wrong-password".to_vec()); let bad_key = SafeCell::new(b"wrong-password".to_vec());
let err = actor.try_unseal(bad_key).await.unwrap_err(); let err = actor.try_unseal(bad_key).await.unwrap_err();
assert!(matches!(err, Error::InvalidKey)); assert!(matches!(err, KeyHolderError::InvalidKey));
let good_key = SafeCell::new(b"test-seal-key".to_vec()); let good_key = SafeCell::new(b"test-seal-key".to_vec());
actor.try_unseal(good_key).await.unwrap(); actor.try_unseal(good_key).await.unwrap();

View File

@@ -1,11 +1,12 @@
use std::collections::HashSet; use std::collections::HashSet;
use arbiter_crypto::safecell::{SafeCell, SafeCellHandle as _};
use arbiter_server::{ use arbiter_server::{
actors::keyholder::Error, actors::keyholder::KeyHolderError,
crypto::encryption::v1::Nonce, crypto::encryption::v1::Nonce,
db::{self, models, schema}, db::{self, models, schema},
safe_cell::{SafeCell, SafeCellHandle as _},
}; };
use diesel::{ExpressionMethods as _, QueryDsl, SelectableHelper, dsl::update}; use diesel::{ExpressionMethods as _, QueryDsl, SelectableHelper, dsl::update};
use diesel_async::RunQueryDsl; use diesel_async::RunQueryDsl;
@@ -13,7 +14,7 @@ use crate::common;
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
async fn test_create_decrypt_roundtrip() { async fn create_decrypt_roundtrip() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let mut actor = common::bootstrapped_keyholder(&db).await; let mut actor = common::bootstrapped_keyholder(&db).await;
@@ -29,17 +30,17 @@ async fn test_create_decrypt_roundtrip() {
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
async fn test_decrypt_nonexistent_returns_not_found() { async fn decrypt_nonexistent_returns_not_found() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let mut actor = common::bootstrapped_keyholder(&db).await; let mut actor = common::bootstrapped_keyholder(&db).await;
let err = actor.decrypt(9999).await.unwrap_err(); let err = actor.decrypt(9999).await.unwrap_err();
assert!(matches!(err, Error::NotFound)); assert!(matches!(err, KeyHolderError::NotFound));
} }
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
async fn test_ciphertext_differs_across_entries() { async fn ciphertext_differs_across_entries() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let mut actor = common::bootstrapped_keyholder(&db).await; let mut actor = common::bootstrapped_keyholder(&db).await;
@@ -77,7 +78,7 @@ async fn test_ciphertext_differs_across_entries() {
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
async fn test_nonce_never_reused() { async fn nonce_never_reused() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let mut actor = common::bootstrapped_keyholder(&db).await; let mut actor = common::bootstrapped_keyholder(&db).await;
@@ -141,7 +142,7 @@ async fn broken_db_nonce_format_fails_closed() {
.create_new(SafeCell::new(b"must fail".to_vec())) .create_new(SafeCell::new(b"must fail".to_vec()))
.await .await
.unwrap_err(); .unwrap_err();
assert!(matches!(err, Error::BrokenDatabase)); assert!(matches!(err, KeyHolderError::BrokenDatabase));
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let mut actor = common::bootstrapped_keyholder(&db).await; let mut actor = common::bootstrapped_keyholder(&db).await;
@@ -158,5 +159,5 @@ async fn broken_db_nonce_format_fails_closed() {
drop(conn); drop(conn);
let err = actor.decrypt(id).await.unwrap_err(); let err = actor.decrypt(id).await.unwrap_err();
assert!(matches!(err, Error::BrokenDatabase)); assert!(matches!(err, KeyHolderError::BrokenDatabase));
} }

View File

@@ -1,25 +1,53 @@
use arbiter_crypto::{
authn::{self, USERAGENT_CONTEXT, format_challenge},
safecell::{SafeCell, SafeCellHandle as _},
};
use arbiter_proto::transport::{Receiver, Sender}; use arbiter_proto::transport::{Receiver, Sender};
use arbiter_server::{ use arbiter_server::{
actors::{ actors::{
GlobalActors, GlobalActors,
bootstrap::GetToken, bootstrap::GetToken,
keyholder::Bootstrap, keyholder::Bootstrap,
user_agent::{AuthPublicKey, UserAgentConnection, auth}, user_agent::{UserAgentConnection, UserAgentCredentials, auth},
}, },
crypto::integrity,
db::{self, schema}, db::{self, schema},
safe_cell::{SafeCell, SafeCellHandle as _},
}; };
use diesel::{ExpressionMethods as _, QueryDsl, insert_into}; use diesel::{ExpressionMethods as _, QueryDsl, insert_into};
use diesel_async::RunQueryDsl; use diesel_async::RunQueryDsl;
use ed25519_dalek::Signer as _; use ml_dsa::{KeyGen, MlDsa87, SigningKey, VerifyingKey, signature::Keypair};
use super::common::ChannelTransport; use super::common::ChannelTransport;
fn verifying_key(key: &SigningKey<MlDsa87>) -> VerifyingKey<MlDsa87> {
<SigningKey<MlDsa87> as Keypair>::verifying_key(key)
}
fn sign_useragent_challenge(
key: &SigningKey<MlDsa87>,
nonce: i32,
pubkey_bytes: &[u8],
) -> authn::Signature {
let challenge = format_challenge(nonce, pubkey_bytes);
key.signing_key()
.sign_deterministic(&challenge, USERAGENT_CONTEXT)
.unwrap()
.into()
}
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
pub async fn test_bootstrap_token_auth() { pub async fn bootstrap_token_auth() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let actors = GlobalActors::spawn(db.clone()).await.unwrap(); let actors = GlobalActors::spawn(db.clone()).await.unwrap();
actors
.key_holder
.ask(Bootstrap {
seal_key_raw: SafeCell::new(b"test-seal-key".to_vec()),
})
.await
.unwrap();
let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap(); let token = actors.bootstrapper.ask(GetToken).await.unwrap().unwrap();
let (server_transport, mut test_transport) = ChannelTransport::new(); let (server_transport, mut test_transport) = ChannelTransport::new();
@@ -29,10 +57,10 @@ pub async fn test_bootstrap_token_auth() {
auth::authenticate(&mut props, server_transport).await auth::authenticate(&mut props, server_transport).await
}); });
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let new_key = MlDsa87::key_gen(&mut rand::rng());
test_transport test_transport
.send(auth::Inbound::AuthChallengeRequest { .send(auth::Inbound::AuthChallengeRequest {
pubkey: AuthPublicKey::Ed25519(new_key.verifying_key()), pubkey: verifying_key(&new_key).into(),
bootstrap_token: Some(token), bootstrap_token: Some(token),
}) })
.await .await
@@ -55,12 +83,12 @@ pub async fn test_bootstrap_token_auth() {
.first::<Vec<u8>>(&mut conn) .first::<Vec<u8>>(&mut conn)
.await .await
.unwrap(); .unwrap();
assert_eq!(stored_pubkey, new_key.verifying_key().to_bytes().to_vec()); assert_eq!(stored_pubkey, verifying_key(&new_key).encode().0.to_vec());
} }
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
pub async fn test_bootstrap_invalid_token_auth() { pub async fn bootstrap_invalid_token_auth() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let actors = GlobalActors::spawn(db.clone()).await.unwrap(); let actors = GlobalActors::spawn(db.clone()).await.unwrap();
@@ -71,11 +99,11 @@ pub async fn test_bootstrap_invalid_token_auth() {
auth::authenticate(&mut props, server_transport).await auth::authenticate(&mut props, server_transport).await
}); });
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let new_key = MlDsa87::key_gen(&mut rand::rng());
test_transport test_transport
.send(auth::Inbound::AuthChallengeRequest { .send(auth::Inbound::AuthChallengeRequest {
pubkey: AuthPublicKey::Ed25519(new_key.verifying_key()), pubkey: verifying_key(&new_key).into(),
bootstrap_token: Some("invalid_token".to_string()), bootstrap_token: Some("invalid_token".to_owned()),
}) })
.await .await
.unwrap(); .unwrap();
@@ -96,23 +124,42 @@ pub async fn test_bootstrap_invalid_token_auth() {
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
pub async fn test_challenge_auth() { pub async fn challenge_auth() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let actors = GlobalActors::spawn(db.clone()).await.unwrap(); let actors = GlobalActors::spawn(db.clone()).await.unwrap();
actors
.key_holder
.ask(Bootstrap {
seal_key_raw: SafeCell::new(b"test-seal-key".to_vec()),
})
.await
.unwrap();
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let new_key = MlDsa87::key_gen(&mut rand::rng());
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); let pubkey_bytes = authn::PublicKey::from(verifying_key(&new_key)).to_bytes();
{ {
let mut conn = db.get().await.unwrap(); let mut conn = db.get().await.unwrap();
insert_into(schema::useragent_client::table) let id: i32 = insert_into(schema::useragent_client::table)
.values(( .values((
schema::useragent_client::public_key.eq(pubkey_bytes.clone()), schema::useragent_client::public_key.eq(pubkey_bytes.clone()),
schema::useragent_client::key_type.eq(1i32), schema::useragent_client::key_type.eq(1i32),
)) ))
.execute(&mut conn) .returning(schema::useragent_client::id)
.get_result(&mut conn)
.await .await
.unwrap(); .unwrap();
integrity::sign_entity(
&mut conn,
&actors.key_holder,
&UserAgentCredentials {
pubkey: verifying_key(&new_key).into(),
nonce: 1,
},
id,
)
.await
.unwrap();
} }
let (server_transport, mut test_transport) = ChannelTransport::new(); let (server_transport, mut test_transport) = ChannelTransport::new();
@@ -124,7 +171,7 @@ pub async fn test_challenge_auth() {
test_transport test_transport
.send(auth::Inbound::AuthChallengeRequest { .send(auth::Inbound::AuthChallengeRequest {
pubkey: AuthPublicKey::Ed25519(new_key.verifying_key()), pubkey: verifying_key(&new_key).into(),
bootstrap_token: None, bootstrap_token: None,
}) })
.await .await
@@ -137,17 +184,16 @@ pub async fn test_challenge_auth() {
let challenge = match response { let challenge = match response {
Ok(resp) => match resp { Ok(resp) => match resp {
auth::Outbound::AuthChallenge { nonce } => nonce, auth::Outbound::AuthChallenge { nonce } => nonce,
other => panic!("Expected AuthChallenge, got {other:?}"), auth::Outbound::AuthSuccess => panic!("Expected AuthChallenge, got AuthSuccess"),
}, },
Err(err) => panic!("Expected Ok response, got Err({err:?})"), Err(err) => panic!("Expected Ok response, got Err({err:?})"),
}; };
let formatted_challenge = arbiter_proto::format_challenge(challenge, &pubkey_bytes); let signature = sign_useragent_challenge(&new_key, challenge, &pubkey_bytes);
let signature = new_key.sign(&formatted_challenge);
test_transport test_transport
.send(auth::Inbound::AuthChallengeSolution { .send(auth::Inbound::AuthChallengeSolution {
signature: signature.to_bytes().to_vec(), signature: signature.to_bytes(),
}) })
.await .await
.unwrap(); .unwrap();
@@ -166,7 +212,7 @@ pub async fn test_challenge_auth() {
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
pub async fn test_challenge_auth_rejects_integrity_tag_mismatch_when_unsealed() { pub async fn challenge_auth_rejects_integrity_tag_mismatch_when_unsealed() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let actors = GlobalActors::spawn(db.clone()).await.unwrap(); let actors = GlobalActors::spawn(db.clone()).await.unwrap();
@@ -178,8 +224,8 @@ pub async fn test_challenge_auth_rejects_integrity_tag_mismatch_when_unsealed()
.await .await
.unwrap(); .unwrap();
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let new_key = MlDsa87::key_gen(&mut rand::rng());
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); let pubkey_bytes = authn::PublicKey::from(verifying_key(&new_key)).to_bytes();
{ {
let mut conn = db.get().await.unwrap(); let mut conn = db.get().await.unwrap();
@@ -187,7 +233,6 @@ pub async fn test_challenge_auth_rejects_integrity_tag_mismatch_when_unsealed()
.values(( .values((
schema::useragent_client::public_key.eq(pubkey_bytes.clone()), schema::useragent_client::public_key.eq(pubkey_bytes.clone()),
schema::useragent_client::key_type.eq(1i32), schema::useragent_client::key_type.eq(1i32),
schema::useragent_client::pubkey_integrity_tag.eq(Some(vec![0u8; 32])),
)) ))
.execute(&mut conn) .execute(&mut conn)
.await .await
@@ -203,7 +248,7 @@ pub async fn test_challenge_auth_rejects_integrity_tag_mismatch_when_unsealed()
test_transport test_transport
.send(auth::Inbound::AuthChallengeRequest { .send(auth::Inbound::AuthChallengeRequest {
pubkey: AuthPublicKey::Ed25519(new_key.verifying_key()), pubkey: verifying_key(&new_key).into(),
bootstrap_token: None, bootstrap_token: None,
}) })
.await .await
@@ -211,29 +256,48 @@ pub async fn test_challenge_auth_rejects_integrity_tag_mismatch_when_unsealed()
assert!(matches!( assert!(matches!(
task.await.unwrap(), task.await.unwrap(),
Err(auth::Error::InvalidChallengeSolution) Err(auth::Error::Internal { .. })
)); ));
} }
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
pub async fn test_challenge_auth_rejects_invalid_signature() { pub async fn challenge_auth_rejects_invalid_signature() {
let db = db::create_test_pool().await; let db = db::create_test_pool().await;
let actors = GlobalActors::spawn(db.clone()).await.unwrap(); let actors = GlobalActors::spawn(db.clone()).await.unwrap();
actors
.key_holder
.ask(Bootstrap {
seal_key_raw: SafeCell::new(b"test-seal-key".to_vec()),
})
.await
.unwrap();
let new_key = ed25519_dalek::SigningKey::generate(&mut rand::rng()); let new_key = MlDsa87::key_gen(&mut rand::rng());
let pubkey_bytes = new_key.verifying_key().to_bytes().to_vec(); let pubkey_bytes = authn::PublicKey::from(verifying_key(&new_key)).to_bytes();
{ {
let mut conn = db.get().await.unwrap(); let mut conn = db.get().await.unwrap();
insert_into(schema::useragent_client::table) let id: i32 = insert_into(schema::useragent_client::table)
.values(( .values((
schema::useragent_client::public_key.eq(pubkey_bytes.clone()), schema::useragent_client::public_key.eq(pubkey_bytes.clone()),
schema::useragent_client::key_type.eq(1i32), schema::useragent_client::key_type.eq(1i32),
)) ))
.execute(&mut conn) .returning(schema::useragent_client::id)
.get_result(&mut conn)
.await .await
.unwrap(); .unwrap();
integrity::sign_entity(
&mut conn,
&actors.key_holder,
&UserAgentCredentials {
pubkey: verifying_key(&new_key).into(),
nonce: 1,
},
id,
)
.await
.unwrap();
} }
let (server_transport, mut test_transport) = ChannelTransport::new(); let (server_transport, mut test_transport) = ChannelTransport::new();
@@ -245,7 +309,7 @@ pub async fn test_challenge_auth_rejects_invalid_signature() {
test_transport test_transport
.send(auth::Inbound::AuthChallengeRequest { .send(auth::Inbound::AuthChallengeRequest {
pubkey: AuthPublicKey::Ed25519(new_key.verifying_key()), pubkey: verifying_key(&new_key).into(),
bootstrap_token: None, bootstrap_token: None,
}) })
.await .await
@@ -258,17 +322,16 @@ pub async fn test_challenge_auth_rejects_invalid_signature() {
let challenge = match response { let challenge = match response {
Ok(resp) => match resp { Ok(resp) => match resp {
auth::Outbound::AuthChallenge { nonce } => nonce, auth::Outbound::AuthChallenge { nonce } => nonce,
other => panic!("Expected AuthChallenge, got {other:?}"), auth::Outbound::AuthSuccess => panic!("Expected AuthChallenge, got AuthSuccess"),
}, },
Err(err) => panic!("Expected Ok response, got Err({err:?})"), Err(err) => panic!("Expected Ok response, got Err({err:?})"),
}; };
let wrong_challenge = arbiter_proto::format_challenge(challenge + 1, &pubkey_bytes); let signature = sign_useragent_challenge(&new_key, challenge + 1, &pubkey_bytes);
let signature = new_key.sign(&wrong_challenge);
test_transport test_transport
.send(auth::Inbound::AuthChallengeSolution { .send(auth::Inbound::AuthChallengeSolution {
signature: signature.to_bytes().to_vec(), signature: signature.to_bytes(),
}) })
.await .await
.unwrap(); .unwrap();

View File

@@ -1,3 +1,4 @@
use arbiter_crypto::safecell::{SafeCell, SafeCellHandle as _};
use arbiter_server::{ use arbiter_server::{
actors::{ actors::{
GlobalActors, GlobalActors,
@@ -8,11 +9,9 @@ use arbiter_server::{
}, },
}, },
db, db,
safe_cell::{SafeCell, SafeCellHandle as _},
}; };
use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit}; use chacha20poly1305::{AeadInPlace, XChaCha20Poly1305, XNonce, aead::KeyInit};
use diesel::{ExpressionMethods as _, QueryDsl as _, insert_into};
use diesel_async::RunQueryDsl;
use kameo::actor::Spawn as _; use kameo::actor::Spawn as _;
use x25519_dalek::{EphemeralSecret, PublicKey}; use x25519_dalek::{EphemeralSecret, PublicKey};
@@ -70,7 +69,7 @@ async fn client_dh_encrypt(
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
pub async fn test_unseal_success() { pub async fn unseal_success() {
let seal_key = b"test-seal-key"; let seal_key = b"test-seal-key";
let (_db, user_agent) = setup_sealed_user_agent(seal_key).await; let (_db, user_agent) = setup_sealed_user_agent(seal_key).await;
@@ -82,7 +81,7 @@ pub async fn test_unseal_success() {
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
pub async fn test_unseal_wrong_seal_key() { pub async fn unseal_wrong_seal_key() {
let (_db, user_agent) = setup_sealed_user_agent(b"correct-key").await; let (_db, user_agent) = setup_sealed_user_agent(b"correct-key").await;
let encrypted_key = client_dh_encrypt(&user_agent, b"wrong-key").await; let encrypted_key = client_dh_encrypt(&user_agent, b"wrong-key").await;
@@ -98,7 +97,7 @@ pub async fn test_unseal_wrong_seal_key() {
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
pub async fn test_unseal_corrupted_ciphertext() { pub async fn unseal_corrupted_ciphertext() {
let (_db, user_agent) = setup_sealed_user_agent(b"test-key").await; let (_db, user_agent) = setup_sealed_user_agent(b"test-key").await;
let client_secret = EphemeralSecret::random(); let client_secret = EphemeralSecret::random();
@@ -129,7 +128,7 @@ pub async fn test_unseal_corrupted_ciphertext() {
#[tokio::test] #[tokio::test]
#[test_log::test] #[test_log::test]
pub async fn test_unseal_retry_after_invalid_key() { pub async fn unseal_retry_after_invalid_key() {
let seal_key = b"real-seal-key"; let seal_key = b"real-seal-key";
let (_db, user_agent) = setup_sealed_user_agent(seal_key).await; let (_db, user_agent) = setup_sealed_user_agent(seal_key).await;
@@ -152,42 +151,3 @@ pub async fn test_unseal_retry_after_invalid_key() {
assert!(matches!(response, Ok(()))); assert!(matches!(response, Ok(())));
} }
} }
#[tokio::test]
#[test_log::test]
pub async fn test_unseal_backfills_missing_pubkey_integrity_tags() {
let seal_key = b"test-seal-key";
let (db, user_agent) = setup_sealed_user_agent(seal_key).await;
{
let mut conn = db.get().await.unwrap();
insert_into(arbiter_server::db::schema::useragent_client::table)
.values((
arbiter_server::db::schema::useragent_client::public_key
.eq(vec![1u8, 2u8, 3u8, 4u8]),
arbiter_server::db::schema::useragent_client::key_type.eq(1i32),
arbiter_server::db::schema::useragent_client::pubkey_integrity_tag
.eq(Option::<Vec<u8>>::None),
))
.execute(&mut conn)
.await
.unwrap();
}
let encrypted_key = client_dh_encrypt(&user_agent, seal_key).await;
let response = user_agent.ask(encrypted_key).await;
assert!(matches!(response, Ok(())));
{
let mut conn = db.get().await.unwrap();
let tags: Vec<Option<Vec<u8>>> = arbiter_server::db::schema::useragent_client::table
.select(arbiter_server::db::schema::useragent_client::pubkey_integrity_tag)
.load(&mut conn)
.await
.unwrap();
assert!(
tags.iter()
.all(|tag| matches!(tag, Some(v) if v.len() == 32))
);
}
}