diff --git a/.gitignore b/.gitignore index 39ca09e..82934a3 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist/ *.egg-info/ build/ .tox/ +.idea \ No newline at end of file diff --git a/drf_dynamic_fields/__init__.py b/drf_dynamic_fields/__init__.py index 97694e4..4623664 100644 --- a/drf_dynamic_fields/__init__.py +++ b/drf_dynamic_fields/__init__.py @@ -1,9 +1,11 @@ """ Mixin to dynamically select only a subset of fields per DRF resource. """ + import warnings from django.conf import settings +from django.db.models import Prefetch from django.utils.functional import cached_property @@ -20,11 +22,11 @@ def fields(self): A blank `fields` parameter (?fields) will remove all fields. Not passing `fields` will pass all fields individual fields are comma - separated (?fields=id,name,url,email). + separated (?fields=id,name,url,email,teachers__age). """ - fields = super(DynamicFieldsMixin, self).fields + fields = super(DynamicFieldsMixin, self).fields if not hasattr(self, "_context"): # We are being called before a request cycle return fields @@ -58,30 +60,262 @@ def fields(self): try: filter_fields = params.get("fields", None).split(",") except AttributeError: - filter_fields = None + filter_fields = [] try: omit_fields = params.get("omit", None).split(",") except AttributeError: omit_fields = [] - # Drop any fields that are not specified in the `fields` argument. + self._flat_allow = set() + self._flat_omit = set() + self._nested_allow = {} + self._nested_omit = {} + + # store top-level and nested fields specified in the `fields` argument. + for filtered_field in filter_fields: + if "__" in filtered_field: + parent, child = filtered_field.split("__", 1) + self._nested_allow.setdefault(parent, []).append(child) + # If a nested field is allowed the related parent level field + # must also be allowed + self._flat_allow.add(parent) + else: + self._flat_allow.add(filtered_field) + + # store top-level and nested fields in the `omit` argument. + for omitted_field in omit_fields: + if "__" in omitted_field: + parent, child = omitted_field.split("__", 1) + self._nested_omit.setdefault(parent, []).append(child) + else: + self._flat_omit.add(omitted_field) + + # Drop top-level fields existing = set(fields.keys()) - if filter_fields is None: - # no fields param given, don't filter. - allowed = existing + if "fields" in params: + allowed = self._flat_allow else: - allowed = set(filter(None, filter_fields)) - - # omit fields in the `omit` argument. - omitted = set(filter(None, omit_fields)) + allowed = existing + omitted = self._flat_omit for field in existing: - if field not in allowed: fields.pop(field, None) - if field in omitted: fields.pop(field, None) + # Drop omitted child fields from nested serializers + for parent, omit_list in self._nested_omit.items(): + if parent not in fields: + continue + field = fields[parent] + nested_serializer = getattr(field, "child", field) + if hasattr(nested_serializer, "fields"): + for child in omit_list: + nested_serializer.fields.pop(child, None) + + # Drop non-allowed child fields from the nested serializers + for parent, allow_list in self._nested_allow.items(): + if parent not in fields or not allow_list: + continue + field = fields[parent] + nested_serializer = getattr(field, "child", field) + if hasattr(nested_serializer, "fields"): + for child_name in list(nested_serializer.fields): + if child_name not in allow_list: + nested_serializer.fields.pop(child_name, None) + return fields + + def _filter_top_level_fields_to_defer(self, field_list, keep_if): + """Method to retrieve the top-level fields to defer, given a list of + field names (allow or omit) and a condition that determines which + fields to defer. + """ + model = getattr(self.Meta, "model", None) + if not field_list or model is None: + return [] + + all_fields = [ + f.name for f in model._meta.get_fields() if getattr(f, "concrete", False) + ] + return [name for name in all_fields if keep_if(name, field_list)] + + def _get_disallowed_top_level_fields_to_defer(self): + """Determine which top-level model fields should be deferred when an + explicit fields filter is in use. + Other model fields not explicitly included in 'fields' are deferred. + """ + allow = getattr(self, "_flat_allow", None) + return self._filter_top_level_fields_to_defer( + allow, keep_if=lambda name, allow: name not in allow + ) + + def _get_omit_top_level_fields_to_defer(self): + """ + Determine which top-level model fields should be deferred when an + explicit omit filter is in use. Valid database fields in the omit list + will be deferred. + """ + omit = getattr(self, "_flat_omit", None) + return self._filter_top_level_fields_to_defer( + omit, keep_if=lambda name, omit: name in omit + ) + + def _get_nested_level_fields_to_defer(self, nested_mapping, should_defer): + """Method to retrieve nested fields to defer based on a mapping and a + defer condition. + """ + fields_to_defer = [] + for parent, items in nested_mapping.items(): + field = self.fields.get(parent) + if not field: + continue + + child_serializer = getattr(field, "child", field) + nested_model = getattr(child_serializer.Meta, "model", None) + if nested_model is None: + continue + + # Filter out nested fields that have a database column associated + # with them. + field_names = [ + f.name + for f in nested_model._meta.get_fields() + if getattr(f, "concrete", False) + ] + + # Determine which nested fields to defer + for name in field_names: + if should_defer(name, items): + fields_to_defer.append(f"{parent}__{name}") + + return fields_to_defer + + def _get_disallowed_nested_level_fields_to_defer(self): + """Determine which top-level model fields should be deferred when an + explicit fields filter is in use. + Other model fields not explicitly included in 'fields' are deferred. + """ + allow_map = getattr(self, "_nested_allow", {}) + return self._get_nested_level_fields_to_defer( + allow_map, + should_defer=lambda name, allow_list: name not in allow_list, + ) + + def _get_omit_nested_level_fields_to_defer(self): + """ + Determine which nested-model fields should be deferred for each nested + serializer when an explicit omit filter is in use. + """ + omit_map = getattr(self, "_nested_omit", {}) + return self._get_nested_level_fields_to_defer( + omit_map, should_defer=lambda name, omit_list: name in omit_list + ) + + def get_deferred_model_fields(self): + """ + Returns a flat list of omitted model-fields; top-level and nested. + Ensures that parsing of "fields"/"omit" has run by accessing ".fields". + """ + + # Trigger parsing of required attributes if not already set + if not all( + hasattr(self, attr) + for attr in ( + "_flat_omit", + "_nested_omit", + "_flat_allow", + "_nested_allow", + ) + ): + _ = self.fields + + deferred = [ + *self._get_omit_top_level_fields_to_defer(), + *self._get_omit_nested_level_fields_to_defer(), + *self._get_disallowed_top_level_fields_to_defer(), + *self._get_disallowed_nested_level_fields_to_defer(), + ] + # Remove any duplicate fields + return list(set(deferred)) + + +class DeferredFieldsMixin: + """ViewSet Mixin that: + - defers top‐level model columns based on omit/fields + - omit deferring select related fields + - builds a Prefetch for each nested relation to defer its columns, + merging cleanly with any existing prefetches in the right order. + """ + + @staticmethod + def _split_deferred_fields(fields): + """Split deferred fields into top‐level fields and nested relations.""" + parent = [] + nested = {} + for field in fields: + if "__" in field: + parent_field, child_field = field.split("__", 1) + nested.setdefault(parent_field, []).append(child_field) + else: + parent.append(field) + return parent, nested + + def get_queryset(self): + qs = super().get_queryset() + # Skip for queries that uses values() or annotate() + if qs.query.values_select or qs.query.annotations: + return qs + + existing_select_related = set(qs.query.select_related or ()) + original_only_or_deferred_fields, _ = qs.query.deferred_loading + serializer = self.get_serializer_class()(context=self.get_serializer_context()) + all_fields = serializer.get_deferred_model_fields() + parent_fields, nested_map = self._split_deferred_fields(all_fields) + + # Remove select_related fields from top level deferred fields + parent_defer_to_keep = set(parent_fields) - existing_select_related + + # Only defer top-level fields if neither only() nor defer() was used in + # the original queryset. + if not original_only_or_deferred_fields: + qs = qs.defer(*parent_defer_to_keep) + + # Prepare to rebuild prefetch_related in correct order + nested_prefetches = [] + existing_simple_lookups = list(qs._prefetch_related_lookups) + for parent_field, child_fields in nested_map.items(): + field = serializer.fields.get(parent_field) + if not field: + continue + + child_serializer = getattr(field, "child", field) + child_model = getattr(child_serializer.Meta, "model", None) + parent_model = serializer.Meta.model + if not child_model or not parent_model: + continue + + # Look for FKs in the child serializer linked to the parent model. + # f.many_to_one is True for ForeignKey fields + # and f.remote_field.model is the parent model. + fk_names = { + f.name + for f in child_model._meta.get_fields() + if getattr(f, "many_to_one", False) + and getattr(f.remote_field, "model", None) == parent_model + } + child_fields = [f for f in child_fields if f not in fk_names] + if child_fields: + nested_prefetches.append( + Prefetch( + parent_field, + queryset=child_model.objects.defer(*child_fields), + ) + ) + + new_qs = qs._clone() + # Apply prefetches in the correct order to prevent conflicts. + new_qs._prefetch_related_lookups = nested_prefetches + existing_simple_lookups + return new_qs diff --git a/setup.py b/setup.py index 7fa9229..1a70244 100644 --- a/setup.py +++ b/setup.py @@ -1,24 +1,25 @@ from setuptools import setup -readme = open('README.rst').read() +readme = open("README.rst").read() -setup(name='drf_dynamic_fields', - version='0.4.0', - description='Dynamically return subset of Django REST Framework serializer fields', - author='Danilo Bargen', - author_email='mail@dbrgn.ch', - url='https://github.com/dbrgn/drf-dynamic-fields', - packages=['drf_dynamic_fields'], - zip_safe=True, - include_package_data=True, - license='MIT', - keywords='drf restframework rest_framework django_rest_framework serializers', - long_description=readme, - classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3', - 'Framework :: Django', - 'Environment :: Web Environment', - ], +setup( + name="drf_dynamic_fields", + version="0.4.0", + description="Dynamically return subset of Django REST Framework serializer fields", + author="Danilo Bargen", + author_email="mail@dbrgn.ch", + url="https://github.com/dbrgn/drf-dynamic-fields", + packages=["drf_dynamic_fields"], + zip_safe=True, + include_package_data=True, + license="MIT", + keywords="drf restframework rest_framework django_rest_framework serializers", + long_description=readme, + classifiers=[ + "Development Status :: 5 - Production/Stable", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Framework :: Django", + "Environment :: Web Environment", + ], ) diff --git a/tests/models.py b/tests/models.py index 714098e..627b32a 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,6 +1,7 @@ """ Some models for the tests. We are modelling a school. """ + from django.db import models @@ -14,3 +15,23 @@ class School(models.Model): name = models.CharField(max_length=30) teachers = models.ManyToManyField(Teacher) + + +class Child(models.Model): + secret = models.CharField(max_length=100) + public = models.CharField(max_length=100) + + +class Parent(models.Model): + child = models.ForeignKey(Child, on_delete=models.CASCADE) + + +class GrantParent(models.Model): + name = models.CharField(max_length=30) + + +class ParentMany(models.Model): + name = models.CharField(max_length=30) + age = models.IntegerField() + grant_parent = models.ForeignKey(GrantParent, on_delete=models.CASCADE) + child = models.ManyToManyField(Child) \ No newline at end of file diff --git a/tests/serializers.py b/tests/serializers.py index 3619d59..3f0de92 100644 --- a/tests/serializers.py +++ b/tests/serializers.py @@ -1,11 +1,12 @@ """ For the tests. """ + from rest_framework import serializers from drf_dynamic_fields import DynamicFieldsMixin -from .models import Teacher, School +from .models import Teacher, School, ParentMany, Child class TeacherSerializer(DynamicFieldsMixin, serializers.ModelSerializer): @@ -29,7 +30,9 @@ def get_request_info(self, teacher): return request.build_absolute_uri("/api/v1/teacher/{}".format(teacher.pk)) -class SchoolSerializer(DynamicFieldsMixin, serializers.ModelSerializer): +class SchoolSerializer( + DynamicFieldsMixin, serializers.ModelSerializer +): """ Interesting enough serializer because the TeacherSerializer will use ListSerializer due to the `many=True` @@ -40,3 +43,29 @@ class SchoolSerializer(DynamicFieldsMixin, serializers.ModelSerializer): class Meta: model = School fields = ("id", "teachers", "name") + + +class ChildSerializer(DynamicFieldsMixin, serializers.Serializer): + secret = serializers.CharField() + public = serializers.CharField() + + class Meta: + model = Child + + +class ParentSerializer(DynamicFieldsMixin, serializers.Serializer): + id = serializers.IntegerField() + child = ChildSerializer() + +class GrantParentSerializer(serializers.Serializer): + name = serializers.CharField() + +class ParentManySerializer( + DynamicFieldsMixin, serializers.ModelSerializer +): + grant_parent = GrantParentSerializer(read_only=True) + child = ChildSerializer(many=True, read_only=True) + + class Meta: + model = ParentMany + fields = ("id", "name", "age","grant_parent", "child") \ No newline at end of file diff --git a/tests/test_mixins.py b/tests/test_mixins.py index f54e179..bf5a83c 100644 --- a/tests/test_mixins.py +++ b/tests/test_mixins.py @@ -11,8 +11,8 @@ from django.test import TestCase, RequestFactory -from .serializers import SchoolSerializer, TeacherSerializer -from .models import Teacher, School +from .serializers import SchoolSerializer, TeacherSerializer, ParentSerializer +from .models import Teacher, School, Child, Parent class TestDynamicFieldsMixin(TestCase): @@ -20,6 +20,37 @@ class TestDynamicFieldsMixin(TestCase): Test case for the DynamicFieldsMixin """ + def _assert_nested_fields(self, data, expected_fields): + """ + Assert nested fields match the expected fields. + """ + for parent, nested_fields in expected_fields.items(): + with self.subTest(parent=parent): + items = data[parent] + if nested_fields is None: + continue + expected_set = set(nested_fields) + for obj in items: + with self.subTest(parent=parent): + actual_set = set(obj.keys()) + self.assertEqual( + actual_set, + expected_set, + f"{parent} fields mismatch: expected " + f"exactly {nested_fields}, got {list(obj.keys())}", + ) + + @staticmethod + def _prepare_school_instance(): + """Prepare school instance for testing.""" + school = School.objects.create(name="Python Heights High") + teachers = [ + Teacher.objects.create(name="Shane", age=45), + Teacher.objects.create(name="Kaz", age=29), + ] + school.teachers.add(*teachers) + return school + def test_removes_fields(self): """ Does it actually remove fields? @@ -175,3 +206,162 @@ def test_serializer_reuse_with_changing_request(self): request2 = rf.get("/api/v1/schools/1/?fields=id,name") serializer.context["request"] = request2 self.assertEqual(set(serializer.fields.keys()), {"id"}) + + def test_omit_nested_field(self): + """Omitting a nested field""" + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?omit=name,teachers__age") + + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + data = serializer.data + + # Confirm omitted fields are in deferred list + deferred = serializer.get_deferred_model_fields() + self.assertIn("name", deferred) + self.assertIn("teachers__age", deferred) + + expected_fields = {"id": None, "teachers": ["id", "name", "request_info"]} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_omit_everything_nested_field(self): + """Omitting all fields within a nested field""" + rf = RequestFactory() + request = rf.get( + "/api/v1/schools/1/?omit=teachers__id,teachers__age,teachers__name,teachers__request_info" + ) + + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + data = serializer.data + + expected_fields = {"id": None, "name": None, "teachers": []} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_omit_top_field_and_keep_all_nested_fields(self): + """Omitting a top-level field while keeping all nested fields""" + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?omit=name") + + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + data = serializer.data + + expected_fields = { + "id": None, + "teachers": ["id", "name", "request_info", "age"], + } + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_allow_nested_field(self): + """Select only the requested fields, including nested-level fields.""" + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?fields=id,teachers__age") + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + + # Confirm omitted fields are in deferred list + deferred = serializer.get_deferred_model_fields() + self.assertIn("name", deferred) + + data = serializer.data + expected_fields = {"id": None, "teachers": ["age"]} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_fields_all_gone_nested(self): + """If no fields are selected, all fields are omitted, including those + from the nested serializer. + """ + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?fields") + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + + data = serializer.data + expected_fields = {} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_nested_omit_and_fields_used(self): + """Omit and fields can be used together at the nested field level.""" + rf = RequestFactory() + request = rf.get( + "/api/v1/schools/1/?fields=id,name,teachers__name,teachers__age&omit=name,teachers__name" + ) + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + + data = serializer.data + expected_fields = {"id": None, "teachers": ["age"]} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_omit_nothing_nested(self): + """ + Blank omit doesn't affect nested fields. + """ + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?omit") + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + + data = serializer.data + expected_fields = { + "id": None, + "name": None, + "teachers": ["id", "age", "name", "request_info"], + } + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_single_nested_instance_omit_field(self): + """Omit also works for filtering fields on single nested instances""" + child = Child(secret="secret_key", public="public_key") + parent = Parent(id=1, child=child) + rf = RequestFactory() + request = rf.get("/api/v1/parent/1/?omit=id,child__secret") + serializer = ParentSerializer(parent, context={"request": request}) + data = serializer.data + + self.assertNotIn("id", data) + self.assertNotIn("secret", data["child"]) + self.assertEqual(data["child"]["public"], "public_key") + + def test_single_nested_instance_allow_field(self): + """Fields selection also works for filtering fields on single nested instances""" + child = Child(secret="secret_key", public="public_key") + parent = Parent(id=1, child=child) + rf = RequestFactory() + request = rf.get("/api/v1/parent/1/?fields=id,child__secret") + serializer = ParentSerializer(parent, context={"request": request}) + data = serializer.data + + self.assertEqual(data["id"], 1) + self.assertIn("secret", data["child"]) + self.assertEqual(data["child"]["secret"], "secret_key") + self.assertNotIn("public", data["child"]) diff --git a/tests/test_requests.py b/tests/test_requests.py index 38e8e22..4ebce01 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -7,14 +7,16 @@ Test for the full request cycle using dynamic fields mixns """ -from collections import OrderedDict +from unittest.mock import MagicMock +import types from django.test import TestCase, RequestFactory +from django.db.models import Prefetch from rest_framework.reverse import reverse -from .serializers import SchoolSerializer, TeacherSerializer from .models import Teacher, School +from .views import SchoolDeferredViewSet, ParentManyDeferredViewSet class TestDynamicFieldsViews(TestCase): @@ -62,3 +64,134 @@ def test_nested_teacher_fields(self): self.assertEqual( school["teachers"][0].keys(), {"id", "request_info", "age", "name"} ) + + +class DeferredFieldsMixinTests(TestCase): + def setUp(self): + self.factory = RequestFactory() + self.view = SchoolDeferredViewSet.as_view({"get": "list"}) + self.view_parent = ParentManyDeferredViewSet.as_view({"get": "list"}) + + @staticmethod + def make_qs_mock(): + qs = MagicMock() + + qs.query = types.SimpleNamespace( + values_select=False, + annotations={}, + select_related=[], + deferred_loading=([], []), + deferred_fields=set(), + ) + + qs._deferred_args = [] + def _defer(*fields): + qs._deferred_args.extend(fields) + return qs + + qs.defer.side_effect = _defer + + qs._prefetch_related_lookups = [] + qs._clone.return_value = qs + + return qs + + def test_deffer_nested_omitted_fields(self): + """Nested level fields omitted should be deferred.""" + + qs_mock = self.make_qs_mock() + SchoolDeferredViewSet.queryset = qs_mock + + request = self.factory.get("/", {"omit": "name,teachers__age"}) + response = self.view(request) + self.assertEqual(response.status_code, 200) + + # top level name field should be deferred + self.assertEqual(qs_mock._deferred_args, ["name"]) + + # One Prefetch on teachers + prefetches = [ + p for p in qs_mock._prefetch_related_lookups + if isinstance(p, Prefetch) + ] + self.assertEqual(len(prefetches), 1) + self.assertEqual(prefetches[0].prefetch_to, "teachers") + + # Confirm the prefetch deferred field matches. + deferred_fields, _ = prefetches[0].queryset.query.deferred_loading + self.assertIn("age", deferred_fields) + + def test_deffer_nested_no_selected_fields(self): + """No selected nested level fields should be deferred.""" + qs_mock = self.make_qs_mock() + SchoolDeferredViewSet.queryset = qs_mock + + request = self.factory.get("/", {"fields": "id,teachers__name"}) + response = self.view(request) + self.assertEqual(response.status_code, 200) + + # name field should be deferred + self.assertEqual(qs_mock._deferred_args, ["name"]) + + # One Prefetch for teachers + prefetches = [ + p for p in qs_mock._prefetch_related_lookups + if isinstance(p, Prefetch) + ] + self.assertEqual(len(prefetches), 1) + self.assertEqual(prefetches[0].prefetch_to, "teachers") + + # Confirm the prefetch deferred field matches the no selected fields. + deferred_fields, _ = prefetches[0].queryset.query.deferred_loading + self.assertEqual(deferred_fields, {"age", "id"}) + + def test_deffer_nested_fields_combining_fields_and_omit(self): + """Omit and allowed fields used together are deferred.""" + qs_mock = self.make_qs_mock() + SchoolDeferredViewSet.queryset = qs_mock + + request = self.factory.get( + "/", + { + "fields": "id,name,teachers__name,teachers__age", + "omit": "name,teachers__age", + }, + ) + response = self.view(request) + self.assertEqual(response.status_code, 200) + + # name field should be deferred + self.assertEqual(qs_mock._deferred_args, ["name"]) + + # One Prefetch for teachers + prefetches = [ + p for p in qs_mock._prefetch_related_lookups + if isinstance(p, Prefetch) + ] + self.assertEqual(len(prefetches), 1) + self.assertEqual(prefetches[0].prefetch_to, "teachers") + + # Confirm the prefetch deferred field matches the no selected fields. + deferred_fields, _ = prefetches[0].queryset.query.deferred_loading + self.assertEqual(deferred_fields, {"age", "id"}) + + def test_deffer_fields_custom_queryset(self): + """ Confirms that the deferring fields logic works correctly with a custom + queryset that uses select_related and prefetch_related. + """ + request = self.factory.get( + "/", + { + "fields": "name,age,child__secret,child__public", + "omit": "age,child__public", + }, + ) + response = self.view_parent(request) + self.assertEqual(response.status_code, 200) + + deferred_fields = response.data["deferred"] + self.assertEqual(deferred_fields, {"age", "id"}) + + prefetches = response.data["prefetches"] + self.assertEqual(len(prefetches), 2) + self.assertIn("child", prefetches) diff --git a/tests/views.py b/tests/views.py index 16de45e..86fd5ea 100644 --- a/tests/views.py +++ b/tests/views.py @@ -1,7 +1,10 @@ from rest_framework.viewsets import ModelViewSet +from rest_framework.response import Response +from rest_framework import status -from .models import School, Teacher -from .serializers import SchoolSerializer, TeacherSerializer +from .models import School, Teacher, ParentMany +from .serializers import SchoolSerializer, TeacherSerializer, ParentManySerializer +from drf_dynamic_fields import DeferredFieldsMixin class TeacherViewSet(ModelViewSet): @@ -12,3 +15,36 @@ class TeacherViewSet(ModelViewSet): class SchoolViewSet(ModelViewSet): queryset = School.objects.all() serializer_class = SchoolSerializer + + +class SchoolDeferredViewSet(DeferredFieldsMixin, ModelViewSet): + serializer_class = SchoolSerializer + queryset = School.objects.all() + + def list(self, request): + qs = self.get_queryset() + prefetch_names = [ + getattr(p, "lookup", None) or getattr(p, "prefetch_to", None) + for p in qs._prefetch_args + ] + + return Response( + {"deferred": qs._deferred_args, "prefetches": prefetch_names}, + status=status.HTTP_200_OK, + ) + +class ParentManyDeferredViewSet(DeferredFieldsMixin, ModelViewSet): + serializer_class = ParentManySerializer + queryset = ParentMany.objects.select_related( + "grant_parent", + ).prefetch_related( + "child", + ).order_by("-id") + + def list(self, request): + qs = self.get_queryset() + deferred_fields, _ = qs.query.deferred_loading + return Response( + {"deferred": deferred_fields, "prefetches": qs._prefetch_related_lookups}, + status=status.HTTP_200_OK, + ) \ No newline at end of file