Skip to content

✅ Refactor tests, consolidate into a single test file for multiple variants #1409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
31 changes: 19 additions & 12 deletions tests/test_advanced/test_decimal/test_tutorial001.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import importlib
import types
from decimal import Decimal
from unittest.mock import patch

import pytest
from sqlmodel import create_engine

from ...conftest import get_testing_print_function
from ...conftest import PrintMock, needs_py310 # Import PrintMock for type hint

expected_calls = [
[
Expand All @@ -30,15 +32,20 @@
]


def test_tutorial():
from docs_src.advanced.decimal import tutorial001 as mod

mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url)
calls = []
@pytest.fixture(
name="module",
params=[
"tutorial001",
pytest.param("tutorial001_py310", marks=needs_py310),
],
)
def get_module(request: pytest.FixtureRequest):
module_name = request.param
return importlib.import_module(f"docs_src.advanced.decimal.{module_name}")

new_print = get_testing_print_function(calls)

with patch("builtins.print", new=new_print):
mod.main()
assert calls == expected_calls
def test_tutorial(print_mock: PrintMock, module: types.ModuleType):
module.sqlite_url = "sqlite://"
module.engine = create_engine(module.sqlite_url)
module.main()
assert print_mock.calls == expected_calls # Use .calls instead of .mock_calls
45 changes: 0 additions & 45 deletions tests/test_advanced/test_decimal/test_tutorial001_py310.py

This file was deleted.

38 changes: 25 additions & 13 deletions tests/test_advanced/test_uuid/test_tutorial001.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,42 @@
from unittest.mock import patch
import importlib
import types

import pytest
from dirty_equals import IsUUID
from sqlmodel import create_engine

from ...conftest import get_testing_print_function
from ...conftest import PrintMock, needs_py310


def test_tutorial() -> None:
from docs_src.advanced.uuid import tutorial001 as mod
@pytest.fixture(
name="module",
params=[
"tutorial001",
pytest.param("tutorial001_py310", marks=needs_py310),
],
)
def get_module(request: pytest.FixtureRequest):
module_name = request.param
return importlib.import_module(f"docs_src.advanced.uuid.{module_name}")

mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url)
calls = []

new_print = get_testing_print_function(calls)
def test_tutorial(print_mock: PrintMock, module: types.ModuleType) -> None:
module.sqlite_url = "sqlite://"
module.engine = create_engine(module.sqlite_url)

with patch("builtins.print", new=new_print):
mod.main()
first_uuid = calls[1][0]["id"]
module.main()

# Extract UUIDs from actual calls recorded by print_mock
first_uuid = print_mock.calls[1][0]["id"]
assert first_uuid == IsUUID(4)

second_uuid = calls[7][0]["id"]
second_uuid = print_mock.calls[7][0]["id"]
assert second_uuid == IsUUID(4)

assert first_uuid != second_uuid

assert calls == [
# Construct expected_calls using the extracted UUIDs
expected_calls = [
["The hero before saving in the DB"],
[
{
Expand Down Expand Up @@ -69,3 +80,4 @@ def test_tutorial() -> None:
["Selected hero ID:"],
[second_uuid],
]
assert print_mock.calls == expected_calls
72 changes: 0 additions & 72 deletions tests/test_advanced/test_uuid/test_tutorial001_py310.py

This file was deleted.

38 changes: 25 additions & 13 deletions tests/test_advanced/test_uuid/test_tutorial002.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,42 @@
from unittest.mock import patch
import importlib
import types

import pytest
from dirty_equals import IsUUID
from sqlmodel import create_engine

from ...conftest import get_testing_print_function
from ...conftest import PrintMock, needs_py310


def test_tutorial() -> None:
from docs_src.advanced.uuid import tutorial002 as mod
@pytest.fixture(
name="module",
params=[
"tutorial002",
pytest.param("tutorial002_py310", marks=needs_py310),
],
)
def get_module(request: pytest.FixtureRequest):
module_name = request.param
return importlib.import_module(f"docs_src.advanced.uuid.{module_name}")

mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url)
calls = []

new_print = get_testing_print_function(calls)
def test_tutorial(print_mock: PrintMock, module: types.ModuleType) -> None:
module.sqlite_url = "sqlite://"
module.engine = create_engine(module.sqlite_url)

with patch("builtins.print", new=new_print):
mod.main()
first_uuid = calls[1][0]["id"]
module.main()

# Extract UUIDs from actual calls recorded by print_mock
first_uuid = print_mock.calls[1][0]["id"]
assert first_uuid == IsUUID(4)

second_uuid = calls[7][0]["id"]
second_uuid = print_mock.calls[7][0]["id"]
assert second_uuid == IsUUID(4)

assert first_uuid != second_uuid

assert calls == [
# Construct expected_calls using the extracted UUIDs
expected_calls = [
["The hero before saving in the DB"],
[
{
Expand Down Expand Up @@ -69,3 +80,4 @@ def test_tutorial() -> None:
["Selected hero ID:"],
[second_uuid],
]
assert print_mock.calls == expected_calls
72 changes: 0 additions & 72 deletions tests/test_advanced/test_uuid/test_tutorial002_py310.py

This file was deleted.

Loading
Loading