diff --git a/bluesky_httpserver/authentication.py b/bluesky_httpserver/_authentication.py similarity index 98% rename from bluesky_httpserver/authentication.py rename to bluesky_httpserver/_authentication.py index 9772974..0375794 100644 --- a/bluesky_httpserver/authentication.py +++ b/bluesky_httpserver/_authentication.py @@ -1,5 +1,4 @@ import asyncio -import enum import hashlib import secrets import uuid as uuid_module @@ -55,11 +54,6 @@ def utcnow(): return datetime.utcnow().replace(microsecond=0) -class Mode(enum.Enum): - password = "password" - external = "external" - - class Token(BaseModel): access_token: str token_type: str @@ -455,7 +449,8 @@ async def auth_code( api_access_manager=Depends(get_api_access_manager), ): request.state.endpoint = "auth" - username = await authenticator.authenticate(request) + user_session_state = await authenticator.authenticate(request) + username = user_session_state.user_name if user_session_state else None if username and api_access_manager.is_user_known(username): scopes = api_access_manager.get_user_scopes(username) @@ -484,7 +479,10 @@ async def handle_credentials( api_access_manager=Depends(get_api_access_manager), ): request.state.endpoint = "auth" - username = await authenticator.authenticate(username=form_data.username, password=form_data.password) + user_session_state = await authenticator.authenticate( + username=form_data.username, password=form_data.password + ) + username = user_session_state.user_name if user_session_state else None err_msg = None if not username: diff --git a/bluesky_httpserver/app.py b/bluesky_httpserver/app.py index f09acb3..9a8420a 100644 --- a/bluesky_httpserver/app.py +++ b/bluesky_httpserver/app.py @@ -15,7 +15,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.utils import get_openapi -from .authentication import Mode +from .authentication import ExternalAuthenticator, InternalAuthenticator from .console_output import CollectPublishedConsoleOutput, ConsoleOutputStream, SystemInfoStream from .core import PatchedStreamingResponse from .database.core import purge_expired @@ -179,12 +179,11 @@ def build_app(authentication=None, api_access=None, resource_access=None, server for spec in authentication["providers"]: provider = spec["provider"] authenticator = spec["authenticator"] - mode = authenticator.mode - if mode == Mode.password: + if isinstance(authenticator, InternalAuthenticator): authentication_router.post(f"/provider/{provider}/token")( build_handle_credentials_route(authenticator, provider) ) - elif mode == Mode.external: + elif isinstance(authenticator, ExternalAuthenticator): authentication_router.get(f"/provider/{provider}/code")( build_auth_code_route(authenticator, provider) ) @@ -192,7 +191,7 @@ def build_app(authentication=None, api_access=None, resource_access=None, server build_auth_code_route(authenticator, provider) ) else: - raise ValueError(f"unknown authentication mode {mode}") + raise ValueError(f"unknown authenticator type {type(authenticator)}") for custom_router in getattr(authenticator, "include_routers", []): authentication_router.include_router(custom_router, prefix=f"/provider/{provider}") diff --git a/bluesky_httpserver/authentication/__init__.py b/bluesky_httpserver/authentication/__init__.py new file mode 100644 index 0000000..fc35cdd --- /dev/null +++ b/bluesky_httpserver/authentication/__init__.py @@ -0,0 +1,25 @@ +from .._authentication import ( + base_authentication_router, + build_auth_code_route, + build_handle_credentials_route, + get_current_principal, + get_current_principal_websocket, + oauth2_scheme, +) +from .authenticator_base import ( + ExternalAuthenticator, + InternalAuthenticator, + UserSessionState, +) + +__all__ = [ + "ExternalAuthenticator", + "InternalAuthenticator", + "UserSessionState", + "get_current_principal", + "get_current_principal_websocket", + "base_authentication_router", + "build_auth_code_route", + "build_handle_credentials_route", + "oauth2_scheme", +] diff --git a/bluesky_httpserver/authentication/authenticator_base.py b/bluesky_httpserver/authentication/authenticator_base.py new file mode 100644 index 0000000..af103c5 --- /dev/null +++ b/bluesky_httpserver/authentication/authenticator_base.py @@ -0,0 +1,37 @@ +from abc import ABC +from dataclasses import dataclass +from typing import Optional + +from fastapi import Request + + +@dataclass +class UserSessionState: + """Data transfer class to communicate custom session state information.""" + + user_name: str + state: dict = None + + +class InternalAuthenticator(ABC): + """ + Base class for authenticators that use username/password credentials. + + Subclasses must implement the authenticate method which takes a username + and password and returns a UserSessionState on success or None on failure. + """ + + async def authenticate(self, username: str, password: str) -> Optional[UserSessionState]: + raise NotImplementedError + + +class ExternalAuthenticator(ABC): + """ + Base class for authenticators that use external identity providers. + + Subclasses must implement the authenticate method which takes a FastAPI + Request object and returns a UserSessionState on success or None on failure. + """ + + async def authenticate(self, request: Request) -> Optional[UserSessionState]: + raise NotImplementedError diff --git a/bluesky_httpserver/authenticators.py b/bluesky_httpserver/authenticators.py index 61c2da4..78b6cf1 100644 --- a/bluesky_httpserver/authenticators.py +++ b/bluesky_httpserver/authenticators.py @@ -1,21 +1,32 @@ import asyncio +import base64 import functools import logging import re import secrets from collections.abc import Iterable +from datetime import timedelta +from typing import Any, List, Mapping, Optional, cast +import httpx +from cachetools import TTLCache, cached from fastapi import APIRouter, Request -from jose import JWTError, jwk, jwt +from fastapi.security import OAuth2, OAuth2AuthorizationCodeBearer +from jose import JWTError, jwt +from pydantic import Secret from starlette.responses import RedirectResponse -from .authentication import Mode -from .utils import modules_available +from .authentication import ( + ExternalAuthenticator, + InternalAuthenticator, + UserSessionState, +) +from .utils import get_root_url, modules_available logger = logging.getLogger(__name__) -class DummyAuthenticator: +class DummyAuthenticator(InternalAuthenticator): """ For test and demo purposes only! @@ -23,26 +34,20 @@ class DummyAuthenticator: """ - mode = Mode.password + def __init__(self, confirmation_message: str = ""): + self.confirmation_message = confirmation_message - async def authenticate(self, username: str, password: str): - return username + async def authenticate(self, username: str, password: str) -> UserSessionState: + return UserSessionState(username, {}) -class DictionaryAuthenticator: +class DictionaryAuthenticator(InternalAuthenticator): """ For test and demo purposes only! Check passwords from a dictionary of usernames mapped to passwords. - - Parameters - ---------- - - users_to_passwords: dict(str, str) - Mapping of usernames to passwords. """ - mode = Mode.password configuration_schema = """ $schema": http://json-schema.org/draft-07/schema# type: object @@ -50,25 +55,28 @@ class DictionaryAuthenticator: properties: users_to_password: type: object - description: | - Mapping usernames to password. Environment variable expansion should be - used to avoid placing passwords directly in configuration. + description: | + Mapping usernames to password. Environment variable expansion should be + used to avoid placing passwords directly in configuration. + confirmation_message: + type: string + description: May be displayed by client after successful login. """ - def __init__(self, users_to_passwords): + def __init__(self, users_to_passwords: Mapping[str, str], confirmation_message: str = ""): self._users_to_passwords = users_to_passwords + self.confirmation_message = confirmation_message - async def authenticate(self, username: str, password: str): + async def authenticate(self, username: str, password: str) -> Optional[UserSessionState]: true_password = self._users_to_passwords.get(username) if not true_password: # Username is not valid. - return + return None if secrets.compare_digest(true_password, password): - return username + return UserSessionState(username, {}) -class PAMAuthenticator: - mode = Mode.password +class PAMAuthenticator(InternalAuthenticator): configuration_schema = """ $schema": http://json-schema.org/draft-07/schema# type: object @@ -77,90 +85,139 @@ class PAMAuthenticator: service: type: string description: PAM service. Default is 'login'. + confirmation_message: + type: string + description: May be displayed by client after successful login. """ - def __init__(self, service="login"): + def __init__(self, service: str = "login", confirmation_message: str = ""): if not modules_available("pamela"): raise ModuleNotFoundError("This PAMAuthenticator requires the module 'pamela' to be installed.") self.service = service + self.confirmation_message = confirmation_message # TODO Try to open a PAM session. - async def authenticate(self, username: str, password: str): + async def authenticate(self, username: str, password: str) -> Optional[UserSessionState]: import pamela try: pamela.authenticate(username, password, service=self.service) + return UserSessionState(username, {}) except pamela.PAMError: # Authentication failed. - return - else: - return username + return None -class OIDCAuthenticator: - mode = Mode.external +class OIDCAuthenticator(ExternalAuthenticator): configuration_schema = """ $schema": http://json-schema.org/draft-07/schema# type: object additionalProperties: false properties: + audience: + type: string client_id: type: string client_secret: type: string - redirect_uri: + well_known_uri: type: string - token_uri: + confirmation_message: type: string - authorization_endpoint: + redirect_on_success: + type: string + redirect_on_failure: type: string - public_keys: - type: array - item: - type: object - properties: - - alg: - type: string - - e - type: string - - kid - type: string - - kty - type: string - - n - type: string - - use - type: string - required: - - alg - - e - - kid - - kty - - n - - use """ def __init__( self, - client_id, - client_secret, - redirect_uri, - public_keys, - token_uri, - authorization_endpoint, - confirmation_message, + audience: str, + client_id: str, + client_secret: str, + well_known_uri: str, + confirmation_message: str = "", + redirect_on_success: Optional[str] = None, + redirect_on_failure: Optional[str] = None, ): - self.client_id = client_id - self.client_secret = client_secret + self._audience = audience + self._client_id = client_id + self._client_secret = Secret(client_secret) + self._well_known_url = well_known_uri self.confirmation_message = confirmation_message - self.redirect_uri = redirect_uri - self.public_keys = public_keys - self.token_uri = token_uri - self.authorization_endpoint = authorization_endpoint.format(client_id=client_id, redirect_uri=redirect_uri) - - async def authenticate(self, request): - code = request.query_params["code"] - response = await exchange_code(self.token_uri, code, self.client_id, self.client_secret, self.redirect_uri) + self.redirect_on_success = redirect_on_success + self.redirect_on_failure = redirect_on_failure + + @functools.cached_property + def _config_from_oidc_url(self) -> dict[str, Any]: + response: httpx.Response = httpx.get(self._well_known_url) + response.raise_for_status() + return response.json() + + @functools.cached_property + def client_id(self) -> str: + return self._client_id + + @functools.cached_property + def id_token_signing_alg_values_supported(self) -> list[str]: + return cast( + list[str], + self._config_from_oidc_url.get("id_token_signing_alg_values_supported"), + ) + + @functools.cached_property + def issuer(self) -> str: + return cast(str, self._config_from_oidc_url.get("issuer")) + + @functools.cached_property + def jwks_uri(self) -> str: + return cast(str, self._config_from_oidc_url.get("jwks_uri")) + + @functools.cached_property + def token_endpoint(self) -> str: + return cast(str, self._config_from_oidc_url.get("token_endpoint")) + + @functools.cached_property + def authorization_endpoint(self) -> httpx.URL: + return httpx.URL(cast(str, self._config_from_oidc_url.get("authorization_endpoint"))) + + @functools.cached_property + def device_authorization_endpoint(self) -> str: + return cast(str, self._config_from_oidc_url.get("device_authorization_endpoint")) + + @functools.cached_property + def end_session_endpoint(self) -> str: + return cast(str, self._config_from_oidc_url.get("end_session_endpoint")) + + @cached(TTLCache(maxsize=1, ttl=timedelta(days=7).total_seconds())) + def keys(self) -> List[str]: + return httpx.get(self.jwks_uri).raise_for_status().json().get("keys", []) + + def decode_token(self, token: str) -> dict[str, Any]: + return jwt.decode( + token, + key=self.keys(), + algorithms=self.id_token_signing_alg_values_supported, + audience=self._audience, + issuer=self.issuer, + ) + + async def authenticate(self, request: Request) -> Optional[UserSessionState]: + code = request.query_params.get("code") + if not code: + logger.warning("Authentication failed: No authorization code parameter provided.") + return None + # A proxy in the middle may make the request into something like + # 'http://localhost:8000/...' so we fix the first part but keep + # the original URI path. + redirect_uri = f"{get_root_url(request)}{request.url.path}" + response = await exchange_code( + self.token_endpoint, + code, + self._client_id, + self._client_secret.get_secret_value(), + redirect_uri, + ) response_body = response.json() if response.is_error: logger.error("Authentication error: %r", response_body) @@ -168,63 +225,84 @@ async def authenticate(self, request): response_body = response.json() id_token = response_body["id_token"] access_token = response_body["access_token"] - # Match the kid in id_token to a key in the list of public_keys. - key = find_key(id_token, self.public_keys) try: - verified_body = jwt.decode(id_token, key, access_token=access_token, audience=self.client_id) + verified_body = self.decode_token(access_token) except JWTError: logger.exception( "Authentication error. Unverified token: %r", jwt.get_unverified_claims(id_token), ) return None - return verified_body["sub"] - - -class KeyNotFoundError(Exception): - pass - - -def find_key(token, keys): - """ - Find a key from the configured keys based on the kid claim of the token + return UserSessionState(verified_body["sub"], {}) - Parameters - ---------- - token : token to search for the kid from - keys: list of keys - Raises - ------ - KeyNotFoundError: - returned if the token does not have a kid claim - - Returns - ------ - key: found key object - """ +class ProxiedOIDCAuthenticator(OIDCAuthenticator): + configuration_schema = """ +$schema": http://json-schema.org/draft-07/schema# +type: object +additionalProperties: false +properties: + audience: + type: string + client_id: + type: string + well_known_uri: + type: string + scopes: + type: array + items: + type: string + description: | + Optional list of OAuth2 scopes to request. If provided, authorization + should be enforced by an external policy agent (for example ExternalPolicyDecisionPoint) + rather than by this authenticator. + device_flow_client_id: + type: string + confirmation_message: + type: string +""" - unverified = jwt.get_unverified_header(token) - kid = unverified.get("kid") - if not kid: - raise KeyNotFoundError("No 'kid' in token") + def __init__( + self, + audience: str, + client_id: str, + well_known_uri: str, + device_flow_client_id: str, + scopes: Optional[List[str]] = None, + confirmation_message: str = "", + ): + super().__init__( + audience=audience, + client_id=client_id, + client_secret="", + well_known_uri=well_known_uri, + confirmation_message=confirmation_message, + ) + self.scopes = scopes + self.device_flow_client_id = device_flow_client_id + self._oidc_bearer = OAuth2AuthorizationCodeBearer( + authorizationUrl=str(self.authorization_endpoint), + tokenUrl=self.token_endpoint, + ) - for key in keys: - if key["kid"] == kid: - return jwk.construct(key) - return KeyNotFoundError(f"Token specifies {kid} but we have {[k['kid'] for k in keys]}") + @property + def oauth2_schema(self) -> OAuth2: + return self._oidc_bearer -async def exchange_code(token_uri, auth_code, client_id, client_secret, redirect_uri): +async def exchange_code( + token_uri: str, + auth_code: str, + client_id: str, + client_secret: str, + redirect_uri: str, +) -> httpx.Response: """Method that talks to an IdP to exchange a code for an access_token and/or id_token Args: token_url ([type]): [description] auth_code ([type]): [description] """ - if not modules_available("httpx"): - raise ModuleNotFoundError("This authenticator requires 'httpx'. (pip install httpx)") - import httpx - + auth_value = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode() response = httpx.post( url=token_uri, data={ @@ -234,18 +312,18 @@ async def exchange_code(token_uri, auth_code, client_id, client_secret, redirect "code": auth_code, "client_secret": client_secret, }, + headers={"Authorization": f"Basic {auth_value}"}, ) return response -class SAMLAuthenticator: - mode = Mode.external +class SAMLAuthenticator(ExternalAuthenticator): def __init__( self, saml_settings, # See EXAMPLE_SAML_SETTINGS below. - attribute_name, # which SAML attribute to use as 'id' for Idenity - confirmation_message=None, + attribute_name: str, # which SAML attribute to use as 'id' for Identity + confirmation_message: str = "", ): self.saml_settings = saml_settings self.attribute_name = attribute_name @@ -263,23 +341,15 @@ def __init__( from onelogin.saml2.auth import OneLogin_Saml2_Auth @router.get("/login") - async def saml_login(request: Request): + async def saml_login(request: Request) -> RedirectResponse: req = await prepare_saml_from_fastapi_request(request) auth = OneLogin_Saml2_Auth(req, self.saml_settings) - # saml_settings = auth.get_settings() - # metadata = saml_settings.get_sp_metadata() - # errors = saml_settings.validate_metadata(metadata) - # if len(errors) == 0: - # print(metadata) - # else: - # print("Error found on Metadata: %s" % (', '.join(errors))) callback_url = auth.login() - response = RedirectResponse(url=callback_url) - return response + return RedirectResponse(url=callback_url) self.include_routers = [router] - async def authenticate(self, request): + async def authenticate(self, request: Request) -> Optional[UserSessionState]: if not modules_available("onelogin"): raise ModuleNotFoundError("This SAMLAuthenticator requires the module 'oneline' to be installed.") from onelogin.saml2.auth import OneLogin_Saml2_Auth @@ -297,12 +367,12 @@ async def authenticate(self, request): attribute_as_list = auth.get_attributes()[self.attribute_name] # Confused in what situation this would have more than one item.... assert len(attribute_as_list) == 1 - return attribute_as_list[0] + return UserSessionState(attribute_as_list[0], {}) else: return None -async def prepare_saml_from_fastapi_request(request, debug=False): +async def prepare_saml_from_fastapi_request(request: Request) -> Mapping[str, str]: form_data = await request.form() rv = { "http_host": request.client.host, @@ -328,7 +398,7 @@ async def prepare_saml_from_fastapi_request(request, debug=False): return rv -class LDAPAuthenticator: +class LDAPAuthenticator(InternalAuthenticator): """ LDAP authenticator. The authenticator code is based on https://github.com/jupyterhub/ldapauthenticator @@ -472,6 +542,8 @@ class LDAPAuthenticator: This can be useful in an heterogeneous environment, when supplying a UNIX username to authenticate against AD. + confirmation_message: str + May be displayed by client after successful login. Examples -------- @@ -510,8 +582,6 @@ class LDAPAuthenticator: id: user02 """ - mode = Mode.password - def __init__( self, server_address, @@ -536,6 +606,7 @@ def __init__( attributes=None, auth_state_attributes=None, use_lookup_dn_username=True, + confirmation_message="", ): self.use_ssl = use_ssl self.use_tls = use_tls @@ -571,6 +642,7 @@ def __init__( self.server_address_list = server_address_list self.server_port = server_port if server_port is not None else self._server_port_default() + self.confirmation_message = confirmation_message def _server_port_default(self): if self.use_ssl: @@ -655,7 +727,7 @@ async def resolve_username(self, username_supplied_by_user): def get_connection(self, userdn, password): import ldap3 - # NOTE: setting 'acitve=False' essentially disables exclusion of inactive servers from the pool. + # NOTE: setting 'active=False' essentially disables exclusion of inactive servers from the pool. # It probably does not matter if the pool contains only one server, but it could have implications # when there are multiple servers in the pool. It is not clear what those implications are. # But using the default 'activate=True' results in the thread being blocked indefinitely @@ -675,14 +747,21 @@ def get_connection(self, userdn, password): server_port = self.server_port server = ldap3.Server( - server_addr, port=server_port, use_ssl=self.use_ssl, connect_timeout=self.connect_timeout + server_addr, + port=server_port, + use_ssl=self.use_ssl, + connect_timeout=self.connect_timeout, ) server_pool.add(server) auto_bind_no_ssl = ldap3.AUTO_BIND_TLS_BEFORE_BIND if self.use_tls else ldap3.AUTO_BIND_NO_TLS auto_bind = ldap3.AUTO_BIND_NO_TLS if self.use_ssl else auto_bind_no_ssl conn = ldap3.Connection( - server_pool, user=userdn, password=password, auto_bind=auto_bind, receive_timeout=self.receive_timeout + server_pool, + user=userdn, + password=password, + auto_bind=auto_bind, + receive_timeout=self.receive_timeout, ) return conn @@ -690,14 +769,17 @@ async def get_user_attributes(self, conn, userdn): attrs = {} if self.auth_state_attributes: search_func = functools.partial( - conn.search, userdn, "(objectClass=*)", attributes=self.auth_state_attributes + conn.search, + userdn, + "(objectClass=*)", + attributes=self.auth_state_attributes, ) found = await asyncio.get_running_loop().run_in_executor(None, search_func) if found: attrs = conn.entries[0].entry_attributes_as_dict return attrs - async def authenticate(self, username: str, password: str): + async def authenticate(self, username: str, password: str) -> Optional[UserSessionState]: import ldap3 username_saved = username # Save the user name passed as a parameter @@ -826,5 +908,6 @@ async def authenticate(self, username: str, password: str): user_info = await self.get_user_attributes(conn, userdn) if user_info: logger.debug("username:%s attributes:%s", username, user_info) - return {"name": username, "auth_state": user_info} - return username + # this path might never have been worked out...is it ever hit? + return UserSessionState(username, user_info) + return UserSessionState(username, {}) diff --git a/bluesky_httpserver/tests/test_authenticators.py b/bluesky_httpserver/tests/test_authenticators.py index cc2984c..183ce75 100644 --- a/bluesky_httpserver/tests/test_authenticators.py +++ b/bluesky_httpserver/tests/test_authenticators.py @@ -3,7 +3,7 @@ import pytest # fmt: off -from ..authenticators import LDAPAuthenticator +from ..authenticators import LDAPAuthenticator, UserSessionState @pytest.mark.parametrize("ldap_server_address, ldap_server_port", [ @@ -35,8 +35,8 @@ def test_LDAPAuthenticator_01(use_tls, use_ssl, ldap_server_address, ldap_server ) async def testing(): - assert await authenticator.authenticate("user01", "password1") == "user01" - assert await authenticator.authenticate("user02", "password2") == "user02" + assert await authenticator.authenticate("user01", "password1") == UserSessionState("user01", {}) + assert await authenticator.authenticate("user02", "password2") == UserSessionState("user02", {}) assert await authenticator.authenticate("user02a", "password2") is None assert await authenticator.authenticate("user02", "password2a") is None diff --git a/requirements.txt b/requirements.txt index f465abd..818362f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ alembic bluesky-queueserver bluesky-queueserver-api +cachetools fastapi ldap3 orjson