From 3b96ace6c653fb5c34fd96bddb5023fe0acddf8d Mon Sep 17 00:00:00 2001 From: Alberto Islas Date: Wed, 9 Jul 2025 15:27:55 -0600 Subject: [PATCH 1/4] feat(drf-dynamic-fields): Added support for filtering nested fields. --- .gitignore | 1 + drf_dynamic_fields/__init__.py | 189 ++++++++++++++++++++++++++++++-- tests/models.py | 9 ++ tests/serializers.py | 15 ++- tests/test_mixins.py | 193 ++++++++++++++++++++++++++++++++- 5 files changed, 396 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 39ca09e..203fcd0 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist/ *.egg-info/ build/ .tox/ +.idea diff --git a/drf_dynamic_fields/__init__.py b/drf_dynamic_fields/__init__.py index 97694e4..e2a819e 100644 --- a/drf_dynamic_fields/__init__.py +++ b/drf_dynamic_fields/__init__.py @@ -1,8 +1,10 @@ """ Mixin to dynamically select only a subset of fields per DRF resource. """ + import warnings +from collections import defaultdict from django.conf import settings from django.utils.functional import cached_property @@ -20,7 +22,7 @@ def fields(self): A blank `fields` parameter (?fields) will remove all fields. Not passing `fields` will pass all fields individual fields are comma - separated (?fields=id,name,url,email). + separated (?fields=id,name,url,email,teachers__age). """ fields = super(DynamicFieldsMixin, self).fields @@ -58,7 +60,7 @@ def fields(self): try: filter_fields = params.get("fields", None).split(",") except AttributeError: - filter_fields = None + filter_fields = [] try: omit_fields = params.get("omit", None).split(",") @@ -66,15 +68,38 @@ def fields(self): omit_fields = [] # Drop any fields that are not specified in the `fields` argument. + self._nested_allow = defaultdict(list) + self._nested_omit = defaultdict(list) + self._flat_allow = set() + self._flat_omit = set() + + # store top-level and nested fields specified in the `fields` argument. + for filtered_field in filter_fields: + if "__" in filtered_field: + parent, child = filtered_field.split("__", 1) + self._nested_allow[parent].append(child) + # If a nested field is allowed the related parent level field + # must also be allowed + self._flat_allow.add(parent) + else: + self._flat_allow.add(filtered_field) + + # store top-level and nested fields in the `omit` argument. + for omitted_field in omit_fields: + if "__" in omitted_field: + parent, child = omitted_field.split("__", 1) + self._nested_omit[parent].append(child) + else: + self._flat_omit.add(omitted_field) + + # Drop top-level fields existing = set(fields.keys()) - if filter_fields is None: + if "fields" in params: + allowed = self._flat_allow + else: # no fields param given, don't filter. allowed = existing - else: - allowed = set(filter(None, filter_fields)) - - # omit fields in the `omit` argument. - omitted = set(filter(None, omit_fields)) + omitted = self._flat_omit for field in existing: @@ -84,4 +109,152 @@ def fields(self): if field in omitted: fields.pop(field, None) + # Drop omitted and non-allowed child fields from nested serializers + self.prune_nested_fields(fields) + return fields + + @staticmethod + def _get_nested_serializer(parent, fields): + """Return the nested serializer for a parent field, or None.""" + field = fields.get(parent) + if not field: + return None + nested = getattr(field, "child", field) + return nested if hasattr(nested, "fields") else None + + def prune_nested_fields(self, fields): + """Prune valid nested serializer fields based on the _nested_omit and _nested_allow lists.""" + valid_parents = ( + self._nested_omit.keys() | self._nested_allow.keys() + ) & fields.keys() + for parent in valid_parents: + nested_serializer = self._get_nested_serializer(parent, fields) + if nested_serializer is None: + continue + + # Drop omitted child fields from nested serializers + for name in self._nested_omit.get(parent, ()): + nested_serializer.fields.pop(name, None) + + # Drop non-allowed child fields from the nested serializers + allow_list = self._nested_allow.get(parent) + if allow_list: + nested_serializer.fields = { + name: field + for name, field in nested_serializer.fields.items() + if name in allow_list + } + + def _filter_top_level_fields_to_defer(self, field_list, keep_if): + """Method to retrieve the top-level fields to defer, given a list of + field names (allow or omit) and a condition that determines which + fields to defer. + """ + model = getattr(self.Meta, "model", None) + if not field_list or model is None: + return [] + + all_fields = [ + f.name for f in model._meta.get_fields() if getattr(f, "concrete", False) + ] + return [name for name in all_fields if keep_if(name, field_list)] + + def _get_disallowed_top_level_fields_to_defer(self): + """Determine which top-level model fields should be deferred when an + explicit fields filter is in use. + Other model fields not explicitly included in 'fields' are deferred. + """ + allow = getattr(self, "_flat_allow", None) + return self._filter_top_level_fields_to_defer( + allow, keep_if=lambda name, allow: name not in allow + ) + + def _get_omit_top_level_fields_to_defer(self): + """ + Determine which top-level model fields should be deferred when an + explicit omit filter is in use. Valid database fields in the omit list + will be deferred. + """ + omit = getattr(self, "_flat_omit", None) + return self._filter_top_level_fields_to_defer( + omit, keep_if=lambda name, omit: name in omit + ) + + def _get_nested_level_fields_to_defer(self, nested_mapping, should_defer): + """Method to retrieve nested fields to defer based on a mapping and a + defer condition. + """ + fields_to_defer = [] + for parent, items in nested_mapping.items(): + field = self.fields.get(parent) + if not field: + continue + + child_serializer = getattr(field, "child", field) + nested_model = getattr(child_serializer.Meta, "model", None) + if nested_model is None: + continue + + # Filter out nested fields that have a database column associated + # with them. + field_names = [ + f.name + for f in nested_model._meta.get_fields() + if getattr(f, "concrete", False) + ] + + # Determine which nested fields to defer + for name in field_names: + if should_defer(name, items): + fields_to_defer.append(f"{parent}__{name}") + + return fields_to_defer + + def _get_disallowed_nested_level_fields_to_defer(self): + """Determine which top-level model fields should be deferred when an + explicit fields filter is in use. + Other model fields not explicitly included in 'fields' are deferred. + """ + allow_map = getattr(self, "_nested_allow", {}) + return self._get_nested_level_fields_to_defer( + allow_map, + should_defer=lambda name, allow_list: name not in allow_list, + ) + + def _get_omit_nested_level_fields_to_defer(self): + """ + Determine which nested-model fields should be deferred for each nested + serializer when an explicit omit filter is in use. + """ + omit_map = getattr(self, "_nested_omit", {}) + return self._get_nested_level_fields_to_defer( + omit_map, should_defer=lambda name, omit_list: name in omit_list + ) + + def get_model_fields_to_defer(self): + """ + Returns a flat list of filtered model-fields; top-level and nested. + Ensures that parsing of 'fields'/'omit' has run by accessing '.fields'. + """ + + # Trigger parsing of required attributes if not already set + if not all( + hasattr(self, attr) + for attr in ( + "_flat_omit", + "_nested_omit", + "_flat_allow", + "_nested_allow", + ) + ): + _ = self.fields + + deferred = [ + *self._get_omit_top_level_fields_to_defer(), + *self._get_omit_nested_level_fields_to_defer(), + *self._get_disallowed_top_level_fields_to_defer(), + *self._get_disallowed_nested_level_fields_to_defer(), + ] + # Remove any duplicate fields + return list(set(deferred)) diff --git a/tests/models.py b/tests/models.py index 714098e..884697a 100644 --- a/tests/models.py +++ b/tests/models.py @@ -14,3 +14,12 @@ class School(models.Model): name = models.CharField(max_length=30) teachers = models.ManyToManyField(Teacher) + + +class Child(models.Model): + secret = models.CharField(max_length=100) + public = models.CharField(max_length=100) + + +class Parent(models.Model): + child = models.ForeignKey(Child, on_delete=models.CASCADE) \ No newline at end of file diff --git a/tests/serializers.py b/tests/serializers.py index 3619d59..766891b 100644 --- a/tests/serializers.py +++ b/tests/serializers.py @@ -5,7 +5,7 @@ from drf_dynamic_fields import DynamicFieldsMixin -from .models import Teacher, School +from .models import Teacher, School, Child class TeacherSerializer(DynamicFieldsMixin, serializers.ModelSerializer): @@ -40,3 +40,16 @@ class SchoolSerializer(DynamicFieldsMixin, serializers.ModelSerializer): class Meta: model = School fields = ("id", "teachers", "name") + + +class ChildSerializer(DynamicFieldsMixin, serializers.Serializer): + secret = serializers.CharField() + public = serializers.CharField() + + class Meta: + model = Child + + +class ParentSerializer(DynamicFieldsMixin, serializers.Serializer): + id = serializers.IntegerField() + child = ChildSerializer() \ No newline at end of file diff --git a/tests/test_mixins.py b/tests/test_mixins.py index f54e179..dfbce95 100644 --- a/tests/test_mixins.py +++ b/tests/test_mixins.py @@ -11,8 +11,8 @@ from django.test import TestCase, RequestFactory -from .serializers import SchoolSerializer, TeacherSerializer -from .models import Teacher, School +from .serializers import SchoolSerializer, TeacherSerializer, ParentSerializer +from .models import Teacher, School, Child, Parent class TestDynamicFieldsMixin(TestCase): @@ -20,6 +20,37 @@ class TestDynamicFieldsMixin(TestCase): Test case for the DynamicFieldsMixin """ + def _assert_nested_fields(self, data, expected_fields): + """ + Assert nested fields match the expected fields. + """ + for parent, nested_fields in expected_fields.items(): + with self.subTest(parent=parent): + items = data[parent] + if nested_fields is None: + continue + expected_set = set(nested_fields) + for obj in items: + with self.subTest(parent=parent): + actual_set = set(obj.keys()) + self.assertEqual( + actual_set, + expected_set, + f"{parent} fields mismatch: expected " + f"exactly {nested_fields}, got {list(obj.keys())}", + ) + + @staticmethod + def _prepare_school_instance(): + """Prepare school instance for testing.""" + school = School.objects.create(name="Python Heights High") + teachers = [ + Teacher.objects.create(name="Shane", age=45), + Teacher.objects.create(name="Kaz", age=29), + ] + school.teachers.add(*teachers) + return school + def test_removes_fields(self): """ Does it actually remove fields? @@ -175,3 +206,161 @@ def test_serializer_reuse_with_changing_request(self): request2 = rf.get("/api/v1/schools/1/?fields=id,name") serializer.context["request"] = request2 self.assertEqual(set(serializer.fields.keys()), {"id"}) + + def test_omit_nested_field(self): + """Omitting a nested field""" + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?omit=invalid,name,teachers__age,teachers__invalid") + + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + data = serializer.data + + # Confirm omitted fields are in deferred list + deferred = set(serializer.get_model_fields_to_defer()) + self.assertEqual({"name", "teachers__age"}, deferred) + + expected_fields = {"id": None, "teachers": ["id", "name", "request_info"]} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_omit_everything_nested_field(self): + """Omitting all fields within a nested field""" + rf = RequestFactory() + request = rf.get( + "/api/v1/schools/1/?omit=teachers__id,teachers__age,teachers__name,teachers__request_info" + ) + + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + data = serializer.data + + expected_fields = {"id": None, "name": None, "teachers": []} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_omit_top_field_and_keep_all_nested_fields(self): + """Omitting a top-level field while keeping all nested fields""" + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?omit=name") + + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + data = serializer.data + + expected_fields = { + "id": None, + "teachers": ["id", "name", "request_info", "age"], + } + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_allow_nested_field(self): + """Select only the requested fields, including nested-level fields.""" + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?fields=invalid,id,teachers__age,teachers__invalid") + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + + # Confirm omitted fields are in deferred list + deferred = set(serializer.get_model_fields_to_defer()) + self.assertEqual({"name","teachers__name", "teachers__id"}, deferred) + + data = serializer.data + expected_fields = {"id": None, "teachers": ["age"]} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_fields_all_gone_nested(self): + """If no fields are selected, all fields are omitted, including those + from the nested serializer. + """ + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?fields") + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + + data = serializer.data + expected_fields = {} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_nested_omit_and_fields_used(self): + """Omit and fields can be used together at the nested field level.""" + rf = RequestFactory() + request = rf.get( + "/api/v1/schools/1/?fields=id,name,teachers__name,teachers__age&omit=name,teachers__name" + ) + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + + data = serializer.data + expected_fields = {"id": None, "teachers": ["age"]} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_omit_nothing_nested(self): + """ + Blank omit doesn't affect nested fields. + """ + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?omit") + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + + data = serializer.data + expected_fields = { + "id": None, + "name": None, + "teachers": ["id", "age", "name", "request_info"], + } + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_single_nested_instance_omit_field(self): + """Omit also works for filtering fields on single nested instances""" + child = Child(secret="secret_key", public="public_key") + parent = Parent(id=1, child=child) + rf = RequestFactory() + request = rf.get("/api/v1/parent/1/?omit=id,child__secret") + serializer = ParentSerializer(parent, context={"request": request}) + data = serializer.data + + self.assertNotIn("id", data) + self.assertNotIn("secret", data["child"]) + self.assertEqual(data["child"]["public"], "public_key") + + def test_single_nested_instance_allow_field(self): + """Fields selection also works for filtering fields on single nested instances""" + child = Child(secret="secret_key", public="public_key") + parent = Parent(id=1, child=child) + rf = RequestFactory() + request = rf.get("/api/v1/parent/1/?fields=id,child__secret") + serializer = ParentSerializer(parent, context={"request": request}) + data = serializer.data + + self.assertEqual(data["id"], 1) + self.assertIn("secret", data["child"]) + self.assertEqual(data["child"]["secret"], "secret_key") + self.assertNotIn("public", data["child"]) \ No newline at end of file From 059eed0cb6c0218b9aceba14ebd52b6c0a20e5fe Mon Sep 17 00:00:00 2001 From: Jervis Whitley Date: Thu, 10 Jul 2025 03:33:23 +0000 Subject: [PATCH 2/4] Uses a nested filtering approach to filter any level of serializer Introduces the new code in a separate serializer that inherits from the existing one so that dynamic field enjoyers can opt into the new functionality. However the new field seems to be 100% backwards compatible with the old one, save for some performance hit since we are doing extra work. --- drf_dynamic_fields/__init__.py | 273 +++++++++++---------------------- tests/serializers.py | 39 +++-- tests/test_mixins.py | 118 ++++++++------ 3 files changed, 180 insertions(+), 250 deletions(-) diff --git a/drf_dynamic_fields/__init__.py b/drf_dynamic_fields/__init__.py index e2a819e..8a4082f 100644 --- a/drf_dynamic_fields/__init__.py +++ b/drf_dynamic_fields/__init__.py @@ -4,10 +4,11 @@ import warnings -from collections import defaultdict from django.conf import settings from django.utils.functional import cached_property +from rest_framework import serializers + class DynamicFieldsMixin(object): """ @@ -15,6 +16,15 @@ class DynamicFieldsMixin(object): which fields should be displayed. """ + @property + def is_preventing_nested_serializers(self): + is_root = self.root == self + parent_is_list_root = self.parent == self.root and getattr( + self.parent, "many", False + ) + + return not (is_root or parent_is_list_root) + @cached_property def fields(self): """ @@ -22,7 +32,7 @@ def fields(self): A blank `fields` parameter (?fields) will remove all fields. Not passing `fields` will pass all fields individual fields are comma - separated (?fields=id,name,url,email,teachers__age). + separated (?fields=id,name,url,email). """ fields = super(DynamicFieldsMixin, self).fields @@ -31,13 +41,7 @@ def fields(self): # We are being called before a request cycle return fields - # Only filter if this is the root serializer, or if the parent is the - # root serializer with many=True - is_root = self.root == self - parent_is_list_root = self.parent == self.root and getattr( - self.parent, "many", False - ) - if not (is_root or parent_is_list_root): + if self.is_preventing_nested_serializers: return fields try: @@ -57,49 +61,22 @@ def fields(self): if params is None: warnings.warn("Request object does not contain query parameters") - try: - filter_fields = params.get("fields", None).split(",") - except AttributeError: - filter_fields = [] + source = get_source_path(self) + level = compute_level(self) - try: - omit_fields = params.get("omit", None).split(",") - except AttributeError: - omit_fields = [] + filter_fields = self.get_filter_fields(params.get("fields", None), level, source) + omit_fields = self.get_omit_fields(params.get("omit", None), level, source) # Drop any fields that are not specified in the `fields` argument. - self._nested_allow = defaultdict(list) - self._nested_omit = defaultdict(list) - self._flat_allow = set() - self._flat_omit = set() - - # store top-level and nested fields specified in the `fields` argument. - for filtered_field in filter_fields: - if "__" in filtered_field: - parent, child = filtered_field.split("__", 1) - self._nested_allow[parent].append(child) - # If a nested field is allowed the related parent level field - # must also be allowed - self._flat_allow.add(parent) - else: - self._flat_allow.add(filtered_field) - - # store top-level and nested fields in the `omit` argument. - for omitted_field in omit_fields: - if "__" in omitted_field: - parent, child = omitted_field.split("__", 1) - self._nested_omit[parent].append(child) - else: - self._flat_omit.add(omitted_field) - - # Drop top-level fields existing = set(fields.keys()) - if "fields" in params: - allowed = self._flat_allow - else: + if filter_fields is None: # no fields param given, don't filter. allowed = existing - omitted = self._flat_omit + else: + allowed = set(filter(None, filter_fields)) + + # omit fields in the `omit` argument. + omitted = set(filter(None, omit_fields)) for field in existing: @@ -109,152 +86,74 @@ def fields(self): if field in omitted: fields.pop(field, None) - # Drop omitted and non-allowed child fields from nested serializers - self.prune_nested_fields(fields) - return fields - @staticmethod - def _get_nested_serializer(parent, fields): - """Return the nested serializer for a parent field, or None.""" - field = fields.get(parent) - if not field: - return None - nested = getattr(field, "child", field) - return nested if hasattr(nested, "fields") else None - - def prune_nested_fields(self, fields): - """Prune valid nested serializer fields based on the _nested_omit and _nested_allow lists.""" - valid_parents = ( - self._nested_omit.keys() | self._nested_allow.keys() - ) & fields.keys() - for parent in valid_parents: - nested_serializer = self._get_nested_serializer(parent, fields) - if nested_serializer is None: - continue - - # Drop omitted child fields from nested serializers - for name in self._nested_omit.get(parent, ()): - nested_serializer.fields.pop(name, None) - - # Drop non-allowed child fields from the nested serializers - allow_list = self._nested_allow.get(parent) - if allow_list: - nested_serializer.fields = { - name: field - for name, field in nested_serializer.fields.items() - if name in allow_list - } - - def _filter_top_level_fields_to_defer(self, field_list, keep_if): - """Method to retrieve the top-level fields to defer, given a list of - field names (allow or omit) and a condition that determines which - fields to defer. - """ - model = getattr(self.Meta, "model", None) - if not field_list or model is None: - return [] - - all_fields = [ - f.name for f in model._meta.get_fields() if getattr(f, "concrete", False) - ] - return [name for name in all_fields if keep_if(name, field_list)] - - def _get_disallowed_top_level_fields_to_defer(self): - """Determine which top-level model fields should be deferred when an - explicit fields filter is in use. - Other model fields not explicitly included in 'fields' are deferred. - """ - allow = getattr(self, "_flat_allow", None) - return self._filter_top_level_fields_to_defer( - allow, keep_if=lambda name, allow: name not in allow - ) + def get_filter_fields(self, params, level, source, default=None, include_parent=True): + try: + return params.split(",") + except AttributeError: + return default - def _get_omit_top_level_fields_to_defer(self): - """ - Determine which top-level model fields should be deferred when an - explicit omit filter is in use. Valid database fields in the omit list - will be deferred. - """ - omit = getattr(self, "_flat_omit", None) - return self._filter_top_level_fields_to_defer( - omit, keep_if=lambda name, omit: name in omit - ) - def _get_nested_level_fields_to_defer(self, nested_mapping, should_defer): - """Method to retrieve nested fields to defer based on a mapping and a - defer condition. - """ - fields_to_defer = [] - for parent, items in nested_mapping.items(): - field = self.fields.get(parent) - if not field: - continue - - child_serializer = getattr(field, "child", field) - nested_model = getattr(child_serializer.Meta, "model", None) - if nested_model is None: - continue - - # Filter out nested fields that have a database column associated - # with them. - field_names = [ - f.name - for f in nested_model._meta.get_fields() - if getattr(f, "concrete", False) - ] - - # Determine which nested fields to defer - for name in field_names: - if should_defer(name, items): - fields_to_defer.append(f"{parent}__{name}") - - return fields_to_defer - - def _get_disallowed_nested_level_fields_to_defer(self): - """Determine which top-level model fields should be deferred when an - explicit fields filter is in use. - Other model fields not explicitly included in 'fields' are deferred. - """ - allow_map = getattr(self, "_nested_allow", {}) - return self._get_nested_level_fields_to_defer( - allow_map, - should_defer=lambda name, allow_list: name not in allow_list, - ) + def get_omit_fields(self, params, level, source): + return self.get_filter_fields(params, level, source, default=[], include_parent=False) - def _get_omit_nested_level_fields_to_defer(self): - """ - Determine which nested-model fields should be deferred for each nested - serializer when an explicit omit filter is in use. - """ - omit_map = getattr(self, "_nested_omit", {}) - return self._get_nested_level_fields_to_defer( - omit_map, should_defer=lambda name, omit_list: name in omit_list - ) - def get_model_fields_to_defer(self): - """ - Returns a flat list of filtered model-fields; top-level and nested. - Ensures that parsing of 'fields'/'omit' has run by accessing '.fields'. - """ +class NestedDynamicFieldsMixin(DynamicFieldsMixin): - # Trigger parsing of required attributes if not already set - if not all( - hasattr(self, attr) - for attr in ( - "_flat_omit", - "_nested_omit", - "_flat_allow", - "_nested_allow", + @property + def is_preventing_nested_serializers(self): + return False + + def get_filter_fields(self, params, level, source, default=None, include_parent=True): + fields = super().get_filter_fields(params, level, source, default, include_parent) + return get_fields_for_level_and_prefix( + fields, + level, + source, + default=default, + include_parent=include_parent ) - ): - _ = self.fields - - deferred = [ - *self._get_omit_top_level_fields_to_defer(), - *self._get_omit_nested_level_fields_to_defer(), - *self._get_disallowed_top_level_fields_to_defer(), - *self._get_disallowed_nested_level_fields_to_defer(), - ] - # Remove any duplicate fields - return list(set(deferred)) + +def get_source_path(serializer): + parts = [] + current = serializer + while current.parent is not None: + if hasattr(current, 'field_name'): + parts.insert(0, current.field_name) + current = current.parent + return "__".join(filter(None, parts)) + +def get_fields_for_level_and_prefix(fields_list, level, source, include_parent, default): + if not fields_list: + return default + + allowed = set() + prefix = source.split("__") if source else [] + for f in fields_list: + parts = f.split("__") + if parts[:level] != prefix: + continue + if len(parts) <= level + 1: + allowed.add(parts[-1]) + elif len(parts) > level + 1 and include_parent: + # include parent field to ensure nesting proceeds + allowed.add(parts[level]) + if set(prefix) == allowed: + return default + return allowed + +def compute_level(serializer): + level = 0 + current = serializer + while hasattr(current, 'parent') and current.parent is not None: + parent = current.parent + + # Handle ListSerializer by skipping over it + if isinstance(parent, serializers.ListSerializer): + current = parent.parent + else: + current = parent + + level += 1 + return level diff --git a/tests/serializers.py b/tests/serializers.py index 766891b..9e37b49 100644 --- a/tests/serializers.py +++ b/tests/serializers.py @@ -3,16 +3,12 @@ """ from rest_framework import serializers -from drf_dynamic_fields import DynamicFieldsMixin +from drf_dynamic_fields import DynamicFieldsMixin, NestedDynamicFieldsMixin from .models import Teacher, School, Child -class TeacherSerializer(DynamicFieldsMixin, serializers.ModelSerializer): - """ - The request_info field is to highlight the issue accessing request during - a nested serializer. - """ +class BaseTeacherSerializer(serializers.ModelSerializer): request_info = serializers.SerializerMethodField() @@ -29,20 +25,37 @@ def get_request_info(self, teacher): return request.build_absolute_uri("/api/v1/teacher/{}".format(teacher.pk)) -class SchoolSerializer(DynamicFieldsMixin, serializers.ModelSerializer): +class TeacherSerializer(DynamicFieldsMixin, BaseTeacherSerializer): + pass + + +class NestableTeacherSerializer(NestedDynamicFieldsMixin, BaseTeacherSerializer): """ - Interesting enough serializer because the TeacherSerializer - will use ListSerializer due to the `many=True` + The request_info field is to highlight the issue accessing request during + a nested serializer. + """ - teachers = TeacherSerializer(many=True, read_only=True) +class BaseSchoolSerializer(serializers.ModelSerializer): + class Meta: model = School fields = ("id", "teachers", "name") -class ChildSerializer(DynamicFieldsMixin, serializers.Serializer): +class SchoolSerializer(DynamicFieldsMixin, BaseSchoolSerializer): + teachers = TeacherSerializer(many=True, read_only=True) + + +class NestableSchoolSerializer(NestedDynamicFieldsMixin, BaseSchoolSerializer): + """ + Interesting enough serializer because the TeacherSerializer + will use ListSerializer due to the `many=True` + """ + teachers = NestableTeacherSerializer(many=True, read_only=True) + +class ChildSerializer(NestedDynamicFieldsMixin, serializers.Serializer): secret = serializers.CharField() public = serializers.CharField() @@ -50,6 +63,6 @@ class Meta: model = Child -class ParentSerializer(DynamicFieldsMixin, serializers.Serializer): +class ParentSerializer(NestedDynamicFieldsMixin, serializers.Serializer): id = serializers.IntegerField() - child = ChildSerializer() \ No newline at end of file + child = ChildSerializer() diff --git a/tests/test_mixins.py b/tests/test_mixins.py index dfbce95..8c9a1a1 100644 --- a/tests/test_mixins.py +++ b/tests/test_mixins.py @@ -11,7 +11,13 @@ from django.test import TestCase, RequestFactory -from .serializers import SchoolSerializer, TeacherSerializer, ParentSerializer +from .serializers import ( + NestableSchoolSerializer, + NestableTeacherSerializer, + SchoolSerializer, + TeacherSerializer, + ParentSerializer, +) from .models import Teacher, School, Child, Parent @@ -20,36 +26,8 @@ class TestDynamicFieldsMixin(TestCase): Test case for the DynamicFieldsMixin """ - def _assert_nested_fields(self, data, expected_fields): - """ - Assert nested fields match the expected fields. - """ - for parent, nested_fields in expected_fields.items(): - with self.subTest(parent=parent): - items = data[parent] - if nested_fields is None: - continue - expected_set = set(nested_fields) - for obj in items: - with self.subTest(parent=parent): - actual_set = set(obj.keys()) - self.assertEqual( - actual_set, - expected_set, - f"{parent} fields mismatch: expected " - f"exactly {nested_fields}, got {list(obj.keys())}", - ) - - @staticmethod - def _prepare_school_instance(): - """Prepare school instance for testing.""" - school = School.objects.create(name="Python Heights High") - teachers = [ - Teacher.objects.create(name="Shane", age=45), - Teacher.objects.create(name="Kaz", age=29), - ] - school.teachers.add(*teachers) - return school + SchoolSerializer = SchoolSerializer + TeacherSerializer = TeacherSerializer def test_removes_fields(self): """ @@ -57,7 +35,7 @@ def test_removes_fields(self): """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/?fields=id") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual(set(serializer.fields.keys()), set(("id",))) @@ -67,7 +45,7 @@ def test_fields_left_alone(self): """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual( set(serializer.fields.keys()), set(("id", "request_info", "age", "name")) @@ -79,7 +57,7 @@ def test_fields_all_gone(self): """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/?fields") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual(set(serializer.fields.keys()), set()) @@ -91,7 +69,7 @@ def test_ordinary_serializer(self): request = rf.get("/api/v1/schools/1/?fields=id,age") teacher = Teacher.objects.create(name="Susan", age=34) - serializer = TeacherSerializer(teacher, context={"request": request}) + serializer = self.TeacherSerializer(teacher, context={"request": request}) self.assertEqual(serializer.data, {"id": teacher.id, "age": teacher.age}) @@ -101,7 +79,7 @@ def test_omit(self): """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/?omit=request_info") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual(set(serializer.fields.keys()), set(("id", "name", "age"))) @@ -111,7 +89,7 @@ def test_omit_and_fields_used(self): """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/?fields=id,request_info&omit=request_info") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual(set(serializer.fields.keys()), set(("id",))) @@ -121,7 +99,7 @@ def test_omit_everything(self): """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/?omit=id,request_info,age,name") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual(set(serializer.fields.keys()), set()) @@ -131,7 +109,7 @@ def test_omit_nothing(self): """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/?omit") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual( set(serializer.fields.keys()), set(("id", "request_info", "name", "age")) @@ -140,7 +118,7 @@ def test_omit_nothing(self): def test_omit_non_existant_field(self): rf = RequestFactory() request = rf.get("/api/v1/schools/1/?omit=pretend") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual( set(serializer.fields.keys()), set(("id", "request_info", "name", "age")) @@ -160,7 +138,7 @@ def test_as_nested_serializer(self): ] school.teachers.add(*teachers) - serializer = SchoolSerializer(school, context={"request": request}) + serializer = self.SchoolSerializer(school, context={"request": request}) request_info = "http://testserver/api/v1/teacher/{}" @@ -199,7 +177,7 @@ def test_serializer_reuse_with_changing_request(self): rf = RequestFactory() request = rf.get("/api/v1/schools/1/?fields=id") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual(set(serializer.fields.keys()), {"id"}) # now change the request on this instantiated serializer. @@ -207,13 +185,51 @@ def test_serializer_reuse_with_changing_request(self): serializer.context["request"] = request2 self.assertEqual(set(serializer.fields.keys()), {"id"}) +class TestNestedDynamicFieldsMixin(TestDynamicFieldsMixin): + """ + Test case for the NestedDynamicFieldsMixin + """ + SchoolSerializer = NestableSchoolSerializer + TeacherSerializer = NestableTeacherSerializer + + def _assert_nested_fields(self, data, expected_fields): + """ + Assert nested fields match the expected fields. + """ + for parent, nested_fields in expected_fields.items(): + with self.subTest(parent=parent): + items = data[parent] + if nested_fields is None: + continue + expected_set = set(nested_fields) + for obj in items: + with self.subTest(parent=parent): + actual_set = set(obj.keys()) + self.assertEqual( + actual_set, + expected_set, + f"{parent} fields mismatch: expected " + f"exactly {nested_fields}, got {list(obj.keys())}", + ) + + @staticmethod + def _prepare_school_instance(): + """Prepare school instance for testing.""" + school = School.objects.create(name="Python Heights High") + teachers = [ + Teacher.objects.create(name="Shane", age=45), + Teacher.objects.create(name="Kaz", age=29), + ] + school.teachers.add(*teachers) + return school + def test_omit_nested_field(self): """Omitting a nested field""" rf = RequestFactory() request = rf.get("/api/v1/schools/1/?omit=invalid,name,teachers__age,teachers__invalid") school = self._prepare_school_instance() - serializer = SchoolSerializer(school, context={"request": request}) + serializer = self.SchoolSerializer(school, context={"request": request}) data = serializer.data # Confirm omitted fields are in deferred list @@ -235,7 +251,7 @@ def test_omit_everything_nested_field(self): ) school = self._prepare_school_instance() - serializer = SchoolSerializer(school, context={"request": request}) + serializer = self.SchoolSerializer(school, context={"request": request}) data = serializer.data expected_fields = {"id": None, "name": None, "teachers": []} @@ -251,7 +267,7 @@ def test_omit_top_field_and_keep_all_nested_fields(self): request = rf.get("/api/v1/schools/1/?omit=name") school = self._prepare_school_instance() - serializer = SchoolSerializer(school, context={"request": request}) + serializer = self.SchoolSerializer(school, context={"request": request}) data = serializer.data expected_fields = { @@ -269,7 +285,7 @@ def test_allow_nested_field(self): rf = RequestFactory() request = rf.get("/api/v1/schools/1/?fields=invalid,id,teachers__age,teachers__invalid") school = self._prepare_school_instance() - serializer = SchoolSerializer(school, context={"request": request}) + serializer = self.SchoolSerializer(school, context={"request": request}) # Confirm omitted fields are in deferred list deferred = set(serializer.get_model_fields_to_defer()) @@ -290,7 +306,7 @@ def test_fields_all_gone_nested(self): rf = RequestFactory() request = rf.get("/api/v1/schools/1/?fields") school = self._prepare_school_instance() - serializer = SchoolSerializer(school, context={"request": request}) + serializer = self.SchoolSerializer(school, context={"request": request}) data = serializer.data expected_fields = {} @@ -307,7 +323,7 @@ def test_nested_omit_and_fields_used(self): "/api/v1/schools/1/?fields=id,name,teachers__name,teachers__age&omit=name,teachers__name" ) school = self._prepare_school_instance() - serializer = SchoolSerializer(school, context={"request": request}) + serializer = self.SchoolSerializer(school, context={"request": request}) data = serializer.data expected_fields = {"id": None, "teachers": ["age"]} @@ -324,7 +340,7 @@ def test_omit_nothing_nested(self): rf = RequestFactory() request = rf.get("/api/v1/schools/1/?omit") school = self._prepare_school_instance() - serializer = SchoolSerializer(school, context={"request": request}) + serializer = self.SchoolSerializer(school, context={"request": request}) data = serializer.data expected_fields = { @@ -363,4 +379,6 @@ def test_single_nested_instance_allow_field(self): self.assertEqual(data["id"], 1) self.assertIn("secret", data["child"]) self.assertEqual(data["child"]["secret"], "secret_key") - self.assertNotIn("public", data["child"]) \ No newline at end of file + self.assertNotIn("public", data["child"]) + + From 18ab2f3706905342009ca203c9a1481d69c4ed64 Mon Sep 17 00:00:00 2001 From: Alberto Islas Date: Fri, 11 Jul 2025 18:20:22 -0600 Subject: [PATCH 3/4] fix(drf): Refined logic to allow nested serializers to filter themselves --- drf_dynamic_fields/__init__.py | 135 ++++++++++++++++++++----------- tests/models.py | 6 +- tests/serializers.py | 25 +++++- tests/test_mixins.py | 140 ++++++++++++++++++++++++--------- 4 files changed, 218 insertions(+), 88 deletions(-) diff --git a/drf_dynamic_fields/__init__.py b/drf_dynamic_fields/__init__.py index 8a4082f..38e47ff 100644 --- a/drf_dynamic_fields/__init__.py +++ b/drf_dynamic_fields/__init__.py @@ -17,14 +17,13 @@ class DynamicFieldsMixin(object): """ @property - def is_preventing_nested_serializers(self): - is_root = self.root == self - parent_is_list_root = self.parent == self.root and getattr( - self.parent, "many", False + def prevent_nested_processing(self) -> bool: + """True when this serializer is not the root nor a root’s list-child.""" + return not ( + self is self.root + or (self.parent is self.root and getattr(self.parent, "many", False)) ) - return not (is_root or parent_is_list_root) - @cached_property def fields(self): """ @@ -41,7 +40,7 @@ def fields(self): # We are being called before a request cycle return fields - if self.is_preventing_nested_serializers: + if self.prevent_nested_processing: return fields try: @@ -64,7 +63,9 @@ def fields(self): source = get_source_path(self) level = compute_level(self) - filter_fields = self.get_filter_fields(params.get("fields", None), level, source) + filter_fields = self.get_filter_fields( + params.get("fields", None), level, source + ) omit_fields = self.get_omit_fields(params.get("omit", None), level, source) # Drop any fields that are not specified in the `fields` argument. @@ -88,72 +89,114 @@ def fields(self): return fields - def get_filter_fields(self, params, level, source, default=None, include_parent=True): + def get_filter_fields( + self, params, level, source, default=None, include_parent=True + ): try: return params.split(",") except AttributeError: return default - def get_omit_fields(self, params, level, source): - return self.get_filter_fields(params, level, source, default=[], include_parent=False) + return self.get_filter_fields( + params, level, source, default=[], include_parent=False + ) class NestedDynamicFieldsMixin(DynamicFieldsMixin): + """A serializer mixin that extends DynamicFieldsMixin to allow nested serializers + to filter their fields based on the original `fields` query parameter. + + Unlike the base mixin—which only applies filtering at the root serializer, + this subclass: + + - Disables the `prevent_nested_processing` guard, allowing each level of nested + serializer to apply field filtering independently. + - Overrides `get_filter_fields` to slice the raw `fields` string + down to exactly those names relevant at this serializer’s + current nesting depth (using get_fields_for_level_and_prefix). + - `get_filter_fields` first delegates to the super method for splitting + the comma‐separated string, then calls a helper that: + • Selects only the fields that are nested under this serializer's path in + the hierarchy + • Returns direct children at depth `level + 1` + """ @property - def is_preventing_nested_serializers(self): + def prevent_nested_processing(self): return False - def get_filter_fields(self, params, level, source, default=None, include_parent=True): - fields = super().get_filter_fields(params, level, source, default, include_parent) + def get_filter_fields( + self, params, level, source, default=None, include_parent=True + ): + """ + Parse the raw `fields` parameter and return the subset of fields + that apply at this serializer’s nesting level under the given + source prefix. + """ + fields = super().get_filter_fields( + params, level, source, default, include_parent + ) return get_fields_for_level_and_prefix( - fields, - level, - source, - default=default, - include_parent=include_parent - ) - -def get_source_path(serializer): - parts = [] - current = serializer - while current.parent is not None: - if hasattr(current, 'field_name'): - parts.insert(0, current.field_name) - current = current.parent - return "__".join(filter(None, parts)) - -def get_fields_for_level_and_prefix(fields_list, level, source, include_parent, default): + fields, level, source, default=default, include_parent=include_parent + ) + + +def get_source_path(serializer) -> str: + """Recursively walks up the serializer tree to build the nested field path.""" + parent = getattr(serializer, "parent", None) + if not parent: + return "" + parent_path = get_source_path(parent) + name = getattr(serializer, "field_name", None) + if not name: + return parent_path + return f"{parent_path}__{name}" if parent_path else name + + +def get_fields_for_level_and_prefix( + fields_list, level, source, include_parent, default +): + """Filter a list of dotted field names down to those relevant at a given + nesting level and prefix. + """ if not fields_list: return default - allowed = set() prefix = source.split("__") if source else [] + allowed = set() for f in fields_list: parts = f.split("__") + if parts[:level] != prefix: continue + if len(parts) <= level + 1: allowed.add(parts[-1]) - elif len(parts) > level + 1 and include_parent: + continue + + if len(parts) > level + 1 and include_parent: # include parent field to ensure nesting proceeds allowed.add(parts[level]) - if set(prefix) == allowed: + continue + + if allowed == set(prefix): return default + return allowed -def compute_level(serializer): - level = 0 - current = serializer - while hasattr(current, 'parent') and current.parent is not None: - parent = current.parent - # Handle ListSerializer by skipping over it - if isinstance(parent, serializers.ListSerializer): - current = parent.parent - else: - current = parent +def compute_level(serializer) -> int: + """Recursively count how many ancestors of `serializer` are not + ListSerializer instances. Stops when parent is None. + """ + parent = getattr(serializer, "parent", None) + if parent is None: + # base case, reached the top + return 0 + + # if this immediate parent is a ListSerializer, don’t count it, otherwise 1 + this_level = 0 if isinstance(parent, serializers.ListSerializer) else 1 - level += 1 - return level + # recurse on the parent itself + return this_level + compute_level(parent) diff --git a/tests/models.py b/tests/models.py index 884697a..a6d1403 100644 --- a/tests/models.py +++ b/tests/models.py @@ -4,10 +4,14 @@ from django.db import models -class Teacher(models.Model): +class Student(models.Model): name = models.CharField(max_length=30) age = models.IntegerField() +class Teacher(models.Model): + name = models.CharField(max_length=30) + age = models.IntegerField() + students = models.ManyToManyField(Student) class School(models.Model): """Schools just have teachers, no students.""" diff --git a/tests/serializers.py b/tests/serializers.py index 9e37b49..13f9c17 100644 --- a/tests/serializers.py +++ b/tests/serializers.py @@ -5,7 +5,7 @@ from drf_dynamic_fields import DynamicFieldsMixin, NestedDynamicFieldsMixin -from .models import Teacher, School, Child +from .models import Teacher, School, Child, Student class BaseTeacherSerializer(serializers.ModelSerializer): @@ -29,13 +29,34 @@ class TeacherSerializer(DynamicFieldsMixin, BaseTeacherSerializer): pass +class NestableStudentSerializer(NestedDynamicFieldsMixin, serializers.ModelSerializer): + + class Meta: + model = Student + fields = ("id", "name", "age") + + class NestableTeacherSerializer(NestedDynamicFieldsMixin, BaseTeacherSerializer): """ The request_info field is to highlight the issue accessing request during a nested serializer. - """ + students = NestableStudentSerializer(many=True, read_only=True) + + class Meta: + model = Teacher + fields = ("id", "request_info", "age", "name", "students") + + def get_request_info(self, teacher): + """ + a meaningless method that attempts + to access the request object. + """ + request = self.context["request"] + return request.build_absolute_uri("/api/v1/teacher/{}".format(teacher.pk)) + + class BaseSchoolSerializer(serializers.ModelSerializer): diff --git a/tests/test_mixins.py b/tests/test_mixins.py index 8c9a1a1..c5609e4 100644 --- a/tests/test_mixins.py +++ b/tests/test_mixins.py @@ -18,7 +18,7 @@ TeacherSerializer, ParentSerializer, ) -from .models import Teacher, School, Child, Parent +from .models import Teacher, School, Child, Parent, Student class TestDynamicFieldsMixin(TestCase): @@ -185,7 +185,7 @@ def test_serializer_reuse_with_changing_request(self): serializer.context["request"] = request2 self.assertEqual(set(serializer.fields.keys()), {"id"}) -class TestNestedDynamicFieldsMixin(TestDynamicFieldsMixin): +class TestNestedDynamicFieldsMixin(TestCase): """ Test case for the NestedDynamicFieldsMixin """ @@ -221,43 +221,69 @@ def _prepare_school_instance(): Teacher.objects.create(name="Kaz", age=29), ] school.teachers.add(*teachers) - return school + student = Student.objects.create(name="Shannon", age=23) + teachers[0].students.add(student) + + school_2 = School.objects.create(name="Python Heights High") + teachers = [ + Teacher.objects.create(name="Shane 2", age=46), + Teacher.objects.create(name="Kaz 2", age=30), + ] + school_2.teachers.add(*teachers) + + return [school, school_2] def test_omit_nested_field(self): """Omitting a nested field""" rf = RequestFactory() - request = rf.get("/api/v1/schools/1/?omit=invalid,name,teachers__age,teachers__invalid") + request = rf.get("/api/v1/schools/1/?omit=invalid,name,teachers__age,teachers__invalid,teachers__students__age") - school = self._prepare_school_instance() - serializer = self.SchoolSerializer(school, context={"request": request}) + # Single nested instance. + schools = self._prepare_school_instance() + serializer = self.SchoolSerializer(schools[0], context={"request": request}) data = serializer.data - - # Confirm omitted fields are in deferred list - deferred = set(serializer.get_model_fields_to_defer()) - self.assertEqual({"name", "teachers__age"}, deferred) - - expected_fields = {"id": None, "teachers": ["id", "name", "request_info"]} + expected_fields = {"id": None, "teachers": ["id", "name", "request_info", "students"]} + third_level_expected_fields = {"id":None, "name":None, "request_info":None, "students":["id", "name"]} # Assert top‐level keys exactly match self.assertEqual(set(data.keys()), set(expected_fields.keys())) + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + # Assert third level fields: + self._assert_nested_fields(data["teachers"][0], third_level_expected_fields) + # Multiple nested instances: + serializer = self.SchoolSerializer(schools, many=True, context={"request": request}) + data = serializer.data[0] + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) # Assert nested fields. self._assert_nested_fields(data, expected_fields) + # Assert third level fields: + self._assert_nested_fields(data["teachers"][0], third_level_expected_fields) + def test_omit_everything_nested_field(self): """Omitting all fields within a nested field""" rf = RequestFactory() request = rf.get( - "/api/v1/schools/1/?omit=teachers__id,teachers__age,teachers__name,teachers__request_info" + "/api/v1/schools/1/?omit=teachers__id,teachers__age,teachers__name,teachers__request_info,teachers__students" ) - - school = self._prepare_school_instance() - serializer = self.SchoolSerializer(school, context={"request": request}) + # Single nested instance. + schools = self._prepare_school_instance() + serializer = self.SchoolSerializer(schools[0], context={"request": request}) data = serializer.data expected_fields = {"id": None, "name": None, "teachers": []} # Assert top‐level keys exactly match self.assertEqual(set(data.keys()), set(expected_fields.keys())) + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + # Multiple nested instances: + serializer = self.SchoolSerializer(schools, many=True, context={"request": request}) + data = serializer.data[0] + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) # Assert nested fields. self._assert_nested_fields(data, expected_fields) @@ -266,38 +292,53 @@ def test_omit_top_field_and_keep_all_nested_fields(self): rf = RequestFactory() request = rf.get("/api/v1/schools/1/?omit=name") - school = self._prepare_school_instance() - serializer = self.SchoolSerializer(school, context={"request": request}) + schools = self._prepare_school_instance() + # Single nested instance. + serializer = self.SchoolSerializer(schools[0], context={"request": request}) data = serializer.data - expected_fields = { "id": None, - "teachers": ["id", "name", "request_info", "age"], + "teachers": ["id", "name", "request_info", "age", "students"], } # Assert top‐level keys exactly match self.assertEqual(set(data.keys()), set(expected_fields.keys())) + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + # Multiple nested instances: + serializer = self.SchoolSerializer(schools, many=True, context={"request": request}) + data = serializer.data[0] + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) # Assert nested fields. self._assert_nested_fields(data, expected_fields) def test_allow_nested_field(self): """Select only the requested fields, including nested-level fields.""" rf = RequestFactory() - request = rf.get("/api/v1/schools/1/?fields=invalid,id,teachers__age,teachers__invalid") - school = self._prepare_school_instance() - serializer = self.SchoolSerializer(school, context={"request": request}) - - # Confirm omitted fields are in deferred list - deferred = set(serializer.get_model_fields_to_defer()) - self.assertEqual({"name","teachers__name", "teachers__id"}, deferred) - + request = rf.get("/api/v1/schools/1/?fields=invalid,id,teachers__age,teachers__invalid,teachers__students__age") + schools = self._prepare_school_instance() + # Single nested instance. + serializer = self.SchoolSerializer(schools[0], context={"request": request}) data = serializer.data - expected_fields = {"id": None, "teachers": ["age"]} + expected_fields = {"id": None, "teachers": ["age", "students"]} + third_level_expected_fields = {"age":None, "students":["age"]} # Assert top‐level keys exactly match self.assertEqual(set(data.keys()), set(expected_fields.keys())) + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + # Assert third level fields: + self._assert_nested_fields(data["teachers"][0], third_level_expected_fields) + # Multiple nested instances: + serializer = self.SchoolSerializer(schools, many=True, context={"request": request}) + data = serializer.data[0] + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) # Assert nested fields. self._assert_nested_fields(data, expected_fields) + # Assert third level fields: + self._assert_nested_fields(data["teachers"][0], third_level_expected_fields) def test_fields_all_gone_nested(self): """If no fields are selected, all fields are omitted, including those @@ -305,14 +346,21 @@ def test_fields_all_gone_nested(self): """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/?fields") - school = self._prepare_school_instance() - serializer = self.SchoolSerializer(school, context={"request": request}) - + schools = self._prepare_school_instance() + # Single nested instance. + serializer = self.SchoolSerializer(schools[0], context={"request": request}) data = serializer.data expected_fields = {} # Assert top‐level keys exactly match self.assertEqual(set(data.keys()), set(expected_fields.keys())) + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + # Multiple nested instances: + serializer = self.SchoolSerializer(schools, many=True, context={"request": request}) + data = serializer.data[0] + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) # Assert nested fields. self._assert_nested_fields(data, expected_fields) @@ -322,14 +370,21 @@ def test_nested_omit_and_fields_used(self): request = rf.get( "/api/v1/schools/1/?fields=id,name,teachers__name,teachers__age&omit=name,teachers__name" ) - school = self._prepare_school_instance() - serializer = self.SchoolSerializer(school, context={"request": request}) - + schools = self._prepare_school_instance() + # Single nested instance. + serializer = self.SchoolSerializer(schools[0], context={"request": request}) data = serializer.data expected_fields = {"id": None, "teachers": ["age"]} # Assert top‐level keys exactly match self.assertEqual(set(data.keys()), set(expected_fields.keys())) + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + # Multiple nested instances: + serializer = self.SchoolSerializer(schools, many=True, context={"request": request}) + data = serializer.data[0] + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) # Assert nested fields. self._assert_nested_fields(data, expected_fields) @@ -339,18 +394,25 @@ def test_omit_nothing_nested(self): """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/?omit") - school = self._prepare_school_instance() - serializer = self.SchoolSerializer(school, context={"request": request}) - + schools = self._prepare_school_instance() + # Single nested instance. + serializer = self.SchoolSerializer(schools[0], context={"request": request}) data = serializer.data expected_fields = { "id": None, "name": None, - "teachers": ["id", "age", "name", "request_info"], + "teachers": ["id", "age", "name", "request_info", "students"], } # Assert top‐level keys exactly match self.assertEqual(set(data.keys()), set(expected_fields.keys())) + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + # Multiple nested instances: + serializer = self.SchoolSerializer(schools, many=True, context={"request": request}) + data = serializer.data[0] + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) # Assert nested fields. self._assert_nested_fields(data, expected_fields) From e600536adb041385269279d60357244cbd43748c Mon Sep 17 00:00:00 2001 From: Alberto Islas Date: Thu, 24 Jul 2025 13:08:48 -0600 Subject: [PATCH 4/4] fix(drf): Fixed outdated docstring --- drf_dynamic_fields/__init__.py | 7 +++++-- tests/test_mixins.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/drf_dynamic_fields/__init__.py b/drf_dynamic_fields/__init__.py index 38e47ff..c4f8df0 100644 --- a/drf_dynamic_fields/__init__.py +++ b/drf_dynamic_fields/__init__.py @@ -157,8 +157,8 @@ def get_source_path(serializer) -> str: def get_fields_for_level_and_prefix( fields_list, level, source, include_parent, default ): - """Filter a list of dotted field names down to those relevant at a given - nesting level and prefix. + """Extract the field names relevant to a specific nesting depth + from a list of double‑underscore lookup strings. """ if not fields_list: return default @@ -180,6 +180,9 @@ def get_fields_for_level_and_prefix( allowed.add(parts[level]) continue + + # If the only allowed fields are exactly the prefix itself, + # fall back to default if allowed == set(prefix): return default diff --git a/tests/test_mixins.py b/tests/test_mixins.py index c5609e4..b03e923 100644 --- a/tests/test_mixins.py +++ b/tests/test_mixins.py @@ -185,6 +185,7 @@ def test_serializer_reuse_with_changing_request(self): serializer.context["request"] = request2 self.assertEqual(set(serializer.fields.keys()), {"id"}) + class TestNestedDynamicFieldsMixin(TestCase): """ Test case for the NestedDynamicFieldsMixin @@ -261,7 +262,6 @@ def test_omit_nested_field(self): # Assert third level fields: self._assert_nested_fields(data["teachers"][0], third_level_expected_fields) - def test_omit_everything_nested_field(self): """Omitting all fields within a nested field""" rf = RequestFactory()