diff --git a/django_filters/filters.py b/django_filters/filters.py index 0017291ce..d42b42125 100644 --- a/django_filters/filters.py +++ b/django_filters/filters.py @@ -83,6 +83,16 @@ def __init__(self, field_name=None, lookup_expr=None, *, label=None, self.creation_counter = Filter.creation_counter Filter.creation_counter += 1 + def bind(self, attr, parent): + """Bind the filter to its parent filterset. + + Provides both the filter's attribute name on the filterset and the + parent filterset instance. Called when the parent is initialized. + """ + self.attr = attr + self.parent = parent + self.model = parent.queryset.model + def get_method(self, qs): """Return filter method based on whether we're excluding or simply filtering. diff --git a/django_filters/filterset.py b/django_filters/filterset.py index d174718c0..e69797771 100644 --- a/django_filters/filterset.py +++ b/django_filters/filterset.py @@ -190,7 +190,6 @@ class BaseFilterSet(object): def __init__(self, data=None, queryset=None, *, request=None, prefix=None): if queryset is None: queryset = self._meta.model._default_manager.all() - model = queryset.model self.is_bound = data is not None self.data = data or {} @@ -200,10 +199,8 @@ def __init__(self, data=None, queryset=None, *, request=None, prefix=None): self.filters = copy.deepcopy(self.base_filters) - # propagate the model and filterset to the filters - for filter_ in self.filters.values(): - filter_.model = model - filter_.parent = self + for filter_name, f in self.filters.items(): + f.bind(filter_name, self) def is_valid(self): """ diff --git a/tests/test_filters.py b/tests/test_filters.py index a46f7567e..f44552d2b 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -94,6 +94,15 @@ def test_field_with_single_lookup_expr(self): field = f.field self.assertIsInstance(field, forms.Field) + def test_filterset_bind(self): + m = mock.Mock(queryset=User.objects.all()) + f = Filter() + f.bind('name', m) + + self.assertEqual(f.attr, 'name') + self.assertIs(f.parent, m) + self.assertIs(f.model, User) + def test_field_params(self): with mock.patch.object(Filter, 'field_class', spec=['__call__']) as mocked: @@ -161,6 +170,7 @@ def test_filter_using_method(self): qs = mock.NonCallableMock(spec=[]) method = mock.Mock() f = Filter(method=method) + f.bind('name', mock.Mock()) result = f.filter(qs, 'value') method.assert_called_once_with(qs, None, 'value') self.assertNotEqual(qs, result) @@ -721,7 +731,7 @@ def test_callable_queryset(self): qs_callable = mock.Mock(return_value=qs) f = ModelChoiceFilter(queryset=qs_callable) - f.parent = mock.Mock(request=request) + f.bind('name', mock.Mock(request=request)) field = f.field qs_callable.assert_called_with(request) @@ -735,7 +745,7 @@ class F(ModelChoiceFilter): get_queryset = mock.create_autospec(ModelChoiceFilter.get_queryset, return_value=qs) f = F() - f.parent = mock.Mock(request=request) + f.bind('name', mock.Mock(request=request)) field = f.field f.get_queryset.assert_called_with(f, request) @@ -787,7 +797,7 @@ def test_callable_queryset(self): qs_callable = mock.Mock(return_value=qs) f = ModelMultipleChoiceFilter(queryset=qs_callable) - f.parent = mock.Mock(request=request) + f.bind('name', mock.Mock(request=request)) field = f.field qs_callable.assert_called_with(request) @@ -1261,19 +1271,14 @@ def test_default_field_without_assigning_model(self): f.field def test_default_field_with_assigning_model(self): - mocked = mock.Mock() - chained_call = '.'.join(['_default_manager', 'distinct.return_value', - 'order_by.return_value', - 'values_list.return_value']) - mocked.configure_mock(**{chained_call: iter([])}) - f = AllValuesFilter() - f.model = mocked - field = f.field - self.assertIsInstance(field, forms.ChoiceField) + f = AllValuesFilter(field_name='username') + f.bind('name', mock.Mock(queryset=User.objects.all())) + + self.assertIsInstance(f.field, forms.ChoiceField) def test_empty_value_in_choices(self): f = AllValuesFilter(field_name='username') - f.model = User + f.bind('name', mock.Mock(queryset=User.objects.all())) self.assertEqual(list(f.field.choices), [ ('', '---------'), @@ -1295,7 +1300,7 @@ def test_normalize_lookup_with_display_label(self): def test_lookup_choices_default(self): # Lookup choices should default to the model field's registered lookups f = LookupChoiceFilter(field_name='username', lookup_choices=None) - f.model = User + f.bind('name', mock.Mock(queryset=User.objects.all())) choice_field = f.field.fields[1] self.assertEqual(