diff --git a/schemars/tests/expected/remote_derive.json b/schemars/tests/expected/remote_derive.json index d5a040f..65de9ad 100644 --- a/schemars/tests/expected/remote_derive.json +++ b/schemars/tests/expected/remote_derive.json @@ -11,10 +11,7 @@ "type": "string" }, "system_cpu_time": { - "default": { - "nanos": 0, - "secs": 0 - }, + "default": "0.000000000s", "allOf": [ { "$ref": "#/definitions/DurationDef" diff --git a/schemars/tests/remote_derive.rs b/schemars/tests/remote_derive.rs index 74a2c8c..c2f352e 100644 --- a/schemars/tests/remote_derive.rs +++ b/schemars/tests/remote_derive.rs @@ -34,7 +34,6 @@ struct Process { wall_time: Duration, #[serde(default, with = "DurationDef")] user_cpu_time: Duration, - // FIXME this should serialize the default as "0.000000000s" #[serde(default, serialize_with = "custom_serialize")] #[schemars(with = "DurationDef")] system_cpu_time: Duration, diff --git a/schemars_derive/src/lib.rs b/schemars_derive/src/lib.rs index 60fbadd..84409d4 100644 --- a/schemars_derive/src/lib.rs +++ b/schemars_derive/src/lib.rs @@ -9,11 +9,12 @@ mod metadata; mod preprocess; use metadata::*; -use proc_macro2::TokenStream; +use proc_macro2::{Group, Span, TokenStream, TokenTree}; use quote::ToTokens; use serde_derive_internals::ast::{Container, Data, Field, Style, Variant}; use serde_derive_internals::attr::{self, Default as SerdeDefault, TagType}; use serde_derive_internals::{Ctxt, Derive}; +use syn::parse::{self, Parse}; use syn::spanned::Spanned; #[proc_macro_derive(JsonSchema, attributes(schemars, serde, doc))] @@ -382,35 +383,70 @@ fn schema_for_struct(fields: &[Field], cattrs: &attr::Container) -> TokenStream } fn get_json_schema_type(field: &Field) -> Box { - // TODO it would probably be simpler to parse attributes manually here, instead of - // using the serde-parsed attributes - let de_with_segments = without_last_element(field.attrs.deserialize_with(), "deserialize"); - let se_with_segments = without_last_element(field.attrs.serialize_with(), "serialize"); - if de_with_segments == se_with_segments { - if let Some(expr_path) = de_with_segments { - return Box::new(expr_path); - } + // TODO support [schemars(schema_with= "...")] or equivalent + match field + .original + .attrs + .iter() + .filter(|at| match at.path.get_ident() { + // FIXME this is relying on order of attributes (schemars before serde) from preprocess.rs + Some(i) => i == "schemars" || i == "serde", + None => false, + }) + .filter_map(get_with_from_attr) + .next() + { + Some(with) => match parse_lit_str::(&with) { + Ok(expr_path) => Box::new(expr_path), + Err(e) => Box::new(compile_error(vec![e])), + }, + None => Box::new(field.ty.clone()), } - Box::new(field.ty.clone()) } -fn without_last_element(path: Option<&syn::ExprPath>, last: &str) -> Option { - match path { - Some(expr_path) - if expr_path - .path - .segments - .last() - .map(|p| p.ident == last) - .unwrap_or(false) => - { - let mut expr_path = expr_path.clone(); - expr_path.path.segments.pop(); - if let Some(segment) = expr_path.path.segments.pop() { - expr_path.path.segments.push(segment.into_value()) - } - Some(expr_path) +fn get_with_from_attr(attr: &syn::Attribute) -> Option { + use syn::*; + let nested_metas = match attr.parse_meta() { + Ok(Meta::List(meta)) => meta.nested, + _ => return None, + }; + for nm in nested_metas { + match nm { + NestedMeta::Meta(Meta::NameValue(MetaNameValue { + path, + lit: Lit::Str(with), + .. + })) if path.is_ident("with") => return Some(with), + _ => {} } - _ => None, } + None +} + +fn parse_lit_str(s: &syn::LitStr) -> parse::Result +where + T: Parse, +{ + let tokens = spanned_tokens(s)?; + syn::parse2(tokens) +} + +fn spanned_tokens(s: &syn::LitStr) -> parse::Result { + let stream = syn::parse_str(&s.value())?; + Ok(respan_token_stream(stream, s.span())) +} + +fn respan_token_stream(stream: TokenStream, span: Span) -> TokenStream { + stream + .into_iter() + .map(|token| respan_token_tree(token, span)) + .collect() +} + +fn respan_token_tree(mut token: TokenTree, span: Span) -> TokenTree { + if let TokenTree::Group(g) = &mut token { + *g = Group::new(g.delimiter(), respan_token_stream(g.stream(), span)); + } + token.set_span(span); + token } diff --git a/schemars_derive/src/preprocess.rs b/schemars_derive/src/preprocess.rs index 4b59c5a..ed5a61a 100644 --- a/schemars_derive/src/preprocess.rs +++ b/schemars_derive/src/preprocess.rs @@ -2,9 +2,30 @@ use quote::ToTokens; use serde_derive_internals::Ctxt; use std::collections::BTreeSet; use syn::parse::Parser; -use syn::{ - Attribute, Data, DeriveInput, Field, GenericParam, Generics, Ident, Meta, NestedMeta, Variant, -}; +use syn::{Attribute, Data, DeriveInput, Field, GenericParam, Generics, Meta, NestedMeta, Variant}; + +// List of keywords that can appear in #[serde(...)]/#[schemars(...)] attributes, which we want serde to parse for us. +static SERDE_KEYWORDS: &[&str] = &[ + "rename", + "rename_all", + "deny_unknown_fields", + "tag", + "content", + "untagged", + "bound", + "default", + "remote", + "alias", + "skip", + "skip_serializing", + "skip_deserializing", + "other", + "flatten", + // special cases - these keywords are not copied from schemars attrs to serde attrs + "serialize_with", + "deserialize_with", + "with", +]; pub fn add_trait_bounds(generics: &mut Generics) { for param in &mut generics.params { @@ -42,52 +63,48 @@ fn process_serde_field_attrs<'a>(ctxt: &Ctxt, fields: impl Iterator) { - let mut schemars_attrs = Vec::::new(); - let mut serde_attrs = Vec::::new(); - let mut misc_attrs = Vec::::new(); + let (serde_attrs, other_attrs): (Vec<_>, Vec<_>) = + attrs.drain(..).partition(|at| at.path.is_ident("serde")); - for attr in attrs.drain(..) { - if attr.path.is_ident("schemars") { - schemars_attrs.push(attr) - } else if attr.path.is_ident("serde") { - serde_attrs.push(attr) - } else { - misc_attrs.push(attr) + *attrs = other_attrs; + + let schemars_attrs: Vec<_> = attrs + .iter() + .filter(|at| at.path.is_ident("schemars")) + .collect(); + + let (mut serde_meta, mut schemars_meta_names): (Vec<_>, BTreeSet<_>) = schemars_attrs + .iter() + .flat_map(|at| get_meta_items(&ctxt, at)) + .flatten() + .filter_map(|meta| { + let keyword = get_meta_ident(&ctxt, &meta).ok()?; + if keyword.ends_with("with") || !SERDE_KEYWORDS.contains(&keyword.as_ref()) { + None + } else { + Some((meta, keyword)) + } + }) + .unzip(); + + if schemars_meta_names.contains("skip") { + schemars_meta_names.insert("skip_serializing".to_string()); + schemars_meta_names.insert("skip_deserializing".to_string()); + } + + for meta in serde_attrs + .into_iter() + .flat_map(|at| get_meta_items(&ctxt, &at)) + .flatten() + { + if let Ok(i) = get_meta_ident(&ctxt, &meta) { + if !schemars_meta_names.contains(&i) && SERDE_KEYWORDS.contains(&i.as_ref()) { + serde_meta.push(meta); + } } } - for attr in schemars_attrs.iter_mut() { - let schemars_ident = attr.path.segments.pop().unwrap().into_value().ident; - attr.path - .segments - .push(Ident::new("serde", schemars_ident.span()).into()); - } - - let mut schemars_meta_names: BTreeSet = schemars_attrs - .iter() - .flat_map(|attr| get_meta_items(&ctxt, attr)) - .flatten() - .flat_map(|m| get_meta_ident(&ctxt, &m)) - .collect(); - if schemars_meta_names.contains("with") { - schemars_meta_names.insert("serialize_with".to_string()); - schemars_meta_names.insert("deserialize_with".to_string()); - } - - let mut serde_meta = serde_attrs - .iter() - .flat_map(|attr| get_meta_items(&ctxt, attr)) - .flatten() - .filter(|m| { - get_meta_ident(&ctxt, m) - .map(|i| !schemars_meta_names.contains(&i)) - .unwrap_or(false) - }) - .peekable(); - - *attrs = schemars_attrs; - - if serde_meta.peek().is_some() { + if !serde_meta.is_empty() { let new_serde_attr = quote! { #[serde(#(#serde_meta),*)] }; @@ -98,8 +115,6 @@ fn process_attrs(ctxt: &Ctxt, attrs: &mut Vec) { Err(e) => ctxt.error_spanned_by(to_tokens(attrs), e), } } - - attrs.extend(misc_attrs) } fn to_tokens(attrs: &[Attribute]) -> impl ToTokens { @@ -149,33 +164,35 @@ mod tests { #[test] fn test_process_serde_attrs() { let mut input: DeriveInput = parse_quote! { - #[serde(container, container2 = "blah")] - #[serde(container3(foo, bar))] - #[schemars(container2 = "overridden", container4)] + #[serde(rename(serialize = "ser_name"), rename_all = "camelCase")] + #[serde(default, unknown_word)] + #[schemars(rename = "overriden", another_unknown_word)] #[misc] struct MyStruct { /// blah blah blah - #[serde(field, field2)] + #[serde(alias = "first")] field1: i32, - #[serde(field, field2, serialize_with = "se", deserialize_with = "de")] - #[schemars(field = "overridden", with = "with")] + #[serde(serialize_with = "se", deserialize_with = "de")] + #[schemars(with = "with")] field2: i32, - #[schemars(field)] + #[schemars(skip)] + #[serde(skip_serializing)] field3: i32, } }; let expected: DeriveInput = parse_quote! { - #[serde(container2 = "overridden", container4)] - #[serde(container, container3(foo, bar))] + #[schemars(rename = "overriden", another_unknown_word)] #[misc] + #[serde(rename = "overriden", rename_all = "camelCase", default)] struct MyStruct { - #[serde(field, field2)] #[doc = r" blah blah blah"] + #[serde(alias = "first")] field1: i32, - #[serde(field = "overridden", with = "with")] - #[serde(field2)] + #[schemars(with = "with")] + #[serde(serialize_with = "se", deserialize_with = "de")] field2: i32, - #[serde(field)] + #[schemars(skip)] + #[serde(skip)] field3: i32, } };