Skip to content

Commit 48b08be

Browse files
authored
feat: added filter by connection names to operations (#6)
* feat: added filter by connection names to operations * chore: code review
1 parent 7d2c718 commit 48b08be

File tree

2 files changed

+223
-40
lines changed

2 files changed

+223
-40
lines changed

datashield/api.py

Lines changed: 87 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -129,31 +129,38 @@ def open(self, restore: str = None, failSafe: bool = False) -> None:
129129
for name in self.errors:
130130
logging.error(f"Connection to {name} has failed")
131131

132-
def close(self, save: str = None) -> None:
132+
def close(self, save: str = None, conn_names: list[str] = None) -> None:
133133
"""
134134
Close connections with remote servers.
135135
136-
:param cons: The list of connections to close.
137136
:param save: The name of the workspace to save before closing the connections.
137+
:param conn_names: The optional list of connection names to close. If not defined, all opened connections are closed.
138138
"""
139139
self.errors = {}
140-
for conn in self.conns:
140+
if not self.conns:
141+
return
142+
selected_conns = self._get_selected_connections(conn_names)
143+
selected_names = {conn.get_name() for conn in selected_conns}
144+
for conn in selected_conns:
141145
try:
142146
if save:
143147
conn.save_workspace(f"{conn.get_name()}:{save}")
144148
conn.disconnect()
145149
except DSError:
146150
# silently fail
147151
pass
148-
self.conns = None
152+
if conn_names is None:
153+
self.conns = None
154+
else:
155+
self.conns = [conn for conn in self.conns if conn.get_name() not in selected_names]
149156

150157
def has_connections(self) -> bool:
151158
"""
152159
Check if some connections were opened.
153160
154161
:return: True if some connections were opened, False otherwise
155162
"""
156-
return len(self.conns) > 0
163+
return self.conns and len(self.conns) > 0
157164

158165
def get_connection_names(self) -> list[str]:
159166
"""
@@ -186,27 +193,29 @@ def get_errors(self) -> dict:
186193
# Environment
187194
#
188195

189-
def tables(self) -> dict:
196+
def tables(self, conn_names: list[str] = None) -> dict:
190197
"""
191198
List available table names from the data repository.
192199
200+
:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
193201
:return: The available table names from the data repository, per remote server name
194202
"""
195203
rval = {}
196-
for conn in self.conns:
204+
for conn in self._get_selected_connections(conn_names):
197205
rval[conn.get_name()] = conn.list_tables()
198206
return rval
199207

200-
def variables(self, table: str = None, tables: dict = None) -> dict:
208+
def variables(self, table: str = None, tables: dict = None, conn_names: list[str] = None) -> dict:
201209
"""
202210
List available variables from the data repository, for a given table.
203211
204212
:param table: The default name of the table to list variables for
205213
:param tables: The name of the table to list variables for, per server name. If not defined, 'table' is used.
214+
:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
206215
:return: The available variables from the data repository, for a given table, per remote server name
207216
"""
208217
rval = {}
209-
for conn in self.conns:
218+
for conn in self._get_selected_connections(conn_names):
210219
name = table
211220
if tables and conn.get_name() in tables:
212221
name = tables[conn.get_name()]
@@ -216,120 +225,130 @@ def variables(self, table: str = None, tables: dict = None) -> dict:
216225
rval[conn.get_name()] = None
217226
return rval
218227

219-
def taxonomies(self) -> dict:
228+
def taxonomies(self, conn_names: list[str] = None) -> dict:
220229
"""
221230
List available taxonomies from the data repository. A taxonomy is a hierarchical structure of vocabulary
222231
terms that can be used to annotate variables in the data repository.
223232
Depending on the data repository's capabilities, taxonomies can be used to perform structured
224233
queries when searching for variables.
225234
235+
:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
226236
:return: The available taxonomies from the data repository, per remote server name
227237
"""
228238
rval = {}
229-
for conn in self.conns:
239+
for conn in self._get_selected_connections(conn_names):
230240
rval[conn.get_name()] = conn.list_taxonomies()
231241
return rval
232242

233-
def search_variables(self, query: str) -> dict:
243+
def search_variables(self, query: str, conn_names: list[str] = None) -> dict:
234244
"""
235245
Search for variable names matching a given query across all tables in the data repository.
236246
237247
:param query: The query to search for in variable names, e.g., a full-text search and/or structured
238248
query (based on taxonomy terms), depending on the data repository's capabilities
249+
:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
239250
:return: The matching variable names from the data repository, per remote server name
240251
"""
241252
rval = {}
242-
for conn in self.conns:
253+
for conn in self._get_selected_connections(conn_names):
243254
rval[conn.get_name()] = conn.search_variables(query)
244255
return rval
245256

246-
def resources(self) -> dict:
257+
def resources(self, conn_names: list[str] = None) -> dict:
247258
"""
248259
List available resource names from the data repository.
249260
261+
:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
250262
:return: The available resource names from the data repository, per remote server name
251263
"""
252264
rval = {}
253-
for conn in self.conns:
265+
for conn in self._get_selected_connections(conn_names):
254266
rval[conn.get_name()] = conn.list_resources()
255267
return rval
256268

257-
def profiles(self) -> dict:
269+
def profiles(self, conn_names: list[str] = None) -> dict:
258270
"""
259271
List available DataSHIELD profile names in the data repository.
260272
273+
:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
261274
:return: The available DataSHIELD profile names in the data repository, per remote server name
262275
"""
263276
rval = {}
264-
for conn in self.conns:
277+
for conn in self._get_selected_connections(conn_names):
265278
rval[conn.get_name()] = conn.list_profiles()
266279
return rval
267280

268-
def packages(self) -> dict:
281+
def packages(self, conn_names: list[str] = None) -> dict:
269282
"""
270283
Get the list of DataSHIELD packages with their version, that have been configured on the remote data repository.
271284
285+
:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
272286
:return: The list of DataSHIELD packages with their version, that have been configured on the remote data repository, per remote server name
273287
"""
274288
rval = {}
275-
for conn in self.conns:
289+
for conn in self._get_selected_connections(conn_names):
276290
rval[conn.get_name()] = conn.list_packages()
277291
return rval
278292

279-
def methods(self, type: str = "aggregate") -> dict:
293+
def methods(self, type: str = "aggregate", conn_names: list[str] = None) -> dict:
280294
"""
281295
Get the list of DataSHIELD methods that have been configured on the remote data repository.
282296
283297
:param type: The type of method, either "aggregate" (default) or "assign"
298+
:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
284299
:return: The list of DataSHIELD methods that have been configured on the remote data repository, per remote server name
285300
"""
286301
rval = {}
287-
for conn in self.conns:
302+
for conn in self._get_selected_connections(conn_names):
288303
rval[conn.get_name()] = conn.list_methods(type)
289304
return rval
290305

291306
#
292307
# Workspaces
293308
#
294309

295-
def workspaces(self) -> dict:
310+
def workspaces(self, conn_names: list[str] = None) -> dict:
296311
"""
297312
Get the list of DataSHIELD workspaces, that have been saved on the remote data repository.
298313
314+
:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
299315
:return: The list of DataSHIELD workspaces, that have been saved on the remote data repository, per remote server name
300316
"""
301317
rval = {}
302-
for conn in self.conns:
318+
for conn in self._get_selected_connections(conn_names):
303319
rval[conn.get_name()] = conn.list_workspaces()
304320
return rval
305321

306-
def workspace_save(self, name: str) -> None:
322+
def workspace_save(self, name: str, conn_names: list[str] = None) -> None:
307323
"""
308324
Save the DataSHIELD R session in a workspace on the remote data repository.
309325
310326
:param name: The name of the workspace
327+
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
311328
"""
312-
for conn in self.conns:
329+
for conn in self._get_selected_connections(conn_names):
313330
conn.save_workspace(f"{conn.get_name()}:{name}")
314331

315-
def workspace_restore(self, name: str) -> None:
332+
def workspace_restore(self, name: str, conn_names: list[str] = None) -> None:
316333
"""
317334
Restore a saved DataSHIELD R session from the remote data repository. When restoring a workspace,
318335
any existing symbol or file with same name will be overridden.
319336
320337
:param name: The name of the workspace
338+
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
321339
"""
322-
for conn in self.conns:
340+
for conn in self._get_selected_connections(conn_names):
323341
conn.restore_workspace(f"{conn.get_name()}:{name}")
324342

325-
def workspace_rm(self, name: str) -> None:
343+
def workspace_rm(self, name: str, conn_names: list[str] = None) -> None:
326344
"""
327345
Remove a DataSHIELD workspace from the remote data repository. Ignored if no
328346
such workspace exists.
329347
330348
:param name: The name of the workspace
349+
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
331350
"""
332-
for conn in self.conns:
351+
for conn in self._get_selected_connections(conn_names):
333352
conn.rm_workspace(f"{conn.get_name()}:{name}")
334353

335354
#
@@ -358,6 +377,9 @@ def sessions(self) -> dict:
358377
"""
359378
rval = {}
360379
self._init_errors()
380+
if not self.conns or len(self.conns) == 0:
381+
return rval
382+
361383
started_conns = []
362384
excluded_conns = []
363385

@@ -409,7 +431,7 @@ def sessions(self) -> dict:
409431
raise DSError("No sessions could be started successfully.")
410432
return rval
411433

412-
def ls(self) -> dict:
434+
def ls(self, conn_names: list[str] = None) -> dict:
413435
"""
414436
After assignments have been performed, list the symbols that live in the DataSHIELD R session on the server side.
415437
@@ -418,7 +440,7 @@ def ls(self) -> dict:
418440
# ensure sessions are started and available
419441
self.sessions()
420442
rval = {}
421-
for conn in self.conns:
443+
for conn in self._get_selected_connections(conn_names):
422444
try:
423445
rval[conn.get_name()] = conn.list_symbols()
424446
except Exception as e:
@@ -427,15 +449,16 @@ def ls(self) -> dict:
427449
self._check_errors()
428450
return rval
429451

430-
def rm(self, symbol: str) -> None:
452+
def rm(self, symbol: str, conn_names: list[str] = None) -> None:
431453
"""
432454
Remove a symbol from remote servers.
433455
434456
:param symbol: The name of the symbol to remove
457+
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
435458
"""
436459
# ensure sessions are started and available
437460
self.sessions()
438-
for conn in self.conns:
461+
for conn in self._get_selected_connections(conn_names):
439462
try:
440463
conn.rm_symbol(symbol)
441464
except Exception as e:
@@ -452,6 +475,7 @@ def assign_table(
452475
identifiers: str = None,
453476
id_name: str = None,
454477
asynchronous: bool = True,
478+
conn_names: list[str] = None,
455479
) -> None:
456480
"""
457481
Assign a data table from the data repository to a symbol in the DataSHIELD R session.
@@ -460,11 +484,12 @@ def assign_table(
460484
:param table: The default name of the table to assign
461485
:param tables: The name of the table to assign, per server name. If not defined, 'table' is used.
462486
:param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server)
487+
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
463488
"""
464489
# ensure sessions are started and available
465490
self.sessions()
466491
cmd = {}
467-
for conn in self.conns:
492+
for conn in self._get_selected_connections(conn_names):
468493
name = table
469494
if tables and conn.get_name() in tables:
470495
name = tables[conn.get_name()]
@@ -478,7 +503,12 @@ def assign_table(
478503
self._check_errors()
479504

480505
def assign_resource(
481-
self, symbol: str, resource: str = None, resources: dict = None, asynchronous: bool = True
506+
self,
507+
symbol: str,
508+
resource: str = None,
509+
resources: dict = None,
510+
asynchronous: bool = True,
511+
conn_names: list[str] = None,
482512
) -> None:
483513
"""
484514
Assign a resource from the data repository to a symbol in the DataSHIELD R session.
@@ -487,11 +517,12 @@ def assign_resource(
487517
:param resource: The default name of the resource to assign
488518
:param resources: The name of the resource to assign, per server name. If not defined, 'resource' is used.
489519
:param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server)
520+
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
490521
"""
491522
# ensure sessions are started and available
492523
self.sessions()
493524
cmd = {}
494-
for conn in self.conns:
525+
for conn in self._get_selected_connections(conn_names):
495526
name = resource
496527
if resources and conn.get_name() in resources:
497528
name = resources[conn.get_name()]
@@ -504,18 +535,19 @@ def assign_resource(
504535
self._do_wait(cmd)
505536
self._check_errors()
506537

507-
def assign_expr(self, symbol: str, expr: str, asynchronous: bool = True) -> None:
538+
def assign_expr(self, symbol: str, expr: str, asynchronous: bool = True, conn_names: list[str] = None) -> None:
508539
"""
509540
Assign the result of the evaluation of an expression to a symbol in the DataSHIELD R session.
510541
511542
:param symbol: The name of the destination symbol
512543
:param expr: The R expression to evaluate and which result will be assigned
513544
:param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server)
545+
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
514546
"""
515547
# ensure sessions are started and available
516548
self.sessions()
517549
cmd = {}
518-
for conn in self.conns:
550+
for conn in self._get_selected_connections(conn_names):
519551
try:
520552
res = conn.assign_expr(symbol, expr, asynchronous)
521553
cmd[conn.get_name()] = res
@@ -524,20 +556,21 @@ def assign_expr(self, symbol: str, expr: str, asynchronous: bool = True) -> None
524556
self._do_wait(cmd)
525557
self._check_errors()
526558

527-
def aggregate(self, expr: str, asynchronous: bool = True) -> dict:
559+
def aggregate(self, expr: str, asynchronous: bool = True, conn_names: list[str] = None) -> dict:
528560
"""
529561
Aggregate some data from the DataSHIELD R session using a valid R expression. The
530562
aggregation expression must satisfy the data repository's DataSHIELD configuration.
531563
532564
:param expr: The R expression to evaluate and which result will be returned
533565
:param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server)
566+
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
534567
:return: The result of the aggregation expression evaluation, per remote server name
535568
"""
536569
# ensure sessions are started and available
537570
self.sessions()
538571
cmd = {}
539572
rval = {}
540-
for conn in self.conns:
573+
for conn in self._get_selected_connections(conn_names):
541574
try:
542575
res = conn.aggregate(expr, asynchronous)
543576
cmd[conn.get_name()] = res
@@ -573,6 +606,20 @@ def _do_wait(self, cmd: dict) -> dict:
573606
time.sleep(0.1)
574607
return rval
575608

609+
def _get_selected_connections(self, conn_names: list[str] = None) -> list[DSConnection]:
610+
"""
611+
Get the list of opened connections, optionally filtered by connection names.
612+
613+
:param conn_names: The optional list of connection names to select.
614+
:return: The list of selected opened connections
615+
"""
616+
if not self.conns:
617+
return []
618+
if conn_names is None:
619+
return self.conns
620+
selected_names = set(conn_names)
621+
return [conn for conn in self.conns if conn.get_name() in selected_names]
622+
576623
def _init_errors(self) -> None:
577624
"""
578625
Prepare for storing errors.

0 commit comments

Comments
 (0)