1
- #
2
- # Copyright 2018-2023 - Swiss Data Science Center (SDSC)
3
- # A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and
4
- # Eidgenössische Technische Hochschule Zürich (ETHZ).
1
+ # Copyright Swiss Data Science Center (SDSC). A partnership between
2
+ # École Polytechnique Fédérale de Lausanne (EPFL) and
3
+ # Eidgenössische Technische Hochschule Zürich (ETHZ).
5
4
#
6
5
# Licensed under the Apache License, Version 2.0 (the "License");
7
6
# you may not use this file except in compliance with the License.
26
25
27
26
from renku .core import errors
28
27
from renku .core .config import get_value
28
+ from renku .core .constant import ProviderPriority
29
29
from renku .core .login import read_renku_token
30
30
from renku .core .plugin import hookimpl
31
31
from renku .core .session .utils import get_renku_project_name , get_renku_url
34
34
from renku .core .util .jwt import is_token_expired
35
35
from renku .core .util .ssh import SystemSSHConfig
36
36
from renku .domain_model .project_context import project_context
37
- from renku .domain_model .session import ISessionProvider , Session
37
+ from renku .domain_model .session import ISessionProvider , Session , SessionStopStatus
38
38
39
39
if TYPE_CHECKING :
40
40
from renku .core .dataset .providers .models import ProviderParameter
@@ -44,6 +44,8 @@ class RenkulabSessionProvider(ISessionProvider):
44
44
"""A session provider that uses the notebook service API to launch sessions."""
45
45
46
46
DEFAULT_TIMEOUT_SECONDS = 300
47
+ # NOTE: Give the renkulab provider the lowest priority so that it's checked last
48
+ priority : ProviderPriority = ProviderPriority .LOWEST
47
49
48
50
def __init__ (self ):
49
51
self .__renku_url : Optional [str ] = None
@@ -187,7 +189,7 @@ def _cleanup_ssh_connection_configs(
187
189
gotten from the server.
188
190
"""
189
191
if not running_sessions :
190
- running_sessions = self .session_list ("" , None , ssh_garbage_collection = False )
192
+ running_sessions = self .session_list (project_name = "" , ssh_garbage_collection = False )
191
193
192
194
system_config = SystemSSHConfig ()
193
195
@@ -199,7 +201,8 @@ def _cleanup_ssh_connection_configs(
199
201
if path not in session_config_paths :
200
202
path .unlink ()
201
203
202
- def _remote_head_hexsha (self ):
204
+ @staticmethod
205
+ def _remote_head_hexsha ():
203
206
remote = get_remote (repository = project_context .repository )
204
207
205
208
if remote is None :
@@ -221,7 +224,8 @@ def _send_renku_request(self, req_type: str, *args, **kwargs):
221
224
)
222
225
return res
223
226
224
- def _project_name_from_full_project_name (self , project_name : str ) -> str :
227
+ @staticmethod
228
+ def _project_name_from_full_project_name (project_name : str ) -> str :
225
229
"""Get just project name of project name if in owner/name form."""
226
230
if "/" not in project_name :
227
231
return project_name
@@ -282,9 +286,7 @@ def get_open_parameters(self) -> List["ProviderParameter"]:
282
286
ProviderParameter ("ssh" , help = "Open a remote terminal through SSH." , is_flag = True ),
283
287
]
284
288
285
- def session_list (
286
- self , project_name : str , config : Optional [Dict [str , Any ]], ssh_garbage_collection : bool = True
287
- ) -> List [Session ]:
289
+ def session_list (self , project_name : str , ssh_garbage_collection : bool = True ) -> List [Session ]:
288
290
"""Lists all the sessions currently running by the given session provider.
289
291
290
292
Returns:
@@ -398,45 +400,67 @@ def session_start(
398
400
)
399
401
raise errors .RenkulabSessionError ("Cannot start session via the notebook service because " + res .text )
400
402
401
- def session_stop (self , project_name : str , session_name : Optional [str ], stop_all : bool ) -> bool :
403
+ def session_stop (self , project_name : str , session_name : Optional [str ], stop_all : bool ) -> SessionStopStatus :
402
404
"""Stops all sessions (for the given project) or a specific interactive session."""
403
405
responses = []
406
+ sessions = self .session_list (project_name = project_name )
407
+ n_sessions = len (sessions )
408
+
409
+ if n_sessions == 0 :
410
+ return SessionStopStatus .NO_ACTIVE_SESSION
411
+
404
412
if stop_all :
405
- sessions = self .session_list (project_name = project_name , config = None )
406
413
for session in sessions :
407
414
responses .append (
408
415
self ._send_renku_request (
409
416
"delete" , f"{ self ._notebooks_url ()} /servers/{ session .id } " , headers = self ._auth_header ()
410
417
)
411
418
)
412
419
self ._wait_for_session_status (session .id , "stopping" )
413
- else :
420
+ elif session_name :
414
421
responses .append (
415
422
self ._send_renku_request (
416
423
"delete" , f"{ self ._notebooks_url ()} /servers/{ session_name } " , headers = self ._auth_header ()
417
424
)
418
425
)
419
426
self ._wait_for_session_status (session_name , "stopping" )
427
+ elif n_sessions == 1 :
428
+ responses .append (
429
+ self ._send_renku_request (
430
+ "delete" , f"{ self ._notebooks_url ()} /servers/{ sessions [0 ].id } " , headers = self ._auth_header ()
431
+ )
432
+ )
433
+ self ._wait_for_session_status (sessions [0 ].id , "stopping" )
434
+ else :
435
+ return SessionStopStatus .NAME_NEEDED
420
436
421
437
self ._cleanup_ssh_connection_configs (project_name )
422
438
423
- return all ([ response . status_code == 204 for response in responses ]) if responses else False
439
+ n_successfully_stopped = len ([ r for r in responses if r . status_code == 204 ])
424
440
425
- def session_open (self , project_name : str , session_name : str , ssh : bool = False , ** kwargs ) -> bool :
441
+ return SessionStopStatus .SUCCESSFUL if n_successfully_stopped == n_sessions else SessionStopStatus .FAILED
442
+
443
+ def session_open (self , project_name : str , session_name : Optional [str ], ssh : bool = False , ** kwargs ) -> bool :
426
444
"""Open a given interactive session.
427
445
428
446
Args:
429
447
project_name(str): Renku project name.
430
- session_name(str): The unique id of the interactive session.
448
+ session_name(Optional[ str] ): The unique id of the interactive session.
431
449
ssh(bool): Whether to open an SSH connection or a normal browser interface.
432
450
"""
433
- sessions = self .session_list ("" , None )
451
+ sessions = self .session_list (project_name = "" )
434
452
system_config = SystemSSHConfig ()
435
453
name = self ._project_name_from_full_project_name (project_name )
436
454
ssh_prefix = f"{ system_config .renku_host } -{ name } -"
437
455
456
+ if not session_name :
457
+ if len (sessions ) == 1 :
458
+ session_name = sessions [0 ].id
459
+ else :
460
+ return False
461
+
438
462
if session_name .startswith (ssh_prefix ):
439
- # NOTE: use passed in ssh connection name instead of session id by accident
463
+ # NOTE: User passed in ssh connection name instead of session id by accident
440
464
session_name = session_name .replace (ssh_prefix , "" , 1 )
441
465
442
466
if not any (s .id == session_name for s in sessions ):
0 commit comments