Skip to content

Commit

Permalink
Merge pull request #717 from josh-bristow/patch-2
Browse files Browse the repository at this point in the history
Update tests.py
  • Loading branch information
brylie committed Jun 16, 2023
2 parents 5f24c0d + c164f53 commit 413cd19
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 20 deletions.
55 changes: 38 additions & 17 deletions cart/cart.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from decimal import Decimal
from collections.abc import Generator

from django.conf import settings
from django.http import HttpRequest

from shipping.calculator import get_book_shipping_cost
from store.models import Product


class Cart:
def __init__(self, request):
def __init__(self, request: HttpRequest) -> None:
"""Initialize the cart."""
self.session = request.session

Expand All @@ -19,9 +21,14 @@ def __init__(self, request):

self.cart = cart

def add(self, product, quantity=1, update_quantity=False):
def add(
self,
product: Product,
quantity: int = 1,
update_quantity: bool = False,
) -> None:
"""Add a product to the cart or update its quantity."""
product_id = str(product.id)
product_id = str(product.id) # type: ignore

if product_id not in self.cart:
self.cart[product_id] = {
Expand All @@ -37,64 +44,78 @@ def add(self, product, quantity=1, update_quantity=False):

self.save()

def save(self):
def save(self) -> None:
# mark the session as "modified"
# to make sure it gets saved

self.session.modified = True

def remove(self, product):
def remove(self, product: Product) -> None:
"""Remove a product from the cart."""
product_id = str(product.id)
product_id = str(product.id) # type: ignore

if product_id in self.cart:
del self.cart[product_id]

self.save()

def get_cart_products(self):
def get_cart_products(self) -> list[Product]:
product_ids = self.cart.keys()

# get the product objects and add them to the cart
return Product.objects.filter(id__in=product_ids)

def get_total_price(self):
return sum([self.get_subtotal_price(), self.get_shipping_cost()])
def get_total_price(self) -> Decimal:
int_sum = sum(
[
self.get_subtotal_price(),
self.get_shipping_cost(),
]
)
return Decimal(int_sum).quantize(Decimal("0.01"))

def get_subtotal_price(self):
return sum(
def get_subtotal_price(self) -> Decimal:
totals = [
Decimal(item["price"]) * item["quantity"] for item in self.cart.values()
)
]
product_sum = sum(totals)
return Decimal(product_sum).quantize(Decimal("0.01"))

def get_shipping_cost(self):
def get_shipping_cost(self) -> Decimal:
book_quantity = sum(item["quantity"] for item in self.cart.values())

return get_book_shipping_cost(book_quantity)

def clear(self):
def clear(self) -> None:
# remove cart from session
del self.session[settings.CART_SESSION_ID]

self.save()

def __iter__(self):
def __iter__(self) -> Generator:
"""Get cart products from the database."""
# get the product objects and add them to the cart
products = self.get_cart_products()

cart = self.cart.copy()

for product in products:
cart[str(product.id)]["product"] = product
if str(product.id) not in cart: # type: ignore
continue

cart[str(product.id)]["product"] = product # type: ignore

for item in cart.values():
item["price"] = Decimal(item["price"])
item["total_price"] = item["price"] * item["quantity"]

yield item

def __len__(self):
def __len__(self) -> int:
"""Count all items in the cart."""

# TODO: determine whether this should count the number of products
# or the total quantity of products
item_quantities = [item["quantity"] for item in self.cart.values()]

return sum(item_quantities)
148 changes: 147 additions & 1 deletion cart/tests.py
Original file line number Diff line number Diff line change
@@ -1 +1,147 @@
# Create your tests here.
from decimal import Decimal
from unittest.mock import Mock, patch
from django.test import RequestFactory, TestCase
from django.contrib.sessions.middleware import SessionMiddleware
from wagtail.models import Page

from .cart import Cart
from home.models import HomePage
from store.models import Product, ProductIndexPage, StoreIndexPage


class CartTestCase(TestCase):
def setUp(self) -> None:
self.request = RequestFactory().get("/")

# Add session middleware to the request
middleware = SessionMiddleware(Mock())
middleware.process_request(self.request)
self.request.session.save()

# get Site Root
root_page = Page.objects.get(id=1)

# Create HomePage
home_page = HomePage(
title="Welcome",
)

root_page.add_child(instance=home_page)
# root_page.save()

# Create StoreIndexPage
store_index_page = StoreIndexPage(
title="Bookstore",
show_in_menus=True,
)
home_page.add_child(instance=store_index_page)

# Create ProductIndexPage
product_index_page = ProductIndexPage(
title="Products",
)
store_index_page.add_child(instance=product_index_page)

self.product1 = Product(
title="Product 1",
price=Decimal("9.99"),
)
self.product2 = Product(
title="Product 2",
price=Decimal("19.99"),
)
product_index_page.add_child(instance=self.product1)
product_index_page.add_child(instance=self.product2)

def test_cart_initialization(self) -> None:
cart = Cart(self.request)

self.assertEqual(len(cart), 0)
self.assertEqual(cart.get_subtotal_price(), Decimal("0"))

def test_add_product(self) -> None:
cart = Cart(self.request)

cart.add(self.product1)
self.assertEqual(len(cart), 1)
self.assertEqual(cart.get_subtotal_price(), Decimal("9.99"))

cart.add(self.product1, quantity=2)
self.assertEqual(len(cart), 3)
self.assertEqual(cart.get_subtotal_price(), Decimal("29.97"))

def test_save_cart(self) -> None:
cart = Cart(self.request)
cart.save()

self.assertTrue(cart.session.modified)

def test_remove_product(self) -> None:
cart = Cart(self.request)

cart.add(self.product1)

self.assertEqual(len(cart), 1)
cart.remove(self.product1)
self.assertEqual(len(cart), 0)

def test_get_cart_products(self) -> None:
cart = Cart(self.request)

cart.add(self.product1)
cart.add(self.product2)

cart_products = cart.get_cart_products()

self.assertEqual(len(cart_products), 2)
self.assertIn(self.product1, cart_products)
self.assertIn(self.product2, cart_products)

def test_get_subtotal_price(self) -> None:
cart = Cart(self.request)

cart.add(self.product1)
cart.add(self.product2, quantity=2)

subtotal_price = cart.get_subtotal_price()

self.assertEqual(subtotal_price, Decimal("49.97"))

def test_get_shipping_cost(self) -> None:
cart = Cart(self.request)

# Mocking the get_shipping_cost method
expected_shipping_cost = Decimal("10.00")
with patch.object(
cart,
"get_shipping_cost",
return_value=expected_shipping_cost,
):
shipping_cost = cart.get_shipping_cost()

self.assertEqual(shipping_cost, expected_shipping_cost)

def test_cart_iteration(self) -> None:
cart = Cart(self.request)

cart.add(self.product1)
cart.add(self.product2, quantity=2)

cart_items = list(cart)

self.assertEqual(len(cart_items), 2)
self.assertEqual(cart_items[0]["product"], self.product1)
self.assertEqual(cart_items[0]["quantity"], 1)
self.assertEqual(cart_items[0]["price"], Decimal("9.99"))
self.assertEqual(cart_items[0]["total_price"], Decimal("9.99"))

self.assertEqual(cart_items[1]["product"], self.product2)
self.assertEqual(cart_items[1]["quantity"], 2)
self.assertEqual(cart_items[1]["price"], Decimal("19.99"))
self.assertEqual(cart_items[1]["total_price"], Decimal("39.98"))

def tearDown(self) -> None:
# delete all pages
Page.objects.all().delete()

return super().tearDown()
2 changes: 1 addition & 1 deletion shipping/calculator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from decimal import Decimal


def get_book_shipping_cost(book_quantity=1):
def get_book_shipping_cost(book_quantity: int = 1) -> Decimal:
"""Calculate shipping costs for books in a cart/order.
The shipping rules are flat rate for each book, with discounts
Expand Down
25 changes: 25 additions & 0 deletions store/migrations/0002_alter_product_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Generated by Django 4.2.1 on 2023-06-16 07:34

from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):
dependencies = [
("wagtailimages", "0025_alter_image_file_alter_rendition_file"),
("store", "0001_initial"),
]

operations = [
migrations.AlterField(
model_name="product",
name="image",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="+",
to="wagtailimages.image",
),
),
]
6 changes: 5 additions & 1 deletion store/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ class ProductIndexPage(Page):

class Product(Page):
image = models.ForeignKey(
"wagtailimages.Image", on_delete=models.SET_NULL, null=True, related_name="+"
"wagtailimages.Image",
on_delete=models.SET_NULL,
null=True,
blank=True, # note, making this required will break the tests
related_name="+",
)
description = RichTextField(blank=True)
price = models.DecimalField(max_digits=10, decimal_places=2)
Expand Down

0 comments on commit 413cd19

Please sign in to comment.