From 3567e85a1df99e5f6dd2e763f8fa98d68502b707 Mon Sep 17 00:00:00 2001 From: Cormac Relf Date: Mon, 17 Nov 2025 13:20:07 +1100 Subject: [PATCH] Implement #[serde(with = "...")] for containers --- serde_derive/src/de.rs | 16 ++- serde_derive/src/internals/attr.rs | 40 ++++++ serde_derive/src/ser.rs | 16 ++- test_suite/tests/test_annotations.rs | 134 ++++++++++++++++++ .../tests/ui/with-container/incorrect_type.rs | 26 ++++ .../ui/with-container/incorrect_type.stderr | 111 +++++++++++++++ 6 files changed, 341 insertions(+), 2 deletions(-) create mode 100644 test_suite/tests/ui/with-container/incorrect_type.rs create mode 100644 test_suite/tests/ui/with-container/incorrect_type.stderr diff --git a/serde_derive/src/de.rs b/serde_derive/src/de.rs index 38408e9f1..0ef9b6cff 100644 --- a/serde_derive/src/de.rs +++ b/serde_derive/src/de.rs @@ -304,7 +304,9 @@ fn borrowed_lifetimes(cont: &Container) -> BorrowedLifetimes { } fn deserialize_body(cont: &Container, params: &Parameters) -> Fragment { - if cont.attrs.transparent() { + if let Some(path) = cont.attrs.deserialize_with() { + deserialize_with(path) + } else if cont.attrs.transparent() { deserialize_transparent(cont, params) } else if let Some(type_from) = cont.attrs.type_from() { deserialize_from(type_from) @@ -336,6 +338,7 @@ fn deserialize_in_place_body(cont: &Container, params: &Parameters) -> Option Fragment { } } +fn deserialize_with(deserialize_with: &syn::ExprPath) -> Fragment { + // Attach type errors to the path in #[serialize(deserialize_with = "path")] + let deserializer_arg = quote!(__deserializer); + let wrapper_deserialize = quote_spanned! {deserialize_with.span()=> + #deserialize_with(#deserializer_arg) + }; + quote_block! { + #wrapper_deserialize + } +} + enum TupleForm<'a> { Tuple, /// Contains a variant name diff --git a/serde_derive/src/internals/attr.rs b/serde_derive/src/internals/attr.rs index df0f33908..e1fed9830 100644 --- a/serde_derive/src/internals/attr.rs +++ b/serde_derive/src/internals/attr.rs @@ -172,6 +172,8 @@ pub struct Container { /// Error message generated when type can't be deserialized expecting: Option, non_exhaustive: bool, + serialize_with: Option, + deserialize_with: Option, } /// Styles of representing an enum. @@ -258,6 +260,8 @@ impl Container { let mut serde_path = Attr::none(cx, CRATE); let mut expecting = Attr::none(cx, EXPECTING); let mut non_exhaustive = false; + let mut serialize_with = Attr::none(cx, SERIALIZE_WITH); + let mut deserialize_with = Attr::none(cx, DESERIALIZE_WITH); for attr in &item.attrs { if attr.path() != SERDE { @@ -490,6 +494,32 @@ impl Container { if let Some(s) = get_lit_str(cx, EXPECTING, &meta)? { expecting.set(&meta.path, s.value()); } + } else if meta.path == WITH { + // #[serde(with = "...")] + if let Some(path) = parse_lit_into_expr_path(cx, WITH, &meta)? { + let mut ser_path = path.clone(); + ser_path + .path + .segments + .push(Ident::new("serialize", ser_path.span()).into()); + serialize_with.set(&meta.path, ser_path); + let mut de_path = path; + de_path + .path + .segments + .push(Ident::new("deserialize", de_path.span()).into()); + deserialize_with.set(&meta.path, de_path); + } + } else if meta.path == SERIALIZE_WITH { + // #[serde(serialize_with = "...")] + if let Some(path) = parse_lit_into_expr_path(cx, SERIALIZE_WITH, &meta)? { + serialize_with.set(&meta.path, path); + } + } else if meta.path == DESERIALIZE_WITH { + // #[serde(deserialize_with = "...")] + if let Some(path) = parse_lit_into_expr_path(cx, DESERIALIZE_WITH, &meta)? { + deserialize_with.set(&meta.path, path); + } } else { let path = meta.path.to_token_stream().to_string().replace(' ', ""); return Err( @@ -541,6 +571,8 @@ impl Container { is_packed, expecting: expecting.get(), non_exhaustive, + serialize_with: serialize_with.get(), + deserialize_with: deserialize_with.get(), } } @@ -617,6 +649,14 @@ impl Container { pub fn non_exhaustive(&self) -> bool { self.non_exhaustive } + + pub fn serialize_with(&self) -> Option<&syn::ExprPath> { + self.serialize_with.as_ref() + } + + pub fn deserialize_with(&self) -> Option<&syn::ExprPath> { + self.deserialize_with.as_ref() + } } fn decide_tag( diff --git a/serde_derive/src/ser.rs b/serde_derive/src/ser.rs index 6cec87cf5..9d4dcb4bb 100644 --- a/serde_derive/src/ser.rs +++ b/serde_derive/src/ser.rs @@ -168,7 +168,9 @@ fn needs_serialize_bound(field: &attr::Field, variant: Option<&attr::Variant>) - } fn serialize_body(cont: &Container, params: &Parameters) -> Fragment { - if cont.attrs.transparent() { + if let Some(path) = cont.attrs.serialize_with() { + serialize_with(params, path) + } else if cont.attrs.transparent() { serialize_transparent(cont, params) } else if let Some(type_into) = cont.attrs.type_into() { serialize_into(params, type_into) @@ -219,6 +221,18 @@ fn serialize_into(params: &Parameters, type_into: &syn::Type) -> Fragment { } } +fn serialize_with(params: &Parameters, path: &syn::ExprPath) -> Fragment { + let self_var = ¶ms.self_var; + // Attach type errors to the path in #[serialize(serialize_with = "path")] + let serializer_var = quote!(__serializer); + let wrapper_serialize = quote_spanned! {path.span()=> + #path(#self_var, #serializer_var) + }; + quote_expr! { + #wrapper_serialize + } +} + fn serialize_unit_struct(cattrs: &attr::Container) -> Fragment { let type_name = cattrs.name().serialize_name(); diff --git a/test_suite/tests/test_annotations.rs b/test_suite/tests/test_annotations.rs index 3f02a61fc..6aaabf944 100644 --- a/test_suite/tests/test_annotations.rs +++ b/test_suite/tests/test_annotations.rs @@ -42,6 +42,27 @@ trait DeserializeWith: Sized { D: Deserializer<'de>; } +mod with { + use super::{DeserializeWith, SerializeWith}; + use serde::{Deserializer, Serializer}; + + pub fn serialize(value: &T, ser: S) -> Result + where + S: Serializer, + T: SerializeWith, + { + T::serialize_with(value, ser) + } + + pub fn deserialize<'de, D, T>(de: D) -> Result + where + D: Deserializer<'de>, + T: DeserializeWith, + { + T::deserialize_with(de) + } +} + impl MyDefault for i32 { fn my_default() -> Self { 123 @@ -1158,6 +1179,119 @@ fn test_serialize_with_enum() { ); } +macro_rules! bool_container { + ($name:ident, $container:meta) => { + #[derive(Debug, PartialEq, Serialize, Deserialize)] + #[$container] + enum $name { + True, + False, + } + + impl SerializeWith for $name { + fn serialize_with(&self, ser: S) -> Result + where + S: Serializer, + { + let boolean = match self { + $name::True => true, + $name::False => false, + }; + boolean.serialize(ser) + } + } + + impl DeserializeWith for $name { + fn deserialize_with<'de, D>(de: D) -> Result + where + D: Deserializer<'de>, + { + let b = bool::deserialize(de)?; + match b { + true => Ok($name::True), + false => Ok($name::False), + } + } + } + }; +} + +macro_rules! int_container { + ($name:ident, $container:meta) => { + #[derive(Debug, PartialEq, Serialize, Deserialize)] + #[$container] + struct $name { + field: i32, + } + + impl SerializeWith for $name { + fn serialize_with(&self, ser: S) -> Result + where + S: Serializer, + { + self.field.to_string().serialize(ser) + } + } + + impl DeserializeWith for $name { + fn deserialize_with<'de, D>(de: D) -> Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(de)?; + let field = s.parse().unwrap(); + Ok(Self { field }) + } + } + }; +} + +#[test] +fn test_container_with_enum() { + bool_container!( + BoolSerializeWith, + serde(serialize_with = "SerializeWith::serialize_with") + ); + assert_ser_tokens(&BoolSerializeWith::True, &[Token::Bool(true)]); + assert_ser_tokens(&BoolSerializeWith::False, &[Token::Bool(false)]); + + bool_container!( + BoolDeserializeWith, + serde(deserialize_with = "DeserializeWith::deserialize_with") + ); + assert_de_tokens(&BoolDeserializeWith::True, &[Token::Bool(true)]); + assert_de_tokens(&BoolDeserializeWith::False, &[Token::Bool(false)]); + + bool_container!(BoolWith, serde(with = "with")); + assert_ser_tokens(&BoolWith::True, &[Token::Bool(true)]); + assert_ser_tokens(&BoolWith::False, &[Token::Bool(false)]); + assert_de_tokens(&BoolWith::True, &[Token::Bool(true)]); + assert_de_tokens(&BoolWith::False, &[Token::Bool(false)]); +} + +#[test] +fn test_container_with_struct() { + int_container!( + IntS, + serde(serialize_with = "SerializeWith::serialize_with") + ); + assert_ser_tokens(&IntS { field: 42 }, &[Token::Str("42")]); + assert_ser_tokens(&IntS { field: 123 }, &[Token::Str("123")]); + + int_container!( + IntD, + serde(deserialize_with = "DeserializeWith::deserialize_with") + ); + assert_de_tokens(&IntD { field: 42 }, &[Token::Str("42")]); + assert_de_tokens(&IntD { field: 123 }, &[Token::Str("123")]); + + int_container!(IntW, serde(with = "with")); + assert_ser_tokens(&IntW { field: 42 }, &[Token::Str("42")]); + assert_ser_tokens(&IntW { field: 123 }, &[Token::Str("123")]); + assert_de_tokens(&IntW { field: 42 }, &[Token::Str("42")]); + assert_de_tokens(&IntW { field: 123 }, &[Token::Str("123")]); +} + #[derive(Debug, PartialEq, Serialize, Deserialize)] enum WithVariant { #[serde(serialize_with = "serialize_unit_variant_as_i8")] diff --git a/test_suite/tests/ui/with-container/incorrect_type.rs b/test_suite/tests/ui/with-container/incorrect_type.rs new file mode 100644 index 000000000..38aa74f36 --- /dev/null +++ b/test_suite/tests/ui/with-container/incorrect_type.rs @@ -0,0 +1,26 @@ +use serde_derive::{Deserialize, Serialize}; + +mod w { + use serde::{Deserializer, Serializer}; + + pub fn deserialize<'de, D: Deserializer<'de>>(_: D) -> Result<(), D::Error> { + unimplemented!() + } + pub fn serialize(_: S) -> Result { + unimplemented!() + } +} + +#[derive(Serialize, Deserialize)] +#[serde(with = "w")] +struct W(u8); + +#[derive(Serialize, Deserialize)] +#[serde(serialize_with = "w::serialize")] +struct S(u8); + +#[derive(Serialize, Deserialize)] +#[serde(deserialize_with = "w::deserialize")] +struct D(u8); + +fn main() {} diff --git a/test_suite/tests/ui/with-container/incorrect_type.stderr b/test_suite/tests/ui/with-container/incorrect_type.stderr new file mode 100644 index 000000000..41abbec56 --- /dev/null +++ b/test_suite/tests/ui/with-container/incorrect_type.stderr @@ -0,0 +1,111 @@ +error[E0277]: the trait bound `&W: serde::Serializer` is not satisfied + --> tests/ui/with-container/incorrect_type.rs:14:10 + | +14 | #[derive(Serialize, Deserialize)] + | ^^^^^^^^^ the trait `Serializer` is not implemented for `&W` +15 | #[serde(with = "w")] + | --- required by a bound introduced by this call + | +help: the trait `Serializer` is implemented for `&mut Formatter<'a>` + --> $WORKSPACE/serde_core/src/ser/fmt.rs + | + | impl<'a> Serializer for &mut fmt::Formatter<'a> { + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +note: required by a bound in `w::serialize` + --> tests/ui/with-container/incorrect_type.rs:9:28 + | + 9 | pub fn serialize(_: S) -> Result { + | ^^^^^^^^^^ required by this bound in `serialize` + +error[E0061]: this function takes 1 argument but 2 arguments were supplied + --> tests/ui/with-container/incorrect_type.rs:15:16 + | +14 | #[derive(Serialize, Deserialize)] + | --------- unexpected argument #2 of type `__S` +15 | #[serde(with = "w")] + | ^^^ + | +note: function defined here + --> tests/ui/with-container/incorrect_type.rs:9:12 + | + 9 | pub fn serialize(_: S) -> Result { + | ^^^^^^^^^ + +error[E0277]: the trait bound `&W: serde::Serializer` is not satisfied + --> tests/ui/with-container/incorrect_type.rs:15:16 + | +15 | #[serde(with = "w")] + | ^^^ the trait `Serializer` is not implemented for `&W` + | +help: the trait `Serializer` is implemented for `&mut Formatter<'a>` + --> $WORKSPACE/serde_core/src/ser/fmt.rs + | + | impl<'a> Serializer for &mut fmt::Formatter<'a> { + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +error[E0308]: mismatched types + --> tests/ui/with-container/incorrect_type.rs:15:16 + | +14 | #[derive(Serialize, Deserialize)] + | ----------- expected `Result>::Error>` because of return type +15 | #[serde(with = "w")] + | ^^^ expected `Result`, found `Result<(), ...>` + | + = note: expected enum `Result>::Error>` + found enum `Result<(), <__D as Deserializer<'_>>::Error>` + +error[E0277]: the trait bound `&S: serde::Serializer` is not satisfied + --> tests/ui/with-container/incorrect_type.rs:18:10 + | +18 | #[derive(Serialize, Deserialize)] + | ^^^^^^^^^ the trait `Serializer` is not implemented for `&S` +19 | #[serde(serialize_with = "w::serialize")] + | -------------- required by a bound introduced by this call + | +help: the trait `Serializer` is implemented for `&mut Formatter<'a>` + --> $WORKSPACE/serde_core/src/ser/fmt.rs + | + | impl<'a> Serializer for &mut fmt::Formatter<'a> { + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +note: required by a bound in `w::serialize` + --> tests/ui/with-container/incorrect_type.rs:9:28 + | + 9 | pub fn serialize(_: S) -> Result { + | ^^^^^^^^^^ required by this bound in `serialize` + +error[E0061]: this function takes 1 argument but 2 arguments were supplied + --> tests/ui/with-container/incorrect_type.rs:19:26 + | +18 | #[derive(Serialize, Deserialize)] + | --------- unexpected argument #2 of type `__S` +19 | #[serde(serialize_with = "w::serialize")] + | ^^^^^^^^^^^^^^ + | +note: function defined here + --> tests/ui/with-container/incorrect_type.rs:9:12 + | + 9 | pub fn serialize(_: S) -> Result { + | ^^^^^^^^^ + +error[E0277]: the trait bound `&S: serde::Serializer` is not satisfied + --> tests/ui/with-container/incorrect_type.rs:19:26 + | +19 | #[serde(serialize_with = "w::serialize")] + | ^^^^^^^^^^^^^^ the trait `Serializer` is not implemented for `&S` + | +help: the trait `Serializer` is implemented for `&mut Formatter<'a>` + --> $WORKSPACE/serde_core/src/ser/fmt.rs + | + | impl<'a> Serializer for &mut fmt::Formatter<'a> { + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +error[E0308]: mismatched types + --> tests/ui/with-container/incorrect_type.rs:23:28 + | +22 | #[derive(Serialize, Deserialize)] + | ----------- expected `Result>::Error>` because of return type +23 | #[serde(deserialize_with = "w::deserialize")] + | ^^^^^^^^^^^^^^^^ expected `Result`, found `Result<(), ...>` + | + = note: expected enum `Result>::Error>` + found enum `Result<(), <__D as Deserializer<'_>>::Error>`