diff --git a/.gitignore b/.gitignore index 39ca09e..203fcd0 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist/ *.egg-info/ build/ .tox/ +.idea diff --git a/drf_dynamic_fields/__init__.py b/drf_dynamic_fields/__init__.py index 97694e4..c4f8df0 100644 --- a/drf_dynamic_fields/__init__.py +++ b/drf_dynamic_fields/__init__.py @@ -1,11 +1,14 @@ """ Mixin to dynamically select only a subset of fields per DRF resource. """ + import warnings from django.conf import settings from django.utils.functional import cached_property +from rest_framework import serializers + class DynamicFieldsMixin(object): """ @@ -13,6 +16,14 @@ class DynamicFieldsMixin(object): which fields should be displayed. """ + @property + def prevent_nested_processing(self) -> bool: + """True when this serializer is not the root nor a root’s list-child.""" + return not ( + self is self.root + or (self.parent is self.root and getattr(self.parent, "many", False)) + ) + @cached_property def fields(self): """ @@ -29,13 +40,7 @@ def fields(self): # We are being called before a request cycle return fields - # Only filter if this is the root serializer, or if the parent is the - # root serializer with many=True - is_root = self.root == self - parent_is_list_root = self.parent == self.root and getattr( - self.parent, "many", False - ) - if not (is_root or parent_is_list_root): + if self.prevent_nested_processing: return fields try: @@ -55,15 +60,13 @@ def fields(self): if params is None: warnings.warn("Request object does not contain query parameters") - try: - filter_fields = params.get("fields", None).split(",") - except AttributeError: - filter_fields = None + source = get_source_path(self) + level = compute_level(self) - try: - omit_fields = params.get("omit", None).split(",") - except AttributeError: - omit_fields = [] + filter_fields = self.get_filter_fields( + params.get("fields", None), level, source + ) + omit_fields = self.get_omit_fields(params.get("omit", None), level, source) # Drop any fields that are not specified in the `fields` argument. existing = set(fields.keys()) @@ -85,3 +88,118 @@ def fields(self): fields.pop(field, None) return fields + + def get_filter_fields( + self, params, level, source, default=None, include_parent=True + ): + try: + return params.split(",") + except AttributeError: + return default + + def get_omit_fields(self, params, level, source): + return self.get_filter_fields( + params, level, source, default=[], include_parent=False + ) + + +class NestedDynamicFieldsMixin(DynamicFieldsMixin): + """A serializer mixin that extends DynamicFieldsMixin to allow nested serializers + to filter their fields based on the original `fields` query parameter. + + Unlike the base mixin—which only applies filtering at the root serializer, + this subclass: + + - Disables the `prevent_nested_processing` guard, allowing each level of nested + serializer to apply field filtering independently. + - Overrides `get_filter_fields` to slice the raw `fields` string + down to exactly those names relevant at this serializer’s + current nesting depth (using get_fields_for_level_and_prefix). + - `get_filter_fields` first delegates to the super method for splitting + the comma‐separated string, then calls a helper that: + • Selects only the fields that are nested under this serializer's path in + the hierarchy + • Returns direct children at depth `level + 1` + """ + + @property + def prevent_nested_processing(self): + return False + + def get_filter_fields( + self, params, level, source, default=None, include_parent=True + ): + """ + Parse the raw `fields` parameter and return the subset of fields + that apply at this serializer’s nesting level under the given + source prefix. + """ + fields = super().get_filter_fields( + params, level, source, default, include_parent + ) + return get_fields_for_level_and_prefix( + fields, level, source, default=default, include_parent=include_parent + ) + + +def get_source_path(serializer) -> str: + """Recursively walks up the serializer tree to build the nested field path.""" + parent = getattr(serializer, "parent", None) + if not parent: + return "" + parent_path = get_source_path(parent) + name = getattr(serializer, "field_name", None) + if not name: + return parent_path + return f"{parent_path}__{name}" if parent_path else name + + +def get_fields_for_level_and_prefix( + fields_list, level, source, include_parent, default +): + """Extract the field names relevant to a specific nesting depth + from a list of double‑underscore lookup strings. + """ + if not fields_list: + return default + + prefix = source.split("__") if source else [] + allowed = set() + for f in fields_list: + parts = f.split("__") + + if parts[:level] != prefix: + continue + + if len(parts) <= level + 1: + allowed.add(parts[-1]) + continue + + if len(parts) > level + 1 and include_parent: + # include parent field to ensure nesting proceeds + allowed.add(parts[level]) + continue + + + # If the only allowed fields are exactly the prefix itself, + # fall back to default + if allowed == set(prefix): + return default + + return allowed + + +def compute_level(serializer) -> int: + """Recursively count how many ancestors of `serializer` are not + ListSerializer instances. Stops when parent is None. + """ + parent = getattr(serializer, "parent", None) + if parent is None: + # base case, reached the top + return 0 + + # if this immediate parent is a ListSerializer, don’t count it, otherwise 1 + this_level = 0 if isinstance(parent, serializers.ListSerializer) else 1 + + # recurse on the parent itself + return this_level + compute_level(parent) diff --git a/tests/models.py b/tests/models.py index 714098e..a6d1403 100644 --- a/tests/models.py +++ b/tests/models.py @@ -4,13 +4,26 @@ from django.db import models -class Teacher(models.Model): +class Student(models.Model): name = models.CharField(max_length=30) age = models.IntegerField() +class Teacher(models.Model): + name = models.CharField(max_length=30) + age = models.IntegerField() + students = models.ManyToManyField(Student) class School(models.Model): """Schools just have teachers, no students.""" name = models.CharField(max_length=30) teachers = models.ManyToManyField(Teacher) + + +class Child(models.Model): + secret = models.CharField(max_length=100) + public = models.CharField(max_length=100) + + +class Parent(models.Model): + child = models.ForeignKey(Child, on_delete=models.CASCADE) \ No newline at end of file diff --git a/tests/serializers.py b/tests/serializers.py index 3619d59..13f9c17 100644 --- a/tests/serializers.py +++ b/tests/serializers.py @@ -3,22 +3,50 @@ """ from rest_framework import serializers -from drf_dynamic_fields import DynamicFieldsMixin +from drf_dynamic_fields import DynamicFieldsMixin, NestedDynamicFieldsMixin -from .models import Teacher, School +from .models import Teacher, School, Child, Student -class TeacherSerializer(DynamicFieldsMixin, serializers.ModelSerializer): +class BaseTeacherSerializer(serializers.ModelSerializer): + + request_info = serializers.SerializerMethodField() + + class Meta: + model = Teacher + fields = ("id", "request_info", "age", "name") + + def get_request_info(self, teacher): + """ + a meaningless method that attempts + to access the request object. + """ + request = self.context["request"] + return request.build_absolute_uri("/api/v1/teacher/{}".format(teacher.pk)) + + +class TeacherSerializer(DynamicFieldsMixin, BaseTeacherSerializer): + pass + + +class NestableStudentSerializer(NestedDynamicFieldsMixin, serializers.ModelSerializer): + + class Meta: + model = Student + fields = ("id", "name", "age") + + +class NestableTeacherSerializer(NestedDynamicFieldsMixin, BaseTeacherSerializer): """ The request_info field is to highlight the issue accessing request during a nested serializer. """ - request_info = serializers.SerializerMethodField() + students = NestableStudentSerializer(many=True, read_only=True) class Meta: model = Teacher - fields = ("id", "request_info", "age", "name") + fields = ("id", "request_info", "age", "name", "students") def get_request_info(self, teacher): """ @@ -29,14 +57,33 @@ def get_request_info(self, teacher): return request.build_absolute_uri("/api/v1/teacher/{}".format(teacher.pk)) -class SchoolSerializer(DynamicFieldsMixin, serializers.ModelSerializer): +class BaseSchoolSerializer(serializers.ModelSerializer): + + + class Meta: + model = School + fields = ("id", "teachers", "name") + + +class SchoolSerializer(DynamicFieldsMixin, BaseSchoolSerializer): + teachers = TeacherSerializer(many=True, read_only=True) + + +class NestableSchoolSerializer(NestedDynamicFieldsMixin, BaseSchoolSerializer): """ Interesting enough serializer because the TeacherSerializer will use ListSerializer due to the `many=True` """ + teachers = NestableTeacherSerializer(many=True, read_only=True) - teachers = TeacherSerializer(many=True, read_only=True) +class ChildSerializer(NestedDynamicFieldsMixin, serializers.Serializer): + secret = serializers.CharField() + public = serializers.CharField() class Meta: - model = School - fields = ("id", "teachers", "name") + model = Child + + +class ParentSerializer(NestedDynamicFieldsMixin, serializers.Serializer): + id = serializers.IntegerField() + child = ChildSerializer() diff --git a/tests/test_mixins.py b/tests/test_mixins.py index f54e179..b03e923 100644 --- a/tests/test_mixins.py +++ b/tests/test_mixins.py @@ -11,8 +11,14 @@ from django.test import TestCase, RequestFactory -from .serializers import SchoolSerializer, TeacherSerializer -from .models import Teacher, School +from .serializers import ( + NestableSchoolSerializer, + NestableTeacherSerializer, + SchoolSerializer, + TeacherSerializer, + ParentSerializer, +) +from .models import Teacher, School, Child, Parent, Student class TestDynamicFieldsMixin(TestCase): @@ -20,13 +26,16 @@ class TestDynamicFieldsMixin(TestCase): Test case for the DynamicFieldsMixin """ + SchoolSerializer = SchoolSerializer + TeacherSerializer = TeacherSerializer + def test_removes_fields(self): """ Does it actually remove fields? """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/?fields=id") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual(set(serializer.fields.keys()), set(("id",))) @@ -36,7 +45,7 @@ def test_fields_left_alone(self): """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual( set(serializer.fields.keys()), set(("id", "request_info", "age", "name")) @@ -48,7 +57,7 @@ def test_fields_all_gone(self): """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/?fields") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual(set(serializer.fields.keys()), set()) @@ -60,7 +69,7 @@ def test_ordinary_serializer(self): request = rf.get("/api/v1/schools/1/?fields=id,age") teacher = Teacher.objects.create(name="Susan", age=34) - serializer = TeacherSerializer(teacher, context={"request": request}) + serializer = self.TeacherSerializer(teacher, context={"request": request}) self.assertEqual(serializer.data, {"id": teacher.id, "age": teacher.age}) @@ -70,7 +79,7 @@ def test_omit(self): """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/?omit=request_info") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual(set(serializer.fields.keys()), set(("id", "name", "age"))) @@ -80,7 +89,7 @@ def test_omit_and_fields_used(self): """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/?fields=id,request_info&omit=request_info") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual(set(serializer.fields.keys()), set(("id",))) @@ -90,7 +99,7 @@ def test_omit_everything(self): """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/?omit=id,request_info,age,name") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual(set(serializer.fields.keys()), set()) @@ -100,7 +109,7 @@ def test_omit_nothing(self): """ rf = RequestFactory() request = rf.get("/api/v1/schools/1/?omit") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual( set(serializer.fields.keys()), set(("id", "request_info", "name", "age")) @@ -109,7 +118,7 @@ def test_omit_nothing(self): def test_omit_non_existant_field(self): rf = RequestFactory() request = rf.get("/api/v1/schools/1/?omit=pretend") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual( set(serializer.fields.keys()), set(("id", "request_info", "name", "age")) @@ -129,7 +138,7 @@ def test_as_nested_serializer(self): ] school.teachers.add(*teachers) - serializer = SchoolSerializer(school, context={"request": request}) + serializer = self.SchoolSerializer(school, context={"request": request}) request_info = "http://testserver/api/v1/teacher/{}" @@ -168,10 +177,270 @@ def test_serializer_reuse_with_changing_request(self): rf = RequestFactory() request = rf.get("/api/v1/schools/1/?fields=id") - serializer = TeacherSerializer(context={"request": request}) + serializer = self.TeacherSerializer(context={"request": request}) self.assertEqual(set(serializer.fields.keys()), {"id"}) # now change the request on this instantiated serializer. request2 = rf.get("/api/v1/schools/1/?fields=id,name") serializer.context["request"] = request2 self.assertEqual(set(serializer.fields.keys()), {"id"}) + + +class TestNestedDynamicFieldsMixin(TestCase): + """ + Test case for the NestedDynamicFieldsMixin + """ + SchoolSerializer = NestableSchoolSerializer + TeacherSerializer = NestableTeacherSerializer + + def _assert_nested_fields(self, data, expected_fields): + """ + Assert nested fields match the expected fields. + """ + for parent, nested_fields in expected_fields.items(): + with self.subTest(parent=parent): + items = data[parent] + if nested_fields is None: + continue + expected_set = set(nested_fields) + for obj in items: + with self.subTest(parent=parent): + actual_set = set(obj.keys()) + self.assertEqual( + actual_set, + expected_set, + f"{parent} fields mismatch: expected " + f"exactly {nested_fields}, got {list(obj.keys())}", + ) + + @staticmethod + def _prepare_school_instance(): + """Prepare school instance for testing.""" + school = School.objects.create(name="Python Heights High") + teachers = [ + Teacher.objects.create(name="Shane", age=45), + Teacher.objects.create(name="Kaz", age=29), + ] + school.teachers.add(*teachers) + student = Student.objects.create(name="Shannon", age=23) + teachers[0].students.add(student) + + school_2 = School.objects.create(name="Python Heights High") + teachers = [ + Teacher.objects.create(name="Shane 2", age=46), + Teacher.objects.create(name="Kaz 2", age=30), + ] + school_2.teachers.add(*teachers) + + return [school, school_2] + + def test_omit_nested_field(self): + """Omitting a nested field""" + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?omit=invalid,name,teachers__age,teachers__invalid,teachers__students__age") + + # Single nested instance. + schools = self._prepare_school_instance() + serializer = self.SchoolSerializer(schools[0], context={"request": request}) + data = serializer.data + expected_fields = {"id": None, "teachers": ["id", "name", "request_info", "students"]} + third_level_expected_fields = {"id":None, "name":None, "request_info":None, "students":["id", "name"]} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + # Assert third level fields: + self._assert_nested_fields(data["teachers"][0], third_level_expected_fields) + + # Multiple nested instances: + serializer = self.SchoolSerializer(schools, many=True, context={"request": request}) + data = serializer.data[0] + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + # Assert third level fields: + self._assert_nested_fields(data["teachers"][0], third_level_expected_fields) + + def test_omit_everything_nested_field(self): + """Omitting all fields within a nested field""" + rf = RequestFactory() + request = rf.get( + "/api/v1/schools/1/?omit=teachers__id,teachers__age,teachers__name,teachers__request_info,teachers__students" + ) + # Single nested instance. + schools = self._prepare_school_instance() + serializer = self.SchoolSerializer(schools[0], context={"request": request}) + data = serializer.data + + expected_fields = {"id": None, "name": None, "teachers": []} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + # Multiple nested instances: + serializer = self.SchoolSerializer(schools, many=True, context={"request": request}) + data = serializer.data[0] + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_omit_top_field_and_keep_all_nested_fields(self): + """Omitting a top-level field while keeping all nested fields""" + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?omit=name") + + schools = self._prepare_school_instance() + # Single nested instance. + serializer = self.SchoolSerializer(schools[0], context={"request": request}) + data = serializer.data + expected_fields = { + "id": None, + "teachers": ["id", "name", "request_info", "age", "students"], + } + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + # Multiple nested instances: + serializer = self.SchoolSerializer(schools, many=True, context={"request": request}) + data = serializer.data[0] + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_allow_nested_field(self): + """Select only the requested fields, including nested-level fields.""" + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?fields=invalid,id,teachers__age,teachers__invalid,teachers__students__age") + schools = self._prepare_school_instance() + # Single nested instance. + serializer = self.SchoolSerializer(schools[0], context={"request": request}) + data = serializer.data + expected_fields = {"id": None, "teachers": ["age", "students"]} + third_level_expected_fields = {"age":None, "students":["age"]} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + # Assert third level fields: + self._assert_nested_fields(data["teachers"][0], third_level_expected_fields) + + # Multiple nested instances: + serializer = self.SchoolSerializer(schools, many=True, context={"request": request}) + data = serializer.data[0] + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + # Assert third level fields: + self._assert_nested_fields(data["teachers"][0], third_level_expected_fields) + + def test_fields_all_gone_nested(self): + """If no fields are selected, all fields are omitted, including those + from the nested serializer. + """ + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?fields") + schools = self._prepare_school_instance() + # Single nested instance. + serializer = self.SchoolSerializer(schools[0], context={"request": request}) + data = serializer.data + expected_fields = {} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + # Multiple nested instances: + serializer = self.SchoolSerializer(schools, many=True, context={"request": request}) + data = serializer.data[0] + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_nested_omit_and_fields_used(self): + """Omit and fields can be used together at the nested field level.""" + rf = RequestFactory() + request = rf.get( + "/api/v1/schools/1/?fields=id,name,teachers__name,teachers__age&omit=name,teachers__name" + ) + schools = self._prepare_school_instance() + # Single nested instance. + serializer = self.SchoolSerializer(schools[0], context={"request": request}) + data = serializer.data + expected_fields = {"id": None, "teachers": ["age"]} + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + # Multiple nested instances: + serializer = self.SchoolSerializer(schools, many=True, context={"request": request}) + data = serializer.data[0] + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_omit_nothing_nested(self): + """ + Blank omit doesn't affect nested fields. + """ + rf = RequestFactory() + request = rf.get("/api/v1/schools/1/?omit") + schools = self._prepare_school_instance() + # Single nested instance. + serializer = self.SchoolSerializer(schools[0], context={"request": request}) + data = serializer.data + expected_fields = { + "id": None, + "name": None, + "teachers": ["id", "age", "name", "request_info", "students"], + } + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + # Multiple nested instances: + serializer = self.SchoolSerializer(schools, many=True, context={"request": request}) + data = serializer.data[0] + # Assert top‐level keys exactly match + self.assertEqual(set(data.keys()), set(expected_fields.keys())) + # Assert nested fields. + self._assert_nested_fields(data, expected_fields) + + def test_single_nested_instance_omit_field(self): + """Omit also works for filtering fields on single nested instances""" + child = Child(secret="secret_key", public="public_key") + parent = Parent(id=1, child=child) + rf = RequestFactory() + request = rf.get("/api/v1/parent/1/?omit=id,child__secret") + serializer = ParentSerializer(parent, context={"request": request}) + data = serializer.data + + self.assertNotIn("id", data) + self.assertNotIn("secret", data["child"]) + self.assertEqual(data["child"]["public"], "public_key") + + def test_single_nested_instance_allow_field(self): + """Fields selection also works for filtering fields on single nested instances""" + child = Child(secret="secret_key", public="public_key") + parent = Parent(id=1, child=child) + rf = RequestFactory() + request = rf.get("/api/v1/parent/1/?fields=id,child__secret") + serializer = ParentSerializer(parent, context={"request": request}) + data = serializer.data + + self.assertEqual(data["id"], 1) + self.assertIn("secret", data["child"]) + self.assertEqual(data["child"]["secret"], "secret_key") + self.assertNotIn("public", data["child"]) + +