diff --git a/schemars/src/_private.rs b/schemars/src/_private.rs index d4ac4a7..3de9614 100644 --- a/schemars/src/_private.rs +++ b/schemars/src/_private.rs @@ -2,6 +2,7 @@ use crate::gen::SchemaGenerator; use crate::JsonSchema; use crate::Schema; use serde::Serialize; +use serde_json::json; use serde_json::Map; use serde_json::Value; @@ -73,6 +74,45 @@ pub fn new_externally_tagged_enum(variant: &str, sub_schema: Schema) -> Schema { }) } +pub fn apply_internal_enum_tag( + schema: &mut Schema, + tag_name: &str, + variant: &str, + deny_unknown_fields: bool, +) { + let obj = schema.ensure_object(); + let is_unit = obj.get("type").is_some_and(|t| t.as_str() == Some("null")); + + obj.insert("type".to_owned(), "object".into()); + + if let Some(properties) = obj + .entry("properties") + .or_insert(Value::Object(Map::new())) + .as_object_mut() + { + properties.insert( + tag_name.to_string(), + json!({ + "type": "string", + // TODO switch from single-valued "enum" to "const" + "enum": [variant] + }), + ); + } + + if let Some(required) = obj + .entry("required") + .or_insert(Value::Array(Vec::new())) + .as_array_mut() + { + required.insert(0, tag_name.into()); + } + + if deny_unknown_fields && is_unit { + obj.entry("additionalProperties").or_insert(false.into()); + } +} + /// Create a schema for an internally tagged enum pub fn new_internally_tagged_enum( tag_name: &str, diff --git a/schemars/src/flatten.rs b/schemars/src/flatten.rs index fc66b74..69b0361 100644 --- a/schemars/src/flatten.rs +++ b/schemars/src/flatten.rs @@ -9,33 +9,9 @@ impl Schema { /// It should not be considered part of the public API. #[doc(hidden)] pub fn flatten(mut self, other: Self) -> Schema { - // This special null-type-schema handling is here for backward-compatibility, but needs reviewing. - // I think it's only needed to make internally-tagged enum unit variants behave correctly, but that - // should be handled entirely within schemars_derive. - if other - .as_object() - .and_then(|o| o.get("type")) - .and_then(|t| t.as_str()) - == Some("null") - { - return self; - } - - if let Value::Object(mut obj2) = other.to_value() { + if let Value::Object(obj2) = other.to_value() { let obj1 = self.ensure_object(); - let ap2 = obj2.remove("additionalProperties"); - if let Entry::Occupied(mut ap1) = obj1.entry("additionalProperties") { - match ap2 { - Some(ap2) => { - flatten_additional_properties(ap1.get_mut(), ap2); - } - None => { - ap1.remove(); - } - } - } - for (key, value2) in obj2 { match obj1.entry(key) { Entry::Vacant(vacant) => { @@ -93,19 +69,3 @@ impl Schema { self } } - -// TODO validate behaviour when flattening a normal struct into a struct with deny_unknown_fields -fn flatten_additional_properties(v1: &mut Value, v2: Value) { - match (v1, v2) { - (v1, Value::Bool(true)) => { - *v1 = Value::Bool(true); - } - (v1 @ Value::Bool(false), v2) => { - *v1 = v2; - } - (Value::Object(o1), Value::Object(o2)) => { - o1.extend(o2); - } - _ => {} - } -} diff --git a/schemars/tests/expected/enum-internal-duf.json b/schemars/tests/expected/enum-internal-duf.json index 501f13b..7a9bcd7 100644 --- a/schemars/tests/expected/enum-internal-duf.json +++ b/schemars/tests/expected/enum-internal-duf.json @@ -112,10 +112,7 @@ "additionalProperties": false }, { - "type": [ - "object", - "integer" - ], + "type": "object", "format": "int32", "required": [ "typeProperty" diff --git a/schemars/tests/expected/enum-internal.json b/schemars/tests/expected/enum-internal.json index 37739b0..115dbf3 100644 --- a/schemars/tests/expected/enum-internal.json +++ b/schemars/tests/expected/enum-internal.json @@ -28,6 +28,9 @@ "StringMap" ] } + }, + "additionalProperties": { + "type": "string" } }, { @@ -105,10 +108,7 @@ } }, { - "type": [ - "object", - "integer" - ], + "type": "object", "format": "int32", "required": [ "typeProperty" diff --git a/schemars/tests/expected/schema_with-enum-internal.json b/schemars/tests/expected/schema_with-enum-internal.json index 75b28dc..c4a0cc1 100644 --- a/schemars/tests/expected/schema_with-enum-internal.json +++ b/schemars/tests/expected/schema_with-enum-internal.json @@ -21,10 +21,7 @@ } }, { - "type": [ - "object", - "boolean" - ], + "type": "object", "required": [ "typeProperty" ], @@ -38,10 +35,7 @@ } }, { - "type": [ - "object", - "boolean" - ], + "type": "object", "required": [ "typeProperty" ], diff --git a/schemars_derive/src/schema_exprs.rs b/schemars_derive/src/schema_exprs.rs index 585184c..f22831e 100644 --- a/schemars_derive/src/schema_exprs.rs +++ b/schemars_derive/src/schema_exprs.rs @@ -231,19 +231,14 @@ fn expr_for_internal_tagged_enum<'a>( let name = variant.name(); - let mut tag_schema = quote! { - schemars::_private::new_internally_tagged_enum(#tag_name, #name, #deny_unknown_fields) - }; + let mut schema_expr = expr_for_internal_tagged_enum_variant(variant, deny_unknown_fields); + variant.attrs.as_metadata().apply_to_schema(&mut schema_expr); - variant.attrs.as_metadata().apply_to_schema(&mut tag_schema); - - if let Some(variant_schema) = - expr_for_untagged_enum_variant_for_flatten(variant, deny_unknown_fields) - { - tag_schema.extend(quote!(.flatten(#variant_schema))) - } - - tag_schema + quote!({ + let mut schema = #schema_expr; + schemars::_private::apply_internal_enum_tag(&mut schema, #tag_name, #name, #deny_unknown_fields); + schema + }) }) .collect(); @@ -383,10 +378,10 @@ fn expr_for_untagged_enum_variant(variant: &Variant, deny_unknown_fields: bool) } } -fn expr_for_untagged_enum_variant_for_flatten( +fn expr_for_internal_tagged_enum_variant( variant: &Variant, deny_unknown_fields: bool, -) -> Option { +) -> TokenStream { if let Some(with_attr) = &variant.attrs.with { let (ty, type_def) = type_for_schema(with_attr); let gen = quote!(gen); @@ -395,15 +390,15 @@ fn expr_for_untagged_enum_variant_for_flatten( }; prepend_type_def(type_def, &mut schema_expr); - return Some(schema_expr); + return schema_expr; } - Some(match variant.style { - Style::Unit => return None, + match variant.style { + Style::Unit => expr_for_unit_struct(), Style::Newtype => expr_for_field(&variant.fields[0], false), Style::Tuple => expr_for_tuple_struct(&variant.fields), Style::Struct => expr_for_struct(&variant.fields, &SerdeDefault::None, deny_unknown_fields), - }) + } } fn expr_for_unit_struct() -> TokenStream {