Skip to content

Commit

Permalink
Merge pull request #763 from WesternFriend/search-tests
Browse files Browse the repository at this point in the history
Add search tests
  • Loading branch information
brylie committed Jul 7, 2023
2 parents 7951827 + a6cf222 commit 49ffc57
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 1 deletion.
22 changes: 22 additions & 0 deletions common/factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any
import factory
from wagtail.models import Page


class PageFactory(factory.django.DjangoModelFactory):
class Meta:
model = Page

title = factory.Sequence(lambda n: f"Test Page {n}")

@factory.post_generation
def add_to_root(
obj: "PageFactory",
create: bool,
extracted: Any | None,
**kwargs: Any,
) -> None:
if create:
root_page = Page.objects.first()
root_page.add_child(instance=obj)
obj.save()
66 changes: 66 additions & 0 deletions search/tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from django.test import TestCase, Client
from django.urls import reverse
from wagtail.models import Page
from wagtail.search.models import Query
from unittest.mock import Mock, patch


class SearchViewTestCase(TestCase):
def setUp(self) -> None:
self.client = Client()

# get the root page
root_page = Page.objects.first()

# add test pages as children of the root
self.page1 = Page(title="Test Page 1")
root_page.add_child(instance=self.page1)

self.page2 = Page(title="Test Page 2")
root_page.add_child(instance=self.page2)

self.page3 = Page(title="Test Page 3")
root_page.add_child(instance=self.page3)

self.page1.save()
self.page2.save()
self.page3.save()

def test_search_no_query(self) -> None:
response = self.client.get(reverse("search"))
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.context["search_results"]), 0)

def test_search_query(self) -> None:
response = self.client.get(
reverse("search"),
{"query": "Test"},
)
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.context["search_results"]), 3)

def test_search_pagination_non_existant_page_default_first(self) -> None:
response = self.client.get(
reverse("search"),
{
"query": "Test",
"page": 2,
},
)
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.context["search_results"]), 3)

@patch.object(Query, "add_hit")
def test_search_query_hit(self, mock_add_hit: Mock) -> None:
self.client.get("/search/?query=Test")
mock_add_hit.assert_called_once()

def test_search_pagination_invalid_page(self) -> None:
response = self.client.get("/search/?query=Test&page=abc")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.context["search_results"].number, 1)

def test_search_pagination_out_of_range(self) -> None:
response = self.client.get("/search/?query=Test&page=100")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.context["search_results"].number, 1)
3 changes: 2 additions & 1 deletion search/views.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from django.core.paginator import EmptyPage, PageNotAnInteger, Paginator
from django.http import HttpRequest, HttpResponse
from django.shortcuts import render
from wagtail.models import Page
from wagtail.search.models import Query


def search(request):
def search(request: HttpRequest) -> HttpResponse:
search_query = request.GET.get("query", None)
page = request.GET.get("page", 1)

Expand Down

0 comments on commit 49ffc57

Please sign in to comment.