diff --git a/serde-reflection/src/de.rs b/serde-reflection/src/de.rs index c08baa634..d7707f4b8 100644 --- a/serde-reflection/src/de.rs +++ b/serde-reflection/src/de.rs @@ -20,18 +20,15 @@ use std::collections::btree_map::{BTreeMap, Entry}; /// `&'a mut` references used to return tracing results. /// * The lifetime 'de is fixed and the `&'de` reference meant to let us /// borrow values from previous serialization runs. -pub(crate) struct Deserializer<'de, 'a> { +pub struct Deserializer<'de, 'a> { tracer: &'a mut Tracer, samples: &'de Samples, format: &'a mut Format, } impl<'de, 'a> Deserializer<'de, 'a> { - pub(crate) fn new( - tracer: &'a mut Tracer, - samples: &'de Samples, - format: &'a mut Format, - ) -> Self { + /// Create a new Deserializer + pub fn new(tracer: &'a mut Tracer, samples: &'de Samples, format: &'a mut Format) -> Self { Deserializer { tracer, samples, @@ -422,9 +419,11 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> { _ => unreachable!(), }; - // If the enum is already marked as incomplete, visit the first index, hoping - // to avoid recursion. - if self.tracer.incomplete_enums.contains_key(enum_name) { + // If the enum is already marked as incomplete and not pending, visit the first index, + // hoping to avoid recursion. + if let Some(EnumProgress::IndexedVariantsRemaining | EnumProgress::NamedVariantsRemaining) = + self.tracer.incomplete_enums.get(enum_name) + { return visitor.visit_enum(EnumDeserializer::new( self.tracer, self.samples, diff --git a/serde-reflection/src/lib.rs b/serde-reflection/src/lib.rs index b34302117..5689cec3b 100644 --- a/serde-reflection/src/lib.rs +++ b/serde-reflection/src/lib.rs @@ -368,7 +368,9 @@ mod value; #[cfg(feature = "json")] pub mod json_converter; +pub use de::Deserializer; pub use error::{Error, Result}; pub use format::{ContainerFormat, Format, FormatHolder, Named, Variable, VariantFormat}; -pub use trace::{Registry, Samples, Tracer, TracerConfig}; +pub use ser::Serializer; +pub use trace::{EnumProgress, Registry, Samples, Tracer, TracerConfig}; pub use value::Value; diff --git a/serde-reflection/src/ser.rs b/serde-reflection/src/ser.rs index f221878dc..7b15e28d4 100644 --- a/serde-reflection/src/ser.rs +++ b/serde-reflection/src/ser.rs @@ -12,13 +12,14 @@ use serde::{ser, Serialize}; /// Serialize a single value. /// The lifetime 'a is set by the serialization call site and the `&'a mut` /// references used to return tracing results and serialization samples. -pub(crate) struct Serializer<'a> { +pub struct Serializer<'a> { tracer: &'a mut Tracer, samples: &'a mut Samples, } impl<'a> Serializer<'a> { - pub(crate) fn new(tracer: &'a mut Tracer, samples: &'a mut Samples) -> Self { + /// Create a new Serializer + pub fn new(tracer: &'a mut Tracer, samples: &'a mut Samples) -> Self { Self { tracer, samples } } } diff --git a/serde-reflection/src/trace.rs b/serde-reflection/src/trace.rs index b3fcf3b46..32e22fa2a 100644 --- a/serde-reflection/src/trace.rs +++ b/serde-reflection/src/trace.rs @@ -36,12 +36,15 @@ pub struct Tracer { pub(crate) discriminants: BTreeMap<(TypeId, VariantId<'static>), Discriminant>, } +/// Type of untraced enum variants #[derive(Copy, Clone, Debug)] -pub(crate) enum EnumProgress { +pub enum EnumProgress { /// There are variant names that have not yet been traced. NamedVariantsRemaining, /// There are variant numbers that have not yet been traced. IndexedVariantsRemaining, + /// Tracing of further variants is pending. + Pending, } #[derive(Eq, PartialEq, Ord, PartialOrd, Debug)] @@ -243,6 +246,24 @@ impl Tracer { Ok((format, value)) } + /// Enable tracing of further variants of a incomplete enum. + /// + /// Marks an enum name as pending in the map of incomplete enums + /// and returns which type of variant tracing still needs to be performed. + /// + /// Call this in order to (simultaneously): + /// + /// * determine whether all variants of an enum have been traced, + /// * determine which type of variant tracing ([`EnumProgress`]) still needs to be + /// performed, and + /// * allow `Deserializer`/`trace_type_once` to make progress on a top level enum by + /// enabling tracing the next variant. + pub fn pend_enum(&mut self, name: &str) -> Option { + self.incomplete_enums + .get_mut(name) + .map(|p| std::mem::replace(p, EnumProgress::Pending)) + } + /// Same as `trace_type_once` but if `T` is an enum, we repeat the process /// until all variants of `T` are covered. /// We accumulate and return all the sampled values at the end. @@ -255,9 +276,12 @@ impl Tracer { let (format, value) = self.trace_type_once::(samples)?; values.push(value); if let Format::TypeName(name) = &format { - if let Some(&progress) = self.incomplete_enums.get(name) { + if let Some(progress) = self.pend_enum(name) { + assert!( + !matches!(progress, EnumProgress::Pending), + "failed to make progress tracing enum {name}" + ); // Restart the analysis to find more variants of T. - self.incomplete_enums.remove(name); if let EnumProgress::NamedVariantsRemaining = progress { values.pop().unwrap(); } @@ -294,9 +318,12 @@ impl Tracer { let (format, value) = self.trace_type_once_with_seed(samples, seed.clone())?; values.push(value); if let Format::TypeName(name) = &format { - if let Some(&progress) = self.incomplete_enums.get(name) { + if let Some(progress) = self.pend_enum(name) { + assert!( + !matches!(progress, EnumProgress::Pending), + "failed to make progress tracing enum {name}" + ); // Restart the analysis to find more variants of T. - self.incomplete_enums.remove(name); if let EnumProgress::NamedVariantsRemaining = progress { values.pop().unwrap(); } diff --git a/serde-reflection/tests/serde.rs b/serde-reflection/tests/serde.rs index 7a617d9dd..5a9f0fabd 100644 --- a/serde-reflection/tests/serde.rs +++ b/serde-reflection/tests/serde.rs @@ -3,8 +3,8 @@ use serde::{de::IntoDeserializer, Deserialize, Serialize}; use serde_reflection::{ - ContainerFormat, Error, Format, FormatHolder, Named, Samples, Tracer, TracerConfig, Value, - VariantFormat, + ContainerFormat, Deserializer, Error, Format, FormatHolder, Named, Samples, Tracer, + TracerConfig, Value, VariantFormat, }; use std::collections::BTreeMap; @@ -505,3 +505,109 @@ fn test_default_value_for_primitive_types() { assert_eq!(format, Format::Str); assert_eq!(value, "A borrowed str"); } + +#[test] +fn test_deserializer() { + #![allow(unused)] + + // This shows using `serde_reflection::Deserializer` directly + // and not through `Tracer::trace_*()`. + // An analogous use case exists for `serde_reflection::Serializer`. + + #[derive(Deserialize)] + enum E { + A, + B(f32), + } + + struct S { + e: E, + f: i8, + } + + impl S { + const FIELDS: &[&str] = &["e", "f"]; + + /// Probe a field given name and deserializer + fn probe<'de, D>(key: &str, value: D) -> Result<(), D::Error> + where + D: serde::Deserializer<'de>, + { + match key { + "e" => { + E::deserialize(value)?; + } + "f" => { + i8::deserialize(value)?; + } + _ => unimplemented!(), + } + Ok(()) + } + } + + // Now build a schema for S + + let mut tracer = Tracer::new(TracerConfig::default()); + let samples = Samples::new(); + + let formats: Vec<_> = S::FIELDS + .iter() + .map(|field| { + loop { + let mut format = Format::unknown(); + let deserializer = Deserializer::new(&mut tracer, &samples, &mut format); + S::probe(field, deserializer).unwrap(); + format.reduce(); + if let Format::TypeName(name) = &format { + if let Some(progress) = tracer.pend_enum(name) { + // If an attempt is made at retreiving the registry at this point + // the incomplete enum is still in incomplete_enums and we get an Err(). + // If we had removed it from incomplete_enums, a registry retrieval at this point + // would (wrongly) succeed. + // assert!(tracer.registry().is_err()); + assert!( + !matches!(progress, serde_reflection::EnumProgress::Pending), + "failed to make progress tracing enum {name}" + ); + // Restart the analysis to find more variants. + continue; + } + } + break (*field, format); + } + }) + .collect(); + + assert_eq!( + formats, + vec![("e", Format::TypeName("E".into())), ("f", Format::I8)] + ); + + assert_eq!( + tracer.registry().unwrap(), + [( + "E".into(), + ContainerFormat::Enum( + [ + ( + 0, + Named { + name: "A".into(), + value: VariantFormat::Unit + } + ), + ( + 1, + Named { + name: "B".into(), + value: VariantFormat::NewType(Format::F32.into()) + } + ) + ] + .into() + ) + )] + .into() + ); +}