@@ -163,10 +163,11 @@ def _get_request_timeout(settings):
163
163
164
164
165
165
class EndpointOptions (object ):
166
- __slots__ = ("ssl_target_name_override" ,)
166
+ __slots__ = ("ssl_target_name_override" , "node_id" )
167
167
168
- def __init__ (self , ssl_target_name_override = None ):
168
+ def __init__ (self , ssl_target_name_override = None , node_id = None ):
169
169
self .ssl_target_name_override = ssl_target_name_override
170
+ self .node_id = node_id
170
171
171
172
172
173
def _construct_channel_options (driver_config , endpoint_options = None ):
@@ -223,16 +224,18 @@ class _RpcState(object):
223
224
"endpoint" ,
224
225
"rendezvous" ,
225
226
"metadata_kv" ,
227
+ "endpoint_key" ,
226
228
)
227
229
228
- def __init__ (self , stub_instance , rpc_name , endpoint ):
230
+ def __init__ (self , stub_instance , rpc_name , endpoint , endpoint_key ):
229
231
"""Stores all RPC related data"""
230
232
self .rpc_name = rpc_name
231
233
self .rpc = getattr (stub_instance , rpc_name )
232
234
self .request_id = uuid .uuid4 ()
233
235
self .endpoint = endpoint
234
236
self .rendezvous = None
235
237
self .metadata_kv = None
238
+ self .endpoint_key = endpoint_key
236
239
237
240
def __str__ (self ):
238
241
return "RpcState(%s, %s, %s)" % (self .rpc_name , self .request_id , self .endpoint )
@@ -318,6 +321,14 @@ def channel_factory(
318
321
)
319
322
320
323
324
+ class EndpointKey (object ):
325
+ __slots__ = ("endpoint" , "node_id" )
326
+
327
+ def __init__ (self , endpoint , node_id ):
328
+ self .endpoint = endpoint
329
+ self .node_id = node_id
330
+
331
+
321
332
class Connection (object ):
322
333
__slots__ = (
323
334
"endpoint" ,
@@ -330,6 +341,8 @@ class Connection(object):
330
341
"lock" ,
331
342
"calls" ,
332
343
"closing" ,
344
+ "endpoint_key" ,
345
+ "node_id" ,
333
346
)
334
347
335
348
def __init__ (self , endpoint , driver_config = None , endpoint_options = None ):
@@ -341,6 +354,10 @@ def __init__(self, endpoint, driver_config=None, endpoint_options=None):
341
354
"""
342
355
global _stubs_list
343
356
self .endpoint = endpoint
357
+ self .node_id = getattr (endpoint_options , "node_id" , None )
358
+ self .endpoint_key = EndpointKey (
359
+ endpoint , getattr (endpoint_options , "node_id" , None )
360
+ )
344
361
self ._channel = channel_factory (
345
362
self .endpoint , driver_config , endpoint_options = endpoint_options
346
363
)
@@ -368,7 +385,9 @@ def _prepare_call(self, stub, rpc_name, request, settings):
368
385
)
369
386
_set_server_timeouts (request , settings , timeout )
370
387
self ._prepare_stub_instance (stub )
371
- rpc_state = _RpcState (self ._stub_instances [stub ], rpc_name , self .endpoint )
388
+ rpc_state = _RpcState (
389
+ self ._stub_instances [stub ], rpc_name , self .endpoint , self .endpoint_key
390
+ )
372
391
logger .debug ("%s: creating call state" , rpc_state )
373
392
with self .lock :
374
393
if self .closing :
0 commit comments