Set "required" schema attribute

This commit is contained in:
Graham Esau 2019-08-08 18:34:47 +01:00
parent 6b64cedb91
commit 998e6c9f0f
7 changed files with 68 additions and 21 deletions

View file

@ -3,6 +3,7 @@ use crate::{MakeSchema, MakeSchemaError, Result};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::BTreeMap as Map;
use std::collections::BTreeSet as Set;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, MakeSchema)]
#[serde(untagged)]
@ -30,13 +31,13 @@ impl From<SchemaRef> for Schema {
}
}
impl Schema {
pub fn flatten(self, other: Self) -> Result {
fn extend<A, E: Extend<A>>(mut a: E, b: impl IntoIterator<Item = A>) -> E {
a.extend(b);
a
}
impl Schema {
pub fn flatten(self, other: Self) -> Result {
let s1 = self.ensure_flattenable()?;
let s2 = other.ensure_flattenable()?;
Ok(Schema::Object(SchemaObject {
@ -112,8 +113,8 @@ pub struct SchemaObject {
pub items: Option<SingleOrVec<Schema>>,
#[serde(skip_serializing_if = "Map::is_empty")]
pub properties: Map<String, Schema>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub required: Vec<String>,
#[serde(skip_serializing_if = "Set::is_empty")]
pub required: Set<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub all_of: Option<Vec<Schema>>,
#[serde(skip_serializing_if = "Option::is_none")]

View file

@ -12,5 +12,10 @@
"new_name_2": {
"type": "integer"
}
}
},
"required": [
"camelCase",
"new_name_1",
"new_name_2"
]
}

View file

@ -22,6 +22,13 @@
}
}
},
"required": [
"inner",
"t",
"u",
"v",
"w"
],
"definitions": {
"another-new-name": {
"type": "object"

View file

@ -22,6 +22,13 @@
}
}
},
"required": [
"inner",
"t",
"u",
"v",
"w"
],
"definitions": {
"MySimpleStruct": {
"type": "object"

View file

@ -151,7 +151,10 @@
"$ref": {
"type": "string"
}
}
},
"required": [
"$ref"
]
},
"SingleOrVec_For_InstanceType": {
"anyOf": [

View file

@ -187,7 +187,10 @@
"$ref": {
"type": "string"
}
}
},
"required": [
"$ref"
]
},
"SingleOrVec_For_InstanceType": {
"anyOf": [

View file

@ -9,7 +9,7 @@ mod preprocess;
use proc_macro2::{Span, TokenStream};
use serde_derive_internals::ast::{Container, Data, Field, Style, Variant};
use serde_derive_internals::attr::{self, EnumTag};
use serde_derive_internals::attr::{self, Default as SerdeDefault, EnumTag};
use serde_derive_internals::{Ctxt, Derive};
use syn::spanned::Spanned;
use syn::DeriveInput;
@ -28,7 +28,7 @@ pub fn derive_make_schema(input: proc_macro::TokenStream) -> proc_macro::TokenSt
}
let schema = match cont.data {
Data::Struct(Style::Struct, ref fields) => schema_for_struct(fields),
Data::Struct(Style::Struct, ref fields) => schema_for_struct(fields, &cont.attrs),
Data::Enum(ref variants) => schema_for_enum(variants, &cont.attrs),
_ => unimplemented!("work in progress!"),
};
@ -102,13 +102,13 @@ fn is_unit_variant(v: &&Variant) -> bool {
fn schema_for_enum(variants: &[Variant], cattrs: &attr::Container) -> TokenStream {
match cattrs.tag() {
EnumTag::External => schema_for_external_tagged_enum(variants),
EnumTag::None => schema_for_untagged_enum(variants),
EnumTag::External => schema_for_external_tagged_enum(variants, cattrs),
EnumTag::None => schema_for_untagged_enum(variants, cattrs),
_ => unimplemented!("Adjacent/internal tagged enums not yet supported."),
}
}
fn schema_for_external_tagged_enum(variants: &[Variant]) -> TokenStream {
fn schema_for_external_tagged_enum(variants: &[Variant], cattrs: &attr::Container) -> TokenStream {
let (unit_variants, complex_variants): (Vec<_>, Vec<_>) =
variants.into_iter().partition(is_unit_variant);
let unit_count = unit_variants.len();
@ -131,7 +131,7 @@ fn schema_for_external_tagged_enum(variants: &[Variant]) -> TokenStream {
schemas.extend(complex_variants.into_iter().map(|variant| {
let name = variant.attrs.name().deserialize_name();
let sub_schema = schema_for_untagged_enum_variant(variant);
let sub_schema = schema_for_untagged_enum_variant(variant, cattrs);
wrap_schema_fields(quote! {
instance_type: Some(schemars::schema::InstanceType::Object.into()),
properties: {
@ -139,6 +139,7 @@ fn schema_for_external_tagged_enum(variants: &[Variant]) -> TokenStream {
props.insert(#name.to_owned(), #sub_schema);
props
},
required: vec![#name.to_owned()],
})
}));
@ -147,15 +148,17 @@ fn schema_for_external_tagged_enum(variants: &[Variant]) -> TokenStream {
})
}
fn schema_for_untagged_enum(variants: &[Variant]) -> TokenStream {
let schemas = variants.into_iter().map(schema_for_untagged_enum_variant);
fn schema_for_untagged_enum(variants: &[Variant], cattrs: &attr::Container) -> TokenStream {
let schemas = variants
.into_iter()
.map(|v| schema_for_untagged_enum_variant(v, cattrs));
wrap_schema_fields(quote! {
any_of: Some(vec![#(#schemas),*]),
})
}
fn schema_for_untagged_enum_variant(variant: &Variant) -> TokenStream {
fn schema_for_untagged_enum_variant(variant: &Variant, cattrs: &attr::Container) -> TokenStream {
match variant.style {
Style::Unit => quote! {
gen.subschema_for::<()>()?
@ -173,19 +176,25 @@ fn schema_for_untagged_enum_variant(variant: &Variant) -> TokenStream {
gen.subschema_for::<(#(#types),*)>()?
}
}
Style::Struct => schema_for_struct(&variant.fields),
Style::Struct => schema_for_struct(&variant.fields, cattrs),
}
}
fn schema_for_struct(fields: &[Field]) -> TokenStream {
fn schema_for_struct(fields: &[Field], cattrs: &attr::Container) -> TokenStream {
let (nested, flat): (Vec<_>, Vec<_>) = fields.iter().partition(|f| !f.attrs.flatten());
let container_has_default = has_default(cattrs.default());
let mut required = Vec::new();
let recurse = nested.iter().map(|f| {
let name = f.attrs.name().deserialize_name();
if !container_has_default && !has_default(f.attrs.default()) {
required.push(name.clone());
}
let ty = f.ty;
quote_spanned! {f.original.span()=>
props.insert(#name.to_owned(), gen.subschema_for::<#ty>()?);
}
});
let schema = wrap_schema_fields(quote! {
instance_type: Some(schemars::schema::InstanceType::Object.into()),
properties: {
@ -193,6 +202,11 @@ fn schema_for_struct(fields: &[Field]) -> TokenStream {
#(#recurse)*
props
},
required: {
let mut required = std::collections::BTreeSet::new();
#(required.insert(#required.to_owned());)*
required
},
});
let flattens = flat.iter().map(|f| {
@ -206,3 +220,10 @@ fn schema_for_struct(fields: &[Field]) -> TokenStream {
#schema #(#flattens)*
}
}
fn has_default(d: &SerdeDefault) -> bool {
match d {
SerdeDefault::None => false,
_ => true,
}
}