Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions superset/commands/database/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
from datetime import datetime, timedelta
from functools import partial
from typing import cast
from uuid import UUID

from superset.commands.base import BaseCommand
from superset.commands.database.exceptions import DatabaseNotFoundError
from superset.daos.database import DatabaseUserOAuth2TokensDAO
from superset.daos.key_value import KeyValueDAO
from superset.databases.schemas import OAuth2ProviderResponseSchema
from superset.exceptions import OAuth2Error
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
from superset.models.core import Database, DatabaseUserOAuth2Tokens
from superset.superset_typing import OAuth2State
from superset.utils.decorators import on_error, transaction
Expand All @@ -50,9 +53,28 @@ def run(self) -> DatabaseUserOAuth2Tokens:
if oauth2_config is None:
raise OAuth2Error("No configuration found for OAuth2")

# Look up PKCE code_verifier from KV store (RFC 7636)
code_verifier = None
tab_id = self._state["tab_id"]
try:
tab_uuid = UUID(tab_id)
except ValueError:
tab_uuid = None

if tab_uuid:
kv_value = KeyValueDAO.get_value(
resource=KeyValueResource.PKCE_CODE_VERIFIER,
key=tab_uuid,
codec=JsonKeyValueCodec(),
)
if kv_value:
code_verifier = kv_value.get("code_verifier")
KeyValueDAO.delete_entry(KeyValueResource.PKCE_CODE_VERIFIER, tab_uuid)

token_response = self._database.db_engine_spec.get_oauth2_token(
oauth2_config,
self._parameters["code"],
code_verifier=code_verifier,
)

# delete old tokens
Expand Down
66 changes: 58 additions & 8 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import logging
import re
import warnings
from datetime import datetime
from datetime import datetime, timedelta
from inspect import signature
from re import Match, Pattern
from typing import (
Expand All @@ -36,7 +36,7 @@
Union,
)
from urllib.parse import urlencode, urljoin
from uuid import uuid4
from uuid import UUID, uuid4

import pandas as pd
import requests
Expand All @@ -63,6 +63,7 @@
from superset.databases.utils import get_table_metadata, make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
from superset.sql.parse import (
BaseSQLStatement,
LimitMethod,
Expand All @@ -83,7 +84,11 @@
from superset.utils.hashing import hash_from_str
from superset.utils.json import redact_sensitive, reveal_sensitive
from superset.utils.network import is_hostname_valid, is_port_open
from superset.utils.oauth2 import encode_oauth2_state
from superset.utils.oauth2 import (
encode_oauth2_state,
generate_code_challenge,
generate_code_verifier,
)

if TYPE_CHECKING:
from superset.connectors.sqla.models import TableColumn
Expand Down Expand Up @@ -608,13 +613,38 @@ def start_oauth2_dance(cls, database: Database) -> None:
tab sends a message to the original tab informing that authorization was
successful (or not), and then closes. The original tab will automatically
re-run the query after authorization.

PKCE (RFC 7636) is used to protect against authorization code interception
attacks. A code_verifier is generated and stored server-side in the KV store,
while the code_challenge (derived from the verifier) is sent to the
authorization server.
"""
# Prevent circular import.
from superset.daos.key_value import KeyValueDAO

tab_id = str(uuid4())
default_redirect_uri = app.config.get(
"DATABASE_OAUTH2_REDIRECT_URI",
url_for("DatabaseRestApi.oauth2", _external=True),
)

# Generate PKCE code verifier (RFC 7636)
code_verifier = generate_code_verifier()

# Store the code_verifier server-side in the KV store, keyed by tab_id.
# This avoids exposing it in the URL/browser history via the JWT state.
KeyValueDAO.delete_expired_entries(KeyValueResource.PKCE_CODE_VERIFIER)
KeyValueDAO.create_entry(
resource=KeyValueResource.PKCE_CODE_VERIFIER,
value={"code_verifier": code_verifier},
codec=JsonKeyValueCodec(),
key=UUID(tab_id),
expires_on=datetime.now() + timedelta(minutes=5),
)
# We need to commit here because we're going to raise an exception, which will
# revert any non-commited changes.
db.session.commit()

# The state is passed to the OAuth2 provider, and sent back to Superset after
# the user authorizes the access. The redirect endpoint in Superset can then
# inspect the state to figure out to which user/database the access token
Expand All @@ -641,7 +671,11 @@ def start_oauth2_dance(cls, database: Database) -> None:
if oauth2_config is None:
raise OAuth2Error("No configuration found for OAuth2")

oauth_url = cls.get_oauth2_authorization_uri(oauth2_config, state)
oauth_url = cls.get_oauth2_authorization_uri(
oauth2_config,
state,
code_verifier=code_verifier,
)

raise OAuth2RedirectError(oauth_url, tab_id, default_redirect_uri)

Expand Down Expand Up @@ -685,41 +719,57 @@ def get_oauth2_authorization_uri(
cls,
config: OAuth2ClientConfig,
state: OAuth2State,
code_verifier: str | None = None,
) -> str:
"""
Return URI for initial OAuth2 request.

Uses standard OAuth 2.0 parameters only. Subclasses can override
to add provider-specific parameters (e.g., Google's prompt=consent).
Uses standard OAuth 2.0 parameters plus PKCE (RFC 7636) parameters.
Subclasses can override to add provider-specific parameters
(e.g., Google's prompt=consent).
"""
uri = config["authorization_request_uri"]
params = {
params: dict[str, str] = {
"scope": config["scope"],
"response_type": "code",
"state": encode_oauth2_state(state),
"redirect_uri": config["redirect_uri"],
"client_id": config["id"],
}

# Add PKCE parameters (RFC 7636) if code_verifier is provided
if code_verifier:
params["code_challenge"] = generate_code_challenge(code_verifier)
params["code_challenge_method"] = "S256"

return urljoin(uri, "?" + urlencode(params))

@classmethod
def get_oauth2_token(
cls,
config: OAuth2ClientConfig,
code: str,
code_verifier: str | None = None,
) -> OAuth2TokenResponse:
"""
Exchange authorization code for refresh/access tokens.

If code_verifier is provided (PKCE flow), it will be included in the
token request per RFC 7636.
"""
timeout = app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
req_body = {
req_body: dict[str, str] = {
"code": code,
"client_id": config["id"],
"client_secret": config["secret"],
"redirect_uri": config["redirect_uri"],
"grant_type": "authorization_code",
}
# Add PKCE code_verifier if present (RFC 7636)
if code_verifier:
req_body["code_verifier"] = code_verifier

response = (
requests.post(uri, data=req_body, timeout=timeout)
if config["request_content_type"] == "data"
Expand Down
11 changes: 9 additions & 2 deletions superset/db_engine_specs/gsheets.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def get_oauth2_authorization_uri(
cls,
config: "OAuth2ClientConfig",
state: "OAuth2State",
code_verifier: str | None = None,
) -> str:
"""
Return URI for initial OAuth2 request with Google-specific parameters.
Expand All @@ -172,10 +173,10 @@ def get_oauth2_authorization_uri(
"""
from urllib.parse import urlencode, urljoin

from superset.utils.oauth2 import encode_oauth2_state
from superset.utils.oauth2 import encode_oauth2_state, generate_code_challenge

uri = config["authorization_request_uri"]
params = {
params: dict[str, str] = {
"scope": config["scope"],
"response_type": "code",
"state": encode_oauth2_state(state),
Expand All @@ -186,6 +187,12 @@ def get_oauth2_authorization_uri(
"include_granted_scopes": "false",
"prompt": "consent",
}

# Add PKCE parameters (RFC 7636) if code_verifier is provided
if code_verifier:
params["code_challenge"] = generate_code_challenge(code_verifier)
params["code_challenge_method"] = "S256"

return urljoin(uri, "?" + urlencode(params))

@classmethod
Expand Down
1 change: 1 addition & 0 deletions superset/key_value/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class KeyValueResource(StrEnum):
EXPLORE_PERMALINK = "explore_permalink"
METASTORE_CACHE = "superset_metastore_cache"
LOCK = "lock"
PKCE_CODE_VERIFIER = "pkce_code_verifier"
SQLLAB_PERMALINK = "sqllab_permalink"


Expand Down
2 changes: 1 addition & 1 deletion superset/superset_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ class OAuth2TokenResponse(TypedDict, total=False):
refresh_token: str


class OAuth2State(TypedDict):
class OAuth2State(TypedDict, total=False):
"""
Type for the state passed during OAuth2.
"""
Expand Down
49 changes: 42 additions & 7 deletions superset/utils/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

from __future__ import annotations

import base64
import hashlib
import logging
import secrets
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from typing import Any, Iterator, TYPE_CHECKING
Expand All @@ -40,6 +43,37 @@

logger = logging.getLogger(__name__)

# PKCE code verifier length (RFC 7636 recommends 43-128 characters)
PKCE_CODE_VERIFIER_LENGTH = 64


def generate_code_verifier() -> str:
"""
Generate a PKCE code verifier (RFC 7636).

The code verifier is a high-entropy cryptographic random string using
unreserved characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~",
with a minimum length of 43 characters and a maximum length of 128.
"""
# Generate random bytes and encode as URL-safe base64
random_bytes = secrets.token_bytes(PKCE_CODE_VERIFIER_LENGTH)
# Use URL-safe base64 encoding without padding
code_verifier = base64.urlsafe_b64encode(random_bytes).rstrip(b"=").decode("ascii")
return code_verifier


def generate_code_challenge(code_verifier: str) -> str:
"""
Generate a PKCE code challenge from a code verifier (RFC 7636).

Uses the S256 method: BASE64URL(SHA256(code_verifier))
"""
# Compute SHA-256 hash of the code verifier
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
# Encode as URL-safe base64 without padding
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
return code_challenge


@backoff.on_exception(
backoff.expo,
Expand Down Expand Up @@ -140,13 +174,14 @@ def encode_oauth2_state(state: OAuth2State) -> str:
"""
Encode the OAuth2 state.
"""
payload = {
payload: dict[str, Any] = {
"exp": datetime.now(tz=timezone.utc) + JWT_EXPIRATION,
"database_id": state["database_id"],
"user_id": state["user_id"],
"default_redirect_uri": state["default_redirect_uri"],
"tab_id": state["tab_id"],
}

encoded_state = jwt.encode(
payload=payload,
key=app.config["SECRET_KEY"],
Expand All @@ -172,12 +207,12 @@ def make_oauth2_state(
data: dict[str, Any],
**kwargs: Any,
) -> OAuth2State:
return OAuth2State(
database_id=data["database_id"],
user_id=data["user_id"],
default_redirect_uri=data["default_redirect_uri"],
tab_id=data["tab_id"],
)
return {
"database_id": data["database_id"],
"user_id": data["user_id"],
"default_redirect_uri": data["default_redirect_uri"],
"tab_id": data["tab_id"],
}

class Meta: # pylint: disable=too-few-public-methods
# ignore `exp`
Expand Down
Loading
Loading