Refactoring of schemars_derive

This commit is contained in:
Graham Esau 2019-12-09 20:57:38 +00:00
parent dca9e2d920
commit 3fb625e08c
5 changed files with 155 additions and 141 deletions

View file

@ -1,13 +1,13 @@
use syn::{Attribute, Lit::Str, Meta::NameValue, MetaNameValue}; use syn::{Attribute, Lit::Str, Meta::NameValue, MetaNameValue};
pub fn get_title_and_desc_from_docs(attrs: &[Attribute]) -> (Option<String>, Option<String>) { pub fn get_title_and_desc_from_doc(attrs: &[Attribute]) -> (Option<String>, Option<String>) {
let docs = match get_docs(attrs) { let doc = match get_doc(attrs) {
None => return (None, None), None => return (None, None),
Some(docs) => docs, Some(doc) => doc,
}; };
if docs.starts_with('#') { if doc.starts_with('#') {
let mut split = docs.splitn(2, '\n'); let mut split = doc.splitn(2, '\n');
let title = split let title = split
.next() .next()
.unwrap() .unwrap()
@ -17,12 +17,12 @@ pub fn get_title_and_desc_from_docs(attrs: &[Attribute]) -> (Option<String>, Opt
let maybe_desc = split.next().and_then(merge_description_lines); let maybe_desc = split.next().and_then(merge_description_lines);
(none_if_empty(title), maybe_desc) (none_if_empty(title), maybe_desc)
} else { } else {
(None, merge_description_lines(&docs)) (None, merge_description_lines(&doc))
} }
} }
fn merge_description_lines(docs: &str) -> Option<String> { fn merge_description_lines(doc: &str) -> Option<String> {
let desc = docs let desc = doc
.trim() .trim()
.split("\n\n") .split("\n\n")
.filter_map(|line| none_if_empty(line.trim().replace('\n', " "))) .filter_map(|line| none_if_empty(line.trim().replace('\n', " ")))
@ -31,8 +31,8 @@ fn merge_description_lines(docs: &str) -> Option<String> {
none_if_empty(desc) none_if_empty(desc)
} }
fn get_docs(attrs: &[Attribute]) -> Option<String> { fn get_doc(attrs: &[Attribute]) -> Option<String> {
let docs = attrs let doc = attrs
.iter() .iter()
.filter_map(|attr| { .filter_map(|attr| {
if !attr.path.is_ident("doc") { if !attr.path.is_ident("doc") {
@ -53,7 +53,7 @@ fn get_docs(attrs: &[Attribute]) -> Option<String> {
.skip_while(|s| *s == "") .skip_while(|s| *s == "")
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join("\n"); .join("\n");
none_if_empty(docs) none_if_empty(doc)
} }
fn none_if_empty(s: String) -> Option<String> { fn none_if_empty(s: String) -> Option<String> {

View file

@ -0,0 +1,71 @@
mod doc;
mod schemars_to_serde;
pub use doc::get_title_and_desc_from_doc;
pub use schemars_to_serde::process_serde_attrs;
use proc_macro2::{Group, Span, TokenStream, TokenTree};
use syn::parse::{self, Parse};
pub fn get_with_from_attrs(field: &syn::Field) -> Option<syn::Result<syn::ExprPath>> {
field
.attrs
.iter()
.filter(|at| match at.path.get_ident() {
// FIXME this is relying on order of attributes (schemars before serde) from preprocess.rs
Some(i) => i == "schemars" || i == "serde",
None => false,
})
.filter_map(get_with_from_attr)
.next()
.map(|lit| parse_lit_str(&lit))
}
fn get_with_from_attr(attr: &syn::Attribute) -> Option<syn::LitStr> {
use syn::*;
let nested_metas = match attr.parse_meta() {
Ok(Meta::List(meta)) => meta.nested,
_ => return None,
};
for nm in nested_metas {
if let NestedMeta::Meta(Meta::NameValue(MetaNameValue {
path,
lit: Lit::Str(with),
..
})) = nm
{
if path.is_ident("with") {
return Some(with);
}
}
}
None
}
fn parse_lit_str<T>(s: &syn::LitStr) -> parse::Result<T>
where
T: Parse,
{
let tokens = spanned_tokens(s)?;
syn::parse2(tokens)
}
fn spanned_tokens(s: &syn::LitStr) -> parse::Result<TokenStream> {
let stream = syn::parse_str(&s.value())?;
Ok(respan_token_stream(stream, s.span()))
}
fn respan_token_stream(stream: TokenStream, span: Span) -> TokenStream {
stream
.into_iter()
.map(|token| respan_token_tree(token, span))
.collect()
}
fn respan_token_tree(mut token: TokenTree, span: Span) -> TokenTree {
if let TokenTree::Group(g) = &mut token {
*g = Group::new(g.delimiter(), respan_token_stream(g.stream(), span));
}
token.set_span(span);
token
}

View file

@ -2,7 +2,7 @@ use quote::ToTokens;
use serde_derive_internals::Ctxt; use serde_derive_internals::Ctxt;
use std::collections::HashSet; use std::collections::HashSet;
use syn::parse::Parser; use syn::parse::Parser;
use syn::{Attribute, Data, DeriveInput, Field, GenericParam, Generics, Meta, NestedMeta, Variant}; use syn::{Attribute, Data, Field, Meta, NestedMeta, Variant};
// List of keywords that can appear in #[serde(...)]/#[schemars(...)] attributes which we want serde_derive_internals to parse for us. // List of keywords that can appear in #[serde(...)]/#[schemars(...)] attributes which we want serde_derive_internals to parse for us.
static SERDE_KEYWORDS: &[&str] = &[ static SERDE_KEYWORDS: &[&str] = &[
@ -28,17 +28,9 @@ static SERDE_KEYWORDS: &[&str] = &[
"with", "with",
]; ];
pub fn add_trait_bounds(generics: &mut Generics) {
for param in &mut generics.params {
if let GenericParam::Type(ref mut type_param) = *param {
type_param.bounds.push(parse_quote!(schemars::JsonSchema));
}
}
}
// If a struct/variant/field has any #[schemars] attributes, then create copies of them // If a struct/variant/field has any #[schemars] attributes, then create copies of them
// as #[serde] attributes so that serde_derive_internals will parse them for us. // as #[serde] attributes so that serde_derive_internals will parse them for us.
pub fn process_serde_attrs(input: &mut DeriveInput) -> Result<(), Vec<syn::Error>> { pub fn process_serde_attrs(input: &mut syn::DeriveInput) -> Result<(), Vec<syn::Error>> {
let ctxt = Ctxt::new(); let ctxt = Ctxt::new();
process_attrs(&ctxt, &mut input.attrs); process_attrs(&ctxt, &mut input.attrs);
match input.data { match input.data {

View file

@ -4,32 +4,30 @@ extern crate quote;
extern crate syn; extern crate syn;
extern crate proc_macro; extern crate proc_macro;
mod doc_attrs; mod attr;
mod metadata; mod metadata;
mod preprocess;
use metadata::*; use metadata::*;
use proc_macro2::{Group, Span, TokenStream, TokenTree}; use proc_macro2::TokenStream;
use quote::ToTokens; use quote::ToTokens;
use serde_derive_internals::ast::{Container, Data, Field, Style, Variant}; use serde_derive_internals::ast::{Container, Data, Field, Style, Variant};
use serde_derive_internals::attr::{self, Default as SerdeDefault, TagType}; use serde_derive_internals::attr::{self as serde_attr, Default as SerdeDefault, TagType};
use serde_derive_internals::{Ctxt, Derive}; use serde_derive_internals::{Ctxt, Derive};
use syn::parse::{self, Parse};
use syn::spanned::Spanned; use syn::spanned::Spanned;
#[proc_macro_derive(JsonSchema, attributes(schemars, serde, doc))] #[proc_macro_derive(JsonSchema, attributes(schemars, serde, doc))]
pub fn derive_json_schema(input: proc_macro::TokenStream) -> proc_macro::TokenStream { pub fn derive_json_schema(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let mut input = parse_macro_input!(input as syn::DeriveInput); let mut input = parse_macro_input!(input as syn::DeriveInput);
preprocess::add_trait_bounds(&mut input.generics); add_trait_bounds(&mut input.generics);
if let Err(e) = preprocess::process_serde_attrs(&mut input) { if let Err(e) = attr::process_serde_attrs(&mut input) {
return compile_error(e).into(); return compile_error(&e).into();
} }
let ctxt = Ctxt::new(); let ctxt = Ctxt::new();
let cont = Container::from_ast(&ctxt, &input, Derive::Deserialize); let cont = Container::from_ast(&ctxt, &input, Derive::Deserialize);
if let Err(e) = ctxt.check() { if let Err(e) = ctxt.check() {
return compile_error(e).into(); return compile_error(&e).into();
} }
let cont = cont.expect("from_ast set no errors on Ctxt, so should have returned Some"); let cont = cont.expect("from_ast set no errors on Ctxt, so should have returned Some");
@ -84,6 +82,14 @@ pub fn derive_json_schema(input: proc_macro::TokenStream) -> proc_macro::TokenSt
proc_macro::TokenStream::from(impl_block) proc_macro::TokenStream::from(impl_block)
} }
fn add_trait_bounds(generics: &mut syn::Generics) {
for param in &mut generics.params {
if let syn::GenericParam::Type(ref mut type_param) = *param {
type_param.bounds.push(parse_quote!(schemars::JsonSchema));
}
}
}
fn wrap_schema_fields(schema_contents: TokenStream) -> TokenStream { fn wrap_schema_fields(schema_contents: TokenStream) -> TokenStream {
quote! { quote! {
schemars::schema::Schema::Object( schemars::schema::Schema::Object(
@ -94,8 +100,8 @@ fn wrap_schema_fields(schema_contents: TokenStream) -> TokenStream {
} }
} }
fn compile_error(errors: Vec<syn::Error>) -> TokenStream { fn compile_error<'a>(errors: impl IntoIterator<Item = &'a syn::Error>) -> TokenStream {
let compile_errors = errors.iter().map(syn::Error::to_compile_error); let compile_errors = errors.into_iter().map(syn::Error::to_compile_error);
quote! { quote! {
#(#compile_errors)* #(#compile_errors)*
} }
@ -108,7 +114,7 @@ fn is_unit_variant(v: &Variant) -> bool {
} }
} }
fn schema_for_enum(variants: &[Variant], cattrs: &attr::Container) -> TokenStream { fn schema_for_enum(variants: &[Variant], cattrs: &serde_attr::Container) -> TokenStream {
let variants = variants.iter().filter(|v| !v.attrs.skip_deserializing()); let variants = variants.iter().filter(|v| !v.attrs.skip_deserializing());
match cattrs.tag() { match cattrs.tag() {
TagType::External => schema_for_external_tagged_enum(variants, cattrs), TagType::External => schema_for_external_tagged_enum(variants, cattrs),
@ -120,7 +126,7 @@ fn schema_for_enum(variants: &[Variant], cattrs: &attr::Container) -> TokenStrea
fn schema_for_external_tagged_enum<'a>( fn schema_for_external_tagged_enum<'a>(
variants: impl Iterator<Item = &'a Variant<'a>>, variants: impl Iterator<Item = &'a Variant<'a>>,
cattrs: &attr::Container, cattrs: &serde_attr::Container,
) -> TokenStream { ) -> TokenStream {
let (unit_variants, complex_variants): (Vec<_>, Vec<_>) = let (unit_variants, complex_variants): (Vec<_>, Vec<_>) =
variants.partition(|v| is_unit_variant(v)); variants.partition(|v| is_unit_variant(v));
@ -174,7 +180,7 @@ fn schema_for_external_tagged_enum<'a>(
fn schema_for_internal_tagged_enum<'a>( fn schema_for_internal_tagged_enum<'a>(
variants: impl Iterator<Item = &'a Variant<'a>>, variants: impl Iterator<Item = &'a Variant<'a>>,
cattrs: &attr::Container, cattrs: &serde_attr::Container,
tag_name: &str, tag_name: &str,
) -> TokenStream { ) -> TokenStream {
let schemas = variants.map(|variant| { let schemas = variants.map(|variant| {
@ -229,7 +235,7 @@ fn schema_for_internal_tagged_enum<'a>(
fn schema_for_untagged_enum<'a>( fn schema_for_untagged_enum<'a>(
variants: impl Iterator<Item = &'a Variant<'a>>, variants: impl Iterator<Item = &'a Variant<'a>>,
cattrs: &attr::Container, cattrs: &serde_attr::Container,
) -> TokenStream { ) -> TokenStream {
let schemas = variants.map(|variant| { let schemas = variants.map(|variant| {
let schema_expr = schema_for_untagged_enum_variant(variant, cattrs); let schema_expr = schema_for_untagged_enum_variant(variant, cattrs);
@ -244,7 +250,10 @@ fn schema_for_untagged_enum<'a>(
}) })
} }
fn schema_for_untagged_enum_variant(variant: &Variant, cattrs: &attr::Container) -> TokenStream { fn schema_for_untagged_enum_variant(
variant: &Variant,
cattrs: &serde_attr::Container,
) -> TokenStream {
match variant.style { match variant.style {
Style::Unit => schema_for_unit_struct(), Style::Unit => schema_for_unit_struct(),
Style::Newtype => schema_for_newtype_struct(&variant.fields[0]), Style::Newtype => schema_for_newtype_struct(&variant.fields[0]),
@ -276,7 +285,7 @@ fn schema_for_tuple_struct(fields: &[Field]) -> TokenStream {
} }
} }
fn schema_for_struct(fields: &[Field], cattrs: &attr::Container) -> TokenStream { fn schema_for_struct(fields: &[Field], cattrs: &serde_attr::Container) -> TokenStream {
let (flat, nested): (Vec<_>, Vec<_>) = fields let (flat, nested): (Vec<_>, Vec<_>) = fields
.iter() .iter()
.filter(|f| !f.attrs.skip_deserializing() || !f.attrs.skip_serializing()) .filter(|f| !f.attrs.skip_deserializing() || !f.attrs.skip_serializing())
@ -284,49 +293,14 @@ fn schema_for_struct(fields: &[Field], cattrs: &attr::Container) -> TokenStream
let set_container_default = match cattrs.default() { let set_container_default = match cattrs.default() {
SerdeDefault::None => None, SerdeDefault::None => None,
SerdeDefault::Default => Some(quote!(let cdefault = Self::default();)), SerdeDefault::Default => Some(quote!(let container_default = Self::default();)),
SerdeDefault::Path(path) => Some(quote!(let cdefault = #path();)), SerdeDefault::Path(path) => Some(quote!(let container_default = #path();)),
}; };
let mut required = Vec::new(); let mut required = Vec::new();
let recurse = nested.iter().map(|field| { let recurse = nested.iter().map(|field| {
let name = field.attrs.name().deserialize_name(); let name = field.attrs.name().deserialize_name();
let ty = field.ty; let default = field_default_expr(field, set_container_default.is_some());
let default = match field.attrs.default() {
_ if field.attrs.skip_serializing() => None,
SerdeDefault::None if set_container_default.is_none() => None,
SerdeDefault::None => {
let field_ident = field
.original
.ident
.as_ref()
.expect("This is not a tuple struct, so field should be named");
Some(quote!(cdefault.#field_ident))
}
SerdeDefault::Default => Some(quote!(<#ty>::default())),
SerdeDefault::Path(path) => Some(quote!(#path())),
}
.map(|d| match field.attrs.serialize_with() {
Some(ser_with) => quote! {
{
struct _SchemarsDefaultSerialize<T>(T);
impl serde::Serialize for _SchemarsDefaultSerialize<#ty>
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer
{
#ser_with(&self.0, serializer)
}
}
_SchemarsDefaultSerialize(#d)
}
},
None => d,
});
if default.is_none() { if default.is_none() {
required.push(name.clone()); required.push(name.clone());
@ -384,73 +358,50 @@ fn schema_for_struct(fields: &[Field], cattrs: &attr::Container) -> TokenStream
} }
} }
fn field_default_expr(field: &Field, container_has_default: bool) -> Option<TokenStream> {
let field_default = field.attrs.default();
if field.attrs.skip_serializing() || (field_default.is_none() && !container_has_default) {
return None;
}
let ty = field.ty;
let default_expr = match field_default {
SerdeDefault::None => {
let member = &field.member;
quote!(container_default.#member)
}
SerdeDefault::Default => quote!(<#ty>::default()),
SerdeDefault::Path(path) => quote!(#path()),
};
Some(if let Some(ser_with) = field.attrs.serialize_with() {
quote! {
{
struct _SchemarsDefaultSerialize<T>(T);
impl serde::Serialize for _SchemarsDefaultSerialize<#ty>
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer
{
#ser_with(&self.0, serializer)
}
}
_SchemarsDefaultSerialize(#default_expr)
}
}
} else {
default_expr
})
}
fn get_json_schema_type(field: &Field) -> Box<dyn ToTokens> { fn get_json_schema_type(field: &Field) -> Box<dyn ToTokens> {
// TODO support [schemars(schema_with= "...")] or equivalent // TODO support [schemars(schema_with= "...")] or equivalent
match field match attr::get_with_from_attrs(&field.original) {
.original
.attrs
.iter()
.filter(|at| match at.path.get_ident() {
// FIXME this is relying on order of attributes (schemars before serde) from preprocess.rs
Some(i) => i == "schemars" || i == "serde",
None => false,
})
.filter_map(get_with_from_attr)
.next()
{
Some(with) => match parse_lit_str::<syn::ExprPath>(&with) {
Ok(expr_path) => Box::new(expr_path),
Err(e) => Box::new(compile_error(vec![e])),
},
None => Box::new(field.ty.clone()), None => Box::new(field.ty.clone()),
Some(Ok(expr_path)) => Box::new(expr_path),
Some(Err(e)) => Box::new(compile_error(&[e])),
} }
} }
fn get_with_from_attr(attr: &syn::Attribute) -> Option<syn::LitStr> {
use syn::*;
let nested_metas = match attr.parse_meta() {
Ok(Meta::List(meta)) => meta.nested,
_ => return None,
};
for nm in nested_metas {
if let NestedMeta::Meta(Meta::NameValue(MetaNameValue {
path,
lit: Lit::Str(with),
..
})) = nm
{
if path.is_ident("with") {
return Some(with);
}
}
}
None
}
fn parse_lit_str<T>(s: &syn::LitStr) -> parse::Result<T>
where
T: Parse,
{
let tokens = spanned_tokens(s)?;
syn::parse2(tokens)
}
fn spanned_tokens(s: &syn::LitStr) -> parse::Result<TokenStream> {
let stream = syn::parse_str(&s.value())?;
Ok(respan_token_stream(stream, s.span()))
}
fn respan_token_stream(stream: TokenStream, span: Span) -> TokenStream {
stream
.into_iter()
.map(|token| respan_token_tree(token, span))
.collect()
}
fn respan_token_tree(mut token: TokenTree, span: Span) -> TokenTree {
if let TokenTree::Group(g) = &mut token {
*g = Group::new(g.delimiter(), respan_token_stream(g.stream(), span));
}
token.set_span(span);
token
}

View file

@ -1,4 +1,4 @@
use crate::doc_attrs; use crate::attr;
use proc_macro2::TokenStream; use proc_macro2::TokenStream;
use syn::{Attribute, ExprPath}; use syn::{Attribute, ExprPath};
@ -21,7 +21,7 @@ pub fn set_metadata_on_schema_from_docs(
} }
pub fn get_metadata_from_docs(attrs: &[Attribute]) -> SchemaMetadata { pub fn get_metadata_from_docs(attrs: &[Attribute]) -> SchemaMetadata {
let (title, description) = doc_attrs::get_title_and_desc_from_docs(attrs); let (title, description) = attr::get_title_and_desc_from_doc(attrs);
SchemaMetadata { SchemaMetadata {
title, title,
description, description,