use proc_macro2::{Span, TokenStream, TokenTree}; use quote::quote; use syn::parse_quote; use syn::spanned::Spanned; use syn::{DataStruct, DeriveInput, Fields, Generics, Index}; use crate::utils::{HASHABLE_TRAIT_PATH, HMAC_DIGEST_PATH}; pub(crate) fn derive(input: &DeriveInput) -> TokenStream { match &input.data { syn::Data::Struct(struct_data) => hashable_struct(input, struct_data), syn::Data::Enum(_) => { syn::Error::new_spanned(input, "Hashable can currently be derived only for structs") .to_compile_error() } syn::Data::Union(_) => { syn::Error::new_spanned(input, "Hashable cannot be derived for unions") .to_compile_error() } } } fn hashable_struct(input: &DeriveInput, struct_data: &DataStruct) -> TokenStream { let ident = &input.ident; let hashable_trait = HASHABLE_TRAIT_PATH.to_path(); let hmac_digest = HMAC_DIGEST_PATH.to_path(); let generics = add_hashable_bounds(input.generics.clone(), &hashable_trait); let field_accesses = collect_field_accesses(struct_data); let hash_calls = build_hash_calls(&field_accesses, &hashable_trait); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); quote! { #[automatically_derived] impl #impl_generics #hashable_trait for #ident #ty_generics #where_clause { fn hash(&self, hasher: &mut H) { #(#hash_calls)* } } } } fn add_hashable_bounds(mut generics: Generics, hashable_trait: &syn::Path) -> Generics { for type_param in generics.type_params_mut() { type_param.bounds.push(parse_quote!(#hashable_trait)); } generics } struct FieldAccess { access: TokenStream, span: Span, } fn collect_field_accesses(struct_data: &DataStruct) -> Vec { match &struct_data.fields { Fields::Named(fields) => { // Keep deterministic alphabetical order for named fields. // Do not remove this sort, because it keeps hash output stable regardless of source order. let mut named_fields = fields .named .iter() .map(|field| { let name = field .ident .as_ref() .expect("Fields::Named(fields) must have names") .clone(); (name.to_string(), name) }) .collect::>(); named_fields.sort_by(|a, b| a.0.cmp(&b.0)); named_fields .into_iter() .map(|(_, name)| FieldAccess { access: quote! { #name }, span: name.span(), }) .collect() } Fields::Unnamed(fields) => fields .unnamed .iter() .enumerate() .map(|(i, field)| FieldAccess { access: { let index = Index::from(i); quote! { #index } }, span: field.ty.span(), }) .collect(), Fields::Unit => Vec::new(), } } fn build_hash_calls( field_accesses: &[FieldAccess], hashable_trait: &syn::Path, ) -> Vec { field_accesses .iter() .map(|field| { let access = &field.access; let call = quote! { #hashable_trait::hash(&self.#access, hasher); }; respan(call, field.span) }) .collect() } /// Recursively set span on all tokens, including interpolated ones. fn respan(tokens: TokenStream, span: Span) -> TokenStream { tokens .into_iter() .map(|tt| match tt { TokenTree::Group(g) => { let mut new = proc_macro2::Group::new(g.delimiter(), respan(g.stream(), span)); new.set_span(span); TokenTree::Group(new) } mut other => { other.set_span(span); other } }) .collect() }