diff --git a/schemars/src/json_schema_impls/core.rs b/schemars/src/json_schema_impls/core.rs index 5cefbdd..647030f 100644 --- a/schemars/src/json_schema_impls/core.rs +++ b/schemars/src/json_schema_impls/core.rs @@ -62,13 +62,17 @@ impl JsonSchema for Option { parent: &mut SchemaObject, name: String, metadata: Option, - _required: bool, + required: Option, ) { - let mut schema = gen.subschema_for::(); - schema = gen.apply_metadata(schema, metadata); + if required == Some(true) { + T::add_schema_as_property(gen, parent, name, metadata, required) + } else { + let mut schema = gen.subschema_for::(); + schema = gen.apply_metadata(schema, metadata); - let object = parent.object(); - object.properties.insert(name, schema); + let object = parent.object(); + object.properties.insert(name, schema); + } } } diff --git a/schemars/src/json_schema_impls/mod.rs b/schemars/src/json_schema_impls/mod.rs index d493ea0..4d96ee2 100644 --- a/schemars/src/json_schema_impls/mod.rs +++ b/schemars/src/json_schema_impls/mod.rs @@ -30,7 +30,7 @@ macro_rules! forward_impl { parent: &mut crate::schema::SchemaObject, name: String, metadata: Option, - required: bool, + required: Option, ) { <$target>::add_schema_as_property(gen, parent, name, metadata, required) } diff --git a/schemars/src/lib.rs b/schemars/src/lib.rs index eb6d5ac..9a9e65c 100644 --- a/schemars/src/lib.rs +++ b/schemars/src/lib.rs @@ -375,13 +375,13 @@ pub trait JsonSchema { parent: &mut SchemaObject, name: String, metadata: Option, - required: bool, + required: Option, ) { let mut schema = gen.subschema_for::(); schema = gen.apply_metadata(schema, metadata); let object = parent.object(); - if required { + if required.unwrap_or(true) { object.required.insert(name.clone()); } object.properties.insert(name, schema); diff --git a/schemars/src/schema.rs b/schemars/src/schema.rs index a4c6e32..e9a0124 100644 --- a/schemars/src/schema.rs +++ b/schemars/src/schema.rs @@ -9,6 +9,7 @@ use crate::JsonSchema; use crate::{Map, Set}; use serde::{Deserialize, Serialize}; use serde_json::Value; +use std::ops::Deref; /// A JSON Schema. #[allow(clippy::large_enum_variant)] @@ -191,7 +192,13 @@ where macro_rules! get_or_insert_default_fn { ($name:ident, $ret:ty) => { get_or_insert_default_fn!( - concat!("Returns a mutable reference to this schema's [`", stringify!($ret), "`](#structfield.", stringify!($name), "), creating it if it was `None`."), + concat!( + "Returns a mutable reference to this schema's [`", + stringify!($ret), + "`](#structfield.", + stringify!($name), + "), creating it if it was `None`." + ), $name, $ret ); @@ -224,6 +231,13 @@ impl SchemaObject { self.reference.is_some() } + // TODO document + pub fn has_type(&self, ty: InstanceType) -> bool { + self.instance_type + .as_ref() + .map_or(true, |x| x.contains(&ty)) + } + get_or_insert_default_fn!(metadata, Metadata); get_or_insert_default_fn!(subschemas, SubschemaValidation); get_or_insert_default_fn!(number, NumberValidation); @@ -506,3 +520,13 @@ impl From> for SingleOrVec { SingleOrVec::Vec(vec) } } + +impl SingleOrVec { + // TODO document + pub fn contains(&self, x: &T) -> bool { + match self { + SingleOrVec::Single(s) => s.deref() == x, + SingleOrVec::Vec(v) => v.contains(x), + } + } +} diff --git a/schemars/tests/expected/validate.json b/schemars/tests/expected/validate.json new file mode 100644 index 0000000..228f249 --- /dev/null +++ b/schemars/tests/expected/validate.json @@ -0,0 +1,81 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "Struct", + "type": "object", + "required": [ + "contains_str1", + "contains_str2", + "email_address", + "homepage", + "map_contains", + "min_max", + "non_empty_str", + "pair", + "regex_str1", + "regex_str2", + "required_option", + "tel" + ], + "properties": { + "min_max": { + "type": "number", + "format": "float", + "maximum": 100.0, + "minimum": 0.01 + }, + "regex_str1": { + "type": "string", + "pattern": "^[Hh]ello\\b" + }, + "regex_str2": { + "type": "string", + "pattern": "^[Hh]ello\\b" + }, + "contains_str1": { + "type": "string", + "pattern": "substring\\.\\.\\." + }, + "contains_str2": { + "type": "string", + "pattern": "substring\\.\\.\\." + }, + "email_address": { + "type": "string", + "format": "email" + }, + "tel": { + "type": "string", + "format": "phone" + }, + "homepage": { + "type": "string", + "format": "uri" + }, + "non_empty_str": { + "type": "string", + "maxLength": 100, + "minLength": 1 + }, + "pair": { + "type": "array", + "items": { + "type": "integer", + "format": "int32" + }, + "maxItems": 2, + "minItems": 2 + }, + "map_contains": { + "type": "object", + "required": [ + "map_key" + ], + "additionalProperties": { + "type": "null" + } + }, + "required_option": { + "type": "boolean" + } + } +} \ No newline at end of file diff --git a/schemars/tests/validate.rs b/schemars/tests/validate.rs new file mode 100644 index 0000000..6d67531 --- /dev/null +++ b/schemars/tests/validate.rs @@ -0,0 +1,40 @@ +mod util; +use schemars::JsonSchema; +use std::collections::HashMap; +use util::*; + +// In real code, this would typically be a Regex, potentially created in a `lazy_static!`. +static STARTS_WITH_HELLO: &'static str = r"^[Hh]ello\b"; + +#[derive(Debug, JsonSchema)] +pub struct Struct { + #[validate(range(min = 0.01, max = 100))] + min_max: f32, + #[validate(regex = "STARTS_WITH_HELLO")] + regex_str1: String, + #[validate(regex(path = "STARTS_WITH_HELLO", code = "foo"))] + regex_str2: String, + #[validate(contains = "substring...")] + contains_str1: String, + #[validate(contains(pattern = "substring...", message = "bar"))] + contains_str2: String, + #[validate(email)] + email_address: String, + #[validate(phone)] + tel: String, + #[validate(url)] + homepage: String, + #[validate(length(min = 1, max = 100))] + non_empty_str: String, + #[validate(length(equal = 2))] + pair: Vec, + #[validate(contains = "map_key")] + map_contains: HashMap, + #[validate(required)] + required_option: Option, +} + +#[test] +fn validate() -> TestResult { + test_default_generated_schema::("validate") +} diff --git a/schemars_derive/src/ast/from_serde.rs b/schemars_derive/src/ast/from_serde.rs index 0d9add3..db2e092 100644 --- a/schemars_derive/src/ast/from_serde.rs +++ b/schemars_derive/src/ast/from_serde.rs @@ -73,6 +73,7 @@ impl<'a> FromSerde for Field<'a> { ty: serde.ty, original: serde.original, attrs: Attrs::new(&serde.original.attrs, errors), + validation_attrs: ValidationAttrs::new(&serde.original.attrs), }) } } diff --git a/schemars_derive/src/ast/mod.rs b/schemars_derive/src/ast/mod.rs index a394acd..99fe188 100644 --- a/schemars_derive/src/ast/mod.rs +++ b/schemars_derive/src/ast/mod.rs @@ -1,6 +1,6 @@ mod from_serde; -use crate::attr::Attrs; +use crate::attr::{Attrs, ValidationAttrs}; use from_serde::FromSerde; use serde_derive_internals::ast as serde_ast; use serde_derive_internals::{Ctxt, Derive}; @@ -34,6 +34,7 @@ pub struct Field<'a> { pub ty: &'a syn::Type, pub original: &'a syn::Field, pub attrs: Attrs, + pub validation_attrs: ValidationAttrs, } impl<'a> Container<'a> { diff --git a/schemars_derive/src/attr/mod.rs b/schemars_derive/src/attr/mod.rs index 65667d9..d7a6628 100644 --- a/schemars_derive/src/attr/mod.rs +++ b/schemars_derive/src/attr/mod.rs @@ -1,7 +1,9 @@ mod doc; mod schemars_to_serde; +mod validation; pub use schemars_to_serde::process_serde_attrs; +pub use validation::ValidationAttrs; use proc_macro2::{Group, Span, TokenStream, TokenTree}; use quote::ToTokens; diff --git a/schemars_derive/src/attr/validation.rs b/schemars_derive/src/attr/validation.rs new file mode 100644 index 0000000..8f13168 --- /dev/null +++ b/schemars_derive/src/attr/validation.rs @@ -0,0 +1,308 @@ +use super::parse_lit_str; +use proc_macro2::TokenStream; +use syn::ExprLit; +use syn::NestedMeta; +use syn::{Expr, Lit, Meta, MetaNameValue, Path}; + +#[derive(Debug, Default)] +pub struct ValidationAttrs { + pub length_min: Option, + pub length_max: Option, + pub length_equal: Option, + pub range_min: Option, + pub range_max: Option, + pub regex: Option, + pub contains: Option, + pub required: bool, + pub format: Option<&'static str>, +} + +impl ValidationAttrs { + pub fn new(attrs: &[syn::Attribute]) -> Self { + // TODO allow setting "validate" attributes through #[schemars(...)] + ValidationAttrs::default().populate(attrs) + } + + fn populate(mut self, attrs: &[syn::Attribute]) -> Self { + // TODO don't silently ignore unparseable attributes + for meta_item in attrs + .iter() + .flat_map(|attr| get_meta_items(attr, "validate")) + .flatten() + { + match &meta_item { + NestedMeta::Meta(Meta::List(meta_list)) if meta_list.path.is_ident("length") => { + for nested in meta_list.nested.iter() { + match nested { + NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("min") => { + self.length_min = str_or_num_to_expr(&nv.lit); + } + NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("max") => { + self.length_max = str_or_num_to_expr(&nv.lit); + } + NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("equal") => { + self.length_equal = str_or_num_to_expr(&nv.lit); + } + _ => {} + } + } + } + + NestedMeta::Meta(Meta::List(meta_list)) if meta_list.path.is_ident("range") => { + for nested in meta_list.nested.iter() { + match nested { + NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("min") => { + self.range_min = str_or_num_to_expr(&nv.lit); + } + NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("max") => { + self.range_max = str_or_num_to_expr(&nv.lit); + } + _ => {} + } + } + } + + NestedMeta::Meta(m) + if m.path().is_ident("required") || m.path().is_ident("required_nested") => + { + self.required = true; + } + + NestedMeta::Meta(m) if m.path().is_ident("email") => { + self.format = Some("email"); + } + + NestedMeta::Meta(m) if m.path().is_ident("url") => { + self.format = Some("uri"); + } + + NestedMeta::Meta(m) if m.path().is_ident("phone") => { + self.format = Some("phone"); + } + + NestedMeta::Meta(Meta::NameValue(MetaNameValue { + path, + lit: Lit::Str(regex), + .. + })) if path.is_ident("regex") => self.regex = parse_lit_str(regex).ok(), + + NestedMeta::Meta(Meta::List(meta_list)) if meta_list.path.is_ident("regex") => { + self.regex = meta_list.nested.iter().find_map(|x| match x { + NestedMeta::Meta(Meta::NameValue(MetaNameValue { + path, + lit: Lit::Str(regex), + .. + })) if path.is_ident("path") => parse_lit_str(regex).ok(), + _ => None, + }); + } + + NestedMeta::Meta(Meta::NameValue(MetaNameValue { + path, + lit: Lit::Str(contains), + .. + })) if path.is_ident("contains") => self.contains = Some(contains.value()), + + NestedMeta::Meta(Meta::List(meta_list)) if meta_list.path.is_ident("contains") => { + self.contains = meta_list.nested.iter().find_map(|x| match x { + NestedMeta::Meta(Meta::NameValue(MetaNameValue { + path, + lit: Lit::Str(contains), + .. + })) if path.is_ident("pattern") => Some(contains.value()), + _ => None, + }); + } + + _ => {} + } + } + self + } + + pub fn validation_statements(&self, field_name: &str) -> TokenStream { + // Assume that the result will be interpolated in a context with the local variable + // `schema_object` - the SchemaObject for the struct that contains this field. + let mut statements = Vec::new(); + + if self.required { + statements.push(quote! { + schema_object.object().required.insert(#field_name.to_owned()); + }); + } + + let mut array_validation = Vec::new(); + let mut number_validation = Vec::new(); + let mut object_validation = Vec::new(); + let mut string_validation = Vec::new(); + + if let Some(length_min) = self + .length_min + .as_ref() + .or_else(|| self.length_equal.as_ref()) + { + string_validation.push(quote! { + validation.min_length = Some(#length_min as u32); + }); + array_validation.push(quote! { + validation.min_items = Some(#length_min as u32); + }); + } + + if let Some(length_max) = self + .length_max + .as_ref() + .or_else(|| self.length_equal.as_ref()) + { + string_validation.push(quote! { + validation.max_length = Some(#length_max as u32); + }); + array_validation.push(quote! { + validation.max_items = Some(#length_max as u32); + }); + } + + if let Some(range_min) = &self.range_min { + number_validation.push(quote! { + validation.minimum = Some(#range_min as f64); + }); + } + + if let Some(range_max) = &self.range_max { + number_validation.push(quote! { + validation.maximum = Some(#range_max as f64); + }); + } + + if let Some(regex) = &self.regex { + string_validation.push(quote! { + validation.pattern = Some(#regex.to_string()); + }); + } + + if let Some(contains) = &self.contains { + object_validation.push(quote! { + validation.required.insert(#contains.to_string()); + }); + + if self.regex.is_none() { + let pattern = crate::regex_syntax::escape(contains); + string_validation.push(quote! { + validation.pattern = Some(#pattern.to_string()); + }); + } + } + + let format = self.format.as_ref().map(|f| { + quote! { + prop_schema_object.format = Some(#f.to_string()); + } + }); + + let array_validation = wrap_array_validation(array_validation); + let number_validation = wrap_number_validation(number_validation); + let object_validation = wrap_object_validation(object_validation); + let string_validation = wrap_string_validation(string_validation); + + if array_validation.is_some() + || number_validation.is_some() + || object_validation.is_some() + || string_validation.is_some() + || format.is_some() + { + statements.push(quote! { + if let Some(schemars::schema::Schema::Object(prop_schema_object)) = schema_object + .object + .as_mut() + .and_then(|o| o.properties.get_mut(#field_name)) + { + #array_validation + #number_validation + #object_validation + #string_validation + #format + } + }); + } + + statements.into_iter().collect() + } +} + +fn wrap_array_validation(v: Vec) -> Option { + if v.is_empty() { + None + } else { + Some(quote! { + if prop_schema_object.has_type(schemars::schema::InstanceType::Array) { + let validation = prop_schema_object.array(); + #(#v)* + } + }) + } +} + +fn wrap_number_validation(v: Vec) -> Option { + if v.is_empty() { + None + } else { + Some(quote! { + if prop_schema_object.has_type(schemars::schema::InstanceType::Integer) + || prop_schema_object.has_type(schemars::schema::InstanceType::Number) { + let validation = prop_schema_object.number(); + #(#v)* + } + }) + } +} + +fn wrap_object_validation(v: Vec) -> Option { + if v.is_empty() { + None + } else { + Some(quote! { + if prop_schema_object.has_type(schemars::schema::InstanceType::Object) { + let validation = prop_schema_object.object(); + #(#v)* + } + }) + } +} + +fn wrap_string_validation(v: Vec) -> Option { + if v.is_empty() { + None + } else { + Some(quote! { + if prop_schema_object.has_type(schemars::schema::InstanceType::String) { + let validation = prop_schema_object.string(); + #(#v)* + } + }) + } +} + +fn get_meta_items( + attr: &syn::Attribute, + attr_type: &'static str, +) -> Result, ()> { + if !attr.path.is_ident(attr_type) { + return Ok(Vec::new()); + } + + match attr.parse_meta() { + Ok(Meta::List(meta)) => Ok(meta.nested.into_iter().collect()), + _ => Err(()), + } +} + +fn str_or_num_to_expr(lit: &Lit) -> Option { + match lit { + Lit::Str(s) => parse_lit_str::(s).ok().map(Expr::Path), + Lit::Int(_) | Lit::Float(_) => Some(Expr::Lit(ExprLit { + attrs: Vec::new(), + lit: lit.clone(), + })), + _ => None, + } +} diff --git a/schemars_derive/src/lib.rs b/schemars_derive/src/lib.rs index c81eb40..c737380 100644 --- a/schemars_derive/src/lib.rs +++ b/schemars_derive/src/lib.rs @@ -9,12 +9,13 @@ extern crate proc_macro; mod ast; mod attr; mod metadata; +mod regex_syntax; mod schema_exprs; use ast::*; use proc_macro2::TokenStream; -#[proc_macro_derive(JsonSchema, attributes(schemars, serde))] +#[proc_macro_derive(JsonSchema, attributes(schemars, serde, validate))] pub fn derive_json_schema_wrapper(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(input as syn::DeriveInput); derive_json_schema(input, false) @@ -72,7 +73,7 @@ fn derive_json_schema( parent: &mut schemars::schema::SchemaObject, name: String, metadata: Option, - required: bool, + required: Option, ) { <#ty as schemars::JsonSchema>::add_schema_as_property(gen, parent, name, metadata, required) } diff --git a/schemars_derive/src/regex_syntax.rs b/schemars_derive/src/regex_syntax.rs new file mode 100644 index 0000000..353bf8d --- /dev/null +++ b/schemars_derive/src/regex_syntax.rs @@ -0,0 +1,26 @@ +// Copied from regex_syntax crate to avoid pulling in the whole crate just for a utility function +// https://github.com/rust-lang/regex/blob/ff283badce21dcebd581909d38b81f2c8c9bfb54/regex-syntax/src/lib.rs + +pub fn escape(text: &str) -> String { + let mut quoted = String::new(); + escape_into(text, &mut quoted); + quoted +} + +fn escape_into(text: &str, buf: &mut String) { + buf.reserve(text.len()); + for c in text.chars() { + if is_meta_character(c) { + buf.push('\\'); + } + buf.push(c); + } +} + +fn is_meta_character(c: char) -> bool { + match c { + '\\' | '.' | '+' | '*' | '?' | '(' | ')' | '|' | '[' | ']' | '{' | '}' | '^' | '$' + | '#' | '&' | '-' | '~' => true, + _ => false, + } +} diff --git a/schemars_derive/src/schema_exprs.rs b/schemars_derive/src/schema_exprs.rs index a5cf2ca..4965fd5 100644 --- a/schemars_derive/src/schema_exprs.rs +++ b/schemars_derive/src/schema_exprs.rs @@ -390,32 +390,43 @@ fn expr_for_struct( let mut type_defs = Vec::new(); - let properties: Vec<_> = property_fields.into_iter().map(|field| { - let name = field.name(); - let default = field_default_expr(field, set_container_default.is_some()); + let properties: Vec<_> = property_fields + .into_iter() + .map(|field| { + let name = field.name(); + let default = field_default_expr(field, set_container_default.is_some()); - let required = match default { - Some(_) => quote!(false), - None => quote!(true), - }; + let required = match (&default, field.validation_attrs.required) { + (Some(_), _) => quote!(Some(false)), + (None, false) => quote!(None), + (None, true) => quote!(Some(true)), + }; - let metadata = &SchemaMetadata { - read_only: field.serde_attrs.skip_deserializing(), - write_only: field.serde_attrs.skip_serializing(), - default, - ..SchemaMetadata::from_attrs(&field.attrs) - }; + let metadata = &SchemaMetadata { + read_only: field.serde_attrs.skip_deserializing(), + write_only: field.serde_attrs.skip_serializing(), + default, + ..SchemaMetadata::from_attrs(&field.attrs) + }; - let (ty, type_def) = type_for_schema(field, type_defs.len()); - if let Some(type_def) = type_def { - type_defs.push(type_def); - } + let (ty, type_def) = type_for_schema(field, type_defs.len()); + if let Some(type_def) = type_def { + type_defs.push(type_def); + } - quote_spanned! {ty.span()=> - <#ty as schemars::JsonSchema>::add_schema_as_property(gen, &mut schema_object, #name.to_owned(), #metadata, #required); - } + let validation = field.validation_attrs.validation_statements(&name); - }).collect(); + quote_spanned! {ty.span()=> + <#ty as schemars::JsonSchema>::add_schema_as_property( + gen, + &mut schema_object, + #name.to_owned(), + #metadata, + #required); + #validation + } + }) + .collect(); let flattens: Vec<_> = flattened_fields .into_iter()