Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ dist/
*.egg-info/
build/
.tox/
.idea
148 changes: 133 additions & 15 deletions drf_dynamic_fields/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
"""
Mixin to dynamically select only a subset of fields per DRF resource.
"""

import warnings

from django.conf import settings
from django.utils.functional import cached_property

from rest_framework import serializers


class DynamicFieldsMixin(object):
"""
A serializer mixin that takes an additional `fields` argument that controls
which fields should be displayed.
"""

@property
def prevent_nested_processing(self) -> bool:
"""True when this serializer is not the root nor a root’s list-child."""
return not (
self is self.root
or (self.parent is self.root and getattr(self.parent, "many", False))
)

@cached_property
def fields(self):
"""
Expand All @@ -29,13 +40,7 @@ def fields(self):
# We are being called before a request cycle
return fields

# Only filter if this is the root serializer, or if the parent is the
# root serializer with many=True
is_root = self.root == self
parent_is_list_root = self.parent == self.root and getattr(
self.parent, "many", False
)
if not (is_root or parent_is_list_root):
if self.prevent_nested_processing:
return fields

try:
Expand All @@ -55,15 +60,13 @@ def fields(self):
if params is None:
warnings.warn("Request object does not contain query parameters")

try:
filter_fields = params.get("fields", None).split(",")
except AttributeError:
filter_fields = None
source = get_source_path(self)
level = compute_level(self)

try:
omit_fields = params.get("omit", None).split(",")
except AttributeError:
omit_fields = []
filter_fields = self.get_filter_fields(
params.get("fields", None), level, source
)
omit_fields = self.get_omit_fields(params.get("omit", None), level, source)

# Drop any fields that are not specified in the `fields` argument.
existing = set(fields.keys())
Expand All @@ -85,3 +88,118 @@ def fields(self):
fields.pop(field, None)

return fields

def get_filter_fields(
self, params, level, source, default=None, include_parent=True
):
try:
return params.split(",")
except AttributeError:
return default

def get_omit_fields(self, params, level, source):
return self.get_filter_fields(
params, level, source, default=[], include_parent=False
)


class NestedDynamicFieldsMixin(DynamicFieldsMixin):
"""A serializer mixin that extends DynamicFieldsMixin to allow nested serializers
to filter their fields based on the original `fields` query parameter.

Unlike the base mixin—which only applies filtering at the root serializer,
this subclass:

- Disables the `prevent_nested_processing` guard, allowing each level of nested
serializer to apply field filtering independently.
- Overrides `get_filter_fields` to slice the raw `fields` string
down to exactly those names relevant at this serializer’s
current nesting depth (using get_fields_for_level_and_prefix).
- `get_filter_fields` first delegates to the super method for splitting
the comma‐separated string, then calls a helper that:
• Selects only the fields that are nested under this serializer's path in
the hierarchy
• Returns direct children at depth `level + 1`
"""

@property
def prevent_nested_processing(self):
return False

def get_filter_fields(
self, params, level, source, default=None, include_parent=True
):
"""
Parse the raw `fields` parameter and return the subset of fields
that apply at this serializer’s nesting level under the given
source prefix.
"""
fields = super().get_filter_fields(
params, level, source, default, include_parent
)
return get_fields_for_level_and_prefix(
fields, level, source, default=default, include_parent=include_parent
)


def get_source_path(serializer) -> str:
"""Recursively walks up the serializer tree to build the nested field path."""
parent = getattr(serializer, "parent", None)
if not parent:
return ""
parent_path = get_source_path(parent)
name = getattr(serializer, "field_name", None)
if not name:
return parent_path
return f"{parent_path}__{name}" if parent_path else name


def get_fields_for_level_and_prefix(
fields_list, level, source, include_parent, default
):
"""Extract the field names relevant to a specific nesting depth
from a list of double‑underscore lookup strings.
"""
if not fields_list:
return default

prefix = source.split("__") if source else []
allowed = set()
for f in fields_list:
parts = f.split("__")

if parts[:level] != prefix:
continue

if len(parts) <= level + 1:
allowed.add(parts[-1])
continue

if len(parts) > level + 1 and include_parent:
# include parent field to ensure nesting proceeds
allowed.add(parts[level])
continue


# If the only allowed fields are exactly the prefix itself,
# fall back to default
if allowed == set(prefix):
return default

return allowed


def compute_level(serializer) -> int:
"""Recursively count how many ancestors of `serializer` are not
ListSerializer instances. Stops when parent is None.
"""
parent = getattr(serializer, "parent", None)
if parent is None:
# base case, reached the top
return 0

# if this immediate parent is a ListSerializer, don’t count it, otherwise 1
this_level = 0 if isinstance(parent, serializers.ListSerializer) else 1

# recurse on the parent itself
return this_level + compute_level(parent)
15 changes: 14 additions & 1 deletion tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,26 @@
from django.db import models


class Teacher(models.Model):
class Student(models.Model):
name = models.CharField(max_length=30)
age = models.IntegerField()

class Teacher(models.Model):
name = models.CharField(max_length=30)
age = models.IntegerField()
students = models.ManyToManyField(Student)

class School(models.Model):
"""Schools just have teachers, no students."""

name = models.CharField(max_length=30)
teachers = models.ManyToManyField(Teacher)


class Child(models.Model):
secret = models.CharField(max_length=100)
public = models.CharField(max_length=100)


class Parent(models.Model):
child = models.ForeignKey(Child, on_delete=models.CASCADE)
65 changes: 56 additions & 9 deletions tests/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,50 @@
"""
from rest_framework import serializers

from drf_dynamic_fields import DynamicFieldsMixin
from drf_dynamic_fields import DynamicFieldsMixin, NestedDynamicFieldsMixin

from .models import Teacher, School
from .models import Teacher, School, Child, Student


class TeacherSerializer(DynamicFieldsMixin, serializers.ModelSerializer):
class BaseTeacherSerializer(serializers.ModelSerializer):

request_info = serializers.SerializerMethodField()

class Meta:
model = Teacher
fields = ("id", "request_info", "age", "name")

def get_request_info(self, teacher):
"""
a meaningless method that attempts
to access the request object.
"""
request = self.context["request"]
return request.build_absolute_uri("/api/v1/teacher/{}".format(teacher.pk))


class TeacherSerializer(DynamicFieldsMixin, BaseTeacherSerializer):
pass


class NestableStudentSerializer(NestedDynamicFieldsMixin, serializers.ModelSerializer):

class Meta:
model = Student
fields = ("id", "name", "age")


class NestableTeacherSerializer(NestedDynamicFieldsMixin, BaseTeacherSerializer):
"""
The request_info field is to highlight the issue accessing request during
a nested serializer.
"""

request_info = serializers.SerializerMethodField()
students = NestableStudentSerializer(many=True, read_only=True)

class Meta:
model = Teacher
fields = ("id", "request_info", "age", "name")
fields = ("id", "request_info", "age", "name", "students")

def get_request_info(self, teacher):
"""
Expand All @@ -29,14 +57,33 @@ def get_request_info(self, teacher):
return request.build_absolute_uri("/api/v1/teacher/{}".format(teacher.pk))


class SchoolSerializer(DynamicFieldsMixin, serializers.ModelSerializer):
class BaseSchoolSerializer(serializers.ModelSerializer):


class Meta:
model = School
fields = ("id", "teachers", "name")


class SchoolSerializer(DynamicFieldsMixin, BaseSchoolSerializer):
teachers = TeacherSerializer(many=True, read_only=True)


class NestableSchoolSerializer(NestedDynamicFieldsMixin, BaseSchoolSerializer):
"""
Interesting enough serializer because the TeacherSerializer
will use ListSerializer due to the `many=True`
"""
teachers = NestableTeacherSerializer(many=True, read_only=True)

teachers = TeacherSerializer(many=True, read_only=True)
class ChildSerializer(NestedDynamicFieldsMixin, serializers.Serializer):
secret = serializers.CharField()
public = serializers.CharField()

class Meta:
model = School
fields = ("id", "teachers", "name")
model = Child


class ParentSerializer(NestedDynamicFieldsMixin, serializers.Serializer):
id = serializers.IntegerField()
child = ChildSerializer()
Loading