diff --git a/superset/commands/database/oauth2.py b/superset/commands/database/oauth2.py index f7259077bc4c..8355bc0098ec 100644 --- a/superset/commands/database/oauth2.py +++ b/superset/commands/database/oauth2.py @@ -18,12 +18,15 @@ from datetime import datetime, timedelta from functools import partial from typing import cast +from uuid import UUID from superset.commands.base import BaseCommand from superset.commands.database.exceptions import DatabaseNotFoundError from superset.daos.database import DatabaseUserOAuth2TokensDAO +from superset.daos.key_value import KeyValueDAO from superset.databases.schemas import OAuth2ProviderResponseSchema from superset.exceptions import OAuth2Error +from superset.key_value.types import JsonKeyValueCodec, KeyValueResource from superset.models.core import Database, DatabaseUserOAuth2Tokens from superset.superset_typing import OAuth2State from superset.utils.decorators import on_error, transaction @@ -50,9 +53,28 @@ def run(self) -> DatabaseUserOAuth2Tokens: if oauth2_config is None: raise OAuth2Error("No configuration found for OAuth2") + # Look up PKCE code_verifier from KV store (RFC 7636) + code_verifier = None + tab_id = self._state["tab_id"] + try: + tab_uuid = UUID(tab_id) + except ValueError: + tab_uuid = None + + if tab_uuid: + kv_value = KeyValueDAO.get_value( + resource=KeyValueResource.PKCE_CODE_VERIFIER, + key=tab_uuid, + codec=JsonKeyValueCodec(), + ) + if kv_value: + code_verifier = kv_value.get("code_verifier") + KeyValueDAO.delete_entry(KeyValueResource.PKCE_CODE_VERIFIER, tab_uuid) + token_response = self._database.db_engine_spec.get_oauth2_token( oauth2_config, self._parameters["code"], + code_verifier=code_verifier, ) # delete old tokens diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index ac355c936630..ac9c397e5b4e 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -21,7 +21,7 @@ import logging import re import warnings -from datetime import datetime +from datetime import datetime, timedelta from inspect import signature from re import Match, Pattern from typing import ( @@ -36,7 +36,7 @@ Union, ) from urllib.parse import urlencode, urljoin -from uuid import uuid4 +from uuid import UUID, uuid4 import pandas as pd import requests @@ -63,6 +63,7 @@ from superset.databases.utils import get_table_metadata, make_url_safe from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import OAuth2Error, OAuth2RedirectError +from superset.key_value.types import JsonKeyValueCodec, KeyValueResource from superset.sql.parse import ( BaseSQLStatement, LimitMethod, @@ -83,7 +84,11 @@ from superset.utils.hashing import hash_from_str from superset.utils.json import redact_sensitive, reveal_sensitive from superset.utils.network import is_hostname_valid, is_port_open -from superset.utils.oauth2 import encode_oauth2_state +from superset.utils.oauth2 import ( + encode_oauth2_state, + generate_code_challenge, + generate_code_verifier, +) if TYPE_CHECKING: from superset.connectors.sqla.models import TableColumn @@ -608,13 +613,38 @@ def start_oauth2_dance(cls, database: Database) -> None: tab sends a message to the original tab informing that authorization was successful (or not), and then closes. The original tab will automatically re-run the query after authorization. + + PKCE (RFC 7636) is used to protect against authorization code interception + attacks. A code_verifier is generated and stored server-side in the KV store, + while the code_challenge (derived from the verifier) is sent to the + authorization server. """ + # Prevent circular import. + from superset.daos.key_value import KeyValueDAO + tab_id = str(uuid4()) default_redirect_uri = app.config.get( "DATABASE_OAUTH2_REDIRECT_URI", url_for("DatabaseRestApi.oauth2", _external=True), ) + # Generate PKCE code verifier (RFC 7636) + code_verifier = generate_code_verifier() + + # Store the code_verifier server-side in the KV store, keyed by tab_id. + # This avoids exposing it in the URL/browser history via the JWT state. + KeyValueDAO.delete_expired_entries(KeyValueResource.PKCE_CODE_VERIFIER) + KeyValueDAO.create_entry( + resource=KeyValueResource.PKCE_CODE_VERIFIER, + value={"code_verifier": code_verifier}, + codec=JsonKeyValueCodec(), + key=UUID(tab_id), + expires_on=datetime.now() + timedelta(minutes=5), + ) + # We need to commit here because we're going to raise an exception, which will + # revert any non-commited changes. + db.session.commit() + # The state is passed to the OAuth2 provider, and sent back to Superset after # the user authorizes the access. The redirect endpoint in Superset can then # inspect the state to figure out to which user/database the access token @@ -641,7 +671,11 @@ def start_oauth2_dance(cls, database: Database) -> None: if oauth2_config is None: raise OAuth2Error("No configuration found for OAuth2") - oauth_url = cls.get_oauth2_authorization_uri(oauth2_config, state) + oauth_url = cls.get_oauth2_authorization_uri( + oauth2_config, + state, + code_verifier=code_verifier, + ) raise OAuth2RedirectError(oauth_url, tab_id, default_redirect_uri) @@ -685,21 +719,29 @@ def get_oauth2_authorization_uri( cls, config: OAuth2ClientConfig, state: OAuth2State, + code_verifier: str | None = None, ) -> str: """ Return URI for initial OAuth2 request. - Uses standard OAuth 2.0 parameters only. Subclasses can override - to add provider-specific parameters (e.g., Google's prompt=consent). + Uses standard OAuth 2.0 parameters plus PKCE (RFC 7636) parameters. + Subclasses can override to add provider-specific parameters + (e.g., Google's prompt=consent). """ uri = config["authorization_request_uri"] - params = { + params: dict[str, str] = { "scope": config["scope"], "response_type": "code", "state": encode_oauth2_state(state), "redirect_uri": config["redirect_uri"], "client_id": config["id"], } + + # Add PKCE parameters (RFC 7636) if code_verifier is provided + if code_verifier: + params["code_challenge"] = generate_code_challenge(code_verifier) + params["code_challenge_method"] = "S256" + return urljoin(uri, "?" + urlencode(params)) @classmethod @@ -707,19 +749,27 @@ def get_oauth2_token( cls, config: OAuth2ClientConfig, code: str, + code_verifier: str | None = None, ) -> OAuth2TokenResponse: """ Exchange authorization code for refresh/access tokens. + + If code_verifier is provided (PKCE flow), it will be included in the + token request per RFC 7636. """ timeout = app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds() uri = config["token_request_uri"] - req_body = { + req_body: dict[str, str] = { "code": code, "client_id": config["id"], "client_secret": config["secret"], "redirect_uri": config["redirect_uri"], "grant_type": "authorization_code", } + # Add PKCE code_verifier if present (RFC 7636) + if code_verifier: + req_body["code_verifier"] = code_verifier + response = ( requests.post(uri, data=req_body, timeout=timeout) if config["request_content_type"] == "data" diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py index 86eadf6c5ab6..780f92cc750f 100644 --- a/superset/db_engine_specs/gsheets.py +++ b/superset/db_engine_specs/gsheets.py @@ -161,6 +161,7 @@ def get_oauth2_authorization_uri( cls, config: "OAuth2ClientConfig", state: "OAuth2State", + code_verifier: str | None = None, ) -> str: """ Return URI for initial OAuth2 request with Google-specific parameters. @@ -172,10 +173,10 @@ def get_oauth2_authorization_uri( """ from urllib.parse import urlencode, urljoin - from superset.utils.oauth2 import encode_oauth2_state + from superset.utils.oauth2 import encode_oauth2_state, generate_code_challenge uri = config["authorization_request_uri"] - params = { + params: dict[str, str] = { "scope": config["scope"], "response_type": "code", "state": encode_oauth2_state(state), @@ -186,6 +187,12 @@ def get_oauth2_authorization_uri( "include_granted_scopes": "false", "prompt": "consent", } + + # Add PKCE parameters (RFC 7636) if code_verifier is provided + if code_verifier: + params["code_challenge"] = generate_code_challenge(code_verifier) + params["code_challenge_method"] = "S256" + return urljoin(uri, "?" + urlencode(params)) @classmethod diff --git a/superset/key_value/types.py b/superset/key_value/types.py index 3b2da06493c9..2cc025e42650 100644 --- a/superset/key_value/types.py +++ b/superset/key_value/types.py @@ -45,6 +45,7 @@ class KeyValueResource(StrEnum): EXPLORE_PERMALINK = "explore_permalink" METASTORE_CACHE = "superset_metastore_cache" LOCK = "lock" + PKCE_CODE_VERIFIER = "pkce_code_verifier" SQLLAB_PERMALINK = "sqllab_permalink" diff --git a/superset/superset_typing.py b/superset/superset_typing.py index 4d409398d1c2..02e294a08cfb 100644 --- a/superset/superset_typing.py +++ b/superset/superset_typing.py @@ -356,7 +356,7 @@ class OAuth2TokenResponse(TypedDict, total=False): refresh_token: str -class OAuth2State(TypedDict): +class OAuth2State(TypedDict, total=False): """ Type for the state passed during OAuth2. """ diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index 0124f5730875..9a24f0c09581 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -17,7 +17,10 @@ from __future__ import annotations +import base64 +import hashlib import logging +import secrets from contextlib import contextmanager from datetime import datetime, timedelta, timezone from typing import Any, Iterator, TYPE_CHECKING @@ -40,6 +43,37 @@ logger = logging.getLogger(__name__) +# PKCE code verifier length (RFC 7636 recommends 43-128 characters) +PKCE_CODE_VERIFIER_LENGTH = 64 + + +def generate_code_verifier() -> str: + """ + Generate a PKCE code verifier (RFC 7636). + + The code verifier is a high-entropy cryptographic random string using + unreserved characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~", + with a minimum length of 43 characters and a maximum length of 128. + """ + # Generate random bytes and encode as URL-safe base64 + random_bytes = secrets.token_bytes(PKCE_CODE_VERIFIER_LENGTH) + # Use URL-safe base64 encoding without padding + code_verifier = base64.urlsafe_b64encode(random_bytes).rstrip(b"=").decode("ascii") + return code_verifier + + +def generate_code_challenge(code_verifier: str) -> str: + """ + Generate a PKCE code challenge from a code verifier (RFC 7636). + + Uses the S256 method: BASE64URL(SHA256(code_verifier)) + """ + # Compute SHA-256 hash of the code verifier + digest = hashlib.sha256(code_verifier.encode("ascii")).digest() + # Encode as URL-safe base64 without padding + code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + return code_challenge + @backoff.on_exception( backoff.expo, @@ -140,13 +174,14 @@ def encode_oauth2_state(state: OAuth2State) -> str: """ Encode the OAuth2 state. """ - payload = { + payload: dict[str, Any] = { "exp": datetime.now(tz=timezone.utc) + JWT_EXPIRATION, "database_id": state["database_id"], "user_id": state["user_id"], "default_redirect_uri": state["default_redirect_uri"], "tab_id": state["tab_id"], } + encoded_state = jwt.encode( payload=payload, key=app.config["SECRET_KEY"], @@ -172,12 +207,12 @@ def make_oauth2_state( data: dict[str, Any], **kwargs: Any, ) -> OAuth2State: - return OAuth2State( - database_id=data["database_id"], - user_id=data["user_id"], - default_redirect_uri=data["default_redirect_uri"], - tab_id=data["tab_id"], - ) + return { + "database_id": data["database_id"], + "user_id": data["user_id"], + "default_redirect_uri": data["default_redirect_uri"], + "tab_id": data["tab_id"], + } class Meta: # pylint: disable=too-few-public-methods # ignore `exp` diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py index 7b77f5099d96..244d75f7e28d 100644 --- a/tests/unit_tests/databases/api_test.py +++ b/tests/unit_tests/databases/api_test.py @@ -255,7 +255,7 @@ def test_database_connection( "service_account_info": { "type": "service_account", "project_id": "black-sanctum-314419", - "private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173", + "private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173", # noqa: E501 "private_key": "XXXXXXXXXX", "client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com", # noqa: E501 "client_id": "114567578578109757129", @@ -621,6 +621,10 @@ def test_oauth2_happy_path( "expires_in": 3600, "refresh_token": "ZZZ", } + mocker.patch( + "superset.commands.database.oauth2.KeyValueDAO.get_value", + return_value=None, + ) state: OAuth2State = { "user_id": 1, @@ -641,7 +645,11 @@ def test_oauth2_happy_path( ) assert response.status_code == 200 - get_oauth2_token.assert_called_with({"id": "one", "secret": "two"}, "XXX") + get_oauth2_token.assert_called_with( + {"id": "one", "secret": "two"}, + "XXX", + code_verifier=None, + ) token = db.session.query(DatabaseUserOAuth2Tokens).one() assert token.user_id == 1 @@ -689,6 +697,10 @@ def test_oauth2_permissions( "expires_in": 3600, "refresh_token": "ZZZ", } + mocker.patch( + "superset.commands.database.oauth2.KeyValueDAO.get_value", + return_value=None, + ) state: OAuth2State = { "user_id": 1, @@ -709,7 +721,11 @@ def test_oauth2_permissions( ) assert response.status_code == 200 - get_oauth2_token.assert_called_with({"id": "one", "secret": "two"}, "XXX") + get_oauth2_token.assert_called_with( + {"id": "one", "secret": "two"}, + "XXX", + code_verifier=None, + ) token = db.session.query(DatabaseUserOAuth2Tokens).one() assert token.user_id == 1 @@ -762,6 +778,10 @@ def test_oauth2_multiple_tokens( "refresh_token": "ZZZ2", }, ] + mocker.patch( + "superset.commands.database.oauth2.KeyValueDAO.get_value", + return_value=None, + ) state: OAuth2State = { "user_id": 1, diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index ccfe2b337f77..5d8c3304102f 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -889,6 +889,124 @@ def test_get_oauth2_authorization_uri_standard_params(mocker: MockerFixture) -> assert "access_type" not in query assert "include_granted_scopes" not in query + # Verify PKCE parameters are NOT included when code_verifier is not provided + assert "code_challenge" not in query + assert "code_challenge_method" not in query + + +def test_get_oauth2_authorization_uri_with_pkce(mocker: MockerFixture) -> None: + """ + Test that BaseEngineSpec.get_oauth2_authorization_uri includes PKCE parameters + when code_verifier is passed as a parameter (RFC 7636). + """ + from urllib.parse import parse_qs, urlparse + + from superset.db_engine_specs.base import BaseEngineSpec + from superset.utils.oauth2 import generate_code_challenge, generate_code_verifier + + config: OAuth2ClientConfig = { + "id": "client-id", + "secret": "client-secret", + "scope": "read write", + "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/", + "authorization_request_uri": "https://oauth.example.com/authorize", + "token_request_uri": "https://oauth.example.com/token", + "request_content_type": "json", + } + + code_verifier = generate_code_verifier() + state: OAuth2State = { + "database_id": 1, + "user_id": 1, + "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/", + "tab_id": "1234", + } + + url = BaseEngineSpec.get_oauth2_authorization_uri( + config, state, code_verifier=code_verifier + ) + parsed = urlparse(url) + query = parse_qs(parsed.query) + + # Verify PKCE parameters are included (RFC 7636) + assert "code_challenge" in query + assert query["code_challenge_method"][0] == "S256" + # Verify the code_challenge matches the expected value + expected_challenge = generate_code_challenge(code_verifier) + assert query["code_challenge"][0] == expected_challenge + + +def test_get_oauth2_token_without_pkce(mocker: MockerFixture) -> None: + """ + Test that BaseEngineSpec.get_oauth2_token works without PKCE code_verifier. + """ + from superset.db_engine_specs.base import BaseEngineSpec + + mocker.patch( + "flask.current_app.config", + {"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 30)}, + ) + mock_post = mocker.patch("superset.db_engine_specs.base.requests.post") + mock_post.return_value.json.return_value = { + "access_token": "test-access-token", # noqa: S105 + "expires_in": 3600, + } + + config: OAuth2ClientConfig = { + "id": "client-id", + "secret": "client-secret", + "scope": "read write", + "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/", + "authorization_request_uri": "https://oauth.example.com/authorize", + "token_request_uri": "https://oauth.example.com/token", + "request_content_type": "json", + } + + result = BaseEngineSpec.get_oauth2_token(config, "auth-code") + + assert result["access_token"] == "test-access-token" # noqa: S105 + # Verify code_verifier is NOT in the request body + call_kwargs = mock_post.call_args + request_body = call_kwargs.kwargs.get("json") or call_kwargs.kwargs.get("data") + assert "code_verifier" not in request_body + + +def test_get_oauth2_token_with_pkce(mocker: MockerFixture) -> None: + """ + Test BaseEngineSpec.get_oauth2_token includes code_verifier when provided. + """ + from superset.db_engine_specs.base import BaseEngineSpec + from superset.utils.oauth2 import generate_code_verifier + + mocker.patch( + "flask.current_app.config", + {"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 30)}, + ) + mock_post = mocker.patch("superset.db_engine_specs.base.requests.post") + mock_post.return_value.json.return_value = { + "access_token": "test-access-token", # noqa: S105 + "expires_in": 3600, + } + + config: OAuth2ClientConfig = { + "id": "client-id", + "secret": "client-secret", + "scope": "read write", + "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/", + "authorization_request_uri": "https://oauth.example.com/authorize", + "token_request_uri": "https://oauth.example.com/token", + "request_content_type": "json", + } + + code_verifier = generate_code_verifier() + result = BaseEngineSpec.get_oauth2_token(config, "auth-code", code_verifier) + + assert result["access_token"] == "test-access-token" # noqa: S105 + # Verify code_verifier IS in the request body (PKCE) + call_kwargs = mock_post.call_args + request_body = call_kwargs.kwargs.get("json") or call_kwargs.kwargs.get("data") + assert request_body["code_verifier"] == code_verifier + def test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) -> None: """ @@ -904,6 +1022,8 @@ def test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) -> N "DATABASE_OAUTH2_JWT_ALGORITHM": "HS256", }, ) + mocker.patch("superset.daos.key_value.KeyValueDAO") + mocker.patch("superset.db_engine_specs.base.db") g = mocker.patch("superset.db_engine_specs.base.g") g.user.id = 1 @@ -944,6 +1064,8 @@ def test_start_oauth2_dance_falls_back_to_url_for(mocker: MockerFixture) -> None "superset.db_engine_specs.base.url_for", return_value=fallback_uri, ) + mocker.patch("superset.daos.key_value.KeyValueDAO") + mocker.patch("superset.db_engine_specs.base.db") g = mocker.patch("superset.db_engine_specs.base.g") g.user.id = 1 diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py index 6e221e1494b6..9fd3a0ac0e24 100644 --- a/tests/unit_tests/sql_lab_test.py +++ b/tests/unit_tests/sql_lab_test.py @@ -18,6 +18,7 @@ import json # noqa: TID251 from unittest.mock import MagicMock +from urllib.parse import parse_qs, urlparse from uuid import UUID import pytest @@ -201,6 +202,13 @@ def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None: "superset.db_engine_specs.base.uuid4", return_value=UUID("fb11f528-6eba-4a8a-837e-6b0d39ee9187"), ) + mocker.patch( + "superset.db_engine_specs.base.generate_code_verifier", + return_value="xkBPVZoFChVcy3VZ2l5u7d0FZPTU-olO7HtsAOok2IUGigyoZ62tG_oldy2xg9_HdqPKrWUmKZLmU-CUqz_SQ", + ) + mocker.patch("superset.daos.key_value.KeyValueDAO.delete_expired_entries") + mocker.patch("superset.daos.key_value.KeyValueDAO.create_entry") + mocker.patch("superset.db_engine_specs.base.db.session.commit") g = mocker.patch("superset.db_engine_specs.base.g") g.user = mocker.MagicMock() @@ -222,22 +230,39 @@ def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None: mocker.patch("superset.sql_lab.get_query", return_value=query) payload = get_sql_results(query_id=1, rendered_query="SELECT 1") - assert payload == { - "status": QueryStatus.FAILED, - "error": "You don't have permission to access the data.", - "errors": [ - { - "message": "You don't have permission to access the data.", - "error_type": SupersetErrorType.OAUTH2_REDIRECT, - "level": ErrorLevel.WARNING, - "extra": { - "url": "https://abcd1234.snowflakecomputing.com/oauth/authorize?scope=refresh_token+session%3Arole%3AUSERADMIN&response_type=code&state=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9%252EeyJleHAiOjE2MTcyMzU1MDAsImRhdGFiYXNlX2lkIjoxLCJ1c2VyX2lkIjo0MiwiZGVmYXVsdF9yZWRpcmVjdF91cmkiOiJodHRwOi8vbG9jYWxob3N0L2FwaS92MS9kYXRhYmFzZS9vYXV0aDIvIiwidGFiX2lkIjoiZmIxMWY1MjgtNmViYS00YThhLTgzN2UtNmIwZDM5ZWU5MTg3In0%252E7nLkei6-V8sVk_Pgm8cFhk0tnKRKayRE1Vc7RxuM9mw&redirect_uri=http%3A%2F%2Flocalhost%2Fapi%2Fv1%2Fdatabase%2Foauth2%2F&client_id=my_client_id", - "tab_id": "fb11f528-6eba-4a8a-837e-6b0d39ee9187", - "redirect_uri": "http://localhost/api/v1/database/oauth2/", - }, - } - ], - } + assert payload["status"] == QueryStatus.FAILED + assert payload["error"] == "You don't have permission to access the data." + assert len(payload["errors"]) == 1 + + error = payload["errors"][0] + assert error["message"] == "You don't have permission to access the data." + assert error["error_type"] == SupersetErrorType.OAUTH2_REDIRECT + assert error["level"] == ErrorLevel.WARNING + assert error["extra"]["tab_id"] == "fb11f528-6eba-4a8a-837e-6b0d39ee9187" + assert error["extra"]["redirect_uri"] == "http://localhost/api/v1/database/oauth2/" + + # Parse the OAuth2 authorization URL and verify components individually, + # since the JWT state and PKCE code_challenge are computed deterministically + # from mocked inputs but their exact encoding depends on library internals. + url = urlparse(error["extra"]["url"]) + assert url.scheme == "https" + assert url.netloc == "abcd1234.snowflakecomputing.com" + assert url.path == "/oauth/authorize" + + params = parse_qs(url.query) + assert params["scope"] == ["refresh_token session:role:USERADMIN"] + assert params["response_type"] == ["code"] + assert params["redirect_uri"] == ["http://localhost/api/v1/database/oauth2/"] + assert params["client_id"] == ["my_client_id"] + assert params["code_challenge_method"] == ["S256"] + + # Verify PKCE code_challenge matches the mocked code_verifier + from superset.utils.oauth2 import generate_code_challenge + + expected_code_challenge = generate_code_challenge( + "xkBPVZoFChVcy3VZ2l5u7d0FZPTU-olO7HtsAOok2IUGigyoZ62tG_oldy2xg9_HdqPKrWUmKZLmU-CUqz_SQ" + ) + assert params["code_challenge"] == [expected_code_challenge] def test_apply_rls(mocker: MockerFixture) -> None: diff --git a/tests/unit_tests/utils/oauth2_tests.py b/tests/unit_tests/utils/oauth2_tests.py index fc3ed7a651d9..33a0c0c26630 100644 --- a/tests/unit_tests/utils/oauth2_tests.py +++ b/tests/unit_tests/utils/oauth2_tests.py @@ -17,6 +17,8 @@ # pylint: disable=invalid-name, disallowed-name +import base64 +import hashlib from datetime import datetime from typing import cast @@ -25,7 +27,14 @@ from pytest_mock import MockerFixture from superset.superset_typing import OAuth2ClientConfig -from superset.utils.oauth2 import get_oauth2_access_token, refresh_oauth2_token +from superset.utils.oauth2 import ( + decode_oauth2_state, + encode_oauth2_state, + generate_code_challenge, + generate_code_verifier, + get_oauth2_access_token, + refresh_oauth2_token, +) DUMMY_OAUTH2_CONFIG = cast(OAuth2ClientConfig, {}) @@ -177,3 +186,96 @@ def test_refresh_oauth2_token_no_access_token_in_response( result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token) assert result is None + + +def test_generate_code_verifier_length() -> None: + """ + Test that generate_code_verifier produces a string of valid length (RFC 7636). + """ + code_verifier = generate_code_verifier() + # RFC 7636 requires 43-128 characters + assert 43 <= len(code_verifier) <= 128 + + +def test_generate_code_verifier_uniqueness() -> None: + """ + Test that generate_code_verifier produces unique values. + """ + verifiers = {generate_code_verifier() for _ in range(100)} + # All generated verifiers should be unique + assert len(verifiers) == 100 + + +def test_generate_code_verifier_valid_characters() -> None: + """ + Test that generate_code_verifier only uses valid characters (RFC 7636). + """ + code_verifier = generate_code_verifier() + # RFC 7636 allows: [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~" + # URL-safe base64 uses: [A-Z] / [a-z] / [0-9] / "-" / "_" + valid_chars = set( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" + ) + assert all(char in valid_chars for char in code_verifier) + + +def test_generate_code_challenge_s256() -> None: + """ + Test that generate_code_challenge produces correct S256 challenge. + """ + # Use a known code_verifier to verify the challenge computation + code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + + # Compute expected challenge manually + digest = hashlib.sha256(code_verifier.encode("ascii")).digest() + expected_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + + code_challenge = generate_code_challenge(code_verifier) + assert code_challenge == expected_challenge + + +def test_generate_code_challenge_rfc_example() -> None: + """ + Test PKCE code challenge against RFC 7636 Appendix B example. + + See: https://datatracker.ietf.org/doc/html/rfc7636#appendix-B + """ + # RFC 7636 example code_verifier (Appendix B) + code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + # RFC 7636 expected code_challenge for S256 method + expected_challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" + + code_challenge = generate_code_challenge(code_verifier) + assert code_challenge == expected_challenge + + +def test_encode_decode_oauth2_state( + mocker: MockerFixture, +) -> None: + """ + Test that encode/decode cycle preserves state fields. + """ + from superset.superset_typing import OAuth2State + + mocker.patch( + "flask.current_app.config", + { + "SECRET_KEY": "test-secret-key", + "DATABASE_OAUTH2_JWT_ALGORITHM": "HS256", + }, + ) + + state: OAuth2State = { + "database_id": 1, + "user_id": 2, + "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/", + "tab_id": "test-tab-id", + } + + with freeze_time("2024-01-01"): + encoded = encode_oauth2_state(state) + decoded = decode_oauth2_state(encoded) + + assert "code_verifier" not in decoded + assert decoded["database_id"] == 1 + assert decoded["user_id"] == 2