1- from unittest .mock import patch , ANY
1+ from unittest .mock import patch , ANY , Mock
22
3+ from durabletask .client import TaskHubGrpcClient
34from durabletask .internal .shared import (DefaultClientInterceptorImpl ,
45 get_default_host_address ,
56 get_grpc_channel )
7+ import pytest
68
79HOST_ADDRESS = 'localhost:50051'
810METADATA = [('key1' , 'value1' ), ('key2' , 'value2' )]
@@ -85,4 +87,61 @@ def test_grpc_channel_with_host_name_protocol_stripping():
8587
8688 prefix = ""
8789 get_grpc_channel (prefix + host_name , METADATA , True )
88- mock_secure_channel .assert_called_with (host_name , ANY )
90+ mock_secure_channel .assert_called_with (host_name , ANY )
91+
92+
93+ @pytest .mark .parametrize ("timeout" , [None , 0 , 5 ])
94+ def test_wait_for_orchestration_start_timeout (timeout ):
95+ instance_id = "test-instance"
96+
97+ from durabletask .internal .orchestrator_service_pb2 import GetInstanceResponse , \
98+ OrchestrationState , ORCHESTRATION_STATUS_RUNNING
99+
100+ response = GetInstanceResponse ()
101+ state = OrchestrationState ()
102+ state .instanceId = instance_id
103+ state .orchestrationStatus = ORCHESTRATION_STATUS_RUNNING
104+ response .orchestrationState .CopyFrom (state )
105+
106+ c = TaskHubGrpcClient ()
107+ c ._stub = Mock ()
108+ c ._stub .WaitForInstanceStart .return_value = response
109+
110+ grpc_timeout = None if timeout is None else timeout
111+ c .wait_for_orchestration_start (instance_id , timeout = grpc_timeout )
112+
113+ # Verify WaitForInstanceStart was called with timeout=None
114+ c ._stub .WaitForInstanceStart .assert_called_once ()
115+ _ , kwargs = c ._stub .WaitForInstanceStart .call_args
116+ if timeout is None or timeout == 0 :
117+ assert kwargs .get ('timeout' ) is None
118+ else :
119+ assert kwargs .get ('timeout' ) == timeout
120+
121+ @pytest .mark .parametrize ("timeout" , [None , 0 , 5 ])
122+ def test_wait_for_orchestration_completion_timeout (timeout ):
123+ instance_id = "test-instance"
124+
125+ from durabletask .internal .orchestrator_service_pb2 import GetInstanceResponse , \
126+ OrchestrationState , ORCHESTRATION_STATUS_COMPLETED
127+
128+ response = GetInstanceResponse ()
129+ state = OrchestrationState ()
130+ state .instanceId = instance_id
131+ state .orchestrationStatus = ORCHESTRATION_STATUS_COMPLETED
132+ response .orchestrationState .CopyFrom (state )
133+
134+ c = TaskHubGrpcClient ()
135+ c ._stub = Mock ()
136+ c ._stub .WaitForInstanceCompletion .return_value = response
137+
138+ grpc_timeout = None if timeout is None else timeout
139+ c .wait_for_orchestration_completion (instance_id , timeout = grpc_timeout )
140+
141+ # Verify WaitForInstanceStart was called with timeout=None
142+ c ._stub .WaitForInstanceCompletion .assert_called_once ()
143+ _ , kwargs = c ._stub .WaitForInstanceCompletion .call_args
144+ if timeout is None or timeout == 0 :
145+ assert kwargs .get ('timeout' ) is None
146+ else :
147+ assert kwargs .get ('timeout' ) == timeout
0 commit comments