diff --git a/schemars_derive/src/doc_attrs.rs b/schemars_derive/src/attr/doc.rs similarity index 72% rename from schemars_derive/src/doc_attrs.rs rename to schemars_derive/src/attr/doc.rs index aea190c..26a5090 100644 --- a/schemars_derive/src/doc_attrs.rs +++ b/schemars_derive/src/attr/doc.rs @@ -1,13 +1,13 @@ use syn::{Attribute, Lit::Str, Meta::NameValue, MetaNameValue}; -pub fn get_title_and_desc_from_docs(attrs: &[Attribute]) -> (Option, Option) { - let docs = match get_docs(attrs) { +pub fn get_title_and_desc_from_doc(attrs: &[Attribute]) -> (Option, Option) { + let doc = match get_doc(attrs) { None => return (None, None), - Some(docs) => docs, + Some(doc) => doc, }; - if docs.starts_with('#') { - let mut split = docs.splitn(2, '\n'); + if doc.starts_with('#') { + let mut split = doc.splitn(2, '\n'); let title = split .next() .unwrap() @@ -17,12 +17,12 @@ pub fn get_title_and_desc_from_docs(attrs: &[Attribute]) -> (Option, Opt let maybe_desc = split.next().and_then(merge_description_lines); (none_if_empty(title), maybe_desc) } else { - (None, merge_description_lines(&docs)) + (None, merge_description_lines(&doc)) } } -fn merge_description_lines(docs: &str) -> Option { - let desc = docs +fn merge_description_lines(doc: &str) -> Option { + let desc = doc .trim() .split("\n\n") .filter_map(|line| none_if_empty(line.trim().replace('\n', " "))) @@ -31,8 +31,8 @@ fn merge_description_lines(docs: &str) -> Option { none_if_empty(desc) } -fn get_docs(attrs: &[Attribute]) -> Option { - let docs = attrs +fn get_doc(attrs: &[Attribute]) -> Option { + let doc = attrs .iter() .filter_map(|attr| { if !attr.path.is_ident("doc") { @@ -53,7 +53,7 @@ fn get_docs(attrs: &[Attribute]) -> Option { .skip_while(|s| *s == "") .collect::>() .join("\n"); - none_if_empty(docs) + none_if_empty(doc) } fn none_if_empty(s: String) -> Option { diff --git a/schemars_derive/src/attr/mod.rs b/schemars_derive/src/attr/mod.rs new file mode 100644 index 0000000..9db24f2 --- /dev/null +++ b/schemars_derive/src/attr/mod.rs @@ -0,0 +1,71 @@ +mod doc; +mod schemars_to_serde; + +pub use doc::get_title_and_desc_from_doc; +pub use schemars_to_serde::process_serde_attrs; + +use proc_macro2::{Group, Span, TokenStream, TokenTree}; +use syn::parse::{self, Parse}; + +pub fn get_with_from_attrs(field: &syn::Field) -> Option> { + field + .attrs + .iter() + .filter(|at| match at.path.get_ident() { + // FIXME this is relying on order of attributes (schemars before serde) from preprocess.rs + Some(i) => i == "schemars" || i == "serde", + None => false, + }) + .filter_map(get_with_from_attr) + .next() + .map(|lit| parse_lit_str(&lit)) +} + +fn get_with_from_attr(attr: &syn::Attribute) -> Option { + use syn::*; + let nested_metas = match attr.parse_meta() { + Ok(Meta::List(meta)) => meta.nested, + _ => return None, + }; + for nm in nested_metas { + if let NestedMeta::Meta(Meta::NameValue(MetaNameValue { + path, + lit: Lit::Str(with), + .. + })) = nm + { + if path.is_ident("with") { + return Some(with); + } + } + } + None +} + +fn parse_lit_str(s: &syn::LitStr) -> parse::Result +where + T: Parse, +{ + let tokens = spanned_tokens(s)?; + syn::parse2(tokens) +} + +fn spanned_tokens(s: &syn::LitStr) -> parse::Result { + let stream = syn::parse_str(&s.value())?; + Ok(respan_token_stream(stream, s.span())) +} + +fn respan_token_stream(stream: TokenStream, span: Span) -> TokenStream { + stream + .into_iter() + .map(|token| respan_token_tree(token, span)) + .collect() +} + +fn respan_token_tree(mut token: TokenTree, span: Span) -> TokenTree { + if let TokenTree::Group(g) = &mut token { + *g = Group::new(g.delimiter(), respan_token_stream(g.stream(), span)); + } + token.set_span(span); + token +} diff --git a/schemars_derive/src/preprocess.rs b/schemars_derive/src/attr/schemars_to_serde.rs similarity index 93% rename from schemars_derive/src/preprocess.rs rename to schemars_derive/src/attr/schemars_to_serde.rs index 34a71d5..b27f6e9 100644 --- a/schemars_derive/src/preprocess.rs +++ b/schemars_derive/src/attr/schemars_to_serde.rs @@ -2,7 +2,7 @@ use quote::ToTokens; use serde_derive_internals::Ctxt; use std::collections::HashSet; use syn::parse::Parser; -use syn::{Attribute, Data, DeriveInput, Field, GenericParam, Generics, Meta, NestedMeta, Variant}; +use syn::{Attribute, Data, Field, Meta, NestedMeta, Variant}; // List of keywords that can appear in #[serde(...)]/#[schemars(...)] attributes which we want serde_derive_internals to parse for us. static SERDE_KEYWORDS: &[&str] = &[ @@ -28,17 +28,9 @@ static SERDE_KEYWORDS: &[&str] = &[ "with", ]; -pub fn add_trait_bounds(generics: &mut Generics) { - for param in &mut generics.params { - if let GenericParam::Type(ref mut type_param) = *param { - type_param.bounds.push(parse_quote!(schemars::JsonSchema)); - } - } -} - // If a struct/variant/field has any #[schemars] attributes, then create copies of them // as #[serde] attributes so that serde_derive_internals will parse them for us. -pub fn process_serde_attrs(input: &mut DeriveInput) -> Result<(), Vec> { +pub fn process_serde_attrs(input: &mut syn::DeriveInput) -> Result<(), Vec> { let ctxt = Ctxt::new(); process_attrs(&ctxt, &mut input.attrs); match input.data { diff --git a/schemars_derive/src/lib.rs b/schemars_derive/src/lib.rs index 446703f..ff3e7c9 100644 --- a/schemars_derive/src/lib.rs +++ b/schemars_derive/src/lib.rs @@ -4,32 +4,30 @@ extern crate quote; extern crate syn; extern crate proc_macro; -mod doc_attrs; +mod attr; mod metadata; -mod preprocess; use metadata::*; -use proc_macro2::{Group, Span, TokenStream, TokenTree}; +use proc_macro2::TokenStream; use quote::ToTokens; use serde_derive_internals::ast::{Container, Data, Field, Style, Variant}; -use serde_derive_internals::attr::{self, Default as SerdeDefault, TagType}; +use serde_derive_internals::attr::{self as serde_attr, Default as SerdeDefault, TagType}; use serde_derive_internals::{Ctxt, Derive}; -use syn::parse::{self, Parse}; use syn::spanned::Spanned; #[proc_macro_derive(JsonSchema, attributes(schemars, serde, doc))] pub fn derive_json_schema(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let mut input = parse_macro_input!(input as syn::DeriveInput); - preprocess::add_trait_bounds(&mut input.generics); - if let Err(e) = preprocess::process_serde_attrs(&mut input) { - return compile_error(e).into(); + add_trait_bounds(&mut input.generics); + if let Err(e) = attr::process_serde_attrs(&mut input) { + return compile_error(&e).into(); } let ctxt = Ctxt::new(); let cont = Container::from_ast(&ctxt, &input, Derive::Deserialize); if let Err(e) = ctxt.check() { - return compile_error(e).into(); + return compile_error(&e).into(); } let cont = cont.expect("from_ast set no errors on Ctxt, so should have returned Some"); @@ -84,6 +82,14 @@ pub fn derive_json_schema(input: proc_macro::TokenStream) -> proc_macro::TokenSt proc_macro::TokenStream::from(impl_block) } +fn add_trait_bounds(generics: &mut syn::Generics) { + for param in &mut generics.params { + if let syn::GenericParam::Type(ref mut type_param) = *param { + type_param.bounds.push(parse_quote!(schemars::JsonSchema)); + } + } +} + fn wrap_schema_fields(schema_contents: TokenStream) -> TokenStream { quote! { schemars::schema::Schema::Object( @@ -94,8 +100,8 @@ fn wrap_schema_fields(schema_contents: TokenStream) -> TokenStream { } } -fn compile_error(errors: Vec) -> TokenStream { - let compile_errors = errors.iter().map(syn::Error::to_compile_error); +fn compile_error<'a>(errors: impl IntoIterator) -> TokenStream { + let compile_errors = errors.into_iter().map(syn::Error::to_compile_error); quote! { #(#compile_errors)* } @@ -108,7 +114,7 @@ fn is_unit_variant(v: &Variant) -> bool { } } -fn schema_for_enum(variants: &[Variant], cattrs: &attr::Container) -> TokenStream { +fn schema_for_enum(variants: &[Variant], cattrs: &serde_attr::Container) -> TokenStream { let variants = variants.iter().filter(|v| !v.attrs.skip_deserializing()); match cattrs.tag() { TagType::External => schema_for_external_tagged_enum(variants, cattrs), @@ -120,7 +126,7 @@ fn schema_for_enum(variants: &[Variant], cattrs: &attr::Container) -> TokenStrea fn schema_for_external_tagged_enum<'a>( variants: impl Iterator>, - cattrs: &attr::Container, + cattrs: &serde_attr::Container, ) -> TokenStream { let (unit_variants, complex_variants): (Vec<_>, Vec<_>) = variants.partition(|v| is_unit_variant(v)); @@ -174,7 +180,7 @@ fn schema_for_external_tagged_enum<'a>( fn schema_for_internal_tagged_enum<'a>( variants: impl Iterator>, - cattrs: &attr::Container, + cattrs: &serde_attr::Container, tag_name: &str, ) -> TokenStream { let schemas = variants.map(|variant| { @@ -229,7 +235,7 @@ fn schema_for_internal_tagged_enum<'a>( fn schema_for_untagged_enum<'a>( variants: impl Iterator>, - cattrs: &attr::Container, + cattrs: &serde_attr::Container, ) -> TokenStream { let schemas = variants.map(|variant| { let schema_expr = schema_for_untagged_enum_variant(variant, cattrs); @@ -244,7 +250,10 @@ fn schema_for_untagged_enum<'a>( }) } -fn schema_for_untagged_enum_variant(variant: &Variant, cattrs: &attr::Container) -> TokenStream { +fn schema_for_untagged_enum_variant( + variant: &Variant, + cattrs: &serde_attr::Container, +) -> TokenStream { match variant.style { Style::Unit => schema_for_unit_struct(), Style::Newtype => schema_for_newtype_struct(&variant.fields[0]), @@ -276,7 +285,7 @@ fn schema_for_tuple_struct(fields: &[Field]) -> TokenStream { } } -fn schema_for_struct(fields: &[Field], cattrs: &attr::Container) -> TokenStream { +fn schema_for_struct(fields: &[Field], cattrs: &serde_attr::Container) -> TokenStream { let (flat, nested): (Vec<_>, Vec<_>) = fields .iter() .filter(|f| !f.attrs.skip_deserializing() || !f.attrs.skip_serializing()) @@ -284,49 +293,14 @@ fn schema_for_struct(fields: &[Field], cattrs: &attr::Container) -> TokenStream let set_container_default = match cattrs.default() { SerdeDefault::None => None, - SerdeDefault::Default => Some(quote!(let cdefault = Self::default();)), - SerdeDefault::Path(path) => Some(quote!(let cdefault = #path();)), + SerdeDefault::Default => Some(quote!(let container_default = Self::default();)), + SerdeDefault::Path(path) => Some(quote!(let container_default = #path();)), }; let mut required = Vec::new(); let recurse = nested.iter().map(|field| { let name = field.attrs.name().deserialize_name(); - let ty = field.ty; - - let default = match field.attrs.default() { - _ if field.attrs.skip_serializing() => None, - SerdeDefault::None if set_container_default.is_none() => None, - SerdeDefault::None => { - let field_ident = field - .original - .ident - .as_ref() - .expect("This is not a tuple struct, so field should be named"); - Some(quote!(cdefault.#field_ident)) - } - SerdeDefault::Default => Some(quote!(<#ty>::default())), - SerdeDefault::Path(path) => Some(quote!(#path())), - } - .map(|d| match field.attrs.serialize_with() { - Some(ser_with) => quote! { - { - struct _SchemarsDefaultSerialize(T); - - impl serde::Serialize for _SchemarsDefaultSerialize<#ty> - { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer - { - #ser_with(&self.0, serializer) - } - } - - _SchemarsDefaultSerialize(#d) - } - }, - None => d, - }); + let default = field_default_expr(field, set_container_default.is_some()); if default.is_none() { required.push(name.clone()); @@ -384,73 +358,50 @@ fn schema_for_struct(fields: &[Field], cattrs: &attr::Container) -> TokenStream } } -fn get_json_schema_type(field: &Field) -> Box { - // TODO support [schemars(schema_with= "...")] or equivalent - match field - .original - .attrs - .iter() - .filter(|at| match at.path.get_ident() { - // FIXME this is relying on order of attributes (schemars before serde) from preprocess.rs - Some(i) => i == "schemars" || i == "serde", - None => false, - }) - .filter_map(get_with_from_attr) - .next() - { - Some(with) => match parse_lit_str::(&with) { - Ok(expr_path) => Box::new(expr_path), - Err(e) => Box::new(compile_error(vec![e])), - }, - None => Box::new(field.ty.clone()), +fn field_default_expr(field: &Field, container_has_default: bool) -> Option { + let field_default = field.attrs.default(); + if field.attrs.skip_serializing() || (field_default.is_none() && !container_has_default) { + return None; } -} -fn get_with_from_attr(attr: &syn::Attribute) -> Option { - use syn::*; - let nested_metas = match attr.parse_meta() { - Ok(Meta::List(meta)) => meta.nested, - _ => return None, + let ty = field.ty; + let default_expr = match field_default { + SerdeDefault::None => { + let member = &field.member; + quote!(container_default.#member) + } + SerdeDefault::Default => quote!(<#ty>::default()), + SerdeDefault::Path(path) => quote!(#path()), }; - for nm in nested_metas { - if let NestedMeta::Meta(Meta::NameValue(MetaNameValue { - path, - lit: Lit::Str(with), - .. - })) = nm - { - if path.is_ident("with") { - return Some(with); + + Some(if let Some(ser_with) = field.attrs.serialize_with() { + quote! { + { + struct _SchemarsDefaultSerialize(T); + + impl serde::Serialize for _SchemarsDefaultSerialize<#ty> + { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer + { + #ser_with(&self.0, serializer) + } + } + + _SchemarsDefaultSerialize(#default_expr) } } + } else { + default_expr + }) +} + +fn get_json_schema_type(field: &Field) -> Box { + // TODO support [schemars(schema_with= "...")] or equivalent + match attr::get_with_from_attrs(&field.original) { + None => Box::new(field.ty.clone()), + Some(Ok(expr_path)) => Box::new(expr_path), + Some(Err(e)) => Box::new(compile_error(&[e])), } - None -} - -fn parse_lit_str(s: &syn::LitStr) -> parse::Result -where - T: Parse, -{ - let tokens = spanned_tokens(s)?; - syn::parse2(tokens) -} - -fn spanned_tokens(s: &syn::LitStr) -> parse::Result { - let stream = syn::parse_str(&s.value())?; - Ok(respan_token_stream(stream, s.span())) -} - -fn respan_token_stream(stream: TokenStream, span: Span) -> TokenStream { - stream - .into_iter() - .map(|token| respan_token_tree(token, span)) - .collect() -} - -fn respan_token_tree(mut token: TokenTree, span: Span) -> TokenTree { - if let TokenTree::Group(g) = &mut token { - *g = Group::new(g.delimiter(), respan_token_stream(g.stream(), span)); - } - token.set_span(span); - token } diff --git a/schemars_derive/src/metadata.rs b/schemars_derive/src/metadata.rs index f2b2b8e..f5bc9e9 100644 --- a/schemars_derive/src/metadata.rs +++ b/schemars_derive/src/metadata.rs @@ -1,4 +1,4 @@ -use crate::doc_attrs; +use crate::attr; use proc_macro2::TokenStream; use syn::{Attribute, ExprPath}; @@ -21,7 +21,7 @@ pub fn set_metadata_on_schema_from_docs( } pub fn get_metadata_from_docs(attrs: &[Attribute]) -> SchemaMetadata { - let (title, description) = doc_attrs::get_title_and_desc_from_docs(attrs); + let (title, description) = attr::get_title_and_desc_from_doc(attrs); SchemaMetadata { title, description,