Skip to content

Commit 9fa63dd

Browse files
feat(core): shell completion for sessions (#3450)
1 parent 41927a1 commit 9fa63dd

File tree

25 files changed

+311
-196
lines changed

25 files changed

+311
-196
lines changed

conftest.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
# -*- coding: utf-8 -*-
2-
#
3-
# Copyright 2017-2023 Swiss Data Science Center (SDSC)
4-
# A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and
1+
# Copyright Swiss Data Science Center (SDSC). A partnership between
2+
# École Polytechnique Fédérale de Lausanne (EPFL) and
53
# Eidgenössische Technische Hochschule Zürich (ETHZ).
64
#
75
# Licensed under the Apache License, Version 2.0 (the "License");

design/003-interactive-session/003-interactive-session.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,9 @@ class ISessionProvider:
160160
"""
161161
pass
162162

163-
def session_list(self, project_name: str, config: Optional[Dict[str, Any]]) -> List[Session]:
163+
def session_list(self, project_name: str) -> List[Session]:
164164
"""Lists all the sessions currently running by the given session provider.
165165
:param project_name: Renku project name.
166-
:param config: Path to the session provider specific configuration YAML.
167166
:returns: a list of sessions.
168167
"""
169168
pass

docs/how-to-guides/shell-integration.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ To activate tab completion for your supported shell run the following command af
3535
3636
$ eval "$(_RENKU_COMPLETE=zsh_source renku)"
3737
38+
You can put the same command in your shell's startup script to enable completion by default.
3839
After this not only sub-commands of ``renku`` will be auto-completed using tab, but for example
3940
in case of ``renku workflow execute`` the available ``Plans`` are going to be listed.
4041

renku/command/session.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
#
2-
# Copyright 2018-2023- Swiss Data Science Center (SDSC)
3-
# A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and
4-
# Eidgenössische Technische Hochschule Zürich (ETHZ).
1+
# Copyright Swiss Data Science Center (SDSC). A partnership between
2+
# École Polytechnique Fédérale de Lausanne (EPFL) and
3+
# Eidgenössische Technische Hochschule Zürich (ETHZ).
54
#
65
# Licensed under the Apache License, Version 2.0 (the "License");
76
# you may not use this file except in compliance with the License.
@@ -16,9 +15,26 @@
1615
# limitations under the License.
1716
"""Renku session commands."""
1817

19-
2018
from renku.command.command_builder.command import Command
21-
from renku.core.session.session import session_list, session_open, session_start, session_stop, ssh_setup
19+
from renku.core.session.session import (
20+
search_session_providers,
21+
search_sessions,
22+
session_list,
23+
session_open,
24+
session_start,
25+
session_stop,
26+
ssh_setup,
27+
)
28+
29+
30+
def search_sessions_command():
31+
"""Get all the session names that match a pattern."""
32+
return Command().command(search_sessions).require_migration().with_database(write=False)
33+
34+
35+
def search_session_providers_command():
36+
"""Get all the session provider names that match a pattern."""
37+
return Command().command(search_session_providers).require_migration().with_database(write=False)
2238

2339

2440
def session_list_command():

renku/core/plugin/session.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,6 @@ def get_supported_session_providers() -> List[ISessionProvider]:
3939
from renku.core.plugin.pluginmanager import get_plugin_manager
4040

4141
pm = get_plugin_manager()
42-
return pm.hook.session_provider()
42+
providers = pm.hook.session_provider()
43+
44+
return sorted(providers, key=lambda p: p.priority)

renku/core/session/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
#
2-
# Copyright 2018-2023 - Swiss Data Science Center (SDSC)
3-
# A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and
4-
# Eidgenössische Technische Hochschule Zürich (ETHZ).
1+
# Copyright Swiss Data Science Center (SDSC). A partnership between
2+
# École Polytechnique Fédérale de Lausanne (EPFL) and
3+
# Eidgenössische Technische Hochschule Zürich (ETHZ).
54
#
65
# Licensed under the Apache License, Version 2.0 (the "License");
76
# you may not use this file except in compliance with the License.

renku/core/session/docker.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
#
2-
# Copyright 2018-2023 - Swiss Data Science Center (SDSC)
3-
# A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and
4-
# Eidgenössische Technische Hochschule Zürich (ETHZ).
1+
# Copyright Swiss Data Science Center (SDSC). A partnership between
2+
# École Polytechnique Fédérale de Lausanne (EPFL) and
3+
# Eidgenössische Technische Hochschule Zürich (ETHZ).
54
#
65
# Licensed under the Apache License, Version 2.0 (the "License");
76
# you may not use this file except in compliance with the License.
@@ -33,7 +32,7 @@
3332
from renku.core.plugin import hookimpl
3433
from renku.core.util import communication
3534
from renku.domain_model.project_context import project_context
36-
from renku.domain_model.session import ISessionProvider, Session
35+
from renku.domain_model.session import ISessionProvider, Session, SessionStopStatus
3736

3837
if TYPE_CHECKING:
3938
from renku.core.dataset.providers.models import ProviderParameter
@@ -43,7 +42,7 @@ class DockerSessionProvider(ISessionProvider):
4342
"""A docker based interactive session provider."""
4443

4544
JUPYTER_PORT = 8888
46-
# NOTE: Give the docker provider a higher priority so that it's checked first
45+
# NOTE: Give the docker provider the highest priority so that it's checked first
4746
priority: ProviderPriority = ProviderPriority.HIGHEST
4847

4948
def __init__(self):
@@ -54,7 +53,7 @@ def docker_client(self) -> docker.client.DockerClient:
5453
5554
Note:
5655
This is not a @property, even though it should be, because ``pluggy``
57-
will call it in that case in unrelated parts of the code that will
56+
will call it in that case in unrelated parts of the code.
5857
Raises:
5958
errors.DockerError: Exception when docker is not available.
6059
Returns:
@@ -133,7 +132,7 @@ def get_open_parameters(self) -> List["ProviderParameter"]:
133132
"""Returns parameters that can be set for session open."""
134133
return []
135134

136-
def session_list(self, project_name: str, config: Optional[Dict[str, Any]]) -> List[Session]:
135+
def session_list(self, project_name: str) -> List[Session]:
137136
"""Lists all the sessions currently running by the given session provider.
138137
139138
Returns:
@@ -297,29 +296,36 @@ def session_start_helper(consider_disk_request: bool):
297296
else:
298297
return result, ""
299298

300-
def session_stop(self, project_name: str, session_name: Optional[str], stop_all: bool) -> bool:
299+
def session_stop(self, project_name: str, session_name: Optional[str], stop_all: bool) -> SessionStopStatus:
301300
"""Stops all or a given interactive session."""
302301
try:
303302
docker_containers = (
304303
self._get_docker_containers(project_name)
305304
if stop_all
306305
else self.docker_client().containers.list(filters={"id": session_name})
306+
if session_name
307+
else self.docker_client().containers.list()
307308
)
308309

309-
if len(docker_containers) == 0:
310-
return False
310+
n_docker_containers = len(docker_containers)
311+
312+
if n_docker_containers == 0:
313+
return SessionStopStatus.FAILED if session_name else SessionStopStatus.NO_ACTIVE_SESSION
314+
elif not session_name and len(docker_containers) > 1:
315+
return SessionStopStatus.NAME_NEEDED
311316

312317
[c.stop() for c in docker_containers]
313-
return True
314318
except docker.errors.APIError as error:
315319
raise errors.DockerError(error.msg)
320+
else:
321+
return SessionStopStatus.SUCCESSFUL
316322

317-
def session_open(self, project_name: str, session_name: str, **kwargs) -> bool:
323+
def session_open(self, project_name: str, session_name: Optional[str], **kwargs) -> bool:
318324
"""Open a given interactive session.
319325
320326
Args:
321327
project_name(str): Renku project name.
322-
session_name(str): The unique id of the interactive session.
328+
session_name(Optional[str]): The unique id of the interactive session.
323329
"""
324330
url = self.session_url(session_name)
325331

@@ -329,10 +335,14 @@ def session_open(self, project_name: str, session_name: str, **kwargs) -> bool:
329335
webbrowser.open(url)
330336
return True
331337

332-
def session_url(self, session_name: str) -> Optional[str]:
338+
def session_url(self, session_name: Optional[str]) -> Optional[str]:
333339
"""Get the URL of the interactive session."""
334-
for c in self.docker_client().containers.list():
335-
if c.short_id == session_name and f"{DockerSessionProvider.JUPYTER_PORT}/tcp" in c.ports:
340+
sessions = self.docker_client().containers.list()
341+
342+
for c in sessions:
343+
if (
344+
c.short_id == session_name or (not session_name and len(sessions) == 1)
345+
) and f"{DockerSessionProvider.JUPYTER_PORT}/tcp" in c.ports:
336346
host = c.ports[f"{DockerSessionProvider.JUPYTER_PORT}/tcp"][0]
337347
return f'http://{host["HostIp"]}:{host["HostPort"]}/?token={c.labels["jupyter_token"]}'
338348
return None

renku/core/session/renkulab.py

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
#
2-
# Copyright 2018-2023 - Swiss Data Science Center (SDSC)
3-
# A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and
4-
# Eidgenössische Technische Hochschule Zürich (ETHZ).
1+
# Copyright Swiss Data Science Center (SDSC). A partnership between
2+
# École Polytechnique Fédérale de Lausanne (EPFL) and
3+
# Eidgenössische Technische Hochschule Zürich (ETHZ).
54
#
65
# Licensed under the Apache License, Version 2.0 (the "License");
76
# you may not use this file except in compliance with the License.
@@ -26,6 +25,7 @@
2625

2726
from renku.core import errors
2827
from renku.core.config import get_value
28+
from renku.core.constant import ProviderPriority
2929
from renku.core.login import read_renku_token
3030
from renku.core.plugin import hookimpl
3131
from renku.core.session.utils import get_renku_project_name, get_renku_url
@@ -34,7 +34,7 @@
3434
from renku.core.util.jwt import is_token_expired
3535
from renku.core.util.ssh import SystemSSHConfig
3636
from renku.domain_model.project_context import project_context
37-
from renku.domain_model.session import ISessionProvider, Session
37+
from renku.domain_model.session import ISessionProvider, Session, SessionStopStatus
3838

3939
if TYPE_CHECKING:
4040
from renku.core.dataset.providers.models import ProviderParameter
@@ -44,6 +44,8 @@ class RenkulabSessionProvider(ISessionProvider):
4444
"""A session provider that uses the notebook service API to launch sessions."""
4545

4646
DEFAULT_TIMEOUT_SECONDS = 300
47+
# NOTE: Give the renkulab provider the lowest priority so that it's checked last
48+
priority: ProviderPriority = ProviderPriority.LOWEST
4749

4850
def __init__(self):
4951
self.__renku_url: Optional[str] = None
@@ -187,7 +189,7 @@ def _cleanup_ssh_connection_configs(
187189
gotten from the server.
188190
"""
189191
if not running_sessions:
190-
running_sessions = self.session_list("", None, ssh_garbage_collection=False)
192+
running_sessions = self.session_list(project_name="", ssh_garbage_collection=False)
191193

192194
system_config = SystemSSHConfig()
193195

@@ -199,7 +201,8 @@ def _cleanup_ssh_connection_configs(
199201
if path not in session_config_paths:
200202
path.unlink()
201203

202-
def _remote_head_hexsha(self):
204+
@staticmethod
205+
def _remote_head_hexsha():
203206
remote = get_remote(repository=project_context.repository)
204207

205208
if remote is None:
@@ -221,7 +224,8 @@ def _send_renku_request(self, req_type: str, *args, **kwargs):
221224
)
222225
return res
223226

224-
def _project_name_from_full_project_name(self, project_name: str) -> str:
227+
@staticmethod
228+
def _project_name_from_full_project_name(project_name: str) -> str:
225229
"""Get just project name of project name if in owner/name form."""
226230
if "/" not in project_name:
227231
return project_name
@@ -282,9 +286,7 @@ def get_open_parameters(self) -> List["ProviderParameter"]:
282286
ProviderParameter("ssh", help="Open a remote terminal through SSH.", is_flag=True),
283287
]
284288

285-
def session_list(
286-
self, project_name: str, config: Optional[Dict[str, Any]], ssh_garbage_collection: bool = True
287-
) -> List[Session]:
289+
def session_list(self, project_name: str, ssh_garbage_collection: bool = True) -> List[Session]:
288290
"""Lists all the sessions currently running by the given session provider.
289291
290292
Returns:
@@ -398,45 +400,67 @@ def session_start(
398400
)
399401
raise errors.RenkulabSessionError("Cannot start session via the notebook service because " + res.text)
400402

401-
def session_stop(self, project_name: str, session_name: Optional[str], stop_all: bool) -> bool:
403+
def session_stop(self, project_name: str, session_name: Optional[str], stop_all: bool) -> SessionStopStatus:
402404
"""Stops all sessions (for the given project) or a specific interactive session."""
403405
responses = []
406+
sessions = self.session_list(project_name=project_name)
407+
n_sessions = len(sessions)
408+
409+
if n_sessions == 0:
410+
return SessionStopStatus.NO_ACTIVE_SESSION
411+
404412
if stop_all:
405-
sessions = self.session_list(project_name=project_name, config=None)
406413
for session in sessions:
407414
responses.append(
408415
self._send_renku_request(
409416
"delete", f"{self._notebooks_url()}/servers/{session.id}", headers=self._auth_header()
410417
)
411418
)
412419
self._wait_for_session_status(session.id, "stopping")
413-
else:
420+
elif session_name:
414421
responses.append(
415422
self._send_renku_request(
416423
"delete", f"{self._notebooks_url()}/servers/{session_name}", headers=self._auth_header()
417424
)
418425
)
419426
self._wait_for_session_status(session_name, "stopping")
427+
elif n_sessions == 1:
428+
responses.append(
429+
self._send_renku_request(
430+
"delete", f"{self._notebooks_url()}/servers/{sessions[0].id}", headers=self._auth_header()
431+
)
432+
)
433+
self._wait_for_session_status(sessions[0].id, "stopping")
434+
else:
435+
return SessionStopStatus.NAME_NEEDED
420436

421437
self._cleanup_ssh_connection_configs(project_name)
422438

423-
return all([response.status_code == 204 for response in responses]) if responses else False
439+
n_successfully_stopped = len([r for r in responses if r.status_code == 204])
424440

425-
def session_open(self, project_name: str, session_name: str, ssh: bool = False, **kwargs) -> bool:
441+
return SessionStopStatus.SUCCESSFUL if n_successfully_stopped == n_sessions else SessionStopStatus.FAILED
442+
443+
def session_open(self, project_name: str, session_name: Optional[str], ssh: bool = False, **kwargs) -> bool:
426444
"""Open a given interactive session.
427445
428446
Args:
429447
project_name(str): Renku project name.
430-
session_name(str): The unique id of the interactive session.
448+
session_name(Optional[str]): The unique id of the interactive session.
431449
ssh(bool): Whether to open an SSH connection or a normal browser interface.
432450
"""
433-
sessions = self.session_list("", None)
451+
sessions = self.session_list(project_name="")
434452
system_config = SystemSSHConfig()
435453
name = self._project_name_from_full_project_name(project_name)
436454
ssh_prefix = f"{system_config.renku_host}-{name}-"
437455

456+
if not session_name:
457+
if len(sessions) == 1:
458+
session_name = sessions[0].id
459+
else:
460+
return False
461+
438462
if session_name.startswith(ssh_prefix):
439-
# NOTE: use passed in ssh connection name instead of session id by accident
463+
# NOTE: User passed in ssh connection name instead of session id by accident
440464
session_name = session_name.replace(ssh_prefix, "", 1)
441465

442466
if not any(s.id == session_name for s in sessions):

0 commit comments

Comments
 (0)