diff --git a/Makefile b/Makefile index 69daa40..be0a317 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,13 @@ test-unit: test-e2e: pytest -m e2e --verbose +coverage-clean: + rm -f .coverage .coverage.* coverage.xml + +coverage-all: coverage-clean + pytest -m "not e2e" --durations=0 --cov=durabletask --cov-branch --cov-report=term-missing --cov-report=xml + pytest -m e2e --durations=0 --cov=durabletask --cov-branch --cov-report=term-missing --cov-report=xml --cov-append + install: python3 -m pip install . @@ -18,4 +25,4 @@ gen-proto: python3 -m grpc_tools.protoc --proto_path=. --python_out=. --pyi_out=. --grpc_python_out=. ./durabletask/internal/orchestrator_service.proto rm durabletask/internal/*.proto -.PHONY: init test-unit test-e2e gen-proto install +.PHONY: init test-unit test-e2e coverage-clean coverage-all gen-proto install diff --git a/README.md b/README.md index 40a4e6e..d4604e0 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,38 @@ This repo contains a Python client SDK for use with the [Durable Task Framework > Note that this project is **not** currently affiliated with the [Durable Functions](https://docs.microsoft.com/azure/azure-functions/durable/durable-functions-overview) project for Azure Functions. If you are looking for a Python SDK for Durable Functions, please see [this repo](https://github.com/Azure/azure-functions-durable-python). +## Minimal worker setup + +To execute orchestrations and activities you must run a worker that connects to the Dapr Workflow sidecar and dispatches work on background threads: + +```python +from durabletask.worker import TaskHubGrpcWorker + +worker = TaskHubGrpcWorker(host_address="localhost:4001") + +worker.add_orchestrator(say_hello) +worker.add_activity(hello_activity) + +try: + worker.start() + # Worker runs in the background and processes work until stopped +finally: + worker.stop() +``` + +Always stop the worker when you're finished. The worker keeps polling threads alive; if you skip `stop()` they continue running and can prevent your process from shutting down cleanly after failures. You can rely on the context manager form to guarantee cleanup: + +```python +from durabletask.worker import TaskHubGrpcWorker + +with TaskHubGrpcWorker(host_address="localhost:4001") as worker: + worker.add_orchestrator(say_hello) + worker.add_activity(hello_activity) + worker.start() + # worker.stop() is called automatically on exit +``` + + ## Supported patterns The following orchestration patterns are currently supported. diff --git a/durabletask/deterministic.py b/durabletask/deterministic.py new file mode 100644 index 0000000..2943783 --- /dev/null +++ b/durabletask/deterministic.py @@ -0,0 +1,224 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" +Deterministic utilities for Durable Task workflows (async and generator). + +This module provides deterministic alternatives to non-deterministic Python +functions, ensuring workflow replay consistency across different executions. +It is shared by both the asyncio authoring model and the generator-based model. +""" + +import hashlib +import random +import string as _string +import uuid +from collections.abc import Sequence +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Optional, TypeVar + + +@dataclass +class DeterminismSeed: + """Seed data for deterministic operations.""" + + instance_id: str + orchestration_unix_ts: int + + def to_int(self) -> int: + """Convert seed to integer for PRNG initialization.""" + combined = f"{self.instance_id}:{self.orchestration_unix_ts}" + hash_bytes = hashlib.sha256(combined.encode("utf-8")).digest() + return int.from_bytes(hash_bytes[:8], byteorder="big") + + +def derive_seed(instance_id: str, orchestration_time: datetime) -> int: + """ + Derive a deterministic seed from instance ID and orchestration time. + """ + ts = int(orchestration_time.timestamp()) + return DeterminismSeed(instance_id=instance_id, orchestration_unix_ts=ts).to_int() + + +def deterministic_random(instance_id: str, orchestration_time: datetime) -> random.Random: + """ + Create a deterministic random number generator. + """ + seed = derive_seed(instance_id, orchestration_time) + return random.Random(seed) + + +def deterministic_uuid4(rnd: random.Random) -> uuid.UUID: + """ + Generate a deterministic UUID4 using the provided random generator. + + Note: This is deprecated in favor of deterministic_uuid_v5 which matches + the .NET implementation for cross-language compatibility. + """ + bytes_ = bytes(rnd.randrange(0, 256) for _ in range(16)) + bytes_list = list(bytes_) + bytes_list[6] = (bytes_list[6] & 0x0F) | 0x40 # Version 4 + bytes_list[8] = (bytes_list[8] & 0x3F) | 0x80 # Variant bits + return uuid.UUID(bytes=bytes(bytes_list)) + + +def deterministic_uuid_v5(instance_id: str, current_datetime: datetime, counter: int) -> uuid.UUID: + """ + Generate a deterministic UUID v5 matching the .NET implementation. + + This implementation matches the durabletask-dotnet NewGuid() method: + https://github.com/microsoft/durabletask-dotnet/blob/main/src/Worker/Core/Shims/TaskOrchestrationContextWrapper.cs + + Args: + instance_id: The orchestration instance ID. + current_datetime: The current orchestration datetime (frozen during replay). + counter: The per-call counter (starts at 0 on each replay). + + Returns: + A deterministic UUID v5 that will be the same across replays. + """ + # DNS namespace UUID - same as .NET DnsNamespaceValue + namespace = uuid.UUID("9e952958-5e33-4daf-827f-2fa12937b875") + + # Build name matching .NET format: instanceId_datetime_counter + # Using isoformat() which produces ISO 8601 format similar to .NET's ToString("o") + name = f"{instance_id}_{current_datetime.isoformat()}_{counter}" + + # Generate UUID v5 (SHA-1 based, matching .NET) + return uuid.uuid5(namespace, name) + + +class DeterministicContextMixin: + """ + Mixin providing deterministic helpers for workflow contexts. + + Assumes the inheriting class exposes `instance_id` and `current_utc_datetime` attributes. + + This implementation matches the .NET durabletask SDK approach with an explicit + counter for UUID generation that resets on each replay. + """ + + def __init__(self, *args, **kwargs): + """Initialize the mixin with UUID and timestamp counters.""" + super().__init__(*args, **kwargs) + # Counter for deterministic UUID generation (matches .NET newGuidCounter) + # This counter resets to 0 on each replay, ensuring determinism + self._uuid_counter: int = 0 + # Counter for deterministic timestamp sequencing (resets on replay) + self._timestamp_counter: int = 0 + + def now(self) -> datetime: + """Alias for deterministic current_utc_datetime.""" + return self.current_utc_datetime # type: ignore[attr-defined] + + def random(self) -> random.Random: + """Return a PRNG seeded deterministically from instance id and orchestration time.""" + rnd = deterministic_random( + self.instance_id, # type: ignore[attr-defined] + self.current_utc_datetime, # type: ignore[attr-defined] + ) + # Mark as deterministic for asyncio sandbox detector whitelisting of bound methods (randint, random) + try: + rnd._dt_deterministic = True + except Exception: + pass + return rnd + + def uuid4(self) -> uuid.UUID: + """ + Return a deterministically generated UUID v5 with explicit counter. + https://www.sohamkamani.com/uuid-versions-explained/#v5-non-random-uuids + + This matches the .NET implementation's NewGuid() method which uses: + - Instance ID + - Current UTC datetime (frozen during replay) + - Per-call counter (resets to 0 on each replay) + + The counter ensures multiple calls produce different UUIDs while maintaining + determinism across replays. + """ + # Lazily initialize counter if not set by __init__ (for compatibility) + if not hasattr(self, "_uuid_counter"): + self._uuid_counter = 0 + + result = deterministic_uuid_v5( + self.instance_id, # type: ignore[attr-defined] + self.current_utc_datetime, # type: ignore[attr-defined] + self._uuid_counter, + ) + self._uuid_counter += 1 + return result + + def new_guid(self) -> uuid.UUID: + """Alias for uuid4 for API parity with other SDKs.""" + return self.uuid4() + + def random_string(self, length: int, *, alphabet: Optional[str] = None) -> str: + """Return a deterministically generated random string of the given length.""" + if length < 0: + raise ValueError("length must be non-negative") + chars = alphabet if alphabet is not None else (_string.ascii_letters + _string.digits) + if not chars: + raise ValueError("alphabet must not be empty") + rnd = self.random() + size = len(chars) + return "".join(chars[rnd.randrange(0, size)] for _ in range(length)) + + def random_int(self, min_value: int = 0, max_value: int = 2**31 - 1) -> int: + """Return a deterministic random integer in the specified range.""" + if min_value > max_value: + raise ValueError("min_value must be <= max_value") + rnd = self.random() + return rnd.randint(min_value, max_value) + + T = TypeVar("T") + + def random_choice(self, sequence: Sequence[T]) -> T: + """Return a deterministic random element from a non-empty sequence.""" + if not sequence: + raise IndexError("Cannot choose from empty sequence") + rnd = self.random() + return rnd.choice(sequence) + + def now_with_sequence(self) -> datetime: + """ + Return deterministic timestamp with microsecond increment per call. + + Each call returns: current_utc_datetime + (counter * 1 microsecond) + + This provides ordered, unique timestamps for tracing/telemetry while maintaining + determinism across replays. The counter resets to 0 on each replay (similar to + _uuid_counter pattern). + + Perfect for preserving event ordering within a workflow without requiring activities. + + Returns: + datetime: Deterministic timestamp that increments on each call + + Example: + ```python + def workflow(ctx): + t1 = ctx.now_with_sequence() # 2024-01-01 12:00:00.000000 + result = yield ctx.call_activity(some_activity, input="data") + t2 = ctx.now_with_sequence() # 2024-01-01 12:00:00.000001 + # t1 < t2, preserving order for telemetry + ``` + """ + offset = timedelta(microseconds=self._timestamp_counter) + self._timestamp_counter += 1 + return self.current_utc_datetime + offset # type: ignore[attr-defined] + + def current_utc_datetime_with_sequence(self): + """Alias for now_with_sequence for API parity with other SDKs.""" + return self.now_with_sequence() diff --git a/durabletask/worker.py b/durabletask/worker.py index 8fcc763..29d67fc 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -19,7 +19,7 @@ import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared -from durabletask import task +from durabletask import deterministic, task from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl TInput = TypeVar("TInput") @@ -605,11 +605,14 @@ def _execute_activity( ) -class _RuntimeOrchestrationContext(task.OrchestrationContext): +class _RuntimeOrchestrationContext( + task.OrchestrationContext, deterministic.DeterministicContextMixin +): _generator: Optional[Generator[task.Task, Any, Any]] _previous_task: Optional[task.Task] def __init__(self, instance_id: str): + super().__init__() self._generator = None self._is_replaying = True self._is_complete = False diff --git a/tests/durabletask/test_deterministic.py b/tests/durabletask/test_deterministic.py new file mode 100644 index 0000000..f8f3acf --- /dev/null +++ b/tests/durabletask/test_deterministic.py @@ -0,0 +1,455 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import random +import uuid +from datetime import datetime, timezone + +import pytest + +from durabletask.deterministic import ( + DeterminismSeed, + derive_seed, + deterministic_random, + deterministic_uuid4, + deterministic_uuid_v5, +) +from durabletask.worker import _RuntimeOrchestrationContext + + +class TestDeterminismSeed: + """Test DeterminismSeed dataclass and its methods.""" + + def test_to_int_produces_consistent_result(self): + """Test that to_int produces the same result for same inputs.""" + seed1 = DeterminismSeed(instance_id="test-123", orchestration_unix_ts=1234567890) + seed2 = DeterminismSeed(instance_id="test-123", orchestration_unix_ts=1234567890) + assert seed1.to_int() == seed2.to_int() + + def test_to_int_produces_different_results_for_different_instance_ids(self): + """Test that different instance IDs produce different seeds.""" + seed1 = DeterminismSeed(instance_id="test-123", orchestration_unix_ts=1234567890) + seed2 = DeterminismSeed(instance_id="test-456", orchestration_unix_ts=1234567890) + assert seed1.to_int() != seed2.to_int() + + def test_to_int_produces_different_results_for_different_timestamps(self): + """Test that different timestamps produce different seeds.""" + seed1 = DeterminismSeed(instance_id="test-123", orchestration_unix_ts=1234567890) + seed2 = DeterminismSeed(instance_id="test-123", orchestration_unix_ts=1234567891) + assert seed1.to_int() != seed2.to_int() + + def test_to_int_returns_positive_integer(self): + """Test that to_int returns a positive integer.""" + seed = DeterminismSeed(instance_id="test-123", orchestration_unix_ts=1234567890) + result = seed.to_int() + assert isinstance(result, int) + assert result >= 0 + + +class TestDeriveSeed: + """Test derive_seed function.""" + + def test_derive_seed_is_deterministic(self): + """Test that derive_seed produces consistent results.""" + instance_id = "test-instance" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + seed1 = derive_seed(instance_id, dt) + seed2 = derive_seed(instance_id, dt) + assert seed1 == seed2 + + def test_derive_seed_different_for_different_times(self): + """Test that different times produce different seeds.""" + instance_id = "test-instance" + dt1 = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + dt2 = datetime(2025, 1, 1, 12, 0, 1, tzinfo=timezone.utc) + seed1 = derive_seed(instance_id, dt1) + seed2 = derive_seed(instance_id, dt2) + assert seed1 != seed2 + + def test_derive_seed_handles_timezone_aware_datetime(self): + """Test that derive_seed works with timezone-aware datetimes.""" + instance_id = "test-instance" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + seed = derive_seed(instance_id, dt) + assert isinstance(seed, int) + + +class TestDeterministicRandom: + """Test deterministic_random function.""" + + def test_deterministic_random_returns_random_object(self): + """Test that deterministic_random returns a Random instance.""" + instance_id = "test-instance" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + rnd = deterministic_random(instance_id, dt) + assert isinstance(rnd, random.Random) + + def test_deterministic_random_produces_same_sequence(self): + """Test that same inputs produce same random sequence.""" + instance_id = "test-instance" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + rnd1 = deterministic_random(instance_id, dt) + rnd2 = deterministic_random(instance_id, dt) + + sequence1 = [rnd1.random() for _ in range(10)] + sequence2 = [rnd2.random() for _ in range(10)] + assert sequence1 == sequence2 + + def test_deterministic_random_different_for_different_inputs(self): + """Test that different inputs produce different sequences.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + rnd1 = deterministic_random("instance-1", dt) + rnd2 = deterministic_random("instance-2", dt) + + val1 = rnd1.random() + val2 = rnd2.random() + assert val1 != val2 + + +class TestDeterministicUuid4: + """Test deterministic_uuid4 function.""" + + def test_deterministic_uuid4_returns_valid_uuid(self): + """Test that deterministic_uuid4 returns a valid UUID4.""" + rnd = random.Random(42) + result = deterministic_uuid4(rnd) + assert isinstance(result, uuid.UUID) + assert result.version == 4 + + def test_deterministic_uuid4_is_deterministic(self): + """Test that same random state produces same UUID.""" + rnd1 = random.Random(42) + rnd2 = random.Random(42) + uuid1 = deterministic_uuid4(rnd1) + uuid2 = deterministic_uuid4(rnd2) + assert uuid1 == uuid2 + + def test_deterministic_uuid4_different_for_different_seeds(self): + """Test that different seeds produce different UUIDs.""" + rnd1 = random.Random(42) + rnd2 = random.Random(43) + uuid1 = deterministic_uuid4(rnd1) + uuid2 = deterministic_uuid4(rnd2) + assert uuid1 != uuid2 + + +class TestDeterministicUuidV5: + """Test deterministic_uuid_v5 function (matching .NET implementation).""" + + def test_deterministic_uuid_v5_returns_valid_uuid(self): + """Test that deterministic_uuid_v5 returns a valid UUID v5.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + result = deterministic_uuid_v5("test-instance", dt, 0) + assert isinstance(result, uuid.UUID) + assert result.version == 5 + + def test_deterministic_uuid_v5_is_deterministic(self): + """Test that same inputs produce same UUID.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + uuid1 = deterministic_uuid_v5("test-instance", dt, 0) + uuid2 = deterministic_uuid_v5("test-instance", dt, 0) + assert uuid1 == uuid2 + + def test_deterministic_uuid_v5_different_for_different_counters(self): + """Test that different counters produce different UUIDs.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + uuid1 = deterministic_uuid_v5("test-instance", dt, 0) + uuid2 = deterministic_uuid_v5("test-instance", dt, 1) + assert uuid1 != uuid2 + + def test_deterministic_uuid_v5_different_for_different_instance_ids(self): + """Test that different instance IDs produce different UUIDs.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + uuid1 = deterministic_uuid_v5("instance-1", dt, 0) + uuid2 = deterministic_uuid_v5("instance-2", dt, 0) + assert uuid1 != uuid2 + + def test_deterministic_uuid_v5_different_for_different_datetimes(self): + """Test that different datetimes produce different UUIDs.""" + dt1 = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + dt2 = datetime(2025, 1, 1, 12, 0, 1, tzinfo=timezone.utc) + uuid1 = deterministic_uuid_v5("test-instance", dt1, 0) + uuid2 = deterministic_uuid_v5("test-instance", dt2, 0) + assert uuid1 != uuid2 + + def test_deterministic_uuid_v5_matches_expected_format(self): + """Test that UUID v5 uses the correct namespace.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + result = deterministic_uuid_v5("test-instance", dt, 0) + # Should be deterministic - same inputs always produce same output + expected = deterministic_uuid_v5("test-instance", dt, 0) + assert result == expected + + def test_deterministic_uuid_v5_counter_sequence(self): + """Test that incrementing counter produces different UUIDs in sequence.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + uuids = [deterministic_uuid_v5("test-instance", dt, i) for i in range(5)] + # All should be different + assert len(set(uuids)) == 5 + # But calling with same counter should produce same UUID + assert uuids[0] == deterministic_uuid_v5("test-instance", dt, 0) + assert uuids[4] == deterministic_uuid_v5("test-instance", dt, 4) + + +def mock_deterministic_context( + instance_id: str, current_utc_datetime: datetime +) -> _RuntimeOrchestrationContext: + """Mock context for testing DeterministicContextMixin.""" + ctx = _RuntimeOrchestrationContext(instance_id) + ctx.current_utc_datetime = current_utc_datetime + return ctx + + +class TestDeterministicContextMixin: + """Test DeterministicContextMixin methods.""" + + def test_now_returns_current_utc_datetime(self): + """Test that now() returns the orchestration time.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + assert ctx.now() == dt + + def test_random_returns_deterministic_prng(self): + """Test that random() returns a deterministic PRNG.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + rnd1 = ctx.random() + rnd2 = ctx.random() + + # Both should produce same sequence + assert isinstance(rnd1, random.Random) + assert isinstance(rnd2, random.Random) + assert rnd1.random() == rnd2.random() + + def test_random_has_deterministic_marker(self): + """Test that random() sets _dt_deterministic marker.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + rnd = ctx.random() + assert hasattr(rnd, "_dt_deterministic") + assert rnd._dt_deterministic is True + + def test_uuid4_generates_deterministic_uuid(self): + """Test that uuid4() generates deterministic UUIDs v5 with counter.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx1 = mock_deterministic_context("test-instance", dt) + ctx2 = mock_deterministic_context("test-instance", dt) + + uuid1 = ctx1.uuid4() + uuid2 = ctx2.uuid4() + + assert isinstance(uuid1, uuid.UUID) + assert uuid1.version == 5 # Now using UUID v5 like .NET + assert uuid1 == uuid2 # Same counter (0) produces same UUID + + def test_uuid4_increments_counter(self): + """Test that uuid4() increments counter producing different UUIDs.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + uuid1 = ctx.uuid4() # counter=0 + uuid2 = ctx.uuid4() # counter=1 + uuid3 = ctx.uuid4() # counter=2 + + # All should be different due to counter + assert uuid1 != uuid2 + assert uuid2 != uuid3 + assert uuid1 != uuid3 + + def test_uuid4_counter_resets_on_replay(self): + """Test that counter resets on new context (simulating replay).""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + + # First execution + ctx1 = mock_deterministic_context("test-instance", dt) + uuid1_first = ctx1.uuid4() # counter=0 + uuid1_second = ctx1.uuid4() # counter=1 + + # Replay - new context, counter resets + ctx2 = mock_deterministic_context("test-instance", dt) + uuid2_first = ctx2.uuid4() # counter=0 + uuid2_second = ctx2.uuid4() # counter=1 + + # Same counter values produce same UUIDs (determinism!) + assert uuid1_first == uuid2_first + assert uuid1_second == uuid2_second + + def test_new_guid_is_alias_for_uuid4(self): + """Test that new_guid() is an alias for uuid4().""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + guid1 = ctx.new_guid() # counter=0 + guid2 = ctx.uuid4() # counter=1 + + # Both should be v5 UUIDs, but different due to counter increment + assert isinstance(guid1, uuid.UUID) + assert isinstance(guid2, uuid.UUID) + assert guid1.version == 5 + assert guid2.version == 5 + assert guid1 != guid2 # Different due to counter + + # Verify determinism - same counter produces same UUID + ctx2 = mock_deterministic_context("test-instance", dt) + guid3 = ctx2.new_guid() # counter=0 + assert guid3 == guid1 # Same as first call + + def test_random_string_generates_string_of_correct_length(self): + """Test that random_string() generates string of specified length.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + s = ctx.random_string(10) + assert len(s) == 10 + + def test_random_string_is_deterministic(self): + """Test that random_string() produces consistent results.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx1 = mock_deterministic_context("test-instance", dt) + ctx2 = mock_deterministic_context("test-instance", dt) + + s1 = ctx1.random_string(20) + s2 = ctx2.random_string(20) + assert s1 == s2 + + def test_random_string_uses_default_alphabet(self): + """Test that random_string() uses alphanumeric characters by default.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + s = ctx.random_string(100) + assert all(c.isalnum() for c in s) + + def test_random_string_uses_custom_alphabet(self): + """Test that random_string() respects custom alphabet.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + s = ctx.random_string(50, alphabet="ABC") + assert all(c in "ABC" for c in s) + + def test_random_string_raises_on_negative_length(self): + """Test that random_string() raises ValueError for negative length.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + with pytest.raises(ValueError, match="length must be non-negative"): + ctx.random_string(-1) + + def test_random_string_raises_on_empty_alphabet(self): + """Test that random_string() raises ValueError for empty alphabet.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + with pytest.raises(ValueError, match="alphabet must not be empty"): + ctx.random_string(10, alphabet="") + + def test_random_string_handles_zero_length(self): + """Test that random_string() handles zero length correctly.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + s = ctx.random_string(0) + assert s == "" + + def test_random_int_generates_int_in_range(self): + """Test that random_int() generates integer in specified range.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + for _ in range(10): + val = ctx.random_int(10, 20) + assert 10 <= val <= 20 + + def test_random_int_is_deterministic(self): + """Test that random_int() produces consistent results.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx1 = mock_deterministic_context("test-instance", dt) + ctx2 = mock_deterministic_context("test-instance", dt) + + val1 = ctx1.random_int(0, 1000) + val2 = ctx2.random_int(0, 1000) + assert val1 == val2 + + def test_random_int_uses_default_range(self): + """Test that random_int() uses default range when not specified.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + val = ctx.random_int() + assert 0 <= val <= 2**31 - 1 + + def test_random_int_raises_on_invalid_range(self): + """Test that random_int() raises ValueError when min > max.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + with pytest.raises(ValueError, match="min_value must be <= max_value"): + ctx.random_int(20, 10) + + def test_random_int_handles_same_min_and_max(self): + """Test that random_int() handles case where min equals max.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + val = ctx.random_int(42, 42) + assert val == 42 + + def test_random_choice_picks_from_sequence(self): + """Test that random_choice() picks element from sequence.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + choices = ["a", "b", "c", "d", "e"] + result = ctx.random_choice(choices) + assert result in choices + + def test_random_choice_is_deterministic(self): + """Test that random_choice() produces consistent results.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx1 = mock_deterministic_context("test-instance", dt) + ctx2 = mock_deterministic_context("test-instance", dt) + + choices = list(range(100)) + result1 = ctx1.random_choice(choices) + result2 = ctx2.random_choice(choices) + assert result1 == result2 + + def test_random_choice_raises_on_empty_sequence(self): + """Test that random_choice() raises IndexError for empty sequence.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + with pytest.raises(IndexError, match="Cannot choose from empty sequence"): + ctx.random_choice([]) + + def test_random_choice_works_with_different_sequence_types(self): + """Test that random_choice() works with various sequence types.""" + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + ctx = mock_deterministic_context("test-instance", dt) + + # List + result = ctx.random_choice([1, 2, 3]) + assert result in [1, 2, 3] + + # Reset context for deterministic behavior + ctx = mock_deterministic_context("test-instance", dt) + # Tuple + result = ctx.random_choice((1, 2, 3)) + assert result in (1, 2, 3) + + # Reset context for deterministic behavior + ctx = mock_deterministic_context("test-instance", dt) + # String + result = ctx.random_choice("abc") + assert result in "abc" diff --git a/tests/durabletask/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py index 9debf39..181d71d 100644 --- a/tests/durabletask/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -594,3 +594,182 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): assert state.serialized_input is None assert state.serialized_output is None assert state.serialized_custom_status == '"foobaz"' + + +def test_now_with_sequence_ordering(): + """ + Test that now_with_sequence() maintains strict ordering across workflow execution. + + This verifies: + 1. Timestamps increment sequentially + 2. Order is preserved across activity calls + 3. Deterministic behavior (timestamps are consistent on replay) + """ + + def simple_activity(ctx, input_val: str): + return f"activity_{input_val}_done" + + def timestamp_ordering_workflow(ctx: task.OrchestrationContext, _): + timestamps = [] + + # First timestamp before any activities + t1 = ctx.now_with_sequence() + timestamps.append(("t1_before_activities", t1.isoformat())) + + # Call first activity + result1 = yield ctx.call_activity(simple_activity, input="first") + timestamps.append(("activity_1_result", result1)) + + # Timestamp after first activity + t2 = ctx.now_with_sequence() + timestamps.append(("t2_after_activity_1", t2.isoformat())) + + # Call second activity + result2 = yield ctx.call_activity(simple_activity, input="second") + timestamps.append(("activity_2_result", result2)) + + # Timestamp after second activity + t3 = ctx.now_with_sequence() + timestamps.append(("t3_after_activity_2", t3.isoformat())) + + # A few more rapid timestamps to test counter incrementing + t4 = ctx.now_with_sequence() + timestamps.append(("t4_rapid", t4.isoformat())) + + t5 = ctx.now_with_sequence() + timestamps.append(("t5_rapid", t5.isoformat())) + + # Return all timestamps for verification + return { + "timestamps": timestamps, + "t1": t1.isoformat(), + "t2": t2.isoformat(), + "t3": t3.isoformat(), + "t4": t4.isoformat(), + "t5": t5.isoformat(), + } + + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: + w.add_orchestrator(timestamp_ordering_workflow) + w.add_activity(simple_activity) + w.start() + + with client.TaskHubGrpcClient() as c: + instance_id = c.schedule_new_orchestration(timestamp_ordering_workflow) + state = c.wait_for_orchestration_completion( + instance_id, timeout=30, fetch_payloads=True + ) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.failure_details is None + + # Parse result + result = json.loads(state.serialized_output) + assert result is not None + + # Verify all timestamps are present + assert "t1" in result + assert "t2" in result + assert "t3" in result + assert "t4" in result + assert "t5" in result + + # Parse timestamps back to datetime objects for comparison + from datetime import datetime + + t1 = datetime.fromisoformat(result["t1"]) + t2 = datetime.fromisoformat(result["t2"]) + t3 = datetime.fromisoformat(result["t3"]) + t4 = datetime.fromisoformat(result["t4"]) + t5 = datetime.fromisoformat(result["t5"]) + + # Verify strict ordering: t1 < t2 < t3 < t4 < t5 + # This is the key guarantee - timestamps must maintain order for tracing + assert t1 < t2, f"t1 ({t1}) should be < t2 ({t2})" + assert t2 < t3, f"t2 ({t2}) should be < t3 ({t3})" + assert t3 < t4, f"t3 ({t3}) should be < t4 ({t4})" + assert t4 < t5, f"t4 ({t4}) should be < t5 ({t5})" + + # Verify that timestamps called in rapid succession (t3, t4, t5 with no activities between) + # have exactly 1 microsecond deltas. These happen within the same replay execution. + delta_t3_t4 = (t4 - t3).total_seconds() * 1_000_000 + delta_t4_t5 = (t5 - t4).total_seconds() * 1_000_000 + + assert delta_t3_t4 == 1.0, f"t3 to t4 should be 1 microsecond, got {delta_t3_t4}" + assert delta_t4_t5 == 1.0, f"t4 to t5 should be 1 microsecond, got {delta_t4_t5}" + + # Note: We don't check exact deltas for t1->t2 or t2->t3 because they span + # activity calls. During replay, current_utc_datetime changes based on event + # timestamps, so the base time shifts. However, ordering is still guaranteed. + + +def test_cannot_add_orchestrator_while_running(): + """Test that orchestrators cannot be added while the worker is running.""" + + def orchestrator(ctx: task.OrchestrationContext, _): + return "done" + + def another_orchestrator(ctx: task.OrchestrationContext, _): + return "another" + + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: + w.add_orchestrator(orchestrator) + w.start() + + # Try to add another orchestrator while running + with pytest.raises( + RuntimeError, match="Orchestrators cannot be added while the worker is running" + ): + w.add_orchestrator(another_orchestrator) + + +def test_cannot_add_activity_while_running(): + """Test that activities cannot be added while the worker is running.""" + + def activity(ctx: task.ActivityContext, input): + return input + + def another_activity(ctx: task.ActivityContext, input): + return input * 2 + + def orchestrator(ctx: task.OrchestrationContext, _): + return "done" + + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: + w.add_orchestrator(orchestrator) + w.add_activity(activity) + w.start() + + # Try to add another activity while running + with pytest.raises( + RuntimeError, match="Activities cannot be added while the worker is running" + ): + w.add_activity(another_activity) + + +def test_can_add_functions_after_stop(): + """Test that orchestrators/activities can be added after stopping the worker.""" + + def orchestrator1(ctx: task.OrchestrationContext, _): + return "done" + + def orchestrator2(ctx: task.OrchestrationContext, _): + return "done2" + + def activity(ctx: task.ActivityContext, input): + return input + + with worker.TaskHubGrpcWorker(stop_timeout=2.0) as w: + w.add_orchestrator(orchestrator1) + w.start() + + c = client.TaskHubGrpcClient() + id = c.schedule_new_orchestration(orchestrator1) + state = c.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + + # Should be able to add after stop + w.add_orchestrator(orchestrator2) + w.add_activity(activity) diff --git a/tests/durabletask/test_registry.py b/tests/durabletask/test_registry.py new file mode 100644 index 0000000..743330c --- /dev/null +++ b/tests/durabletask/test_registry.py @@ -0,0 +1,164 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Unit tests for the _Registry class validation logic.""" + +import pytest + +from durabletask import worker + + +def test_registry_add_orchestrator_none(): + """Test that adding a None orchestrator raises ValueError.""" + registry = worker._Registry() + + with pytest.raises(ValueError, match="An orchestrator function argument is required"): + registry.add_orchestrator(None) + + +def test_registry_add_named_orchestrator_empty_name(): + """Test that adding an orchestrator with empty name raises ValueError.""" + registry = worker._Registry() + + def dummy_orchestrator(ctx, input): + return "done" + + with pytest.raises(ValueError, match="A non-empty orchestrator name is required"): + registry.add_named_orchestrator("", dummy_orchestrator) + + +def test_registry_add_orchestrator_duplicate(): + """Test that adding a duplicate orchestrator raises ValueError.""" + registry = worker._Registry() + + def dummy_orchestrator(ctx, input): + return "done" + + name = "test_orchestrator" + registry.add_named_orchestrator(name, dummy_orchestrator) + + with pytest.raises(ValueError, match=f"A '{name}' orchestrator already exists"): + registry.add_named_orchestrator(name, dummy_orchestrator) + + +def test_registry_add_activity_none(): + """Test that adding a None activity raises ValueError.""" + registry = worker._Registry() + + with pytest.raises(ValueError, match="An activity function argument is required"): + registry.add_activity(None) + + +def test_registry_add_named_activity_empty_name(): + """Test that adding an activity with empty name raises ValueError.""" + registry = worker._Registry() + + def dummy_activity(ctx, input): + return "done" + + with pytest.raises(ValueError, match="A non-empty activity name is required"): + registry.add_named_activity("", dummy_activity) + + +def test_registry_add_activity_duplicate(): + """Test that adding a duplicate activity raises ValueError.""" + registry = worker._Registry() + + def dummy_activity(ctx, input): + return "done" + + name = "test_activity" + registry.add_named_activity(name, dummy_activity) + + with pytest.raises(ValueError, match=f"A '{name}' activity already exists"): + registry.add_named_activity(name, dummy_activity) + + +def test_registry_get_orchestrator_exists(): + """Test retrieving an existing orchestrator.""" + registry = worker._Registry() + + def dummy_orchestrator(ctx, input): + return "done" + + name = registry.add_orchestrator(dummy_orchestrator) + retrieved = registry.get_orchestrator(name) + + assert retrieved is dummy_orchestrator + + +def test_registry_get_orchestrator_not_exists(): + """Test retrieving a non-existent orchestrator returns None.""" + registry = worker._Registry() + + retrieved = registry.get_orchestrator("non_existent") + + assert retrieved is None + + +def test_registry_get_activity_exists(): + """Test retrieving an existing activity.""" + registry = worker._Registry() + + def dummy_activity(ctx, input): + return "done" + + name = registry.add_activity(dummy_activity) + retrieved = registry.get_activity(name) + + assert retrieved is dummy_activity + + +def test_registry_get_activity_not_exists(): + """Test retrieving a non-existent activity returns None.""" + registry = worker._Registry() + + retrieved = registry.get_activity("non_existent") + + assert retrieved is None + + +def test_registry_add_multiple_orchestrators(): + """Test adding multiple different orchestrators.""" + registry = worker._Registry() + + def orchestrator1(ctx, input): + return "one" + + def orchestrator2(ctx, input): + return "two" + + name1 = registry.add_orchestrator(orchestrator1) + name2 = registry.add_orchestrator(orchestrator2) + + assert name1 != name2 + assert registry.get_orchestrator(name1) is orchestrator1 + assert registry.get_orchestrator(name2) is orchestrator2 + + +def test_registry_add_multiple_activities(): + """Test adding multiple different activities.""" + registry = worker._Registry() + + def activity1(ctx, input): + return "one" + + def activity2(ctx, input): + return "two" + + name1 = registry.add_activity(activity1) + name2 = registry.add_activity(activity2) + + assert name1 != name2 + assert registry.get_activity(name1) is activity1 + assert registry.get_activity(name2) is activity2 diff --git a/tests/durabletask/test_serialization.py b/tests/durabletask/test_serialization.py new file mode 100644 index 0000000..68f7f14 --- /dev/null +++ b/tests/durabletask/test_serialization.py @@ -0,0 +1,87 @@ +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from collections import namedtuple +from dataclasses import dataclass +from types import SimpleNamespace + +from durabletask.internal.shared import AUTO_SERIALIZED, from_json, to_json + + +@dataclass +class SamplePayload: + count: int + name: str + + +def test_to_json_roundtrip_dataclass(): + payload = SamplePayload(count=5, name="widgets") + encoded = to_json(payload) + + assert AUTO_SERIALIZED in encoded + + decoded = from_json(encoded) + assert isinstance(decoded, SimpleNamespace) + assert decoded.count == 5 + assert decoded.name == "widgets" + + +def test_to_json_roundtrip_simplenamespace(): + payload = SimpleNamespace(foo="bar", baz=42) + encoded = to_json(payload) + + assert AUTO_SERIALIZED in encoded + + decoded = from_json(encoded) + assert isinstance(decoded, SimpleNamespace) + assert decoded.foo == "bar" + assert decoded.baz == 42 + + +def test_to_json_plain_dict_passthrough(): + payload = {"foo": "bar", "baz": 1} + encoded = to_json(payload) + + assert AUTO_SERIALIZED not in encoded + + decoded = from_json(encoded) + assert isinstance(decoded, dict) + assert decoded == {"foo": "bar", "baz": 1} + + +def test_to_json_namedtuple_roundtrip(): + Point = namedtuple("Point", ["x", "y"]) + payload = Point(3, 4) + encoded = to_json(payload) + + assert AUTO_SERIALIZED in encoded + + decoded = from_json(encoded) + assert isinstance(decoded, SimpleNamespace) + assert decoded.x == 3 + assert decoded.y == 4 + + +def test_to_json_nested_dataclass_collection(): + payload = [ + SamplePayload(count=1, name="first"), + SamplePayload(count=2, name="second"), + ] + encoded = to_json(payload) + + assert encoded.count(AUTO_SERIALIZED) >= 2 + + decoded = from_json(encoded) + assert isinstance(decoded, list) + assert [item.count for item in decoded] == [1, 2] + assert [item.name for item in decoded] == ["first", "second"]