diff --git a/schemars_derive/src/lib.rs b/schemars_derive/src/lib.rs index f126e73..d135804 100644 --- a/schemars_derive/src/lib.rs +++ b/schemars_derive/src/lib.rs @@ -23,7 +23,7 @@ pub fn derive_make_schema(input: proc_macro::TokenStream) -> proc_macro::TokenSt let name = cont.ident; let (impl_generics, ty_generics, where_clause) = cont.generics.split_for_impl(); - let schema_contents = match cont.data { + let schema = match cont.data { Data::Struct(Style::Struct, ref fields) => schema_for_struct(fields), Data::Enum(ref variants) => schema_for_enum(variants), _ => unimplemented!("work in progress!"), @@ -33,41 +33,84 @@ pub fn derive_make_schema(input: proc_macro::TokenStream) -> proc_macro::TokenSt #[automatically_derived] impl #impl_generics schemars::make_schema::MakeSchema for #name #ty_generics #where_clause { fn make_schema(gen: &mut schemars::SchemaGenerator) -> schemars::Schema { - schemars::SchemaObject { - #schema_contents - ..Default::default() - } - .into() + #schema } }; }; proc_macro::TokenStream::from(impl_block) } +fn wrap_schema_fields(schema_contents: TokenStream) -> TokenStream { + quote! { + schemars::SchemaObject { + #schema_contents + ..Default::default() + } + .into() + } +} + fn compile_error(span: Span, message: String) -> TokenStream { quote_spanned! {span=> compile_error!(#message); } } -fn name_for_unit_variant(v: &Variant) -> Option { +fn is_unit_variant(v: &&Variant) -> bool { match v.style { - Style::Unit => Some(v.attrs.name().deserialize_name()), - _ => None, + Style::Unit => true, + _ => false, } } fn schema_for_enum(variants: &[Variant]) -> TokenStream { // TODO handle untagged or adjacently tagged enums - let unit_names: Vec<_> = variants.iter().filter_map(name_for_unit_variant).collect(); + let (unit_variants, complex_variants): (Vec<_>, Vec<_>) = + variants.into_iter().partition(is_unit_variant); + let unit_count = unit_variants.len(); - if unit_names.len() == variants.len() { - return quote! { - enum_values: Some(vec![#(#unit_names.into()),*]), - }; + let unit_names = unit_variants + .into_iter() + .map(|v| v.attrs.name().deserialize_name()); + let unit_schema = wrap_schema_fields(quote! { + enum_values: Some(vec![#(#unit_names.into()),*]), + }); + + if complex_variants.is_empty() { + return unit_schema; } - unimplemented!("work in progress!") + let mut schemas = Vec::new(); + if unit_count > 0 { + schemas.push(unit_schema); + } + + schemas.extend(complex_variants.into_iter().map(|variant| { + let sub_schema = match variant.style { + Style::Newtype => { + let f = &variant.fields[0]; + let ty = f.ty; + quote_spanned! {f.original.span()=> + gen.subschema_for::<#ty>() + } + } + Style::Tuple => unimplemented!("work in progress!"), + Style::Struct => unimplemented!("work in progress!"), + Style::Unit => unreachable!("Unit variants already filtered out"), + }; + let name = variant.attrs.name().deserialize_name(); + wrap_schema_fields(quote! { + properties: { + let mut props = std::collections::BTreeMap::new(); + props.insert(#name.to_owned(), #sub_schema); + props + }, + }) + })); + + wrap_schema_fields(quote! { + any_of: Some(vec![#(#schemas),*]), + }) } fn schema_for_struct(fields: &[Field]) -> TokenStream { @@ -78,11 +121,11 @@ fn schema_for_struct(fields: &[Field]) -> TokenStream { props.insert(#name.to_owned(), gen.subschema_for::<#ty>()); } }); - quote! { + wrap_schema_fields(quote! { properties: { let mut props = std::collections::BTreeMap::new(); #(#recurse)* props }, - } + }) }