diff --git a/django_tasks/backends/base.py b/django_tasks/backends/base.py index c33658ab..191f996f 100644 --- a/django_tasks/backends/base.py +++ b/django_tasks/backends/base.py @@ -4,6 +4,7 @@ from typing import Any, TypeVar from asgiref.sync import sync_to_async +from django.conf import settings from django.core.checks import messages from django.db import connections from django.utils import timezone @@ -83,7 +84,11 @@ def validate_task(self, task: Task) -> None: if not self.supports_defer and task.run_after is not None: raise InvalidTaskError("Backend does not support run_after") - if task.run_after is not None and not timezone.is_aware(task.run_after): + if ( + settings.USE_TZ + and task.run_after is not None + and not timezone.is_aware(task.run_after) + ): raise InvalidTaskError("run_after must be an aware datetime") if self.queues and task.queue_name not in self.queues: diff --git a/django_tasks/backends/database/migrations/0016_remove_dbtaskresult_django_task_new_ordering_idx_and_more.py b/django_tasks/backends/database/migrations/0016_remove_dbtaskresult_django_task_new_ordering_idx_and_more.py index d425371c..bd601c3f 100644 --- a/django_tasks/backends/database/migrations/0016_remove_dbtaskresult_django_task_new_ordering_idx_and_more.py +++ b/django_tasks/backends/database/migrations/0016_remove_dbtaskresult_django_task_new_ordering_idx_and_more.py @@ -54,9 +54,7 @@ class Migration(migrations.Migration): model_name="dbtaskresult", name="run_after", field=models.DateTimeField( - default=datetime.datetime( - 9999, 1, 1, 0, 0, tzinfo=datetime.timezone.utc - ), + default=datetime.datetime(9999, 1, 1, 0, 0), verbose_name="run after", ), preserve_default=False, diff --git a/django_tasks/backends/database/models.py b/django_tasks/backends/database/models.py index 16063d9c..4824a145 100644 --- a/django_tasks/backends/database/models.py +++ b/django_tasks/backends/database/models.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar import django +from django.conf import settings from django.core.exceptions import SuspiciousOperation from django.db import models from django.db.models import F, Q @@ -48,7 +49,12 @@ def __class_getitem__(cls, _): return cls -DATE_MAX = datetime.datetime(9999, 1, 1, tzinfo=datetime.timezone.utc) +DATE_MAX = datetime.datetime(9999, 1, 1) +DATE_MAXS = (datetime.datetime(9999, 1, 1, tzinfo=datetime.timezone.utc), DATE_MAX) + + +def get_date_max() -> datetime.datetime: + return DATE_MAXS[0] if settings.USE_TZ else DATE_MAXS[1] class DBTaskResultQuerySet(models.QuerySet): @@ -58,7 +64,9 @@ def ready(self) -> "DBTaskResultQuerySet": """ return self.filter( status=ResultStatus.READY, - ).filter(models.Q(run_after=DATE_MAX) | models.Q(run_after__lte=timezone.now())) + ).filter( + models.Q(run_after=get_date_max()) | models.Q(run_after__lte=timezone.now()) + ) def succeeded(self) -> "DBTaskResultQuerySet": return self.filter(status=ResultStatus.SUCCEEDED) @@ -157,7 +165,7 @@ def task(self) -> Task[P, T]: return task.using( priority=self.priority, queue_name=self.queue_name, - run_after=None if self.run_after == DATE_MAX else self.run_after, + run_after=None if self.run_after in DATE_MAXS else self.run_after, backend=self.backend_name, ) diff --git a/django_tasks/backends/database/signal_handlers.py b/django_tasks/backends/database/signal_handlers.py index 09c0d3b1..dc9eb71d 100644 --- a/django_tasks/backends/database/signal_handlers.py +++ b/django_tasks/backends/database/signal_handlers.py @@ -3,10 +3,10 @@ from django.db.models.signals import pre_save from django.dispatch import receiver -from .models import DATE_MAX, DBTaskResult +from .models import DBTaskResult, get_date_max @receiver(pre_save, sender=DBTaskResult) def set_run_after(sender: Any, instance: DBTaskResult, **kwargs: Any) -> None: if instance.run_after is None: - instance.run_after = DATE_MAX + instance.run_after = get_date_max() diff --git a/tests/tests/test_database_backend.py b/tests/tests/test_database_backend.py index b1ce8990..0c518639 100644 --- a/tests/tests/test_database_backend.py +++ b/tests/tests/test_database_backend.py @@ -7,9 +7,10 @@ import sys import time import uuid +import warnings from collections import Counter from contextlib import redirect_stderr -from datetime import timedelta +from datetime import datetime, timedelta from functools import partial from io import StringIO from typing import Any, List, Optional, Sequence, Union, cast @@ -436,6 +437,36 @@ def test_index_scan_for_ready(self) -> None: else: self.fail("Unknown database engine") + def test_run_after_tz(self) -> None: + for use_tz in [True, False]: + with self.subTest(use_tz=use_tz): + with override_settings(USE_TZ=use_tz): + result = test_tasks.noop_task.enqueue() + self.assertIsNone( + DBTaskResult.objects.get(id=result.id).task.run_after + ) + + def test_run_after_null_0016_migration(self) -> None: + for use_tz in [True, False]: + with self.subTest(use_tz=use_tz): + with override_settings(USE_TZ=use_tz): + result = test_tasks.noop_task.enqueue() + + db_result = DBTaskResult.objects.get(id=result.id) + + # Literal taken from migration + db_result.run_after = datetime(9999, 1, 1, 0, 0) + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", module="django.db", category=RuntimeWarning + ) + db_result.save() + + self.assertIsNone( + DBTaskResult.objects.get(id=result.id).task.run_after + ) + @override_settings( TASKS={