diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 825304db836d..bc34e39bc24c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,6 +27,7 @@ repos: args: [--check-untyped-defs] exclude: ^superset-extensions-cli/ additional_dependencies: [ + types-cachetools, types-simplejson, types-python-dateutil, types-requests, diff --git a/docs/static/feature-flags.json b/docs/static/feature-flags.json index bfe0955ea89a..95738dd913fc 100644 --- a/docs/static/feature-flags.json +++ b/docs/static/feature-flags.json @@ -114,6 +114,12 @@ "lifecycle": "testing", "description": "Allow users to export full CSV of table viz type. Warning: Could cause server memory/compute issues with large datasets." }, + { + "name": "AWS_DATABASE_IAM_AUTH", + "default": false, + "lifecycle": "testing", + "description": "Enable AWS IAM authentication for database connections (Aurora, Redshift). Allows cross-account role assumption via STS AssumeRole. Security note: When enabled, ensure Superset's IAM role has restricted sts:AssumeRole permissions to prevent unauthorized access." + }, { "name": "CACHE_IMPERSONATION", "default": false, diff --git a/pyproject.toml b/pyproject.toml index fc37dbe89c16..87496d5e9726 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -204,6 +204,7 @@ ydb = ["ydb-sqlalchemy>=0.1.2"] development = [ # no bounds for apache-superset-extensions-cli until a stable version "apache-superset-extensions-cli", + "boto3", "docker", "flask-testing", "freezegun", diff --git a/requirements/development.txt b/requirements/development.txt index c91b6a3646ec..d26c1c78b91b 100644 --- a/requirements/development.txt +++ b/requirements/development.txt @@ -76,6 +76,12 @@ blinker==1.9.0 # via # -c requirements/base-constraint.txt # flask +boto3==1.42.39 + # via apache-superset +botocore==1.42.39 + # via + # boto3 + # s3transfer bottleneck==1.5.0 # via # -c requirements/base-constraint.txt @@ -460,6 +466,10 @@ jinja2==3.1.6 # apache-superset-extensions-cli # flask # flask-babel +jmespath==1.1.0 + # via + # boto3 + # botocore jsonpath-ng==1.7.0 # via # -c requirements/base-constraint.txt @@ -812,6 +822,7 @@ python-dateutil==2.9.0.post0 # via # -c requirements/base-constraint.txt # apache-superset + # botocore # celery # croniter # flask-appbuilder @@ -915,6 +926,8 @@ rsa==4.9.1 # google-auth ruff==0.9.7 # via apache-superset +s3transfer==0.16.0 + # via boto3 secretstorage==3.5.0 # via keyring selenium==4.32.0 @@ -1066,6 +1079,7 @@ url-normalize==2.2.1 urllib3==2.6.3 # via # -c requirements/base-constraint.txt + # botocore # docker # requests # requests-cache diff --git a/superset-frontend/packages/superset-ui-core/src/utils/index.ts b/superset-frontend/packages/superset-ui-core/src/utils/index.ts index 426af7e9f3fb..4d6e869cd0ca 100644 --- a/superset-frontend/packages/superset-ui-core/src/utils/index.ts +++ b/superset-frontend/packages/superset-ui-core/src/utils/index.ts @@ -25,6 +25,7 @@ export { default as isEqualArray } from './isEqualArray'; export { default as makeSingleton } from './makeSingleton'; export { default as promiseTimeout } from './promiseTimeout'; export { default as removeDuplicates } from './removeDuplicates'; +export { default as withLabel } from './withLabel'; export { lruCache } from './lruCache'; export { getSelectedText } from './getSelectedText'; export * from './featureFlags'; diff --git a/superset-frontend/packages/superset-ui-core/src/utils/withLabel.ts b/superset-frontend/packages/superset-ui-core/src/utils/withLabel.ts new file mode 100644 index 000000000000..59da9612b3d4 --- /dev/null +++ b/superset-frontend/packages/superset-ui-core/src/utils/withLabel.ts @@ -0,0 +1,43 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import type { ValidatorFunction } from '../validator'; + +/** + * Wraps a validator function to prepend a label to its error message. + * + * @param validator - The validator function to wrap + * @param label - The label to prepend to error messages + * @returns A new validator function that includes the label in error messages + * + * @example + * validators: [ + * withLabel(validateInteger, t('Row limit')), + * ] + * // Returns: "Row limit is expected to be an integer" + */ +export default function withLabel( + validator: ValidatorFunction, + label: string, +): ValidatorFunction { + return (value: V, state?: S): string | false => { + const error = validator(value, state); + return error ? `${label} ${error}` : false; + }; +} diff --git a/superset-frontend/packages/superset-ui-core/src/validator/index.ts b/superset-frontend/packages/superset-ui-core/src/validator/index.ts index df8dc10a70f9..1bc68d2929e9 100644 --- a/superset-frontend/packages/superset-ui-core/src/validator/index.ts +++ b/superset-frontend/packages/superset-ui-core/src/validator/index.ts @@ -17,6 +17,7 @@ * under the License. */ +export * from './types'; export { default as legacyValidateInteger } from './legacyValidateInteger'; export { default as legacyValidateNumber } from './legacyValidateNumber'; export { default as validateInteger } from './validateInteger'; diff --git a/superset-frontend/packages/superset-ui-core/src/validator/legacyValidateInteger.ts b/superset-frontend/packages/superset-ui-core/src/validator/legacyValidateInteger.ts index 972fdf855dac..e2b3303506b3 100644 --- a/superset-frontend/packages/superset-ui-core/src/validator/legacyValidateInteger.ts +++ b/superset-frontend/packages/superset-ui-core/src/validator/legacyValidateInteger.ts @@ -23,7 +23,7 @@ import { t } from '@apache-superset/core'; * formerly called integer() * @param v */ -export default function legacyValidateInteger(v: unknown) { +export default function legacyValidateInteger(v: unknown): string | false { if ( v && (Number.isNaN(Number(v)) || parseInt(v as string, 10) !== Number(v)) diff --git a/superset-frontend/packages/superset-ui-core/src/validator/legacyValidateNumber.ts b/superset-frontend/packages/superset-ui-core/src/validator/legacyValidateNumber.ts index d6a7b337f0d0..e075500c76d6 100644 --- a/superset-frontend/packages/superset-ui-core/src/validator/legacyValidateNumber.ts +++ b/superset-frontend/packages/superset-ui-core/src/validator/legacyValidateNumber.ts @@ -23,7 +23,7 @@ import { t } from '@apache-superset/core'; * formerly called numeric() * @param v */ -export default function numeric(v: unknown) { +export default function numeric(v: unknown): string | false { if (v && Number.isNaN(Number(v))) { return t('is expected to be a number'); } diff --git a/superset-frontend/packages/superset-ui-core/src/validator/types.ts b/superset-frontend/packages/superset-ui-core/src/validator/types.ts new file mode 100644 index 000000000000..3313af30fbe8 --- /dev/null +++ b/superset-frontend/packages/superset-ui-core/src/validator/types.ts @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Type definition for a validator function. + * Returns an error message string if validation fails, or false if validation passes. + */ +export type ValidatorFunction = ( + value: V, + state?: S, +) => string | false; diff --git a/superset-frontend/packages/superset-ui-core/src/validator/validateInteger.ts b/superset-frontend/packages/superset-ui-core/src/validator/validateInteger.ts index bea18dca9d78..4aa4af1d21bd 100644 --- a/superset-frontend/packages/superset-ui-core/src/validator/validateInteger.ts +++ b/superset-frontend/packages/superset-ui-core/src/validator/validateInteger.ts @@ -19,7 +19,7 @@ import { t } from '@apache-superset/core'; -export default function validateInteger(v: unknown) { +export default function validateInteger(v: unknown): string | false { if ( (typeof v === 'string' && v.trim().length > 0 && diff --git a/superset-frontend/packages/superset-ui-core/src/validator/validateMapboxStylesUrl.ts b/superset-frontend/packages/superset-ui-core/src/validator/validateMapboxStylesUrl.ts index facbc149aefc..66c474d5ad49 100644 --- a/superset-frontend/packages/superset-ui-core/src/validator/validateMapboxStylesUrl.ts +++ b/superset-frontend/packages/superset-ui-core/src/validator/validateMapboxStylesUrl.ts @@ -25,7 +25,7 @@ const VALIDE_OSM_URLS = ['https://tile.osm', 'https://tile.openstreetmap']; * Validate a [Mapbox styles URL](https://docs.mapbox.com/help/glossary/style-url/) * @param v */ -export default function validateMapboxStylesUrl(v: unknown) { +export default function validateMapboxStylesUrl(v: unknown): string | false { if (typeof v === 'string') { const trimmed_v = v.trim(); if ( diff --git a/superset-frontend/packages/superset-ui-core/src/validator/validateMaxValue.ts b/superset-frontend/packages/superset-ui-core/src/validator/validateMaxValue.ts index bb7e6c052b3a..353e14315c64 100644 --- a/superset-frontend/packages/superset-ui-core/src/validator/validateMaxValue.ts +++ b/superset-frontend/packages/superset-ui-core/src/validator/validateMaxValue.ts @@ -18,7 +18,10 @@ */ import { t } from '@apache-superset/core'; -export default function validateMaxValue(v: unknown, max: number) { +export default function validateMaxValue( + v: unknown, + max: number, +): string | false { if (Number(v) > +max) { return t('Value cannot exceed %s', max); } diff --git a/superset-frontend/packages/superset-ui-core/src/validator/validateNonEmpty.ts b/superset-frontend/packages/superset-ui-core/src/validator/validateNonEmpty.ts index 835c433fe2e6..1d8a525c631f 100644 --- a/superset-frontend/packages/superset-ui-core/src/validator/validateNonEmpty.ts +++ b/superset-frontend/packages/superset-ui-core/src/validator/validateNonEmpty.ts @@ -19,7 +19,7 @@ import { t } from '@apache-superset/core'; -export default function validateNonEmpty(v: unknown) { +export default function validateNonEmpty(v: unknown): string | false { if ( v === null || typeof v === 'undefined' || diff --git a/superset-frontend/packages/superset-ui-core/src/validator/validateNumber.ts b/superset-frontend/packages/superset-ui-core/src/validator/validateNumber.ts index ce8db32cd28b..524d27d4b581 100644 --- a/superset-frontend/packages/superset-ui-core/src/validator/validateNumber.ts +++ b/superset-frontend/packages/superset-ui-core/src/validator/validateNumber.ts @@ -19,7 +19,7 @@ import { t } from '@apache-superset/core'; -export default function validateInteger(v: any) { +export default function validateNumber(v: unknown): string | false { if ( (typeof v === 'string' && v.trim().length > 0 && diff --git a/superset-frontend/packages/superset-ui-core/src/validator/validateServerPagination.ts b/superset-frontend/packages/superset-ui-core/src/validator/validateServerPagination.ts index 1907a8198c5f..4fde0b12aa93 100644 --- a/superset-frontend/packages/superset-ui-core/src/validator/validateServerPagination.ts +++ b/superset-frontend/packages/superset-ui-core/src/validator/validateServerPagination.ts @@ -23,7 +23,7 @@ export default function validateServerPagination( serverPagination: boolean, maxValueWithoutServerPagination: number, maxServer: number, -) { +): string | false { if ( Number(v) > +maxValueWithoutServerPagination && Number(v) <= maxServer && diff --git a/superset-frontend/packages/superset-ui-core/src/validator/validateTimeComparisonRangeValues.ts b/superset-frontend/packages/superset-ui-core/src/validator/validateTimeComparisonRangeValues.ts index b362757db12b..2c8ccc5cb980 100644 --- a/superset-frontend/packages/superset-ui-core/src/validator/validateTimeComparisonRangeValues.ts +++ b/superset-frontend/packages/superset-ui-core/src/validator/validateTimeComparisonRangeValues.ts @@ -22,13 +22,13 @@ import { t } from '@apache-superset/core'; import { ensureIsArray } from '../utils'; export const validateTimeComparisonRangeValues = ( - timeRangeValue?: any, - controlValue?: any, -) => { + timeRangeValue?: unknown, + controlValue?: unknown, +): string[] => { const isCustomTimeRange = timeRangeValue === ComparisonTimeRangeType.Custom; - const isCustomControlEmpty = controlValue?.every( - (val: any) => ensureIsArray(val).length === 0, - ); + const isCustomControlEmpty = + Array.isArray(controlValue) && + controlValue.every((val: unknown) => ensureIsArray(val).length === 0); return isCustomTimeRange && isCustomControlEmpty ? [t('Filters for comparison must have a value')] : []; diff --git a/superset-frontend/packages/superset-ui-core/test/validator/validateMaxValue.test.ts b/superset-frontend/packages/superset-ui-core/test/validator/validateMaxValue.test.ts index 6a8ed1642e7b..3a6698f06a82 100644 --- a/superset-frontend/packages/superset-ui-core/test/validator/validateMaxValue.test.ts +++ b/superset-frontend/packages/superset-ui-core/test/validator/validateMaxValue.test.ts @@ -20,13 +20,13 @@ import { validateMaxValue } from '@superset-ui/core'; import './setup'; -test('validateInteger returns the warning message if invalid', () => { +test('validateMaxValue returns the warning message if invalid', () => { expect(validateMaxValue(10.1, 10)).toBeTruthy(); expect(validateMaxValue(1, 0)).toBeTruthy(); expect(validateMaxValue('2', 1)).toBeTruthy(); }); -test('validateInteger returns false if the input is valid', () => { +test('validateMaxValue returns false if the input is valid', () => { expect(validateMaxValue(0, 1)).toBeFalsy(); expect(validateMaxValue(10, 10)).toBeFalsy(); expect(validateMaxValue(undefined, 1)).toBeFalsy(); diff --git a/superset-frontend/plugins/plugin-chart-ag-grid-table/src/controlPanel.tsx b/superset-frontend/plugins/plugin-chart-ag-grid-table/src/controlPanel.tsx index a098674c65c4..fd549558d06f 100644 --- a/superset-frontend/plugins/plugin-chart-ag-grid-table/src/controlPanel.tsx +++ b/superset-frontend/plugins/plugin-chart-ag-grid-table/src/controlPanel.tsx @@ -51,6 +51,7 @@ import { SMART_DATE_ID, validateMaxValue, validateServerPagination, + withLabel, } from '@superset-ui/core'; import { GenericDataType } from '@apache-superset/core/api/core'; import { isEmpty, last } from 'lodash'; @@ -384,7 +385,7 @@ const config: ControlPanelConfig = { description: t('Rows per page, 0 means no pagination'), visibility: ({ controls }: ControlPanelsContainerProps) => Boolean(controls?.server_pagination?.value), - validators: [validateInteger], + validators: [withLabel(validateInteger, t('Server Page Length'))], }, }, ], @@ -403,7 +404,7 @@ const config: ControlPanelConfig = { state?.common?.conf?.SQL_MAX_ROW, }), validators: [ - validateInteger, + withLabel(validateInteger, t('Row limit')), (v, state) => validateMaxValue( v, diff --git a/superset-frontend/plugins/plugin-chart-echarts/src/Histogram/controlPanel.tsx b/superset-frontend/plugins/plugin-chart-echarts/src/Histogram/controlPanel.tsx index 978b1427d6d9..839a2af354f7 100644 --- a/superset-frontend/plugins/plugin-chart-echarts/src/Histogram/controlPanel.tsx +++ b/superset-frontend/plugins/plugin-chart-echarts/src/Histogram/controlPanel.tsx @@ -17,7 +17,11 @@ * under the License. */ import { t } from '@apache-superset/core'; -import { validateInteger, validateNonEmpty } from '@superset-ui/core'; +import { + validateInteger, + validateNonEmpty, + withLabel, +} from '@superset-ui/core'; import { GenericDataType } from '@apache-superset/core/api/core'; import { ControlPanelConfig, @@ -66,7 +70,7 @@ const config: ControlPanelConfig = { default: 5, choices: formatSelectOptionsForRange(5, 20, 5), description: t('The number of bins for the histogram'), - validators: [validateInteger], + validators: [withLabel(validateInteger, t('Bins'))], }, }, ], diff --git a/superset-frontend/plugins/plugin-chart-table/src/controlPanel.tsx b/superset-frontend/plugins/plugin-chart-table/src/controlPanel.tsx index 274ed5de69d4..aed287d66e8b 100644 --- a/superset-frontend/plugins/plugin-chart-table/src/controlPanel.tsx +++ b/superset-frontend/plugins/plugin-chart-table/src/controlPanel.tsx @@ -52,6 +52,7 @@ import { SMART_DATE_ID, validateMaxValue, validateServerPagination, + withLabel, } from '@superset-ui/core'; import { GenericDataType } from '@apache-superset/core/api/core'; import { isEmpty, last } from 'lodash'; @@ -407,7 +408,7 @@ const config: ControlPanelConfig = { description: t('Rows per page, 0 means no pagination'), visibility: ({ controls }: ControlPanelsContainerProps) => Boolean(controls?.server_pagination?.value), - validators: [validateInteger], + validators: [withLabel(validateInteger, t('Server Page Length'))], }, }, ], @@ -426,7 +427,7 @@ const config: ControlPanelConfig = { state?.common?.conf?.SQL_MAX_ROW, }), validators: [ - validateInteger, + withLabel(validateInteger, t('Row limit')), (v, state) => validateMaxValue( v, @@ -448,9 +449,6 @@ const config: ControlPanelConfig = { 'Limits the number of the rows that are computed in the query that is the source of the data used for this chart.', ), }, - override: { - default: 1000, - }, }, ], [ diff --git a/superset-frontend/src/explore/controlUtils/getControlState.ts b/superset-frontend/src/explore/controlUtils/getControlState.ts index 4a9e139ec4d6..a50d39313c89 100644 --- a/superset-frontend/src/explore/controlUtils/getControlState.ts +++ b/superset-frontend/src/explore/controlUtils/getControlState.ts @@ -122,7 +122,8 @@ export function applyMapStateToPropsToControl( } } // If no current value, set it as default - if (state.default && value === undefined) { + // Use loose equality to catch both null and undefined + if (state.default != null && value == null) { value = state.default; } // If a choice control went from multi=false to true, wrap value in array diff --git a/superset/config.py b/superset/config.py index 5e66fe40f48e..b048fc2f8d54 100644 --- a/superset/config.py +++ b/superset/config.py @@ -656,6 +656,12 @@ class D3TimeFormat(TypedDict, total=False): # @lifecycle: testing # @docs: https://superset.apache.org/docs/configuration/setup-ssh-tunneling "SSH_TUNNELING": False, + # Enable AWS IAM authentication for database connections (Aurora, Redshift). + # Allows cross-account role assumption via STS AssumeRole. + # Security note: When enabled, ensure Superset's IAM role has restricted + # sts:AssumeRole permissions to prevent unauthorized access. + # @lifecycle: testing + "AWS_DATABASE_IAM_AUTH": False, # Use analogous colors in charts # @lifecycle: testing "USE_ANALOGOUS_COLORS": False, diff --git a/superset/db_engine_specs/aurora.py b/superset/db_engine_specs/aurora.py index 6dcbe6e1c0fd..bac6274f271c 100644 --- a/superset/db_engine_specs/aurora.py +++ b/superset/db_engine_specs/aurora.py @@ -54,3 +54,29 @@ class AuroraPostgresDataAPI(PostgresEngineSpec): "secret_arn={secret_arn}&" "region_name={region_name}" ) + + +class AuroraMySQLEngineSpec(MySQLEngineSpec): + """ + Aurora MySQL engine spec. + + IAM authentication is handled by the parent MySQLEngineSpec via + the aws_iam config in encrypted_extra. + """ + + engine = "mysql" + engine_name = "Aurora MySQL" + default_driver = "mysqldb" + + +class AuroraPostgresEngineSpec(PostgresEngineSpec): + """ + Aurora PostgreSQL engine spec. + + IAM authentication is handled by the parent PostgresEngineSpec via + the aws_iam config in encrypted_extra. + """ + + engine = "postgresql" + engine_name = "Aurora PostgreSQL" + default_driver = "psycopg2" diff --git a/superset/db_engine_specs/aws_iam.py b/superset/db_engine_specs/aws_iam.py new file mode 100644 index 000000000000..ce2960c3d283 --- /dev/null +++ b/superset/db_engine_specs/aws_iam.py @@ -0,0 +1,660 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +AWS IAM Authentication Mixin for database engine specs. + +This mixin provides cross-account IAM authentication support for AWS databases +(Aurora PostgreSQL, Aurora MySQL, Redshift). It handles: +- Assuming IAM roles via STS AssumeRole +- Generating RDS IAM auth tokens +- Generating Redshift Serverless credentials +- Configuring SSL (required for IAM auth) +- Caching STS credentials to reduce API calls +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any, TYPE_CHECKING, TypedDict + +from cachetools import TTLCache + +from superset.databases.utils import make_url_safe +from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.exceptions import SupersetSecurityException + +if TYPE_CHECKING: + from superset.models.core import Database + +logger = logging.getLogger(__name__) + +# Default session duration for STS AssumeRole (1 hour) +DEFAULT_SESSION_DURATION = 3600 + +# Default ports +DEFAULT_POSTGRES_PORT = 5432 +DEFAULT_MYSQL_PORT = 3306 +DEFAULT_REDSHIFT_PORT = 5439 + +# Cache STS credentials: key = (role_arn, region, external_id), TTL = 10 min +# Using a TTL shorter than the minimum supported session duration (900s) avoids +# reusing expired STS credentials when a short session_duration is configured. +_credentials_cache: TTLCache[tuple[str, str, str | None], dict[str, Any]] = TTLCache( + maxsize=100, ttl=600 +) +_credentials_lock = threading.RLock() + + +class AWSIAMConfig(TypedDict, total=False): + """Configuration for AWS IAM authentication.""" + + enabled: bool + role_arn: str + external_id: str + region: str + db_username: str + session_duration: int + # Redshift Serverless fields + workgroup_name: str + db_name: str + # Redshift provisioned cluster fields + cluster_identifier: str + + +class AWSIAMAuthMixin: + """ + Mixin that provides AWS IAM authentication for database connections. + + This mixin can be used with database engine specs that support IAM + authentication (Aurora PostgreSQL, Aurora MySQL, Redshift). + + Configuration is provided via the database's encrypted_extra JSON: + + { + "aws_iam": { + "enabled": true, + "role_arn": "arn:aws:iam::222222222222:role/SupersetDatabaseAccess", + "external_id": "superset-prod-12345", # optional + "region": "us-east-1", + "db_username": "superset_iam_user", + "session_duration": 3600 # optional, defaults to 3600 + } + } + """ + + # AWS error patterns for actionable error messages + aws_iam_custom_errors: dict[str, tuple[SupersetErrorType, str]] = { + "AccessDenied": ( + SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + "Unable to assume IAM role. Verify the role ARN and trust policy " + "allow access from Superset's IAM role.", + ), + "InvalidIdentityToken": ( + SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + "Invalid IAM credentials. Ensure Superset has a valid IAM role " + "with permissions to assume the target role.", + ), + "MalformedPolicyDocument": ( + SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, + "Invalid IAM role ARN format. Please verify the role ARN.", + ), + "ExpiredTokenException": ( + SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + "AWS credentials have expired. Please refresh the connection.", + ), + } + + @classmethod + def get_iam_credentials( + cls, + role_arn: str, + region: str, + external_id: str | None = None, + session_duration: int = DEFAULT_SESSION_DURATION, + ) -> dict[str, Any]: + """ + Assume cross-account IAM role via STS AssumeRole with credential caching. + + Credentials are cached by (role_arn, region, external_id) with a 50-minute + TTL to reduce STS API calls while ensuring tokens are refreshed before the + default 1-hour expiration. + + :param role_arn: The ARN of the IAM role to assume + :param region: AWS region for the STS client + :param external_id: External ID for the role assumption (optional) + :param session_duration: Duration of the session in seconds + :returns: Dictionary with AccessKeyId, SecretAccessKey, SessionToken + :raises SupersetSecurityException: If role assumption fails + """ + cache_key = (role_arn, region, external_id) + + with _credentials_lock: + cached = _credentials_cache.get(cache_key) + if cached is not None: + return cached + + try: + # Lazy import to avoid errors when boto3 is not installed + import boto3 + from botocore.exceptions import ClientError + except ImportError as ex: + raise SupersetSecurityException( + SupersetError( + message="boto3 is required for AWS IAM authentication. " + "Install it with: pip install boto3", + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + try: + sts_client = boto3.client("sts", region_name=region) + + assume_role_kwargs: dict[str, Any] = { + "RoleArn": role_arn, + "RoleSessionName": "superset-iam-session", + "DurationSeconds": session_duration, + } + if external_id: + assume_role_kwargs["ExternalId"] = external_id + + response = sts_client.assume_role(**assume_role_kwargs) + credentials = response["Credentials"] + + with _credentials_lock: + _credentials_cache[cache_key] = credentials + + return credentials + + except ClientError as ex: + error_code = ex.response.get("Error", {}).get("Code", "") + error_message = ex.response.get("Error", {}).get("Message", "") + + # Handle ExternalId mismatch (shows as AccessDenied with specific message) + # Check this first before generic AccessDenied handling + if "external id" in error_message.lower(): + raise SupersetSecurityException( + SupersetError( + message="External ID mismatch. Verify the external_id " + "configuration matches the trust policy.", + error_type=SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + if error_code in cls.aws_iam_custom_errors: + error_type, message = cls.aws_iam_custom_errors[error_code] + raise SupersetSecurityException( + SupersetError( + message=message, + error_type=error_type, + level=ErrorLevel.ERROR, + ) + ) from ex + + raise SupersetSecurityException( + SupersetError( + message=f"Failed to assume IAM role: {ex}", + error_type=SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + @classmethod + def generate_rds_auth_token( + cls, + credentials: dict[str, Any], + hostname: str, + port: int, + username: str, + region: str, + ) -> str: + """ + Generate RDS IAM auth token using temporary credentials. + + :param credentials: STS credentials from assume_role + :param hostname: RDS/Aurora endpoint hostname + :param port: Database port + :param username: Database username configured for IAM auth + :param region: AWS region + :returns: IAM auth token to use as database password + :raises SupersetSecurityException: If token generation fails + """ + try: + import boto3 + from botocore.exceptions import ClientError + except ImportError as ex: + raise SupersetSecurityException( + SupersetError( + message="boto3 is required for AWS IAM authentication.", + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + try: + rds_client = boto3.client( + "rds", + region_name=region, + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + + token = rds_client.generate_db_auth_token( + DBHostname=hostname, + Port=port, + DBUsername=username, + ) + return token + + except ClientError as ex: + raise SupersetSecurityException( + SupersetError( + message=f"Failed to generate RDS auth token: {ex}", + error_type=SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + @classmethod + def generate_redshift_credentials( + cls, + credentials: dict[str, Any], + workgroup_name: str, + db_name: str, + region: str, + ) -> tuple[str, str]: + """ + Generate Redshift Serverless credentials using temporary STS credentials. + + :param credentials: STS credentials from assume_role + :param workgroup_name: Redshift Serverless workgroup name + :param db_name: Redshift database name + :param region: AWS region + :returns: Tuple of (username, password) for Redshift connection + :raises SupersetSecurityException: If credential generation fails + """ + try: + import boto3 + from botocore.exceptions import ClientError + except ImportError as ex: + raise SupersetSecurityException( + SupersetError( + message="boto3 is required for AWS IAM authentication.", + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + try: + client = boto3.client( + "redshift-serverless", + region_name=region, + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + + response = client.get_credentials( + workgroupName=workgroup_name, + dbName=db_name, + ) + return response["dbUser"], response["dbPassword"] + + except ClientError as ex: + raise SupersetSecurityException( + SupersetError( + message=f"Failed to get Redshift Serverless credentials: {ex}", + error_type=SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + @classmethod + def generate_redshift_cluster_credentials( + cls, + credentials: dict[str, Any], + cluster_identifier: str, + db_user: str, + db_name: str, + region: str, + auto_create: bool = False, + ) -> tuple[str, str]: + """ + Generate credentials for a provisioned Redshift cluster using temporary + STS credentials. + + :param credentials: STS credentials from assume_role + :param cluster_identifier: Redshift cluster identifier + :param db_user: Database username to get credentials for + :param db_name: Redshift database name + :param region: AWS region + :param auto_create: Whether to auto-create the database user if it doesn't exist + :returns: Tuple of (username, password) for Redshift connection + :raises SupersetSecurityException: If credential generation fails + """ + try: + import boto3 + from botocore.exceptions import ClientError + except ImportError as ex: + raise SupersetSecurityException( + SupersetError( + message="boto3 is required for AWS IAM authentication.", + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + try: + client = boto3.client( + "redshift", + region_name=region, + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + + response = client.get_cluster_credentials( + ClusterIdentifier=cluster_identifier, + DbUser=db_user, + DbName=db_name, + AutoCreate=auto_create, + ) + return response["DbUser"], response["DbPassword"] + + except ClientError as ex: + raise SupersetSecurityException( + SupersetError( + message=f"Failed to get Redshift cluster credentials: {ex}", + error_type=SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + @classmethod + def _apply_iam_authentication( + cls, + database: Database, + params: dict[str, Any], + iam_config: AWSIAMConfig, + ssl_args: dict[str, str] | None = None, + default_port: int = DEFAULT_POSTGRES_PORT, + ) -> None: + """ + Apply IAM authentication to the connection parameters. + + Full flow: assume role -> generate token -> update connect_args -> enable SSL. + + :param database: Database model instance + :param params: Engine parameters dict to modify + :param iam_config: IAM configuration from encrypted_extra + :param ssl_args: SSL args to apply (defaults to sslmode=require) + :param default_port: Default port if not specified in URI + :raises SupersetSecurityException: If any step fails + """ + from superset import feature_flag_manager + + if not feature_flag_manager.is_feature_enabled("AWS_DATABASE_IAM_AUTH"): + raise SupersetSecurityException( + SupersetError( + message="AWS IAM database authentication is not enabled. " + "Set the AWS_DATABASE_IAM_AUTH feature flag to True in your " + "Superset configuration to enable this feature.", + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + if ssl_args is None: + ssl_args = {"sslmode": "require"} + + # Extract configuration + role_arn = iam_config.get("role_arn") + region = iam_config.get("region") + db_username = iam_config.get("db_username") + external_id = iam_config.get("external_id") + session_duration = iam_config.get("session_duration", DEFAULT_SESSION_DURATION) + + # Validate required fields + missing_fields = [] + if not role_arn: + missing_fields.append("role_arn") + if not region: + missing_fields.append("region") + if not db_username: + missing_fields.append("db_username") + + if missing_fields: + raise SupersetSecurityException( + SupersetError( + message="AWS IAM configuration missing required fields: " + f"{', '.join(missing_fields)}", + error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + # Type assertions after validation (mypy doesn't narrow types from list check) + assert role_arn is not None + assert region is not None + assert db_username is not None + + # Get hostname and port from the database URI + uri = make_url_safe(database.sqlalchemy_uri_decrypted) + hostname = uri.host + port = uri.port or default_port + + if not hostname: + raise SupersetSecurityException( + SupersetError( + message=( + "Database URI must include a hostname for IAM authentication" + ), + error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + logger.debug( + "Applying IAM authentication for %s:%d as user %s", + hostname, + port, + db_username, + ) + + # Step 1: Assume the IAM role + credentials = cls.get_iam_credentials( + role_arn=role_arn, + region=region, + external_id=external_id, + session_duration=session_duration, + ) + + # Step 2: Generate the RDS auth token + token = cls.generate_rds_auth_token( + credentials=credentials, + hostname=hostname, + port=port, + username=db_username, + region=region, + ) + + # Step 3: Update connection parameters + connect_args = params.setdefault("connect_args", {}) + + # Set the IAM token as the password + connect_args["password"] = token + + # Override username if different from URI + connect_args["user"] = db_username + + # Step 4: Enable SSL (required for IAM authentication) + connect_args.update(ssl_args) + + logger.debug("IAM authentication configured successfully") + + @classmethod + def _apply_redshift_iam_authentication( + cls, + database: Database, + params: dict[str, Any], + iam_config: AWSIAMConfig, + ) -> None: + """ + Apply Redshift IAM authentication to connection parameters. + + Supports both Redshift Serverless (workgroup_name) and provisioned + clusters (cluster_identifier). The method auto-detects which type + based on the configuration provided. + + Flow: assume role -> get Redshift credentials -> update connect_args -> SSL. + + :param database: Database model instance + :param params: Engine parameters dict to modify + :param iam_config: IAM configuration from encrypted_extra + :raises SupersetSecurityException: If any step fails + """ + from superset import feature_flag_manager + + if not feature_flag_manager.is_feature_enabled("AWS_DATABASE_IAM_AUTH"): + raise SupersetSecurityException( + SupersetError( + message="AWS IAM database authentication is not enabled. " + "Set the AWS_DATABASE_IAM_AUTH feature flag to True in your " + "Superset configuration to enable this feature.", + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + # Extract configuration + role_arn = iam_config.get("role_arn") + region = iam_config.get("region") + external_id = iam_config.get("external_id") + session_duration = iam_config.get("session_duration", DEFAULT_SESSION_DURATION) + + # Serverless fields + workgroup_name = iam_config.get("workgroup_name") + + # Provisioned cluster fields + cluster_identifier = iam_config.get("cluster_identifier") + db_username = iam_config.get("db_username") + + # Common field + db_name = iam_config.get("db_name") + + # Determine deployment type + is_serverless = bool(workgroup_name) + is_provisioned = bool(cluster_identifier) + + if is_serverless and is_provisioned: + raise SupersetSecurityException( + SupersetError( + message="AWS IAM configuration cannot have both workgroup_name " + "(Serverless) and cluster_identifier (provisioned). " + "Please specify only one.", + error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + if not is_serverless and not is_provisioned: + raise SupersetSecurityException( + SupersetError( + message="AWS IAM configuration must include either workgroup_name " + "(for Redshift Serverless) or cluster_identifier " + "(for provisioned Redshift clusters).", + error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + # Validate common required fields + missing_fields = [] + if not role_arn: + missing_fields.append("role_arn") + if not region: + missing_fields.append("region") + if not db_name: + missing_fields.append("db_name") + + # Validate provisioned cluster specific fields + if is_provisioned and not db_username: + missing_fields.append("db_username") + + if missing_fields: + raise SupersetSecurityException( + SupersetError( + message="AWS IAM configuration missing required fields: " + f"{', '.join(missing_fields)}", + error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + # Type assertions after validation + assert role_arn is not None + assert region is not None + assert db_name is not None + + # Step 1: Assume the IAM role + credentials = cls.get_iam_credentials( + role_arn=role_arn, + region=region, + external_id=external_id, + session_duration=session_duration, + ) + + # Step 2: Get Redshift credentials based on deployment type + if is_serverless: + assert workgroup_name is not None + logger.debug( + "Applying Redshift Serverless IAM authentication for workgroup %s", + workgroup_name, + ) + db_user, db_password = cls.generate_redshift_credentials( + credentials=credentials, + workgroup_name=workgroup_name, + db_name=db_name, + region=region, + ) + else: + assert cluster_identifier is not None + assert db_username is not None + logger.debug( + "Applying Redshift provisioned cluster IAM authentication for %s", + cluster_identifier, + ) + db_user, db_password = cls.generate_redshift_cluster_credentials( + credentials=credentials, + cluster_identifier=cluster_identifier, + db_user=db_username, + db_name=db_name, + region=region, + ) + + # Step 3: Update connection parameters + connect_args = params.setdefault("connect_args", {}) + connect_args["password"] = db_password + connect_args["user"] = db_user + + # Step 4: Enable SSL (required for Redshift IAM authentication) + connect_args["sslmode"] = "verify-ca" + + logger.debug("Redshift IAM authentication configured successfully") diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 396c3805acda..ac355c936630 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -610,7 +610,10 @@ def start_oauth2_dance(cls, database: Database) -> None: re-run the query after authorization. """ tab_id = str(uuid4()) - default_redirect_uri = url_for("DatabaseRestApi.oauth2", _external=True) + default_redirect_uri = app.config.get( + "DATABASE_OAUTH2_REDIRECT_URI", + url_for("DatabaseRestApi.oauth2", _external=True), + ) # The state is passed to the OAuth2 provider, and sent back to Superset after # the user authorizes the access. The redirect endpoint in Superset can then diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py index a36e95b92da0..86eadf6c5ab6 100644 --- a/superset/db_engine_specs/gsheets.py +++ b/superset/db_engine_specs/gsheets.py @@ -30,9 +30,11 @@ from marshmallow import fields, Schema from marshmallow.exceptions import ValidationError from requests import Session +from requests.exceptions import HTTPError from shillelagh.adapters.api.gsheets.lib import SCOPES from shillelagh.exceptions import UnauthenticatedError from sqlalchemy.engine import create_engine +from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL from superset import db, security_manager @@ -41,7 +43,9 @@ from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetException +from superset.superset_typing import OAuth2TokenResponse from superset.utils import json +from superset.utils.oauth2 import get_oauth2_access_token if TYPE_CHECKING: from superset.models.core import Database @@ -83,14 +87,16 @@ class GSheetsParametersSchema(Schema): ) -class GSheetsParametersType(TypedDict): +class GSheetsParametersType(TypedDict, total=False): service_account_info: str catalog: dict[str, str] | None + oauth2_client_info: dict[str, str] | None -class GSheetsPropertiesType(TypedDict): +class GSheetsPropertiesType(TypedDict, total=False): parameters: GSheetsParametersType catalog: dict[str, str] + masked_encrypted_extra: str class GSheetsEngineSpec(ShillelaghEngineSpec): @@ -123,7 +129,10 @@ class GSheetsEngineSpec(ShillelaghEngineSpec): # when editing the database, mask this field in `encrypted_extra` # pylint: disable=invalid-name - encrypted_extra_sensitive_fields = {"$.service_account_info.private_key"} + encrypted_extra_sensitive_fields = { + "$.service_account_info.private_key", + "$.oauth2_client_info.secret", + } custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { SYNTAX_ERROR_REGEX: ( @@ -179,6 +188,47 @@ def get_oauth2_authorization_uri( } return urljoin(uri, "?" + urlencode(params)) + @classmethod + def needs_oauth2(cls, ex: Exception) -> bool: + """ + Check if the exception is one that indicates OAuth2 is needed. + + In case the token was manually revoked on Google side, `google-auth` will + try to automatically refresh credentials, but it fails since it only has the + access token. This override catches this scenario as well. + """ + return ( + g + and hasattr(g, "user") + and ( + isinstance(ex, cls.oauth2_exception) + or "credentials do not contain the necessary fields" in str(ex) + ) + ) + + @classmethod + def get_oauth2_fresh_token( + cls, + config: OAuth2ClientConfig, + refresh_token: str, + ) -> OAuth2TokenResponse: + """ + Refresh an OAuth2 access token that has expired. + + When trying to refresh an expired token that was revoked on Google side, + the request fails with 400 status code. + """ + try: + return super().get_oauth2_fresh_token(config, refresh_token) + except HTTPError as ex: + if ex.response is not None and ex.response.status_code == 400: + error_data = ex.response.json() + if error_data.get("error") == "invalid_grant": + raise UnauthenticatedError( + error_data.get("error_description", "Token has been revoked") + ) from ex + raise + @classmethod def impersonate_user( cls, @@ -198,6 +248,28 @@ def impersonate_user( return url, engine_kwargs + @classmethod + def get_table_names( + cls, + database: Database, + inspector: Inspector, + schema: str | None, + ) -> set[str]: + """ + Get all sheets added to the connection. + + For OAuth2 connections, force the OAuth2 dance in case the user + doesn't have a token yet to avoid showing table names berofe auth. + """ + if database.is_oauth2_enabled() and not get_oauth2_access_token( + database.get_oauth2_config(), + database.id, + g.user.id, + database.db_engine_spec, + ): + database.start_oauth2_dance() + return super().get_table_names(database, inspector, schema) + @classmethod def get_extra_table_metadata( cls, @@ -311,6 +383,14 @@ def validate_parameters( conn = engine.connect() idx = 0 + # Check for OAuth2 config. Skip URL access for OAuth2 connections (user + # might not have a token, or admin adding a sheet they don't have access to) + oauth2_config_in_params = parameters.get("oauth2_client_info") + oauth2_config_in_secure_extra = json.loads( + properties.get("masked_encrypted_extra", "{}") + ).get("oauth2_client_info") + is_oauth2_conn = bool(oauth2_config_in_params or oauth2_config_in_secure_extra) + for name, url in table_catalog.items(): if not name: errors.append( @@ -334,7 +414,11 @@ def validate_parameters( ) return errors + if is_oauth2_conn: + continue + try: + url = url.replace('"', '""') results = conn.execute(f'SELECT * FROM "{url}" LIMIT 1') # noqa: S608 results.fetchall() except Exception: # pylint: disable=broad-except diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index 19554d5b9c8c..b6cba3906a6e 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -14,12 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import contextlib +import logging import re from datetime import datetime from decimal import Decimal from re import Pattern -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, TYPE_CHECKING from urllib import parse from flask_babel import gettext as __ @@ -46,8 +49,14 @@ ) from superset.errors import SupersetErrorType from superset.models.sql_lab import Query +from superset.utils import json from superset.utils.core import GenericDataType +if TYPE_CHECKING: + from superset.models.core import Database + +logger = logging.getLogger(__name__) + # Regular expressions to catch custom errors CONNECTION_ACCESS_DENIED_REGEX = re.compile( "Access denied for user '(?P.*?)'@'(?P.*?)'" @@ -294,6 +303,54 @@ class MySQLEngineSpec(BasicParametersMixin, BaseEngineSpec): "mysqlconnector": {"allow_local_infile": 0}, } + # Sensitive fields that should be masked in encrypted_extra. + # This follows the pattern used by other engine specs (bigquery, snowflake, etc.) + # that specify exact paths rather than using the base class's catch-all "$.*". + encrypted_extra_sensitive_fields = { + "$.aws_iam.external_id", + "$.aws_iam.role_arn", + } + + @staticmethod + def update_params_from_encrypted_extra( + database: Database, + params: dict[str, Any], + ) -> None: + """ + Extract sensitive parameters from encrypted_extra. + + Handles AWS IAM authentication if configured, then merges any + remaining encrypted_extra keys into params. + """ + if not database.encrypted_extra: + return + + try: + encrypted_extra = json.loads(database.encrypted_extra) + except json.JSONDecodeError as ex: + logger.error(ex, exc_info=True) + raise + + # Handle AWS IAM auth: pop the key so it doesn't reach create_engine() + iam_config = encrypted_extra.pop("aws_iam", None) + if iam_config and iam_config.get("enabled"): + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + AWSIAMAuthMixin._apply_iam_authentication( + database, + params, + iam_config, + # MySQL drivers (mysqlclient) use 'ssl' dict, not 'ssl_mode'. + # SSL is typically configured via the database's extra settings, + # so we pass empty ssl_args here to avoid driver compatibility issues. + ssl_args={}, + default_port=3306, + ) + + # Standard behavior: merge remaining keys into params + if encrypted_extra: + params.update(encrypted_extra) + @classmethod def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 10458f15c600..8ae844ff4a29 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -359,6 +359,14 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): max_column_name_length = 63 try_remove_schema_from_table_name = False # pylint: disable=invalid-name + # Sensitive fields that should be masked in encrypted_extra. + # This follows the pattern used by other engine specs (bigquery, snowflake, etc.) + # that specify exact paths rather than using the base class's catch-all "$.*". + encrypted_extra_sensitive_fields = { + "$.aws_iam.external_id", + "$.aws_iam.role_arn", + } + column_type_mappings = ( ( re.compile(r"^double precision", re.IGNORECASE), @@ -461,6 +469,51 @@ def adjust_engine_params( return uri, connect_args + @staticmethod + def update_params_from_encrypted_extra( + database: Database, + params: dict[str, Any], + ) -> None: + """ + Extract sensitive parameters from encrypted_extra. + + Handles AWS IAM authentication if configured, then merges any + remaining encrypted_extra keys into params (standard behavior). + """ + if not database.encrypted_extra: + return + + try: + encrypted_extra = json.loads(database.encrypted_extra) + except json.JSONDecodeError as ex: + logger.error(ex, exc_info=True) + raise + + # Handle AWS IAM auth: pop the key so it doesn't reach create_engine() + iam_config = encrypted_extra.pop("aws_iam", None) + if iam_config and iam_config.get("enabled"): + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + # Preserve a stricter existing sslmode (e.g. verify-full) if present + connect_args = params.get("connect_args") or {} + previous_sslmode = connect_args.get("sslmode") + + AWSIAMAuthMixin._apply_iam_authentication( + database, + params, + iam_config, + ssl_args={"sslmode": "require"}, + default_port=5432, + ) + + # Restore stricter sslmode if it was previously configured + if previous_sslmode in ("verify-ca", "verify-full"): + params.setdefault("connect_args", {})["sslmode"] = previous_sslmode + + # Standard behavior: merge remaining keys into params + if encrypted_extra: + params.update(encrypted_extra) + @classmethod def get_default_catalog(cls, database: Database) -> str: """ diff --git a/superset/db_engine_specs/redshift.py b/superset/db_engine_specs/redshift.py index ea49c479dea7..fcdfab16967e 100644 --- a/superset/db_engine_specs/redshift.py +++ b/superset/db_engine_specs/redshift.py @@ -31,6 +31,7 @@ from superset.models.core import Database from superset.models.sql_lab import Query from superset.sql.parse import Table +from superset.utils import json logger = logging.getLogger() @@ -201,6 +202,47 @@ def normalize_table_name_for_upload( schema_name.lower() if schema_name else None, ) + # Sensitive fields that should be masked in encrypted_extra. + # This follows the pattern used by other engine specs (bigquery, snowflake, etc.) + # that specify exact paths rather than using the base class's catch-all "$.*". + encrypted_extra_sensitive_fields = { + "$.aws_iam.external_id", + "$.aws_iam.role_arn", + } + + @staticmethod + def update_params_from_encrypted_extra( + database: Database, + params: dict[str, Any], + ) -> None: + """ + Extract sensitive parameters from encrypted_extra. + + Handles AWS IAM authentication for Redshift Serverless if configured, + then merges any remaining encrypted_extra keys into params. + """ + if not database.encrypted_extra: + return + + try: + encrypted_extra = json.loads(database.encrypted_extra) + except json.JSONDecodeError as ex: + logger.error(ex, exc_info=True) + raise + + # Handle AWS IAM auth: pop the key so it doesn't reach create_engine() + iam_config = encrypted_extra.pop("aws_iam", None) + if iam_config and iam_config.get("enabled"): + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + AWSIAMAuthMixin._apply_redshift_iam_authentication( + database, params, iam_config + ) + + # Standard behavior: merge remaining keys into params + if encrypted_extra: + params.update(encrypted_extra) + @classmethod def df_to_sql( cls, diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 4a8fc4e729f2..f177d9e83b3b 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -51,29 +51,27 @@ def get_default_instructions(branding: str = "Apache Superset") -> str: Dashboard Management: - list_dashboards: List dashboards with advanced filters (1-based pagination) - get_dashboard_info: Get detailed dashboard information by ID -- generate_dashboard: Automatically create a dashboard from datasets with AI +- generate_dashboard: Create a dashboard from chart IDs - add_chart_to_existing_dashboard: Add a chart to an existing dashboard Dataset Management: - list_datasets: List datasets with advanced filters (1-based pagination) -- get_dataset_info: Get detailed dataset information by ID +- get_dataset_info: Get detailed dataset information by ID (includes columns/metrics) Chart Management: - list_charts: List charts with advanced filters (1-based pagination) - get_chart_info: Get detailed chart information by ID - get_chart_preview: Get a visual preview of a chart with image URL - get_chart_data: Get underlying chart data in text-friendly format -- generate_chart: Create a new chart with AI assistance -- update_chart: Update existing chart configuration -- update_chart_preview: Update chart and get preview in one operation +- generate_chart: Create and save a new chart permanently +- generate_explore_link: Create an interactive explore URL (preferred for exploration) +- update_chart: Update existing saved chart configuration +- update_chart_preview: Update cached chart preview without saving SQL Lab Integration: -- execute_sql: Execute SQL queries and get results +- execute_sql: Execute SQL queries and get results (requires database_id) - open_sql_lab_with_context: Generate SQL Lab URL with pre-filled query -Explore & Analysis: -- generate_explore_link: Create pre-configured explore URL with dataset/metrics/filters - Schema Discovery: - get_schema: Get schema metadata for chart/dataset/dashboard (columns, filters) @@ -82,42 +80,49 @@ def get_default_instructions(branding: str = "Apache Superset") -> str: - health_check: Simple health check tool (takes NO parameters, call without arguments) Available Resources: -- instance/metadata: Access instance configuration and metadata -- chart/templates: Access chart configuration templates +- instance://metadata: Instance configuration, stats, and available dataset IDs +- chart://configs: Valid chart configuration examples and best practices Available Prompts: - quickstart: Interactive guide for getting started with the MCP service - create_chart_guided: Step-by-step chart creation wizard -Common Chart Types (viz_type) and Behaviors: - -Interactive Charts (support sorting, filtering, drill-down): -- table: Standard table view with sorting and filtering -- pivot_table_v2: Pivot table with grouping and aggregations -- echarts_timeseries_line: Time series line chart -- echarts_timeseries_bar: Time series bar chart -- echarts_timeseries_area: Time series area chart -- echarts_timeseries_scatter: Time series scatter plot -- mixed_timeseries: Combined line/bar time series - -Common Visualization Types: -- big_number: Single metric display -- big_number_total: Total value display -- pie: Pie chart for proportions -- echarts_timeseries: Generic time series chart -- funnel: Funnel chart for conversion analysis -- gauge_chart: Gauge/speedometer visualization -- heatmap_v2: Heat map for correlation analysis -- sankey_v2: Sankey diagram for flow visualization -- sunburst_v2: Sunburst chart for hierarchical data -- treemap_v2: Tree map for hierarchical proportions -- word_cloud: Word cloud visualization -- world_map: Geographic world map -- box_plot: Box plot for distribution analysis -- bubble: Bubble chart for 3-dimensional data +Recommended Workflows: + +To create a chart: +1. list_datasets -> find a dataset +2. get_dataset_info(id) -> examine columns and metrics +3. generate_explore_link(dataset_id, config) -> preview interactively +4. generate_chart(dataset_id, config, save_chart=True) -> save permanently + +To explore data with SQL: +1. get_instance_info -> find database_id +2. execute_sql(database_id, sql) -> run query +3. open_sql_lab_with_context(database_id) -> open SQL Lab UI + +generate_explore_link vs generate_chart: +- Use generate_explore_link for exploration (no permanent chart created) +- Use generate_chart with save_chart=True only when user wants to save permanently + +Chart Types You Can CREATE with generate_chart/generate_explore_link: +- chart_type="xy", kind="line": Line chart for time series and trends +- chart_type="xy", kind="bar": Bar chart for category comparison +- chart_type="xy", kind="area": Area chart for volume visualization +- chart_type="xy", kind="scatter": Scatter plot for correlation analysis +- chart_type="table": Data table for detailed views +- chart_type="table", viz_type="ag-grid-table": Interactive AG Grid table + +Time grain for temporal x-axis (time_grain parameter): +- PT1H (hourly), P1D (daily), P1W (weekly), P1M (monthly), P1Y (yearly) + +Chart Types in Existing Charts (viewable via list_charts/get_chart_info): +- pie, big_number, big_number_total, funnel, gauge_chart +- echarts_timeseries_line, echarts_timeseries_bar, echarts_timeseries_area +- pivot_table_v2, heatmap_v2, sankey_v2, sunburst_v2, treemap_v2 +- word_cloud, world_map, box_plot, bubble, mixed_timeseries Query Examples: -- List all interactive tables: +- List all tables: filters=[{{"col": "viz_type", "opr": "in", "value": ["table", "pivot_table_v2"]}}] - List time series charts: filters=[{{"col": "viz_type", "opr": "sw", "value": "echarts_timeseries"}}] diff --git a/superset/mcp_service/chart/prompts/create_chart_guided.py b/superset/mcp_service/chart/prompts/create_chart_guided.py index 6010a71bd401..06bca71bbd0c 100644 --- a/superset/mcp_service/chart/prompts/create_chart_guided.py +++ b/superset/mcp_service/chart/prompts/create_chart_guided.py @@ -19,175 +19,132 @@ Chart prompts for visualization guidance """ -import logging - from superset_core.mcp import prompt -logger = logging.getLogger(__name__) - @prompt("create_chart_guided") async def create_chart_guided_prompt( chart_type: str = "auto", business_goal: str = "exploration" ) -> str: """ - AI-powered chart creation guide following Anthropic's agent design principles. - - This prompt implements: - - Transparency: Clear reasoning at each step - - Proactive Intelligence: Suggests insights before being asked - - Context Awareness: Maintains conversational flow - - Business Focus: Translates data into actionable insights - - Validation: Verifies choices before proceeding - - Natural Interaction: Conversational, not configuration-driven + Guided chart creation with step-by-step workflow. Args: - chart_type: Preferred chart type (auto, line, bar, pie, table, scatter, area) + chart_type: Preferred chart type (auto, line, bar, table, scatter, area) business_goal: Purpose (exploration, reporting, monitoring, presentation) """ - # Enhanced chart intelligence with business context chart_intelligence = { "line": { - "description": "Time series visualization for trend analysis", - "best_for": "Tracking performance over time, identifying patterns", - "business_value": "Reveals growth trends, seasonality, and patterns", + "description": "Time series trends", "data_requirements": "Temporal column + continuous metrics", }, "bar": { - "description": "Category comparison visualization", - "best_for": "Ranking, comparisons, and performance by category", - "business_value": "Identifies top performers, bottlenecks, and gaps", + "description": "Category comparison", "data_requirements": "Categorical dimensions + aggregatable metrics", }, "scatter": { - "description": "Correlation and relationship analysis", - "best_for": "Finding relationships, outlier detection, clustering", - "business_value": "Uncovers hidden correlations and identifies anomalies", + "description": "Correlation analysis", "data_requirements": "Two continuous variables, optional grouping", }, "table": { - "description": "Detailed data exploration and exact values", - "best_for": "Detailed analysis, data validation, precise values", - "business_value": "Provides granular insights and detailed reporting", + "description": "Detailed data view", "data_requirements": "Any combination of dimensions and metrics", }, "area": { - "description": "Volume and composition over time", - "best_for": "Showing cumulative effects, stacked comparisons", - "business_value": "Visualizes contribution and total volume trends", + "description": "Volume over time", "data_requirements": "Temporal dimension + stackable metrics", }, "auto": { - "description": "AI-powered visualization recommendation", - "best_for": "When you're not sure what chart type to use", - "business_value": "Optimizes chart choice based on data characteristics", - "data_requirements": "I'll analyze your data and recommend the best type", + "description": "Recommend based on data", + "data_requirements": "Any - will analyze columns to determine best type", }, } - # Business context intelligence - goal_intelligence = { - "exploration": { - "approach": "Interactive discovery and pattern finding", - "features": "Filters, drill-downs, multiple perspectives", - "outcome": "Uncover hidden insights and generate hypotheses", - }, - "reporting": { - "approach": "Clear, professional, and consistent presentation", - "features": "Clean design, appropriate aggregation, clear labels", - "outcome": "Reliable, repeatable business reporting", - }, - "monitoring": { - "approach": "Real-time tracking with clear thresholds", - "features": "Alert conditions, trend indicators, key metrics", - "outcome": "Proactive issue detection and performance tracking", - }, - "presentation": { - "approach": "Compelling visual storytelling", - "features": "Engaging colors, clear messaging, audience-appropriate detail", - "outcome": "Persuasive data-driven presentations for stakeholders", - }, + goal_context = { + "exploration": "interactive discovery with filters and drill-downs", + "reporting": "clean, professional presentation with clear labels", + "monitoring": "real-time tracking with key metrics highlighted", + "presentation": "compelling visual storytelling for stakeholders", } selected_chart = chart_intelligence.get(chart_type, chart_intelligence["auto"]) - selected_goal = goal_intelligence.get( - business_goal, goal_intelligence["exploration"] - ) - - return f"""🎯 **AI-Powered Chart Creation Assistant** - -I'm your intelligent data visualization partner! Let me help you create charts. - -**Your Visualization Goal:** -📊 **Chart Focus**: {chart_type.title()} - {selected_chart["description"]} -🎯 **Business Purpose**: {business_goal.title()} - {selected_goal["approach"]} -💡 **Expected Value**: {selected_chart["business_value"]} - ---- - -## 🚀 My Intelligent Approach - -### **Phase 1: Data Intelligence** 📊 -I'll automatically analyze your dataset to understand: -- **Data characteristics** (types, distributions, quality) -- **Business relationships** (correlations, hierarchies, trends) -- **Visualization opportunities** (what stories your data can tell) -- **Performance considerations** (size, complexity, aggregation needs) + selected_goal = goal_context.get(business_goal, goal_context["exploration"]) + valid_kinds = ("line", "bar", "area", "scatter") + kind = chart_type if chart_type in valid_kinds else "line" -*Why this matters: The right chart depends on your data's unique characteristics* + return f"""**Guided Chart Creation** -### **Phase 2: Smart Recommendations** 🧠 -Based on your data analysis, I'll: -- **Recommend optimal chart types** with confidence scores and reasoning -- **Suggest meaningful metrics** that align with your business goal -- **Identify interesting patterns** you might want to highlight -- **Propose filters** to focus on what matters most - -*Why this matters: I'll spot opportunities you might miss and save you time* - -### **Phase 3: Intelligent Configuration** ⚙️ -I'll configure your chart with: -- **Business-appropriate aggregations** (daily, weekly, monthly for time series) -- **Meaningful labels and formatting** (currency, percentages, readable names) -- **Performance optimizations** (appropriate limits, caching strategies) -- **Visual best practices** (colors, scales, legends that enhance understanding) - -*Why this matters: Proper configuration makes charts both beautiful and actionable* - -### **Phase 4: Validation & Refinement** 🎯 -Before finalizing, I'll: -- **Verify the chart answers your business question** -- **Check data quality and completeness** -- **Suggest improvements** based on visualization best practices -- **Provide preview** so you can see exactly what you're getting - -*Why this matters: Great charts require iteration and validation* - ---- - -## 🎬 Let's Begin Your Data Story - -I'm ready to be your proactive data exploration partner. Here's how we can start: - -**Option 1: Quick Start** ⚡ -Tell me: *"What business question are you trying to answer?"* -(e.g., "How are our sales trending?" or "Which products perform best?") - -**Option 2: Dataset Exploration** 🔍 -I can show you available datasets: `list_datasets` -Or explore a specific one: `get_dataset_info [dataset_id]` - -**Option 3: Visual Inspiration** 🎨 -Browse pre-built chart configurations: `superset://chart/configs` resource -Perfect for when you want to see examples of great charts! - -**Option 4: Autonomous Discovery** 🤖 -Just point me to a dataset and say *"Find something interesting"* -I'll explore autonomously and surface the most compelling insights! +Chart type: {chart_type} - {selected_chart["description"]} +Data needs: {selected_chart["data_requirements"]} +Goal: {business_goal} - {selected_goal} --- -💡 **Pro Tip**: Great charts combine business intuition with data analysis! - -**What's your data challenge today?** 🚀""" +## Step-by-Step Workflow + +Follow these steps in order: + +### Step 1: Find a Dataset +Call `list_datasets` to see available datasets. + +### Step 2: Examine Columns +Call `get_dataset_info(dataset_id)` to see columns, types, and metrics. + +### Step 3: Choose Chart Configuration +Based on column types: +- Temporal x-axis + numeric y -> line or area chart +- Categorical x-axis + numeric y -> bar chart +- Two numeric columns -> scatter plot +- Any columns for detail -> table + +### Step 4: Create the Chart +Use `generate_explore_link` for interactive preview (preferred), or +`generate_chart` with `save_chart=True` to save permanently. + +Example XY chart config: +```json +{{ + "dataset_id": , + "config": {{ + "chart_type": "xy", + "kind": "{kind}", + "x": {{"name": ""}}, + "y": [{{"name": "", "aggregate": "SUM"}}], + "time_grain": "P1D" + }} +}} +``` + +Example table config: +```json +{{ + "dataset_id": , + "config": {{ + "chart_type": "table", + "columns": [ + {{"name": ""}}, + {{"name": "", "aggregate": "SUM", "label": "Total"}} + ] + }} +}} +``` + +### Step 5: Validate Results +- If you get a column validation error, call `get_dataset_info` to check + the exact column names available +- If data is empty, check if filters are too restrictive +- If the chart type doesn't suit the data, try a different kind + +## Available Aggregations +SUM, COUNT, AVG, MIN, MAX, COUNT_DISTINCT, STDDEV, VAR, MEDIAN + +## Time Grain Options (for temporal x-axis) +PT1H (hourly), P1D (daily), P1W (weekly), P1M (monthly), P3M (quarterly), P1Y (yearly) + +## Additional Options +- group_by: Add a dimension to split data into series +- filters: [{{"column": "col", "op": "=", "value": "x"}}] +- stacked: true (for bar/area charts) +- legend: {{"show": true, "position": "right"}} +- x_axis/y_axis: {{"title": "Label", "format": "$,.0f"}}""" diff --git a/superset/mcp_service/chart/resources/chart_configs.py b/superset/mcp_service/chart/resources/chart_configs.py index 277404f70891..ea40a89fdc62 100644 --- a/superset/mcp_service/chart/resources/chart_configs.py +++ b/superset/mcp_service/chart/resources/chart_configs.py @@ -40,67 +40,71 @@ def get_chart_configs_resource() -> str: - Best practices for each chart type configuration """ - # Valid XYChartConfig examples - these match the exact schema + # XY chart examples covering all chart kinds and features xy_chart_configs = { "line_chart": { - "description": "Basic line chart for time series analysis", + "description": "Line chart with daily time grain", "config": { "chart_type": "xy", "kind": "line", - "x": {"name": "created_on", "label": "Date Created"}, + "x": {"name": "order_date", "label": "Date"}, "y": [ { - "name": "count_metric", - "aggregate": "COUNT", - "label": "Total Count", + "name": "revenue", + "aggregate": "SUM", + "label": "Daily Revenue", } ], + "time_grain": "P1D", }, - "use_cases": [ - "Time series trends", - "Historical analysis", - "Growth tracking", - ], + "use_cases": ["Time series trends", "Growth tracking"], }, "bar_chart": { - "description": "Bar chart for category comparison", + "description": "Bar chart for category comparison with axis formatting", "config": { "chart_type": "xy", "kind": "bar", "x": {"name": "category", "label": "Category"}, "y": [{"name": "sales", "aggregate": "SUM", "label": "Total Sales"}], - "x_axis": {"title": "Product Categories", "scale": "linear"}, - "y_axis": { - "title": "Revenue ($)", - "format": "$,.0f", - "scale": "linear", - }, + "x_axis": {"title": "Product Categories"}, + "y_axis": {"title": "Revenue ($)", "format": "$,.0f"}, }, - "use_cases": ["Category comparison", "Rankings", "Performance metrics"], + "use_cases": ["Category comparison", "Rankings"], + }, + "stacked_bar": { + "description": "Stacked bar chart with group_by dimension", + "config": { + "chart_type": "xy", + "kind": "bar", + "x": {"name": "quarter", "label": "Quarter"}, + "y": [ + {"name": "revenue", "aggregate": "SUM", "label": "Revenue"}, + ], + "group_by": {"name": "region", "label": "Region"}, + "stacked": True, + "legend": {"show": True, "position": "right"}, + }, + "use_cases": ["Composition analysis", "Regional breakdown"], }, "multi_metric_line": { - "description": "Multi-metric line chart with grouping", + "description": "Multi-metric line chart with filters and monthly grain", "config": { "chart_type": "xy", "kind": "line", - "x": {"name": "date_column", "label": "Date"}, + "x": {"name": "order_date", "label": "Date"}, "y": [ {"name": "revenue", "aggregate": "SUM", "label": "Revenue"}, { - "name": "users", + "name": "customer_id", "aggregate": "COUNT_DISTINCT", - "label": "Unique Users", + "label": "Unique Customers", }, ], - "group_by": {"name": "region", "label": "Region"}, - "legend": {"show": True, "position": "right"}, + "time_grain": "P1M", + "legend": {"show": True, "position": "top"}, "filters": [{"column": "status", "op": "=", "value": "active"}], }, - "use_cases": [ - "Multi-dimensional analysis", - "Regional comparisons", - "KPI tracking", - ], + "use_cases": ["KPI tracking", "Multi-dimensional analysis"], }, "scatter_plot": { "description": "Scatter plot for correlation analysis", @@ -108,7 +112,7 @@ def get_chart_configs_resource() -> str: "chart_type": "xy", "kind": "scatter", "x": { - "name": "advertising_spend", + "name": "ad_spend", "aggregate": "AVG", "label": "Avg Ad Spend", }, @@ -119,56 +123,44 @@ def get_chart_configs_resource() -> str: "label": "Avg Conversion Rate", } ], - "group_by": {"name": "campaign_type", "label": "Campaign Type"}, - "x_axis": {"title": "Average Advertising Spend", "format": "$,.0f"}, - "y_axis": {"title": "Conversion Rate", "format": ".2%"}, + "group_by": {"name": "campaign_type", "label": "Campaign"}, + "x_axis": {"format": "$,.0f"}, + "y_axis": {"format": ".2%"}, }, - "use_cases": [ - "Correlation analysis", - "Outlier detection", - "Performance relationships", - ], + "use_cases": ["Correlation analysis", "Outlier detection"], }, - "area_chart": { - "description": "Area chart for volume visualization", + "stacked_area": { + "description": "Stacked area chart for volume composition over time", "config": { "chart_type": "xy", "kind": "area", - "x": {"name": "month", "label": "Month"}, - "y": [ - {"name": "signups", "aggregate": "SUM", "label": "Monthly Signups"} - ], - "filters": [ - {"column": "year", "op": ">=", "value": 2023}, - {"column": "active", "op": "=", "value": True}, - ], + "x": {"name": "order_date", "label": "Date"}, + "y": [{"name": "signups", "aggregate": "SUM", "label": "Signups"}], + "group_by": {"name": "channel", "label": "Channel"}, + "stacked": True, + "time_grain": "P1W", }, - "use_cases": ["Volume trends", "Cumulative metrics", "Stacked comparisons"], + "use_cases": ["Volume trends", "Channel attribution"], }, } - # Valid TableChartConfig examples - these match the exact schema + # Table chart examples table_chart_configs = { "basic_table": { - "description": "Basic data table with multiple columns", + "description": "Standard table with dimensions and aggregated metrics", "config": { "chart_type": "table", "columns": [ - {"name": "name", "label": "Customer Name"}, - {"name": "email", "label": "Email Address"}, + {"name": "customer_name", "label": "Customer"}, {"name": "orders", "aggregate": "COUNT", "label": "Total Orders"}, {"name": "revenue", "aggregate": "SUM", "label": "Total Revenue"}, ], "sort_by": ["Total Revenue"], }, - "use_cases": [ - "Detailed data views", - "Customer lists", - "Transaction records", - ], + "use_cases": ["Detail views", "Customer lists"], }, "aggregated_table": { - "description": "Table with aggregated metrics and filters", + "description": "Table with multiple aggregations and filters", "config": { "chart_type": "table", "columns": [ @@ -190,171 +182,64 @@ def get_chart_configs_resource() -> str: }, ], "filters": [ - {"column": "sale_date", "op": ">=", "value": "2024-01-01"}, {"column": "status", "op": "!=", "value": "cancelled"}, ], - "sort_by": ["Total Sales", "Sales Region"], + "sort_by": ["Total Sales"], }, - "use_cases": ["Summary reports", "Regional analysis", "Performance tables"], + "use_cases": ["Summary reports", "Regional analysis"], }, - } - - # Schema reference for developers - schema_reference = { - "ChartConfig": { - "description": "Union type - XYChartConfig or TableChartConfig by type", - "discriminator": "chart_type", - "types": ["xy", "table"], - }, - "XYChartConfig": { - "required_fields": ["chart_type", "x", "y"], - "optional_fields": [ - "kind", - "group_by", - "x_axis", - "y_axis", - "legend", - "filters", - ], - "chart_type": "xy", - "kind_options": ["line", "bar", "area", "scatter"], - "validation_rules": [ - "All column labels must be unique across x, y, and group_by", - "Y-axis must have at least one column", - "Column names must match pattern: ^[a-zA-Z0-9_][a-zA-Z0-9_\\s\\-\\.]*$", - ], - }, - "TableChartConfig": { - "required_fields": ["chart_type", "columns"], - "optional_fields": ["filters", "sort_by"], - "chart_type": "table", - "validation_rules": [ - "Must have at least one column", - "All column labels must be unique", - "Column names must match pattern: ^[a-zA-Z0-9_][a-zA-Z0-9_\\s\\-\\.]*$", - ], - }, - "ColumnRef": { - "required_fields": ["name"], - "optional_fields": ["label", "dtype", "aggregate"], - "aggregate_options": [ - "SUM", - "COUNT", - "AVG", - "MIN", - "MAX", - "COUNT_DISTINCT", - "STDDEV", - "VAR", - "MEDIAN", - "PERCENTILE", - ], - "validation_rules": [ - "Name cannot be empty and must follow pattern", - "Labels are HTML-escaped to prevent XSS", - "Aggregates are validated against allowed functions", - ], - }, - "FilterConfig": { - "required_fields": ["column", "op", "value"], - "operator_options": ["=", ">", "<", ">=", "<=", "!="], - "value_types": ["string", "number", "boolean"], - "validation_rules": [ - "Column names are sanitized to prevent injection", - "Values are checked for malicious patterns", - "String values are HTML-escaped", + "ag_grid_table": { + "description": "Interactive AG Grid table with advanced features", + "config": { + "chart_type": "table", + "viz_type": "ag-grid-table", + "columns": [ + {"name": "product_name", "label": "Product"}, + {"name": "category", "label": "Category"}, + {"name": "quantity", "aggregate": "SUM", "label": "Qty Sold"}, + {"name": "revenue", "aggregate": "SUM", "label": "Revenue"}, + ], + }, + "use_cases": [ + "Interactive exploration", + "Large datasets with client-side sorting/filtering", ], }, - "AxisConfig": { - "optional_fields": ["title", "scale", "format"], - "scale_options": ["linear", "log"], - "format_examples": ["$,.2f", ".2%", ",.0f", ".1f"], - }, - "LegendConfig": { - "optional_fields": ["show", "position"], - "show_default": True, - "position_options": ["top", "bottom", "left", "right"], - "position_default": "right", - }, } - # Best practices for each configuration type + # Best practices best_practices = { "xy_charts": [ - "Use descriptive labels for axes and metrics", - "Choose appropriate aggregation functions for your data", - "Limit the number of Y-axis metrics (3-5 maximum)", - "Use filters to focus on relevant data", - "Configure axis formatting for better readability", - "Consider grouping when comparing categories", - "Use chart kinds: line for trends, bar for comparisons, scatter plots", + "Use time_grain for temporal x-axis columns (P1D, P1W, P1M, P1Y)", + "Limit Y-axis metrics to 3-5 maximum for readability", + "Use group_by to split data into series for comparison", + "Use stacked=true for bar/area charts showing composition", + "Configure axis format for readability ($,.0f for currency, .2% for pct)", ], "table_charts": [ - "Include essential columns only to avoid clutter", - "Use meaningful column labels", - "Apply sorting to highlight important data", - "Use filters to limit result sets", - "Mix dimensions and aggregated metrics appropriately", - "Ensure unique labels to avoid conflicts", - "Consider performance with large datasets", + "Include only essential columns to avoid clutter", + "Use meaningful labels different from raw column names", + "Apply sort_by to highlight important data", + "Use ag-grid-table viz_type for large interactive datasets", ], "general": [ - "Always specify chart_type as the first field", - "Use consistent naming conventions for columns", - "Validate column names exist in your dataset", - "Test configurations with actual data", - "Consider caching for frequently accessed charts", - "Apply security best practices - avoid user input in column names", + "Always verify column names with get_dataset_info before charting", + "Use generate_explore_link for preview, generate_chart for saving", + "Each column label must be unique across the entire configuration", + "Column names must match: ^[a-zA-Z0-9_][a-zA-Z0-9_ \\-\\.]*$", ], } - # Common patterns and examples - common_patterns = { - "time_series": { - "description": "Standard time-based analysis", - "x_column_types": ["date", "datetime", "timestamp"], - "recommended_aggregations": ["SUM", "COUNT", "AVG"], - "best_chart_types": ["line", "area", "bar"], - }, - "categorical_analysis": { - "description": "Comparing discrete categories", - "x_column_types": ["string", "category", "enum"], - "recommended_aggregations": ["SUM", "COUNT", "COUNT_DISTINCT", "AVG"], - "best_chart_types": ["bar", "table"], - }, - "correlation_analysis": { - "description": "Finding relationships between variables", - "requirements": ["Two numerical metrics"], - "recommended_aggregations": ["AVG", "SUM", "MEDIAN"], - "best_chart_types": ["scatter"], - }, - } - resource_data = { "xy_chart_configs": xy_chart_configs, "table_chart_configs": table_chart_configs, - "schema_reference": schema_reference, "best_practices": best_practices, - "common_patterns": common_patterns, - "metadata": { - "version": "1.0", - "schema_version": "ChartConfig v1.0", - "last_updated": "2025-08-07", - "usage_notes": [ - "All examples are valid ChartConfig objects that pass validation", - "Copy these configurations directly into generate_chart requests", - "Modify column names and labels to match your actual dataset", - "Test configurations with get_dataset_info to verify columns", - "All examples follow security best practices and input validation", - ], - "validation_info": [ - "Column names must match: ^[a-zA-Z0-9_][a-zA-Z0-9_\\s\\-\\.]*$", - "Labels are automatically HTML-escaped for security", - "Filter values are sanitized to prevent injection attacks", - "All field lengths are validated against schema limits", - "Duplicate labels are automatically detected and rejected", - ], - }, + "usage_notes": [ + "All examples are valid ChartConfig objects that pass validation", + "Modify column names and labels to match your actual dataset", + "Use get_dataset_info to verify column names before charting", + "For complete schema details, see the generate_chart tool parameters", + ], } from superset.utils import json diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index 928858aea123..b813bc4ebc50 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -848,9 +848,7 @@ class GenerateChartRequest(QueryCacheControl): default=True, description="Whether to generate a preview image", ) - preview_formats: List[ - Literal["url", "interactive", "ascii", "vega_lite", "table", "base64"] - ] = Field( + preview_formats: List[Literal["url", "ascii", "vega_lite", "table"]] = Field( default_factory=lambda: ["url"], description="List of preview formats to generate", ) @@ -896,9 +894,7 @@ class UpdateChartRequest(QueryCacheControl): default=True, description="Whether to generate a preview after updating", ) - preview_formats: List[ - Literal["url", "interactive", "ascii", "vega_lite", "table", "base64"] - ] = Field( + preview_formats: List[Literal["url", "ascii", "vega_lite", "table"]] = Field( default_factory=lambda: ["url"], description="List of preview formats to generate", ) @@ -973,9 +969,7 @@ class UpdateChartPreviewRequest(FormDataCacheControl): default=True, description="Whether to generate a preview after updating", ) - preview_formats: List[ - Literal["url", "interactive", "ascii", "vega_lite", "table", "base64"] - ] = Field( + preview_formats: List[Literal["url", "ascii", "vega_lite", "table"]] = Field( default_factory=lambda: ["url"], description="List of preview formats to generate", ) @@ -1063,11 +1057,11 @@ class GetChartPreviewRequest(QueryCacheControl): """Request for chart preview with cache control.""" identifier: int | str = Field(description="Chart identifier (ID, UUID)") - format: Literal["url", "ascii", "table", "base64", "vega_lite"] = Field( + format: Literal["url", "ascii", "table", "vega_lite"] = Field( default="url", description=( "Preview format: 'url' for image URL, 'ascii' for text art, " - "'table' for data table, 'base64' for embedded image, " + "'table' for data table, " "'vega_lite' for interactive JSON specification" ), ) diff --git a/superset/mcp_service/chart/tool/generate_chart.py b/superset/mcp_service/chart/tool/generate_chart.py index 9cd6e6225848..7b7f3a17895a 100644 --- a/superset/mcp_service/chart/tool/generate_chart.py +++ b/superset/mcp_service/chart/tool/generate_chart.py @@ -369,10 +369,6 @@ async def generate_chart( # noqa: C901 await ctx.debug( "Processing preview format: format=%s" % (format_type,) ) - # Skip base64 format - we never return base64 - if format_type == "base64": - logger.info("Skipping base64 format - not supported") - continue if chart_id: # For saved charts, use the existing preview generation diff --git a/superset/mcp_service/chart/tool/get_chart_preview.py b/superset/mcp_service/chart/tool/get_chart_preview.py index 2b3ecf353e7a..9e0015403962 100644 --- a/superset/mcp_service/chart/tool/get_chart_preview.py +++ b/superset/mcp_service/chart/tool/get_chart_preview.py @@ -85,9 +85,6 @@ def generate(self) -> URLPreview | ChartError: ) -# Base64 preview support removed - we never return base64 data - - class ASCIIPreviewStrategy(PreviewFormatStrategy): """Generate ASCII art preview.""" diff --git a/superset/mcp_service/common/schema_discovery.py b/superset/mcp_service/common/schema_discovery.py index 21da80287353..530d26c38754 100644 --- a/superset/mcp_service/common/schema_discovery.py +++ b/superset/mcp_service/common/schema_discovery.py @@ -119,6 +119,63 @@ def _get_sqlalchemy_type_name(col_type: Any) -> str: return "str" # Default fallback +# Descriptions for common model columns that SQLAlchemy models don't document. +# Used as a fallback when the model column has no doc/comment attribute. +_COLUMN_DESCRIPTIONS: dict[str, str] = { + # Common across models + "id": "Unique integer identifier", + "uuid": "Unique UUID identifier", + "created_on": "Timestamp when the resource was created", + "changed_on": "Timestamp when the resource was last modified", + "created_by_fk": "User ID of the creator", + "changed_by_fk": "User ID of the last modifier", + "description": "User-provided description text", + "cache_timeout": "Cache timeout override in seconds", + "perm": "Permission string for access control", + "schema_perm": "Schema-level permission string", + "catalog_perm": "Catalog-level permission string", + "is_managed_externally": "Whether managed by an external system", + "external_url": "URL of the external management system", + "certified_by": "Name of the person who certified this resource", + "certification_details": "Details about the certification", + # Chart-specific + "slice_name": "Chart display name", + "datasource_id": "ID of the underlying dataset", + "datasource_type": "Type of data source (e.g., table)", + "viz_type": "Visualization type (e.g., echarts_timeseries_line, table)", + "params": "JSON string of chart parameters/configuration", + "query_context": "JSON string of the query context for data fetching", + "last_saved_at": "Timestamp of the last explicit save", + "last_saved_by_fk": "User ID who last saved this chart", + # Dataset-specific + "table_name": "Name of the database table or view", + "schema": "Database schema name", + "catalog": "Database catalog name", + "database_id": "ID of the database connection", + "sql": "Custom SQL expression (for virtual datasets)", + "main_dttm_col": "Primary datetime column for time-series queries", + "is_sqllab_view": "Whether this dataset was created from SQL Lab", + "template_params": "Jinja template parameters as JSON", + "extra": "Extra configuration as JSON", + "filter_select_enabled": "Whether filter select is enabled", + "normalize_columns": "Whether to normalize column names", + "always_filter_main_dttm": "Whether to always filter on the main datetime column", + "fetch_values_predicate": "SQL predicate for fetching filter values", + "default_endpoint": "Default endpoint URL", + "offset": "Row offset for queries", + "is_featured": "Whether this dataset is featured", + "currency_code_column": "Column containing currency codes", + # Dashboard-specific + "dashboard_title": "Dashboard display title", + "slug": "URL-friendly identifier for the dashboard", + "published": "Whether the dashboard is published and visible", + "position_json": "JSON layout of dashboard components", + "json_metadata": "JSON metadata including filters and settings", + "css": "Custom CSS for the dashboard", + "theme_id": "Theme ID for dashboard styling", +} + + def get_columns_from_model( model_cls: Type[Any], default_columns: list[str], @@ -141,8 +198,12 @@ def get_columns_from_model( for col in mapper.columns: col_name = col.key col_type = _get_sqlalchemy_type_name(col.type) - # Get description from column doc or comment - description = getattr(col, "doc", None) or getattr(col, "comment", None) + # Get description from column doc, comment, or fallback mapping + description = ( + getattr(col, "doc", None) + or getattr(col, "comment", None) + or _COLUMN_DESCRIPTIONS.get(col_name) + ) columns.append( ColumnMetadata( diff --git a/superset/mcp_service/system/prompts/quickstart.py b/superset/mcp_service/system/prompts/quickstart.py index 3955cbc594d8..ec56400b3a44 100644 --- a/superset/mcp_service/system/prompts/quickstart.py +++ b/superset/mcp_service/system/prompts/quickstart.py @@ -19,13 +19,9 @@ System prompts for general guidance """ -import logging - from flask import current_app from superset_core.mcp import prompt -logger = logging.getLogger(__name__) - def _get_app_name() -> str: """Get the application name from Flask config.""" @@ -43,61 +39,57 @@ async def quickstart_prompt( """ Guide new users through their first experience with the platform. - This prompt helps users: - 1. Understand what data is available - 2. Create their first visualization - 3. Build a simple dashboard - 4. Learn key Superset concepts - Args: user_type: Type of user (analyst, executive, developer) focus_area: Area of interest (sales, marketing, operations, general) """ - # Build personalized prompt based on user type - intro_messages = { - "analyst": "I see you're an analyst. Let's explore the data and build some " - "detailed visualizations.", - "executive": "Welcome! Let's create a high-level dashboard with key business " - "metrics.", - "developer": "Great to have a developer here! Let's explore both the UI and " - "API capabilities.", - } - - focus_examples = { - "sales": "Since you're interested in sales, we'll focus on revenue, customer, " - "and product metrics.", - "marketing": "For marketing analytics, we'll look at campaigns, conversions, " - "and customer acquisition.", - "operations": "Let's explore operational efficiency, inventory, and process " - "metrics.", - "general": "We'll explore various datasets to find what's most relevant to " - "you.", - } - - intro = intro_messages.get(user_type, intro_messages["analyst"]) - focus = focus_examples.get(focus_area, focus_examples["general"]) app_name = _get_app_name() - return f"""Welcome to {app_name}! I'll guide you through creating your first - dashboard. + # Workflow varies by user type + workflows = { + "analyst": f"""**Workflow for Analysts:** + +1. Call `get_instance_info` to see what's available in this {app_name} instance +2. Call `list_datasets` to find datasets relevant to {focus_area} +3. Call `get_dataset_info(id)` to examine columns and metrics +4. Call `generate_explore_link` to create interactive chart previews +5. Iterate on chart configuration until the visualization answers your question +6. Call `generate_chart(save_chart=True)` to save charts you want to keep +7. Call `generate_dashboard` with your saved chart IDs to build a dashboard""", + "executive": f"""**Workflow for Executives:** + +1. Call `get_instance_info` to see available dashboards and charts +2. Call `list_dashboards` to find existing dashboards relevant to {focus_area} +3. Call `get_dashboard_info(id)` to view dashboard details and chart list +4. To create a new KPI dashboard: + a. Call `list_datasets` to find data sources + b. Create charts with `generate_chart` (line/bar/table) + c. Call `generate_dashboard` with chart IDs""", + "developer": """**Workflow for Developers:** + +1. Call `get_instance_info` to understand the instance +2. Call `get_schema(model_type)` to discover columns and filters +3. Use `execute_sql(database_id, sql)` to run queries +4. Use `open_sql_lab_with_context` for SQL Lab URLs +5. Use `list_datasets`/`list_charts`/`list_dashboards` with filters +6. Use `generate_explore_link` for chart previews without saving""", + } -{intro} {focus} + selected_workflow = workflows.get(user_type, workflows["analyst"]) -I'll help you through these steps: -1. **Explore Available Data** - See what datasets you can work with -2. **Understand Your Data** - Examine columns, metrics, and sample data -3. **Create Visualizations** - Build charts that tell a story -4. **Design a Dashboard** - Combine charts into an interactive dashboard -5. **Learn Advanced Features** - Discover filters, SQL Lab, and more + return f"""**{app_name} Quickstart Guide** -To get started, I'll use these tools: -- `get_instance_info` - Overview of your {app_name} instance -- `list_datasets` - Find available datasets -- `get_dataset_info` - Explore dataset details -- `generate_chart` - Create visualizations -- `generate_dashboard` - Build your dashboard +{selected_workflow} -Let me begin by checking what's available in your {app_name} instance. I'll first get -an overview, then show you the datasets filtered by your interest in {focus_area}. +**Available Tools Summary:** +- `get_instance_info` - Instance overview (databases, dataset count, chart count) +- `list_datasets` / `get_dataset_info` - Find and examine data sources +- `list_charts` / `get_chart_info` - Browse existing charts +- `list_dashboards` / `get_dashboard_info` - Browse existing dashboards +- `generate_explore_link` - Create interactive chart preview (no save) +- `generate_chart` - Create and save a chart permanently +- `generate_dashboard` - Create a dashboard from chart IDs +- `execute_sql` - Run SQL queries against a database +- `get_schema` - Discover filterable/sortable columns for list tools -Would you like me to start by showing you what data you can work with?""" +Start by calling `get_instance_info` to see what data is available.""" diff --git a/superset/mcp_service/system/resources/instance_metadata.py b/superset/mcp_service/system/resources/instance_metadata.py index 55d4b67f75d8..29e223f06e1b 100644 --- a/superset/mcp_service/system/resources/instance_metadata.py +++ b/superset/mcp_service/system/resources/instance_metadata.py @@ -16,11 +16,17 @@ # under the License. """ -System resources for providing instance configuration and stats +System resources for providing instance configuration and stats. + +This resource differs from the get_instance_info tool by also including +available dataset IDs and database IDs, so LLMs can immediately call +get_dataset_info or execute_sql without an extra list call. """ import logging +from sqlalchemy.exc import SQLAlchemyError + from superset.mcp_service.app import mcp from superset.mcp_service.auth import mcp_auth_hook @@ -31,19 +37,15 @@ @mcp_auth_hook def get_instance_metadata_resource() -> str: """ - Provide comprehensive metadata about the instance. + Provide instance metadata with available dataset and database IDs. This resource gives LLMs context about: - - Available datasets and their popularity + - Instance summary stats (counts of dashboards, charts, datasets) + - Available database connections with their IDs (for execute_sql) + - Available datasets with IDs and table names (for get_dataset_info) - Dashboard and chart statistics - - Database connections - - Popular queries and usage patterns - - Available visualization types - - Feature flags and configuration """ try: - # Import the shared core and DAOs at runtime - # Create a shared core instance for the resource from typing import Any, cast, Type from superset.daos.base import BaseDAO @@ -62,6 +64,7 @@ def get_instance_metadata_resource() -> str: calculate_popular_content, calculate_recent_activity, ) + from superset.utils import json instance_info_core = InstanceInfoCore( dao_classes={ @@ -88,12 +91,57 @@ def get_instance_metadata_resource() -> str: logger=logger, ) - # Use the shared core's resource method - return instance_info_core.get_resource() + # Get base instance info + base_result = json.loads(instance_info_core.get_resource()) + + # Remove empty popular_content if it has no useful data + popular = base_result.get("popular_content", {}) + if popular and not any(popular.get(k) for k in popular): + del base_result["popular_content"] + + # Add available datasets (top 20 by most recent modification) + dataset_dao = instance_info_core.dao_classes["datasets"] + try: + datasets = dataset_dao.find_all() + # Convert to string to avoid TypeError when comparing datetime with None + sorted_datasets = sorted( + datasets, + key=lambda d: str(getattr(d, "changed_on", "") or ""), + reverse=True, + )[:20] + base_result["available_datasets"] = [ + { + "id": ds.id, + "table_name": ds.table_name, + "schema": getattr(ds, "schema", None), + "database_id": getattr(ds, "database_id", None), + } + for ds in sorted_datasets + ] + except (SQLAlchemyError, AttributeError) as e: + logger.warning("Could not fetch datasets for metadata: %s", e) + base_result["available_datasets"] = [] + + # Add available databases (for execute_sql) + database_dao = instance_info_core.dao_classes["databases"] + try: + databases = database_dao.find_all() + base_result["available_databases"] = [ + { + "id": db.id, + "database_name": db.database_name, + "backend": getattr(db, "backend", None), + } + for db in databases + ] + except (SQLAlchemyError, AttributeError) as e: + logger.warning("Could not fetch databases for metadata: %s", e) + base_result["available_databases"] = [] + + return json.dumps(base_result, indent=2) - except Exception as e: + except (SQLAlchemyError, AttributeError, KeyError, ValueError) as e: logger.error("Error generating instance metadata: %s", e) - # Return minimal metadata on error from superset.utils import json return json.dumps( diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index cd1a2a14d9e5..0124f5730875 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -17,6 +17,7 @@ from __future__ import annotations +import logging from contextlib import contextmanager from datetime import datetime, timedelta, timezone from typing import Any, Iterator, TYPE_CHECKING @@ -37,6 +38,8 @@ JWT_EXPIRATION = timedelta(minutes=5) +logger = logging.getLogger(__name__) + @backoff.on_exception( backoff.expo, @@ -96,10 +99,28 @@ def refresh_oauth2_token( user_id=user_id, database_id=database_id, ): - token_response = db_engine_spec.get_oauth2_fresh_token( - config, - token.refresh_token, - ) + try: + token_response = db_engine_spec.get_oauth2_fresh_token( + config, + token.refresh_token, + ) + except db_engine_spec.oauth2_exception: + # OAuth token is no longer valid, delete it and start OAuth2 dance + logger.warning( + "OAuth2 token refresh failed for user=%s db=%s, deleting invalid token", + user_id, + database_id, + ) + db.session.delete(token) + raise + except Exception: + # non-OAuth related failure, log the exception + logger.warning( + "OAuth2 token refresh failed for user=%s db=%s", + user_id, + database_id, + ) + raise # store new access token; note that the refresh token might be revoked, in which # case there would be no access token in the response diff --git a/tests/unit_tests/db_engine_specs/test_aurora.py b/tests/unit_tests/db_engine_specs/test_aurora.py new file mode 100644 index 000000000000..9b979c278bd6 --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_aurora.py @@ -0,0 +1,317 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=import-outside-toplevel + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from superset.utils import json +from tests.unit_tests.conftest import with_feature_flags + + +def test_aurora_postgres_engine_spec_properties() -> None: + from superset.db_engine_specs.aurora import AuroraPostgresEngineSpec + + assert AuroraPostgresEngineSpec.engine == "postgresql" + assert AuroraPostgresEngineSpec.engine_name == "Aurora PostgreSQL" + assert AuroraPostgresEngineSpec.default_driver == "psycopg2" + + +def test_update_params_from_encrypted_extra_without_iam() -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps({}) + database.sqlalchemy_uri_decrypted = ( + "postgresql://user:password@mydb.us-east-1.rds.amazonaws.com:5432/mydb" + ) + + params: dict[str, Any] = {} + PostgresEngineSpec.update_params_from_encrypted_extra(database, params) + + # No modifications should be made + assert params == {} + + +def test_update_params_from_encrypted_extra_iam_disabled() -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": False, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_user", + } + } + ) + database.sqlalchemy_uri_decrypted = ( + "postgresql://user:password@mydb.us-east-1.rds.amazonaws.com:5432/mydb" + ) + + params: dict[str, Any] = {} + PostgresEngineSpec.update_params_from_encrypted_extra(database, params) + + # No modifications should be made when IAM is disabled + assert params == {} + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_update_params_from_encrypted_extra_with_iam() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + from superset.db_engine_specs.postgres import PostgresEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + } + ) + database.sqlalchemy_uri_decrypted = ( + "postgresql://user@mydb.cluster-xyz.us-east-1.rds.amazonaws.com:5432/mydb" + ) + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ), + ): + PostgresEngineSpec.update_params_from_encrypted_extra(database, params) + + assert "connect_args" in params + assert params["connect_args"]["password"] == "iam-auth-token" # noqa: S105 + assert params["connect_args"]["user"] == "superset_iam_user" + assert params["connect_args"]["sslmode"] == "require" + + +def test_update_params_merges_remaining_encrypted_extra() -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": {"enabled": False}, + "pool_size": 10, + } + ) + database.sqlalchemy_uri_decrypted = ( + "postgresql://user:password@mydb.us-east-1.rds.amazonaws.com:5432/mydb" + ) + + params: dict[str, Any] = {} + PostgresEngineSpec.update_params_from_encrypted_extra(database, params) + + # aws_iam should be consumed, pool_size should be merged + assert "aws_iam" not in params + assert params["pool_size"] == 10 + + +def test_update_params_from_encrypted_extra_no_encrypted_extra() -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec + + database = MagicMock() + database.encrypted_extra = None + + params: dict[str, Any] = {} + PostgresEngineSpec.update_params_from_encrypted_extra(database, params) + + # No modifications should be made + assert params == {} + + +def test_update_params_from_encrypted_extra_invalid_json() -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec + + database = MagicMock() + database.encrypted_extra = "not-valid-json" + + params: dict[str, Any] = {} + + with pytest.raises(json.JSONDecodeError): + PostgresEngineSpec.update_params_from_encrypted_extra(database, params) + + +def test_encrypted_extra_sensitive_fields() -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec + + # Verify sensitive fields are properly defined + assert ( + "$.aws_iam.external_id" in PostgresEngineSpec.encrypted_extra_sensitive_fields + ) + assert "$.aws_iam.role_arn" in PostgresEngineSpec.encrypted_extra_sensitive_fields + + +def test_mask_encrypted_extra() -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec + + encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/SecretRole", + "external_id": "secret-external-id-12345", + "region": "us-east-1", + "db_username": "superset_user", + } + } + ) + + masked = PostgresEngineSpec.mask_encrypted_extra(encrypted_extra) + assert masked is not None + + masked_config = json.loads(masked) + + # role_arn and external_id should be masked + assert ( + masked_config["aws_iam"]["role_arn"] + != "arn:aws:iam::123456789012:role/SecretRole" + ) + assert masked_config["aws_iam"]["external_id"] != "secret-external-id-12345" + + # Non-sensitive fields should remain unchanged + assert masked_config["aws_iam"]["enabled"] is True + assert masked_config["aws_iam"]["region"] == "us-east-1" + assert masked_config["aws_iam"]["db_username"] == "superset_user" + + +def test_aurora_postgres_inherits_from_postgres() -> None: + from superset.db_engine_specs.aurora import AuroraPostgresEngineSpec + from superset.db_engine_specs.postgres import PostgresEngineSpec + + # Verify inheritance + assert issubclass(AuroraPostgresEngineSpec, PostgresEngineSpec) + + # Verify it inherits PostgreSQL capabilities + assert AuroraPostgresEngineSpec.supports_dynamic_schema is True + assert AuroraPostgresEngineSpec.supports_catalog is True + + +def test_aurora_mysql_engine_spec_properties() -> None: + from superset.db_engine_specs.aurora import AuroraMySQLEngineSpec + + assert AuroraMySQLEngineSpec.engine == "mysql" + assert AuroraMySQLEngineSpec.engine_name == "Aurora MySQL" + assert AuroraMySQLEngineSpec.default_driver == "mysqldb" + + +def test_aurora_mysql_inherits_from_mysql() -> None: + from superset.db_engine_specs.aurora import AuroraMySQLEngineSpec + from superset.db_engine_specs.mysql import MySQLEngineSpec + + assert issubclass(AuroraMySQLEngineSpec, MySQLEngineSpec) + assert AuroraMySQLEngineSpec.supports_dynamic_schema is True + + +def test_aurora_mysql_has_iam_support() -> None: + from superset.db_engine_specs.aurora import AuroraMySQLEngineSpec + + # Verify it inherits encrypted_extra_sensitive_fields + assert ( + "$.aws_iam.external_id" + in AuroraMySQLEngineSpec.encrypted_extra_sensitive_fields + ) + assert ( + "$.aws_iam.role_arn" in AuroraMySQLEngineSpec.encrypted_extra_sensitive_fields + ) + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_aurora_mysql_update_params_from_encrypted_extra_with_iam() -> None: + from superset.db_engine_specs.aurora import AuroraMySQLEngineSpec + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + } + ) + database.sqlalchemy_uri_decrypted = ( + "mysql://user@mydb.cluster-xyz.us-east-1.rds.amazonaws.com:3306/mydb" + ) + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ), + ): + AuroraMySQLEngineSpec.update_params_from_encrypted_extra(database, params) + + assert "connect_args" in params + assert params["connect_args"]["password"] == "iam-auth-token" # noqa: S105 + assert params["connect_args"]["user"] == "superset_iam_user" + # Note: ssl_mode is not set because MySQL drivers don't support it. + # SSL should be configured via the database's extra settings. + + +def test_aurora_data_api_classes_unchanged() -> None: + from superset.db_engine_specs.aurora import ( + AuroraMySQLDataAPI, + AuroraPostgresDataAPI, + ) + + # Verify Data API classes are still available and unchanged + assert AuroraMySQLDataAPI.engine == "mysql" + assert AuroraMySQLDataAPI.default_driver == "auroradataapi" + assert AuroraMySQLDataAPI.engine_name == "Aurora MySQL (Data API)" + + assert AuroraPostgresDataAPI.engine == "postgresql" + assert AuroraPostgresDataAPI.default_driver == "auroradataapi" + assert AuroraPostgresDataAPI.engine_name == "Aurora PostgreSQL (Data API)" diff --git a/tests/unit_tests/db_engine_specs/test_aws_iam.py b/tests/unit_tests/db_engine_specs/test_aws_iam.py new file mode 100644 index 000000000000..602bd76f68fd --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_aws_iam.py @@ -0,0 +1,1045 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=import-outside-toplevel, protected-access + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from superset.exceptions import SupersetSecurityException +from tests.unit_tests.conftest import with_feature_flags + + +def test_get_iam_credentials_success() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + mock_credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + "Expiration": "2025-01-01T00:00:00Z", + } + + with patch("boto3.client") as mock_boto3_client: + mock_sts = MagicMock() + mock_sts.assume_role.return_value = {"Credentials": mock_credentials} + mock_boto3_client.return_value = mock_sts + + credentials = AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::123456789012:role/TestRole", + region="us-east-1", + ) + + assert credentials == mock_credentials + mock_boto3_client.assert_called_once_with("sts", region_name="us-east-1") + mock_sts.assume_role.assert_called_once_with( + RoleArn="arn:aws:iam::123456789012:role/TestRole", + RoleSessionName="superset-iam-session", + DurationSeconds=3600, + ) + + +def test_get_iam_credentials_with_external_id() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + mock_credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + } + + with patch("boto3.client") as mock_boto3_client: + mock_sts = MagicMock() + mock_sts.assume_role.return_value = {"Credentials": mock_credentials} + mock_boto3_client.return_value = mock_sts + + credentials = AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::123456789012:role/TestRole", + region="us-west-2", + external_id="external-id-12345", + session_duration=900, + ) + + assert credentials == mock_credentials + mock_sts.assume_role.assert_called_once_with( + RoleArn="arn:aws:iam::123456789012:role/TestRole", + RoleSessionName="superset-iam-session", + DurationSeconds=900, + ExternalId="external-id-12345", + ) + + +def test_get_iam_credentials_access_denied() -> None: + from botocore.exceptions import ClientError + + from superset.db_engine_specs.aws_iam import ( + _credentials_cache, + _credentials_lock, + AWSIAMAuthMixin, + ) + + with _credentials_lock: + _credentials_cache.clear() + + with patch("boto3.client") as mock_boto3_client: + mock_sts = MagicMock() + mock_sts.assume_role.side_effect = ClientError( + {"Error": {"Code": "AccessDenied", "Message": "Access Denied"}}, + "AssumeRole", + ) + mock_boto3_client.return_value = mock_sts + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::123456789012:role/TestRole", + region="us-east-1", + ) + + assert "Unable to assume IAM role" in str(exc_info.value) + + +def test_get_iam_credentials_external_id_mismatch() -> None: + from botocore.exceptions import ClientError + + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + with patch("boto3.client") as mock_boto3_client: + mock_sts = MagicMock() + mock_sts.assume_role.side_effect = ClientError( + { + "Error": { + "Code": "AccessDenied", + "Message": "The external id does not match", + } + }, + "AssumeRole", + ) + mock_boto3_client.return_value = mock_sts + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::123456789012:role/TestRole", + region="us-east-1", + external_id="wrong-id", + ) + + assert "External ID mismatch" in str(exc_info.value) + + +def test_generate_rds_auth_token() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + } + + with patch("boto3.client") as mock_boto3_client: + mock_rds = MagicMock() + mock_rds.generate_db_auth_token.return_value = "iam-token-12345" + mock_boto3_client.return_value = mock_rds + + token = AWSIAMAuthMixin.generate_rds_auth_token( + credentials=credentials, + hostname="mydb.cluster-xyz.us-east-1.rds.amazonaws.com", + port=5432, + username="superset_user", + region="us-east-1", + ) + + assert token == "iam-token-12345" # noqa: S105 + mock_boto3_client.assert_called_once_with( + "rds", + region_name="us-east-1", + aws_access_key_id="ASIA...", + aws_secret_access_key="secret...", # noqa: S106 + aws_session_token="token...", # noqa: S106 + ) + mock_rds.generate_db_auth_token.assert_called_once_with( + DBHostname="mydb.cluster-xyz.us-east-1.rds.amazonaws.com", + Port=5432, + DBUsername="superset_user", + ) + + +def test_apply_iam_authentication_feature_flag_disabled() -> None: + """Test that IAM auth is blocked when feature flag is disabled.""" + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = ( + "postgresql://user@mydb.cluster-xyz.us-east-1.rds.amazonaws.com:5432/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + + params: dict[str, Any] = {} + + # Feature flag is disabled by default + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin._apply_iam_authentication( + mock_database, + params, + iam_config, + ) + + assert "AWS IAM database authentication is not enabled" in str(exc_info.value) + assert "AWS_DATABASE_IAM_AUTH" in str(exc_info.value) + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_iam_authentication() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = ( + "postgresql://user@mydb.cluster-xyz.us-east-1.rds.amazonaws.com:5432/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ) as mock_get_creds, + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ) as mock_gen_token, + ): + AWSIAMAuthMixin._apply_iam_authentication(mock_database, params, iam_config) + + mock_get_creds.assert_called_once_with( + role_arn="arn:aws:iam::123456789012:role/TestRole", + region="us-east-1", + external_id=None, + session_duration=3600, + ) + + mock_gen_token.assert_called_once() + token_call_kwargs = mock_gen_token.call_args[1] + assert ( + token_call_kwargs["hostname"] == "mydb.cluster-xyz.us-east-1.rds.amazonaws.com" + ) + assert token_call_kwargs["port"] == 5432 + assert token_call_kwargs["username"] == "superset_iam_user" + + assert params["connect_args"]["password"] == "iam-auth-token" # noqa: S105 + assert params["connect_args"]["user"] == "superset_iam_user" + assert params["connect_args"]["sslmode"] == "require" + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_iam_authentication_with_external_id() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = ( + "postgresql://user@mydb.us-west-2.rds.amazonaws.com:5432/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::222222222222:role/CrossAccountRole", + "external_id": "superset-prod-12345", + "region": "us-west-2", + "db_username": "iam_user", + "session_duration": 1800, + } + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ) as mock_get_creds, + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ), + ): + AWSIAMAuthMixin._apply_iam_authentication(mock_database, params, iam_config) + + mock_get_creds.assert_called_once_with( + role_arn="arn:aws:iam::222222222222:role/CrossAccountRole", + region="us-west-2", + external_id="superset-prod-12345", + session_duration=1800, + ) + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_iam_authentication_missing_role_arn() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = ( + "postgresql://user@mydb.us-east-1.rds.amazonaws.com:5432/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "region": "us-east-1", + "db_username": "superset_iam_user", + } + + params: dict[str, Any] = {} + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin._apply_iam_authentication(mock_database, params, iam_config) + + assert "role_arn" in str(exc_info.value) + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_iam_authentication_missing_db_username() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = ( + "postgresql://user@mydb.us-east-1.rds.amazonaws.com:5432/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + } + + params: dict[str, Any] = {} + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin._apply_iam_authentication(mock_database, params, iam_config) + + assert "db_username" in str(exc_info.value) + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_iam_authentication_default_port() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + # URI without explicit port + mock_database.sqlalchemy_uri_decrypted = ( + "postgresql://user@mydb.us-east-1.rds.amazonaws.com/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ) as mock_gen_token, + ): + AWSIAMAuthMixin._apply_iam_authentication(mock_database, params, iam_config) + + # Should use default port 5432 + token_call_kwargs = mock_gen_token.call_args[1] + assert token_call_kwargs["port"] == 5432 + + +def test_get_iam_credentials_boto3_not_installed() -> None: + import builtins + + from superset.db_engine_specs.aws_iam import ( + _credentials_cache, + _credentials_lock, + AWSIAMAuthMixin, + ) + + with _credentials_lock: + _credentials_cache.clear() + + # Patch the import mechanism to simulate boto3 not being installed + real_import = builtins.__import__ + + def fake_import(name: str, *args: Any, **kwargs: Any) -> Any: + if name == "boto3" or name.startswith("boto3."): + raise ImportError("No module named 'boto3'") + return real_import(name, *args, **kwargs) + + with patch.object(builtins, "__import__", side_effect=fake_import): + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::123456789012:role/TestRole", + region="us-east-1", + ) + + assert "boto3 is required" in str(exc_info.value) + + +def test_get_iam_credentials_caching() -> None: + from superset.db_engine_specs.aws_iam import ( + _credentials_cache, + _credentials_lock, + AWSIAMAuthMixin, + ) + + mock_credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + } + + # Clear cache before test + with _credentials_lock: + _credentials_cache.clear() + + with patch("boto3.client") as mock_boto3_client: + mock_sts = MagicMock() + mock_sts.assume_role.return_value = {"Credentials": mock_credentials} + mock_boto3_client.return_value = mock_sts + + # First call should hit STS + result1 = AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::123456789012:role/CachedRole", + region="us-east-1", + ) + + # Second call should use cache + result2 = AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::123456789012:role/CachedRole", + region="us-east-1", + ) + + assert result1 == mock_credentials + assert result2 == mock_credentials + # STS should only be called once + mock_sts.assume_role.assert_called_once() + + # Clean up + with _credentials_lock: + _credentials_cache.clear() + + +def test_get_iam_credentials_cache_different_keys() -> None: + from superset.db_engine_specs.aws_iam import ( + _credentials_cache, + _credentials_lock, + AWSIAMAuthMixin, + ) + + creds_role1 = { + "AccessKeyId": "ASIA_ROLE1", + "SecretAccessKey": "secret1", + "SessionToken": "token1", + } + creds_role2 = { + "AccessKeyId": "ASIA_ROLE2", + "SecretAccessKey": "secret2", + "SessionToken": "token2", + } + + # Clear cache before test + with _credentials_lock: + _credentials_cache.clear() + + with patch("boto3.client") as mock_boto3_client: + mock_sts = MagicMock() + mock_sts.assume_role.side_effect = [ + {"Credentials": creds_role1}, + {"Credentials": creds_role2}, + ] + mock_boto3_client.return_value = mock_sts + + result1 = AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::111111111111:role/Role1", + region="us-east-1", + ) + result2 = AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::222222222222:role/Role2", + region="us-east-1", + ) + + assert result1 == creds_role1 + assert result2 == creds_role2 + # Both calls should hit STS (different cache keys) + assert mock_sts.assume_role.call_count == 2 + + # Clean up + with _credentials_lock: + _credentials_cache.clear() + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_iam_authentication_custom_ssl_args() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = ( + "mysql://user@mydb.cluster-xyz.us-east-1.rds.amazonaws.com:3306/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ), + ): + AWSIAMAuthMixin._apply_iam_authentication( + mock_database, + params, + iam_config, + ssl_args={"ssl_mode": "REQUIRED"}, + default_port=3306, + ) + + assert params["connect_args"]["ssl_mode"] == "REQUIRED" + assert "sslmode" not in params["connect_args"] + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_iam_authentication_custom_default_port() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + # URI without explicit port + mock_database.sqlalchemy_uri_decrypted = ( + "mysql://user@mydb.cluster-xyz.us-east-1.rds.amazonaws.com/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ) as mock_gen_token, + ): + AWSIAMAuthMixin._apply_iam_authentication( + mock_database, + params, + iam_config, + default_port=3306, + ) + + token_call_kwargs = mock_gen_token.call_args[1] + assert token_call_kwargs["port"] == 3306 + + +def test_generate_redshift_credentials() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + } + + with patch("boto3.client") as mock_boto3_client: + mock_redshift = MagicMock() + mock_redshift.get_credentials.return_value = { + "dbUser": "IAM:admin", + "dbPassword": "redshift-temp-password", + } + mock_boto3_client.return_value = mock_redshift + + db_user, db_password = AWSIAMAuthMixin.generate_redshift_credentials( + credentials=credentials, + workgroup_name="my-workgroup", + db_name="dev", + region="us-east-1", + ) + + assert db_user == "IAM:admin" + assert db_password == "redshift-temp-password" # noqa: S105 + mock_boto3_client.assert_called_once_with( + "redshift-serverless", + region_name="us-east-1", + aws_access_key_id="ASIA...", + aws_secret_access_key="secret...", # noqa: S106 + aws_session_token="token...", # noqa: S106 + ) + mock_redshift.get_credentials.assert_called_once_with( + workgroupName="my-workgroup", + dbName="dev", + ) + + +def test_generate_redshift_credentials_client_error() -> None: + from botocore.exceptions import ClientError + + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + } + + with patch("boto3.client") as mock_boto3_client: + mock_redshift = MagicMock() + mock_redshift.get_credentials.side_effect = ClientError( + {"Error": {"Code": "AccessDenied", "Message": "Access Denied"}}, + "GetCredentials", + ) + mock_boto3_client.return_value = mock_redshift + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin.generate_redshift_credentials( + credentials=credentials, + workgroup_name="my-workgroup", + db_name="dev", + region="us-east-1", + ) + + assert "Failed to get Redshift Serverless credentials" in str(exc_info.value) + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_redshift_iam_authentication() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = ( + "redshift+psycopg2://user@my-workgroup.123456789012.us-east-1" + ".redshift-serverless.amazonaws.com:5439/dev" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/RedshiftRole", + "region": "us-east-1", + "workgroup_name": "my-workgroup", + "db_name": "dev", + } + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ) as mock_get_creds, + patch.object( + AWSIAMAuthMixin, + "generate_redshift_credentials", + return_value=("IAM:admin", "redshift-temp-password"), + ) as mock_gen_creds, + ): + AWSIAMAuthMixin._apply_redshift_iam_authentication( + mock_database, params, iam_config + ) + + mock_get_creds.assert_called_once_with( + role_arn="arn:aws:iam::123456789012:role/RedshiftRole", + region="us-east-1", + external_id=None, + session_duration=3600, + ) + + mock_gen_creds.assert_called_once_with( + credentials={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + workgroup_name="my-workgroup", + db_name="dev", + region="us-east-1", + ) + + assert params["connect_args"]["password"] == "redshift-temp-password" # noqa: S105 + assert params["connect_args"]["user"] == "IAM:admin" + assert params["connect_args"]["sslmode"] == "verify-ca" + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_redshift_iam_authentication_missing_workgroup() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = "redshift+psycopg2://user@host:5439/dev" + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/RedshiftRole", + "region": "us-east-1", + "db_name": "dev", + } + + params: dict[str, Any] = {} + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin._apply_redshift_iam_authentication( + mock_database, params, iam_config + ) + + assert "workgroup_name" in str(exc_info.value) + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_redshift_iam_authentication_missing_db_name() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = "redshift+psycopg2://user@host:5439/dev" + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/RedshiftRole", + "region": "us-east-1", + "workgroup_name": "my-workgroup", + } + + params: dict[str, Any] = {} + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin._apply_redshift_iam_authentication( + mock_database, params, iam_config + ) + + assert "db_name" in str(exc_info.value) + + +def test_generate_redshift_cluster_credentials() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + } + + with patch("boto3.client") as mock_boto3_client: + mock_redshift = MagicMock() + mock_redshift.get_cluster_credentials.return_value = { + "DbUser": "IAM:superset_user", + "DbPassword": "redshift-cluster-temp-password", + } + mock_boto3_client.return_value = mock_redshift + + db_user, db_password = AWSIAMAuthMixin.generate_redshift_cluster_credentials( + credentials=credentials, + cluster_identifier="my-redshift-cluster", + db_user="superset_user", + db_name="analytics", + region="us-east-1", + ) + + assert db_user == "IAM:superset_user" + assert db_password == "redshift-cluster-temp-password" # noqa: S105 + mock_boto3_client.assert_called_once_with( + "redshift", + region_name="us-east-1", + aws_access_key_id="ASIA...", + aws_secret_access_key="secret...", # noqa: S106 + aws_session_token="token...", # noqa: S106 + ) + mock_redshift.get_cluster_credentials.assert_called_once_with( + ClusterIdentifier="my-redshift-cluster", + DbUser="superset_user", + DbName="analytics", + AutoCreate=False, + ) + + +def test_generate_redshift_cluster_credentials_with_auto_create() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + } + + with patch("boto3.client") as mock_boto3_client: + mock_redshift = MagicMock() + mock_redshift.get_cluster_credentials.return_value = { + "DbUser": "IAM:new_user", + "DbPassword": "temp-password", + } + mock_boto3_client.return_value = mock_redshift + + AWSIAMAuthMixin.generate_redshift_cluster_credentials( + credentials=credentials, + cluster_identifier="my-cluster", + db_user="new_user", + db_name="dev", + region="us-west-2", + auto_create=True, + ) + + mock_redshift.get_cluster_credentials.assert_called_once_with( + ClusterIdentifier="my-cluster", + DbUser="new_user", + DbName="dev", + AutoCreate=True, + ) + + +def test_generate_redshift_cluster_credentials_client_error() -> None: + from botocore.exceptions import ClientError + + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + } + + with patch("boto3.client") as mock_boto3_client: + mock_redshift = MagicMock() + mock_redshift.get_cluster_credentials.side_effect = ClientError( + {"Error": {"Code": "ClusterNotFound", "Message": "Cluster not found"}}, + "GetClusterCredentials", + ) + mock_boto3_client.return_value = mock_redshift + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin.generate_redshift_cluster_credentials( + credentials=credentials, + cluster_identifier="nonexistent-cluster", + db_user="superset_user", + db_name="dev", + region="us-east-1", + ) + + assert "Failed to get Redshift cluster credentials" in str(exc_info.value) + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_redshift_iam_authentication_provisioned_cluster() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = ( + "redshift+psycopg2://user@my-cluster.abc123.us-east-1" + ".redshift.amazonaws.com:5439/analytics" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/RedshiftRole", + "region": "us-east-1", + "cluster_identifier": "my-cluster", + "db_username": "superset_user", + "db_name": "analytics", + } + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ) as mock_get_creds, + patch.object( + AWSIAMAuthMixin, + "generate_redshift_cluster_credentials", + return_value=("IAM:superset_user", "cluster-temp-password"), + ) as mock_gen_creds, + ): + AWSIAMAuthMixin._apply_redshift_iam_authentication( + mock_database, params, iam_config + ) + + mock_get_creds.assert_called_once_with( + role_arn="arn:aws:iam::123456789012:role/RedshiftRole", + region="us-east-1", + external_id=None, + session_duration=3600, + ) + + mock_gen_creds.assert_called_once_with( + credentials={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + cluster_identifier="my-cluster", + db_user="superset_user", + db_name="analytics", + region="us-east-1", + ) + + assert params["connect_args"]["password"] == "cluster-temp-password" # noqa: S105 + assert params["connect_args"]["user"] == "IAM:superset_user" + assert params["connect_args"]["sslmode"] == "verify-ca" + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_redshift_iam_authentication_provisioned_missing_db_username() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = "redshift+psycopg2://user@host:5439/dev" + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/RedshiftRole", + "region": "us-east-1", + "cluster_identifier": "my-cluster", + "db_name": "dev", + # Missing db_username - required for provisioned clusters + } + + params: dict[str, Any] = {} + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin._apply_redshift_iam_authentication( + mock_database, params, iam_config + ) + + assert "db_username" in str(exc_info.value) + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_redshift_iam_authentication_both_workgroup_and_cluster() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = "redshift+psycopg2://user@host:5439/dev" + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/RedshiftRole", + "region": "us-east-1", + "workgroup_name": "my-workgroup", + "cluster_identifier": "my-cluster", + "db_name": "dev", + } + + params: dict[str, Any] = {} + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin._apply_redshift_iam_authentication( + mock_database, params, iam_config + ) + + assert "cannot have both" in str(exc_info.value) + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_redshift_iam_authentication_neither_workgroup_nor_cluster() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = "redshift+psycopg2://user@host:5439/dev" + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/RedshiftRole", + "region": "us-east-1", + "db_name": "dev", + # Missing both workgroup_name and cluster_identifier + } + + params: dict[str, Any] = {} + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin._apply_redshift_iam_authentication( + mock_database, params, iam_config + ) + + assert "must include either workgroup_name" in str(exc_info.value) diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index bff4c9311710..ccfe2b337f77 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -23,18 +23,27 @@ import re from textwrap import dedent from typing import Any +from urllib.parse import parse_qs, urlparse import pytest from pytest_mock import MockerFixture -from sqlalchemy import types +from sqlalchemy import Boolean, Column, Integer, types from sqlalchemy.dialects import sqlite from sqlalchemy.engine.url import make_url, URL from sqlalchemy.sql import sqltypes +from superset.db_engine_specs.base import BaseEngineSpec, convert_inspector_columns from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.exceptions import OAuth2RedirectError from superset.sql.parse import Table -from superset.superset_typing import ResultSetColumnType, SQLAColumnType -from superset.utils.core import GenericDataType +from superset.superset_typing import ( + OAuth2ClientConfig, + OAuth2State, + ResultSetColumnType, + SQLAColumnType, +) +from superset.utils.core import FilterOperator, GenericDataType +from superset.utils.oauth2 import decode_oauth2_state from tests.unit_tests.db_engine_specs.utils import assert_column_spec @@ -68,9 +77,6 @@ def test_get_text_clause_with_colon() -> None: """ Make sure text clauses are correctly escaped """ - - from superset.db_engine_specs.base import BaseEngineSpec - text_clause = BaseEngineSpec.get_text_clause( "SELECT foo FROM tbl WHERE foo = '123:456')" ) @@ -90,8 +96,6 @@ def mock_validate(sqlalchemy_uri: URL) -> None: {"DB_SQLA_URI_VALIDATOR": mock_validate}, ) - from superset.db_engine_specs.base import BaseEngineSpec - with pytest.raises(ValueError): # noqa: PT011 BaseEngineSpec.validate_database_uri(URL.create("sqlite")) @@ -130,8 +134,6 @@ def mock_validate(sqlalchemy_uri: URL) -> None: ], ) def test_cte_query_parsing(original: types.TypeEngine, expected: str) -> None: - from superset.db_engine_specs.base import BaseEngineSpec - actual = BaseEngineSpec.get_cte_query(original) assert actual == expected @@ -197,8 +199,6 @@ def test_get_column_spec( def test_convert_inspector_columns( cols: list[SQLAColumnType], expected_result: list[ResultSetColumnType] ): - from superset.db_engine_specs.base import convert_inspector_columns - assert convert_inspector_columns(cols) == expected_result @@ -206,8 +206,6 @@ def test_select_star(mocker: MockerFixture) -> None: """ Test the ``select_star`` method. """ - from superset.db_engine_specs.base import BaseEngineSpec - cols: list[ResultSetColumnType] = [ { "column_name": "a", @@ -249,7 +247,6 @@ def test_extra_table_metadata(mocker: MockerFixture) -> None: """ Test the deprecated `extra_table_metadata` method. """ - from superset.db_engine_specs.base import BaseEngineSpec from superset.models.core import Database class ThirdPartyDBEngineSpec(BaseEngineSpec): @@ -285,8 +282,6 @@ def test_get_default_catalog(mocker: MockerFixture) -> None: """ Test the `get_default_catalog` method. """ - from superset.db_engine_specs.base import BaseEngineSpec - database = mocker.MagicMock() assert BaseEngineSpec.get_default_catalog(database) is None @@ -295,7 +290,6 @@ def test_quote_table() -> None: """ Test the `quote_table` function. """ - from superset.db_engine_specs.base import BaseEngineSpec dialect = sqlite.dialect() @@ -318,8 +312,6 @@ def test_mask_encrypted_extra() -> None: """ Test that the private key is masked when the database is edited. """ - from superset.db_engine_specs.base import BaseEngineSpec - config = json.dumps( { "foo": "bar", @@ -342,8 +334,6 @@ def test_unmask_encrypted_extra() -> None: """ Test that the private key can be reused from the previous `encrypted_extra`. """ - from superset.db_engine_specs.base import BaseEngineSpec - old = json.dumps( { "foo": "bar", @@ -375,8 +365,6 @@ def test_impersonate_user_backwards_compatible(mocker: MockerFixture) -> None: """ Test that the `impersonate_user` method calls the original methods it replaced. """ - from superset.db_engine_specs.base import BaseEngineSpec - database = mocker.MagicMock() url = make_url("sqlite://foo.db") new_url = make_url("sqlite://bar.db") @@ -417,8 +405,6 @@ def test_impersonate_user_no_database(mocker: MockerFixture) -> None: """ Test `impersonate_user` when `update_impersonation_config` has an old signature. """ - from superset.db_engine_specs.base import BaseEngineSpec - database = mocker.MagicMock() url = make_url("sqlite://foo.db") new_url = make_url("sqlite://bar.db") @@ -457,10 +443,6 @@ def test_handle_boolean_filter_default_behavior() -> None: """ Test that BaseEngineSpec uses IS operators for boolean filters by default. """ - from sqlalchemy import Boolean, Column - - from superset.db_engine_specs.base import BaseEngineSpec - # Create a mock SQLAlchemy column bool_col = Column("test_col", Boolean) @@ -479,9 +461,6 @@ def test_handle_boolean_filter_with_equality() -> None: """ Test that BaseEngineSpec can use equality operators when configured. """ - from sqlalchemy import Boolean, Column - - from superset.db_engine_specs.base import BaseEngineSpec # Create a test engine spec that uses equality class TestEngineSpec(BaseEngineSpec): @@ -502,15 +481,9 @@ def test_handle_null_filter() -> None: """ Test null/not null filter handling. """ - from sqlalchemy import Boolean, Column - - from superset.db_engine_specs.base import BaseEngineSpec - bool_col = Column("test_col", Boolean) # Test IS_NULL - use actual FilterOperator values - from superset.utils.core import FilterOperator - result_null = BaseEngineSpec.handle_null_filter(bool_col, FilterOperator.IS_NULL) assert hasattr(result_null, "left") assert hasattr(result_null, "right") @@ -531,15 +504,9 @@ def test_handle_comparison_filter() -> None: """ Test comparison filter handling for all operators. """ - from sqlalchemy import Column, Integer - - from superset.db_engine_specs.base import BaseEngineSpec - int_col = Column("test_col", Integer) # Test all comparison operators - use actual FilterOperator values - from superset.utils.core import FilterOperator - operators_and_values = [ (FilterOperator.EQUALS, 5), (FilterOperator.NOT_EQUALS, 5), @@ -563,8 +530,6 @@ def test_use_equality_for_boolean_filters_property() -> None: """ Test that BaseEngineSpec has the correct default value for boolean filter property. """ - from superset.db_engine_specs.base import BaseEngineSpec - # Default should be False (use IS operators) assert BaseEngineSpec.use_equality_for_boolean_filters is False @@ -573,9 +538,6 @@ def test_extract_errors(mocker: MockerFixture) -> None: """ Test that error is extracted correctly when no custom error message is provided. """ - - from superset.db_engine_specs.base import BaseEngineSpec - mocker.patch( "flask.current_app.config", {}, @@ -597,8 +559,6 @@ def test_extract_errors_from_config(mocker: MockerFixture) -> None: using database_name. """ - from superset.db_engine_specs.base import BaseEngineSpec - class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" @@ -632,8 +592,6 @@ def test_extract_errors_only_to_specified_database(mocker: MockerFixture) -> Non Test that custom error messages are only applied to the specified database_name. """ - from superset.db_engine_specs.base import BaseEngineSpec - class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" @@ -669,8 +627,6 @@ def test_extract_errors_from_config_with_regex(mocker: MockerFixture) -> None: and show_issue_info are extracted correctly from config. """ - from superset.db_engine_specs.base import BaseEngineSpec - class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" @@ -740,7 +696,6 @@ def test_extract_errors_with_non_dict_custom_errors(mocker: MockerFixture): Test that extract_errors doesn't fail when custom database errors are in wrong format. """ - from superset.db_engine_specs.base import BaseEngineSpec class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" @@ -765,7 +720,6 @@ def test_extract_errors_with_non_dict_engine_custom_errors(mocker: MockerFixture Test that extract_errors doesn't fail when database-specific custom errors are in wrong format. """ - from superset.db_engine_specs.base import BaseEngineSpec class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" @@ -790,7 +744,6 @@ def test_extract_errors_with_empty_custom_error_message(mocker: MockerFixture): Test that when the custom error message is empty, the original error message is preserved. """ - from superset.db_engine_specs.base import BaseEngineSpec class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" @@ -824,7 +777,6 @@ def test_extract_errors_matches_database_name_selection(mocker: MockerFixture) - """ Test that custom error messages are matched by database_name. """ - from superset.db_engine_specs.base import BaseEngineSpec class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" @@ -866,7 +818,6 @@ def test_extract_errors_no_match_falls_back(mocker: MockerFixture) -> None: """ Test that when database_name has no match, the original error message is preserved. """ - from superset.db_engine_specs.base import BaseEngineSpec class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" @@ -901,12 +852,6 @@ def test_get_oauth2_authorization_uri_standard_params(mocker: MockerFixture) -> Test that BaseEngineSpec.get_oauth2_authorization_uri uses standard OAuth 2.0 parameters only and does not include provider-specific params like prompt=consent. """ - from urllib.parse import parse_qs, urlparse - - from superset.db_engine_specs.base import BaseEngineSpec - from superset.superset_typing import OAuth2ClientConfig, OAuth2State - from superset.utils.oauth2 import decode_oauth2_state - config: OAuth2ClientConfig = { "id": "client-id", "secret": "client-secret", @@ -943,3 +888,81 @@ def test_get_oauth2_authorization_uri_standard_params(mocker: MockerFixture) -> assert "prompt" not in query assert "access_type" not in query assert "include_granted_scopes" not in query + + +def test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) -> None: + """ + Test that start_oauth2_dance uses DATABASE_OAUTH2_REDIRECT_URI config if set. + """ + custom_redirect_uri = "https://proxy.example.com/oauth2/" + + mocker.patch( + "flask.current_app.config", + { + "DATABASE_OAUTH2_REDIRECT_URI": custom_redirect_uri, + "SECRET_KEY": "test-secret-key", + "DATABASE_OAUTH2_JWT_ALGORITHM": "HS256", + }, + ) + + g = mocker.patch("superset.db_engine_specs.base.g") + g.user.id = 1 + + database = mocker.MagicMock() + database.id = 1 + database.get_oauth2_config.return_value = { + "id": "client-id", + "secret": "client-secret", + "scope": "read write", + "redirect_uri": "https://another-link.com", + "authorization_request_uri": "https://oauth.example.com/authorize", + "token_request_uri": "https://oauth.example.com/token", + } + + with pytest.raises(OAuth2RedirectError) as exc_info: + BaseEngineSpec.start_oauth2_dance(database) + + error = exc_info.value.error + + assert error.extra["redirect_uri"] == custom_redirect_uri + + +def test_start_oauth2_dance_falls_back_to_url_for(mocker: MockerFixture) -> None: + """ + Test that start_oauth2_dance falls back to url_for when no config is set. + """ + fallback_uri = "http://localhost:8088/api/v1/database/oauth2/" + + mocker.patch( + "flask.current_app.config", + { + "SECRET_KEY": "test-secret-key", + "DATABASE_OAUTH2_JWT_ALGORITHM": "HS256", + }, + ) + mocker.patch( + "superset.db_engine_specs.base.url_for", + return_value=fallback_uri, + ) + + g = mocker.patch("superset.db_engine_specs.base.g") + g.user.id = 1 + + database = mocker.MagicMock() + database.id = 1 + database.get_oauth2_config.return_value = { + "id": "client-id", + "secret": "client-secret", + "scope": "read write", + "redirect_uri": "https://another-link.com", + "authorization_request_uri": "https://oauth.example.com/authorize", + "token_request_uri": "https://oauth.example.com/token", + "request_content_type": "json", + } + + with pytest.raises(OAuth2RedirectError) as exc_info: + BaseEngineSpec.start_oauth2_dance(database) + + error = exc_info.value.error + + assert error.extra["redirect_uri"] == fallback_uri diff --git a/tests/unit_tests/db_engine_specs/test_gsheets.py b/tests/unit_tests/db_engine_specs/test_gsheets.py index 2ed796c32d24..a2406c18fc32 100644 --- a/tests/unit_tests/db_engine_specs/test_gsheets.py +++ b/tests/unit_tests/db_engine_specs/test_gsheets.py @@ -23,6 +23,8 @@ import pandas as pd import pytest from pytest_mock import MockerFixture +from requests.exceptions import HTTPError +from shillelagh.exceptions import UnauthenticatedError from sqlalchemy.engine.url import make_url from superset.errors import ErrorLevel, SupersetError, SupersetErrorType @@ -737,3 +739,253 @@ def test_update_params_from_encrypted_extra(mocker: MockerFixture) -> None: GSheetsEngineSpec.update_params_from_encrypted_extra(database, params) assert params == {"foo": "bar"} + + +def test_needs_oauth2_with_credentials_error(mocker: MockerFixture) -> None: + """ + Test that needs_oauth2 returns True for google-auth credentials error. + + When a token is manually revoked on Google side, google-auth tries to + refresh credentials but fails with this message. + """ + from superset.db_engine_specs.gsheets import GSheetsEngineSpec + + g = mocker.patch("superset.db_engine_specs.gsheets.g") + g.user = mocker.MagicMock() + + ex = Exception("credentials do not contain the necessary fields") + assert GSheetsEngineSpec.needs_oauth2(ex) is True + + +def test_needs_oauth2_with_other_error(mocker: MockerFixture) -> None: + """ + Test that needs_oauth2 returns False for other errors. + """ + from superset.db_engine_specs.gsheets import GSheetsEngineSpec + + g = mocker.patch("superset.db_engine_specs.gsheets.g") + g.user = mocker.MagicMock() + + ex = Exception("Some other error") + assert GSheetsEngineSpec.needs_oauth2(ex) is False + + +def test_get_oauth2_fresh_token_success( + mocker: MockerFixture, + oauth2_config: OAuth2ClientConfig, +) -> None: + """ + Test that get_oauth2_fresh_token returns token on success. + """ + from superset.db_engine_specs.gsheets import GSheetsEngineSpec + + requests = mocker.patch("superset.db_engine_specs.base.requests") + requests.post().json.return_value = { + "access_token": "new-access-token", + "expires_in": 3600, + } + + result = GSheetsEngineSpec.get_oauth2_fresh_token(oauth2_config, "refresh-token") + assert result == { + "access_token": "new-access-token", + "expires_in": 3600, + } + + +def test_get_oauth2_fresh_token_invalid_grant( + mocker: MockerFixture, + oauth2_config: OAuth2ClientConfig, +) -> None: + """ + Test that get_oauth2_fresh_token raises UnauthenticatedError for invalid_grant. + + When a token is revoked on Google side, the refresh request returns 400 + with error=invalid_grant. + """ + from superset.db_engine_specs.gsheets import GSheetsEngineSpec + + mock_response = mocker.MagicMock() + mock_response.status_code = 400 + mock_response.json.return_value = { + "error": "invalid_grant", + "error_description": "Token has been expired or revoked.", + } + http_error = HTTPError() + http_error.response = mock_response + + requests = mocker.patch("superset.db_engine_specs.base.requests") + requests.post().raise_for_status.side_effect = http_error + + with pytest.raises(UnauthenticatedError): + GSheetsEngineSpec.get_oauth2_fresh_token(oauth2_config, "refresh-token") + + +def test_get_oauth2_fresh_token_other_http_error( + mocker: MockerFixture, + oauth2_config: OAuth2ClientConfig, +) -> None: + """ + Test that get_oauth2_fresh_token re-raises non-invalid_grant HTTP errors. + """ + from superset.db_engine_specs.gsheets import GSheetsEngineSpec + + mock_response = mocker.MagicMock() + mock_response.status_code = 500 + mock_response.json.return_value = {"error": "server_error"} + + http_error = HTTPError() + http_error.response = mock_response + + requests = mocker.patch("superset.db_engine_specs.base.requests") + requests.post().raise_for_status.side_effect = http_error + + with pytest.raises(HTTPError): + GSheetsEngineSpec.get_oauth2_fresh_token(oauth2_config, "refresh-token") + + +def test_get_table_names_triggers_oauth2_dance(mocker: MockerFixture) -> None: + """ + Test that get_table_names triggers OAuth2 dance when no token exists. + """ + from superset.db_engine_specs.gsheets import GSheetsEngineSpec + + g = mocker.patch("superset.db_engine_specs.gsheets.g") + g.user.id = 1 + + get_oauth2_access_token = mocker.patch( + "superset.db_engine_specs.gsheets.get_oauth2_access_token", + return_value=None, + ) + + database = mocker.MagicMock() + database.id = 1 + database.is_oauth2_enabled.return_value = True + database.get_oauth2_config.return_value = {"id": "client-id"} + database.db_engine_spec = GSheetsEngineSpec + + inspector = mocker.MagicMock() + + GSheetsEngineSpec.get_table_names(database, inspector, None) + + database.start_oauth2_dance.assert_called_once() + get_oauth2_access_token.assert_called_once() + + +def test_get_table_names_does_not_trigger_oauth2_when_token_exists( + mocker: MockerFixture, +) -> None: + """ + Test that get_table_names does not trigger OAuth2 dance when token exists. + """ + from superset.db_engine_specs.gsheets import GSheetsEngineSpec + + g = mocker.patch("superset.db_engine_specs.gsheets.g") + g.user.id = 1 + + get_oauth2_access_token = mocker.patch( + "superset.db_engine_specs.gsheets.get_oauth2_access_token", + return_value="valid-token", + ) + + mocker.patch( + "superset.db_engine_specs.shillelagh.ShillelaghEngineSpec.get_table_names", + return_value={"sheet1", "sheet2"}, + ) + + database = mocker.MagicMock() + database.id = 1 + database.is_oauth2_enabled.return_value = True + database.get_oauth2_config.return_value = {"id": "client-id"} + database.db_engine_spec = GSheetsEngineSpec + + inspector = mocker.MagicMock() + + result = GSheetsEngineSpec.get_table_names(database, inspector, None) + + database.start_oauth2_dance.assert_not_called() + get_oauth2_access_token.assert_called_once() + assert result == {"sheet1", "sheet2"} + + +def test_validate_parameters_skips_oauth2_connections_with_parameters( + mocker: MockerFixture, +) -> None: + """ + Test that validate_parameters skips validation for OAuth2 connections. + + When oauth2_client_info is present in parameters, the validation should + skip URL checks since the user will authenticate via OAuth2. + """ + from superset.db_engine_specs.gsheets import ( + GSheetsEngineSpec, + GSheetsPropertiesType, + ) + + g = mocker.patch("superset.db_engine_specs.gsheets.g") + g.user.email = "admin@example.org" + + create_engine = mocker.patch("superset.db_engine_specs.gsheets.create_engine") + conn = create_engine.return_value.connect.return_value + results = conn.execute.return_value + results.fetchall.side_effect = ProgrammingError( + "The caller does not have permission" + ) + + properties: GSheetsPropertiesType = { + "parameters": { + "service_account_info": "", + "catalog": {}, + "oauth2_client_info": {"id": "client-id", "secret": "client-secret"}, + }, + "catalog": { + "sheet1": "https://docs.google.com/spreadsheets/d/1/edit", + }, + } + errors = GSheetsEngineSpec.validate_parameters(properties) + + assert errors == [] + conn.execute.assert_not_called() + + +def test_validate_parameters_skips_oauth2_connections_with_masked_encrypted_extra( + mocker: MockerFixture, +) -> None: + """ + Test validate_parameters skips validation for OAuth2 via masked_encrypted_extra. + + When oauth2_client_info is present in masked_encrypted_extra (used during + create/update), the validation should skip URL checks. + """ + from superset.db_engine_specs.gsheets import ( + GSheetsEngineSpec, + GSheetsPropertiesType, + ) + + g = mocker.patch("superset.db_engine_specs.gsheets.g") + g.user.email = "admin@example.org" + + create_engine = mocker.patch("superset.db_engine_specs.gsheets.create_engine") + conn = create_engine.return_value.connect.return_value + results = conn.execute.return_value + results.fetchall.side_effect = ProgrammingError( + "The caller does not have permission" + ) + + properties: GSheetsPropertiesType = { + "parameters": { + "service_account_info": "", + "catalog": {}, + }, + "catalog": { + "sheet1": "https://docs.google.com/spreadsheets/d/1/edit", + }, + "masked_encrypted_extra": json.dumps( + { + "oauth2_client_info": {"id": "client-id", "secret": "XXXXXXXXXX"}, + } + ), + } + errors = GSheetsEngineSpec.validate_parameters(properties) + + assert errors == [] + conn.execute.assert_not_called() diff --git a/tests/unit_tests/db_engine_specs/test_mysql_iam.py b/tests/unit_tests/db_engine_specs/test_mysql_iam.py new file mode 100644 index 000000000000..9b5c25b53cf0 --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_mysql_iam.py @@ -0,0 +1,236 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=import-outside-toplevel + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from superset.utils import json +from tests.unit_tests.conftest import with_feature_flags + + +def test_mysql_encrypted_extra_sensitive_fields() -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec + + assert "$.aws_iam.external_id" in MySQLEngineSpec.encrypted_extra_sensitive_fields + assert "$.aws_iam.role_arn" in MySQLEngineSpec.encrypted_extra_sensitive_fields + + +def test_mysql_update_params_no_encrypted_extra() -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec + + database = MagicMock() + database.encrypted_extra = None + + params: dict[str, Any] = {} + MySQLEngineSpec.update_params_from_encrypted_extra(database, params) + + assert params == {} + + +def test_mysql_update_params_empty_encrypted_extra() -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps({}) + + params: dict[str, Any] = {} + MySQLEngineSpec.update_params_from_encrypted_extra(database, params) + + assert params == {} + + +def test_mysql_update_params_iam_disabled() -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": False, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_user", + } + } + ) + + params: dict[str, Any] = {} + MySQLEngineSpec.update_params_from_encrypted_extra(database, params) + + assert params == {} + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_mysql_update_params_with_iam() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + from superset.db_engine_specs.mysql import MySQLEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + } + ) + database.sqlalchemy_uri_decrypted = ( + "mysql://user@mydb.cluster-xyz.us-east-1.rds.amazonaws.com:3306/mydb" + ) + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ), + ): + MySQLEngineSpec.update_params_from_encrypted_extra(database, params) + + assert "connect_args" in params + assert params["connect_args"]["password"] == "iam-auth-token" # noqa: S105 + assert params["connect_args"]["user"] == "superset_iam_user" + # Note: ssl_mode is not set because MySQL drivers don't support it. + # SSL should be configured via the database's extra settings. + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_mysql_update_params_iam_uses_mysql_port() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + from superset.db_engine_specs.mysql import MySQLEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + } + ) + # URI without explicit port + database.sqlalchemy_uri_decrypted = ( + "mysql://user@mydb.cluster-xyz.us-east-1.rds.amazonaws.com/mydb" + ) + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ) as mock_gen_token, + ): + MySQLEngineSpec.update_params_from_encrypted_extra(database, params) + + # Should use default MySQL port 3306 + token_call_kwargs = mock_gen_token.call_args[1] + assert token_call_kwargs["port"] == 3306 + + +def test_mysql_update_params_merges_remaining_encrypted_extra() -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": {"enabled": False}, + "pool_size": 10, + } + ) + + params: dict[str, Any] = {} + MySQLEngineSpec.update_params_from_encrypted_extra(database, params) + + assert "aws_iam" not in params + assert params["pool_size"] == 10 + + +def test_mysql_update_params_invalid_json() -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec + + database = MagicMock() + database.encrypted_extra = "not-valid-json" + + params: dict[str, Any] = {} + + with pytest.raises(json.JSONDecodeError): + MySQLEngineSpec.update_params_from_encrypted_extra(database, params) + + +def test_mysql_mask_encrypted_extra() -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec + + encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/SecretRole", + "external_id": "secret-external-id-12345", + "region": "us-east-1", + "db_username": "superset_user", + } + } + ) + + masked = MySQLEngineSpec.mask_encrypted_extra(encrypted_extra) + assert masked is not None + + masked_config = json.loads(masked) + + # role_arn and external_id should be masked + assert ( + masked_config["aws_iam"]["role_arn"] + != "arn:aws:iam::123456789012:role/SecretRole" + ) + assert masked_config["aws_iam"]["external_id"] != "secret-external-id-12345" + + # Non-sensitive fields should remain unchanged + assert masked_config["aws_iam"]["enabled"] is True + assert masked_config["aws_iam"]["region"] == "us-east-1" + assert masked_config["aws_iam"]["db_username"] == "superset_user" diff --git a/tests/unit_tests/db_engine_specs/test_redshift_iam.py b/tests/unit_tests/db_engine_specs/test_redshift_iam.py new file mode 100644 index 000000000000..49657bad8911 --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_redshift_iam.py @@ -0,0 +1,387 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=import-outside-toplevel + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from superset.utils import json +from tests.unit_tests.conftest import with_feature_flags + + +def test_redshift_encrypted_extra_sensitive_fields() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + assert ( + "$.aws_iam.external_id" in RedshiftEngineSpec.encrypted_extra_sensitive_fields + ) + assert "$.aws_iam.role_arn" in RedshiftEngineSpec.encrypted_extra_sensitive_fields + + +def test_redshift_update_params_no_encrypted_extra() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = None + + params: dict[str, Any] = {} + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + assert params == {} + + +def test_redshift_update_params_empty_encrypted_extra() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps({}) + + params: dict[str, Any] = {} + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + assert params == {} + + +def test_redshift_update_params_iam_disabled() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": False, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "workgroup_name": "my-workgroup", + "db_name": "dev", + } + } + ) + + params: dict[str, Any] = {} + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + assert params == {} + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_redshift_update_params_with_iam() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/RedshiftRole", + "region": "us-east-1", + "workgroup_name": "my-workgroup", + "db_name": "dev", + } + } + ) + database.sqlalchemy_uri_decrypted = ( + "redshift+psycopg2://user@my-workgroup.123456789012.us-east-1" + ".redshift-serverless.amazonaws.com:5439/dev" + ) + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_redshift_credentials", + return_value=("IAM:admin", "redshift-temp-password"), + ), + ): + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + assert "connect_args" in params + assert params["connect_args"]["password"] == "redshift-temp-password" # noqa: S105 + assert params["connect_args"]["user"] == "IAM:admin" + assert params["connect_args"]["sslmode"] == "verify-ca" + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_redshift_update_params_with_external_id() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::222222222222:role/CrossAccountRedshift", + "external_id": "superset-prod-12345", + "region": "us-west-2", + "workgroup_name": "prod-workgroup", + "db_name": "analytics", + "session_duration": 1800, + } + } + ) + database.sqlalchemy_uri_decrypted = ( + "redshift+psycopg2://user@prod-workgroup.222222222222.us-west-2" + ".redshift-serverless.amazonaws.com:5439/analytics" + ) + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ) as mock_get_creds, + patch.object( + AWSIAMAuthMixin, + "generate_redshift_credentials", + return_value=("IAM:admin", "redshift-temp-password"), + ), + ): + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + mock_get_creds.assert_called_once_with( + role_arn="arn:aws:iam::222222222222:role/CrossAccountRedshift", + region="us-west-2", + external_id="superset-prod-12345", + session_duration=1800, + ) + + +def test_redshift_update_params_merges_remaining_encrypted_extra() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": {"enabled": False}, + "pool_size": 5, + } + ) + + params: dict[str, Any] = {} + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + assert "aws_iam" not in params + assert params["pool_size"] == 5 + + +def test_redshift_update_params_invalid_json() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = "not-valid-json" + + params: dict[str, Any] = {} + + with pytest.raises(json.JSONDecodeError): + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + +def test_redshift_mask_encrypted_extra() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/SecretRole", + "external_id": "secret-external-id-12345", + "region": "us-east-1", + "workgroup_name": "my-workgroup", + "db_name": "dev", + } + } + ) + + masked = RedshiftEngineSpec.mask_encrypted_extra(encrypted_extra) + assert masked is not None + + masked_config = json.loads(masked) + + # role_arn and external_id should be masked + assert ( + masked_config["aws_iam"]["role_arn"] + != "arn:aws:iam::123456789012:role/SecretRole" + ) + assert masked_config["aws_iam"]["external_id"] != "secret-external-id-12345" + + # Non-sensitive fields should remain unchanged + assert masked_config["aws_iam"]["enabled"] is True + assert masked_config["aws_iam"]["region"] == "us-east-1" + assert masked_config["aws_iam"]["workgroup_name"] == "my-workgroup" + assert masked_config["aws_iam"]["db_name"] == "dev" + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_redshift_update_params_with_iam_provisioned_cluster() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/RedshiftRole", + "region": "us-east-1", + "cluster_identifier": "my-redshift-cluster", + "db_username": "superset_user", + "db_name": "analytics", + } + } + ) + database.sqlalchemy_uri_decrypted = ( + "redshift+psycopg2://user@my-redshift-cluster.abc123.us-east-1" + ".redshift.amazonaws.com:5439/analytics" + ) + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_redshift_cluster_credentials", + return_value=("IAM:superset_user", "cluster-temp-password"), + ), + ): + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + assert "connect_args" in params + assert params["connect_args"]["password"] == "cluster-temp-password" # noqa: S105 + assert params["connect_args"]["user"] == "IAM:superset_user" + assert params["connect_args"]["sslmode"] == "verify-ca" + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_redshift_update_params_provisioned_cluster_with_external_id() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::222222222222:role/CrossAccountRedshift", + "external_id": "superset-prod-12345", + "region": "us-west-2", + "cluster_identifier": "prod-cluster", + "db_username": "analytics_user", + "db_name": "prod_db", + "session_duration": 1800, + } + } + ) + database.sqlalchemy_uri_decrypted = ( + "redshift+psycopg2://user@prod-cluster.xyz789.us-west-2" + ".redshift.amazonaws.com:5439/prod_db" + ) + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ) as mock_get_creds, + patch.object( + AWSIAMAuthMixin, + "generate_redshift_cluster_credentials", + return_value=("IAM:analytics_user", "cluster-temp-password"), + ), + ): + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + mock_get_creds.assert_called_once_with( + role_arn="arn:aws:iam::222222222222:role/CrossAccountRedshift", + region="us-west-2", + external_id="superset-prod-12345", + session_duration=1800, + ) + + +def test_redshift_mask_encrypted_extra_provisioned_cluster() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/SecretRole", + "external_id": "secret-external-id-12345", + "region": "us-east-1", + "cluster_identifier": "my-cluster", + "db_username": "superset_user", + "db_name": "analytics", + } + } + ) + + masked = RedshiftEngineSpec.mask_encrypted_extra(encrypted_extra) + assert masked is not None + + masked_config = json.loads(masked) + + # role_arn and external_id should be masked + assert ( + masked_config["aws_iam"]["role_arn"] + != "arn:aws:iam::123456789012:role/SecretRole" + ) + assert masked_config["aws_iam"]["external_id"] != "secret-external-id-12345" + + # Non-sensitive fields should remain unchanged + assert masked_config["aws_iam"]["enabled"] is True + assert masked_config["aws_iam"]["region"] == "us-east-1" + assert masked_config["aws_iam"]["cluster_identifier"] == "my-cluster" + assert masked_config["aws_iam"]["db_username"] == "superset_user" + assert masked_config["aws_iam"]["db_name"] == "analytics" diff --git a/tests/unit_tests/utils/oauth2_tests.py b/tests/unit_tests/utils/oauth2_tests.py index e9aa283b1acb..fc3ed7a651d9 100644 --- a/tests/unit_tests/utils/oauth2_tests.py +++ b/tests/unit_tests/utils/oauth2_tests.py @@ -18,11 +18,16 @@ # pylint: disable=invalid-name, disallowed-name from datetime import datetime +from typing import cast +import pytest from freezegun import freeze_time from pytest_mock import MockerFixture -from superset.utils.oauth2 import get_oauth2_access_token +from superset.superset_typing import OAuth2ClientConfig +from superset.utils.oauth2 import get_oauth2_access_token, refresh_oauth2_token + +DUMMY_OAUTH2_CONFIG = cast(OAuth2ClientConfig, {}) def test_get_oauth2_access_token_base_no_token(mocker: MockerFixture) -> None: @@ -93,3 +98,82 @@ def test_get_oauth2_access_token_base_no_refresh(mocker: MockerFixture) -> None: # check that token was deleted db.session.delete.assert_called_with(token) + + +def test_refresh_oauth2_token_deletes_token_on_oauth2_exception( + mocker: MockerFixture, +) -> None: + """ + Test that refresh_oauth2_token deletes the token on OAuth2-specific exception. + + When the token refresh fails with an OAuth2-specific exception (e.g., token + was revoked), the invalid token should be deleted and the exception re-raised. + """ + db = mocker.patch("superset.utils.oauth2.db") + mocker.patch("superset.utils.oauth2.KeyValueDistributedLock") + + class OAuth2ExceptionError(Exception): + pass + + db_engine_spec = mocker.MagicMock() + db_engine_spec.oauth2_exception = OAuth2ExceptionError + db_engine_spec.get_oauth2_fresh_token.side_effect = OAuth2ExceptionError( + "Token revoked" + ) + token = mocker.MagicMock() + token.refresh_token = "refresh-token" # noqa: S105 + + with pytest.raises(OAuth2ExceptionError): + refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token) + + db.session.delete.assert_called_with(token) + + +def test_refresh_oauth2_token_keeps_token_on_other_exception( + mocker: MockerFixture, +) -> None: + """ + Test that refresh_oauth2_token keeps the token on non-OAuth2 exceptions. + + When the token refresh fails with a transient error (e.g., network issue), + the token should be kept (refresh token may still be valid) and the + exception re-raised. + """ + db = mocker.patch("superset.utils.oauth2.db") + mocker.patch("superset.utils.oauth2.KeyValueDistributedLock") + + class OAuth2ExceptionError(Exception): + pass + + db_engine_spec = mocker.MagicMock() + db_engine_spec.oauth2_exception = OAuth2ExceptionError + db_engine_spec.get_oauth2_fresh_token.side_effect = Exception("Network error") + token = mocker.MagicMock() + token.refresh_token = "refresh-token" # noqa: S105 + + with pytest.raises(Exception, match="Network error"): + refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token) + + db.session.delete.assert_not_called() + + +def test_refresh_oauth2_token_no_access_token_in_response( + mocker: MockerFixture, +) -> None: + """ + Test that refresh_oauth2_token returns None when no access_token in response. + + This can happen when the refresh token was revoked. + """ + mocker.patch("superset.utils.oauth2.db") + mocker.patch("superset.utils.oauth2.KeyValueDistributedLock") + db_engine_spec = mocker.MagicMock() + db_engine_spec.get_oauth2_fresh_token.return_value = { + "error": "invalid_grant", + } + token = mocker.MagicMock() + token.refresh_token = "refresh-token" # noqa: S105 + + result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token) + + assert result is None