diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index ac9a238..07feb1c 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -19,3 +19,10 @@ repos:
hooks:
- id: django-upgrade
args: ["--target-version=3.1"]
+
+- repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.14.9
+ hooks:
+ - id: ruff-check
+ args: [ --fix ]
+ - id: ruff-format
diff --git a/docs/conf.py b/docs/conf.py
index aaa936e..61b0445 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -4,30 +4,31 @@
import re
import sys
from datetime import date
+from importlib.util import find_spec
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
-sys.path.insert(0, os.path.abspath('..'))
-os.environ['RUNTIME_ENV'] = 'TESTSUITE'
-os.environ['DJANGO_SETTINGS_MODULE'] = 'tests.settings'
-try:
- import django
- import sphinx_rtd_theme
+sys.path.insert(0, os.path.abspath(".."))
+os.environ["RUNTIME_ENV"] = "TESTSUITE"
+os.environ["DJANGO_SETTINGS_MODULE"] = "tests.settings"
+if find_spec("django") and find_spec("sphinx_rtd_theme"):
use_sphinx_rtd_theme = True
+
+ import django
+
if hasattr(django, "setup"):
django.setup()
-
-except ImportError:
- use_sphinx_rtd_theme = os.environ.get('READTHEDOCS', False)
+else:
+ use_sphinx_rtd_theme = os.environ.get("READTHEDOCS", False)
def get_version(package):
"""
Return package version as listed in `__version__` in `init.py`.
"""
- init_py = open(os.path.join(package, '__init__.py')).read()
+ init_py = open(os.path.join(package, "__init__.py")).read()
return re.search("__version__ = ['\"]([^'\"]+)['\"]", init_py).group(1)
@@ -40,35 +41,37 @@ def get_version(package):
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
-extensions = ['sphinx.ext.autodoc',
- 'sphinx.ext.intersphinx',
- 'sphinx.ext.extlinks',
- 'sphinx.ext.todo',
- 'sphinx.ext.viewcode']
+extensions = [
+ "sphinx.ext.autodoc",
+ "sphinx.ext.intersphinx",
+ "sphinx.ext.extlinks",
+ "sphinx.ext.todo",
+ "sphinx.ext.viewcode",
+]
# Add any paths that contain templates here, relative to this directory.
-templates_path = ['_templates']
+templates_path = ["_templates"]
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
-source_suffix = '.rst'
+source_suffix = ".rst"
# The master toctree document.
-master_doc = 'index'
+master_doc = "index"
# General information about the project.
-project = 'drf_haystack'
-copyright = '%d, Rolf Håvard Blindheim' % date.today().year
-author = 'Rolf Håvard Blindheim'
+project = "drf_haystack"
+copyright = "%d, Rolf Håvard Blindheim" % date.today().year
+author = "Rolf Håvard Blindheim"
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# The full version, including alpha/beta/rc tags.
-version = release = get_version(os.path.join('..', 'drf_haystack'))
+version = release = get_version(os.path.join("..", "drf_haystack"))
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
@@ -80,10 +83,10 @@ def get_version(package):
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
-exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
+exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
# The name of the Pygments (syntax highlighting) style to use.
-pygments_style = 'sphinx'
+pygments_style = "sphinx"
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True
@@ -94,7 +97,7 @@ def get_version(package):
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
-html_theme = 'sphinx_rtd_theme' if use_sphinx_rtd_theme else 'alabaster'
+html_theme = "sphinx_rtd_theme" if use_sphinx_rtd_theme else "alabaster"
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
@@ -105,28 +108,25 @@ def get_version(package):
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = ['_static']
+html_static_path = ["_static"]
# -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder.
-htmlhelp_basename = 'drfhaystackdoc'
+htmlhelp_basename = "drfhaystackdoc"
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
- 'papersize': 'a4paper',
-
+ "papersize": "a4paper",
# The font size ('10pt', '11pt' or '12pt').
- 'pointsize': '11pt',
-
+ "pointsize": "11pt",
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
-
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
@@ -136,8 +136,7 @@ def get_version(package):
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
- ('index', 'drf-haystack.tex', 'drf-haystack documentation',
- 'Inonit', 'manual'),
+ ("index", "drf-haystack.tex", "drf-haystack documentation", "Inonit", "manual"),
]
@@ -145,10 +144,7 @@ def get_version(package):
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
-man_pages = [
- ('index', 'drf-haystack', 'drf-haystack documentation',
- ['Inonit'], 1)
-]
+man_pages = [("index", "drf-haystack", "drf-haystack documentation", ["Inonit"], 1)]
# -- Options for Texinfo output -------------------------------------------
@@ -157,21 +153,27 @@ def get_version(package):
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
- ('index', 'drf-haystack', 'drf-haystack documentation',
- 'Inonit', 'drf-haystack', 'Haystack for Django REST Framework',
- 'Miscellaneous'),
+ (
+ "index",
+ "drf-haystack",
+ "drf-haystack documentation",
+ "Inonit",
+ "drf-haystack",
+ "Haystack for Django REST Framework",
+ "Miscellaneous",
+ ),
]
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
- 'http://docs.python.org/': None,
- 'django': ('https://django.readthedocs.io/en/latest/', None),
- 'haystack': ('https://django-haystack.readthedocs.io/en/latest/', None)
+ "http://docs.python.org/": None,
+ "django": ("https://django.readthedocs.io/en/latest/", None),
+ "haystack": ("https://django-haystack.readthedocs.io/en/latest/", None),
}
# Configurations for extlinks
extlinks = {
- 'drf-pr': ('https://github.com/rhblind/drf-haystack/pull/%s', 'PR#'),
- 'drf-issue': ('https://github.com/rhblind/drf-haystack/issues/%s', '#'),
- 'haystack-issue': ('https://github.com/django-haystack/django-haystack/issues/%s', '#')
+ "drf-pr": ("https://github.com/rhblind/drf-haystack/pull/%s", "PR#"),
+ "drf-issue": ("https://github.com/rhblind/drf-haystack/issues/%s", "#"),
+ "haystack-issue": ("https://github.com/django-haystack/django-haystack/issues/%s", "#"),
}
diff --git a/drf_haystack/__init__.py b/drf_haystack/__init__.py
index b87f922..55c42b3 100644
--- a/drf_haystack/__init__.py
+++ b/drf_haystack/__init__.py
@@ -1,6 +1,5 @@
import warnings
-
__title__ = "drf-haystack"
__version__ = "1.9.1"
__author__ = "Rolf Haavard Blindheim"
diff --git a/drf_haystack/fields.py b/drf_haystack/fields.py
index d89d2c9..2d1fac0 100644
--- a/drf_haystack/fields.py
+++ b/drf_haystack/fields.py
@@ -1,4 +1,3 @@
-import six
from rest_framework import fields
@@ -6,7 +5,7 @@ class DRFHaystackFieldMixin:
prefix_field_names = False
def __init__(self, **kwargs):
- self.prefix_field_names = kwargs.pop('prefix_field_names', False)
+ self.prefix_field_names = kwargs.pop("prefix_field_names", False)
super().__init__(**kwargs)
def bind(self, field_name, parent):
@@ -23,8 +22,7 @@ def bind(self, field_name, parent):
assert self.source != field_name, (
"It is redundant to specify `source='%s'` on field '%s' in "
"serializer '%s', because it is the same as the field name. "
- "Remove the `source` keyword argument." %
- (field_name, self.__class__.__name__, parent.__class__.__name__)
+ "Remove the `source` keyword argument." % (field_name, self.__class__.__name__, parent.__class__.__name__)
)
self.field_name = field_name
@@ -32,7 +30,7 @@ def bind(self, field_name, parent):
# `self.label` should default to being based on the field name.
if self.label is None:
- self.label = field_name.replace('_', ' ').capitalize()
+ self.label = field_name.replace("_", " ").capitalize()
# self.source should default to being the same as the field name.
if self.source is None:
@@ -40,10 +38,10 @@ def bind(self, field_name, parent):
# self.source_attrs is a list of attributes that need to be looked up
# when serializing the instance, or populating the validated data.
- if self.source == '*':
+ if self.source == "*":
self.source_attrs = []
else:
- self.source_attrs = self.source.split('.')
+ self.source_attrs = self.source.split(".")
def convert_field_name(self, field_name):
if not self.prefix_field_names:
@@ -91,10 +89,7 @@ class FacetDictField(fields.DictField):
"""
def to_representation(self, value):
- return {
- str(key): self.child.to_representation(key, val)
- for key, val in value.items()
- }
+ return {str(key): self.child.to_representation(key, val) for key, val in value.items()}
class FacetListField(fields.ListField):
diff --git a/drf_haystack/filters.py b/drf_haystack/filters.py
index 8a5aeb4..b55c603 100644
--- a/drf_haystack/filters.py
+++ b/drf_haystack/filters.py
@@ -1,12 +1,12 @@
import operator
-import six
from functools import reduce
+import six
from django.core.exceptions import ImproperlyConfigured
from haystack.query import SearchQuerySet
from rest_framework.filters import BaseFilterBackend, OrderingFilter
-from drf_haystack.query import BoostQueryBuilder, FilterQueryBuilder, FacetQueryBuilder, SpatialQueryBuilder
+from drf_haystack.query import BoostQueryBuilder, FacetQueryBuilder, FilterQueryBuilder, SpatialQueryBuilder
class BaseHaystackFilterBackend(BaseFilterBackend):
@@ -57,7 +57,7 @@ def filter_queryset(self, request, queryset, view):
return self.apply_filters(
queryset=queryset,
applicable_filters=self.process_filters(applicable_filters, queryset, view),
- applicable_exclusions=self.process_filters(applicable_exclusions, queryset, view)
+ applicable_exclusions=self.process_filters(applicable_exclusions, queryset, view),
)
def get_query_builder(self, *args, **kwargs):
@@ -109,9 +109,7 @@ def process_filters(self, filters, queryset, view):
for field_name, query in filters.children:
for word in query.split(" "):
bit = queryset.query.clean(word.strip())
- kwargs = {
- field_name: bit
- }
+ kwargs = {field_name: bit}
query_bits.append(view.query_object(**kwargs))
return six.moves.reduce(operator.and_, filter(lambda x: x, query_bits))
@@ -243,16 +241,14 @@ def get_valid_fields(self, queryset, view, context={}):
"Cannot use %s with '__all__' as 'ordering_fields' attribute on a view "
"which has no 'index_models' set. Either specify some 'ordering_fields', "
"set the 'index_models' attribute or override the 'get_queryset' "
- "method and pass some 'index_models'."
- % self.__class__.__name__)
+ "method and pass some 'index_models'." % self.__class__.__name__
+ )
- model_fields = map(lambda model: [(field.name, field.verbose_name) for field in model._meta.fields],
- queryset.query.models)
+ model_fields = map(
+ lambda model: [(field.name, field.verbose_name) for field in model._meta.fields], queryset.query.models
+ )
valid_fields = list(set(reduce(operator.concat, model_fields)))
else:
- valid_fields = [
- (item, item) if isinstance(item, str) else item
- for item in valid_fields
- ]
+ valid_fields = [(item, item) if isinstance(item, str) else item for item in valid_fields]
return valid_fields
diff --git a/drf_haystack/generics.py b/drf_haystack/generics.py
index 52a640b..923eb1d 100644
--- a/drf_haystack/generics.py
+++ b/drf_haystack/generics.py
@@ -1,8 +1,5 @@
-import six
-
from django.contrib.contenttypes.models import ContentType
from django.http import Http404
-
from haystack.backends import SQ
from haystack.query import SearchQuerySet
from rest_framework.generics import GenericAPIView
@@ -14,6 +11,7 @@ class HaystackGenericAPIView(GenericAPIView):
"""
Base class for all haystack generic views.
"""
+
# Use `index_models` to filter on which search index models we
# should include in the search result.
index_models = []
@@ -72,8 +70,10 @@ def get_object(self):
ctype = ContentType.objects.get(app_label=app_label, model=model)
queryset = self.get_queryset(index_models=[ctype.model_class()])
except (ValueError, ContentType.DoesNotExist):
- raise Http404("Could not find any models matching '%s'. Make sure to use a valid "
- "'app_label.model' name for the 'model' query parameter." % self.request.query_params["model"])
+ raise Http404(
+ "Could not find any models matching '%s'. Make sure to use a valid "
+ "'app_label.model' name for the 'model' query parameter." % self.request.query_params["model"]
+ )
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
if lookup_url_kwarg not in self.kwargs:
diff --git a/drf_haystack/mixins.py b/drf_haystack/mixins.py
index 15579e5..9aea457 100644
--- a/drf_haystack/mixins.py
+++ b/drf_haystack/mixins.py
@@ -37,7 +37,7 @@ class FacetMixin:
facet_filter_backends = [HaystackFacetFilter]
facet_serializer_class = None
facet_objects_serializer_class = None
- facet_query_params_text = 'selected_facets'
+ facet_query_params_text = "selected_facets"
@action(detail=False, methods=["get"], url_path="facets")
def facets(self, request):
@@ -48,7 +48,6 @@ def facets(self, request):
queryset = self.filter_facet_queryset(self.get_queryset())
for facet in request.query_params.getlist(self.facet_query_params_text):
-
if ":" not in facet:
continue
@@ -95,8 +94,7 @@ def get_facet_serializer_class(self):
if self.facet_serializer_class is None:
raise AttributeError(
"%(cls)s should either include a `facet_serializer_class` attribute, "
- "or override %(cls)s.get_facet_serializer_class() method." %
- {"cls": self.__class__.__name__}
+ "or override %(cls)s.get_facet_serializer_class() method." % {"cls": self.__class__.__name__}
)
return self.facet_serializer_class
diff --git a/drf_haystack/query.py b/drf_haystack/query.py
index abf06aa..17c81c1 100644
--- a/drf_haystack/query.py
+++ b/drf_haystack/query.py
@@ -1,9 +1,8 @@
import operator
-import six
import warnings
from itertools import chain
-
+import six
from dateutil import parser
from drf_haystack import constants
@@ -56,14 +55,14 @@ def build_query(self, **filters):
try:
term, val = chain.from_iterable(zip(self.tokenize(value, self.view.lookup_sep)))
except ValueError:
- raise ValueError("Cannot convert the '%s' query parameter to a valid boost filter."
- % query_param)
+ raise ValueError("Cannot convert the '%s' query parameter to a valid boost filter." % query_param)
else:
try:
applicable_filters = {"term": term, "boost": float(val)}
except ValueError:
- raise ValueError("Cannot convert boost to float value. Make sure to provide a "
- "numerical boost value.")
+ raise ValueError(
+ "Cannot convert boost to float value. Make sure to provide a numerical boost value."
+ )
return applicable_filters
@@ -79,7 +78,8 @@ def __init__(self, backend, view):
assert getattr(self.backend, "default_operator", None) in (operator.and_, operator.or_), (
"{cls}.default_operator must be either 'operator.and_' or 'operator.or_'.".format(
cls=self.backend.__class__.__name__
- ))
+ )
+ )
self.default_operator = self.backend.default_operator
self.default_same_param_operator = getattr(self.backend, "default_same_param_operator", self.default_operator)
@@ -118,23 +118,26 @@ def build_query(self, **filters):
param = param.replace("__%s" % negation_keyword, "") # haystack wouldn't understand our negation
if self.view.serializer_class:
- if hasattr(self.view.serializer_class.Meta, 'field_aliases'):
+ if hasattr(self.view.serializer_class.Meta, "field_aliases"):
old_base = base_param
base_param = self.view.serializer_class.Meta.field_aliases.get(base_param, base_param)
param = param.replace(old_base, base_param) # need to replace the alias
- fields = getattr(self.view.serializer_class.Meta, 'fields', [])
- exclude = getattr(self.view.serializer_class.Meta, 'exclude', [])
- search_fields = getattr(self.view.serializer_class.Meta, 'search_fields', [])
+ fields = getattr(self.view.serializer_class.Meta, "fields", [])
+ exclude = getattr(self.view.serializer_class.Meta, "exclude", [])
+ search_fields = getattr(self.view.serializer_class.Meta, "search_fields", [])
# Skip if the parameter is not listed in the serializer's `fields`
# or if it's in the `exclude` list.
- if ((fields or search_fields) and base_param not in
- chain(fields, search_fields)) or base_param in exclude or not value:
+ if (
+ ((fields or search_fields) and base_param not in chain(fields, search_fields))
+ or base_param in exclude
+ or not value
+ ):
continue
param_queries = []
- if len(param_parts) > 1 and param_parts[-1] in ('in', 'range'):
+ if len(param_parts) > 1 and param_parts[-1] in ("in", "range"):
# `in` and `range` filters expects a list of values
param_queries.append(self.view.query_object((param, list(self.tokenize(value, self.view.lookup_sep)))))
else:
@@ -149,11 +152,17 @@ def build_query(self, **filters):
else:
applicable_filters.append(term)
- applicable_filters = six.moves.reduce(
- self.default_operator, filter(lambda x: x, applicable_filters)) if applicable_filters else self.view.query_object()
+ applicable_filters = (
+ six.moves.reduce(self.default_operator, filter(lambda x: x, applicable_filters))
+ if applicable_filters
+ else self.view.query_object()
+ )
- applicable_exclusions = six.moves.reduce(
- self.default_operator, filter(lambda x: x, applicable_exclusions)) if applicable_exclusions else self.view.query_object()
+ applicable_exclusions = (
+ six.moves.reduce(self.default_operator, filter(lambda x: x, applicable_exclusions))
+ if applicable_exclusions
+ else self.view.query_object()
+ )
return applicable_filters, applicable_exclusions
@@ -178,16 +187,17 @@ def build_query(self, **filters):
facet_serializer_cls = self.view.get_facet_serializer_class()
if self.view.lookup_sep == ":":
- raise AttributeError("The %(cls)s.lookup_sep attribute conflicts with the HaystackFacetFilter "
- "query parameter parser. Please choose another `lookup_sep` attribute "
- "for %(cls)s." % {"cls": self.view.__class__.__name__})
+ raise AttributeError(
+ "The %(cls)s.lookup_sep attribute conflicts with the HaystackFacetFilter "
+ "query parameter parser. Please choose another `lookup_sep` attribute "
+ "for %(cls)s." % {"cls": self.view.__class__.__name__}
+ )
fields = facet_serializer_cls.Meta.fields
exclude = facet_serializer_cls.Meta.exclude
field_options = facet_serializer_cls.Meta.field_options
for field, options in filters.items():
-
if field not in fields or field in exclude:
continue
@@ -196,12 +206,10 @@ def build_query(self, **filters):
valid_gap = ("year", "month", "day", "hour", "minute", "second")
for field, options in field_options.items():
if any([k in options for k in ("start_date", "end_date", "gap_by", "gap_amount")]):
-
if not all(("start_date", "end_date", "gap_by" in options)):
- raise ValueError("Date faceting requires at least 'start_date', 'end_date' "
- "and 'gap_by' to be set.")
+ raise ValueError("Date faceting requires at least 'start_date', 'end_date' and 'gap_by' to be set.")
- if not options["gap_by"] in valid_gap:
+ if options["gap_by"] not in valid_gap:
raise ValueError("The 'gap_by' parameter must be one of %s." % ", ".join(valid_gap))
options.setdefault("gap_amount", 1)
@@ -210,11 +218,7 @@ def build_query(self, **filters):
else:
field_facets[field] = field_options[field]
- return {
- "date_facets": date_facets,
- "field_facets": field_facets,
- "query_facets": query_facets
- }
+ return {"date_facets": date_facets, "field_facets": field_facets, "query_facets": query_facets}
def parse_field_options(self, *options):
"""
@@ -227,14 +231,15 @@ def parse_field_options(self, *options):
for token in tokens:
if not len(token.split(":")) == 2:
- warnings.warn("The %s token is not properly formatted. Tokens need to be "
- "formatted as 'token:value' pairs." % token)
+ warnings.warn(
+ "The %s token is not properly formatted. Tokens need to be "
+ "formatted as 'token:value' pairs." % token
+ )
continue
param, value = token.split(":", 1)
if any([k == param for k in ("start_date", "end_date", "gap_amount")]):
-
if param in ("start_date", "end_date"):
value = parser.parse(value)
@@ -256,18 +261,21 @@ def __init__(self, backend, view):
assert getattr(self.backend, "point_field", None) is not None, (
"%(cls)s.point_field cannot be None. Set the %(cls)s.point_field "
- "to the name of the `LocationField` you want to filter on your index class." % {
- "cls": self.backend.__class__.__name__
- })
+ "to the name of the `LocationField` you want to filter on your index class."
+ % {"cls": self.backend.__class__.__name__}
+ )
try:
from haystack.utils.geo import D, Point
+
self.D = D
self.Point = Point
except ImportError:
- warnings.warn("Make sure you've installed the `libgeos` library. "
- "Run `apt-get install libgeos` on debian based linux systems, "
- "or `brew install geos` on OS X.")
+ warnings.warn(
+ "Make sure you've installed the `libgeos` library. "
+ "Run `apt-get install libgeos` on debian based linux systems, "
+ "or `brew install geos` on OS X."
+ )
raise
def build_query(self, **filters):
@@ -288,17 +296,23 @@ def build_query(self, **filters):
applicable_filters = None
- filters = {k: filters[k] for k in chain(self.D.UNITS.keys(),
- [constants.DRF_HAYSTACK_SPATIAL_QUERY_PARAM]) if k in filters}
+ filters = {
+ k: filters[k]
+ for k in chain(self.D.UNITS.keys(), [constants.DRF_HAYSTACK_SPATIAL_QUERY_PARAM])
+ if k in filters
+ }
distance = {k: v for k, v in filters.items() if k in self.D.UNITS.keys()}
try:
- latitude, longitude = map(float, self.tokenize(filters[constants.DRF_HAYSTACK_SPATIAL_QUERY_PARAM],
- self.view.lookup_sep))
+ latitude, longitude = map(
+ float, self.tokenize(filters[constants.DRF_HAYSTACK_SPATIAL_QUERY_PARAM], self.view.lookup_sep)
+ )
point = self.Point(longitude, latitude, srid=constants.GEO_SRID)
except ValueError:
- raise ValueError("Cannot convert `from=latitude,longitude` query parameter to "
- "float values. Make sure to provide numerical values only!")
+ raise ValueError(
+ "Cannot convert `from=latitude,longitude` query parameter to "
+ "float values. Make sure to provide numerical values only!"
+ )
except KeyError:
# If the user has not provided any `from` query string parameter,
# just return.
@@ -311,15 +325,8 @@ def build_query(self, **filters):
if point and distance:
applicable_filters = {
- "dwithin": {
- "field": self.backend.point_field,
- "point": point,
- "distance": self.D(**distance)
- },
- "distance": {
- "field": self.backend.point_field,
- "point": point
- }
+ "dwithin": {"field": self.backend.point_field, "point": point, "distance": self.D(**distance)},
+ "distance": {"field": self.backend.point_field, "point": point},
}
return applicable_filters
diff --git a/drf_haystack/serializers.py b/drf_haystack/serializers.py
index beef428..90d8752 100644
--- a/drf_haystack/serializers.py
+++ b/drf_haystack/serializers.py
@@ -1,28 +1,31 @@
import copy
-import six
-import warnings
-from itertools import chain
from datetime import datetime
+from itertools import chain
try:
from collections import OrderedDict
except ImportError:
from django.utils.datastructures import SortedDict as OrderedDict
-from django.core.exceptions import ImproperlyConfigured, FieldDoesNotExist
-
+from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
from haystack import fields as haystack_fields
from haystack.query import EmptySearchQuerySet
from haystack.utils.highlighting import Highlighter
-
from rest_framework import serializers
from rest_framework.fields import empty
from rest_framework.utils.field_mapping import ClassLookupDict, get_field_kwargs
from drf_haystack.fields import (
- HaystackBooleanField, HaystackCharField, HaystackDateField, HaystackDateTimeField,
- HaystackDecimalField, HaystackFloatField, HaystackIntegerField, HaystackMultiValueField,
- FacetDictField, FacetListField
+ FacetDictField,
+ FacetListField,
+ HaystackBooleanField,
+ HaystackCharField,
+ HaystackDateField,
+ HaystackDateTimeField,
+ HaystackDecimalField,
+ HaystackFloatField,
+ HaystackIntegerField,
+ HaystackMultiValueField,
)
@@ -57,7 +60,6 @@ def __delattr__(cls, key, value):
class HaystackSerializerMeta(serializers.SerializerMetaclass):
-
"""
Metaclass for the HaystackSerializer that ensures that all declared subclasses implemented a Meta.
"""
@@ -110,8 +112,9 @@ def __init__(self, instance=None, data=empty, **kwargs):
super().__init__(instance, data, **kwargs)
if not self.Meta.index_classes and not self.Meta.serializers:
- raise ImproperlyConfigured("You must set either the 'index_classes' or 'serializers' "
- "attribute on the serializer Meta class.")
+ raise ImproperlyConfigured(
+ "You must set either the 'index_classes' or 'serializers' attribute on the serializer Meta class."
+ )
if not self.instance:
self.instance = EmptySearchQuerySet()
@@ -156,7 +159,7 @@ def _get_index_class_name(self, index_cls):
"""
cls_name = index_cls.__name__
aliases = self.Meta.index_aliases
- return aliases.get(cls_name, cls_name.split('.')[-1])
+ return aliases.get(cls_name, cls_name.split(".")[-1])
def get_fields(self):
"""
@@ -199,7 +202,7 @@ def get_fields(self):
# in order to correctly instantiate the serializer field.
model = index_cls().get_model()
kwargs = self._get_default_field_kwargs(model, field_type)
- kwargs['prefix_field_names'] = prefix_field_names
+ kwargs["prefix_field_names"] = prefix_field_names
field_mapping[field_name] = self._field_mapping[field_type](**kwargs)
# Add any explicitly declared fields. They *will* override any index fields
@@ -302,10 +305,9 @@ def get_paginate_by_param(self):
raise AttributeError(
"%(root_cls)s is missing a `paginate_by_param` attribute. "
"Define a %(root_cls)s.paginate_by_param or override "
- "%(cls)s.get_paginate_by_param()." % {
- "root_cls": self.root.__class__.__name__,
- "cls": self.__class__.__name__
- })
+ "%(cls)s.get_paginate_by_param()."
+ % {"root_cls": self.root.__class__.__name__, "cls": self.__class__.__name__}
+ )
def get_text(self, instance):
"""
@@ -378,10 +380,11 @@ def get_fields(self):
"""
field_mapping = OrderedDict()
for field, data in self.instance.items():
- field_mapping.update(
- {field: self.facet_dict_field_class(
- child=self.facet_list_field_class(child=self.facet_field_serializer_class(data)), required=False)}
- )
+ field_mapping.update({
+ field: self.facet_dict_field_class(
+ child=self.facet_list_field_class(child=self.facet_field_serializer_class(data)), required=False
+ )
+ })
if self.serialize_objects is True:
field_mapping["objects"] = serializers.SerializerMethodField()
@@ -402,7 +405,7 @@ def get_objects(self, instance):
("count", self.get_count(queryset)),
("next", view.paginator.get_next_link()),
("previous", view.paginator.get_previous_link()),
- ("results", serializer.data)
+ ("results", serializer.data),
])
serializer = view.get_serializer(queryset, many=True)
@@ -451,8 +454,7 @@ def get_highlighter(self):
if not self.highlighter_class:
raise ImproperlyConfigured(
"%(cls)s is missing a highlighter_class. Define %(cls)s.highlighter_class, "
- "or override %(cls)s.get_highlighter()." %
- {"cls": self.__class__.__name__}
+ "or override %(cls)s.get_highlighter()." % {"cls": self.__class__.__name__}
)
return self.highlighter_class
@@ -477,14 +479,17 @@ def to_representation(self, instance):
ret = super().to_representation(instance)
terms = self.get_terms(ret)
if terms:
- highlighter = self.get_highlighter()(terms, **{
- "html_tag": self.highlighter_html_tag,
- "css_class": self.highlighter_css_class,
- "max_length": self.highlighter_max_length
- })
+ highlighter = self.get_highlighter()(
+ terms,
+ **{
+ "html_tag": self.highlighter_html_tag,
+ "css_class": self.highlighter_css_class,
+ "max_length": self.highlighter_max_length,
+ },
+ )
document_field = self.get_document_field(instance)
if highlighter and document_field:
# Handle case where this data is None, but highlight expects it to be a string
- data_to_highlight = getattr(instance, self.highlighter_field or document_field) or ''
+ data_to_highlight = getattr(instance, self.highlighter_field or document_field) or ""
ret["highlighted"] = highlighter.highlight(data_to_highlight)
return ret
diff --git a/drf_haystack/utils.py b/drf_haystack/utils.py
index af8341e..af4b062 100644
--- a/drf_haystack/utils.py
+++ b/drf_haystack/utils.py
@@ -1,4 +1,3 @@
-import six
from copy import deepcopy
diff --git a/drf_haystack/viewsets.py b/drf_haystack/viewsets.py
index c17d0e6..b9940da 100644
--- a/drf_haystack/viewsets.py
+++ b/drf_haystack/viewsets.py
@@ -9,4 +9,5 @@ class HaystackViewSet(RetrieveModelMixin, ListModelMixin, ViewSetMixin, Haystack
The HaystackViewSet class provides the default ``list()`` and
``retrieve()`` actions with a haystack index as it's data source.
"""
+
pass
diff --git a/ez_setup.py b/ez_setup.py
index da497e6..3c36d41 100644
--- a/ez_setup.py
+++ b/ez_setup.py
@@ -13,12 +13,13 @@
This file can also be run as a script to install or upgrade setuptools.
"""
+
+import fnmatch
import os
import sys
-import time
-import fnmatch
-import tempfile
import tarfile
+import tempfile
+import time
from distutils import log
try:
@@ -38,19 +39,23 @@ def _python_cmd(*args):
def _python_cmd(*args):
args = (sys.executable,) + args
# quoting arguments if windows
- if sys.platform == 'win32':
+ if sys.platform == "win32":
+
def quote(arg):
- if ' ' in arg:
+ if " " in arg:
return '"%s"' % arg
return arg
+
args = [quote(arg) for arg in args]
return os.spawnl(os.P_WAIT, sys.executable, *args) == 0
+
DEFAULT_VERSION = "0.6.14"
DEFAULT_URL = "http://pypi.python.org/packages/source/d/distribute/"
SETUPTOOLS_FAKED_VERSION = "0.6c11"
-SETUPTOOLS_PKG_INFO = """\
+SETUPTOOLS_PKG_INFO = (
+ """\
Metadata-Version: 1.0
Name: setuptools
Version: %s
@@ -60,13 +65,15 @@ def quote(arg):
Author-email: xxx
License: xxx
Description: xxx
-""" % SETUPTOOLS_FAKED_VERSION
+"""
+ % SETUPTOOLS_FAKED_VERSION
+)
def _install(tarball):
# extracting the tarball
tmpdir = tempfile.mkdtemp()
- log.warn('Extracting in %s', tmpdir)
+ log.warn("Extracting in %s", tmpdir)
old_wd = os.getcwd()
try:
os.chdir(tmpdir)
@@ -77,13 +84,13 @@ def _install(tarball):
# going in the directory
subdir = os.path.join(tmpdir, os.listdir(tmpdir)[0])
os.chdir(subdir)
- log.warn('Now working in %s', subdir)
+ log.warn("Now working in %s", subdir)
# installing
- log.warn('Installing Distribute')
- if not _python_cmd('setup.py', 'install'):
- log.warn('Something went wrong during the installation.')
- log.warn('See the error message above.')
+ log.warn("Installing Distribute")
+ if not _python_cmd("setup.py", "install"):
+ log.warn("Something went wrong during the installation.")
+ log.warn("See the error message above.")
finally:
os.chdir(old_wd)
@@ -91,7 +98,7 @@ def _install(tarball):
def _build_egg(egg, tarball, to_dir):
# extracting the tarball
tmpdir = tempfile.mkdtemp()
- log.warn('Extracting in %s', tmpdir)
+ log.warn("Extracting in %s", tmpdir)
old_wd = os.getcwd()
try:
os.chdir(tmpdir)
@@ -102,73 +109,72 @@ def _build_egg(egg, tarball, to_dir):
# going in the directory
subdir = os.path.join(tmpdir, os.listdir(tmpdir)[0])
os.chdir(subdir)
- log.warn('Now working in %s', subdir)
+ log.warn("Now working in %s", subdir)
# building an egg
- log.warn('Building a Distribute egg in %s', to_dir)
- _python_cmd('setup.py', '-q', 'bdist_egg', '--dist-dir', to_dir)
+ log.warn("Building a Distribute egg in %s", to_dir)
+ _python_cmd("setup.py", "-q", "bdist_egg", "--dist-dir", to_dir)
finally:
os.chdir(old_wd)
# returning the result
log.warn(egg)
if not os.path.exists(egg):
- raise OSError('Could not build the egg.')
+ raise OSError("Could not build the egg.")
def _do_download(version, download_base, to_dir, download_delay):
- egg = os.path.join(to_dir, 'distribute-%s-py%d.%d.egg'
- % (version, sys.version_info[0], sys.version_info[1]))
+ egg = os.path.join(to_dir, "distribute-%s-py%d.%d.egg" % (version, sys.version_info[0], sys.version_info[1]))
if not os.path.exists(egg):
- tarball = download_setuptools(version, download_base,
- to_dir, download_delay)
+ tarball = download_setuptools(version, download_base, to_dir, download_delay)
_build_egg(egg, tarball, to_dir)
sys.path.insert(0, egg)
import setuptools
+
setuptools.bootstrap_install_from = egg
-def use_setuptools(version=DEFAULT_VERSION, download_base=DEFAULT_URL,
- to_dir=os.curdir, download_delay=15, no_fake=True):
+def use_setuptools(
+ version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir, download_delay=15, no_fake=True
+):
# making sure we use the absolute path
to_dir = os.path.abspath(to_dir)
- was_imported = 'pkg_resources' in sys.modules or \
- 'setuptools' in sys.modules
+ was_imported = "pkg_resources" in sys.modules or "setuptools" in sys.modules
try:
try:
import pkg_resources
- if not hasattr(pkg_resources, '_distribute'):
+
+ if not hasattr(pkg_resources, "_distribute"):
if not no_fake:
_fake_setuptools()
raise ImportError
except ImportError:
return _do_download(version, download_base, to_dir, download_delay)
try:
- pkg_resources.require("distribute>="+version)
+ pkg_resources.require("distribute>=" + version)
return
except pkg_resources.VersionConflict:
e = sys.exc_info()[1]
if was_imported:
sys.stderr.write(
- "The required version of distribute (>=%s) is not available,\n"
- "and can't be installed while this script is running. Please\n"
- "install a more recent version first, using\n"
- "'easy_install -U distribute'."
- "\n\n(Currently using %r)\n" % (version, e.args[0]))
+ "The required version of distribute (>=%s) is not available,\n"
+ "and can't be installed while this script is running. Please\n"
+ "install a more recent version first, using\n"
+ "'easy_install -U distribute'."
+ "\n\n(Currently using %r)\n" % (version, e.args[0])
+ )
sys.exit(2)
else:
- del pkg_resources, sys.modules['pkg_resources'] # reload ok
- return _do_download(version, download_base, to_dir,
- download_delay)
+ del pkg_resources, sys.modules["pkg_resources"] # reload ok
+ return _do_download(version, download_base, to_dir, download_delay)
except pkg_resources.DistributionNotFound:
- return _do_download(version, download_base, to_dir,
- download_delay)
+ return _do_download(version, download_base, to_dir, download_delay)
finally:
if not no_fake:
_create_fake_setuptools_pkg_info(to_dir)
-def download_setuptools(version=DEFAULT_VERSION, download_base=DEFAULT_URL,
- to_dir=os.curdir, delay=15):
+
+def download_setuptools(version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir, delay=15):
"""Download distribute from a specified location and return its filename
`version` should be a valid distribute version number that is available
@@ -203,13 +209,17 @@ def download_setuptools(version=DEFAULT_VERSION, download_base=DEFAULT_URL,
dst.close()
return os.path.realpath(saveto)
+
def _no_sandbox(function):
def __no_sandbox(*args, **kw):
try:
from setuptools.sandbox import DirectorySandbox
- if not hasattr(DirectorySandbox, '_old'):
+
+ if not hasattr(DirectorySandbox, "_old"):
+
def violation(*args):
pass
+
DirectorySandbox._old = DirectorySandbox._violation
DirectorySandbox._violation = violation
patched = True
@@ -227,47 +237,52 @@ def violation(*args):
return __no_sandbox
+
def _patch_file(path, content):
"""Will backup the file then patch it"""
existing_content = open(path).read()
if existing_content == content:
# already patched
- log.warn('Already patched.')
+ log.warn("Already patched.")
return False
- log.warn('Patching...')
+ log.warn("Patching...")
_rename_path(path)
- f = open(path, 'w')
+ f = open(path, "w")
try:
f.write(content)
finally:
f.close()
return True
+
_patch_file = _no_sandbox(_patch_file)
+
def _same_content(path, content):
return open(path).read() == content
+
def _rename_path(path):
- new_name = path + '.OLD.%s' % time.time()
- log.warn('Renaming %s into %s', path, new_name)
+ new_name = path + ".OLD.%s" % time.time()
+ log.warn("Renaming %s into %s", path, new_name)
os.rename(path, new_name)
return new_name
+
def _remove_flat_installation(placeholder):
if not os.path.isdir(placeholder):
- log.warn('Unkown installation at %s', placeholder)
+ log.warn("Unkown installation at %s", placeholder)
return False
found = False
for file in os.listdir(placeholder):
- if fnmatch.fnmatch(file, 'setuptools*.egg-info'):
+ if fnmatch.fnmatch(file, "setuptools*.egg-info"):
found = True
break
if not found:
- log.warn('Could not locate setuptools*.egg-info')
+ log.warn("Could not locate setuptools*.egg-info")
return
- log.warn('Removing elements out of the way...')
+ log.warn("Removing elements out of the way...")
pkg_info = os.path.join(placeholder, file)
if os.path.isdir(pkg_info):
patched = _patch_egg_dir(pkg_info)
@@ -275,169 +290,172 @@ def _remove_flat_installation(placeholder):
patched = _patch_file(pkg_info, SETUPTOOLS_PKG_INFO)
if not patched:
- log.warn('%s already patched.', pkg_info)
+ log.warn("%s already patched.", pkg_info)
return False
# now let's move the files out of the way
- for element in ('setuptools', 'pkg_resources.py', 'site.py'):
+ for element in ("setuptools", "pkg_resources.py", "site.py"):
element = os.path.join(placeholder, element)
if os.path.exists(element):
_rename_path(element)
else:
- log.warn('Could not find the %s element of the '
- 'Setuptools distribution', element)
+ log.warn("Could not find the %s element of the Setuptools distribution", element)
return True
+
_remove_flat_installation = _no_sandbox(_remove_flat_installation)
+
def _after_install(dist):
- log.warn('After install bootstrap.')
- placeholder = dist.get_command_obj('install').install_purelib
+ log.warn("After install bootstrap.")
+ placeholder = dist.get_command_obj("install").install_purelib
_create_fake_setuptools_pkg_info(placeholder)
+
def _create_fake_setuptools_pkg_info(placeholder):
if not placeholder or not os.path.exists(placeholder):
- log.warn('Could not find the install location')
+ log.warn("Could not find the install location")
return
- pyver = f'{sys.version_info[0]}.{sys.version_info[1]}'
- setuptools_file = 'setuptools-%s-py%s.egg-info' % \
- (SETUPTOOLS_FAKED_VERSION, pyver)
+ pyver = f"{sys.version_info[0]}.{sys.version_info[1]}"
+ setuptools_file = f"setuptools-{SETUPTOOLS_FAKED_VERSION}-py{pyver}.egg-info"
pkg_info = os.path.join(placeholder, setuptools_file)
if os.path.exists(pkg_info):
- log.warn('%s already exists', pkg_info)
+ log.warn("%s already exists", pkg_info)
return
- log.warn('Creating %s', pkg_info)
- f = open(pkg_info, 'w')
+ log.warn("Creating %s", pkg_info)
+ f = open(pkg_info, "w")
try:
f.write(SETUPTOOLS_PKG_INFO)
finally:
f.close()
- pth_file = os.path.join(placeholder, 'setuptools.pth')
- log.warn('Creating %s', pth_file)
- f = open(pth_file, 'w')
+ pth_file = os.path.join(placeholder, "setuptools.pth")
+ log.warn("Creating %s", pth_file)
+ f = open(pth_file, "w")
try:
f.write(os.path.join(os.curdir, setuptools_file))
finally:
f.close()
+
_create_fake_setuptools_pkg_info = _no_sandbox(_create_fake_setuptools_pkg_info)
+
def _patch_egg_dir(path):
# let's check if it's already patched
- pkg_info = os.path.join(path, 'EGG-INFO', 'PKG-INFO')
+ pkg_info = os.path.join(path, "EGG-INFO", "PKG-INFO")
if os.path.exists(pkg_info):
if _same_content(pkg_info, SETUPTOOLS_PKG_INFO):
- log.warn('%s already patched.', pkg_info)
+ log.warn("%s already patched.", pkg_info)
return False
_rename_path(path)
os.mkdir(path)
- os.mkdir(os.path.join(path, 'EGG-INFO'))
- pkg_info = os.path.join(path, 'EGG-INFO', 'PKG-INFO')
- f = open(pkg_info, 'w')
+ os.mkdir(os.path.join(path, "EGG-INFO"))
+ pkg_info = os.path.join(path, "EGG-INFO", "PKG-INFO")
+ f = open(pkg_info, "w")
try:
f.write(SETUPTOOLS_PKG_INFO)
finally:
f.close()
return True
+
_patch_egg_dir = _no_sandbox(_patch_egg_dir)
+
def _before_install():
- log.warn('Before install bootstrap.')
+ log.warn("Before install bootstrap.")
_fake_setuptools()
def _under_prefix(location):
- if 'install' not in sys.argv:
+ if "install" not in sys.argv:
return True
- args = sys.argv[sys.argv.index('install')+1:]
+ args = sys.argv[sys.argv.index("install") + 1 :]
for index, arg in enumerate(args):
- for option in ('--root', '--prefix'):
- if arg.startswith('%s=' % option):
- top_dir = arg.split('root=')[-1]
+ for option in ("--root", "--prefix"):
+ if arg.startswith("%s=" % option):
+ top_dir = arg.split("root=")[-1]
return location.startswith(top_dir)
elif arg == option:
if len(args) > index:
- top_dir = args[index+1]
+ top_dir = args[index + 1]
return location.startswith(top_dir)
- if arg == '--user' and USER_SITE is not None:
+ if arg == "--user" and USER_SITE is not None:
return location.startswith(USER_SITE)
return True
def _fake_setuptools():
- log.warn('Scanning installed packages')
+ log.warn("Scanning installed packages")
try:
import pkg_resources
except ImportError:
# we're cool
- log.warn('Setuptools or Distribute does not seem to be installed.')
+ log.warn("Setuptools or Distribute does not seem to be installed.")
return
ws = pkg_resources.working_set
try:
- setuptools_dist = ws.find(pkg_resources.Requirement.parse('setuptools',
- replacement=False))
+ setuptools_dist = ws.find(pkg_resources.Requirement.parse("setuptools", replacement=False))
except TypeError:
# old distribute API
- setuptools_dist = ws.find(pkg_resources.Requirement.parse('setuptools'))
+ setuptools_dist = ws.find(pkg_resources.Requirement.parse("setuptools"))
if setuptools_dist is None:
- log.warn('No setuptools distribution found')
+ log.warn("No setuptools distribution found")
return
# detecting if it was already faked
setuptools_location = setuptools_dist.location
- log.warn('Setuptools installation detected at %s', setuptools_location)
+ log.warn("Setuptools installation detected at %s", setuptools_location)
# if --root or --preix was provided, and if
# setuptools is not located in them, we don't patch it
if not _under_prefix(setuptools_location):
- log.warn('Not patching, --root or --prefix is installing Distribute'
- ' in another location')
+ log.warn("Not patching, --root or --prefix is installing Distribute in another location")
return
# let's see if its an egg
- if not setuptools_location.endswith('.egg'):
- log.warn('Non-egg installation')
+ if not setuptools_location.endswith(".egg"):
+ log.warn("Non-egg installation")
res = _remove_flat_installation(setuptools_location)
if not res:
return
else:
- log.warn('Egg installation')
- pkg_info = os.path.join(setuptools_location, 'EGG-INFO', 'PKG-INFO')
- if (os.path.exists(pkg_info) and
- _same_content(pkg_info, SETUPTOOLS_PKG_INFO)):
- log.warn('Already patched.')
+ log.warn("Egg installation")
+ pkg_info = os.path.join(setuptools_location, "EGG-INFO", "PKG-INFO")
+ if os.path.exists(pkg_info) and _same_content(pkg_info, SETUPTOOLS_PKG_INFO):
+ log.warn("Already patched.")
return
- log.warn('Patching...')
+ log.warn("Patching...")
# let's create a fake egg replacing setuptools one
res = _patch_egg_dir(setuptools_location)
if not res:
return
- log.warn('Patched done.')
+ log.warn("Patched done.")
_relaunch()
def _relaunch():
- log.warn('Relaunching...')
+ log.warn("Relaunching...")
# we have to relaunch the process
# pip marker to avoid a relaunch bug
- if sys.argv[:3] == ['-c', 'install', '--single-version-externally-managed']:
- sys.argv[0] = 'setup.py'
+ if sys.argv[:3] == ["-c", "install", "--single-version-externally-managed"]:
+ sys.argv[0] = "setup.py"
args = [sys.executable] + sys.argv
sys.exit(subprocess.call(args))
def _extractall(self, path=".", members=None):
"""Extract all members from the archive to the current working
- directory and set owner, modification time and permissions on
- directories afterwards. `path' specifies a different directory
- to extract to. `members' is optional and must be a subset of the
- list returned by getmembers().
+ directory and set owner, modification time and permissions on
+ directories afterwards. `path' specifies a different directory
+ to extract to. `members' is optional and must be a subset of the
+ list returned by getmembers().
"""
import copy
import operator
from tarfile import ExtractError
+
directories = []
if members is None:
@@ -448,17 +466,11 @@ def _extractall(self, path=".", members=None):
# Extract directories with a safe mode.
directories.append(tarinfo)
tarinfo = copy.copy(tarinfo)
- tarinfo.mode = 448 # decimal for oct 0700
+ tarinfo.mode = 448 # decimal for oct 0700
self.extract(tarinfo, path)
# Reverse sort directories.
- if sys.version_info < (2, 4):
- def sorter(dir1, dir2):
- return cmp(dir1.name, dir2.name)
- directories.sort(sorter)
- directories.reverse()
- else:
- directories.sort(key=operator.attrgetter('name'), reverse=True)
+ directories.sort(key=operator.attrgetter("name"), reverse=True)
# Set correct owner, mtime and filemode on directories.
for tarinfo in directories:
@@ -481,5 +493,5 @@ def main(argv, version=DEFAULT_VERSION):
_install(tarball)
-if __name__ == '__main__':
+if __name__ == "__main__":
main(sys.argv[1:])
diff --git a/ruff.toml b/ruff.toml
new file mode 100644
index 0000000..f537058
--- /dev/null
+++ b/ruff.toml
@@ -0,0 +1,13 @@
+line-length = 120
+
+target-version = "py38"
+
+[format]
+preview = true
+
+[lint]
+preview = true
+extend-select = [
+ # https://docs.astral.sh/ruff/rules/#isort-i
+ "I",
+]
diff --git a/setup.py b/setup.py
index f6924b0..30472e0 100644
--- a/setup.py
+++ b/setup.py
@@ -1,5 +1,5 @@
-import re
import os
+import re
try:
from setuptools import setup
diff --git a/tests/__init__.py b/tests/__init__.py
index 4b9fc75..7d620da 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -1,41 +1,46 @@
import os
-from django.core.exceptions import ImproperlyConfigured
+from importlib.util import find_spec
+
+import django
test_runner = None
old_config = None
os.environ["DJANGO_SETTINGS_MODULE"] = "tests.settings"
-import django
+
if hasattr(django, "setup"):
django.setup()
def _geospatial_support():
- try:
- import geopy
- from haystack.utils.geo import Point
- except (ImportError, ImproperlyConfigured):
- return False
- else:
- return True
+ return find_spec("geopy") and find_spec("haystack.utils.geo.Point")
+
+
geospatial_support = _geospatial_support()
def _restframework_version():
import rest_framework
+
return tuple(map(int, rest_framework.VERSION.split(".")))
+
+
restframework_version = _restframework_version()
def _elasticsearch_version():
import elasticsearch
+
return elasticsearch.VERSION
+
+
elasticsearch_version = _elasticsearch_version()
def setup():
from django.test.runner import DiscoverRunner
+
global test_runner
global old_config
diff --git a/tests/constants.py b/tests/constants.py
index bd4cc04..bb902cc 100644
--- a/tests/constants.py
+++ b/tests/constants.py
@@ -1,5 +1,6 @@
-import os
import json
+import os
+
from django.conf import settings
with open(os.path.join(settings.BASE_DIR, "mockapp", "fixtures", "mocklocation.json")) as f:
diff --git a/tests/mockapp/admin.py b/tests/mockapp/admin.py
index 8c38f3f..846f6b4 100644
--- a/tests/mockapp/admin.py
+++ b/tests/mockapp/admin.py
@@ -1,3 +1 @@
-from django.contrib import admin
-
# Register your models here.
diff --git a/tests/mockapp/migrations/0001_initial.py b/tests/mockapp/migrations/0001_initial.py
index 50ca3ff..9995553 100644
--- a/tests/mockapp/migrations/0001_initial.py
+++ b/tests/mockapp/migrations/0001_initial.py
@@ -1,26 +1,23 @@
-from django.db import models, migrations
+from django.db import migrations, models
class Migration(migrations.Migration):
-
- dependencies = [
- ]
+ dependencies = []
operations = [
migrations.CreateModel(
- name='MockLocation',
+ name="MockLocation",
fields=[
- ('id', models.AutoField(primary_key=True, auto_created=True, serialize=False, verbose_name='ID')),
- ('latitude', models.FloatField()),
- ('longitude', models.FloatField()),
- ('address', models.CharField(max_length=100)),
- ('city', models.CharField(max_length=30)),
- ('zip_code', models.CharField(max_length=10)),
- ('created', models.DateTimeField(auto_now_add=True)),
- ('updated', models.DateTimeField(auto_now=True)),
+ ("id", models.AutoField(primary_key=True, auto_created=True, serialize=False, verbose_name="ID")),
+ ("latitude", models.FloatField()),
+ ("longitude", models.FloatField()),
+ ("address", models.CharField(max_length=100)),
+ ("city", models.CharField(max_length=30)),
+ ("zip_code", models.CharField(max_length=10)),
+ ("created", models.DateTimeField(auto_now_add=True)),
+ ("updated", models.DateTimeField(auto_now=True)),
],
- options={
- },
+ options={},
bases=(models.Model,),
),
]
diff --git a/tests/mockapp/migrations/0002_mockperson.py b/tests/mockapp/migrations/0002_mockperson.py
index f2ec97b..ac6ab49 100644
--- a/tests/mockapp/migrations/0002_mockperson.py
+++ b/tests/mockapp/migrations/0002_mockperson.py
@@ -1,24 +1,22 @@
-from django.db import models, migrations
+from django.db import migrations, models
class Migration(migrations.Migration):
-
dependencies = [
- ('mockapp', '0001_initial'),
+ ("mockapp", "0001_initial"),
]
operations = [
migrations.CreateModel(
- name='MockPerson',
+ name="MockPerson",
fields=[
- ('id', models.AutoField(auto_created=True, verbose_name='ID', primary_key=True, serialize=False)),
- ('firstname', models.CharField(max_length=20)),
- ('lastname', models.CharField(max_length=20)),
- ('created', models.DateTimeField(auto_now_add=True)),
- ('updated', models.DateTimeField(auto_now=True)),
+ ("id", models.AutoField(auto_created=True, verbose_name="ID", primary_key=True, serialize=False)),
+ ("firstname", models.CharField(max_length=20)),
+ ("lastname", models.CharField(max_length=20)),
+ ("created", models.DateTimeField(auto_now_add=True)),
+ ("updated", models.DateTimeField(auto_now=True)),
],
- options={
- },
+ options={},
bases=(models.Model,),
),
]
diff --git a/tests/mockapp/migrations/0003_mockpet.py b/tests/mockapp/migrations/0003_mockpet.py
index d2ff986..6d12207 100644
--- a/tests/mockapp/migrations/0003_mockpet.py
+++ b/tests/mockapp/migrations/0003_mockpet.py
@@ -1,25 +1,23 @@
-from django.db import models, migrations
+from django.db import migrations, models
class Migration(migrations.Migration):
-
dependencies = [
- ('mockapp', '0001_initial'),
- ('mockapp', '0002_mockperson'),
+ ("mockapp", "0001_initial"),
+ ("mockapp", "0002_mockperson"),
]
operations = [
migrations.CreateModel(
- name='MockPet',
+ name="MockPet",
fields=[
- ('id', models.AutoField(auto_created=True, verbose_name='ID', primary_key=True, serialize=False)),
- ('name', models.CharField(max_length=20)),
- ('species', models.CharField(max_length=20)),
- ('created', models.DateTimeField(auto_now_add=True)),
- ('updated', models.DateTimeField(auto_now=True)),
+ ("id", models.AutoField(auto_created=True, verbose_name="ID", primary_key=True, serialize=False)),
+ ("name", models.CharField(max_length=20)),
+ ("species", models.CharField(max_length=20)),
+ ("created", models.DateTimeField(auto_now_add=True)),
+ ("updated", models.DateTimeField(auto_now=True)),
],
- options={
- },
+ options={},
bases=(models.Model,),
),
]
diff --git a/tests/mockapp/migrations/0004_load_fixtures.py b/tests/mockapp/migrations/0004_load_fixtures.py
index a2adfc5..120fb8f 100644
--- a/tests/mockapp/migrations/0004_load_fixtures.py
+++ b/tests/mockapp/migrations/0004_load_fixtures.py
@@ -1,7 +1,7 @@
import os
from django.core import serializers
-from django.db import models, migrations
+from django.db import migrations
def load_data(apps, schema_editor):
@@ -26,6 +26,7 @@ def load_data(apps, schema_editor):
for obj in objects:
obj.save()
+
def unload_data(apps, schema_editor):
"""
Unload fixtures for MockPerson, MockPet and MockLocation
@@ -41,7 +42,6 @@ def unload_data(apps, schema_editor):
class Migration(migrations.Migration):
-
dependencies = [
("mockapp", "0001_initial"),
("mockapp", "0002_mockperson"),
@@ -49,6 +49,4 @@ class Migration(migrations.Migration):
("mockapp", "0003_mockpet"),
]
- operations = [
- migrations.RunPython(load_data, reverse_code=unload_data)
- ]
+ operations = [migrations.RunPython(load_data, reverse_code=unload_data)]
diff --git a/tests/mockapp/migrations/0005_mockperson_birthdate.py b/tests/mockapp/migrations/0005_mockperson_birthdate.py
index 4f418b7..a2c4ca6 100644
--- a/tests/mockapp/migrations/0005_mockperson_birthdate.py
+++ b/tests/mockapp/migrations/0005_mockperson_birthdate.py
@@ -1,19 +1,19 @@
# Generated by Django 1.9.4 on 2016-04-16 07:05
from django.db import migrations, models
+
import tests.mockapp.models
class Migration(migrations.Migration):
-
dependencies = [
- ('mockapp', '0002_mockperson'),
+ ("mockapp", "0002_mockperson"),
]
operations = [
migrations.AddField(
- model_name='mockperson',
- name='birthdate',
+ model_name="mockperson",
+ name="birthdate",
field=models.DateField(default=tests.mockapp.models.get_random_date, null=True),
),
]
diff --git a/tests/mockapp/migrations/0006_mockallfield.py b/tests/mockapp/migrations/0006_mockallfield.py
index 704deec..9787a2c 100644
--- a/tests/mockapp/migrations/0006_mockallfield.py
+++ b/tests/mockapp/migrations/0006_mockallfield.py
@@ -1,27 +1,27 @@
# Generated by Django 1.11.9 on 2018-03-22 22:46
from django.db import migrations, models
+
import tests.mockapp.models
class Migration(migrations.Migration):
-
dependencies = [
- ('mockapp', '0004_load_fixtures'),
+ ("mockapp", "0004_load_fixtures"),
]
operations = [
migrations.CreateModel(
- name='MockAllField',
+ name="MockAllField",
fields=[
- ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
- ('charfield', models.CharField(max_length=100)),
- ('integerfield', models.IntegerField()),
- ('floatfield', models.FloatField()),
- ('decimalfield', models.DecimalField(decimal_places=2, max_digits=5)),
- ('boolfield', models.BooleanField(default=False)),
- ('datefield', models.DateField(default=tests.mockapp.models.get_random_date)),
- ('datetimefield', models.DateTimeField(default=tests.mockapp.models.get_random_datetime)),
+ ("id", models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
+ ("charfield", models.CharField(max_length=100)),
+ ("integerfield", models.IntegerField()),
+ ("floatfield", models.FloatField()),
+ ("decimalfield", models.DecimalField(decimal_places=2, max_digits=5)),
+ ("boolfield", models.BooleanField(default=False)),
+ ("datefield", models.DateField(default=tests.mockapp.models.get_random_date)),
+ ("datetimefield", models.DateTimeField(default=tests.mockapp.models.get_random_datetime)),
],
),
]
diff --git a/tests/mockapp/models.py b/tests/mockapp/models.py
index 53bc07e..1eaa608 100644
--- a/tests/mockapp/models.py
+++ b/tests/mockapp/models.py
@@ -1,7 +1,7 @@
-import pytz
from datetime import date, datetime, timedelta
-from random import randrange, randint
+from random import randint, randrange
+import pytz
from django.db import models
@@ -9,7 +9,7 @@ def get_random_date(start=date(1950, 1, 1), end=date.today()):
"""
:return a random date between `start` and `end`
"""
- delta = ((end - start).days * 24 * 60 * 60)
+ delta = (end - start).days * 24 * 60 * 60
return start + timedelta(seconds=randrange(delta))
@@ -17,12 +17,11 @@ def get_random_datetime(start=datetime(1950, 1, 1, 0, 0), end=datetime.today()):
"""
:return a random datetime
"""
- delta = ((end - start).total_seconds())
+ delta = (end - start).total_seconds()
return (start + timedelta(seconds=randint(0, int(delta)))).replace(tzinfo=pytz.UTC)
class MockLocation(models.Model):
-
latitude = models.FloatField()
longitude = models.FloatField()
address = models.CharField(max_length=100)
@@ -46,7 +45,6 @@ def coordinates(self):
class MockPerson(models.Model):
-
firstname = models.CharField(max_length=20)
lastname = models.CharField(max_length=20)
birthdate = models.DateField(null=True, default=get_random_date)
@@ -59,7 +57,6 @@ def __str__(self):
class MockPet(models.Model):
-
name = models.CharField(max_length=20)
species = models.CharField(max_length=20)
@@ -71,7 +68,6 @@ def __str__(self):
class MockAllField(models.Model):
-
charfield = models.CharField(max_length=100)
integerfield = models.IntegerField()
floatfield = models.FloatField()
diff --git a/tests/mockapp/search_indexes.py b/tests/mockapp/search_indexes.py
index 83a97d0..ef4e31c 100644
--- a/tests/mockapp/search_indexes.py
+++ b/tests/mockapp/search_indexes.py
@@ -1,11 +1,10 @@
from django.utils import timezone
from haystack import indexes
-from .models import MockLocation, MockPerson, MockPet, MockAllField
+from .models import MockAllField, MockLocation, MockPerson, MockPet
class MockLocationIndex(indexes.SearchIndex, indexes.Indexable):
-
text = indexes.CharField(document=True, use_template=True)
address = indexes.CharField(model_attr="address")
city = indexes.CharField(model_attr="city")
@@ -16,21 +15,16 @@ class MockLocationIndex(indexes.SearchIndex, indexes.Indexable):
@staticmethod
def prepare_autocomplete(obj):
- return " ".join((
- obj.address, obj.city, obj.zip_code
- ))
+ return " ".join((obj.address, obj.city, obj.zip_code))
def get_model(self):
return MockLocation
def index_queryset(self, using=None):
- return self.get_model().objects.filter(
- created__lte=timezone.now()
- )
+ return self.get_model().objects.filter(created__lte=timezone.now())
class MockPersonIndex(indexes.SearchIndex, indexes.Indexable):
-
text = indexes.CharField(document=True, use_template=True)
firstname = indexes.CharField(model_attr="firstname", faceted=True)
lastname = indexes.CharField(model_attr="lastname", faceted=True)
@@ -62,13 +56,10 @@ def get_model(self):
return MockPerson
def index_queryset(self, using=None):
- return self.get_model().objects.filter(
- created__lte=timezone.now()
- )
+ return self.get_model().objects.filter(created__lte=timezone.now())
class MockPetIndex(indexes.SearchIndex, indexes.Indexable):
-
text = indexes.CharField(document=True, use_template=True)
name = indexes.CharField(model_attr="name")
species = indexes.CharField(model_attr="species")
@@ -94,7 +85,6 @@ def get_model(self):
class MockAllFieldIndex(indexes.SearchIndex, indexes.Indexable):
-
text = indexes.CharField(document=True, use_template=False)
charfield = indexes.CharField(model_attr="charfield")
integerfield = indexes.IntegerField(model_attr="integerfield")
@@ -107,7 +97,7 @@ class MockAllFieldIndex(indexes.SearchIndex, indexes.Indexable):
@staticmethod
def prepare_multivaluefield(obj):
- return obj.charfield.split(' ', 1)
+ return obj.charfield.split(" ", 1)
def get_model(self):
return MockAllField
diff --git a/tests/mockapp/serializers.py b/tests/mockapp/serializers.py
index 9a78683..d84e89c 100644
--- a/tests/mockapp/serializers.py
+++ b/tests/mockapp/serializers.py
@@ -1,47 +1,46 @@
from datetime import datetime, timedelta
+
from rest_framework.serializers import HyperlinkedIdentityField
-from drf_haystack.serializers import HaystackSerializer, HaystackFacetSerializer, HighlighterMixin
-from .search_indexes import MockPersonIndex, MockLocationIndex
+from drf_haystack.serializers import HaystackFacetSerializer, HaystackSerializer, HighlighterMixin
+from .search_indexes import MockLocationIndex, MockPersonIndex
-class SearchSerializer(HaystackSerializer):
+class SearchSerializer(HaystackSerializer):
class Meta:
index_classes = [MockPersonIndex, MockLocationIndex]
fields = [
- "firstname", "lastname", "birthdate", "full_name", "text",
- "address", "city", "zip_code", "highlighted"
+ "firstname",
+ "lastname",
+ "birthdate",
+ "full_name",
+ "text",
+ "address",
+ "city",
+ "zip_code",
+ "highlighted",
]
class HighlighterSerializer(HighlighterMixin, HaystackSerializer):
-
highlighter_css_class = "my-highlighter-class"
highlighter_html_tag = "em"
class Meta:
index_classes = [MockPersonIndex, MockLocationIndex]
- fields = [
- "firstname", "lastname", "full_name",
- "address", "city", "zip_code", "coordinates"
- ]
+ fields = ["firstname", "lastname", "full_name", "address", "city", "zip_code", "coordinates"]
class MoreLikeThisSerializer(HaystackSerializer):
-
more_like_this = HyperlinkedIdentityField(view_name="search-person-mlt-more-like-this", read_only=True)
class Meta:
index_classes = [MockPersonIndex]
- fields = [
- "firstname", "lastname", "full_name",
- "autocomplete"
- ]
+ fields = ["firstname", "lastname", "full_name", "autocomplete"]
class MockPersonFacetSerializer(HaystackFacetSerializer):
-
serialize_objects = True
class Meta:
@@ -54,6 +53,6 @@ class Meta:
"start_date": datetime.now() - timedelta(days=3 * 365),
"end_date": datetime.now(),
"gap_by": "day",
- "gap_amount": 10
- }
+ "gap_amount": 10,
+ },
}
diff --git a/tests/mockapp/views.py b/tests/mockapp/views.py
index cb237de..6291177 100644
--- a/tests/mockapp/views.py
+++ b/tests/mockapp/views.py
@@ -1,14 +1,13 @@
-from rest_framework.pagination import PageNumberPagination, LimitOffsetPagination
+from rest_framework.pagination import LimitOffsetPagination, PageNumberPagination
-from drf_haystack.filters import HaystackFilter, HaystackBoostFilter, HaystackHighlightFilter, HaystackAutocompleteFilter, HaystackGEOSpatialFilter
-from drf_haystack.viewsets import HaystackViewSet
+from drf_haystack.filters import (
+ HaystackFilter,
+)
from drf_haystack.mixins import FacetMixin, MoreLikeThisMixin
+from drf_haystack.viewsets import HaystackViewSet
-from .models import MockPerson, MockLocation
-from .serializers import (
- SearchSerializer, HighlighterSerializer,
- MoreLikeThisSerializer, MockPersonFacetSerializer
-)
+from .models import MockPerson
+from .serializers import MockPersonFacetSerializer, MoreLikeThisSerializer, SearchSerializer
class BasicPageNumberPagination(PageNumberPagination):
diff --git a/tests/run_tests.py b/tests/run_tests.py
index 43df728..747ff90 100644
--- a/tests/run_tests.py
+++ b/tests/run_tests.py
@@ -1,7 +1,6 @@
#!/usr/bin/env python
-
import os
import sys
diff --git a/tests/settings.py b/tests/settings.py
index cb317a7..bcbe8b5 100644
--- a/tests/settings.py
+++ b/tests/settings.py
@@ -1,118 +1,103 @@
import os
-BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(os.path.dirname(__file__)), 'tests'))
-SECRET_KEY = 'NOBODY expects the Spanish Inquisition!'
+BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(os.path.dirname(__file__)), "tests"))
+
+SECRET_KEY = "NOBODY expects the Spanish Inquisition!"
DEBUG = True
ALLOWED_HOSTS = ["*"]
-DATABASES = {
- 'default': {
- 'ENGINE': 'django.db.backends.sqlite3',
- 'NAME': os.path.join(BASE_DIR, os.pardir, 'test.db')
- }
-}
+DATABASES = {"default": {"ENGINE": "django.db.backends.sqlite3", "NAME": os.path.join(BASE_DIR, os.pardir, "test.db")}}
INSTALLED_APPS = (
- 'django.contrib.auth',
- 'django.contrib.contenttypes',
- 'django.contrib.sessions',
- 'django.contrib.staticfiles',
-
- 'haystack',
- 'rest_framework',
-
- 'tests.mockapp',
+ "django.contrib.auth",
+ "django.contrib.contenttypes",
+ "django.contrib.sessions",
+ "django.contrib.staticfiles",
+ "haystack",
+ "rest_framework",
+ "tests.mockapp",
)
MIDDLEWARE_CLASSES = (
- 'django.contrib.sessions.middleware.SessionMiddleware',
- 'django.middleware.common.CommonMiddleware',
- 'django.middleware.csrf.CsrfViewMiddleware',
- 'django.contrib.auth.middleware.AuthenticationMiddleware',
- 'django.contrib.auth.middleware.SessionAuthenticationMiddleware',
- 'django.contrib.messages.middleware.MessageMiddleware',
- 'django.middleware.clickjacking.XFrameOptionsMiddleware',
+ "django.contrib.sessions.middleware.SessionMiddleware",
+ "django.middleware.common.CommonMiddleware",
+ "django.middleware.csrf.CsrfViewMiddleware",
+ "django.contrib.auth.middleware.AuthenticationMiddleware",
+ "django.contrib.auth.middleware.SessionAuthenticationMiddleware",
+ "django.contrib.messages.middleware.MessageMiddleware",
+ "django.middleware.clickjacking.XFrameOptionsMiddleware",
)
TEMPLATES = [
{
- 'BACKEND': 'django.template.backends.django.DjangoTemplates',
- 'OPTIONS': {'debug': True},
- 'APP_DIRS': True,
+ "BACKEND": "django.template.backends.django.DjangoTemplates",
+ "OPTIONS": {"debug": True},
+ "APP_DIRS": True,
},
]
-REST_FRAMEWORK = {
- 'DEFAULT_PERMISSION_CLASSES': (
- 'rest_framework.permissions.AllowAny',
- )
-}
+REST_FRAMEWORK = {"DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.AllowAny",)}
-ROOT_URLCONF = 'tests.urls'
-WSGI_APPLICATION = 'tests.wsgi.application'
-LANGUAGE_CODE = 'en-us'
-TIME_ZONE = 'UTC'
+ROOT_URLCONF = "tests.urls"
+WSGI_APPLICATION = "tests.wsgi.application"
+LANGUAGE_CODE = "en-us"
+TIME_ZONE = "UTC"
USE_I18N = True
USE_L10N = True
USE_TZ = True
-STATIC_URL = '/static/'
+STATIC_URL = "/static/"
HAYSTACK_CONNECTIONS = {
- 'default': {
- 'ENGINE': 'haystack.backends.elasticsearch_backend.ElasticsearchSearchEngine',
- 'URL': os.environ.get('ELASTICSEARCH_URL', 'http://localhost:9200/'),
- 'INDEX_NAME': 'drf-haystack-test',
- 'INCLUDE_SPELLING': True,
- 'TIMEOUT': 300,
+ "default": {
+ "ENGINE": "haystack.backends.elasticsearch_backend.ElasticsearchSearchEngine",
+ "URL": os.environ.get("ELASTICSEARCH_URL", "http://localhost:9200/"),
+ "INDEX_NAME": "drf-haystack-test",
+ "INCLUDE_SPELLING": True,
+ "TIMEOUT": 300,
},
}
-DEFAULT_LOG_DIR = os.path.join(BASE_DIR, 'logs')
+DEFAULT_LOG_DIR = os.path.join(BASE_DIR, "logs")
LOGGING = {
- 'version': 1,
- 'disable_existing_loggers': False,
- 'formatters': {
- 'standard': {
- 'format': '%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s'
- },
+ "version": 1,
+ "disable_existing_loggers": False,
+ "formatters": {
+ "standard": {"format": "%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s"},
},
- 'handlers': {
- 'console_handler': {
- 'level': 'DEBUG',
- 'class': 'logging.StreamHandler',
- 'formatter': 'standard'
- },
- 'file_handler': {
- 'level': 'DEBUG',
- 'class': 'logging.FileHandler',
- 'filename': os.path.join(DEFAULT_LOG_DIR, 'tests.log'),
+ "handlers": {
+ "console_handler": {"level": "DEBUG", "class": "logging.StreamHandler", "formatter": "standard"},
+ "file_handler": {
+ "level": "DEBUG",
+ "class": "logging.FileHandler",
+ "filename": os.path.join(DEFAULT_LOG_DIR, "tests.log"),
},
},
- 'loggers': {
- 'default': {
- 'handlers': ['file_handler'],
- 'level': 'INFO',
- 'propagate': True,
+ "loggers": {
+ "default": {
+ "handlers": ["file_handler"],
+ "level": "INFO",
+ "propagate": True,
},
- 'elasticsearch': {
- 'handlers': ['file_handler'],
- 'level': 'ERROR',
- 'propagate': True,
+ "elasticsearch": {
+ "handlers": ["file_handler"],
+ "level": "ERROR",
+ "propagate": True,
},
- 'elasticsearch.trace': {
- 'handlers': ['file_handler'],
- 'level': 'ERROR',
- 'propagate': True,
+ "elasticsearch.trace": {
+ "handlers": ["file_handler"],
+ "level": "ERROR",
+ "propagate": True,
},
},
}
try:
import elasticsearch
- if (2, ) <= elasticsearch.VERSION <= (3, ):
- HAYSTACK_CONNECTIONS['default'].update({
- 'ENGINE': 'haystack.backends.elasticsearch2_backend.Elasticsearch2SearchEngine'
+
+ if (2,) <= elasticsearch.VERSION <= (3,):
+ HAYSTACK_CONNECTIONS["default"].update({
+ "ENGINE": "haystack.backends.elasticsearch2_backend.Elasticsearch2SearchEngine"
})
-except ImportError as e:
- del HAYSTACK_CONNECTIONS['default'] # This will intentionally cause everything to break!
+except ImportError:
+ del HAYSTACK_CONNECTIONS["default"] # This will intentionally cause everything to break!
diff --git a/tests/test_filters.py b/tests/test_filters.py
index 4980dc1..8f25dcc 100644
--- a/tests/test_filters.py
+++ b/tests/test_filters.py
@@ -5,26 +5,26 @@
import json
from datetime import date, datetime, timedelta
-
from unittest import skipIf
from django.test import TestCase
-
-from rest_framework import status
-from rest_framework import serializers
+from rest_framework import serializers, status
from rest_framework.test import APIRequestFactory
-from drf_haystack.viewsets import HaystackViewSet
-from drf_haystack.serializers import HaystackSerializer, HaystackFacetSerializer
from drf_haystack.filters import (
- HaystackAutocompleteFilter, HaystackBoostFilter,
- HaystackFacetFilter, HaystackFilter,
- HaystackGEOSpatialFilter, HaystackHighlightFilter,
- HaystackOrderingFilter
+ HaystackAutocompleteFilter,
+ HaystackBoostFilter,
+ HaystackFacetFilter,
+ HaystackFilter,
+ HaystackGEOSpatialFilter,
+ HaystackHighlightFilter,
+ HaystackOrderingFilter,
)
from drf_haystack.mixins import FacetMixin
+from drf_haystack.serializers import HaystackFacetSerializer, HaystackSerializer
+from drf_haystack.viewsets import HaystackViewSet
-from . import geospatial_support, elasticsearch_version
+from . import elasticsearch_version, geospatial_support
from .constants import MOCKLOCATION_DATA_SET_SIZE, MOCKPERSON_DATA_SET_SIZE
from .mixins import WarningTestCaseMixin
from .mockapp.models import MockAllField, MockLocation, MockPerson
@@ -34,7 +34,6 @@
class HaystackFilterTestCase(TestCase):
-
fixtures = ["mockperson", "mockallfield"]
def setUp(self):
@@ -44,12 +43,8 @@ def setUp(self):
class Serializer1(HaystackSerializer):
class Meta:
index_classes = [MockPersonIndex]
- fields = ["text", "firstname", "lastname",
- "full_name", "birthdate", "autocomplete"]
- field_aliases = {
- "q": "autocomplete",
- "name": "full_name"
- }
+ fields = ["text", "firstname", "lastname", "full_name", "birthdate", "autocomplete"]
+ field_aliases = {"q": "autocomplete", "name": "full_name"}
class Serializer2(HaystackSerializer):
class Meta:
@@ -147,10 +142,9 @@ def test_filter_multiple_fields(self):
def test_filter_multiple_fields_OR_same_fields(self):
# Test filtering multiple fields for multiple values. The values should be OR'ed between
# same parameters, and AND'ed between them
- request = factory.get(path="/", data={
- "lastname": "Hickman,Hood",
- "firstname": "Walker,Bruno"
- }) # Should return 2 result
+ request = factory.get(
+ path="/", data={"lastname": "Hickman,Hood", "firstname": "Walker,Bruno"}
+ ) # Should return 2 result
response = self.view1.as_view(actions={"get": "list"})(request)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 2)
@@ -159,7 +153,9 @@ def test_filter_excluded_field(self):
request = factory.get(path="/", data={"lastname": "Hood"}, content_type="application/json")
response = self.view2.as_view(actions={"get": "list"})(request)
self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(len(response.data), MOCKPERSON_DATA_SET_SIZE) # Should return all results since, field is ignored
+ self.assertEqual(
+ len(response.data), MOCKPERSON_DATA_SET_SIZE
+ ) # Should return all results since, field is ignored
def test_filter_with_non_searched_excluded_field(self):
request = factory.get(path="/", data={"text": "John"}, content_type="application/json")
@@ -168,8 +164,9 @@ def test_filter_with_non_searched_excluded_field(self):
self.assertEqual(len(response.data), 3)
def test_filter_unicode_characters(self):
- request = factory.get(path="/", data={"firstname": "åsmund", "lastname": "sørensen"},
- content_type="application/json")
+ request = factory.get(
+ path="/", data={"firstname": "åsmund", "lastname": "sørensen"}, content_type="application/json"
+ )
response = self.view1.as_view(actions={"get": "list"})(request)
self.assertEqual(len(response.data), 1)
@@ -186,7 +183,9 @@ def test_filter_negated_field_with_lookup(self):
self.assertEqual(len(response.data), 99)
def test_filter_negated_field_with_other_field(self):
- request = factory.get(path="/", data={"firstname": "John", "lastname__not": "McClane"}, content_type="application/json")
+ request = factory.get(
+ path="/", data={"firstname": "John", "lastname__not": "McClane"}, content_type="application/json"
+ )
response = self.view1.as_view(actions={"get": "list"})(request)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 2)
@@ -213,14 +212,12 @@ def test_filter_range_integerfield(self):
class HaystackAutocompleteFilterTestCase(TestCase):
-
fixtures = ["mockperson"]
def setUp(self):
MockPersonIndex().reindex()
class Serializer(HaystackSerializer):
-
class Meta:
index_classes = [MockPersonIndex]
fields = ["text", "firstname", "lastname", "autocomplete"]
@@ -251,8 +248,9 @@ def test_filter_autocomplete_multiple_terms(self):
self.assertEqual(len(response.data), 2)
def test_filter_autocomplete_multiple_parameters(self):
- request = factory.get(path="/", data={"autocomplete": "jer fowler", "firstname": "jeremy"},
- content_type="application/json")
+ request = factory.get(
+ path="/", data={"autocomplete": "jer fowler", "firstname": "jeremy"}, content_type="application/json"
+ )
response = self.view.as_view(actions={"get": "list"})(request)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -264,18 +262,19 @@ def test_filter_autocomplete_single_field_OR(self):
@skipIf(not geospatial_support, "Skipped due to lack of GEO spatial features")
class HaystackGEOSpatialFilterTestCase(TestCase):
-
fixtures = ["mocklocation"]
def setUp(self):
MockLocationIndex().reindex()
class Serializer(HaystackSerializer):
-
class Meta:
index_classes = [MockLocationIndex]
fields = [
- "text", "address", "city", "zip_code",
+ "text",
+ "address",
+ "city",
+ "zip_code",
"coordinates",
]
@@ -304,22 +303,19 @@ def test_filter_dwithin_without_range_unit(self):
self.assertEqual(len(response.data), MOCKLOCATION_DATA_SET_SIZE)
def test_filter_dwithin_invalid_params(self):
- request = factory.get(path="/", data={"from": "i am not numeric,10.739370", "km": 1}, content_type="application/json")
- self.assertRaises(
- ValueError,
- self.view.as_view(actions={"get": "list"}), request
+ request = factory.get(
+ path="/", data={"from": "i am not numeric,10.739370", "km": 1}, content_type="application/json"
)
+ self.assertRaises(ValueError, self.view.as_view(actions={"get": "list"}), request)
class HaystackHighlightFilterTestCase(TestCase):
-
fixtures = ["mockperson"]
def setUp(self):
MockPersonIndex().reindex()
class Serializer(HaystackSerializer):
-
class Meta:
index_classes = [MockPersonIndex]
fields = ["firstname", "lastname"]
@@ -334,28 +330,23 @@ class ViewSet(HaystackViewSet):
def tearDown(self):
MockPersonIndex().clear()
- @skipIf(not elasticsearch_version < (2, ), "Highlighting is not yet supported for the Elasticsearch2 backend")
+ @skipIf(not elasticsearch_version < (2,), "Highlighting is not yet supported for the Elasticsearch2 backend")
def test_filter_highlighter_filter(self):
request = factory.get(path="/", data={"firstname": "jeremy"}, content_type="application/json")
response = self.view.as_view(actions={"get": "list"})(request)
response.render()
for result in json.loads(response.content.decode()):
self.assertTrue("highlighted" in result)
- self.assertEqual(
- result["highlighted"],
- " ".join(("Jeremy", "%s\n" % result["lastname"]))
- )
+ self.assertEqual(result["highlighted"], " ".join(("Jeremy", "%s\n" % result["lastname"])))
class HaystackBoostFilterTestCase(TestCase):
-
fixtures = ["mockperson"]
def setUp(self):
MockPersonIndex().reindex()
class Serializer(HaystackSerializer):
-
class Meta:
index_classes = [MockPersonIndex]
fields = ["firstname", "lastname"]
@@ -409,8 +400,7 @@ def test_filter_boost_invalid_non_numeric(self):
self.fail("Did not raise ValueError when called with a non-numeric boost value.")
except ValueError as e:
self.assertEqual(
- str(e),
- "Cannot convert boost to float value. Make sure to provide a numerical boost value."
+ str(e), "Cannot convert boost to float value. Make sure to provide a numerical boost value."
)
def test_filter_boost_invalid_malformed_query_params(self):
@@ -421,26 +411,22 @@ def test_filter_boost_invalid_malformed_query_params(self):
except ValueError as e:
self.assertEqual(
str(e),
- "Cannot convert the '%s' query parameter to a valid boost filter."
- % HaystackBoostFilter.query_param
+ "Cannot convert the '%s' query parameter to a valid boost filter." % HaystackBoostFilter.query_param,
)
class HaystackFacetFilterTestCase(WarningTestCaseMixin, TestCase):
-
fixtures = ["mockperson"]
def setUp(self):
MockPersonIndex().reindex()
class FacetSerializer1(HaystackFacetSerializer):
-
class Meta:
index_classes = [MockPersonIndex]
fields = ["firstname", "lastname", "created"]
class FacetSerializer2(HaystackFacetSerializer):
-
class Meta:
index_classes = [MockPersonIndex]
fields = ["firstname", "lastname", "created"]
@@ -451,8 +437,8 @@ class Meta:
"start_date": datetime.now() - timedelta(days=3 * 365),
"end_date": datetime.now(),
"gap_by": "day",
- "gap_amount": 10
- }
+ "gap_amount": 10,
+ },
}
class ViewSet1(FacetMixin, HaystackViewSet):
@@ -483,19 +469,18 @@ def test_filter_facet_serializer_no_field_options_missing_required_query_paramet
request = factory.get("/", data={"created": "start_date:Oct 3rd 2015"}, content_type="application/json")
try:
self.view1.as_view(actions={"get": "facets"})(request)
- self.fail("Did not raise ValueError when called without all required "
- "attributes and no default field_options is set.")
- except ValueError as e:
- self.assertEqual(
- str(e),
- "Date faceting requires at least 'start_date', 'end_date' and 'gap_by' to be set."
+ self.fail(
+ "Did not raise ValueError when called without all required "
+ "attributes and no default field_options is set."
)
+ except ValueError as e:
+ self.assertEqual(str(e), "Date faceting requires at least 'start_date', 'end_date' and 'gap_by' to be set.")
def test_filter_facet_no_field_options_valid_required_query_parameters(self):
request = factory.get(
"/",
data={"created": "start_date:Jan 1th 2010,end_date:Dec 31th 2020,gap_by:month,gap_amount:1"},
- content_type="application/json"
+ content_type="application/json",
)
response = self.view1.as_view(actions={"get": "facets"})(request)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -511,7 +496,6 @@ def test_filter_facet_warn_on_inproperly_formatted_token(self):
class OrderedHaystackViewSetTestCase(TestCase):
-
fixtures = ["mockallfield"]
def setUp(self):
@@ -519,8 +503,7 @@ def setUp(self):
class Serializer(HaystackSerializer):
class Meta:
- fields = ("charfield", "integerfield", "floatfield",
- "decimalfield", "boolfield")
+ fields = ("charfield", "integerfield", "floatfield", "decimalfield", "boolfield")
index_classes = [MockAllFieldIndex]
class ViewSet1(HaystackViewSet):
@@ -552,7 +535,7 @@ def test_viewset_default_ordering(self):
self.assertEqual(
[result["integerfield"] for result in content],
- list(MockAllField.objects.values_list("integerfield", flat=True).order_by("integerfield"))
+ list(MockAllField.objects.values_list("integerfield", flat=True).order_by("integerfield")),
)
def test_viewset_default_reverse_ordering(self):
@@ -564,7 +547,7 @@ def test_viewset_default_reverse_ordering(self):
self.assertEqual(
[result["integerfield"] for result in content],
- list(MockAllField.objects.values_list("integerfield", flat=True).order_by("-integerfield"))
+ list(MockAllField.objects.values_list("integerfield", flat=True).order_by("-integerfield")),
)
def test_viewset_order_by_single_query_param(self):
@@ -576,7 +559,7 @@ def test_viewset_order_by_single_query_param(self):
self.assertEqual(
[result["integerfield"] for result in content],
- list(MockAllField.objects.values_list("integerfield", flat=True).order_by("integerfield"))
+ list(MockAllField.objects.values_list("integerfield", flat=True).order_by("integerfield")),
)
def test_viewset_order_by_multiple_query_params(self):
@@ -588,5 +571,5 @@ def test_viewset_order_by_multiple_query_params(self):
self.assertEqual(
[result["integerfield"] for result in content],
- list(MockAllField.objects.values_list("integerfield", flat=True).order_by("integerfield", "boolfield"))
+ list(MockAllField.objects.values_list("integerfield", flat=True).order_by("integerfield", "boolfield")),
)
diff --git a/tests/test_serializers.py b/tests/test_serializers.py
index 2b252df..ff05a44 100644
--- a/tests/test_serializers.py
+++ b/tests/test_serializers.py
@@ -6,29 +6,30 @@
import json
from datetime import datetime, timedelta
-import six
-from django.urls import path, include
from django.core.exceptions import ImproperlyConfigured
from django.http import QueryDict
-from django.test import TestCase, SimpleTestCase, override_settings
+from django.test import SimpleTestCase, TestCase, override_settings
+from django.urls import include, path
from haystack.query import SearchQuerySet
-
from rest_framework import serializers
from rest_framework.fields import CharField, IntegerField
from rest_framework.routers import DefaultRouter
from rest_framework.test import APIRequestFactory, APITestCase
from drf_haystack import fields
+from drf_haystack.mixins import FacetMixin, MoreLikeThisMixin
from drf_haystack.serializers import (
- HighlighterMixin, HaystackSerializer,
- HaystackSerializerMixin, HaystackFacetSerializer,
- HaystackSerializerMeta)
+ HaystackFacetSerializer,
+ HaystackSerializer,
+ HaystackSerializerMeta,
+ HaystackSerializerMixin,
+ HighlighterMixin,
+)
from drf_haystack.viewsets import HaystackViewSet
-from drf_haystack.mixins import MoreLikeThisMixin, FacetMixin
from .mixins import WarningTestCaseMixin
-from .mockapp.models import MockPerson, MockAllField
-from .mockapp.search_indexes import MockPersonIndex, MockPetIndex, MockAllFieldIndex
+from .mockapp.models import MockAllField, MockPerson
+from .mockapp.search_indexes import MockAllFieldIndex, MockPersonIndex, MockPetIndex
factory = APIRequestFactory()
@@ -51,7 +52,6 @@ class Meta:
# Faceting stuff
class SearchPersonFSerializer(HaystackSerializer):
-
class Meta:
index_classes = [MockPersonIndex]
fields = ["firstname", "lastname", "full_name"]
@@ -70,8 +70,8 @@ class Meta:
"start_date": datetime.now() - timedelta(days=10 * 365),
"end_date": datetime.now(),
"gap_by": "month",
- "gap_amount": 1
- }
+ "gap_amount": 1,
+ },
}
@@ -87,13 +87,10 @@ class Meta:
router.register("search-person-mlt", viewset=SearchPersonMLTViewSet, basename="search-person-mlt")
router.register("search-person-facet", viewset=SearchPersonFacetViewSet, basename="search-person-facet")
-urlpatterns = [
- path(r"^", include(router.urls))
-]
+urlpatterns = [path(r"^", include(router.urls))]
class HaystackSerializerTestCase(WarningTestCaseMixin, TestCase):
-
fixtures = ["mockperson", "mockpet"]
def setUp(self):
@@ -101,7 +98,6 @@ def setUp(self):
MockPetIndex().reindex()
class Serializer1(HaystackSerializer):
-
integer_field = serializers.IntegerField()
city = serializers.SerializerMethodField()
@@ -116,20 +112,17 @@ def get_city(self, instance):
return "Declared overriding field"
class Serializer2(HaystackSerializer):
-
class Meta:
index_classes = [MockPersonIndex]
exclude = ["firstname"]
class Serializer3(HaystackSerializer):
-
class Meta:
index_classes = [MockPersonIndex]
fields = ["text", "firstname", "lastname", "autocomplete"]
ignore_fields = ["autocomplete"]
class Serializer7(HaystackSerializer):
-
class Meta:
index_classes = [MockPetIndex]
@@ -150,16 +143,19 @@ def tearDown(self):
def test_serializer_raise_without_meta_class(self):
try:
+
class Serializer(HaystackSerializer):
pass
+
self.fail("Did not fail when defining a Serializer without a Meta class")
except ImproperlyConfigured as e:
self.assertEqual(str(e), "%s must implement a Meta class or have the property _abstract" % "Serializer")
def test_serializer_gets_default_instance(self):
serializer = self.serializer1(instance=None)
- self.assertIsInstance(serializer.instance, SearchQuerySet,
- "Did not get default instance of type SearchQuerySet")
+ self.assertIsInstance(
+ serializer.instance, SearchQuerySet, "Did not get default instance of type SearchQuerySet"
+ )
def test_serializer_get_fields(self):
obj = SearchQuerySet().filter(lastname="Foreman")[0]
@@ -205,11 +201,10 @@ def test_serializer_declared_field_overrides(self):
obj = SearchQuerySet().filter(lastname="Foreman")[0]
serializer = self.serializer1(instance=obj)
- self.assertEqual(serializer.data['city'], "Declared overriding field")
+ self.assertEqual(serializer.data["city"], "Declared overriding field")
class HaystackSerializerAllFieldsTestCase(TestCase):
-
fixtures = ["mockallfield"]
def setUp(self):
@@ -218,28 +213,34 @@ def setUp(self):
class Serializer1(HaystackSerializer):
class Meta:
index_classes = [MockAllFieldIndex]
- fields = ["charfield", "integerfield", "floatfield",
- "decimalfield", "boolfield", "datefield",
- "datetimefield", "multivaluefield"]
+ fields = [
+ "charfield",
+ "integerfield",
+ "floatfield",
+ "decimalfield",
+ "boolfield",
+ "datefield",
+ "datetimefield",
+ "multivaluefield",
+ ]
self.serializer1 = Serializer1
def test_serialize_field_is_correct_type(self):
- obj = SearchQuerySet().models(MockAllField).latest('datetimefield')
+ obj = SearchQuerySet().models(MockAllField).latest("datetimefield")
serializer = self.serializer1(instance=obj, many=False)
- self.assertIsInstance(serializer.fields['charfield'], fields.HaystackCharField)
- self.assertIsInstance(serializer.fields['integerfield'], fields.HaystackIntegerField)
- self.assertIsInstance(serializer.fields['floatfield'], fields.HaystackFloatField)
- self.assertIsInstance(serializer.fields['decimalfield'], fields.HaystackDecimalField)
- self.assertIsInstance(serializer.fields['boolfield'], fields.HaystackBooleanField)
- self.assertIsInstance(serializer.fields['datefield'], fields.HaystackDateField)
- self.assertIsInstance(serializer.fields['datetimefield'], fields.HaystackDateTimeField)
- self.assertIsInstance(serializer.fields['multivaluefield'], fields.HaystackMultiValueField)
+ self.assertIsInstance(serializer.fields["charfield"], fields.HaystackCharField)
+ self.assertIsInstance(serializer.fields["integerfield"], fields.HaystackIntegerField)
+ self.assertIsInstance(serializer.fields["floatfield"], fields.HaystackFloatField)
+ self.assertIsInstance(serializer.fields["decimalfield"], fields.HaystackDecimalField)
+ self.assertIsInstance(serializer.fields["boolfield"], fields.HaystackBooleanField)
+ self.assertIsInstance(serializer.fields["datefield"], fields.HaystackDateField)
+ self.assertIsInstance(serializer.fields["datetimefield"], fields.HaystackDateTimeField)
+ self.assertIsInstance(serializer.fields["multivaluefield"], fields.HaystackMultiValueField)
class HaystackSerializerMultipleIndexTestCase(WarningTestCaseMixin, TestCase):
-
fixtures = ["mockperson", "mockpet"]
def setUp(self):
@@ -259,6 +260,7 @@ class Serializer2(HaystackSerializer):
"""
Multiple index serializer with declared fields
"""
+
_MockPersonIndex__hair_color = serializers.SerializerMethodField()
extra = serializers.SerializerMethodField()
@@ -280,9 +282,7 @@ class Serializer3(HaystackSerializer):
class Meta:
index_classes = [MockPersonIndex, MockPetIndex]
exclude = ["firstname"]
- index_aliases = {
- 'mockapp.MockPersonIndex': 'People'
- }
+ index_aliases = {"mockapp.MockPersonIndex": "People"}
class ViewSet1(HaystackViewSet):
serializer_class = Serializer1
@@ -344,7 +344,6 @@ def test_serializer_multiple_index_declared_fields(self):
class HaystackSerializerHighlighterMixinTestCase(WarningTestCaseMixin, TestCase):
-
fixtures = ["mockperson"]
def setUp(self):
@@ -382,10 +381,13 @@ def test_serializer_highlighting(self):
self.assertTrue("highlighted" in result)
self.assertEqual(
result["highlighted"],
- " ".join(('<{tag} class="{css_class}">Jeremy{tag}>'.format(
- tag=self.view1.serializer_class.highlighter_html_tag,
- css_class=self.view1.serializer_class.highlighter_css_class
- ), "%s" % "is a nice chap!"))
+ " ".join((
+ '<{tag} class="{css_class}">Jeremy{tag}>'.format(
+ tag=self.view1.serializer_class.highlighter_html_tag,
+ css_class=self.view1.serializer_class.highlighter_css_class,
+ ),
+ "%s" % "is a nice chap!",
+ )),
)
def test_serializer_highlighter_raise_no_highlighter_class(self):
@@ -397,13 +399,12 @@ def test_serializer_highlighter_raise_no_highlighter_class(self):
self.assertEqual(
str(e),
"%(cls)s is missing a highlighter_class. Define %(cls)s.highlighter_class, "
- "or override %(cls)s.get_highlighter()." % {"cls": self.view2.serializer_class.__name__}
+ "or override %(cls)s.get_highlighter()." % {"cls": self.view2.serializer_class.__name__},
)
@override_settings(ROOT_URLCONF="tests.test_serializers")
class HaystackSerializerMoreLikeThisTestCase(APITestCase):
-
fixtures = ["mockperson"]
def setUp(self):
@@ -414,33 +415,28 @@ def tearDown(self):
def test_serializer_more_like_this_link(self):
response = self.client.get(
- path="/search-person-mlt/",
- data={"firstname": "odysseus", "lastname": "cooley"},
- format="json"
+ path="/search-person-mlt/", data={"firstname": "odysseus", "lastname": "cooley"}, format="json"
)
self.assertEqual(
response.data,
- [{
- "lastname": "Cooley",
- "full_name": "Odysseus Cooley",
- "firstname": "Odysseus",
- "more_like_this": "http://testserver/search-person-mlt/18/more-like-this/"
- }]
+ [
+ {
+ "lastname": "Cooley",
+ "full_name": "Odysseus Cooley",
+ "firstname": "Odysseus",
+ "more_like_this": "http://testserver/search-person-mlt/18/more-like-this/",
+ }
+ ],
)
@override_settings(ROOT_URLCONF="tests.test_serializers")
class HaystackFacetSerializerTestCase(TestCase):
-
fixtures = ["mockperson"]
def setUp(self):
MockPersonIndex().reindex()
- self.response = self.client.get(
- path="/search-person-facet/facets/",
- data={},
- format="json"
- )
+ self.response = self.client.get(path="/search-person-facet/facets/", data={}, format="json")
def tearDown(self):
MockPersonIndex().clear()
@@ -459,8 +455,9 @@ def is_paginated_facet_response(response):
Returns True if the response.data seems like a faceted result.
Only works for responses created with the test client.
"""
- return "objects" in response.data and \
- all([k in response.data["objects"] for k in ("count", "next", "previous", "results")])
+ return "objects" in response.data and all([
+ k in response.data["objects"] for k in ("count", "next", "previous", "results")
+ ])
def test_serializer_facet_top_level_structure(self):
for key in ("fields", "dates", "queries"):
@@ -476,17 +473,18 @@ def test_serializer_facet_field_result(self):
self.assertTrue({"text", "count", "narrow_url"} <= set(firstname))
self.assertEqual(
firstname["narrow_url"],
- self.build_absolute_uri("/search-person-facet/facets/?selected_facets=firstname_exact%3A{term}".format(
- term=firstname["text"]))
+ self.build_absolute_uri(
+ "/search-person-facet/facets/?selected_facets=firstname_exact%3A{term}".format(term=firstname["text"])
+ ),
)
lastname = fields["lastname"][0]
self.assertTrue({"text", "count", "narrow_url"} <= set(lastname))
self.assertEqual(
lastname["narrow_url"],
- self.build_absolute_uri("/search-person-facet/facets/?selected_facets=lastname_exact%3A{term}".format(
- term=lastname["text"]
- ))
+ self.build_absolute_uri(
+ "/search-person-facet/facets/?selected_facets=lastname_exact%3A{term}".format(term=lastname["text"])
+ ),
)
def test_serializer_facet_date_result(self):
@@ -500,7 +498,9 @@ def test_serializer_facet_date_result(self):
self.assertEqual(created["count"], 100)
self.assertEqual(
created["narrow_url"],
- self.build_absolute_uri("/search-person-facet/facets/?selected_facets=created_exact%3A2015-05-01+00%3A00%3A00")
+ self.build_absolute_uri(
+ "/search-person-facet/facets/?selected_facets=created_exact%3A2015-05-01+00%3A00%3A00"
+ ),
)
def test_serializer_facet_queries_result(self):
@@ -511,7 +511,7 @@ def test_serializer_facet_narrow(self):
response = self.client.get(
path="/search-person-facet/facets/",
data=QueryDict("selected_facets=firstname_exact:John&selected_facets=lastname_exact:McClane"),
- format="json"
+ format="json",
)
self.assertEqual(response.data["queries"], {})
@@ -522,8 +522,10 @@ def test_serializer_facet_narrow(self):
self.assertEqual(response.data["fields"]["firstname"][0]["count"], 1)
self.assertEqual(
response.data["fields"]["firstname"][0]["narrow_url"],
- self.build_absolute_uri("/search-person-facet/facets/?selected_facets=firstname_exact%3AJohn"
- "&selected_facets=lastname_exact%3AMcClane")
+ self.build_absolute_uri(
+ "/search-person-facet/facets/?selected_facets=firstname_exact%3AJohn"
+ "&selected_facets=lastname_exact%3AMcClane"
+ ),
)
self.assertEqual(len(response.data["fields"]["lastname"]), 1)
@@ -531,8 +533,10 @@ def test_serializer_facet_narrow(self):
self.assertEqual(response.data["fields"]["lastname"][0]["count"], 1)
self.assertEqual(
response.data["fields"]["lastname"][0]["narrow_url"],
- self.build_absolute_uri("/search-person-facet/facets/?selected_facets=firstname_exact%3AJohn"
- "&selected_facets=lastname_exact%3AMcClane")
+ self.build_absolute_uri(
+ "/search-person-facet/facets/?selected_facets=firstname_exact%3AJohn"
+ "&selected_facets=lastname_exact%3AMcClane"
+ ),
)
self.assertTrue("created" in response.data["dates"])
@@ -541,22 +545,26 @@ def test_serializer_facet_narrow(self):
self.assertEqual(response.data["dates"]["created"][0]["count"], 1)
self.assertEqual(
response.data["dates"]["created"][0]["narrow_url"],
- self.build_absolute_uri("/search-person-facet/facets/?selected_facets=created_exact%3A2015-05-01+00%3A00%3A00"
- "&selected_facets=firstname_exact%3AJohn&selected_facets=lastname_exact%3AMcClane"
- )
+ self.build_absolute_uri(
+ "/search-person-facet/facets/?selected_facets=created_exact%3A2015-05-01+00%3A00%3A00"
+ "&selected_facets=firstname_exact%3AJohn&selected_facets=lastname_exact%3AMcClane"
+ ),
)
def test_serializer_raise_without_meta_class(self):
try:
+
class FacetSerializer(HaystackFacetSerializer):
pass
+
self.fail("Did not fail when defining a Serializer without a Meta class")
except ImproperlyConfigured as e:
- self.assertEqual(str(e), "%s must implement a Meta class or have the property _abstract" % "FacetSerializer")
+ self.assertEqual(
+ str(e), "%s must implement a Meta class or have the property _abstract" % "FacetSerializer"
+ )
class HaystackSerializerMixinTestCase(WarningTestCaseMixin, TestCase):
-
fixtures = ["mockperson"]
def setUp(self):
@@ -565,12 +573,14 @@ def setUp(self):
class MockPersonSerializer(serializers.ModelSerializer):
class Meta:
model = MockPerson
- fields = ('id', 'firstname', 'lastname', 'created', 'updated')
- read_only_fields = ('created', 'updated')
+ fields = ("id", "firstname", "lastname", "created", "updated")
+ read_only_fields = ("created", "updated")
class Serializer1(HaystackSerializerMixin, MockPersonSerializer):
class Meta(MockPersonSerializer.Meta):
- search_fields = ['text', ]
+ search_fields = [
+ "text",
+ ]
class Viewset1(HaystackViewSet):
serializer_class = Serializer1
@@ -586,18 +596,19 @@ def test_serializer_mixin(self):
serializer = self.serializer1(instance=objs, many=True)
self.assertEqual(
json.loads(json.dumps(serializer.data)),
- [{
- "id": 1,
- "firstname": "Abel",
- "lastname": "Foreman",
- "created": "2015-05-19T10:48:08.686000Z",
- "updated": "2016-04-24T16:02:59.378000Z"
- }]
+ [
+ {
+ "id": 1,
+ "firstname": "Abel",
+ "lastname": "Foreman",
+ "created": "2015-05-19T10:48:08.686000Z",
+ "updated": "2016-04-24T16:02:59.378000Z",
+ }
+ ],
)
class HaystackMultiSerializerTestCase(WarningTestCaseMixin, TestCase):
-
fixtures = ["mockperson", "mockpet"]
def setUp(self):
@@ -607,19 +618,16 @@ def setUp(self):
class MockPersonSerializer(HaystackSerializer):
class Meta:
index_classes = [MockPersonIndex]
- fields = ('text', 'firstname', 'lastname', 'description')
+ fields = ("text", "firstname", "lastname", "description")
class MockPetSerializer(HaystackSerializer):
class Meta:
index_classes = [MockPetIndex]
- exclude = ('description', 'autocomplete')
+ exclude = ("description", "autocomplete")
class Serializer1(HaystackSerializer):
class Meta:
- serializers = {
- MockPersonIndex: MockPersonSerializer,
- MockPetIndex: MockPetSerializer
- }
+ serializers = {MockPersonIndex: MockPersonSerializer, MockPetIndex: MockPetSerializer}
self.serializer1 = Serializer1
@@ -632,23 +640,19 @@ def test_multi_serializer(self):
serializer = self.serializer1(instance=objs, many=True)
self.assertEqual(
json.loads(json.dumps(serializer.data)),
- [{
- "has_rabies": True,
- "text": "Zane",
- "name": "Zane",
- "species": "Dog"
- },
- {
- "text": "Zane Griffith\n",
- "firstname": "Zane",
- "lastname": "Griffith",
- "description": "Zane is a nice chap!"
- }]
+ [
+ {"has_rabies": True, "text": "Zane", "name": "Zane", "species": "Dog"},
+ {
+ "text": "Zane Griffith\n",
+ "firstname": "Zane",
+ "lastname": "Griffith",
+ "description": "Zane is a nice chap!",
+ },
+ ],
)
class TestHaystackSerializerMeta(SimpleTestCase):
-
def test_abstract_not_inherited(self):
class Base(serializers.Serializer, metaclass=HaystackSerializerMeta):
_abstract = True
@@ -661,21 +665,21 @@ class Sub(HaystackSerializer):
class TestMeta(SimpleTestCase):
-
def test_inheritance(self):
"""
Tests that Meta fields are correctly overriden by subclasses.
"""
+
class Serializer(HaystackSerializer):
class Meta:
- fields = ('overriden_fields',)
+ fields = ("overriden_fields",)
- self.assertEqual(Serializer.Meta.fields, ('overriden_fields',))
+ self.assertEqual(Serializer.Meta.fields, ("overriden_fields",))
def test_default_attrs(self):
class Serializer(HaystackSerializer):
class Meta:
- fields = ('overriden_fields',)
+ fields = ("overriden_fields",)
self.assertEqual(Serializer.Meta.exclude, tuple())
@@ -683,8 +687,9 @@ def test_raises_if_fields_and_exclude_defined(self):
def create_subclass():
class Serializer(HaystackSerializer):
class Meta:
- fields = ('include_field',)
- exclude = ('exclude_field',)
+ fields = ("include_field",)
+ exclude = ("exclude_field",)
+
return Serializer
self.assertRaises(ImproperlyConfigured, create_subclass)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index c225687..0814a45 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -4,48 +4,40 @@
class MergeDictTestCase(TestCase):
-
def setUp(self):
self.dict_a = {
- "person": {
- "lastname": "Holmes",
- "combat_proficiency": [
- "Pistol",
- "boxing"
- ]
- },
+ "person": {"lastname": "Holmes", "combat_proficiency": ["Pistol", "boxing"]},
}
self.dict_b = {
"person": {
"gender": "male",
"firstname": "Sherlock",
- "location": {
- "address": "221B Baker Street"
- },
+ "location": {"address": "221B Baker Street"},
"combat_proficiency": [
"sword",
"Martial arts",
- ]
+ ],
}
}
def test_utils_merge_dict(self):
- self.assertEqual(merge_dict(self.dict_a, self.dict_b), {
- "person": {
- "gender": "male",
- "firstname": "Sherlock",
- "lastname": "Holmes",
- "location": {
- "address": "221B Baker Street"
- },
- "combat_proficiency": [
- "Martial arts",
- "Pistol",
- "boxing",
- "sword",
- ]
- }
- })
+ self.assertEqual(
+ merge_dict(self.dict_a, self.dict_b),
+ {
+ "person": {
+ "gender": "male",
+ "firstname": "Sherlock",
+ "lastname": "Holmes",
+ "location": {"address": "221B Baker Street"},
+ "combat_proficiency": [
+ "Martial arts",
+ "Pistol",
+ "boxing",
+ "sword",
+ ],
+ }
+ },
+ )
def test_utils_merge_dict_invalid_input(self):
self.assertEqual(merge_dict(self.dict_a, "I'm not a dict!"), "I'm not a dict!")
diff --git a/tests/test_viewsets.py b/tests/test_viewsets.py
index bcfd302..e46f563 100644
--- a/tests/test_viewsets.py
+++ b/tests/test_viewsets.py
@@ -6,31 +6,27 @@
import json
from unittest import skipIf
-from django.test import TestCase
from django.contrib.auth.models import User
-
+from django.test import TestCase
from haystack.query import SearchQuerySet
-
from rest_framework import status
from rest_framework.pagination import PageNumberPagination
from rest_framework.routers import SimpleRouter
from rest_framework.serializers import Serializer
-from rest_framework.test import force_authenticate, APIRequestFactory
+from rest_framework.test import APIRequestFactory, force_authenticate
+from drf_haystack.mixins import FacetMixin, MoreLikeThisMixin
+from drf_haystack.serializers import HaystackFacetSerializer, HaystackSerializer
from drf_haystack.viewsets import HaystackViewSet
-from drf_haystack.serializers import HaystackSerializer, HaystackFacetSerializer
-from drf_haystack.mixins import MoreLikeThisMixin, FacetMixin
from . import restframework_version
from .mockapp.models import MockPerson, MockPet
from .mockapp.search_indexes import MockPersonIndex, MockPetIndex
-
factory = APIRequestFactory()
class HaystackViewSetTestCase(TestCase):
-
fixtures = ["mockperson", "mockpet"]
def setUp(self):
@@ -39,7 +35,6 @@ def setUp(self):
self.router = SimpleRouter()
class FacetSerializer(HaystackFacetSerializer):
-
class Meta:
fields = ["firstname", "lastname", "created"]
@@ -101,10 +96,7 @@ def test_viewset_get_obj_raise_404(self):
def test_viewset_get_object_invalid_lookup_field(self):
request = factory.get(path="/", data="", content_type="application/json")
- self.assertRaises(
- AttributeError,
- self.view1.as_view(actions={"get": "retrieve"}), request, invalid_lookup=1
- )
+ self.assertRaises(AttributeError, self.view1.as_view(actions={"get": "retrieve"}), request, invalid_lookup=1)
def test_viewset_get_obj_override_lookup_field(self):
setattr(self.view1, "lookup_field", "custom_lookup")
@@ -130,7 +122,6 @@ def test_viewset_facets_action_route(self):
class HaystackViewSetPermissionsTestCase(TestCase):
-
fixtures = ["mockperson"]
def setUp(self):
@@ -155,7 +146,8 @@ def test_viewset_get_queryset_with_no_permsission(self):
def test_viewset_get_queryset_with_AllowAny_permission(self):
from rest_framework.permissions import AllowAny
- setattr(self.view, "permission_classes", (AllowAny, ))
+
+ setattr(self.view, "permission_classes", (AllowAny,))
request = factory.get(path="/", data="", content_type="application/json")
response = self.view.as_view(actions={"get": "list"})(request)
@@ -163,7 +155,8 @@ def test_viewset_get_queryset_with_AllowAny_permission(self):
def test_viewset_get_queryset_with_IsAuthenticated_permission(self):
from rest_framework.permissions import IsAuthenticated
- setattr(self.view, "permission_classes", (IsAuthenticated, ))
+
+ setattr(self.view, "permission_classes", (IsAuthenticated,))
request = factory.get(path="/", data="", content_type="application/json")
response = self.view.as_view(actions={"get": "list"})(request)
@@ -175,6 +168,7 @@ def test_viewset_get_queryset_with_IsAuthenticated_permission(self):
def test_viewset_get_queryset_with_IsAdminUser_permission(self):
from rest_framework.permissions import IsAdminUser
+
setattr(self.view, "permission_classes", (IsAdminUser,))
request = factory.get(path="/", data="", content_type="application/json")
@@ -188,6 +182,7 @@ def test_viewset_get_queryset_with_IsAdminUser_permission(self):
def test_viewset_get_queryset_with_IsAuthenticatedOrReadOnly_permission(self):
from rest_framework.permissions import IsAuthenticatedOrReadOnly
+
setattr(self.view, "permission_classes", (IsAuthenticatedOrReadOnly,))
# Unauthenticated GET requests should pass
@@ -207,6 +202,7 @@ def test_viewset_get_queryset_with_IsAuthenticatedOrReadOnly_permission(self):
@skipIf(not restframework_version < (3, 7), "Skipped due to fix in django-rest-framework > 3.6")
def test_viewset_get_queryset_with_DjangoModelPermissions_permission(self):
from rest_framework.permissions import DjangoModelPermissions
+
setattr(self.view, "permission_classes", (DjangoModelPermissions,))
# The `DjangoModelPermissions` is not supported and should raise an
@@ -214,17 +210,23 @@ def test_viewset_get_queryset_with_DjangoModelPermissions_permission(self):
request = factory.get(path="/", data="", content_type="application/json")
try:
self.view.as_view(actions={"get": "list"})(request)
- self.fail("Did not fail with AssertionError or AttributeError "
- "when calling HaystackView with DjangoModelPermissions")
+ self.fail(
+ "Did not fail with AssertionError or AttributeError "
+ "when calling HaystackView with DjangoModelPermissions"
+ )
except (AttributeError, AssertionError) as e:
if isinstance(e, AttributeError):
self.assertEqual(str(e), "'SearchQuerySet' object has no attribute 'model'")
else:
- self.assertEqual(str(e), "Cannot apply DjangoModelPermissions on a view that does "
- "not have `.model` or `.queryset` property.")
+ self.assertEqual(
+ str(e),
+ "Cannot apply DjangoModelPermissions on a view that does "
+ "not have `.model` or `.queryset` property.",
+ )
def test_viewset_get_queryset_with_DjangoModelPermissionsOrAnonReadOnly_permission(self):
from rest_framework.permissions import DjangoModelPermissionsOrAnonReadOnly
+
setattr(self.view, "permission_classes", (DjangoModelPermissionsOrAnonReadOnly,))
# The `DjangoModelPermissionsOrAnonReadOnly` is not supported and should raise an
@@ -232,18 +234,23 @@ def test_viewset_get_queryset_with_DjangoModelPermissionsOrAnonReadOnly_permissi
request = factory.get(path="/", data="", content_type="application/json")
try:
self.view.as_view(actions={"get": "list"})(request)
- self.fail("Did not fail with AssertionError when calling HaystackView "
- "with DjangoModelPermissionsOrAnonReadOnly")
+ self.fail(
+ "Did not fail with AssertionError when calling HaystackView with DjangoModelPermissionsOrAnonReadOnly"
+ )
except (AttributeError, AssertionError) as e:
if isinstance(e, AttributeError):
self.assertEqual(str(e), "'SearchQuerySet' object has no attribute 'model'")
else:
- self.assertEqual(str(e), "Cannot apply DjangoModelPermissions on a view that does "
- "not have `.model` or `.queryset` property.")
+ self.assertEqual(
+ str(e),
+ "Cannot apply DjangoModelPermissions on a view that does "
+ "not have `.model` or `.queryset` property.",
+ )
@skipIf(not restframework_version < (3, 7), "Skipped due to fix in django-rest-framework > 3.6")
def test_viewset_get_queryset_with_DjangoObjectPermissions_permission(self):
from rest_framework.permissions import DjangoObjectPermissions
+
setattr(self.view, "permission_classes", (DjangoObjectPermissions,))
# The `DjangoObjectPermissions` is a subclass of `DjangoModelPermissions` and
@@ -256,12 +263,14 @@ def test_viewset_get_queryset_with_DjangoObjectPermissions_permission(self):
if isinstance(e, AttributeError):
self.assertEqual(str(e), "'SearchQuerySet' object has no attribute 'model'")
else:
- self.assertEqual(str(e), "Cannot apply DjangoModelPermissions on a view that does "
- "not have `.model` or `.queryset` property.")
+ self.assertEqual(
+ str(e),
+ "Cannot apply DjangoModelPermissions on a view that does "
+ "not have `.model` or `.queryset` property.",
+ )
class PaginatedHaystackViewSetTestCase(TestCase):
-
fixtures = ["mockperson"]
def setUp(self):
@@ -269,17 +278,14 @@ def setUp(self):
MockPersonIndex().reindex()
class Serializer1(HaystackSerializer):
-
class Meta:
fields = ["firstname", "lastname"]
index_classes = [MockPersonIndex]
class NumberPagination(PageNumberPagination):
-
page_size = 5
class ViewSet1(HaystackViewSet):
-
index_models = [MockPerson]
serializer_class = Serializer1
pagination_class = NumberPagination
diff --git a/tests/urls.py b/tests/urls.py
index e1e3424..95edb3f 100644
--- a/tests/urls.py
+++ b/tests/urls.py
@@ -1,5 +1,4 @@
from django.urls import include, path
-
from rest_framework import routers
from tests.mockapp.views import SearchPersonFacetViewSet, SearchPersonMLTViewSet
@@ -8,6 +7,4 @@
router.register("search-person-facet", viewset=SearchPersonFacetViewSet, basename="search-person-facet")
router.register("search-person-mlt", viewset=SearchPersonMLTViewSet, basename="search-person-mlt")
-urlpatterns = [
- path(r"^", include(router.urls))
-]
+urlpatterns = [path(r"^", include(router.urls))]
diff --git a/tests/wsgi.py b/tests/wsgi.py
index 4f38a59..af8c2a6 100644
--- a/tests/wsgi.py
+++ b/tests/wsgi.py
@@ -8,7 +8,9 @@
"""
import os
+
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "tests.settings")
from django.core.wsgi import get_wsgi_application
+
application = get_wsgi_application()