Skip to content

Commit ecccbef

Browse files
committed
Enabled context manager in client.
Added copyright headers on new files Signed-off-by: Patrick Assuied <[email protected]>
1 parent eac4b82 commit ecccbef

File tree

6 files changed

+116
-98
lines changed

6 files changed

+116
-98
lines changed

durabletask/aio/client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Copyright (c) The Dapr Authors.
2+
# Licensed under the MIT License.
3+
14
import logging
25
import uuid
36
from datetime import datetime
@@ -47,6 +50,13 @@ def __init__(self, *,
4750
async def aclose(self):
4851
await self._channel.close()
4952

53+
async def __aenter__(self):
54+
return self
55+
56+
async def __aexit__(self, exc_type, exc_val, exc_tb):
57+
await self.aclose()
58+
return False
59+
5060
async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
5161
input: Optional[TInput] = None,
5262
instance_id: Optional[str] = None,

durabletask/aio/internal/grpc_interceptor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Copyright (c) The Dapr Authors.
2+
# Licensed under the MIT License.
3+
14
from collections import namedtuple
25

36
from grpc import aio as grpc_aio

durabletask/aio/internal/shared.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Copyright (c) The Dapr Authors.
2+
# Licensed under the MIT License.
3+
14
from typing import Optional, Sequence, Union
25

36
import grpc

tests/durabletask/test_client_async.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Copyright (c) The Dapr Authors.
2+
# Licensed under the MIT License.
3+
14
from unittest.mock import ANY, patch
25

36
from durabletask.aio.internal.grpc_interceptor import DefaultClientInterceptorImpl

tests/durabletask/test_orchestration_e2e_async.py

Lines changed: 93 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Copyright (c) The Dapr Authors.
2+
# Licensed under the MIT License.
3+
14
import asyncio
25
import json
36
import threading
@@ -235,34 +238,33 @@ def orchestrator(ctx: task.OrchestrationContext, _):
235238
# there could be a race condition if the workflow is scheduled before orchestrator is started
236239
await asyncio.sleep(0.2)
237240

238-
client = AsyncTaskHubGrpcClient()
239-
id = await client.schedule_new_orchestration(orchestrator)
240-
state = await client.wait_for_orchestration_start(id, timeout=30)
241-
assert state is not None
242-
243-
# Suspend the orchestration and wait for it to go into the SUSPENDED state
244-
await client.suspend_orchestration(id)
245-
while state.runtime_status == OrchestrationStatus.RUNNING:
246-
await asyncio.sleep(0.1)
247-
state = await client.get_orchestration_state(id)
241+
async with AsyncTaskHubGrpcClient() as client:
242+
id = await client.schedule_new_orchestration(orchestrator)
243+
state = await client.wait_for_orchestration_start(id, timeout=30)
248244
assert state is not None
249-
assert state.runtime_status == OrchestrationStatus.SUSPENDED
250245

251-
# Raise an event to the orchestration and confirm that it does NOT complete
252-
await client.raise_orchestration_event(id, "my_event", data=42)
253-
try:
254-
state = await client.wait_for_orchestration_completion(id, timeout=3)
255-
assert False, "Orchestration should not have completed"
256-
except TimeoutError:
257-
pass
258-
259-
# Resume the orchestration and wait for it to complete
260-
await client.resume_orchestration(id)
261-
state = await client.wait_for_orchestration_completion(id, timeout=30)
262-
assert state is not None
263-
assert state.runtime_status == OrchestrationStatus.COMPLETED
264-
assert state.serialized_output == json.dumps(42)
265-
await client.aclose()
246+
# Suspend the orchestration and wait for it to go into the SUSPENDED state
247+
await client.suspend_orchestration(id)
248+
while state.runtime_status == OrchestrationStatus.RUNNING:
249+
await asyncio.sleep(0.1)
250+
state = await client.get_orchestration_state(id)
251+
assert state is not None
252+
assert state.runtime_status == OrchestrationStatus.SUSPENDED
253+
254+
# Raise an event to the orchestration and confirm that it does NOT complete
255+
await client.raise_orchestration_event(id, "my_event", data=42)
256+
try:
257+
state = await client.wait_for_orchestration_completion(id, timeout=3)
258+
assert False, "Orchestration should not have completed"
259+
except TimeoutError:
260+
pass
261+
262+
# Resume the orchestration and wait for it to complete
263+
await client.resume_orchestration(id)
264+
state = await client.wait_for_orchestration_completion(id, timeout=30)
265+
assert state is not None
266+
assert state.runtime_status == OrchestrationStatus.COMPLETED
267+
assert state.serialized_output == json.dumps(42)
266268

267269

268270
async def test_terminate():
@@ -275,18 +277,17 @@ def orchestrator(ctx: task.OrchestrationContext, _):
275277
w.add_orchestrator(orchestrator)
276278
w.start()
277279

278-
client = AsyncTaskHubGrpcClient()
279-
id = await client.schedule_new_orchestration(orchestrator)
280-
state = await client.wait_for_orchestration_start(id, timeout=30)
281-
assert state is not None
282-
assert state.runtime_status == OrchestrationStatus.RUNNING
280+
async with AsyncTaskHubGrpcClient() as client:
281+
id = await client.schedule_new_orchestration(orchestrator)
282+
state = await client.wait_for_orchestration_start(id, timeout=30)
283+
assert state is not None
284+
assert state.runtime_status == OrchestrationStatus.RUNNING
283285

284-
await client.terminate_orchestration(id, output="some reason for termination")
285-
state = await client.wait_for_orchestration_completion(id, timeout=30)
286-
assert state is not None
287-
assert state.runtime_status == OrchestrationStatus.TERMINATED
288-
assert state.serialized_output == json.dumps("some reason for termination")
289-
await client.aclose()
286+
await client.terminate_orchestration(id, output="some reason for termination")
287+
state = await client.wait_for_orchestration_completion(id, timeout=30)
288+
assert state is not None
289+
assert state.runtime_status == OrchestrationStatus.TERMINATED
290+
assert state.serialized_output == json.dumps("some reason for termination")
290291

291292

292293
async def test_terminate_recursive():
@@ -304,27 +305,26 @@ def child(ctx: task.OrchestrationContext, _):
304305
w.add_orchestrator(child)
305306
w.start()
306307

307-
client = AsyncTaskHubGrpcClient()
308-
id = await client.schedule_new_orchestration(root)
309-
state = await client.wait_for_orchestration_start(id, timeout=30)
310-
assert state is not None
311-
assert state.runtime_status == OrchestrationStatus.RUNNING
308+
async with AsyncTaskHubGrpcClient() as client:
309+
id = await client.schedule_new_orchestration(root)
310+
state = await client.wait_for_orchestration_start(id, timeout=30)
311+
assert state is not None
312+
assert state.runtime_status == OrchestrationStatus.RUNNING
312313

313-
# Terminate root orchestration(recursive set to True by default)
314-
await client.terminate_orchestration(id, output="some reason for termination")
315-
state = await client.wait_for_orchestration_completion(id, timeout=30)
316-
assert state is not None
317-
assert state.runtime_status == OrchestrationStatus.TERMINATED
314+
# Terminate root orchestration(recursive set to True by default)
315+
await client.terminate_orchestration(id, output="some reason for termination")
316+
state = await client.wait_for_orchestration_completion(id, timeout=30)
317+
assert state is not None
318+
assert state.runtime_status == OrchestrationStatus.TERMINATED
318319

319-
# Verify that child orchestration is also terminated
320-
await client.wait_for_orchestration_completion(id, timeout=30)
321-
assert state is not None
322-
assert state.runtime_status == OrchestrationStatus.TERMINATED
320+
# Verify that child orchestration is also terminated
321+
await client.wait_for_orchestration_completion(id, timeout=30)
322+
assert state is not None
323+
assert state.runtime_status == OrchestrationStatus.TERMINATED
323324

324-
await client.purge_orchestration(id)
325-
state = await client.get_orchestration_state(id)
326-
assert state is None
327-
await client.aclose()
325+
await client.purge_orchestration(id)
326+
state = await client.get_orchestration_state(id)
327+
assert state is None
328328

329329

330330
async def test_continue_as_new():
@@ -347,21 +347,20 @@ def orchestrator(ctx: task.OrchestrationContext, input: int):
347347
w.add_orchestrator(orchestrator)
348348
w.start()
349349

350-
client = AsyncTaskHubGrpcClient()
351-
id = await client.schedule_new_orchestration(orchestrator, input=0)
352-
await client.raise_orchestration_event(id, "my_event", data=1)
353-
await client.raise_orchestration_event(id, "my_event", data=2)
354-
await client.raise_orchestration_event(id, "my_event", data=3)
355-
await client.raise_orchestration_event(id, "my_event", data=4)
356-
await client.raise_orchestration_event(id, "my_event", data=5)
350+
async with AsyncTaskHubGrpcClient() as client:
351+
id = await client.schedule_new_orchestration(orchestrator, input=0)
352+
await client.raise_orchestration_event(id, "my_event", data=1)
353+
await client.raise_orchestration_event(id, "my_event", data=2)
354+
await client.raise_orchestration_event(id, "my_event", data=3)
355+
await client.raise_orchestration_event(id, "my_event", data=4)
356+
await client.raise_orchestration_event(id, "my_event", data=5)
357357

358-
state = await client.wait_for_orchestration_completion(id, timeout=30)
359-
assert state is not None
360-
assert state.runtime_status == OrchestrationStatus.COMPLETED
361-
assert state.serialized_output == json.dumps(all_results)
362-
assert state.serialized_input == json.dumps(4)
363-
assert all_results == [1, 2, 3, 4, 5]
364-
await client.aclose()
358+
state = await client.wait_for_orchestration_completion(id, timeout=30)
359+
assert state is not None
360+
assert state.runtime_status == OrchestrationStatus.COMPLETED
361+
assert state.serialized_output == json.dumps(all_results)
362+
assert state.serialized_input == json.dumps(4)
363+
assert all_results == [1, 2, 3, 4, 5]
365364

366365

367366
async def test_retry_policies():
@@ -405,19 +404,18 @@ def throw_activity_with_retry(ctx: task.ActivityContext, _):
405404
w.add_activity(throw_activity_with_retry)
406405
w.start()
407406

408-
client = AsyncTaskHubGrpcClient()
409-
id = await client.schedule_new_orchestration(parent_orchestrator_with_retry)
410-
state = await client.wait_for_orchestration_completion(id, timeout=30)
411-
assert state is not None
412-
assert state.runtime_status == OrchestrationStatus.FAILED
413-
assert state.failure_details is not None
414-
assert state.failure_details.error_type == "TaskFailedError"
415-
assert state.failure_details.message.startswith("Sub-orchestration task #1 failed:")
416-
assert state.failure_details.message.endswith("Activity task #1 failed: Kah-BOOOOM!!!")
417-
assert state.failure_details.stack_trace is not None
418-
assert throw_activity_counter == 9
419-
assert child_orch_counter == 3
420-
await client.aclose()
407+
async with AsyncTaskHubGrpcClient() as client:
408+
id = await client.schedule_new_orchestration(parent_orchestrator_with_retry)
409+
state = await client.wait_for_orchestration_completion(id, timeout=30)
410+
assert state is not None
411+
assert state.runtime_status == OrchestrationStatus.FAILED
412+
assert state.failure_details is not None
413+
assert state.failure_details.error_type == "TaskFailedError"
414+
assert state.failure_details.message.startswith("Sub-orchestration task #1 failed:")
415+
assert state.failure_details.message.endswith("Activity task #1 failed: Kah-BOOOOM!!!")
416+
assert state.failure_details.stack_trace is not None
417+
assert throw_activity_counter == 9
418+
assert child_orch_counter == 3
421419

422420

423421
async def test_retry_timeout():
@@ -446,17 +444,16 @@ def throw_activity(ctx: task.ActivityContext, _):
446444
w.add_activity(throw_activity)
447445
w.start()
448446

449-
client = AsyncTaskHubGrpcClient()
450-
id = await client.schedule_new_orchestration(mock_orchestrator)
451-
state = await client.wait_for_orchestration_completion(id, timeout=30)
452-
assert state is not None
453-
assert state.runtime_status == OrchestrationStatus.FAILED
454-
assert state.failure_details is not None
455-
assert state.failure_details.error_type == "TaskFailedError"
456-
assert state.failure_details.message.endswith("Activity task #1 failed: Kah-BOOOOM!!!")
457-
assert state.failure_details.stack_trace is not None
458-
assert throw_activity_counter == 4
459-
await client.aclose()
447+
async with AsyncTaskHubGrpcClient() as client:
448+
id = await client.schedule_new_orchestration(mock_orchestrator)
449+
state = await client.wait_for_orchestration_completion(id, timeout=30)
450+
assert state is not None
451+
assert state.runtime_status == OrchestrationStatus.FAILED
452+
assert state.failure_details is not None
453+
assert state.failure_details.error_type == "TaskFailedError"
454+
assert state.failure_details.message.endswith("Activity task #1 failed: Kah-BOOOOM!!!")
455+
assert state.failure_details.stack_trace is not None
456+
assert throw_activity_counter == 4
460457

461458

462459
async def test_custom_status():
@@ -469,10 +466,9 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _):
469466
w.add_orchestrator(empty_orchestrator)
470467
w.start()
471468

472-
c = AsyncTaskHubGrpcClient()
473-
id = await c.schedule_new_orchestration(empty_orchestrator)
474-
state = await c.wait_for_orchestration_completion(id, timeout=30)
475-
await c.aclose()
469+
async with AsyncTaskHubGrpcClient() as client:
470+
id = await client.schedule_new_orchestration(empty_orchestrator)
471+
state = await client.wait_for_orchestration_completion(id, timeout=30)
476472

477473
assert state is not None
478474
assert state.name == task.get_name(empty_orchestrator)

tests/durabletask/test_task.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Copyright (c) The Dapr Authors.
2+
# Licensed under the MIT License.
3+
14
"""Unit tests for durabletask.task primitives."""
25

36
from durabletask import task
@@ -10,6 +13,7 @@ def test_when_all_empty_returns_successfully():
1013
assert when_all_task.is_complete
1114
assert when_all_task.get_result() == []
1215

16+
1317
def test_when_any_empty_returns_successfully():
1418
"""task.when_any([]) should complete immediately and return an empty list."""
1519
when_any_task = task.when_any([])
@@ -64,4 +68,3 @@ def test_when_any_happy_path_returns_winner_task_and_completes_on_first():
6468
a.complete("A")
6569

6670
assert any_task.get_result() is b
67-

0 commit comments

Comments
 (0)