diff --git a/.gitignore b/.gitignore index 39ca09e..203fcd0 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist/ *.egg-info/ build/ .tox/ +.idea diff --git a/drf_dynamic_fields/__init__.py b/drf_dynamic_fields/__init__.py index 97694e4..e2a819e 100644 --- a/drf_dynamic_fields/__init__.py +++ b/drf_dynamic_fields/__init__.py @@ -1,8 +1,10 @@ """ Mixin to dynamically select only a subset of fields per DRF resource. """ + import warnings +from collections import defaultdict from django.conf import settings from django.utils.functional import cached_property @@ -20,7 +22,7 @@ def fields(self): A blank `fields` parameter (?fields) will remove all fields. Not passing `fields` will pass all fields individual fields are comma - separated (?fields=id,name,url,email). + separated (?fields=id,name,url,email,teachers__age). """ fields = super(DynamicFieldsMixin, self).fields @@ -58,7 +60,7 @@ def fields(self): try: filter_fields = params.get("fields", None).split(",") except AttributeError: - filter_fields = None + filter_fields = [] try: omit_fields = params.get("omit", None).split(",") @@ -66,15 +68,38 @@ def fields(self): omit_fields = [] # Drop any fields that are not specified in the `fields` argument. + self._nested_allow = defaultdict(list) + self._nested_omit = defaultdict(list) + self._flat_allow = set() + self._flat_omit = set() + + # store top-level and nested fields specified in the `fields` argument. + for filtered_field in filter_fields: + if "__" in filtered_field: + parent, child = filtered_field.split("__", 1) + self._nested_allow[parent].append(child) + # If a nested field is allowed the related parent level field + # must also be allowed + self._flat_allow.add(parent) + else: + self._flat_allow.add(filtered_field) + + # store top-level and nested fields in the `omit` argument. + for omitted_field in omit_fields: + if "__" in omitted_field: + parent, child = omitted_field.split("__", 1) + self._nested_omit[parent].append(child) + else: + self._flat_omit.add(omitted_field) + + # Drop top-level fields existing = set(fields.keys()) - if filter_fields is None: + if "fields" in params: + allowed = self._flat_allow + else: # no fields param given, don't filter. allowed = existing - else: - allowed = set(filter(None, filter_fields)) - - # omit fields in the `omit` argument. - omitted = set(filter(None, omit_fields)) + omitted = self._flat_omit for field in existing: @@ -84,4 +109,152 @@ def fields(self): if field in omitted: fields.pop(field, None) + # Drop omitted and non-allowed child fields from nested serializers + self.prune_nested_fields(fields) + return fields + + @staticmethod + def _get_nested_serializer(parent, fields): + """Return the nested serializer for a parent field, or None.""" + field = fields.get(parent) + if not field: + return None + nested = getattr(field, "child", field) + return nested if hasattr(nested, "fields") else None + + def prune_nested_fields(self, fields): + """Prune valid nested serializer fields based on the _nested_omit and _nested_allow lists.""" + valid_parents = ( + self._nested_omit.keys() | self._nested_allow.keys() + ) & fields.keys() + for parent in valid_parents: + nested_serializer = self._get_nested_serializer(parent, fields) + if nested_serializer is None: + continue + + # Drop omitted child fields from nested serializers + for name in self._nested_omit.get(parent, ()): + nested_serializer.fields.pop(name, None) + + # Drop non-allowed child fields from the nested serializers + allow_list = self._nested_allow.get(parent) + if allow_list: + nested_serializer.fields = { + name: field + for name, field in nested_serializer.fields.items() + if name in allow_list + } + + def _filter_top_level_fields_to_defer(self, field_list, keep_if): + """Method to retrieve the top-level fields to defer, given a list of + field names (allow or omit) and a condition that determines which + fields to defer. + """ + model = getattr(self.Meta, "model", None) + if not field_list or model is None: + return [] + + all_fields = [ + f.name for f in model._meta.get_fields() if getattr(f, "concrete", False) + ] + return [name for name in all_fields if keep_if(name, field_list)] + + def _get_disallowed_top_level_fields_to_defer(self): + """Determine which top-level model fields should be deferred when an + explicit fields filter is in use. + Other model fields not explicitly included in 'fields' are deferred. + """ + allow = getattr(self, "_flat_allow", None) + return self._filter_top_level_fields_to_defer( + allow, keep_if=lambda name, allow: name not in allow + ) + + def _get_omit_top_level_fields_to_defer(self): + """ + Determine which top-level model fields should be deferred when an + explicit omit filter is in use. Valid database fields in the omit list + will be deferred. + """ + omit = getattr(self, "_flat_omit", None) + return self._filter_top_level_fields_to_defer( + omit, keep_if=lambda name, omit: name in omit + ) + + def _get_nested_level_fields_to_defer(self, nested_mapping, should_defer): + """Method to retrieve nested fields to defer based on a mapping and a + defer condition. + """ + fields_to_defer = [] + for parent, items in nested_mapping.items(): + field = self.fields.get(parent) + if not field: + continue + + child_serializer = getattr(field, "child", field) + nested_model = getattr(child_serializer.Meta, "model", None) + if nested_model is None: + continue + + # Filter out nested fields that have a database column associated + # with them. + field_names = [ + f.name + for f in nested_model._meta.get_fields() + if getattr(f, "concrete", False) + ] + + # Determine which nested fields to defer + for name in field_names: + if should_defer(name, items): + fields_to_defer.append(f"{parent}__{name}") + + return fields_to_defer + + def _get_disallowed_nested_level_fields_to_defer(self): + """Determine which top-level model fields should be deferred when an + explicit fields filter is in use. + Other model fields not explicitly included in 'fields' are deferred. + """ + allow_map = getattr(self, "_nested_allow", {}) + return self._get_nested_level_fields_to_defer( + allow_map, + should_defer=lambda name, allow_list: name not in allow_list, + ) + + def _get_omit_nested_level_fields_to_defer(self): + """ + Determine which nested-model fields should be deferred for each nested + serializer when an explicit omit filter is in use. + """ + omit_map = getattr(self, "_nested_omit", {}) + return self._get_nested_level_fields_to_defer( + omit_map, should_defer=lambda name, omit_list: name in omit_list + ) + + def get_model_fields_to_defer(self): + """ + Returns a flat list of filtered model-fields; top-level and nested. + Ensures that parsing of 'fields'/'omit' has run by accessing '.fields'. + """ + + # Trigger parsing of required attributes if not already set + if not all( + hasattr(self, attr) + for attr in ( + "_flat_omit", + "_nested_omit", + "_flat_allow", + "_nested_allow", + ) + ): + _ = self.fields + + deferred = [ + *self._get_omit_top_level_fields_to_defer(), + *self._get_omit_nested_level_fields_to_defer(), + *self._get_disallowed_top_level_fields_to_defer(), + *self._get_disallowed_nested_level_fields_to_defer(), + ] + # Remove any duplicate fields + return list(set(deferred)) diff --git a/tests/models.py b/tests/models.py index 714098e..884697a 100644 --- a/tests/models.py +++ b/tests/models.py @@ -14,3 +14,12 @@ class School(models.Model): name = models.CharField(max_length=30) teachers = models.ManyToManyField(Teacher) + + +class Child(models.Model): + secret = models.CharField(max_length=100) + public = models.CharField(max_length=100) + + +class Parent(models.Model): + child = models.ForeignKey(Child, on_delete=models.CASCADE) \ No newline at end of file diff --git a/tests/serializers.py b/tests/serializers.py index 3619d59..766891b 100644 --- a/tests/serializers.py +++ b/tests/serializers.py @@ -5,7 +5,7 @@ from drf_dynamic_fields import DynamicFieldsMixin -from .models import Teacher, School +from .models import Teacher, School, Child class TeacherSerializer(DynamicFieldsMixin, serializers.ModelSerializer): @@ -40,3 +40,16 @@ class SchoolSerializer(DynamicFieldsMixin, serializers.ModelSerializer): class Meta: model = School fields = ("id", "teachers", "name") + + +class ChildSerializer(DynamicFieldsMixin, serializers.Serializer): + secret = serializers.CharField() + public = serializers.CharField() + + class Meta: + model = Child + + +class ParentSerializer(DynamicFieldsMixin, serializers.Serializer): + id = serializers.IntegerField() + child = ChildSerializer() \ No newline at end of file diff --git a/tests/test_mixins.py b/tests/test_mixins.py index f54e179..dfbce95 100644 --- a/tests/test_mixins.py +++ b/tests/test_mixins.py @@ -11,8 +11,8 @@ from django.test import TestCase, RequestFactory -from .serializers import SchoolSerializer, TeacherSerializer -from .models import Teacher, School +from .serializers import SchoolSerializer, TeacherSerializer, ParentSerializer +from .models import Teacher, School, Child, Parent class TestDynamicFieldsMixin(TestCase): @@ -20,6 +20,37 @@ class TestDynamicFieldsMixin(TestCase): Test case for the DynamicFieldsMixin """ + def _assert_nested_fields(self, data, expected_fields): + """ + Assert nested fields match the expected fields. + """ + for parent, nested_fields in expected_fields.items(): + with self.subTest(parent=parent): + items = data[parent] + if nested_fields is None: + continue + expected_set = set(nested_fields) + for obj in items: + with self.subTest(parent=parent): + actual_set = set(obj.keys()) + self.assertEqual( + actual_set, + expected_set, + f"{parent} fields mismatch: expected " + f"exactly {nested_fields}, got {list(obj.keys())}", + ) + + @staticmethod + def _prepare_school_instance(): + """Prepare school instance for testing.""" + school = School.objects.create(name="Python Heights High") + teachers = [ + Teacher.objects.create(name="Shane", age=45), + Teacher.objects.create(name="Kaz", age=29), + ] + school.teachers.add(*teachers) + return school + def test_removes_fields(self): """ Does it actually remove fields? @@ -175,3 +206,161 @@ def test_serializer_reuse_with_changing_request(self): request2 = rf.get("/api/v1/schools/1/?fields=id,name") serializer.context["request"] = request2 self.assertEqual(set(serializer.fields.keys()), {"id"}) + + def test_omit_nested_field(self): + """Omitting a nested field""" + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?omit=invalid,name,teachers__age,teachers__invalid") + + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + data = serializer.data + + # Confirm omitted fields are in deferred list + deferred = set(serializer.get_model_fields_to_defer()) + self.assertEqual({"name", "teachers__age"}, deferred) + + expected_fields = {"id": None, "teachers": ["id", "name", "request_info"]} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_omit_everything_nested_field(self): + """Omitting all fields within a nested field""" + rf = RequestFactory() + request = rf.get( + "/api/v1/schools/1/?omit=teachers__id,teachers__age,teachers__name,teachers__request_info" + ) + + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + data = serializer.data + + expected_fields = {"id": None, "name": None, "teachers": []} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_omit_top_field_and_keep_all_nested_fields(self): + """Omitting a top-level field while keeping all nested fields""" + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?omit=name") + + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + data = serializer.data + + expected_fields = { + "id": None, + "teachers": ["id", "name", "request_info", "age"], + } + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_allow_nested_field(self): + """Select only the requested fields, including nested-level fields.""" + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?fields=invalid,id,teachers__age,teachers__invalid") + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + + # Confirm omitted fields are in deferred list + deferred = set(serializer.get_model_fields_to_defer()) + self.assertEqual({"name","teachers__name", "teachers__id"}, deferred) + + data = serializer.data + expected_fields = {"id": None, "teachers": ["age"]} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_fields_all_gone_nested(self): + """If no fields are selected, all fields are omitted, including those + from the nested serializer. + """ + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?fields") + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + + data = serializer.data + expected_fields = {} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_nested_omit_and_fields_used(self): + """Omit and fields can be used together at the nested field level.""" + rf = RequestFactory() + request = rf.get( + "/api/v1/schools/1/?fields=id,name,teachers__name,teachers__age&omit=name,teachers__name" + ) + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + + data = serializer.data + expected_fields = {"id": None, "teachers": ["age"]} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_omit_nothing_nested(self): + """ + Blank omit doesn't affect nested fields. + """ + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?omit") + school = self._prepare_school_instance() + serializer = SchoolSerializer(school, context={"request": request}) + + data = serializer.data + expected_fields = { + "id": None, + "name": None, + "teachers": ["id", "age", "name", "request_info"], + } + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_single_nested_instance_omit_field(self): + """Omit also works for filtering fields on single nested instances""" + child = Child(secret="secret_key", public="public_key") + parent = Parent(id=1, child=child) + rf = RequestFactory() + request = rf.get("/api/v1/parent/1/?omit=id,child__secret") + serializer = ParentSerializer(parent, context={"request": request}) + data = serializer.data + + self.assertNotIn("id", data) + self.assertNotIn("secret", data["child"]) + self.assertEqual(data["child"]["public"], "public_key") + + def test_single_nested_instance_allow_field(self): + """Fields selection also works for filtering fields on single nested instances""" + child = Child(secret="secret_key", public="public_key") + parent = Parent(id=1, child=child) + rf = RequestFactory() + request = rf.get("/api/v1/parent/1/?fields=id,child__secret") + serializer = ParentSerializer(parent, context={"request": request}) + data = serializer.data + + self.assertEqual(data["id"], 1) + self.assertIn("secret", data["child"]) + self.assertEqual(data["child"]["secret"], "secret_key") + self.assertNotIn("public", data["child"]) \ No newline at end of file