refactor(hashing): introduce Hashable derive macro and migrate server types
This commit is contained in:
18
server/crates/arbiter-macros/Cargo.toml
Normal file
18
server/crates/arbiter-macros/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "arbiter-macros"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[lib]
|
||||
proc-macro = true
|
||||
|
||||
[dependencies]
|
||||
proc-macro2 = "1.0"
|
||||
quote = "1.0"
|
||||
syn = { version = "2.0", features = ["derive", "fold", "full", "visit-mut"] }
|
||||
|
||||
[dev-dependencies]
|
||||
arbiter-crypto = { path = "../arbiter-crypto" }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
133
server/crates/arbiter-macros/src/hashable.rs
Normal file
133
server/crates/arbiter-macros/src/hashable.rs
Normal file
@@ -0,0 +1,133 @@
|
||||
use proc_macro2::{Span, TokenStream, TokenTree};
|
||||
use quote::quote;
|
||||
use syn::parse_quote;
|
||||
use syn::spanned::Spanned;
|
||||
use syn::{DataStruct, DeriveInput, Fields, Generics, Index};
|
||||
|
||||
use crate::utils::{HASHABLE_TRAIT_PATH, HMAC_DIGEST_PATH};
|
||||
|
||||
pub(crate) fn derive(input: &DeriveInput) -> TokenStream {
|
||||
match &input.data {
|
||||
syn::Data::Struct(struct_data) => hashable_struct(input, struct_data),
|
||||
syn::Data::Enum(_) => {
|
||||
syn::Error::new_spanned(input, "Hashable can currently be derived only for structs")
|
||||
.to_compile_error()
|
||||
}
|
||||
syn::Data::Union(_) => {
|
||||
syn::Error::new_spanned(input, "Hashable cannot be derived for unions")
|
||||
.to_compile_error()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn hashable_struct(input: &DeriveInput, struct_data: &syn::DataStruct) -> TokenStream {
|
||||
let ident = &input.ident;
|
||||
let hashable_trait = HASHABLE_TRAIT_PATH.to_path();
|
||||
let hmac_digest = HMAC_DIGEST_PATH.to_path();
|
||||
let generics = add_hashable_bounds(input.generics.clone(), &hashable_trait);
|
||||
let field_accesses = collect_field_accesses(struct_data);
|
||||
let hash_calls = build_hash_calls(&field_accesses, &hashable_trait);
|
||||
|
||||
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
|
||||
|
||||
quote! {
|
||||
#[automatically_derived]
|
||||
impl #impl_generics #hashable_trait for #ident #ty_generics #where_clause {
|
||||
fn hash<H: #hmac_digest>(&self, hasher: &mut H) {
|
||||
#(#hash_calls)*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn add_hashable_bounds(mut generics: Generics, hashable_trait: &syn::Path) -> Generics {
|
||||
for type_param in generics.type_params_mut() {
|
||||
type_param.bounds.push(parse_quote!(#hashable_trait));
|
||||
}
|
||||
|
||||
generics
|
||||
}
|
||||
|
||||
struct FieldAccess {
|
||||
access: TokenStream,
|
||||
span: Span,
|
||||
}
|
||||
|
||||
fn collect_field_accesses(struct_data: &DataStruct) -> Vec<FieldAccess> {
|
||||
match &struct_data.fields {
|
||||
Fields::Named(fields) => {
|
||||
// Keep deterministic alphabetical order for named fields.
|
||||
// Do not remove this sort, because it keeps hash output stable regardless of source order.
|
||||
let mut named_fields = fields
|
||||
.named
|
||||
.iter()
|
||||
.map(|field| {
|
||||
let name = field
|
||||
.ident
|
||||
.as_ref()
|
||||
.expect("Fields::Named(fields) must have names")
|
||||
.clone();
|
||||
(name.to_string(), name)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
named_fields.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
|
||||
named_fields
|
||||
.into_iter()
|
||||
.map(|(_, name)| FieldAccess {
|
||||
access: quote! { #name },
|
||||
span: name.span(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
Fields::Unnamed(fields) => fields
|
||||
.unnamed
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, field)| FieldAccess {
|
||||
access: {
|
||||
let index = Index::from(i);
|
||||
quote! { #index }
|
||||
},
|
||||
span: field.ty.span(),
|
||||
})
|
||||
.collect(),
|
||||
Fields::Unit => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_hash_calls(
|
||||
field_accesses: &[FieldAccess],
|
||||
hashable_trait: &syn::Path,
|
||||
) -> Vec<TokenStream> {
|
||||
field_accesses
|
||||
.iter()
|
||||
.map(|field| {
|
||||
let access = &field.access;
|
||||
let call = quote! {
|
||||
#hashable_trait::hash(&self.#access, hasher);
|
||||
};
|
||||
|
||||
respan(call, field.span)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Recursively set span on all tokens, including interpolated ones.
|
||||
fn respan(tokens: TokenStream, span: Span) -> TokenStream {
|
||||
tokens
|
||||
.into_iter()
|
||||
.map(|tt| match tt {
|
||||
TokenTree::Group(g) => {
|
||||
let mut new = proc_macro2::Group::new(g.delimiter(), respan(g.stream(), span));
|
||||
new.set_span(span);
|
||||
TokenTree::Group(new)
|
||||
}
|
||||
mut other => {
|
||||
other.set_span(span);
|
||||
other
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
10
server/crates/arbiter-macros/src/lib.rs
Normal file
10
server/crates/arbiter-macros/src/lib.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
use syn::{DeriveInput, parse_macro_input};
|
||||
|
||||
mod hashable;
|
||||
mod utils;
|
||||
|
||||
#[proc_macro_derive(Hashable)]
|
||||
pub fn derive_hashable(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
|
||||
let input = parse_macro_input!(input as DeriveInput);
|
||||
hashable::derive(&input).into()
|
||||
}
|
||||
19
server/crates/arbiter-macros/src/utils.rs
Normal file
19
server/crates/arbiter-macros/src/utils.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
pub struct ToPath(pub &'static str);
|
||||
|
||||
impl ToPath {
|
||||
pub fn to_path(&self) -> syn::Path {
|
||||
syn::parse_str(self.0).expect("Invalid path")
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! ensure_path {
|
||||
($path:path) => {{
|
||||
#[cfg(test)]
|
||||
#[expect(unused_imports)]
|
||||
use $path as _;
|
||||
ToPath(stringify!($path))
|
||||
}};
|
||||
}
|
||||
|
||||
pub const HASHABLE_TRAIT_PATH: ToPath = ensure_path!(::arbiter_crypto::hashing::Hashable);
|
||||
pub const HMAC_DIGEST_PATH: ToPath = ensure_path!(::arbiter_crypto::hashing::Digest);
|
||||
Reference in New Issue
Block a user