diff --git a/bases/renku_data_services/data_api/app.py b/bases/renku_data_services/data_api/app.py index e1c7d8082..451053a25 100644 --- a/bases/renku_data_services/data_api/app.py +++ b/bases/renku_data_services/data_api/app.py @@ -115,7 +115,12 @@ def register_all_handlers(app: Sanic, dm: DependencyManager) -> Sanic: storage_repo=dm.storage_repo, authenticator=dm.gitlab_authenticator, ) - storage_schema = StorageSchemaBP(name="storage_schema", url_prefix=url_prefix) + storage_schema = StorageSchemaBP( + name="storage_schema", + url_prefix=url_prefix, + data_source_repo=dm.data_source_repo, + authenticator=dm.authenticator, + ) user_preferences = UserPreferencesBP( name="user_preferences", url_prefix=url_prefix, @@ -185,6 +190,7 @@ def register_all_handlers(app: Sanic, dm: DependencyManager) -> Sanic: connected_services_repo=dm.connected_services_repo, oauth_client_factory=dm.oauth_http_client_factory, authenticator=dm.authenticator, + nb_config=dm.config.nb_config, ) repositories = RepositoriesBP( name="repositories", @@ -202,6 +208,7 @@ def register_all_handlers(app: Sanic, dm: DependencyManager) -> Sanic: data_connector_repo=dm.data_connector_repo, data_connector_secret_repo=dm.data_connector_secret_repo, git_provider_helper=dm.git_provider_helper, + data_source_repo=dm.data_source_repo, image_check_repo=dm.image_check_repo, internal_gitlab_authenticator=dm.gitlab_authenticator, metrics=dm.metrics, diff --git a/bases/renku_data_services/data_api/dependencies.py b/bases/renku_data_services/data_api/dependencies.py index 20e856abb..62145bcb6 100644 --- a/bases/renku_data_services/data_api/dependencies.py +++ b/bases/renku_data_services/data_api/dependencies.py @@ -49,6 +49,7 @@ from renku_data_services.notebooks.api.classes.data_service import DummyGitProviderHelper, GitProviderHelper from renku_data_services.notebooks.config import GitProviderHelperProto, get_clusters from renku_data_services.notebooks.constants import AMALTHEA_SESSION_GVK, JUPYTER_SESSION_GVK +from renku_data_services.notebooks.data_sources import DataSourceRepository from renku_data_services.notebooks.image_check import ImageCheckRepository from renku_data_services.notifications.db import NotificationsRepository from renku_data_services.platform.db import PlatformRepository, UrlRedirectRepository @@ -144,6 +145,7 @@ class DependencyManager: data_connector_repo: DataConnectorRepository data_connector_secret_repo: DataConnectorSecretRepository cluster_repo: ClusterRepository + data_source_repo: DataSourceRepository image_check_repo: ImageCheckRepository metrics_repo: MetricsRepository metrics: StagingMetricsService @@ -382,6 +384,11 @@ def from_env(cls) -> DependencyManager: secret_service_public_key=config.secrets.public_key, authz=authz, ) + data_source_repo = DataSourceRepository( + nb_config=config.nb_config, + connected_services_repo=connected_services_repo, + oauth_client_factory=oauth_http_client_factory, + ) image_check_repo = ImageCheckRepository( nb_config=config.nb_config, connected_services_repo=connected_services_repo, @@ -429,6 +436,7 @@ def from_env(cls) -> DependencyManager: data_connector_repo=data_connector_repo, data_connector_secret_repo=data_connector_secret_repo, cluster_repo=cluster_repo, + data_source_repo=data_source_repo, image_check_repo=image_check_repo, metrics_repo=metrics_repo, metrics=metrics, diff --git a/components/renku_data_services/connected_services/api.spec.yaml b/components/renku_data_services/connected_services/api.spec.yaml index d783c68b3..ed4680449 100644 --- a/components/renku_data_services/connected_services/api.spec.yaml +++ b/components/renku_data_services/connected_services/api.spec.yaml @@ -426,12 +426,11 @@ components: ProviderKind: type: string enum: - - "gitlab" - - "github" - - "drive" - - "onedrive" - "dropbox" - "generic_oidc" + - "github" + - "gitlab" + - "google" example: "gitlab" ApplicationSlug: description: | diff --git a/components/renku_data_services/connected_services/apispec.py b/components/renku_data_services/connected_services/apispec.py index f6f9c5417..7e187633a 100644 --- a/components/renku_data_services/connected_services/apispec.py +++ b/components/renku_data_services/connected_services/apispec.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: api.spec.yaml -# timestamp: 2025-09-05T11:16:18+00:00 +# timestamp: 2026-01-21T11:52:43+00:00 from __future__ import annotations @@ -29,12 +29,11 @@ class AppInstallation(BaseAPISpec): class ProviderKind(Enum): - gitlab = "gitlab" - github = "github" - drive = "drive" - onedrive = "onedrive" dropbox = "dropbox" generic_oidc = "generic_oidc" + github = "github" + gitlab = "gitlab" + google = "google" class ConnectionStatus(Enum): diff --git a/components/renku_data_services/connected_services/apispec_extras.py b/components/renku_data_services/connected_services/apispec_extras.py new file mode 100644 index 000000000..252e650c8 --- /dev/null +++ b/components/renku_data_services/connected_services/apispec_extras.py @@ -0,0 +1,64 @@ +"""Extra definitions for the API spec.""" + +from __future__ import annotations + +import base64 +from enum import StrEnum +from typing import Self + +from pydantic import ConfigDict, Field + +from renku_data_services.connected_services.apispec_base import BaseAPISpec + + +class PostTokenGrantType(StrEnum): + """Grant type for token refresh.""" + + refresh_token = "refresh_token" # nosec B105 + + +class PostTokenRequest(BaseAPISpec): + """Body for a refresh token request.""" + + model_config = ConfigDict( + extra="allow", + ) + grant_type: PostTokenGrantType + refresh_token: str + client_id: str | None = Field(None) + client_secret: str | None = Field(None) + + +class PostTokenResponse(BaseAPISpec): + """Response for a refresh token request.""" + + model_config = ConfigDict( + extra="allow", + ) + access_token: str + token_type: str + expires_in: int + refresh_token: str + refresh_expires_in: int | None = Field(None) + scope: str | None + + +class RenkuTokens(BaseAPISpec): + """Represents a set of authentication tokens used in Renku.""" + + model_config = ConfigDict( + extra="forbid", + ) + access_token: str + refresh_token: str + + def encode(self) -> str: + """Encode the Renku tokens as a single URL-safe string.""" + as_json = self.model_dump_json() + return base64.urlsafe_b64encode(as_json.encode("utf-8")).decode("utf-8") + + @classmethod + def decode(cls, encoded: str) -> Self: + """Decode a single string into a set of Renku tokens.""" + json_raw = base64.urlsafe_b64decode(encoded.encode("utf-8")) + return cls.model_validate_json(json_raw) diff --git a/components/renku_data_services/connected_services/blueprints.py b/components/renku_data_services/connected_services/blueprints.py index 73d871a5d..74e244306 100644 --- a/components/renku_data_services/connected_services/blueprints.py +++ b/components/renku_data_services/connected_services/blueprints.py @@ -1,9 +1,13 @@ """Connected services blueprint.""" +import math from dataclasses import dataclass -from typing import Any +from datetime import UTC, datetime +from typing import Any, cast from urllib.parse import unquote, urlparse, urlunparse +import httpx +import jwt from sanic import HTTPResponse, Request, empty, json, redirect from sanic.response import JSONResponse from sanic_ext import validate @@ -17,7 +21,7 @@ from renku_data_services.base_api.misc import validate_query from renku_data_services.base_api.pagination import PaginationRequest, paginate from renku_data_services.base_models.validation import validate_and_dump, validated_json -from renku_data_services.connected_services import apispec +from renku_data_services.connected_services import apispec, apispec_extras from renku_data_services.connected_services.apispec_base import AuthorizeParams, CallbackParams from renku_data_services.connected_services.core import validate_oauth2_client_patch, validate_unsaved_oauth2_client from renku_data_services.connected_services.db import ConnectedServicesRepository @@ -26,6 +30,7 @@ OAuthHttpError, OAuthHttpFactoryError, ) +from renku_data_services.notebooks.config import NotebooksConfig logger = logging.getLogger(__name__) @@ -159,6 +164,7 @@ class OAuth2ConnectionsBP(CustomBlueprint): connected_services_repo: ConnectedServicesRepository oauth_client_factory: OAuthHttpClientFactory authenticator: base_models.Authenticator + nb_config: NotebooksConfig def get_all(self) -> BlueprintFactoryResponse: """List all OAuth2 connections.""" @@ -202,7 +208,7 @@ async def _get_account(_: Request, user: base_models.APIUser, connection_id: ULI account = await client.get_connected_account() match account: case OAuthHttpError() as err: - raise errors.InvalidTokenError(message=f"OAuth error getting the connected accoun: {err}") + raise errors.InvalidTokenError(message=f"OAuth error getting the connected account: {err}") case account: return validated_json(apispec.ConnectedAccount, account) @@ -245,3 +251,122 @@ async def _get_installations( return body, installations_list.total_count return "/oauth2/connections//installations", ["GET"], _get_installations + + def post_token_endpoint(self) -> BlueprintFactoryResponse: + """OAuth 2.0 token endpoint to support applications running in sessions. + + Details: + 1. Decode the refresh_token value into an instance of RenkuTokens + 2. Validate the access_token + -> if the access_token is invalid (expired), use the renku refresh_token + to get a fresh set of tokens + 3. Send back the refreshed OAuth 2.0 access token and a the encoded value + of the current RenkuTokens + """ + + @validate(form=apispec_extras.PostTokenRequest) + async def _post_token_endpoint( + request: Request, body: apispec_extras.PostTokenRequest, connection_id: ULID + ) -> JSONResponse: + renku_tokens = apispec_extras.RenkuTokens.decode(body.refresh_token) + # NOTE: inject the access token in the headers so that we can use `self.authenticator` + request.headers[self.authenticator.token_field] = renku_tokens.access_token + + user: base_models.APIUser | None = None + try: + _user = cast( + base_models.APIUser, + await self.authenticator.authenticate( + access_token=renku_tokens.access_token or "", request=request + ), + ) + if _user.is_authenticated and _user.access_token: + user = _user + except Exception as err: + logger.error(f"Got authenticate error: {err.__class__}.") + raise + + # Try to refresh the Renku access token + if user is None and renku_tokens.refresh_token: + renku_base_url = "https://" + self.nb_config.sessions.ingress.host + renku_base_url = renku_base_url.rstrip("/") + renku_realm = self.nb_config.keycloak_realm + renku_auth_token_uri = f"{renku_base_url}/auth/realms/{renku_realm}/protocol/openid-connect/token" + + async with httpx.AsyncClient(timeout=10) as http: + auth = ( + self.nb_config.sessions.git_proxy.renku_client_id, + self.nb_config.sessions.git_proxy.renku_client_secret, + ) + payload = { + "grant_type": "refresh_token", + "refresh_token": renku_tokens.refresh_token, + } + response = await http.post(renku_auth_token_uri, auth=auth, data=payload, follow_redirects=True) + if 200 <= response.status_code < 300: + try: + parsed_response = apispec_extras.PostTokenResponse.model_validate_json(response.content) + except Exception as err: + logger.error(f"Failed to parse refreshed Renku tokens: {err.__class__}.") + raise + try: + renku_tokens.access_token = parsed_response.access_token + renku_tokens.refresh_token = parsed_response.refresh_token + request.headers[self.authenticator.token_field] = renku_tokens.access_token + _user = cast( + base_models.APIUser, + await self.authenticator.authenticate( + access_token=renku_tokens.access_token or "", request=request + ), + ) + if _user.is_authenticated and _user.access_token: + user = _user + except Exception as err: + logger.error(f"Got authenticate error: {err.__class__}.") + raise + else: + logger.error( + f"Got error from refreshing Renku tokens: HTTP {response.status_code}; {response.json()}." + ) + raise errors.UnauthorizedError() + + if user is None or not user.is_authenticated: + raise errors.UnauthorizedError() + + client = await self.oauth_client_factory.for_user_connection_raise(user, connection_id) + oauth_token = await client.get_token() + access_token = oauth_token.access_token + if access_token is None: + raise errors.ProgrammingError(message="Unexpected error: access token not present.") + result: dict[str, str | int] = { + "access_token": access_token, + "token_type": str(oauth_token.get("token_type")) or "Bearer", + "refresh_token": renku_tokens.encode(), + } + if oauth_token.get("scope"): + result["scope"] = oauth_token["scope"] + # NOTE: Set "expires_in" according to whichever of the OAuth 2.0 access token or the Renku refresh + # token expires first. + try: + refresh_decoded: dict[str, Any] = jwt.decode( + renku_tokens.refresh_token, options={"verify_signature": False} + ) + refresh_exp: int | None = refresh_decoded.get("exp") + if refresh_exp is not None and refresh_exp > 0: + exp = datetime.fromtimestamp(refresh_exp, UTC) + expires_in = exp - datetime.now(UTC) + result["expires_in"] = math.ceil(expires_in.total_seconds()) + except Exception as err: + logger.error(f"Could not parse Renku refresh token; cannot determine its expiration: {err.__class__}.") + if oauth_token.expires_at: + exp = datetime.fromtimestamp(oauth_token.expires_at, UTC) + expires_in = exp - datetime.now(UTC) + result_expires_in = result.get("expires_in") + if isinstance(result_expires_in, int) and result_expires_in > 0: + result["expires_in"] = min(result_expires_in, math.ceil(expires_in.total_seconds())) + else: + result["expires_in"] = math.ceil(expires_in.total_seconds()) + + return validated_json(apispec_extras.PostTokenResponse, result) + + return "/oauth2/connections//token_endpoint", ["POST"], _post_token_endpoint diff --git a/components/renku_data_services/connected_services/db.py b/components/renku_data_services/connected_services/db.py index bfb773f0d..cba3afd94 100644 --- a/components/renku_data_services/connected_services/db.py +++ b/components/renku_data_services/connected_services/db.py @@ -5,6 +5,7 @@ from sqlalchemy import and_, select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from ulid import ULID import renku_data_services.base_models as base_models @@ -241,7 +242,7 @@ async def get_oauth2_connection(self, connection_id: ULID, user: base_models.API return connection async def get_provider_for_image(self, user: APIUser, image: Image) -> models.ImageProvider | None: - """Find a provider supporting the given an image.""" + """Find a provider supporting the given image.""" registry_urls = [f"http://{image.hostname}", f"https://{image.hostname}"] async with self.session_maker() as session: stmt = ( @@ -273,6 +274,51 @@ async def get_provider_for_image(self, user: APIUser, image: Image) -> models.Im str(row.OAuth2ClientORM.image_registry_url), # above query makes it non-nil ) + async def get_provider_for_kind( + self, user: APIUser, provider_kind: models.ProviderKind + ) -> models.ServiceProvider | None: + """Find a service provider of a given kind.""" + async with self.session_maker() as session: + # First, match an established connection if it exists + stmt = ( + select(schemas.OAuth2ConnectionORM) + .join(schemas.OAuth2ClientORM) + .where(schemas.OAuth2ConnectionORM.user_id == user.id) + .where(schemas.OAuth2ConnectionORM.status == models.ConnectionStatus.connected.value) + .where(schemas.OAuth2ClientORM.kind == provider_kind.value) + .options(selectinload(schemas.OAuth2ConnectionORM.client)) + .limit(1) + ) + res = await session.scalars(stmt) + connection = res.one_or_none() + if connection is not None: + return models.ServiceProvider( + provider=connection.client.dump(), + connected_user=models.ConnectedUser(connection=connection.dump(), user=user), + ) + # Otherwise, match the first suitable provider + provider_stmt = ( + select(schemas.OAuth2ClientORM).where(schemas.OAuth2ClientORM.kind == provider_kind.value).limit(1) + ) + provider_res = await session.scalars(provider_stmt) + provider = provider_res.one_or_none() + if provider is not None: + return models.ServiceProvider( + provider=provider.dump(), + connected_user=None, + ) + return None + + async def get_token_set(self, user: APIUser, connection_id: ULID) -> models.OAuth2TokenSet | None: + """Returns the token set from a given OAuth2 connection.""" + client_or_error = await self.oauth_client_factory.for_user_connection(user=user, connection_id=connection_id) + match client_or_error: + case OAuthHttpFactoryError() as err: + logger.info(f"Error getting oauth client for user={user} connection={connection_id}: {err}") + return None + case client: + return await client.get_token() + async def get_image_repo_client(self, image_provider: models.ImageProvider) -> ImageRepoDockerAPI: """Create a image repository client for the given user and image provider.""" url = urlparse(image_provider.registry_url) @@ -282,15 +328,9 @@ async def get_image_repo_client(self, image_provider: models.ImageProvider) -> I user = image_provider.connected_user.user conn = image_provider.connected_user.connection access_token: str | None = None - client_or_error = await self.oauth_client_factory.for_user_connection(user, conn.id) - match client_or_error: - case OAuthHttpFactoryError() as err: - logger.info(f"Error getting oauth client for user={user} connection={conn.id}: {err}") - - case client: - token_set = await client.get_token() - access_token = token_set.access_token - + token_set = await self.get_token_set(user=user, connection_id=conn.id) + if token_set is not None: + access_token = token_set.access_token if access_token: logger.debug(f"Use connection {conn.id} to {image_provider.provider.id} for user {user.id}") repo_api = repo_api.with_oauth2_token(access_token) diff --git a/components/renku_data_services/connected_services/external_models.py b/components/renku_data_services/connected_services/external_models.py index 19cb51b5a..3cf81b74e 100644 --- a/components/renku_data_services/connected_services/external_models.py +++ b/components/renku_data_services/connected_services/external_models.py @@ -63,43 +63,38 @@ def to_app_installation_list(self) -> models.AppInstallationList: ) -class GoogleDriveConnectedAccount(BaseModel): - """OAuth2 connected account model for google drive.""" +class GoogleConnectedAccount(BaseModel): + """OAuth2 connected account model for Google.""" - name: str email: str def to_connected_account(self) -> models.ConnectedAccount: """Returns the corresponding ConnectedAccount object.""" - return models.ConnectedAccount(username=self.name, web_url=f"mailto:{self.email}") + return models.ConnectedAccount(username=self.email, web_url="") -class OneDriveConnectedAccount(BaseModel): - """OAuth2 connected account model for onedrive.""" +# class OneDriveConnectedAccount(BaseModel): +# """OAuth2 connected account model for onedrive.""" - givenname: str - familyname: str - email: str +# givenname: str +# familyname: str +# email: str - def to_connected_account(self) -> models.ConnectedAccount: - """Returns the corresponding ConnectedAccount object.""" - return models.ConnectedAccount( - username=" ".join(filter(None, [self.givenname, self.familyname])), web_url=f"mailto:{self.email}" - ) +# def to_connected_account(self) -> models.ConnectedAccount: +# """Returns the corresponding ConnectedAccount object.""" +# return models.ConnectedAccount( +# username=" ".join(filter(None, [self.givenname, self.familyname])), web_url=f"mailto:{self.email}" +# ) class DropboxConnectedAccount(BaseModel): """OAuth2 connected account model for dropbox.""" - family_name: str | None - given_name: str | None email: str def to_connected_account(self) -> models.ConnectedAccount: """Returns the corresponding ConnectedAccount object.""" - return models.ConnectedAccount( - username=" ".join(filter(None, [self.given_name, self.family_name])), web_url=f"mailto:{self.email}" - ) + return models.ConnectedAccount(username=self.email, web_url="") class GenericOIDCConnectedAccount(BaseModel): diff --git a/components/renku_data_services/connected_services/models.py b/components/renku_data_services/connected_services/models.py index c7f3768a0..72f0f6247 100644 --- a/components/renku_data_services/connected_services/models.py +++ b/components/renku_data_services/connected_services/models.py @@ -13,12 +13,11 @@ class ProviderKind(StrEnum): """The kind of platform we connnect to.""" - gitlab = "gitlab" - github = "github" - drive = "drive" - onedrive = "onedrive" dropbox = "dropbox" generic_oidc = "generic_oidc" + github = "github" + gitlab = "gitlab" + google = "google" class ConnectionStatus(StrEnum): @@ -99,7 +98,7 @@ class ConnectedAccount: web_url: str -class OAuth2TokenSet(dict): +class OAuth2TokenSet(dict[str, Any]): """OAuth2 token set model.""" @classmethod @@ -171,12 +170,11 @@ def is_connected(self) -> bool: @dataclass(frozen=True, eq=True) -class ImageProvider: - """Result when retrieving provider information for an image.""" +class ServiceProvider: + """Result when retrieving provider information for a connected service.""" provider: OAuth2Client connected_user: ConnectedUser | None - registry_url: str def is_connected(self) -> bool: """Returns whether the connection exists and is in status 'connected'.""" @@ -185,10 +183,18 @@ def is_connected(self) -> bool: @property def connection(self) -> OAuth2Connection | None: """Return the connection if present.""" - if self.connected_user: - return self.connected_user.connection - else: - return None + return self.connected_user.connection if self.connected_user else None + + def __str__(self) -> str: + conn = f"connection={self.connection.id}" if self.connection else "connection=None" + return f"ServiceProvider(provider={self.provider.id}/{self.provider.kind}, {conn})" + + +@dataclass(frozen=True, eq=True) +class ImageProvider(ServiceProvider): + """Result when retrieving provider information for an image.""" + + registry_url: str def __str__(self) -> str: conn = f"connection={self.connection.id}" if self.connection else "connection=None" diff --git a/components/renku_data_services/connected_services/oauth_http.py b/components/renku_data_services/connected_services/oauth_http.py index 436306868..60f1226ee 100644 --- a/components/renku_data_services/connected_services/oauth_http.py +++ b/components/renku_data_services/connected_services/oauth_http.py @@ -41,7 +41,7 @@ from renku_data_services.users.db import APIUser from renku_data_services.utils import cryptography as crypt -logger = logging.getLogger(__file__) +logger = logging.getLogger(__name__) class OAuthHttpFactoryError(StrEnum): @@ -274,7 +274,14 @@ async def get_connected_account(self) -> OAuthHttpError | models.ConnectedAccoun request_url = urljoin(self.adapter.api_url, self.adapter.user_info_endpoint) try: if self.adapter.user_info_method == "POST": - response = await self._delegate.post(request_url, headers=self.adapter.api_common_headers) + # NOTE: we need to remove "Content-Type" from the headers (empty post) + headers: dict[str, str] | None = None + if self.adapter.api_common_headers: + headers = dict() + for key, value in self.adapter.api_common_headers.items(): + if key.lower() != "content-type": + headers[key] = value + response = await self._delegate.post(request_url, headers=headers) else: response = await self.get(request_url, headers=self.adapter.api_common_headers) except OAuthError as e: diff --git a/components/renku_data_services/connected_services/provider_adapters.py b/components/renku_data_services/connected_services/provider_adapters.py index b49fea2a2..6de3db9c4 100644 --- a/components/renku_data_services/connected_services/provider_adapters.py +++ b/components/renku_data_services/connected_services/provider_adapters.py @@ -1,6 +1,5 @@ """Adapters for each kind of OAuth2 client.""" -import logging from abc import ABC, abstractmethod from typing import Any from urllib.parse import urljoin, urlparse, urlunparse @@ -11,8 +10,6 @@ from renku_data_services.connected_services import external_models, models from renku_data_services.connected_services import orm as schemas -logger = logging.getLogger(__name__) - class ProviderAdapter(ABC): """Defines the functionality of OAuth2 client adapters.""" @@ -120,8 +117,8 @@ def api_validate_app_installations_response(self, response: Response) -> models. return external_models.GitHubAppInstallationList.model_validate(response.json()).to_app_installation_list() -class GoogleDriveAdapter(ProviderAdapter): - """Adapter for Google Drive OAuth2 clients.""" +class GoogleAdapter(ProviderAdapter): + """Adapter for Google OAuth2 clients.""" user_info_endpoint = "userinfo" @@ -133,7 +130,7 @@ def authorization_url(self) -> str: @property def authorization_url_extra_params(self) -> dict[str, str]: """Extra parameters to add to the auth url.""" - return {"access_type": "offline"} + return {"access_type": "offline", "prompt": "consent"} @property def token_endpoint_url(self) -> str: @@ -155,45 +152,45 @@ def api_common_headers(self) -> dict[str, str] | None: def api_validate_account_response(self, response: Response) -> models.ConnectedAccount: """Validates and returns the connected account response from the Resource Server.""" - return external_models.GoogleDriveConnectedAccount.model_validate(response.json()).to_connected_account() + return external_models.GoogleConnectedAccount.model_validate(response.json()).to_connected_account() -class OneDriveAdapter(ProviderAdapter): - """Adapter for One Drive OAuth2 clients.""" +# class OneDriveAdapter(ProviderAdapter): +# """Adapter for One Drive OAuth2 clients.""" - user_info_endpoint = "userinfo" +# user_info_endpoint = "userinfo" - @property - def authorization_url(self) -> str: - """The authorization URL for the OAuth2 protocol.""" - return "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" +# @property +# def authorization_url(self) -> str: +# """The authorization URL for the OAuth2 protocol.""" +# return "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" - @property - def authorization_url_extra_params(self) -> dict[str, str]: - """Extra parameters to add to the auth url.""" - return {"access_type": "offline"} +# @property +# def authorization_url_extra_params(self) -> dict[str, str]: +# """Extra parameters to add to the auth url.""" +# return {"access_type": "offline"} - @property - def token_endpoint_url(self) -> str: - """The token endpoint URL for the OAuth2 protocol.""" - return "https://login.microsoftonline.com/common/oauth2/v2.0/token" +# @property +# def token_endpoint_url(self) -> str: +# """The token endpoint URL for the OAuth2 protocol.""" +# return "https://login.microsoftonline.com/common/oauth2/v2.0/token" - @property - def api_url(self) -> str: - """The URL used for API calls on the Resource Server.""" - return "https://graph.microsoft.com/oidc/" +# @property +# def api_url(self) -> str: +# """The URL used for API calls on the Resource Server.""" +# return "https://graph.microsoft.com/oidc/" - @property - def api_common_headers(self) -> dict[str, str] | None: - """The HTTP headers used for API calls on the Resource Server.""" - return { - "Accept": "application/json", - "Content-Type": "application/json", - } +# @property +# def api_common_headers(self) -> dict[str, str] | None: +# """The HTTP headers used for API calls on the Resource Server.""" +# return { +# "Accept": "application/json", +# "Content-Type": "application/json", +# } - def api_validate_account_response(self, response: Response) -> models.ConnectedAccount: - """Validates and returns the connected account response from the Resource Server.""" - return external_models.OneDriveConnectedAccount.model_validate(response.json()).to_connected_account() +# def api_validate_account_response(self, response: Response) -> models.ConnectedAccount: +# """Validates and returns the connected account response from the Resource Server.""" +# return external_models.OneDriveConnectedAccount.model_validate(response.json()).to_connected_account() class DropboxAdapter(ProviderAdapter): @@ -210,7 +207,7 @@ def authorization_url(self) -> str: @property def authorization_url_extra_params(self) -> dict[str, str]: """Extra parameters to add to the auth url.""" - return {"access_type": "offline"} + return {"token_access_type": "offline"} @property def token_endpoint_url(self) -> str: @@ -314,12 +311,11 @@ def __get_httpx_client(cls) -> Client: _adapter_map: dict[models.ProviderKind, type[ProviderAdapter]] = { - models.ProviderKind.gitlab: GitLabAdapter, - models.ProviderKind.github: GitHubAdapter, - models.ProviderKind.drive: GoogleDriveAdapter, - models.ProviderKind.onedrive: OneDriveAdapter, models.ProviderKind.dropbox: DropboxAdapter, models.ProviderKind.generic_oidc: GenericOidcAdapter, + models.ProviderKind.github: GitHubAdapter, + models.ProviderKind.gitlab: GitLabAdapter, + models.ProviderKind.google: GoogleAdapter, } @@ -330,5 +326,7 @@ def get_provider_adapter(client: schemas.OAuth2ClientORM) -> ProviderAdapter: if not client.url: raise errors.ValidationError(message=f"URL not defined for provider {client.id}.") - adapter_class = _adapter_map[client.kind] + adapter_class = _adapter_map.get(client.kind) + if adapter_class is None: + raise errors.ProgrammingError(message=f"Provider adapter not implemented for kind {client.kind}.") return adapter_class(client_url=client.url, oidc_issuer_url=client.oidc_issuer_url) diff --git a/components/renku_data_services/k8s/client_interfaces.py b/components/renku_data_services/k8s/client_interfaces.py index 9bdaf05b7..c038e85ba 100644 --- a/components/renku_data_services/k8s/client_interfaces.py +++ b/components/renku_data_services/k8s/client_interfaces.py @@ -51,6 +51,10 @@ async def patch_resource_quota( class SecretClient(Protocol): """Methods to manipulate Secret kubernetes resources.""" + async def get_secret(self, secret: K8sObjectMeta) -> K8sSecret | None: + """Get a secret.""" + ... + async def create_secret(self, secret: K8sSecret) -> K8sSecret: """Create a secret.""" ... diff --git a/components/renku_data_services/k8s/clients.py b/components/renku_data_services/k8s/clients.py index 7285d1304..73cf7637a 100644 --- a/components/renku_data_services/k8s/clients.py +++ b/components/renku_data_services/k8s/clients.py @@ -107,6 +107,11 @@ class K8sSecretClient(SecretClient): def __init__(self, client: K8sClient) -> None: self.__client = client + async def get_secret(self, secret: K8sObjectMeta) -> K8sSecret | None: + """Get a secret.""" + res = await self.__client.get(secret) + return K8sSecret.from_k8s_object(res) if res is not None else None + async def create_secret(self, secret: K8sSecret) -> K8sSecret: """Create a secret.""" @@ -211,6 +216,10 @@ async def patch_resource_quota( """Update a resource quota.""" raise NotImplementedError() + async def get_secret(self, secret: K8sObjectMeta) -> K8sSecret | None: + """Get a secret.""" + raise NotImplementedError() + async def create_secret(self, secret: K8sSecret) -> K8sSecret: """Create a secret.""" raise NotImplementedError() diff --git a/components/renku_data_services/migrations/versions/0bfc18c91b05_removeme.py b/components/renku_data_services/migrations/versions/0bfc18c91b05_removeme.py new file mode 100644 index 000000000..24bf432dd --- /dev/null +++ b/components/renku_data_services/migrations/versions/0bfc18c91b05_removeme.py @@ -0,0 +1,21 @@ +"""removeme: merge migrations heads + +Revision ID: 0bfc18c91b05 +Revises: 287879848fb3, fddfe7960a8b +Create Date: 2026-02-16 12:37:55.425957 + +""" + +# revision identifiers, used by Alembic. +revision = "0bfc18c91b05" +down_revision = ("287879848fb3", "fddfe7960a8b") +branch_labels = None +depends_on = None + + +def upgrade() -> None: + pass + + +def downgrade() -> None: + pass diff --git a/components/renku_data_services/migrations/versions/58ad5426c2f3_upgrade_oauth_provider_kind_enum.py b/components/renku_data_services/migrations/versions/58ad5426c2f3_upgrade_oauth_provider_kind_enum.py new file mode 100644 index 000000000..8c38fd8c9 --- /dev/null +++ b/components/renku_data_services/migrations/versions/58ad5426c2f3_upgrade_oauth_provider_kind_enum.py @@ -0,0 +1,51 @@ +"""upgrade oauth provider kind enum + +Revision ID: 58ad5426c2f3 +Revises: 9b18adb58e63 +Create Date: 2026-01-14 14:35:29.539830 + +""" + +from alembic import op +from sqlalchemy.exc import OperationalError + +from renku_data_services.app_config import logging + +# revision identifiers, used by Alembic. +revision = "58ad5426c2f3" +down_revision = "9b18adb58e63" +branch_labels = None +depends_on = None + +logger = logging.getLogger(__name__) + +# NOTE: Postgres does not allow removing values from an enum + + +def upgrade() -> None: + connection = op.get_bind() + with connection.begin_nested() as tx: + try: + op.execute("DELETE FROM connected_services.oauth2_clients WHERE kind = 'drive'") + op.execute("DELETE FROM connected_services.oauth2_clients WHERE kind = 'onedrive'") + op.execute("DELETE FROM connected_services.oauth2_clients WHERE kind = 'dropbox'") + tx.commit() + except OperationalError as err: + logger.debug(f"Skipped DELETE section from migration of the connected_services.oauth2_clients table: {err}") + tx.rollback() + op.execute("ALTER TYPE providerkind RENAME TO providerkind_old") + op.execute("CREATE TYPE providerkind AS ENUM ('gitlab', 'github', 'google', 'generic_oidc')") + op.execute( + "ALTER TABLE connected_services.oauth2_clients ALTER COLUMN kind SET DATA TYPE providerkind USING kind::text::providerkind" + ) + op.execute("DROP TYPE providerkind_old CASCADE") + + +def downgrade() -> None: + op.execute("DELETE FROM connected_services.oauth2_clients WHERE kind = 'google'") + op.execute("ALTER TYPE providerkind RENAME TO providerkind_old") + op.execute("CREATE TYPE providerkind AS ENUM ('gitlab', 'github', 'drive', 'onedrive', 'dropbox', 'generic_oidc')") + op.execute( + "ALTER TABLE connected_services.oauth2_clients ALTER COLUMN kind SET DATA TYPE providerkind USING kind::text::providerkind" + ) + op.execute("DROP TYPE providerkind_old CASCADE") diff --git a/components/renku_data_services/migrations/versions/fddfe7960a8b_squash_me.py b/components/renku_data_services/migrations/versions/fddfe7960a8b_squash_me.py new file mode 100644 index 000000000..147ac40fc --- /dev/null +++ b/components/renku_data_services/migrations/versions/fddfe7960a8b_squash_me.py @@ -0,0 +1,23 @@ +"""squash me + +Revision ID: fddfe7960a8b +Revises: 58ad5426c2f3 +Create Date: 2026-01-21 11:52:05.169734 + +""" + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "fddfe7960a8b" +down_revision = "58ad5426c2f3" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.execute("ALTER TYPE providerkind ADD VALUE 'dropbox'") + + +def downgrade() -> None: + pass diff --git a/components/renku_data_services/notebooks/api/classes/k8s_client.py b/components/renku_data_services/notebooks/api/classes/k8s_client.py index f48aee437..955f1c490 100644 --- a/components/renku_data_services/notebooks/api/classes/k8s_client.py +++ b/components/renku_data_services/notebooks/api/classes/k8s_client.py @@ -386,6 +386,10 @@ async def patch_image_pull_secret(self, session_name: str, gitlab_token: GitlabT await secret.patch(patch, type="json") + async def get_secret(self, secret: K8sObjectMeta) -> K8sSecret | None: + """Get a secret.""" + return await self.__secrets_client.get_secret(secret) + async def create_secret(self, secret: K8sSecret) -> K8sSecret: """Create a secret.""" diff --git a/components/renku_data_services/notebooks/blueprints.py b/components/renku_data_services/notebooks/blueprints.py index d5f84439a..a5ed094bb 100644 --- a/components/renku_data_services/notebooks/blueprints.py +++ b/components/renku_data_services/notebooks/blueprints.py @@ -28,6 +28,7 @@ start_session, validate_session_post_request, ) +from renku_data_services.notebooks.data_sources import DataSourceRepository from renku_data_services.notebooks.image_check import ImageCheckRepository from renku_data_services.project.db import ProjectRepository, ProjectSessionSecretRepository from renku_data_services.session.db import SessionRepository @@ -49,6 +50,7 @@ class NotebooksNewBP(CustomBlueprint): data_connector_secret_repo: DataConnectorSecretRepository git_provider_helper: GitProviderHelperProto oauth_client_factory: OAuthHttpClientFactory + data_source_repo: DataSourceRepository image_check_repo: ImageCheckRepository project_repo: ProjectRepository project_session_secret_repo: ProjectSessionSecretRepository @@ -86,6 +88,7 @@ async def _handler( user_repo=self.user_repo, metrics=self.metrics, image_check_repo=self.image_check_repo, + data_source_repo=self.data_source_repo, ) status = 201 if created else 200 return json(session.as_apispec().model_dump(exclude_none=True, mode="json"), status) @@ -134,25 +137,28 @@ def patch(self) -> BlueprintFactoryResponse: @authenticate_2(self.authenticator, self.internal_gitlab_authenticator) @validate(json=apispec.SessionPatchRequest) async def _handler( - _: Request, + request: Request, user: AuthenticatedAPIUser | AnonymousAPIUser, internal_gitlab_user: APIUser, session_id: str, body: apispec.SessionPatchRequest, ) -> HTTPResponse: new_session = await patch_session( + request=request, body=body, session_id=session_id, user=user, internal_gitlab_user=internal_gitlab_user, nb_config=self.nb_config, git_provider_helper=self.git_provider_helper, + data_connector_secret_repo=self.data_connector_secret_repo, project_repo=self.project_repo, project_session_secret_repo=self.project_session_secret_repo, rp_repo=self.rp_repo, session_repo=self.session_repo, metrics=self.metrics, image_check_repo=self.image_check_repo, + data_source_repo=self.data_source_repo, ) return json(new_session.as_apispec().model_dump(exclude_none=True, mode="json")) diff --git a/components/renku_data_services/notebooks/core_sessions.py b/components/renku_data_services/notebooks/core_sessions.py index 61c68004e..3409421ba 100644 --- a/components/renku_data_services/notebooks/core_sessions.py +++ b/components/renku_data_services/notebooks/core_sessions.py @@ -38,7 +38,7 @@ ) from renku_data_services.data_connectors.models import DataConnectorSecret, DataConnectorWithSecrets from renku_data_services.errors import ValidationError, errors -from renku_data_services.k8s.models import K8sSecret, sanitizer +from renku_data_services.k8s.models import ClusterConnection, K8sSecret, sanitizer from renku_data_services.notebooks import apispec, core from renku_data_services.notebooks.api.amalthea_patches import git_proxy, init_containers from renku_data_services.notebooks.api.amalthea_patches.init_containers import user_secrets_extras @@ -84,6 +84,7 @@ Storage, TlsSecret, ) +from renku_data_services.notebooks.data_sources import DataSourceRepository from renku_data_services.notebooks.image_check import ImageCheckRepository from renku_data_services.notebooks.models import ( ExtraSecret, @@ -260,6 +261,7 @@ async def __get_gitlab_image_pull_secret( async def get_data_sources( + request: Request, nb_config: NotebooksConfig, user: AnonymousAPIUser | AuthenticatedAPIUser, server_name: str, @@ -267,6 +269,7 @@ async def get_data_sources( work_dir: PurePosixPath, data_connectors_overrides: list[SessionDataConnectorOverride], user_repo: UserRepo, + data_source_repo: DataSourceRepository, ) -> SessionExtraResources: """Generate cloud storage related resources.""" data_sources: list[DataSource] = [] @@ -275,6 +278,11 @@ async def get_data_sources( dcs_secrets: dict[str, list[DataConnectorSecret]] = {} user_secret_key: str | None = None async for dc in data_connectors_stream: + configuration = await data_source_repo.handle_configuration( + request=request, user=user, data_connector=dc.data_connector + ) + if configuration is None: + continue mount_folder = ( dc.data_connector.storage.target_path if PurePosixPath(dc.data_connector.storage.target_path).is_absolute() @@ -283,7 +291,7 @@ async def get_data_sources( dcs[str(dc.data_connector.id)] = RCloneStorage( source_path=dc.data_connector.storage.source_path, mount_folder=mount_folder, - configuration=dc.data_connector.storage.configuration, + configuration=configuration, readonly=dc.data_connector.storage.readonly, name=dc.data_connector.name, secrets={str(secret.secret_id): secret.name for secret in dc.secrets}, @@ -348,6 +356,83 @@ async def get_data_sources( ) +async def patch_data_sources( + request: Request, + user: AnonymousAPIUser | AuthenticatedAPIUser, + session: AmaltheaSessionV1Alpha1, + cluster: ClusterConnection, + nb_config: NotebooksConfig, + data_connectors_stream: AsyncIterator[DataConnectorWithSecrets], + data_source_repo: DataSourceRepository, +) -> SessionExtraResources: + """Handle updating data sources definitions when resuming a session. + + This touches data connectors which use OAuth2 tokens for access. + Other data connectors are left untouched. + """ + secrets: list[ExtraSecret] = [] + server_name = session.metadata.name + secret_prefix = f"{server_name}-ds-" + dss = session.spec.dataSources or [] + mounted_dcs: list[tuple[ULID, str]] = [] + for ds in dss: + if ds.secretRef is not None: + name = ds.secretRef.name + if name.startswith(secret_prefix): + ulid = name[len(secret_prefix) :] + try: + mounted_dcs.append((ULID.from_str(ulid.upper()), name)) + except ValueError: + logger.warning(f"Could not parse {ulid.upper()} as a ULID.") + async for dc in data_connectors_stream: + if not data_source_repo.is_patching_enabled(dc.data_connector): + continue + dc_id = dc.data_connector.id + mounted_dc = next(filter(lambda tup: tup[0] == dc_id, mounted_dcs), None) + if mounted_dc is None: + continue + _, secret_name = mounted_dc + logger.debug(f"Patching DC secret {secret_name} for data connector {str(dc_id)}.") + k8s_secret = await nb_config.k8s_v2_client.get_secret( + K8sSecret.from_v1_secret(V1Secret(metadata=V1ObjectMeta(name=secret_name)), cluster) + ) + if k8s_secret is None: + logger.warning(f"Could not read secret {secret_name} for patching, skipping!") + continue + v1_secret = k8s_secret.to_v1_secret() + secret_data: dict[str, str] = v1_secret.data + config_data_raw = secret_data.get("configData") + if not config_data_raw: + logger.warning(f"Field 'configData' not found for data connector {str(dc_id)}, skipping!") + continue + existing_config_data: str = "" + try: + existing_config_data = base64.b64decode(config_data_raw).decode("utf-8") + except Exception as err: + logger.warning(f"Error decoding 'configData' for data connector {str(dc_id)}, skipping! {err}") + continue + new_config_data = await data_source_repo.handle_patching_configuration( + request=request, user=user, data_connector=dc.data_connector, config_data=existing_config_data + ) + if not new_config_data: + continue + # We re-create the secret for the data connector, with the updated configuration. + metadata = v1_secret.metadata + new_secret = V1Secret( + api_version="v1", + kind="Secret", + metadata=V1ObjectMeta( + name=metadata.name, + namespace=metadata.namespace, + ), + data=secret_data, + ) + new_secret.data["configData"] = base64.b64encode(new_config_data.encode("utf-8")).decode("utf-8") + secrets.append(ExtraSecret(new_secret)) + + return SessionExtraResources(secrets=secrets) + + async def request_dc_secret_creation( user: AuthenticatedAPIUser | AnonymousAPIUser, nb_config: NotebooksConfig, @@ -743,6 +828,7 @@ async def start_session( user_repo: UserRepo, metrics: MetricsService, image_check_repo: ImageCheckRepository, + data_source_repo: DataSourceRepository, ) -> tuple[AmaltheaSessionV1Alpha1, bool]: """Start an Amalthea session. @@ -819,6 +905,7 @@ async def start_session( # Data connectors session_extras = session_extras.concat( await get_data_sources( + request=request, nb_config=nb_config, server_name=server_name, user=user, @@ -826,6 +913,7 @@ async def start_session( work_dir=work_dir, data_connectors_overrides=launch_request.data_connectors_overrides or [], user_repo=user_repo, + data_source_repo=data_source_repo, ) ) @@ -1031,17 +1119,20 @@ async def start_session( async def patch_session( + request: Request, body: apispec.SessionPatchRequest, session_id: str, user: AnonymousAPIUser | AuthenticatedAPIUser, internal_gitlab_user: APIUser, nb_config: NotebooksConfig, git_provider_helper: GitProviderHelperProto, + data_connector_secret_repo: DataConnectorSecretRepository, project_repo: ProjectRepository, project_session_secret_repo: ProjectSessionSecretRepository, rp_repo: ResourcePoolRepository, session_repo: SessionRepository, image_check_repo: ImageCheckRepository, + data_source_repo: DataSourceRepository, metrics: MetricsService, ) -> AmaltheaSessionV1Alpha1: """Patch an Amalthea session.""" @@ -1134,6 +1225,7 @@ async def patch_session( session_secrets = await project_session_secret_repo.get_all_session_secrets_from_project( user=user, project_id=project.id ) + data_connectors_stream = data_connector_secret_repo.get_data_connectors_with_secrets(user, project.id) git_providers = await git_provider_helper.get_providers(user=user) repositories = repositories_from_project(project, git_providers) @@ -1155,6 +1247,17 @@ async def patch_session( # TODO: but that we do not save these overrides (e.g. as annotations) means that # TODO: we cannot patch data connectors upon resume. # TODO: If we did, we would lose the user's provided overrides (e.g. unsaved credentials). + session_extras = session_extras.concat( + await patch_data_sources( + request=request, + user=user, + session=session, + cluster=cluster, + nb_config=nb_config, + data_connectors_stream=data_connectors_stream, + data_source_repo=data_source_repo, + ) + ) # More init containers session_extras = session_extras.concat( diff --git a/components/renku_data_services/notebooks/data_sources.py b/components/renku_data_services/notebooks/data_sources.py new file mode 100644 index 000000000..347cd580b --- /dev/null +++ b/components/renku_data_services/notebooks/data_sources.py @@ -0,0 +1,222 @@ +"""Handling of data sources which require an OAuth2 connection.""" + +import json +from configparser import ConfigParser +from dataclasses import dataclass +from io import StringIO +from typing import TYPE_CHECKING, Any + +from sanic import Request + +from renku_data_services.app_config import logging +from renku_data_services.base_models.core import APIUser +from renku_data_services.connected_services.apispec_extras import RenkuTokens +from renku_data_services.connected_services.db import ConnectedServicesRepository +from renku_data_services.connected_services.models import ProviderKind +from renku_data_services.connected_services.oauth_http import ( + OAuthHttpClientFactory, +) +from renku_data_services.data_connectors.models import DataConnector, GlobalDataConnector +from renku_data_services.notebooks.config import NotebooksConfig + +if TYPE_CHECKING: + from renku_data_services.storage.models import RCloneConfig + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True, eq=True, kw_only=True) +class _OAuth2ConfigPartial: + """Partial configuration; contains OAuth2 fields.""" + + token: str + token_url: str + + +class DataSourceRepository: + """Repository for checking session images with rich responses.""" + + def __init__( + self, + nb_config: NotebooksConfig, + connected_services_repo: ConnectedServicesRepository, + oauth_client_factory: OAuthHttpClientFactory, + ) -> None: + self.nb_config = nb_config + self.connected_services_repo = connected_services_repo + self.oauth_client_factory = oauth_client_factory + + async def handle_configuration( + self, request: Request, user: APIUser, data_connector: DataConnector | GlobalDataConnector + ) -> dict[str, Any] | None: + """Ajusts the configuration of the input data connector if it requires an OAuth2 connection. + + Returns either an rclone configuration or None if the data connector should be skipped. + """ + # NOTE: do not handle global data connectors + if data_connector.namespace is None: + return data_connector.storage.configuration + + provider_kind = self._get_oauth2_provider_kind(data_connector=data_connector) + if provider_kind is None: + return data_connector.storage.configuration + + oauth2_part = await self._get_oauth2_configuration_part( + request=request, user=user, data_connector=data_connector + ) + if oauth2_part is None: + return None + + logger.info(f"Adjusting rclone configuration for data connector {str(data_connector.id)}.") + configuration = data_connector.storage.configuration + if provider_kind == ProviderKind.google: + configuration["scope"] = configuration.get("scope") or "drive" + configuration["token"] = oauth2_part.token + configuration["token_url"] = oauth2_part.token_url + return configuration + + def is_patching_enabled(self, data_connector: DataConnector | GlobalDataConnector) -> bool: + """Returns true iff the data connector can be patched.""" + # NOTE: do not handle global data connectors + if data_connector.namespace is None: + return False + provider_kind = self._get_oauth2_provider_kind(data_connector=data_connector) + return provider_kind is not None + + async def handle_patching_configuration( + self, request: Request, user: APIUser, data_connector: DataConnector | GlobalDataConnector, config_data: str + ) -> str | None: + """Handles patching the configuration of a data connector when a session is resumed. + + This method updates the "token" and the "token_url" fields and no other part of the configuration. + + Returns either a new configuration (INI form) or None if the configuration should be left untouched. + """ + # NOTE: do not handle global data connectors + if data_connector.namespace is None: + return None + + parser = ConfigParser(interpolation=None) + try: + parser.read_string(config_data) + except Exception as err: + logger.error(f"Failed to parse existing data connector configuration: {err}") + return None + main_section = next(filter(lambda s: s, parser.sections()), "") + if not main_section: + logger.error("Failed to parse existing data connector configuration: no main section.") + return None + items = parser.items(main_section) + configuration = dict(items) + if configuration.get("type") != data_connector.storage.configuration.get("type"): + logger.warning( + f"Data connector type changed to {data_connector.storage.configuration.get("type")}, skipping!" + ) + return None + + oauth2_part = await self._get_oauth2_configuration_part( + request=request, user=user, data_connector=data_connector + ) + if oauth2_part is None: + return None + + logger.info(f"Patching rclone configuration for data connector {str(data_connector.id)}.") + parser.set(main_section, "token", oauth2_part.token) + parser.set(main_section, "token_url", oauth2_part.token_url) + stringio = StringIO() + parser.write(stringio) + return stringio.getvalue() + + async def handle_configuration_for_test( + self, user: APIUser, configuration: "RCloneConfig | dict[str, Any]" + ) -> "RCloneConfig | dict[str, Any] | None": + """Ajusts the input configuration if it requires an OAuth2 connection. + + Returns either an rclone configuration or None if the data connector should be skipped. + """ + provider_kind: ProviderKind | None = None + match configuration.get("type"): + case "drive": + provider_kind = ProviderKind.google + case "dropbox": + provider_kind = ProviderKind.dropbox + if provider_kind is None: + return configuration + + provider = await self.connected_services_repo.get_provider_for_kind(user=user, provider_kind=provider_kind) + if provider is None: + return None + connection = provider.connected_user.connection if provider.connected_user else None + if connection is None: + return None + token_set = await self.connected_services_repo.get_token_set(user=user, connection_id=connection.id) + if not token_set or not token_set.access_token: + return None + token_config = { + "access_token": token_set.access_token, + "token_type": "Bearer", + } + if provider_kind == ProviderKind.google: + configuration["scope"] = configuration.get("scope") or "drive" + if token_set.expires_at_iso: + token_config["expiry"] = token_set.expires_at_iso + configuration["token"] = json.dumps(token_config) + return configuration + + def _get_oauth2_provider_kind(self, data_connector: DataConnector | GlobalDataConnector) -> ProviderKind | None: + """Returns the provider kind for data connectors which require an OAuth2 configuration.""" + match data_connector.storage.configuration["type"]: + case "drive": + return ProviderKind.google + case "dropbox": + return ProviderKind.dropbox + case _: + return None + + async def _get_oauth2_configuration_part( + self, request: Request, user: APIUser, data_connector: DataConnector + ) -> _OAuth2ConfigPartial | None: + """Get the OAuth2 configuration fields.""" + provider_kind = self._get_oauth2_provider_kind(data_connector=data_connector) + if provider_kind is None: + return None + + provider = await self.connected_services_repo.get_provider_for_kind(user=user, provider_kind=provider_kind) + if provider is None: + logger.info( + f"Skipping data connector {str(data_connector.id)} of type " + f"{data_connector.storage.configuration["type"]} " + f"because no provider of kind {provider_kind.value} was found." + ) + return None + connection = provider.connected_user.connection if provider.connected_user else None + if connection is None: + logger.info( + f"Skipping data connector {str(data_connector.id)} of type " + f"{data_connector.storage.configuration["type"]} " + f"because no active connection was found; user needs to connect with {provider.provider.id}." + ) + return None + token_set = await self.connected_services_repo.get_token_set(user=user, connection_id=connection.id) + if not token_set or not token_set.access_token: + logger.info( + f"Skipping data connector {str(data_connector.id)} of type " + f"{data_connector.storage.configuration["type"]} " + f"because the connection is not active; user needs to re-connect with {provider.provider.id}." + ) + return None + token_config = { + "access_token": token_set.access_token, + "token_type": "Bearer", + } + if user.access_token and user.refresh_token: + renku_tokens = RenkuTokens( + access_token=user.access_token, + refresh_token=user.refresh_token, + ) + token_config["refresh_token"] = renku_tokens.encode() + if token_set.expires_at_iso: + token_config["expiry"] = token_set.expires_at_iso + token = json.dumps(token_config) + token_url = request.url_for("oauth2_connections.post_token_endpoint", connection_id=connection.id) + return _OAuth2ConfigPartial(token=token, token_url=token_url) diff --git a/components/renku_data_services/storage/blueprints.py b/components/renku_data_services/storage/blueprints.py index 270adb602..d1539908b 100644 --- a/components/renku_data_services/storage/blueprints.py +++ b/components/renku_data_services/storage/blueprints.py @@ -15,6 +15,7 @@ from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint from renku_data_services.base_api.misc import validate_query from renku_data_services.base_models.validation import validated_json +from renku_data_services.notebooks.data_sources import DataSourceRepository from renku_data_services.storage import apispec, models from renku_data_services.storage.db import StorageRepository from renku_data_services.storage.rclone import RCloneValidator @@ -193,6 +194,9 @@ async def _delete(request: Request, user: base_models.APIUser, storage_id: ULID) class StorageSchemaBP(CustomBlueprint): """Handler for getting RClone storage schema.""" + data_source_repo: DataSourceRepository + authenticator: base_models.Authenticator + def get(self) -> BlueprintFactoryResponse: """Get cloud storage for a repository.""" @@ -204,12 +208,18 @@ async def _get(_: Request, validator: RCloneValidator) -> JSONResponse: def test_connection(self) -> BlueprintFactoryResponse: """Validate an RClone config.""" + @authenticate(self.authenticator) @validate(json=apispec.StorageSchemaTestConnectionPostRequest) async def _test_connection( - request: Request, validator: RCloneValidator, body: apispec.StorageSchemaTestConnectionPostRequest + request: Request, + user: base_models.APIUser, + validator: RCloneValidator, + body: apispec.StorageSchemaTestConnectionPostRequest, ) -> HTTPResponse: validator.validate(body.configuration, keep_sensitive=True) - result = await validator.test_connection(body.configuration, body.source_path) + result = await validator.test_connection( + body.configuration, body.source_path, user=user, data_source_repo=self.data_source_repo + ) if not result.success: raise errors.ValidationError(message=result.error) return empty(204) diff --git a/components/renku_data_services/storage/rclone.py b/components/renku_data_services/storage/rclone.py index d6c6f2876..3a4086354 100644 --- a/components/renku_data_services/storage/rclone.py +++ b/components/renku_data_services/storage/rclone.py @@ -23,6 +23,8 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: + from renku_data_services import base_models + from renku_data_services.notebooks.data_sources import DataSourceRepository from renku_data_services.storage.models import RCloneConfig @@ -88,7 +90,11 @@ def get_real_configuration(self, configuration: Union[RCloneConfig, dict[str, An return real_config async def test_connection( - self, configuration: Union[RCloneConfig, dict[str, Any]], source_path: str + self, + configuration: Union[RCloneConfig, dict[str, Any]], + source_path: str, + user: base_models.APIUser | None = None, + data_source_repo: DataSourceRepository | None = None, ) -> ConnectionResult: """Tests connecting with an RClone config.""" try: @@ -101,6 +107,14 @@ async def test_connection( transformed_config = self.inject_default_values(self.transform_polybox_switchdriver_config(obscured_config)) transformed_config = self.transform_envidat_config(transformed_config) + # Handle testing with Renku integrations + if user is not None and data_source_repo is not None: + with_oauth2_config = await data_source_repo.handle_configuration_for_test( + user=user, configuration=transformed_config + ) + if with_oauth2_config is not None: + transformed_config = with_oauth2_config + with tempfile.NamedTemporaryFile(mode="w+", delete=False, encoding="utf-8") as f: config = "\n".join(f"{k}={v}" for k, v in transformed_config.items()) f.write(f"[temp]\n{config}") diff --git a/test/components/renku_data_services/connected_services/oauth_test_script.py b/test/components/renku_data_services/connected_services/oauth_test_script.py index ef29f62b7..336341068 100644 --- a/test/components/renku_data_services/connected_services/oauth_test_script.py +++ b/test/components/renku_data_services/connected_services/oauth_test_script.py @@ -66,7 +66,7 @@ ### --------------------------------------------------------------------- deps = DependencyManager.from_env() -logger = logging.getLogger(__file__) +logger = logging.getLogger(__name__) factory = DefaultOAuthHttpClientFactory(deps.config.secrets.encryption_key, deps.config.db.async_session_maker) diff --git a/test/components/renku_data_services/connected_services/test_provider_adapters.py b/test/components/renku_data_services/connected_services/test_provider_adapters.py new file mode 100644 index 000000000..fa55250d4 --- /dev/null +++ b/test/components/renku_data_services/connected_services/test_provider_adapters.py @@ -0,0 +1,34 @@ +"""Tests for provider adapters.""" + +from datetime import UTC, datetime + +import pytest +from ulid import ULID + +from renku_data_services.connected_services import models +from renku_data_services.connected_services import orm as schemas +from renku_data_services.connected_services.provider_adapters import get_provider_adapter + + +@pytest.mark.parametrize("provider_kind", list(models.ProviderKind)) +def test_get_provider_adapter_maps_all_providers(provider_kind: models.ProviderKind) -> None: + client = schemas.OAuth2ClientORM( + id=ULID(), + client_id=f"c-{provider_kind.value}", + display_name=provider_kind.value, + created_by_id="", + kind=provider_kind, + scope="", + url="https://dev.renku.ch", + use_pkce=False, + app_slug="", + client_secret=None, + creation_date=datetime.now(UTC), + updated_at=datetime.now(UTC), + image_registry_url=None, + oidc_issuer_url=None, + ) + + adapter = get_provider_adapter(client) + + assert adapter is not None