From ded3a164453347c38460cdaaa846160812b87ddf Mon Sep 17 00:00:00 2001 From: Alberto Islas Date: Thu, 3 Jul 2025 17:46:52 -0600 Subject: [PATCH 1/4] feat(drf): Added support for filtering fields at nested levels - Also added support for deferring filtered fields --- .gitignore | 1 + drf_dynamic_fields/__init__.py | 224 +++++++++++++++++++++++++++++++-- setup.py | 41 +++--- tests/models.py | 10 ++ tests/serializers.py | 17 ++- tests/test_mixins.py | 194 +++++++++++++++++++++++++++- tests/test_requests.py | 77 ++++++++++++ tests/views.py | 19 +++ 8 files changed, 546 insertions(+), 37 deletions(-) diff --git a/.gitignore b/.gitignore index 39ca09e..82934a3 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist/ *.egg-info/ build/ .tox/ +.idea \ No newline at end of file diff --git a/drf_dynamic_fields/__init__.py b/drf_dynamic_fields/__init__.py index 97694e4..2e70468 100644 --- a/drf_dynamic_fields/__init__.py +++ b/drf_dynamic_fields/__init__.py @@ -1,9 +1,11 @@ """ Mixin to dynamically select only a subset of fields per DRF resource. """ + import warnings from django.conf import settings +from django.db.models import Prefetch from django.utils.functional import cached_property @@ -24,7 +26,6 @@ def fields(self): """ fields = super(DynamicFieldsMixin, self).fields - if not hasattr(self, "_context"): # We are being called before a request cycle return fields @@ -37,7 +38,6 @@ def fields(self): ) if not (is_root or parent_is_list_root): return fields - try: request = self.context["request"] except KeyError: @@ -58,30 +58,228 @@ def fields(self): try: filter_fields = params.get("fields", None).split(",") except AttributeError: - filter_fields = None + filter_fields = [] try: omit_fields = params.get("omit", None).split(",") except AttributeError: omit_fields = [] - # Drop any fields that are not specified in the `fields` argument. + # Save for deferred logic + self._flat_allow = set() + self._flat_omit = set() + self._nested_allow = {} + self._nested_omit = {} + + for filtered_field in filter_fields: + if "__" in filtered_field: + parent, child = filtered_field.split("__", 1) + self._nested_allow.setdefault(parent, []).append(child) + # If a nested field is allowed the related parent level field + # must also be allowed + self._flat_allow.add(parent) + else: + self._flat_allow.add(filtered_field) + + for omitted_field in omit_fields: + if "__" in omitted_field: + parent, child = omitted_field.split("__", 1) + self._nested_omit.setdefault(parent, []).append(child) + else: + self._flat_omit.add(omitted_field) + + # Drop top-level fields existing = set(fields.keys()) - if filter_fields is None: - # no fields param given, don't filter. - allowed = existing + if "fields" in params: + allowed = self._flat_allow else: - allowed = set(filter(None, filter_fields)) - - # omit fields in the `omit` argument. - omitted = set(filter(None, omit_fields)) + allowed = existing + omitted = self._flat_omit for field in existing: - if field not in allowed: fields.pop(field, None) - if field in omitted: fields.pop(field, None) return fields + + def to_representation(self, instance): + """Use this method to prune filtered fields from a nested serializer.""" + representation = super(DynamicFieldsMixin, self).to_representation(instance) + + # Apply nested omit on dicts and lists of dicts + for parent, omit_list in getattr(self, "_nested_omit", {}).items(): + if parent not in representation: + continue + parent_instance = representation[parent] + + # helper to drop keys on a single dict + def do_omit(d): + for child in omit_list: + d.pop(child, None) + + if isinstance(parent_instance, dict): + do_omit(parent_instance) + elif isinstance(parent_instance, list): + for item in parent_instance: + if isinstance(item, dict): + do_omit(item) + + # Apply nested allow on dicts and lists of dicts + for parent, allow_list in getattr(self, "_nested_allow", {}).items(): + if parent not in representation: + continue + + parent_instance = representation[parent] + + def do_allow(d): + return { + field_name: field_value + for field_name, field_value in d.items() + if field_name in allow_list + } + + if isinstance(parent_instance, dict): + representation[parent] = do_allow(parent_instance) + elif isinstance(parent_instance, list): + representation[parent] = [ + do_allow(item) if isinstance(item, dict) else item + for item in parent_instance + ] + + return representation + + def _flat_whitelist_deferred(self): + """ + Determine which top-level model fields should be deferred when an explicit + omit/fields filter is in use. + """ + allow = getattr(self, "_flat_allow", None) + model = getattr(self.Meta, "model", None) + if not allow or model is None: + return [] + + names = [ + fld.name + for fld in model._meta.get_fields() + if getattr(fld, "concrete", False) + ] + return [ + name + for name in names + if name not in allow and name not in getattr(self, "_flat_omit", []) + ] + + def _nested_whitelist_deferred(self): + """ + Determine which nested-model fields should be deferred for each parent serializer + when an explicit fields/omit filter is in use. + """ + results = [] + for parent, allow_list in getattr(self, "_nested_allow", {}).items(): + field = self.fields.get(parent) + if not field: + continue + + child_ser = getattr(field, "child", field) + nested_model = getattr(child_ser.Meta, "model", None) + if nested_model is None: + continue + + omitted = set(getattr(self, "_nested_omit", {}).get(parent, [])) + concrete_names = [ + fld.name + for fld in nested_model._meta.get_fields() + if getattr(fld, "concrete", False) + ] + for name in concrete_names: + if name not in allow_list and name not in omitted: + results.append(f"{parent}__{name}") + + return results + + def get_deferred_model_fields(self): + """ + Returns flat list of omitted model-fields; top-level and nested. + Ensures that parsing of "fields"/"omit" has run by accessing ".fields". + """ + + # Trigger parsing of _flat_omit and _nested_omit if not already set + if not hasattr(self, "_flat_omit") or not hasattr(self, "_nested_omit"): + _ = self.fields # trigger parsing + + flat_omit = getattr(self, "_flat_omit", []) + nested_omit = getattr(self, "_nested_omit", {}) + + deferred = [] + # top-level + deferred.extend(flat_omit) + # nested + deferred.extend( + f"{parent}__{child}" + for parent, children in nested_omit.items() + for child in children + ) + + deferred.extend(self._flat_whitelist_deferred()) + deferred.extend(self._nested_whitelist_deferred()) + + # dedupe and preserve order + return list(dict.fromkeys(deferred)) + + +class DeferredFieldsMixin: + """ViewSet Mixin that: + - defers top‐level model columns based on omit/fields + - builds a Prefetch for each nested relation to defer its columns too + """ + + @staticmethod + def _split_deferred_fields(fields): + """Split deferred fields into top‐level fields and nested relations.""" + parent, nested = [], {} + for field in fields: + if "__" in field: + rel, fld = field.split("__", 1) + nested.setdefault(rel, []).append(fld) + else: + parent.append(field) + return parent, nested + + @staticmethod + def _apply_nested_prefetch(qs, nested_map, serializer): + """For each nested relation, add a Prefetch that defers its specified + fields. + """ + for rel, child_fields in nested_map.items(): + field = serializer.fields.get(rel) + if not field: + continue + + child_ser = getattr(field, "child", field) + model = getattr(child_ser.Meta, "model", None) + if not model: + continue + + qs = qs.prefetch_related( + Prefetch(rel, queryset=model.objects.defer(*child_fields)) + ) + return qs + + def get_queryset(self): + qs = super().get_queryset() + # instantiate serializer so deferred fields are calculated + serializer = self.get_serializer_class()(context=self.get_serializer_context()) + + # split deferred fields into top-level and nested + fields = serializer.get_deferred_model_fields() + parent_fields, nested_map = self._split_deferred_fields(fields) + + # defer top-level fields + if parent_fields: + qs = qs.defer(*parent_fields) + + # defer nested fields via Prefetch + qs = self._apply_nested_prefetch(qs, nested_map, serializer) + return qs diff --git a/setup.py b/setup.py index 7fa9229..1a70244 100644 --- a/setup.py +++ b/setup.py @@ -1,24 +1,25 @@ from setuptools import setup -readme = open('README.rst').read() +readme = open("README.rst").read() -setup(name='drf_dynamic_fields', - version='0.4.0', - description='Dynamically return subset of Django REST Framework serializer fields', - author='Danilo Bargen', - author_email='mail@dbrgn.ch', - url='https://github.com/dbrgn/drf-dynamic-fields', - packages=['drf_dynamic_fields'], - zip_safe=True, - include_package_data=True, - license='MIT', - keywords='drf restframework rest_framework django_rest_framework serializers', - long_description=readme, - classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3', - 'Framework :: Django', - 'Environment :: Web Environment', - ], +setup( + name="drf_dynamic_fields", + version="0.4.0", + description="Dynamically return subset of Django REST Framework serializer fields", + author="Danilo Bargen", + author_email="mail@dbrgn.ch", + url="https://github.com/dbrgn/drf-dynamic-fields", + packages=["drf_dynamic_fields"], + zip_safe=True, + include_package_data=True, + license="MIT", + keywords="drf restframework rest_framework django_rest_framework serializers", + long_description=readme, + classifiers=[ + "Development Status :: 5 - Production/Stable", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Framework :: Django", + "Environment :: Web Environment", + ], ) diff --git a/tests/models.py b/tests/models.py index 714098e..f2bdb33 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,6 +1,7 @@ """ Some models for the tests. We are modelling a school. """ + from django.db import models @@ -14,3 +15,12 @@ class School(models.Model): name = models.CharField(max_length=30) teachers = models.ManyToManyField(Teacher) + + +class Child(models.Model): + secret = models.CharField(max_length=100) + public = models.CharField(max_length=100) + + +class Parent(models.Model): + child = models.ForeignKey(Child, on_delete=models.CASCADE) diff --git a/tests/serializers.py b/tests/serializers.py index 3619d59..1e0b22d 100644 --- a/tests/serializers.py +++ b/tests/serializers.py @@ -1,9 +1,10 @@ """ For the tests. """ + from rest_framework import serializers -from drf_dynamic_fields import DynamicFieldsMixin +from drf_dynamic_fields import DynamicFieldsMixin, DeferredFieldsMixin from .models import Teacher, School @@ -29,7 +30,9 @@ def get_request_info(self, teacher): return request.build_absolute_uri("/api/v1/teacher/{}".format(teacher.pk)) -class SchoolSerializer(DynamicFieldsMixin, serializers.ModelSerializer): +class SchoolSerializer( + DynamicFieldsMixin, DeferredFieldsMixin, serializers.ModelSerializer +): """ Interesting enough serializer because the TeacherSerializer will use ListSerializer due to the `many=True` @@ -40,3 +43,13 @@ class SchoolSerializer(DynamicFieldsMixin, serializers.ModelSerializer): class Meta: model = School fields = ("id", "teachers", "name") + + +class ChildSerializer(DynamicFieldsMixin, serializers.Serializer): + secret = serializers.CharField() + public = serializers.CharField() + + +class ParentSerializer(DynamicFieldsMixin, serializers.Serializer): + id = serializers.IntegerField() + child = ChildSerializer() diff --git a/tests/test_mixins.py b/tests/test_mixins.py index f54e179..bf5a83c 100644 --- a/tests/test_mixins.py +++ b/tests/test_mixins.py @@ -11,8 +11,8 @@ from django.test import TestCase, RequestFactory -from .serializers import SchoolSerializer, TeacherSerializer -from .models import Teacher, School +from .serializers import SchoolSerializer, TeacherSerializer, ParentSerializer +from .models import Teacher, School, Child, Parent class TestDynamicFieldsMixin(TestCase): @@ -20,6 +20,37 @@ class TestDynamicFieldsMixin(TestCase): Test case for the DynamicFieldsMixin """ + def _assert_nested_fields(self, data, expected_fields): + """ + Assert nested fields match the expected fields. + """ + for parent, nested_fields in expected_fields.items(): + with self.subTest(parent=parent): + items = data[parent] + if nested_fields is None: + continue + expected_set = set(nested_fields) + for obj in items: + with self.subTest(parent=parent): + actual_set = set(obj.keys()) + self.assertEqual( + actual_set, + expected_set, + f"{parent} fields mismatch: expected " + f"exactly {nested_fields}, got {list(obj.keys())}", + ) + + @staticmethod + def _prepare_school_instance(): + """Prepare school instance for testing.""" + school = School.objects.create(name="Python Heights High") + teachers = [ + Teacher.objects.create(name="Shane", age=45), + Teacher.objects.create(name="Kaz", age=29), + ] + school.teachers.add(*teachers) + return school + def test_removes_fields(self): """ Does it actually remove fields? @@ -175,3 +206,162 @@ def test_serializer_reuse_with_changing_request(self): request2 = rf.get("/api/v1/schools/1/?fields=id,name") serializer.context["request"] = request2 self.assertEqual(set(serializer.fields.keys()), {"id"}) + + def test_omit_nested_field(self): + """Omitting a nested field""" + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?omit=name,teachers__age") + + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + data = serializer.data + + # Confirm omitted fields are in deferred list + deferred = serializer.get_deferred_model_fields() + self.assertIn("name", deferred) + self.assertIn("teachers__age", deferred) + + expected_fields = {"id": None, "teachers": ["id", "name", "request_info"]} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_omit_everything_nested_field(self): + """Omitting all fields within a nested field""" + rf = RequestFactory() + request = rf.get( + "/api/v1/schools/1/?omit=teachers__id,teachers__age,teachers__name,teachers__request_info" + ) + + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + data = serializer.data + + expected_fields = {"id": None, "name": None, "teachers": []} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_omit_top_field_and_keep_all_nested_fields(self): + """Omitting a top-level field while keeping all nested fields""" + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?omit=name") + + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + data = serializer.data + + expected_fields = { + "id": None, + "teachers": ["id", "name", "request_info", "age"], + } + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_allow_nested_field(self): + """Select only the requested fields, including nested-level fields.""" + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?fields=id,teachers__age") + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + + # Confirm omitted fields are in deferred list + deferred = serializer.get_deferred_model_fields() + self.assertIn("name", deferred) + + data = serializer.data + expected_fields = {"id": None, "teachers": ["age"]} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_fields_all_gone_nested(self): + """If no fields are selected, all fields are omitted, including those + from the nested serializer. + """ + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?fields") + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + + data = serializer.data + expected_fields = {} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_nested_omit_and_fields_used(self): + """Omit and fields can be used together at the nested field level.""" + rf = RequestFactory() + request = rf.get( + "/api/v1/schools/1/?fields=id,name,teachers__name,teachers__age&omit=name,teachers__name" + ) + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + + data = serializer.data + expected_fields = {"id": None, "teachers": ["age"]} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_omit_nothing_nested(self): + """ + Blank omit doesn't affect nested fields. + """ + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?omit") + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + + data = serializer.data + expected_fields = { + "id": None, + "name": None, + "teachers": ["id", "age", "name", "request_info"], + } + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_single_nested_instance_omit_field(self): + """Omit also works for filtering fields on single nested instances""" + child = Child(secret="secret_key", public="public_key") + parent = Parent(id=1, child=child) + rf = RequestFactory() + request = rf.get("/api/v1/parent/1/?omit=id,child__secret") + serializer = ParentSerializer(parent, context={"request": request}) + data = serializer.data + + self.assertNotIn("id", data) + self.assertNotIn("secret", data["child"]) + self.assertEqual(data["child"]["public"], "public_key") + + def test_single_nested_instance_allow_field(self): + """Fields selection also works for filtering fields on single nested instances""" + child = Child(secret="secret_key", public="public_key") + parent = Parent(id=1, child=child) + rf = RequestFactory() + request = rf.get("/api/v1/parent/1/?fields=id,child__secret") + serializer = ParentSerializer(parent, context={"request": request}) + data = serializer.data + + self.assertEqual(data["id"], 1) + self.assertIn("secret", data["child"]) + self.assertEqual(data["child"]["secret"], "secret_key") + self.assertNotIn("public", data["child"]) diff --git a/tests/test_requests.py b/tests/test_requests.py index 38e8e22..d4b2a72 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -8,13 +8,16 @@ Test for the full request cycle using dynamic fields mixns """ from collections import OrderedDict +from unittest.mock import MagicMock +from django.db.models import Prefetch from django.test import TestCase, RequestFactory from rest_framework.reverse import reverse from .serializers import SchoolSerializer, TeacherSerializer from .models import Teacher, School +from .views import SchoolDeferredViewSet class TestDynamicFieldsViews(TestCase): @@ -62,3 +65,77 @@ def test_nested_teacher_fields(self): self.assertEqual( school["teachers"][0].keys(), {"id", "request_info", "age", "name"} ) + + +class DeferredFieldsMixinTests(TestCase): + def setUp(self): + self.factory = RequestFactory() + self.view = SchoolDeferredViewSet.as_view({"get": "list"}) + + @staticmethod + def make_qs_mock(): + """ + Mock that records .defer() and .prefetch_related() calls for testing. + """ + qs_mock = MagicMock() + qs_mock._deferred_args = [] + qs_mock._prefetch_args = [] + + def mock_defer(*args): + qs_mock._deferred_args.extend(args) + return qs_mock + + qs_mock.defer.side_effect = mock_defer + + def mock_prefetch(*lookups): + for lookup in lookups: + if isinstance(lookup, Prefetch): + qs_mock._prefetch_args.append(lookup) + return qs_mock + + qs_mock.prefetch_related.side_effect = mock_prefetch + + return qs_mock + + def test_deffer_nested_omitted_fields(self): + """Nested level fields omitted should be deferred.""" + + qs_mock = self.make_qs_mock() + SchoolDeferredViewSet.queryset = qs_mock + + request = self.factory.get("/", {"omit": "name,teachers__age"}) + response = self.view(request) + self.assertEqual(response.status_code, 200) + + # top level name field should be deferred + self.assertEqual(qs_mock._deferred_args, ["name"]) + + # One Prefetch on teachers + self.assertEqual(response.data["prefetches"], ["teachers"]) + self.assertEqual(len(qs_mock._prefetch_args), 1) + + # Confirm the prefetch deferred field matches. + prefetch = qs_mock._prefetch_args[0] + deferred_fields, _ = prefetch.queryset.query.deferred_loading + self.assertIn("age", deferred_fields) + + def test_deffer_nested_no_selected_fields(self): + """No selected nested level fields should be deferred.""" + qs_mock = self.make_qs_mock() + SchoolDeferredViewSet.queryset = qs_mock + + request = self.factory.get("/", {"fields": "id,teachers__name"}) + response = self.view(request) + self.assertEqual(response.status_code, 200) + + # name field should be deferred + self.assertEqual(qs_mock._deferred_args, ["name"]) + + # One Prefetch for teachers + self.assertEqual(response.data["prefetches"], ["teachers"]) + self.assertEqual(len(qs_mock._prefetch_args), 1) + + # Confirm the prefetch deferred field matches the no selected fields. + prefetch = qs_mock._prefetch_args[0] + deferred_fields, _ = prefetch.queryset.query.deferred_loading + self.assertEqual(deferred_fields, {"age", "id"}) diff --git a/tests/views.py b/tests/views.py index 16de45e..22d7390 100644 --- a/tests/views.py +++ b/tests/views.py @@ -1,7 +1,10 @@ from rest_framework.viewsets import ModelViewSet +from rest_framework.response import Response +from rest_framework import status from .models import School, Teacher from .serializers import SchoolSerializer, TeacherSerializer +from drf_dynamic_fields import DeferredFieldsMixin class TeacherViewSet(ModelViewSet): @@ -12,3 +15,19 @@ class TeacherViewSet(ModelViewSet): class SchoolViewSet(ModelViewSet): queryset = School.objects.all() serializer_class = SchoolSerializer + + +class SchoolDeferredViewSet(DeferredFieldsMixin, ModelViewSet): + serializer_class = SchoolSerializer + + def list(self, request): + qs = self.get_queryset() + prefetch_names = [ + getattr(p, "lookup", None) or getattr(p, "prefetch_to", None) + for p in qs._prefetch_args + ] + + return Response( + {"deferred": qs._deferred_args, "prefetches": prefetch_names}, + status=status.HTTP_200_OK, + ) From fe661cb8a613c2d803fb81d61a231772e7509b58 Mon Sep 17 00:00:00 2001 From: Alberto Islas Date: Fri, 4 Jul 2025 15:30:33 -0600 Subject: [PATCH 2/4] fix(drf): Refined logic to defer top and nested fields --- drf_dynamic_fields/__init__.py | 111 ++++++++++++++++++--------------- tests/test_requests.py | 27 ++++++++ 2 files changed, 89 insertions(+), 49 deletions(-) diff --git a/drf_dynamic_fields/__init__.py b/drf_dynamic_fields/__init__.py index 2e70468..86851a9 100644 --- a/drf_dynamic_fields/__init__.py +++ b/drf_dynamic_fields/__init__.py @@ -22,9 +22,10 @@ def fields(self): A blank `fields` parameter (?fields) will remove all fields. Not passing `fields` will pass all fields individual fields are comma - separated (?fields=id,name,url,email). + separated (?fields=id,name,url,email,teachers__age). """ + fields = super(DynamicFieldsMixin, self).fields if not hasattr(self, "_context"): # We are being called before a request cycle @@ -38,6 +39,7 @@ def fields(self): ) if not (is_root or parent_is_list_root): return fields + try: request = self.context["request"] except KeyError: @@ -65,12 +67,12 @@ def fields(self): except AttributeError: omit_fields = [] - # Save for deferred logic self._flat_allow = set() self._flat_omit = set() self._nested_allow = {} self._nested_omit = {} + # store top-level and nested fields specified in the `fields` argument. for filtered_field in filter_fields: if "__" in filtered_field: parent, child = filtered_field.split("__", 1) @@ -81,6 +83,7 @@ def fields(self): else: self._flat_allow.add(filtered_field) + # store top-level and nested fields in the `omit` argument. for omitted_field in omit_fields: if "__" in omitted_field: parent, child = omitted_field.split("__", 1) @@ -105,7 +108,7 @@ def fields(self): return fields def to_representation(self, instance): - """Use this method to prune filtered fields from a nested serializer.""" + """This method prunes filtered fields from a nested serializer.""" representation = super(DynamicFieldsMixin, self).to_representation(instance) # Apply nested omit on dicts and lists of dicts @@ -114,8 +117,8 @@ def to_representation(self, instance): continue parent_instance = representation[parent] - # helper to drop keys on a single dict def do_omit(d): + """Helper to drop fields on a single dict""" for child in omit_list: d.pop(child, None) @@ -134,6 +137,7 @@ def do_omit(d): parent_instance = representation[parent] def do_allow(d): + """Helper to include fields allowed on a single dict""" return { field_name: field_value for field_name, field_value in d.items() @@ -150,83 +154,87 @@ def do_allow(d): return representation - def _flat_whitelist_deferred(self): + def _get_disallowed_top_level_fields_to_defer(self): """ Determine which top-level model fields should be deferred when an explicit - omit/fields filter is in use. + fields filter is in use. + Other model fields not explicitly included in 'fields' are deferred. """ allow = getattr(self, "_flat_allow", None) model = getattr(self.Meta, "model", None) if not allow or model is None: return [] - names = [ - fld.name - for fld in model._meta.get_fields() - if getattr(fld, "concrete", False) - ] - return [ - name - for name in names - if name not in allow and name not in getattr(self, "_flat_omit", []) + # Filter out fields that have a database column associated with them. + field_names = [ + field.name + for field in model._meta.get_fields() + if getattr(field, "concrete", False) ] + return [field_name for field_name in field_names if field_name not in allow] - def _nested_whitelist_deferred(self): + def _get_disallowed_nested_level_fields_to_defer(self): """ Determine which nested-model fields should be deferred for each parent serializer - when an explicit fields/omit filter is in use. + when an explicit fields filter is in use. + Other model nested fields not explicitly included in 'fields' are deferred. """ - results = [] + fields_to_defer = [] for parent, allow_list in getattr(self, "_nested_allow", {}).items(): field = self.fields.get(parent) if not field: continue - child_ser = getattr(field, "child", field) - nested_model = getattr(child_ser.Meta, "model", None) + child_serializer = getattr(field, "child", field) + nested_model = getattr(child_serializer.Meta, "model", None) if nested_model is None: continue - omitted = set(getattr(self, "_nested_omit", {}).get(parent, [])) - concrete_names = [ - fld.name - for fld in nested_model._meta.get_fields() - if getattr(fld, "concrete", False) + # Filter out nested fields that have a database column associated + # with them. + field_names = [ + field.name + for field in nested_model._meta.get_fields() + if getattr(field, "concrete", False) ] - for name in concrete_names: - if name not in allow_list and name not in omitted: - results.append(f"{parent}__{name}") + for field_name in field_names: + if field_name not in allow_list: + fields_to_defer.append(f"{parent}__{field_name}") - return results + return fields_to_defer def get_deferred_model_fields(self): """ - Returns flat list of omitted model-fields; top-level and nested. + Returns a flat list of omitted model-fields; top-level and nested. Ensures that parsing of "fields"/"omit" has run by accessing ".fields". """ - # Trigger parsing of _flat_omit and _nested_omit if not already set - if not hasattr(self, "_flat_omit") or not hasattr(self, "_nested_omit"): - _ = self.fields # trigger parsing + # Trigger parsing of required attributes if not already set + if not all( + hasattr(self, attr) + for attr in ("_flat_omit", "_nested_omit", "_flat_allow", "_nested_allow") + ): + _ = self.fields flat_omit = getattr(self, "_flat_omit", []) nested_omit = getattr(self, "_nested_omit", {}) - deferred = [] - # top-level + # Set omit top-level fields to defer deferred.extend(flat_omit) - # nested + + # Set omit nested-level fields to defer deferred.extend( f"{parent}__{child}" for parent, children in nested_omit.items() for child in children ) + # Set disallowed top-level fields to defer + deferred.extend(self._get_disallowed_top_level_fields_to_defer()) + # Set disallowed nested-level fields to defer + deferred.extend(self._get_disallowed_nested_level_fields_to_defer()) - deferred.extend(self._flat_whitelist_deferred()) - deferred.extend(self._nested_whitelist_deferred()) - - # dedupe and preserve order - return list(dict.fromkeys(deferred)) + # Remove any duplicate fields + return list(set(deferred)) class DeferredFieldsMixin: @@ -238,11 +246,12 @@ class DeferredFieldsMixin: @staticmethod def _split_deferred_fields(fields): """Split deferred fields into top‐level fields and nested relations.""" - parent, nested = [], {} + parent = [] + nested = {} for field in fields: if "__" in field: - rel, fld = field.split("__", 1) - nested.setdefault(rel, []).append(fld) + parent_field, child_field = field.split("__", 1) + nested.setdefault(parent_field, []).append(child_field) else: parent.append(field) return parent, nested @@ -252,22 +261,26 @@ def _apply_nested_prefetch(qs, nested_map, serializer): """For each nested relation, add a Prefetch that defers its specified fields. """ - for rel, child_fields in nested_map.items(): - field = serializer.fields.get(rel) + for parent_field, child_fields in nested_map.items(): + field = serializer.fields.get(parent_field) if not field: continue - child_ser = getattr(field, "child", field) - model = getattr(child_ser.Meta, "model", None) + child_serializer = getattr(field, "child", field) + model = getattr(child_serializer.Meta, "model", None) if not model: continue qs = qs.prefetch_related( - Prefetch(rel, queryset=model.objects.defer(*child_fields)) + Prefetch(parent_field, queryset=model.objects.defer(*child_fields)) ) return qs def get_queryset(self): + """ + Returns a queryset with top-level and nested fields deferred to + optimize database retrieval. + """ qs = super().get_queryset() # instantiate serializer so deferred fields are calculated serializer = self.get_serializer_class()(context=self.get_serializer_context()) diff --git a/tests/test_requests.py b/tests/test_requests.py index d4b2a72..d699d0b 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -139,3 +139,30 @@ def test_deffer_nested_no_selected_fields(self): prefetch = qs_mock._prefetch_args[0] deferred_fields, _ = prefetch.queryset.query.deferred_loading self.assertEqual(deferred_fields, {"age", "id"}) + + def test_deffer_nested_fields_combining_fields_and_omit(self): + """Omit and allowed fields used together are deferred.""" + qs_mock = self.make_qs_mock() + SchoolDeferredViewSet.queryset = qs_mock + + request = self.factory.get( + "/", + { + "fields": "id,name,teachers__name,teachers__age", + "omit": "name,teachers__age", + }, + ) + response = self.view(request) + self.assertEqual(response.status_code, 200) + + # name field should be deferred + self.assertEqual(qs_mock._deferred_args, ["name"]) + + # One Prefetch for teachers + self.assertEqual(response.data["prefetches"], ["teachers"]) + self.assertEqual(len(qs_mock._prefetch_args), 1) + + # Confirm the prefetch deferred field matches the no selected fields. + prefetch = qs_mock._prefetch_args[0] + deferred_fields, _ = prefetch.queryset.query.deferred_loading + self.assertEqual(deferred_fields, {"age", "id"}) From 1aaeaccaed62968f4fd5594fac29358926372243 Mon Sep 17 00:00:00 2001 From: Alberto Islas Date: Mon, 7 Jul 2025 13:16:18 -0600 Subject: [PATCH 3/4] fix(drf): Changed approach to drop child fields from the nested serializers --- drf_dynamic_fields/__init__.py | 60 +++++++++------------------------- 1 file changed, 15 insertions(+), 45 deletions(-) diff --git a/drf_dynamic_fields/__init__.py b/drf_dynamic_fields/__init__.py index 86851a9..bcaa416 100644 --- a/drf_dynamic_fields/__init__.py +++ b/drf_dynamic_fields/__init__.py @@ -105,54 +105,24 @@ def fields(self): if field in omitted: fields.pop(field, None) - return fields - - def to_representation(self, instance): - """This method prunes filtered fields from a nested serializer.""" - representation = super(DynamicFieldsMixin, self).to_representation(instance) - - # Apply nested omit on dicts and lists of dicts - for parent, omit_list in getattr(self, "_nested_omit", {}).items(): - if parent not in representation: - continue - parent_instance = representation[parent] - - def do_omit(d): - """Helper to drop fields on a single dict""" + # Drop omitted child fields from nested serializers + for parent, omit_list in self._nested_omit.items(): + field = fields[parent] + nested_serializer = getattr(field, "child", field) + if hasattr(nested_serializer, "fields"): for child in omit_list: - d.pop(child, None) + nested_serializer.fields.pop(child, None) - if isinstance(parent_instance, dict): - do_omit(parent_instance) - elif isinstance(parent_instance, list): - for item in parent_instance: - if isinstance(item, dict): - do_omit(item) + # Drop non-allowed child fields from the nested serializers + for parent, allow_list in self._nested_allow.items(): + field = fields[parent] + nested_serializer = getattr(field, "child", field) + if hasattr(nested_serializer, "fields"): + for child_name in list(nested_serializer.fields): + if child_name not in allow_list: + nested_serializer.fields.pop(child_name, None) - # Apply nested allow on dicts and lists of dicts - for parent, allow_list in getattr(self, "_nested_allow", {}).items(): - if parent not in representation: - continue - - parent_instance = representation[parent] - - def do_allow(d): - """Helper to include fields allowed on a single dict""" - return { - field_name: field_value - for field_name, field_value in d.items() - if field_name in allow_list - } - - if isinstance(parent_instance, dict): - representation[parent] = do_allow(parent_instance) - elif isinstance(parent_instance, list): - representation[parent] = [ - do_allow(item) if isinstance(item, dict) else item - for item in parent_instance - ] - - return representation + return fields def _get_disallowed_top_level_fields_to_defer(self): """ From 8d80a67d20b09e3b6fbb25217faf29d98e6906b0 Mon Sep 17 00:00:00 2001 From: Alberto Islas Date: Tue, 8 Jul 2025 15:57:19 -0600 Subject: [PATCH 4/4] fix(drf): Added support for deferring fields in viewsets with complex querysets. --- drf_dynamic_fields/__init__.py | 201 +++++++++++++++++++++------------ tests/models.py | 11 ++ tests/serializers.py | 22 +++- tests/test_requests.py | 97 ++++++++++------ tests/views.py | 21 +++- 5 files changed, 239 insertions(+), 113 deletions(-) diff --git a/drf_dynamic_fields/__init__.py b/drf_dynamic_fields/__init__.py index bcaa416..4623664 100644 --- a/drf_dynamic_fields/__init__.py +++ b/drf_dynamic_fields/__init__.py @@ -107,6 +107,8 @@ def fields(self): # Drop omitted child fields from nested serializers for parent, omit_list in self._nested_omit.items(): + if parent not in fields: + continue field = fields[parent] nested_serializer = getattr(field, "child", field) if hasattr(nested_serializer, "fields"): @@ -115,6 +117,8 @@ def fields(self): # Drop non-allowed child fields from the nested serializers for parent, allow_list in self._nested_allow.items(): + if parent not in fields or not allow_list: + continue field = fields[parent] nested_serializer = getattr(field, "child", field) if hasattr(nested_serializer, "fields"): @@ -124,33 +128,47 @@ def fields(self): return fields - def _get_disallowed_top_level_fields_to_defer(self): + def _filter_top_level_fields_to_defer(self, field_list, keep_if): + """Method to retrieve the top-level fields to defer, given a list of + field names (allow or omit) and a condition that determines which + fields to defer. """ - Determine which top-level model fields should be deferred when an explicit - fields filter is in use. - Other model fields not explicitly included in 'fields' are deferred. - """ - allow = getattr(self, "_flat_allow", None) model = getattr(self.Meta, "model", None) - if not allow or model is None: + if not field_list or model is None: return [] - # Filter out fields that have a database column associated with them. - field_names = [ - field.name - for field in model._meta.get_fields() - if getattr(field, "concrete", False) + all_fields = [ + f.name for f in model._meta.get_fields() if getattr(f, "concrete", False) ] - return [field_name for field_name in field_names if field_name not in allow] + return [name for name in all_fields if keep_if(name, field_list)] - def _get_disallowed_nested_level_fields_to_defer(self): + def _get_disallowed_top_level_fields_to_defer(self): + """Determine which top-level model fields should be deferred when an + explicit fields filter is in use. + Other model fields not explicitly included in 'fields' are deferred. + """ + allow = getattr(self, "_flat_allow", None) + return self._filter_top_level_fields_to_defer( + allow, keep_if=lambda name, allow: name not in allow + ) + + def _get_omit_top_level_fields_to_defer(self): """ - Determine which nested-model fields should be deferred for each parent serializer - when an explicit fields filter is in use. - Other model nested fields not explicitly included in 'fields' are deferred. + Determine which top-level model fields should be deferred when an + explicit omit filter is in use. Valid database fields in the omit list + will be deferred. + """ + omit = getattr(self, "_flat_omit", None) + return self._filter_top_level_fields_to_defer( + omit, keep_if=lambda name, omit: name in omit + ) + + def _get_nested_level_fields_to_defer(self, nested_mapping, should_defer): + """Method to retrieve nested fields to defer based on a mapping and a + defer condition. """ fields_to_defer = [] - for parent, allow_list in getattr(self, "_nested_allow", {}).items(): + for parent, items in nested_mapping.items(): field = self.fields.get(parent) if not field: continue @@ -163,16 +181,39 @@ def _get_disallowed_nested_level_fields_to_defer(self): # Filter out nested fields that have a database column associated # with them. field_names = [ - field.name - for field in nested_model._meta.get_fields() - if getattr(field, "concrete", False) + f.name + for f in nested_model._meta.get_fields() + if getattr(f, "concrete", False) ] - for field_name in field_names: - if field_name not in allow_list: - fields_to_defer.append(f"{parent}__{field_name}") + + # Determine which nested fields to defer + for name in field_names: + if should_defer(name, items): + fields_to_defer.append(f"{parent}__{name}") return fields_to_defer + def _get_disallowed_nested_level_fields_to_defer(self): + """Determine which top-level model fields should be deferred when an + explicit fields filter is in use. + Other model fields not explicitly included in 'fields' are deferred. + """ + allow_map = getattr(self, "_nested_allow", {}) + return self._get_nested_level_fields_to_defer( + allow_map, + should_defer=lambda name, allow_list: name not in allow_list, + ) + + def _get_omit_nested_level_fields_to_defer(self): + """ + Determine which nested-model fields should be deferred for each nested + serializer when an explicit omit filter is in use. + """ + omit_map = getattr(self, "_nested_omit", {}) + return self._get_nested_level_fields_to_defer( + omit_map, should_defer=lambda name, omit_list: name in omit_list + ) + def get_deferred_model_fields(self): """ Returns a flat list of omitted model-fields; top-level and nested. @@ -182,27 +223,21 @@ def get_deferred_model_fields(self): # Trigger parsing of required attributes if not already set if not all( hasattr(self, attr) - for attr in ("_flat_omit", "_nested_omit", "_flat_allow", "_nested_allow") + for attr in ( + "_flat_omit", + "_nested_omit", + "_flat_allow", + "_nested_allow", + ) ): _ = self.fields - flat_omit = getattr(self, "_flat_omit", []) - nested_omit = getattr(self, "_nested_omit", {}) - deferred = [] - # Set omit top-level fields to defer - deferred.extend(flat_omit) - - # Set omit nested-level fields to defer - deferred.extend( - f"{parent}__{child}" - for parent, children in nested_omit.items() - for child in children - ) - # Set disallowed top-level fields to defer - deferred.extend(self._get_disallowed_top_level_fields_to_defer()) - # Set disallowed nested-level fields to defer - deferred.extend(self._get_disallowed_nested_level_fields_to_defer()) - + deferred = [ + *self._get_omit_top_level_fields_to_defer(), + *self._get_omit_nested_level_fields_to_defer(), + *self._get_disallowed_top_level_fields_to_defer(), + *self._get_disallowed_nested_level_fields_to_defer(), + ] # Remove any duplicate fields return list(set(deferred)) @@ -210,7 +245,9 @@ def get_deferred_model_fields(self): class DeferredFieldsMixin: """ViewSet Mixin that: - defers top‐level model columns based on omit/fields - - builds a Prefetch for each nested relation to defer its columns too + - omit deferring select related fields + - builds a Prefetch for each nested relation to defer its columns, + merging cleanly with any existing prefetches in the right order. """ @staticmethod @@ -226,43 +263,59 @@ def _split_deferred_fields(fields): parent.append(field) return parent, nested - @staticmethod - def _apply_nested_prefetch(qs, nested_map, serializer): - """For each nested relation, add a Prefetch that defers its specified - fields. - """ + def get_queryset(self): + qs = super().get_queryset() + # Skip for queries that uses values() or annotate() + if qs.query.values_select or qs.query.annotations: + return qs + + existing_select_related = set(qs.query.select_related or ()) + original_only_or_deferred_fields, _ = qs.query.deferred_loading + serializer = self.get_serializer_class()(context=self.get_serializer_context()) + all_fields = serializer.get_deferred_model_fields() + parent_fields, nested_map = self._split_deferred_fields(all_fields) + + # Remove select_related fields from top level deferred fields + parent_defer_to_keep = set(parent_fields) - existing_select_related + + # Only defer top-level fields if neither only() nor defer() was used in + # the original queryset. + if not original_only_or_deferred_fields: + qs = qs.defer(*parent_defer_to_keep) + + # Prepare to rebuild prefetch_related in correct order + nested_prefetches = [] + existing_simple_lookups = list(qs._prefetch_related_lookups) for parent_field, child_fields in nested_map.items(): field = serializer.fields.get(parent_field) if not field: continue child_serializer = getattr(field, "child", field) - model = getattr(child_serializer.Meta, "model", None) - if not model: + child_model = getattr(child_serializer.Meta, "model", None) + parent_model = serializer.Meta.model + if not child_model or not parent_model: continue - qs = qs.prefetch_related( - Prefetch(parent_field, queryset=model.objects.defer(*child_fields)) - ) - return qs - - def get_queryset(self): - """ - Returns a queryset with top-level and nested fields deferred to - optimize database retrieval. - """ - qs = super().get_queryset() - # instantiate serializer so deferred fields are calculated - serializer = self.get_serializer_class()(context=self.get_serializer_context()) - - # split deferred fields into top-level and nested - fields = serializer.get_deferred_model_fields() - parent_fields, nested_map = self._split_deferred_fields(fields) - - # defer top-level fields - if parent_fields: - qs = qs.defer(*parent_fields) + # Look for FKs in the child serializer linked to the parent model. + # f.many_to_one is True for ForeignKey fields + # and f.remote_field.model is the parent model. + fk_names = { + f.name + for f in child_model._meta.get_fields() + if getattr(f, "many_to_one", False) + and getattr(f.remote_field, "model", None) == parent_model + } + child_fields = [f for f in child_fields if f not in fk_names] + if child_fields: + nested_prefetches.append( + Prefetch( + parent_field, + queryset=child_model.objects.defer(*child_fields), + ) + ) - # defer nested fields via Prefetch - qs = self._apply_nested_prefetch(qs, nested_map, serializer) - return qs + new_qs = qs._clone() + # Apply prefetches in the correct order to prevent conflicts. + new_qs._prefetch_related_lookups = nested_prefetches + existing_simple_lookups + return new_qs diff --git a/tests/models.py b/tests/models.py index f2bdb33..627b32a 100644 --- a/tests/models.py +++ b/tests/models.py @@ -24,3 +24,14 @@ class Child(models.Model): class Parent(models.Model): child = models.ForeignKey(Child, on_delete=models.CASCADE) + + +class GrantParent(models.Model): + name = models.CharField(max_length=30) + + +class ParentMany(models.Model): + name = models.CharField(max_length=30) + age = models.IntegerField() + grant_parent = models.ForeignKey(GrantParent, on_delete=models.CASCADE) + child = models.ManyToManyField(Child) \ No newline at end of file diff --git a/tests/serializers.py b/tests/serializers.py index 1e0b22d..3f0de92 100644 --- a/tests/serializers.py +++ b/tests/serializers.py @@ -4,9 +4,9 @@ from rest_framework import serializers -from drf_dynamic_fields import DynamicFieldsMixin, DeferredFieldsMixin +from drf_dynamic_fields import DynamicFieldsMixin -from .models import Teacher, School +from .models import Teacher, School, ParentMany, Child class TeacherSerializer(DynamicFieldsMixin, serializers.ModelSerializer): @@ -31,7 +31,7 @@ def get_request_info(self, teacher): class SchoolSerializer( - DynamicFieldsMixin, DeferredFieldsMixin, serializers.ModelSerializer + DynamicFieldsMixin, serializers.ModelSerializer ): """ Interesting enough serializer because the TeacherSerializer @@ -49,7 +49,23 @@ class ChildSerializer(DynamicFieldsMixin, serializers.Serializer): secret = serializers.CharField() public = serializers.CharField() + class Meta: + model = Child + class ParentSerializer(DynamicFieldsMixin, serializers.Serializer): id = serializers.IntegerField() child = ChildSerializer() + +class GrantParentSerializer(serializers.Serializer): + name = serializers.CharField() + +class ParentManySerializer( + DynamicFieldsMixin, serializers.ModelSerializer +): + grant_parent = GrantParentSerializer(read_only=True) + child = ChildSerializer(many=True, read_only=True) + + class Meta: + model = ParentMany + fields = ("id", "name", "age","grant_parent", "child") \ No newline at end of file diff --git a/tests/test_requests.py b/tests/test_requests.py index d699d0b..4ebce01 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -7,17 +7,16 @@ Test for the full request cycle using dynamic fields mixns """ -from collections import OrderedDict from unittest.mock import MagicMock -from django.db.models import Prefetch +import types from django.test import TestCase, RequestFactory +from django.db.models import Prefetch from rest_framework.reverse import reverse -from .serializers import SchoolSerializer, TeacherSerializer from .models import Teacher, School -from .views import SchoolDeferredViewSet +from .views import SchoolDeferredViewSet, ParentManyDeferredViewSet class TestDynamicFieldsViews(TestCase): @@ -71,31 +70,31 @@ class DeferredFieldsMixinTests(TestCase): def setUp(self): self.factory = RequestFactory() self.view = SchoolDeferredViewSet.as_view({"get": "list"}) + self.view_parent = ParentManyDeferredViewSet.as_view({"get": "list"}) @staticmethod def make_qs_mock(): - """ - Mock that records .defer() and .prefetch_related() calls for testing. - """ - qs_mock = MagicMock() - qs_mock._deferred_args = [] - qs_mock._prefetch_args = [] - - def mock_defer(*args): - qs_mock._deferred_args.extend(args) - return qs_mock + qs = MagicMock() + + qs.query = types.SimpleNamespace( + values_select=False, + annotations={}, + select_related=[], + deferred_loading=([], []), + deferred_fields=set(), + ) - qs_mock.defer.side_effect = mock_defer + qs._deferred_args = [] + def _defer(*fields): + qs._deferred_args.extend(fields) + return qs - def mock_prefetch(*lookups): - for lookup in lookups: - if isinstance(lookup, Prefetch): - qs_mock._prefetch_args.append(lookup) - return qs_mock + qs.defer.side_effect = _defer - qs_mock.prefetch_related.side_effect = mock_prefetch + qs._prefetch_related_lookups = [] + qs._clone.return_value = qs - return qs_mock + return qs def test_deffer_nested_omitted_fields(self): """Nested level fields omitted should be deferred.""" @@ -111,12 +110,15 @@ def test_deffer_nested_omitted_fields(self): self.assertEqual(qs_mock._deferred_args, ["name"]) # One Prefetch on teachers - self.assertEqual(response.data["prefetches"], ["teachers"]) - self.assertEqual(len(qs_mock._prefetch_args), 1) + prefetches = [ + p for p in qs_mock._prefetch_related_lookups + if isinstance(p, Prefetch) + ] + self.assertEqual(len(prefetches), 1) + self.assertEqual(prefetches[0].prefetch_to, "teachers") # Confirm the prefetch deferred field matches. - prefetch = qs_mock._prefetch_args[0] - deferred_fields, _ = prefetch.queryset.query.deferred_loading + deferred_fields, _ = prefetches[0].queryset.query.deferred_loading self.assertIn("age", deferred_fields) def test_deffer_nested_no_selected_fields(self): @@ -132,12 +134,15 @@ def test_deffer_nested_no_selected_fields(self): self.assertEqual(qs_mock._deferred_args, ["name"]) # One Prefetch for teachers - self.assertEqual(response.data["prefetches"], ["teachers"]) - self.assertEqual(len(qs_mock._prefetch_args), 1) + prefetches = [ + p for p in qs_mock._prefetch_related_lookups + if isinstance(p, Prefetch) + ] + self.assertEqual(len(prefetches), 1) + self.assertEqual(prefetches[0].prefetch_to, "teachers") # Confirm the prefetch deferred field matches the no selected fields. - prefetch = qs_mock._prefetch_args[0] - deferred_fields, _ = prefetch.queryset.query.deferred_loading + deferred_fields, _ = prefetches[0].queryset.query.deferred_loading self.assertEqual(deferred_fields, {"age", "id"}) def test_deffer_nested_fields_combining_fields_and_omit(self): @@ -159,10 +164,34 @@ def test_deffer_nested_fields_combining_fields_and_omit(self): self.assertEqual(qs_mock._deferred_args, ["name"]) # One Prefetch for teachers - self.assertEqual(response.data["prefetches"], ["teachers"]) - self.assertEqual(len(qs_mock._prefetch_args), 1) + prefetches = [ + p for p in qs_mock._prefetch_related_lookups + if isinstance(p, Prefetch) + ] + self.assertEqual(len(prefetches), 1) + self.assertEqual(prefetches[0].prefetch_to, "teachers") # Confirm the prefetch deferred field matches the no selected fields. - prefetch = qs_mock._prefetch_args[0] - deferred_fields, _ = prefetch.queryset.query.deferred_loading + deferred_fields, _ = prefetches[0].queryset.query.deferred_loading self.assertEqual(deferred_fields, {"age", "id"}) + + def test_deffer_fields_custom_queryset(self): + """ Confirms that the deferring fields logic works correctly with a custom + queryset that uses select_related and prefetch_related. + """ + request = self.factory.get( + "/", + { + "fields": "name,age,child__secret,child__public", + "omit": "age,child__public", + }, + ) + response = self.view_parent(request) + self.assertEqual(response.status_code, 200) + + deferred_fields = response.data["deferred"] + self.assertEqual(deferred_fields, {"age", "id"}) + + prefetches = response.data["prefetches"] + self.assertEqual(len(prefetches), 2) + self.assertIn("child", prefetches) diff --git a/tests/views.py b/tests/views.py index 22d7390..86fd5ea 100644 --- a/tests/views.py +++ b/tests/views.py @@ -2,8 +2,8 @@ from rest_framework.response import Response from rest_framework import status -from .models import School, Teacher -from .serializers import SchoolSerializer, TeacherSerializer +from .models import School, Teacher, ParentMany +from .serializers import SchoolSerializer, TeacherSerializer, ParentManySerializer from drf_dynamic_fields import DeferredFieldsMixin @@ -19,6 +19,7 @@ class SchoolViewSet(ModelViewSet): class SchoolDeferredViewSet(DeferredFieldsMixin, ModelViewSet): serializer_class = SchoolSerializer + queryset = School.objects.all() def list(self, request): qs = self.get_queryset() @@ -31,3 +32,19 @@ def list(self, request): {"deferred": qs._deferred_args, "prefetches": prefetch_names}, status=status.HTTP_200_OK, ) + +class ParentManyDeferredViewSet(DeferredFieldsMixin, ModelViewSet): + serializer_class = ParentManySerializer + queryset = ParentMany.objects.select_related( + "grant_parent", + ).prefetch_related( + "child", + ).order_by("-id") + + def list(self, request): + qs = self.get_queryset() + deferred_fields, _ = qs.query.deferred_loading + return Response( + {"deferred": deferred_fields, "prefetches": qs._prefetch_related_lookups}, + status=status.HTTP_200_OK, + ) \ No newline at end of file