Skip to content
4 changes: 1 addition & 3 deletions src/serializers/fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ impl SerField {
fn exclude_default(value: &Bound<'_, PyAny>, extra: &Extra, serializer: &CombinedSerializer) -> PyResult<bool> {
if extra.exclude_defaults {
if let Some(default) = serializer.get_default(value.py())? {
if value.eq(default)? {
return Ok(true);
}
return Ok(value.eq(default).unwrap_or(false));
}
}
Ok(false)
Expand Down
43 changes: 43 additions & 0 deletions tests/serializers/test_typed_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,49 @@ def test_exclude_default():
assert v.to_json({'foo': 1, 'bar': b'[default]'}, exclude_defaults=True) == b'{"foo":1}'


def test_exclude_incomparable_default():
"""Values that can't be compared with eq are treated as not equal to the default"""

def ser_x(*args):
return [1, 2, 3]

cls_schema = core_schema.any_schema(serialization=core_schema.plain_serializer_function_ser_schema(ser_x))

class Incomparable:
__pydantic_serializer__ = SchemaSerializer(cls_schema)

def __get_pydantic_core_schema__(*args):
return cls_schema

def __eq__(self, other):
raise NotImplementedError("Can't be compared!")

class NeqComparable(Incomparable):
def __eq__(self, other):
return False

class EqComparable(Incomparable):
def __eq__(self, other):
return True

v = SchemaSerializer(
core_schema.typed_dict_schema(
{
'foo': core_schema.typed_dict_field(
core_schema.with_default_schema(core_schema.any_schema(), default=None)
),
}
)
)

assert v.to_python({'foo': Incomparable()}, exclude_defaults=True)['foo'] == [1, 2, 3]
assert v.to_json({'foo': Incomparable()}, exclude_defaults=True) == b'{"foo":[1,2,3]}'
assert v.to_python({'foo': NeqComparable()}, exclude_defaults=True)['foo'] == [1, 2, 3]
assert v.to_json({'foo': NeqComparable()}, exclude_defaults=True) == b'{"foo":[1,2,3]}'
assert v.to_python({'foo': EqComparable()}, exclude_defaults=True) == {}
assert v.to_json({'foo': EqComparable()}, exclude_defaults=True) == b'{}'


def test_function_plain_field_serializer_to_python():
class Model(TypedDict):
x: int
Expand Down