diff --git a/schemars_derive/src/ast/from_serde.rs b/schemars_derive/src/ast/from_serde.rs new file mode 100644 index 0000000..eb91b48 --- /dev/null +++ b/schemars_derive/src/ast/from_serde.rs @@ -0,0 +1,76 @@ +use super::*; +use crate::attr::get_with_from_attrs; +use serde_derive_internals::ast as serde_ast; +use serde_derive_internals::Ctxt; + +pub trait FromSerde: Sized { + type SerdeType; + + fn from_serde(errors: &Ctxt, serde: Self::SerdeType) -> Result; + + fn vec_from_serde(errors: &Ctxt, serdes: Vec) -> Result, ()> { + let mut result = Vec::with_capacity(serdes.len()); + for s in serdes { + result.push(Self::from_serde(errors, s)?) + } + Ok(result) + } +} + +impl<'a> FromSerde for Container<'a> { + type SerdeType = serde_ast::Container<'a>; + + fn from_serde(errors: &Ctxt, serde: Self::SerdeType) -> Result { + Ok(Self { + ident: serde.ident, + serde_attrs: serde.attrs, + data: Data::from_serde(errors, serde.data)?, + generics: serde.generics, + original: serde.original, + }) + } +} + +impl<'a> FromSerde for Data<'a> { + type SerdeType = serde_ast::Data<'a>; + + fn from_serde(errors: &Ctxt, serde: Self::SerdeType) -> Result { + Ok(match serde { + Self::SerdeType::Enum(variants) => { + Self::Enum(Variant::vec_from_serde(errors, variants)?) + } + Self::SerdeType::Struct(style, fields) => { + Self::Struct(style, Field::vec_from_serde(errors, fields)?) + } + }) + } +} + +impl<'a> FromSerde for Variant<'a> { + type SerdeType = serde_ast::Variant<'a>; + + fn from_serde(errors: &Ctxt, serde: Self::SerdeType) -> Result { + Ok(Self { + ident: serde.ident, + serde_attrs: serde.attrs, + style: serde.style, + fields: Field::vec_from_serde(errors, serde.fields)?, + original: serde.original, + with: get_with_from_attrs(&serde.original.attrs, errors)?, + }) + } +} + +impl<'a> FromSerde for Field<'a> { + type SerdeType = serde_ast::Field<'a>; + + fn from_serde(errors: &Ctxt, serde: Self::SerdeType) -> Result { + Ok(Self { + member: serde.member, + serde_attrs: serde.attrs, + ty: serde.ty, + original: serde.original, + with: get_with_from_attrs(&serde.original.attrs, errors)?, + }) + } +} diff --git a/schemars_derive/src/ast/mod.rs b/schemars_derive/src/ast/mod.rs new file mode 100644 index 0000000..b67850e --- /dev/null +++ b/schemars_derive/src/ast/mod.rs @@ -0,0 +1,63 @@ +mod from_serde; + +use from_serde::FromSerde; +use serde_derive_internals::ast as serde_ast; +use serde_derive_internals::{Ctxt, Derive}; + +pub struct Container<'a> { + pub ident: syn::Ident, + pub serde_attrs: serde_derive_internals::attr::Container, + pub data: Data<'a>, + pub generics: &'a syn::Generics, + pub original: &'a syn::DeriveInput, +} + +pub enum Data<'a> { + Enum(Vec>), + Struct(serde_ast::Style, Vec>), +} + +pub struct Variant<'a> { + pub ident: syn::Ident, + pub serde_attrs: serde_derive_internals::attr::Variant, + pub style: serde_ast::Style, + pub fields: Vec>, + pub original: &'a syn::Variant, + pub with: Option, +} + +pub struct Field<'a> { + pub member: syn::Member, + pub serde_attrs: serde_derive_internals::attr::Field, + pub ty: &'a syn::Type, + pub original: &'a syn::Field, + pub with: Option, +} + +impl<'a> Container<'a> { + pub fn from_ast(item: &'a syn::DeriveInput) -> Result, Vec> { + let ctxt = Ctxt::new(); + let result = serde_ast::Container::from_ast(&ctxt, item, Derive::Deserialize) + .ok_or(()) + .and_then(|serde| Self::from_serde(&ctxt, serde)); + + ctxt.check() + .map(|_| result.expect("from_ast set no errors on Ctxt, so should have returned Ok")) + } + + pub fn name(&self) -> String { + self.serde_attrs.name().deserialize_name() + } +} + +impl<'a> Variant<'a> { + pub fn name(&self) -> String { + self.serde_attrs.name().deserialize_name() + } +} + +impl<'a> Field<'a> { + pub fn name(&self) -> String { + self.serde_attrs.name().deserialize_name() + } +} diff --git a/schemars_derive/src/attr/mod.rs b/schemars_derive/src/attr/mod.rs index eb04f55..841228b 100644 --- a/schemars_derive/src/attr/mod.rs +++ b/schemars_derive/src/attr/mod.rs @@ -5,9 +5,13 @@ pub use doc::get_title_and_desc_from_doc; pub use schemars_to_serde::process_serde_attrs; use proc_macro2::{Group, Span, TokenStream, TokenTree}; +use serde_derive_internals::Ctxt; use syn::parse::{self, Parse}; -pub fn get_with_from_attrs(attrs: &[syn::Attribute]) -> Option> { +pub fn get_with_from_attrs( + attrs: &[syn::Attribute], + errors: &Ctxt, +) -> Result, ()> { attrs .iter() .filter(|at| match at.path.get_ident() { @@ -17,7 +21,13 @@ pub fn get_with_from_attrs(attrs: &[syn::Attribute]) -> Option(&lit) { + Ok(t) => Ok(Some(t)), + Err(e) => { + errors.error_spanned_by(lit, e); + Err(()) + } + }) } fn get_with_from_attr(attr: &syn::Attribute) -> Option { diff --git a/schemars_derive/src/lib.rs b/schemars_derive/src/lib.rs index 8e59735..6b69957 100644 --- a/schemars_derive/src/lib.rs +++ b/schemars_derive/src/lib.rs @@ -4,52 +4,54 @@ extern crate quote; extern crate syn; extern crate proc_macro; +mod ast; mod attr; mod metadata; +use ast::*; use metadata::*; use proc_macro2::TokenStream; use quote::ToTokens; -use serde_derive_internals::ast::{Container, Data, Field, Style, Variant as SerdeVariant}; +use serde_derive_internals::ast::Style; use serde_derive_internals::attr::{self as serde_attr, Default as SerdeDefault, TagType}; -use serde_derive_internals::{Ctxt, Derive}; -use std::ops::Deref; use syn::spanned::Spanned; #[proc_macro_derive(JsonSchema, attributes(schemars, serde))] -pub fn derive_json_schema(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let mut input = parse_macro_input!(input as syn::DeriveInput); +pub fn derive_json_schema_wrapper(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(input as syn::DeriveInput); + derive_json_schema(input).into() +} +fn derive_json_schema(mut input: syn::DeriveInput) -> TokenStream { add_trait_bounds(&mut input.generics); + if let Err(e) = attr::process_serde_attrs(&mut input) { - return compile_error(&e).into(); + return compile_error(&e); } - let ctxt = Ctxt::new(); - let cont = Container::from_ast(&ctxt, &input, Derive::Deserialize); - if let Err(e) = ctxt.check() { - return compile_error(&e).into(); - } - let cont = cont.expect("from_ast set no errors on Ctxt, so should have returned Some"); + let cont = match Container::from_ast(&input) { + Ok(c) => c, + Err(e) => return compile_error(&e), + }; - let schema_expr = match cont.data { + let schema_expr = match &cont.data { Data::Struct(Style::Unit, _) => schema_for_unit_struct(), - Data::Struct(Style::Newtype, ref fields) => schema_for_newtype_struct(&fields[0]), - Data::Struct(Style::Tuple, ref fields) => schema_for_tuple_struct(fields), - Data::Struct(Style::Struct, ref fields) => schema_for_struct(fields, Some(&cont.attrs)), - Data::Enum(variants) => schema_for_enum(&Variant::vec_new(variants), &cont.attrs), + Data::Struct(Style::Newtype, fields) => schema_for_newtype_struct(&fields[0]), + Data::Struct(Style::Tuple, fields) => schema_for_tuple_struct(fields), + Data::Struct(Style::Struct, fields) => schema_for_struct(fields, Some(&cont.serde_attrs)), + Data::Enum(variants) => schema_for_enum(variants, &cont.serde_attrs), }; let doc_metadata = SchemaMetadata::from_doc_attrs(&cont.original.attrs); let schema_expr = doc_metadata.apply_to_schema(schema_expr); - let type_name = cont.ident; + let type_name = &cont.ident; let type_params: Vec<_> = cont.generics.type_params().map(|ty| &ty.ident).collect(); - let mut schema_base_name = cont.attrs.name().deserialize_name(); - let schema_is_renamed = type_name != schema_base_name; + let mut schema_base_name = cont.name(); + let schema_is_renamed = *type_name != schema_base_name; if !schema_is_renamed { - if let Some(path) = cont.attrs.remote() { + if let Some(path) = cont.serde_attrs.remote() { if let Some(segment) = path.segments.last() { schema_base_name = segment.ident.to_string(); } @@ -79,7 +81,7 @@ pub fn derive_json_schema(input: proc_macro::TokenStream) -> proc_macro::TokenSt let (impl_generics, ty_generics, where_clause) = cont.generics.split_for_impl(); - let impl_block = quote! { + quote! { #[automatically_derived] impl #impl_generics schemars::JsonSchema for #type_name #ty_generics #where_clause { fn schema_name() -> std::string::String { @@ -90,8 +92,7 @@ pub fn derive_json_schema(input: proc_macro::TokenStream) -> proc_macro::TokenSt #schema_expr } }; - }; - proc_macro::TokenStream::from(impl_block) + } } fn add_trait_bounds(generics: &mut syn::Generics) { @@ -127,7 +128,9 @@ fn is_unit_variant(v: &Variant) -> bool { } fn schema_for_enum(variants: &[Variant], cattrs: &serde_attr::Container) -> TokenStream { - let variants = variants.iter().filter(|v| !v.attrs.skip_deserializing()); + let variants = variants + .iter() + .filter(|v| !v.serde_attrs.skip_deserializing()); match cattrs.tag() { TagType::External => schema_for_external_tagged_enum(variants), TagType::None => schema_for_untagged_enum(variants), @@ -145,9 +148,7 @@ fn schema_for_external_tagged_enum<'a>( variants.partition(|v| is_unit_variant(v)); let unit_count = unit_variants.len(); - let unit_names = unit_variants - .into_iter() - .map(|v| v.attrs.name().deserialize_name()); + let unit_names = unit_variants.into_iter().map(|v| v.name()); let unit_schema = wrap_schema_fields(quote! { enum_values: Some(vec![#(#unit_names.into()),*]), }); @@ -162,7 +163,7 @@ fn schema_for_external_tagged_enum<'a>( } schemas.extend(complex_variants.into_iter().map(|variant| { - let name = variant.attrs.name().deserialize_name(); + let name = variant.name(); let sub_schema = schema_for_untagged_enum_variant(variant); let schema_expr = wrap_schema_fields(quote! { instance_type: Some(schemars::schema::InstanceType::Object.into()), @@ -197,7 +198,7 @@ fn schema_for_internal_tagged_enum<'a>( tag_name: &str, ) -> TokenStream { let schemas = variants.map(|variant| { - let name = variant.attrs.name().deserialize_name(); + let name = variant.name(); let type_schema = wrap_schema_fields(quote! { instance_type: Some(schemars::schema::InstanceType::String.into()), enum_values: Some(vec![#name.into()]), @@ -315,7 +316,7 @@ fn schema_for_adjacent_tagged_enum<'a>( }) .unwrap_or_default(); - let name = variant.attrs.name().deserialize_name(); + let name = variant.name(); let tag_schema = wrap_schema_fields(quote! { instance_type: Some(schemars::schema::InstanceType::String.into()), enum_values: Some(vec![#name.into()]), @@ -368,7 +369,7 @@ fn schema_for_newtype_struct(field: &Field) -> TokenStream { fn schema_for_tuple_struct(fields: &[Field]) -> TokenStream { let types = fields .iter() - .filter(|f| !f.attrs.skip_deserializing()) + .filter(|f| !f.serde_attrs.skip_deserializing()) .map(get_json_schema_type); quote! { gen.subschema_for::<(#(#types),*)>() @@ -378,8 +379,8 @@ fn schema_for_tuple_struct(fields: &[Field]) -> TokenStream { fn schema_for_struct(fields: &[Field], cattrs: Option<&serde_attr::Container>) -> TokenStream { let (flattened_fields, property_fields): (Vec<_>, Vec<_>) = fields .iter() - .filter(|f| !f.attrs.skip_deserializing() || !f.attrs.skip_serializing()) - .partition(|f| f.attrs.flatten()); + .filter(|f| !f.serde_attrs.skip_deserializing() || !f.serde_attrs.skip_serializing()) + .partition(|f| f.serde_attrs.flatten()); let set_container_default = match cattrs.map_or(&SerdeDefault::None, |c| c.default()) { SerdeDefault::None => None, @@ -388,7 +389,7 @@ fn schema_for_struct(fields: &[Field], cattrs: Option<&serde_attr::Container>) - }; let properties = property_fields.iter().map(|field| { - let name = field.attrs.name().deserialize_name(); + let name = field.name(); let default = field_default_expr(field, set_container_default.is_some()); let required = match default { @@ -397,8 +398,8 @@ fn schema_for_struct(fields: &[Field], cattrs: Option<&serde_attr::Container>) - }; let metadata = &SchemaMetadata { - read_only: field.attrs.skip_deserializing(), - write_only: field.attrs.skip_serializing(), + read_only: field.serde_attrs.skip_deserializing(), + write_only: field.serde_attrs.skip_serializing(), default, ..SchemaMetadata::from_doc_attrs(&field.original.attrs) }; @@ -435,8 +436,8 @@ fn schema_for_struct(fields: &[Field], cattrs: Option<&serde_attr::Container>) - } fn field_default_expr(field: &Field, container_has_default: bool) -> Option { - let field_default = field.attrs.default(); - if field.attrs.skip_serializing() || (field_default.is_none() && !container_has_default) { + let field_default = field.serde_attrs.default(); + if field.serde_attrs.skip_serializing() || (field_default.is_none() && !container_has_default) { return None; } @@ -450,7 +451,7 @@ fn field_default_expr(field: &Field, container_has_default: bool) -> Option quote!(#path()), }; - let default_expr = match field.attrs.skip_serializing_if() { + let default_expr = match field.serde_attrs.skip_serializing_if() { Some(skip_if) => { quote! { { @@ -466,7 +467,7 @@ fn field_default_expr(field: &Field, container_has_default: bool) -> Option quote!(Some(#default_expr)), }; - Some(if let Some(ser_with) = field.attrs.serialize_with() { + Some(if let Some(ser_with) = field.serde_attrs.serialize_with() { quote! { { struct _SchemarsDefaultSerialize(T); @@ -491,37 +492,8 @@ fn field_default_expr(field: &Field, container_has_default: bool) -> Option TokenStream { // TODO support [schemars(schema_with= "...")] or equivalent - match attr::get_with_from_attrs(&field.original.attrs) { - None => field.ty.to_token_stream(), - Some(Ok(expr_path)) => expr_path.to_token_stream(), - Some(Err(e)) => compile_error(&[e]), - } -} - -struct Variant<'a> { - serde: SerdeVariant<'a>, - with: Option, -} - -impl<'a> Variant<'a> { - fn new(serde: SerdeVariant<'a>) -> Self { - let with = match attr::get_with_from_attrs(&serde.original.attrs) { - None => None, - Some(Ok(expr_path)) => Some(expr_path.to_token_stream()), - Some(Err(e)) => Some(compile_error(&[e])), - }; - Self { serde, with } - } - - fn vec_new(serdes: Vec>) -> Vec> { - serdes.into_iter().map(Self::new).collect() - } -} - -impl<'a> Deref for Variant<'a> { - type Target = SerdeVariant<'a>; - - fn deref(&self) -> &Self::Target { - &self.serde - } + field + .with + .as_ref() + .map_or_else(|| field.ty.to_token_stream(), |w| w.to_token_stream()) }