Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update tests.py #717

Merged
merged 6 commits into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 38 additions & 17 deletions cart/cart.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from decimal import Decimal
from collections.abc import Generator

from django.conf import settings
from django.http import HttpRequest

from shipping.calculator import get_book_shipping_cost
from store.models import Product


class Cart:
def __init__(self, request):
def __init__(self, request: HttpRequest) -> None:
"""Initialize the cart."""
self.session = request.session

Expand All @@ -19,9 +21,14 @@ def __init__(self, request):

self.cart = cart

def add(self, product, quantity=1, update_quantity=False):
def add(
self,
product: Product,
quantity: int = 1,
update_quantity: bool = False,
) -> None:
"""Add a product to the cart or update its quantity."""
product_id = str(product.id)
product_id = str(product.id) # type: ignore

if product_id not in self.cart:
self.cart[product_id] = {
Expand All @@ -37,64 +44,78 @@ def add(self, product, quantity=1, update_quantity=False):

self.save()

def save(self):
def save(self) -> None:
# mark the session as "modified"
# to make sure it gets saved

self.session.modified = True

def remove(self, product):
def remove(self, product: Product) -> None:
"""Remove a product from the cart."""
product_id = str(product.id)
product_id = str(product.id) # type: ignore

if product_id in self.cart:
del self.cart[product_id]

self.save()

def get_cart_products(self):
def get_cart_products(self) -> list[Product]:
product_ids = self.cart.keys()

# get the product objects and add them to the cart
return Product.objects.filter(id__in=product_ids)

def get_total_price(self):
return sum([self.get_subtotal_price(), self.get_shipping_cost()])
def get_total_price(self) -> Decimal:
int_sum = sum(
[
self.get_subtotal_price(),
self.get_shipping_cost(),
]
)
return Decimal(int_sum).quantize(Decimal("0.01"))

def get_subtotal_price(self):
return sum(
def get_subtotal_price(self) -> Decimal:
totals = [
Decimal(item["price"]) * item["quantity"] for item in self.cart.values()
)
]
product_sum = sum(totals)
return Decimal(product_sum).quantize(Decimal("0.01"))

def get_shipping_cost(self):
def get_shipping_cost(self) -> Decimal:
book_quantity = sum(item["quantity"] for item in self.cart.values())

return get_book_shipping_cost(book_quantity)

def clear(self):
def clear(self) -> None:
# remove cart from session
del self.session[settings.CART_SESSION_ID]

self.save()

def __iter__(self):
def __iter__(self) -> Generator:
"""Get cart products from the database."""
# get the product objects and add them to the cart
products = self.get_cart_products()

cart = self.cart.copy()

for product in products:
cart[str(product.id)]["product"] = product
if str(product.id) not in cart: # type: ignore
continue

cart[str(product.id)]["product"] = product # type: ignore

for item in cart.values():
item["price"] = Decimal(item["price"])
item["total_price"] = item["price"] * item["quantity"]

yield item

def __len__(self):
def __len__(self) -> int:
"""Count all items in the cart."""

# TODO: determine whether this should count the number of products
# or the total quantity of products
item_quantities = [item["quantity"] for item in self.cart.values()]

return sum(item_quantities)
148 changes: 147 additions & 1 deletion cart/tests.py
Original file line number Diff line number Diff line change
@@ -1 +1,147 @@
# Create your tests here.
from decimal import Decimal
from unittest.mock import Mock, patch
from django.test import RequestFactory, TestCase
from django.contrib.sessions.middleware import SessionMiddleware
from wagtail.models import Page

from .cart import Cart
from home.models import HomePage
from store.models import Product, ProductIndexPage, StoreIndexPage


class CartTestCase(TestCase):
def setUp(self) -> None:
self.request = RequestFactory().get("/")

# Add session middleware to the request
middleware = SessionMiddleware(Mock())
middleware.process_request(self.request)
self.request.session.save()

# get Site Root
root_page = Page.objects.get(id=1)

# Create HomePage
home_page = HomePage(
title="Welcome",
)

root_page.add_child(instance=home_page)
# root_page.save()

# Create StoreIndexPage
store_index_page = StoreIndexPage(
title="Bookstore",
show_in_menus=True,
)
home_page.add_child(instance=store_index_page)

# Create ProductIndexPage
product_index_page = ProductIndexPage(
title="Products",
)
store_index_page.add_child(instance=product_index_page)

self.product1 = Product(
title="Product 1",
price=Decimal("9.99"),
)
self.product2 = Product(
title="Product 2",
price=Decimal("19.99"),
)
product_index_page.add_child(instance=self.product1)
product_index_page.add_child(instance=self.product2)

def test_cart_initialization(self) -> None:
cart = Cart(self.request)

self.assertEqual(len(cart), 0)
self.assertEqual(cart.get_subtotal_price(), Decimal("0"))

def test_add_product(self) -> None:
cart = Cart(self.request)

cart.add(self.product1)
self.assertEqual(len(cart), 1)
self.assertEqual(cart.get_subtotal_price(), Decimal("9.99"))

cart.add(self.product1, quantity=2)
self.assertEqual(len(cart), 3)
self.assertEqual(cart.get_subtotal_price(), Decimal("29.97"))

def test_save_cart(self) -> None:
cart = Cart(self.request)
cart.save()

self.assertTrue(cart.session.modified)

def test_remove_product(self) -> None:
cart = Cart(self.request)

cart.add(self.product1)

self.assertEqual(len(cart), 1)
cart.remove(self.product1)
self.assertEqual(len(cart), 0)

def test_get_cart_products(self) -> None:
cart = Cart(self.request)

cart.add(self.product1)
cart.add(self.product2)

cart_products = cart.get_cart_products()

self.assertEqual(len(cart_products), 2)
self.assertIn(self.product1, cart_products)
self.assertIn(self.product2, cart_products)

def test_get_subtotal_price(self) -> None:
cart = Cart(self.request)

cart.add(self.product1)
cart.add(self.product2, quantity=2)

subtotal_price = cart.get_subtotal_price()

self.assertEqual(subtotal_price, Decimal("49.97"))

def test_get_shipping_cost(self) -> None:
cart = Cart(self.request)

# Mocking the get_shipping_cost method
expected_shipping_cost = Decimal("10.00")
with patch.object(
cart,
"get_shipping_cost",
return_value=expected_shipping_cost,
):
shipping_cost = cart.get_shipping_cost()

self.assertEqual(shipping_cost, expected_shipping_cost)

def test_cart_iteration(self) -> None:
cart = Cart(self.request)

cart.add(self.product1)
cart.add(self.product2, quantity=2)

cart_items = list(cart)

self.assertEqual(len(cart_items), 2)
self.assertEqual(cart_items[0]["product"], self.product1)
self.assertEqual(cart_items[0]["quantity"], 1)
self.assertEqual(cart_items[0]["price"], Decimal("9.99"))
self.assertEqual(cart_items[0]["total_price"], Decimal("9.99"))

self.assertEqual(cart_items[1]["product"], self.product2)
self.assertEqual(cart_items[1]["quantity"], 2)
self.assertEqual(cart_items[1]["price"], Decimal("19.99"))
self.assertEqual(cart_items[1]["total_price"], Decimal("39.98"))

def tearDown(self) -> None:
# delete all pages
Page.objects.all().delete()

return super().tearDown()
2 changes: 1 addition & 1 deletion shipping/calculator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from decimal import Decimal


def get_book_shipping_cost(book_quantity=1):
def get_book_shipping_cost(book_quantity: int = 1) -> Decimal:
"""Calculate shipping costs for books in a cart/order.

The shipping rules are flat rate for each book, with discounts
Expand Down
25 changes: 25 additions & 0 deletions store/migrations/0002_alter_product_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Generated by Django 4.2.1 on 2023-06-16 07:34

from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):
dependencies = [
("wagtailimages", "0025_alter_image_file_alter_rendition_file"),
("store", "0001_initial"),
]

operations = [
migrations.AlterField(
model_name="product",
name="image",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="+",
to="wagtailimages.image",
),
),
]
6 changes: 5 additions & 1 deletion store/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ class ProductIndexPage(Page):

class Product(Page):
image = models.ForeignKey(
"wagtailimages.Image", on_delete=models.SET_NULL, null=True, related_name="+"
"wagtailimages.Image",
on_delete=models.SET_NULL,
null=True,
blank=True, # note, making this required will break the tests
related_name="+",
)
description = RichTextField(blank=True)
price = models.DecimalField(max_digits=10, decimal_places=2)
Expand Down