use quote::ToTokens; use serde_derive_internals::Ctxt; use std::collections::BTreeSet; use syn::parse::Parser; use syn::{Attribute, Data, DeriveInput, Field, GenericParam, Generics, Meta, NestedMeta, Variant}; // List of keywords that can appear in #[serde(...)]/#[schemars(...)] attributes, which we want serde to parse for us. static SERDE_KEYWORDS: &[&str] = &[ "rename", "rename_all", "deny_unknown_fields", "tag", "content", "untagged", "bound", "default", "remote", "alias", "skip", "skip_serializing", "skip_serializing_if", "skip_deserializing", "other", "flatten", // special cases - these keywords are not copied from schemars attrs to serde attrs "serialize_with", "deserialize_with", "with", ]; pub fn add_trait_bounds(generics: &mut Generics) { for param in &mut generics.params { if let GenericParam::Type(ref mut type_param) = *param { type_param.bounds.push(parse_quote!(schemars::JsonSchema)); } } } // If a struct/variant/field has any #[schemars] attributes, then rename them // to #[serde] so that serde_derive_internals will parse them for us. pub fn process_serde_attrs(input: &mut DeriveInput) -> Result<(), Vec> { let ctxt = Ctxt::new(); process_attrs(&ctxt, &mut input.attrs); match input.data { Data::Struct(ref mut s) => process_serde_field_attrs(&ctxt, s.fields.iter_mut()), Data::Enum(ref mut e) => process_serde_variant_attrs(&ctxt, e.variants.iter_mut()), Data::Union(ref mut u) => process_serde_field_attrs(&ctxt, u.fields.named.iter_mut()), }; ctxt.check() } fn process_serde_variant_attrs<'a>(ctxt: &Ctxt, variants: impl Iterator) { for v in variants { process_attrs(&ctxt, &mut v.attrs); process_serde_field_attrs(&ctxt, v.fields.iter_mut()); } } fn process_serde_field_attrs<'a>(ctxt: &Ctxt, fields: impl Iterator) { for f in fields { process_attrs(&ctxt, &mut f.attrs); } } fn process_attrs(ctxt: &Ctxt, attrs: &mut Vec) { let (serde_attrs, other_attrs): (Vec<_>, Vec<_>) = attrs.drain(..).partition(|at| at.path.is_ident("serde")); *attrs = other_attrs; let schemars_attrs: Vec<_> = attrs .iter() .filter(|at| at.path.is_ident("schemars")) .collect(); let (mut serde_meta, mut schemars_meta_names): (Vec<_>, BTreeSet<_>) = schemars_attrs .iter() .flat_map(|at| get_meta_items(&ctxt, at)) .flatten() .filter_map(|meta| { let keyword = get_meta_ident(&ctxt, &meta).ok()?; if keyword.ends_with("with") || !SERDE_KEYWORDS.contains(&keyword.as_ref()) { None } else { Some((meta, keyword)) } }) .unzip(); if schemars_meta_names.contains("skip") { schemars_meta_names.insert("skip_serializing".to_string()); schemars_meta_names.insert("skip_deserializing".to_string()); } for meta in serde_attrs .into_iter() .flat_map(|at| get_meta_items(&ctxt, &at)) .flatten() { if let Ok(i) = get_meta_ident(&ctxt, &meta) { if !schemars_meta_names.contains(&i) && SERDE_KEYWORDS.contains(&i.as_ref()) { serde_meta.push(meta); } } } if !serde_meta.is_empty() { let new_serde_attr = quote! { #[serde(#(#serde_meta),*)] }; let parser = Attribute::parse_outer; match parser.parse2(new_serde_attr) { Ok(ref mut parsed) => attrs.append(parsed), Err(e) => ctxt.error_spanned_by(to_tokens(attrs), e), } } } fn to_tokens(attrs: &[Attribute]) -> impl ToTokens { let mut tokens = proc_macro2::TokenStream::new(); for attr in attrs { attr.to_tokens(&mut tokens); } tokens } fn get_meta_items(ctxt: &Ctxt, attr: &Attribute) -> Result, ()> { match attr.parse_meta() { Ok(Meta::List(meta)) => Ok(meta.nested.into_iter().collect()), Ok(_) => { ctxt.error_spanned_by(attr, "expected #[schemars(...)] or #[serde(...)]"); Err(()) } Err(err) => { ctxt.error_spanned_by(attr, err); Err(()) } } } fn get_meta_ident(ctxt: &Ctxt, meta: &NestedMeta) -> Result { match meta { NestedMeta::Meta(m) => m.path().get_ident().map(|i| i.to_string()).ok_or(()), NestedMeta::Lit(lit) => { ctxt.error_spanned_by( meta, format!( "unexpected literal in attribute: {}", lit.into_token_stream() ), ); Err(()) } } } #[cfg(test)] mod tests { use super::*; use pretty_assertions::assert_eq; use syn::DeriveInput; #[test] fn test_process_serde_attrs() { let mut input: DeriveInput = parse_quote! { #[serde(rename(serialize = "ser_name"), rename_all = "camelCase")] #[serde(default, unknown_word)] #[schemars(rename = "overriden", another_unknown_word)] #[misc] struct MyStruct { /// blah blah blah #[serde(alias = "first")] field1: i32, #[serde(serialize_with = "se", deserialize_with = "de")] #[schemars(with = "with")] field2: i32, #[schemars(skip)] #[serde(skip_serializing)] field3: i32, } }; let expected: DeriveInput = parse_quote! { #[schemars(rename = "overriden", another_unknown_word)] #[misc] #[serde(rename = "overriden", rename_all = "camelCase", default)] struct MyStruct { #[doc = r" blah blah blah"] #[serde(alias = "first")] field1: i32, #[schemars(with = "with")] #[serde(serialize_with = "se", deserialize_with = "de")] field2: i32, #[schemars(skip)] #[serde(skip)] field3: i32, } }; if let Err(e) = process_serde_attrs(&mut input) { panic!("process_serde_attrs returned error: {}", e[0]) }; assert_eq!(input, expected); } }