Refactoring of schemars_derive

This commit is contained in:
Graham Esau 2019-12-09 20:57:38 +00:00
parent dca9e2d920
commit 3fb625e08c
5 changed files with 155 additions and 141 deletions

View file

@ -4,32 +4,30 @@ extern crate quote;
extern crate syn;
extern crate proc_macro;
mod doc_attrs;
mod attr;
mod metadata;
mod preprocess;
use metadata::*;
use proc_macro2::{Group, Span, TokenStream, TokenTree};
use proc_macro2::TokenStream;
use quote::ToTokens;
use serde_derive_internals::ast::{Container, Data, Field, Style, Variant};
use serde_derive_internals::attr::{self, Default as SerdeDefault, TagType};
use serde_derive_internals::attr::{self as serde_attr, Default as SerdeDefault, TagType};
use serde_derive_internals::{Ctxt, Derive};
use syn::parse::{self, Parse};
use syn::spanned::Spanned;
#[proc_macro_derive(JsonSchema, attributes(schemars, serde, doc))]
pub fn derive_json_schema(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let mut input = parse_macro_input!(input as syn::DeriveInput);
preprocess::add_trait_bounds(&mut input.generics);
if let Err(e) = preprocess::process_serde_attrs(&mut input) {
return compile_error(e).into();
add_trait_bounds(&mut input.generics);
if let Err(e) = attr::process_serde_attrs(&mut input) {
return compile_error(&e).into();
}
let ctxt = Ctxt::new();
let cont = Container::from_ast(&ctxt, &input, Derive::Deserialize);
if let Err(e) = ctxt.check() {
return compile_error(e).into();
return compile_error(&e).into();
}
let cont = cont.expect("from_ast set no errors on Ctxt, so should have returned Some");
@ -84,6 +82,14 @@ pub fn derive_json_schema(input: proc_macro::TokenStream) -> proc_macro::TokenSt
proc_macro::TokenStream::from(impl_block)
}
fn add_trait_bounds(generics: &mut syn::Generics) {
for param in &mut generics.params {
if let syn::GenericParam::Type(ref mut type_param) = *param {
type_param.bounds.push(parse_quote!(schemars::JsonSchema));
}
}
}
fn wrap_schema_fields(schema_contents: TokenStream) -> TokenStream {
quote! {
schemars::schema::Schema::Object(
@ -94,8 +100,8 @@ fn wrap_schema_fields(schema_contents: TokenStream) -> TokenStream {
}
}
fn compile_error(errors: Vec<syn::Error>) -> TokenStream {
let compile_errors = errors.iter().map(syn::Error::to_compile_error);
fn compile_error<'a>(errors: impl IntoIterator<Item = &'a syn::Error>) -> TokenStream {
let compile_errors = errors.into_iter().map(syn::Error::to_compile_error);
quote! {
#(#compile_errors)*
}
@ -108,7 +114,7 @@ fn is_unit_variant(v: &Variant) -> bool {
}
}
fn schema_for_enum(variants: &[Variant], cattrs: &attr::Container) -> TokenStream {
fn schema_for_enum(variants: &[Variant], cattrs: &serde_attr::Container) -> TokenStream {
let variants = variants.iter().filter(|v| !v.attrs.skip_deserializing());
match cattrs.tag() {
TagType::External => schema_for_external_tagged_enum(variants, cattrs),
@ -120,7 +126,7 @@ fn schema_for_enum(variants: &[Variant], cattrs: &attr::Container) -> TokenStrea
fn schema_for_external_tagged_enum<'a>(
variants: impl Iterator<Item = &'a Variant<'a>>,
cattrs: &attr::Container,
cattrs: &serde_attr::Container,
) -> TokenStream {
let (unit_variants, complex_variants): (Vec<_>, Vec<_>) =
variants.partition(|v| is_unit_variant(v));
@ -174,7 +180,7 @@ fn schema_for_external_tagged_enum<'a>(
fn schema_for_internal_tagged_enum<'a>(
variants: impl Iterator<Item = &'a Variant<'a>>,
cattrs: &attr::Container,
cattrs: &serde_attr::Container,
tag_name: &str,
) -> TokenStream {
let schemas = variants.map(|variant| {
@ -229,7 +235,7 @@ fn schema_for_internal_tagged_enum<'a>(
fn schema_for_untagged_enum<'a>(
variants: impl Iterator<Item = &'a Variant<'a>>,
cattrs: &attr::Container,
cattrs: &serde_attr::Container,
) -> TokenStream {
let schemas = variants.map(|variant| {
let schema_expr = schema_for_untagged_enum_variant(variant, cattrs);
@ -244,7 +250,10 @@ fn schema_for_untagged_enum<'a>(
})
}
fn schema_for_untagged_enum_variant(variant: &Variant, cattrs: &attr::Container) -> TokenStream {
fn schema_for_untagged_enum_variant(
variant: &Variant,
cattrs: &serde_attr::Container,
) -> TokenStream {
match variant.style {
Style::Unit => schema_for_unit_struct(),
Style::Newtype => schema_for_newtype_struct(&variant.fields[0]),
@ -276,7 +285,7 @@ fn schema_for_tuple_struct(fields: &[Field]) -> TokenStream {
}
}
fn schema_for_struct(fields: &[Field], cattrs: &attr::Container) -> TokenStream {
fn schema_for_struct(fields: &[Field], cattrs: &serde_attr::Container) -> TokenStream {
let (flat, nested): (Vec<_>, Vec<_>) = fields
.iter()
.filter(|f| !f.attrs.skip_deserializing() || !f.attrs.skip_serializing())
@ -284,49 +293,14 @@ fn schema_for_struct(fields: &[Field], cattrs: &attr::Container) -> TokenStream
let set_container_default = match cattrs.default() {
SerdeDefault::None => None,
SerdeDefault::Default => Some(quote!(let cdefault = Self::default();)),
SerdeDefault::Path(path) => Some(quote!(let cdefault = #path();)),
SerdeDefault::Default => Some(quote!(let container_default = Self::default();)),
SerdeDefault::Path(path) => Some(quote!(let container_default = #path();)),
};
let mut required = Vec::new();
let recurse = nested.iter().map(|field| {
let name = field.attrs.name().deserialize_name();
let ty = field.ty;
let default = match field.attrs.default() {
_ if field.attrs.skip_serializing() => None,
SerdeDefault::None if set_container_default.is_none() => None,
SerdeDefault::None => {
let field_ident = field
.original
.ident
.as_ref()
.expect("This is not a tuple struct, so field should be named");
Some(quote!(cdefault.#field_ident))
}
SerdeDefault::Default => Some(quote!(<#ty>::default())),
SerdeDefault::Path(path) => Some(quote!(#path())),
}
.map(|d| match field.attrs.serialize_with() {
Some(ser_with) => quote! {
{
struct _SchemarsDefaultSerialize<T>(T);
impl serde::Serialize for _SchemarsDefaultSerialize<#ty>
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer
{
#ser_with(&self.0, serializer)
}
}
_SchemarsDefaultSerialize(#d)
}
},
None => d,
});
let default = field_default_expr(field, set_container_default.is_some());
if default.is_none() {
required.push(name.clone());
@ -384,73 +358,50 @@ fn schema_for_struct(fields: &[Field], cattrs: &attr::Container) -> TokenStream
}
}
fn get_json_schema_type(field: &Field) -> Box<dyn ToTokens> {
// TODO support [schemars(schema_with= "...")] or equivalent
match field
.original
.attrs
.iter()
.filter(|at| match at.path.get_ident() {
// FIXME this is relying on order of attributes (schemars before serde) from preprocess.rs
Some(i) => i == "schemars" || i == "serde",
None => false,
})
.filter_map(get_with_from_attr)
.next()
{
Some(with) => match parse_lit_str::<syn::ExprPath>(&with) {
Ok(expr_path) => Box::new(expr_path),
Err(e) => Box::new(compile_error(vec![e])),
},
None => Box::new(field.ty.clone()),
fn field_default_expr(field: &Field, container_has_default: bool) -> Option<TokenStream> {
let field_default = field.attrs.default();
if field.attrs.skip_serializing() || (field_default.is_none() && !container_has_default) {
return None;
}
}
fn get_with_from_attr(attr: &syn::Attribute) -> Option<syn::LitStr> {
use syn::*;
let nested_metas = match attr.parse_meta() {
Ok(Meta::List(meta)) => meta.nested,
_ => return None,
let ty = field.ty;
let default_expr = match field_default {
SerdeDefault::None => {
let member = &field.member;
quote!(container_default.#member)
}
SerdeDefault::Default => quote!(<#ty>::default()),
SerdeDefault::Path(path) => quote!(#path()),
};
for nm in nested_metas {
if let NestedMeta::Meta(Meta::NameValue(MetaNameValue {
path,
lit: Lit::Str(with),
..
})) = nm
{
if path.is_ident("with") {
return Some(with);
Some(if let Some(ser_with) = field.attrs.serialize_with() {
quote! {
{
struct _SchemarsDefaultSerialize<T>(T);
impl serde::Serialize for _SchemarsDefaultSerialize<#ty>
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer
{
#ser_with(&self.0, serializer)
}
}
_SchemarsDefaultSerialize(#default_expr)
}
}
} else {
default_expr
})
}
fn get_json_schema_type(field: &Field) -> Box<dyn ToTokens> {
// TODO support [schemars(schema_with= "...")] or equivalent
match attr::get_with_from_attrs(&field.original) {
None => Box::new(field.ty.clone()),
Some(Ok(expr_path)) => Box::new(expr_path),
Some(Err(e)) => Box::new(compile_error(&[e])),
}
None
}
fn parse_lit_str<T>(s: &syn::LitStr) -> parse::Result<T>
where
T: Parse,
{
let tokens = spanned_tokens(s)?;
syn::parse2(tokens)
}
fn spanned_tokens(s: &syn::LitStr) -> parse::Result<TokenStream> {
let stream = syn::parse_str(&s.value())?;
Ok(respan_token_stream(stream, s.span()))
}
fn respan_token_stream(stream: TokenStream, span: Span) -> TokenStream {
stream
.into_iter()
.map(|token| respan_token_tree(token, span))
.collect()
}
fn respan_token_tree(mut token: TokenTree, span: Span) -> TokenTree {
if let TokenTree::Group(g) = &mut token {
*g = Group::new(g.delimiter(), respan_token_stream(g.stream(), span));
}
token.set_span(span);
token
}