diff --git a/feature-registry-app/.gitignore b/feature-registry-app/.gitignore new file mode 100644 index 00000000..0086b149 --- /dev/null +++ b/feature-registry-app/.gitignore @@ -0,0 +1,11 @@ +# Databricks +.databricks/ +.databricks.sync-snapshots + +# Python +__pycache__/ +.pytest_cache/ +.coverage + +# Environment & Config +deploy_config.sh diff --git a/feature-registry-app/README.md b/feature-registry-app/README.md index 11bbbd28..c610b49c 100644 --- a/feature-registry-app/README.md +++ b/feature-registry-app/README.md @@ -7,13 +7,14 @@ date: 2025-08-05 # 🚀 Feature Registry Application -This application provides a modern interface for discovering and managing features with seamless integration to Unity Catalog. +This is a modern web application that allows users to interact with the Databricks Feature Registry. The app provides a user-friendly interface for exploring existing features in Unity Catalog. Additionally, users can generate code for creating feature specs and training sets to train machine learning models and deploy features as Feature Serving Endpoints. ## ✨ Features -- 🔍 List and search for features +- 🔍 List and search for features in Unity Catalog - 🔒 On-behalf-of-user authentication - ⚙️ Code-gen for creating feature specs and training sets +- 📋 Configurable catalog allow-listing for access control ## 🏗️ Architecture @@ -25,11 +26,63 @@ The application is built with: ![Feature Registry Interface](./images/feature-registry-interface.png) +## 🚀 Deployment + +### Create an App +1. Log into your destination Databricks workspace and navigate to "Compute > Apps" +2. Click on "Create App" and select "Create a custom app" +3. Enter an app name and click "Create app" + +### Customization +1. Create a file named `deploy_config.sh` in the root folder with the following variables: + ```sh + # Path to a destination folder in default Databricks workspace where source code will be sync'ed + export DEST=/Workspace/Users/Path/To/App/Code + # Name of the App to deploy + export APP_NAME=your-app-name + ``` + Or simply run `./deploy.sh` - it will create a template file if it doesn't exist + +2. Update `deploy_config.sh` with the config for your environment + +3. Ensure the Databricks CLI is installed and configured on your machine. The "DEFAULT" profile should point to the destination workspace where the app will be deployed. You can find instructions here for [AWS](https://docs.databricks.com/dev-tools/cli/index.html) / [Azure](https://learn.microsoft.com/en-us/azure/databricks/dev-tools/cli/) + +### Deploy the App +1. Navigate to the app directory +2. Run `./deploy.sh` shell command. This will sync the app code to the destination workspace location and deploy the app +3. Navigate to the Databricks workspace and access the app via "Compute > Apps" + +## 🔐 Access Control + +### Catalog Allow-Listing + +By default, the Feature Registry App will show all the catalogs to which the user has read access. You can restrict which Unity Catalog catalogs users can explore for features. This is useful for: +- Limiting feature discovery to production-ready catalogs +- Ensuring data scientists only access approved feature sets +- Organizing features by teams or projects + +#### Setting Up Allow-Listed Catalogs + +1. Edit the `src/uc_catalogs_allowlist.yaml` file +2. Uncomment and add the catalog names you want to allow: + + ```yaml + # List catalogs that should be accessible in the Feature Registry App + - production_features + - team_a_catalog + - ml_features_catalog + ``` + +3. If the file is empty or all entries are commented out, the app will show all catalogs available to the user +4. Deploy the app with the updated configuration + +**Note:** Users will still need appropriate permissions in Unity Catalog to access the data within these catalogs. The allow-list acts as an additional filter on top of existing permissions. + ## 🔑 Requirements The application requires the following scopes: -- `catalog.catalogs` -- `catalog.schemas` -- `catalog.tables` +- `catalog.catalogs:read` +- `catalog.schemas:read` +- `catalog.tables:read` -The app owner needs to grant other users `Can Use` permission for the app itself, along with the access to the underlying Datarbricks resources. +The app owner needs to grant other users `Can Use` permission for the app itself, along with access to the underlying Databricks resources. diff --git a/feature-registry-app/deploy.sh b/feature-registry-app/deploy.sh new file mode 100755 index 00000000..0b9d9143 --- /dev/null +++ b/feature-registry-app/deploy.sh @@ -0,0 +1,22 @@ +echo_red() { + echo "\033[1;31m$*\033[0m" +} + +# Validate the current folder +[[ -d "./src" && -f "./src/app.yaml" ]] || { echo_red "Error: Couldn't find app.yaml. \nPlease run this script from the //sandbox/feature-registry-app directory."; exit 1; } + +# Users: Make sure you have a ./deploy_config.sh file that sets the necessary variables for this script. +[ -f "./deploy_config.sh" ] || { +cat < deploy_config.sh +# Path to a folder in the workspace. E.g. /Workspace/Users/Path/To/App/Code +export DEST="" +# Name of the App to deploy. E.g. your-app-name +export APP_NAME="" +EOF +echo_red "Please update deploy_config.sh and run again." +exit 1; +} +source ./deploy_config.sh + +databricks sync --full ./src $DEST +databricks apps deploy $APP_NAME --source-code-path $DEST diff --git a/feature-registry-app/pytest.ini b/feature-registry-app/pytest.ini new file mode 100644 index 00000000..8b937927 --- /dev/null +++ b/feature-registry-app/pytest.ini @@ -0,0 +1,24 @@ +[pytest] +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* + +# Add src directory to Python path +pythonpath = src + +# Coverage settings +[coverage:run] +source = src +omit = + */__pycache__/* + */tests/* + +[coverage:report] +exclude_lines = + pragma: no cover + def __repr__ + raise NotImplementedError + if __name__ == .__main__.: + pass + raise ImportError diff --git a/feature-registry-app/src/app.yaml b/feature-registry-app/src/app.yaml new file mode 100644 index 00000000..040a110d --- /dev/null +++ b/feature-registry-app/src/app.yaml @@ -0,0 +1,11 @@ +command: [ + "streamlit", + "run", + "feature_registry.py" +] + +env: + - name: STREAMLIT_BROWSER_GATHER_USAGE_STATS + value: "false" + - name: UC_CATALOGS_ALLOWLIST # Set this to the path of the yaml file that contains the allow-listed UC catalogs. The Feature Registry App will restrict the search of features only to this list of catalogs. + value: "uc_catalogs_allowlist.yaml" diff --git a/feature-registry-app/src/clients/uc_client.py b/feature-registry-app/src/clients/uc_client.py new file mode 100644 index 00000000..81d946aa --- /dev/null +++ b/feature-registry-app/src/clients/uc_client.py @@ -0,0 +1,24 @@ +from databricks.sdk import WorkspaceClient + + +class UcClient: + def __init__(self, user_access_token: str): + self.w = WorkspaceClient(token=user_access_token, auth_type="pat") + + def get_catalogs(self): + return self.w.catalogs.list(include_browse=False) + + def get_schemas(self, catalog_name: str): + return self.w.schemas.list(catalog_name=catalog_name) + + def get_tables(self, catalog_name: str, schema_name: str): + return self.w.tables.list(catalog_name=catalog_name, schema_name=schema_name) + + def get_table(self, full_name: str): + return self.w.tables.get(full_name=full_name) + + def get_functions(self, catalog_name: str, schema_name: str): + return self.w.functions.list(catalog_name=catalog_name, schema_name=schema_name) + + def get_function(self, full_name: str): + return self.w.functions.get(name=full_name) diff --git a/feature-registry-app/src/entities/__init__.py b/feature-registry-app/src/entities/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/feature-registry-app/src/entities/features.py b/feature-registry-app/src/entities/features.py new file mode 100644 index 00000000..38968065 --- /dev/null +++ b/feature-registry-app/src/entities/features.py @@ -0,0 +1,72 @@ +from typing import Any, Dict, List, Optional, Tuple + +from pydantic import BaseModel + +from .tables import Table + + +class MaterializedInfo(BaseModel): + schema_name: str + table_name: str + primary_keys: List[str] + timeseries_columns: List[str] + + +class Feature: + def __init__( + self, name: str, table: Table, pks: List[str], ts: Optional[List[str]] = None + ): + self.name = name + self.table = table + self.pks = pks + self.ts = ts or [] + + def get_materialized_info(self) -> MaterializedInfo: + return MaterializedInfo( + schema_name=self.table.schema(), + table_name=self.table.name(), + primary_keys=self.pks or [], + timeseries_columns=self.ts or [], + ) + + def description(self) -> str: + for column in self.table.uc_table.columns: + if column.name == self.name: + return column.comment + return "" + + def components(self) -> Tuple[str, str, str]: + return self.name, self.table.full_name(), ", ".join(self.pks) + + def metadata(self) -> Dict[str, Any]: + return { + "Table Name": self.table.full_name(), + "Primary Keys": self.pks, + "Timeseries Columns": self.ts, + "# of Features": len(self.table.uc_table.columns) - len(self.pks), + "Table Type": self.table.uc_table.table_type.name, + } + + def inputs(self) -> Dict[str, str] | None: + return None + + def outputs(self) -> Dict[str, str] | None: + return None + + def code(self) -> str: + return self.table.uc_table.view_definition + + def table_name(self) -> str: + return self.table.full_name() + + def full_name(self) -> str: + return f"{self.table.full_name()}.{self.name}" + + +class SelectableFeature: + def __init__(self, feature: Feature, selected: bool = False): + self.feature = feature + self.selected = selected + + def components(self) -> Tuple[bool, str, str, str]: + return (self.selected,) + self.feature.components() diff --git a/feature-registry-app/src/entities/functions.py b/feature-registry-app/src/entities/functions.py new file mode 100644 index 00000000..7aff41c1 --- /dev/null +++ b/feature-registry-app/src/entities/functions.py @@ -0,0 +1,30 @@ +from typing import Any, Dict, Tuple + +from databricks import sdk +from pydantic import BaseModel + + +class FeatureFunction(BaseModel): + function: sdk.service.catalog.FunctionInfo + + def full_name(self) -> str: + return self.function.full_name + + def components(self) -> Tuple[str, str, Any, Any]: + return self.full_name(), "feature spec", None, None + + def metadata(self) -> Dict[str, Any] | None: + return None + + def inputs(self) -> Dict[str, str] | None: + if self.function.input_params and self.function.input_params.parameters: + return {p.name: p.type_text for p in self.function.input_params.parameters} + return None + + def outputs(self) -> Dict[str, str] | None: + if self.function.return_params and self.function.return_params.parameters: + return {p.name: p.type_text for p in self.function.return_params.parameters} + return None + + def code(self) -> str: + return self.function.routine_definition diff --git a/feature-registry-app/src/entities/tables.py b/feature-registry-app/src/entities/tables.py new file mode 100644 index 00000000..adee5680 --- /dev/null +++ b/feature-registry-app/src/entities/tables.py @@ -0,0 +1,20 @@ +from typing import Tuple + +from databricks import sdk + + +class Table: + def __init__(self, uc_table: sdk.service.catalog.TableInfo): + self.uc_table = uc_table + + def full_name(self) -> str: + return self.uc_table.full_name + + def name(self) -> str: + return self.uc_table.name + + def schema(self) -> str: + return self.uc_table.schema_name + + def components(self) -> Tuple[str, str, str]: + return self.uc_table.catalog_name, self.uc_table.schema_name, self.uc_table.name diff --git a/feature-registry-app/src/feature_registry.py b/feature-registry-app/src/feature_registry.py new file mode 100644 index 00000000..c1b2156c --- /dev/null +++ b/feature-registry-app/src/feature_registry.py @@ -0,0 +1,6 @@ +import streamlit as st + +pages = [st.Page("navigator/explore/features.py", title="Features")] + +pg = st.navigation(pages=pages) +pg.run() diff --git a/feature-registry-app/src/navigator/__init__.py b/feature-registry-app/src/navigator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/feature-registry-app/src/navigator/explore/__init__.py b/feature-registry-app/src/navigator/explore/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/feature-registry-app/src/navigator/explore/features.py b/feature-registry-app/src/navigator/explore/features.py new file mode 100644 index 00000000..39273725 --- /dev/null +++ b/feature-registry-app/src/navigator/explore/features.py @@ -0,0 +1,335 @@ +from typing import Dict, List + +import pandas as pd +import streamlit as st +import yaml + +from entities.features import SelectableFeature +from services.caches import CachedUcData +from services.filters import FeatureFilters +from session_manager import SessionFeatures +from utils.code_gen import ( + generate_create_feature_spec_code, + generate_create_training_set_code, +) +from databricks.sdk.errors import PermissionDenied + +# Constants for dataframe columns +SELECTED_COLUMN = "Is Selected" +FEATURE_NAME = "Feature" +DESCRIPTION = "Description" +SCHEMA_COLUMN = "Schema Name" +TABLE_NAME = "Table Name" +PRIMARY_KEYS = "Primary Keys" +TIMESERIES_COLUMNS = "Timeseries Columns" + +# Required scopes to use this app +REQUIRED_SCOPES = [ + "catalog.catalogs:read", + "catalog.schemas:read", + "catalog.tables:read", +] + +SessionFeatures.initialize_state() + + +def create_features_dataframe( + features: List[SelectableFeature], show_schema_column: bool = False +) -> pd.DataFrame: + data = [] + for f in features: + m_info = f.feature.get_materialized_info() + # extract data in the required order + row = ( + f.selected, + f.feature.name, + f.feature.description(), + *((m_info.schema_name,) if show_schema_column else ()), + m_info.table_name, + ", ".join(m_info.primary_keys), + ", ".join(m_info.timeseries_columns) if m_info.timeseries_columns else "", + ) + data.append(tuple(row)) + non_editable_columns = ( + [FEATURE_NAME, DESCRIPTION] + + ([SCHEMA_COLUMN] if show_schema_column else []) + + [TABLE_NAME, PRIMARY_KEYS, TIMESERIES_COLUMNS] + ) + columns = [SELECTED_COLUMN] + non_editable_columns + return pd.DataFrame(data=data, columns=columns), non_editable_columns + + +filters = FeatureFilters() +uc_data = CachedUcData() + +st.set_page_config(layout="wide") + +try: + # Reduce padding on the top of the page above the title. + st.markdown( + """ + + """, + unsafe_allow_html=True, + ) + st.title("Explore features from Unity Catalog") + + catalog_schema_choices, filter_selection = st.columns( + [2, 3], gap="large", vertical_alignment="top" + ) + with catalog_schema_choices: + st.subheader("Select catalog and schema:") + catalog_ui_col, schema_ui_col = st.columns( + [1, 1], gap="large", vertical_alignment="top" + ) + with catalog_ui_col: + catalog = st.selectbox( + label="Catalog", options=(uc_data.get_catalogs()) + ) + with schema_ui_col: + # By default, search through all schemas in the catalog. + # If a schema is selected, teh app will rerun and search only in that schema. + schema = st.selectbox( + label="Schema", + options=uc_data.get_schemas(catalog), + index=None, + placeholder="All schemas", + ) + with st.container(): + st.markdown("") + search_msg = st.empty() + with filter_selection: + with st.expander("Filters"): + st.subheader("Filter features:") + with st.form("Filters", border=False): + st.text_input( + label="by name", + key="name_filter", + placeholder="Enter partial feature name. For example: 'max' or '7_days'.", + ) + st.text_input( + label="by table", + key="table_filter", + placeholder="Enter partial table name.", + ) + st.text_input( + label="by description", + key="description_filter", + placeholder="Enter partial description", + ) + b1, b2, _ = st.columns([0.18, 0.18, 0.64]) + b1.form_submit_button("Filter", type="primary") + if b2.form_submit_button("Clear"): + filters.clear() + + # update session state with current selected catalog and schema names + SessionFeatures.set_current_catalog(catalog) + SessionFeatures.set_current_schema(schema) + + # Generate a list of features for the selected catalog and schema (or all schemas) + # Save searched features + all_features = [] + schemas = [schema] if schema else uc_data.get_schemas(catalog) + for schema in schemas: + # check if features for this schema are already saved + features_list_by_table = SessionFeatures.get_saved_features(catalog, schema) + if features_list_by_table is None: + # fetch features for this catalog/schema + tables_data = uc_data.get_tables(catalog, schema) + feature_data = uc_data.get_features(tables_data) + features_list_by_table = { + table_name: [ + SelectableFeature(feature=feature, selected=False) + for feature in features + ] + for table_name, features in feature_data.items() + } + SessionFeatures.update_saved_features( + catalog, schema, features_list_by_table + ) + num_features = sum( + [len(features) for features in features_list_by_table.values()] + ) + search_msg.markdown( + f":hourglass_flowing_sand: Found **{num_features}** features in **`{catalog}.{schema}`**" + ) + all_features += [ + feature + for features in features_list_by_table.values() + for feature in features + ] + search_msg.markdown( + f":white_check_mark: Search complete! Found **{len(all_features)}** features in **`{catalog}`**" + ) + + DATA_EDITOR_UPDATES_KEY = "data_editor_updates" + + def change_state(ordered_feature_list): + """This callback is called when user selects or unselects a feature.""" + + # edits made by the user that triggered the callback + edited = st.session_state[DATA_EDITOR_UPDATES_KEY] + # create a change log of all updated feature selected by schema. Record their *new* selected/unselected state + changes_by_schema: Dict[str, Dict[str, bool]] = {} + for index, selection in edited.get("edited_rows", {}).items(): + # identify the feature that was changed using index from the ordered feature list + if selection.get(SELECTED_COLUMN) is not None: + feature = ordered_feature_list[index] + schema = feature.feature.table.schema() + if schema not in changes_by_schema: + changes_by_schema[schema] = {} + changes_by_schema[schema][feature.feature.full_name()] = selection[ + SELECTED_COLUMN + ] + + # iterate through the changes and update the selected state + for schema, changes in changes_by_schema.items(): + # fetch current state for this catalog and schema + features_in_schema = SessionFeatures.get_saved_features( + SessionFeatures.current_catalog(), schema + ) + + # create an updated list of features for each table in current schema + updated_all_features = { + table_name: [ + SelectableFeature( + feature=f.feature, + selected=changes.get(f.feature.full_name(), f.selected), + ) + for f in features + ] + for table_name, features in features_in_schema.items() + } + # update the session state + SessionFeatures.update_saved_features( + SessionFeatures.current_catalog(), schema, updated_all_features + ) + + # Filter the features in this catalog/schema based on selected filters + if filters.enabled(): + ordered_feature_list = [] + total_features, matching_features = 0, 0 + for feature in all_features: + total_features += 1 + if filters.include(feature.feature): + matching_features += 1 + ordered_feature_list.append(feature) + search_msg.markdown( + f":white_check_mark: Search complete! " + f"Found {matching_features} of {total_features} features matching filters." + ) + else: + ordered_feature_list = all_features + + # show column for schema name only if no schema was selected from dropdown + show_schema_column = SessionFeatures.current_schema() is None + features_df, non_editable_column_names = create_features_dataframe( + ordered_feature_list, show_schema_column + ) + + # use a data editor to allow user to select features. User can only select not change any other values. + st.data_editor( + features_df, + disabled=non_editable_column_names, + column_config={ + SELECTED_COLUMN: st.column_config.CheckboxColumn( + label="Selected Features", width="small" + ) + }, + use_container_width=True, + hide_index=True, + key=DATA_EDITOR_UPDATES_KEY, + on_change=change_state, + args=(ordered_feature_list,), + ) + + # show selected features + if SessionFeatures.total_selected_features(): + ( + create_training_set_tab, + create_feature_spec_tab, + selected_features_tab, + ) = st.tabs( + [ + "Create Training Set", + "Create Feature Spec", + "Selected Features", + ] + ) + + # Show code gen for create training set + with create_training_set_tab: + st.markdown( + "Copy the code below to your notebook to create a training set with selected features." + ) + code = generate_create_training_set_code( + SessionFeatures.get_all_selected_features_by_table() + ) + st.code(code, language="python") + + # Show code gen for feature spec + with create_feature_spec_tab: + st.markdown( + "Copy the code below to your notebook to create a feature spec with selected features." + ) + code = generate_create_feature_spec_code( + SessionFeatures.get_all_selected_features_by_table() + ) + st.code(code, language="python") + + # Create an expander to show selected features. Automatically expand if there are selected features. + with selected_features_tab: + curr_catalog_col, other_catalog_col = st.columns( + [1, 1], gap="large", vertical_alignment="top" + ) + selected = SessionFeatures.all_selected_features() + with curr_catalog_col: + # Show selected features for the current catalog + st.markdown(f"Current catalog `{SessionFeatures.current_catalog()}`") + features_by_table_by_schema = selected.get( + SessionFeatures.current_catalog(), {} + ) + if any( + [ + f + for s_data in features_by_table_by_schema.values() + for features in s_data.values() + for f in features + ] + ): + formatted_features = yaml.dump( + dict(features_by_table_by_schema), sort_keys=False + ) + st.code(formatted_features, language="yaml") + with other_catalog_col: + # Show selected features for other catalogs + st.markdown(f"Other selected features") + if SessionFeatures.current_catalog() in selected: + # exclude current catalog + del selected[SessionFeatures.current_catalog()] + if any( + [ + f + for c_data in selected.values() + for s_data in c_data.values() + for features in s_data.values() + for f in features + ] + ): + formatted_features = yaml.dump(selected, sort_keys=False) + st.code(formatted_features, language="yaml") + +except PermissionDenied as e: + st.error(f"Permission Denied: The following scopes are required to use this app: {', '.join([f'`{s}`' for s in REQUIRED_SCOPES])}", icon="🚫") +except Exception as e: + st.error(f"An error occurred: {e}") + st.exception(e) diff --git a/feature-registry-app/src/requirements.txt b/feature-registry-app/src/requirements.txt new file mode 100644 index 00000000..11545ab6 --- /dev/null +++ b/feature-registry-app/src/requirements.txt @@ -0,0 +1 @@ +databricks-sdk>=0.58.0 diff --git a/feature-registry-app/src/services/__init__.py b/feature-registry-app/src/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/feature-registry-app/src/services/caches.py b/feature-registry-app/src/services/caches.py new file mode 100644 index 00000000..a1d4c45e --- /dev/null +++ b/feature-registry-app/src/services/caches.py @@ -0,0 +1,85 @@ +from typing import Dict, List + +import streamlit as st +import os +import yaml + +from clients.uc_client import UcClient +from entities.features import Feature +from entities.tables import Table +from utils.auth import get_user_access_token + +INTERNAL_SCHEMAS = {"information_schema"} + + +class CachedUcData: + def __init__(self): + # Reuse uc client + if "uc_client" not in st.session_state: + st.session_state.uc_client = UcClient(get_user_access_token()) + # Load and cache catalogs allowlist + # the cache only contains names of allowlisted catalogs. In order to access data from the catalogs, the app will use the caller's read access tokens. + if "catalogs_allowlist" not in st.session_state: + st.session_state.catalogs_allowlist = self._load_catalogs_allowlist_from_env() + + @staticmethod + def get_tables(catalog_name: str, schema_name: str) -> List[Table]: + return [ + Table(uc_table=t) + for t in st.session_state.uc_client.get_tables( + catalog_name=catalog_name, schema_name=schema_name + ) + ] + + @staticmethod + def get_features(tables: List[Table]) -> Dict[str, List[Feature]]: + """Returns a dictionary of features for each table""" + features: Dict[str, List[Feature]] = {} + for t in tables: + tab = st.session_state.uc_client.get_table(full_name=t.full_name()) + if tab.table_constraints: + for tc in tab.table_constraints: + if tc.primary_key_constraint: + all_pks = tc.primary_key_constraint.child_columns + ts = tc.primary_key_constraint.timeseries_columns or [] + pks = [pk for pk in all_pks if pk not in ts] + these_features = [ + Feature(name=c.name, table=t, pks=pks, ts=ts) + for c in tab.columns + if c.name not in all_pks + ] + features[t.name()] = these_features + break + return features + + @staticmethod + def get_catalogs() -> List[str]: + catalogs_allowlist = st.session_state.get("catalogs_allowlist") + # if allowlist is not empty, return the allowlisted catalogs + if catalogs_allowlist: + return catalogs_allowlist + # else return all catalogs available to the user + return [c.name for c in st.session_state.uc_client.get_catalogs()] + + @staticmethod + def get_schemas(catalog_name: str) -> List[str]: + # skip internal schemas + return [ + s.name + for s in st.session_state.uc_client.get_schemas(catalog_name=catalog_name) + if s.name not in INTERNAL_SCHEMAS + ] + + def _load_catalogs_allowlist_from_env(self) -> List[str]: + env_path = os.getenv("UC_CATALOGS_ALLOWLIST") + if env_path: + try: + with open(env_path) as f: + data = yaml.safe_load(f) + if isinstance(data, list) and all( + isinstance(item, str) for item in data + ): + return data + except Exception: + st.error(f"Error loading catalogs allowlist from {env_path}") + return [] diff --git a/feature-registry-app/src/services/filters.py b/feature-registry-app/src/services/filters.py new file mode 100644 index 00000000..99d3a7f5 --- /dev/null +++ b/feature-registry-app/src/services/filters.py @@ -0,0 +1,56 @@ +import streamlit as st + +from entities.features import Feature + + +class FeatureFilters: + def __init__(self): + # - If app is run for the first time, initialize the filters + # - If app is rerun after user clears form, reset the state of the filters + if "reset_filters_on_init" not in st.session_state: + st.session_state.reset_filters_on_init = False + if ( + st.session_state.reset_filters_on_init + or "name_filter" not in st.session_state + ): + st.session_state.name_filter = "" + if ( + st.session_state.reset_filters_on_init + or "table_filter" not in st.session_state + ): + st.session_state.table_filter = "" + if ( + st.session_state.reset_filters_on_init + or "description_filter" not in st.session_state + ): + st.session_state.description_filter = "" + st.session_state.reset_filters_on_init = False + + @staticmethod + def substring_match(haystack: str, needle: str) -> bool: + """Check if 'needle' can be found in 'haystack', ignoring case.""" + return not needle or needle.lower() in haystack.lower() + + def clear(self): + # User hit clear for. Filters should be clear when app is rerun + st.session_state.reset_filters_on_init = True + st.rerun() + + def enabled(self): + return bool( + st.session_state.name_filter + or st.session_state.table_filter + or st.session_state.description_filter + ) + + def include(self, feature: Feature) -> bool: + """Check if feature should be included based on filters.""" + return ( + self.substring_match(feature.name, st.session_state.name_filter) + and self.substring_match( + feature.table.name(), st.session_state.table_filter + ) + and self.substring_match( + feature.description() or "", st.session_state.description_filter + ) + ) diff --git a/feature-registry-app/src/session_manager.py b/feature-registry-app/src/session_manager.py new file mode 100644 index 00000000..2bba3314 --- /dev/null +++ b/feature-registry-app/src/session_manager.py @@ -0,0 +1,134 @@ +from typing import Dict, List + +import pandas as pd +import streamlit as st + +from entities.features import SelectableFeature + + +class SessionFeatures: + @staticmethod + def initialize_state(): + if "catalog" not in st.session_state: + st.session_state.catalog = None + if "schema" not in st.session_state: + st.session_state.schema = None + if "features" not in st.session_state: + # features: Dict[str, Dict[str, Dict[str, List[SelectableFeature]]]] = {} + # catalog_name -> schema_name -> table name -> list of features + st.session_state.features = {} + + @staticmethod + def current_catalog() -> str: + return st.session_state.catalog + + @staticmethod + def set_current_catalog(catalog: str) -> None: + st.session_state.catalog = catalog + + @staticmethod + def current_schema() -> str: + return st.session_state.schema + + @staticmethod + def set_current_schema(schema: str) -> None: + st.session_state.schema = schema + + @staticmethod + def update_saved_features( + catalog: str, schema: str, features: Dict[str, List[SelectableFeature]] + ) -> None: + # Updates the saved features by table name for the given catalog and schema. + if catalog not in st.session_state.features: + st.session_state.features[catalog] = {} + st.session_state.features[catalog][schema] = features + + @staticmethod + def get_saved_features( + catalog: str, schema: str + ) -> Dict[str, List[SelectableFeature]]: + # Returns the saved features by table name for the given catalog and schema. + if catalog not in st.session_state.features: + return None + if schema not in st.session_state.features[catalog]: + return None + return st.session_state.features[catalog][schema] + + @staticmethod + def total_selected_features() -> int: + selected_features = 0 + for catalog_to_data in st.session_state.features.values(): + for table_to_data in catalog_to_data.values(): + for features in table_to_data.values(): + selected_features += len( + list(filter(lambda f: f.selected, features)) + ) + return selected_features + + @staticmethod + def all_selected_features() -> Dict[str, Dict[str, Dict[str, List[str]]]]: + """ + Retrieve all selected features organized in a nested dictionary structure. + + The return value is a dictionary with the following hierarchy: + - Catalog (str): The top-level key representing the catalog name. + - Schema (str): A nested key representing the schema name within the catalog. + - Table (str): A nested key representing the table name within the schema. + - Feature Names (List[str]): A list of feature names (strings) that are selected. + + Example: + { + "catalog1": { + "schema1": { + "table1": ["feature1", "feature2"], + "table2": ["feature3"] + }, + "schema2": { + "table3": ["feature4"] + } + }, + "catalog2": { + "schema3": { + "table4": ["feature5", "feature6"] + } + } + } + + Returns: + Dict[str, Dict[str, Dict[str, List[str]]]]: A nested dictionary of selected features. + """ + selected_features = {} + for catalog, catalog_to_data in st.session_state.features.items(): + for schema, schema_to_data in catalog_to_data.items(): + for table, features in schema_to_data.items(): + for feature in features: + if feature.selected: + # only create hierarchy if any feature is selected + if catalog not in selected_features: + selected_features[catalog] = {} + if schema not in selected_features[catalog]: + selected_features[catalog][schema] = {} + if table not in selected_features[catalog][schema]: + selected_features[catalog][schema][table] = [] + selected_features[catalog][schema][table].append( + feature.feature.name + ) + return selected_features + + @staticmethod + def get_all_selected_features_by_table() -> Dict[str, List[SelectableFeature]]: + """ + Retrieve all selected features grouped by table. + + The return value is a dictionary where the keys are table full names and the values are lists of selected features. + """ + + selected_features_by_table = {} + for catalog, catalog_to_data in st.session_state.features.items(): + for schema, schema_to_data in catalog_to_data.items(): + for table, features in schema_to_data.items(): + selected_features = [f for f in features if f.selected] + if selected_features: + full_table_name = f"{catalog}.{schema}.{table}" + selected_features_by_table[full_table_name] = selected_features + return selected_features_by_table diff --git a/feature-registry-app/src/uc_catalogs_allowlist.yaml b/feature-registry-app/src/uc_catalogs_allowlist.yaml new file mode 100644 index 00000000..9691569a --- /dev/null +++ b/feature-registry-app/src/uc_catalogs_allowlist.yaml @@ -0,0 +1,6 @@ +# Use this file to allowlist only those catalogs from Unity Catalog metastore where data scientists can share or explore existing features. +# The Feature Registry App will only explore features in the listed catalogs. +# If not specified, all catalogs available to the user will be shown. + +# - catalog1 +# - catalog2 diff --git a/feature-registry-app/src/utils/__init__.py b/feature-registry-app/src/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/feature-registry-app/src/utils/auth.py b/feature-registry-app/src/utils/auth.py new file mode 100644 index 00000000..f096ce7a --- /dev/null +++ b/feature-registry-app/src/utils/auth.py @@ -0,0 +1,10 @@ +import streamlit as st + +def get_user_access_token(): + if "user_access_token" not in st.session_state: + # The token is sent in header as x-forwarded-access-token by Databricks when the app is deployed to a Databricks workspace. + st.session_state.user_access_token = st.context.headers.get('x-forwarded-access-token') + if not st.session_state.user_access_token: + st.error("The app must be deployed to a Databricks workspace. Refer to the Databricks documentation for custom app deployment.") + st.stop() + return st.session_state.user_access_token diff --git a/feature-registry-app/src/utils/code_gen.py b/feature-registry-app/src/utils/code_gen.py new file mode 100644 index 00000000..3bbebc71 --- /dev/null +++ b/feature-registry-app/src/utils/code_gen.py @@ -0,0 +1,97 @@ +from collections import defaultdict +from typing import Dict, List, Optional + +from entities.features import SelectableFeature + + +def generate_create_training_set_code( + features_by_table: Dict[str, List[SelectableFeature]] +) -> str: + + return f"""from databricks.feature_engineering import FeatureEngineeringClient, FeatureLookup + +feature_lookups = [{_generate_feature_lookups_code(features_by_table)} +] + +fe = FeatureEngineeringClient() + +training_set = fe.create_training_set( + df=training_df, # ⚠️ Make sure this is your labeled DataFrame! + feature_lookups=feature_lookups, + label='label' # 🏷️ Replace 'label' with your actual label column name +) +""" + + +def generate_create_feature_spec_code( + features_by_table: Dict[str, List[SelectableFeature]] +) -> str: + + return f"""from databricks.feature_engineering import FeatureEngineeringClient, FeatureLookup +features = [{_generate_feature_lookups_code(features_by_table)} +] + +fe = FeatureEngineeringClient() + +# Create a `FeatureSpec` with the features defined above. +# The `FeatureSpec` can be accessed in Unity Catalog as a function. +# They can be used to create training sets or feature serving endpoints +fe.create_feature_spec( + name="..", # 📚 Replace with your actual catalog, schema, and feature spec name + features=features, +) +""" + + +def _generate_feature_lookups_code( + features_by_table: Dict[str, List[SelectableFeature]] +) -> str: + + primary_keys_by_table = {} + timeseries_columns_by_table = {} + features_names_by_table = defaultdict(list) + + # Group features by table + for table_name, features in features_by_table.items(): + features_names_by_table[table_name] = [ + feature.feature.name for feature in features + ] + primary_keys_by_table[table_name] = features[0].feature.pks if features else [] + timeseries_columns_by_table[table_name] = ( + features[0].feature.ts if features and features[0].feature.ts else None + ) + + feature_lookup_codes = [ + _generate_feature_lookup_code( + table_name, + primary_keys_by_table[table_name], + features_names_by_table[table_name], + timeseries_columns_by_table[table_name], + ) + for table_name in features_by_table + ] + return ",".join(feature_lookup_codes) + + +def _generate_feature_lookup_code( + table_name: str, + primary_keys: List[str], + features_by_table: List[str], + timeseries_columns: Optional[str] = None, +) -> str: + return ( + f""" + FeatureLookup( + table_name="{table_name}", + feature_names={features_by_table}, + lookup_keys={primary_keys}, # 🔑 Replace with your actual lookup keys if needed + )""" + if not timeseries_columns + else f""" + FeatureLookup( + table_name="{table_name}", + feature_names={features_by_table}, + lookup_keys={primary_keys}, # 🔑 Replace with your actual lookup keys if needed + timeseries_columns='{timeseries_columns[0]}' # ⏱️ Replace with your actual timeseries columns if needed + )""" + ) diff --git a/feature-registry-app/test-requirements.txt b/feature-registry-app/test-requirements.txt new file mode 100644 index 00000000..be9921a7 --- /dev/null +++ b/feature-registry-app/test-requirements.txt @@ -0,0 +1,10 @@ +# Test requirements for Feature Registry + +-r src/requirements.txt + +pytest>=7.0.0 +pytest-mock>=3.10.0 +pytest-cov>=4.1.0 +mock>=5.0.0 +pydantic>=1.10.0 # Required for entity models +coverage>=7.0.0 # For test coverage reports diff --git a/feature-registry-app/tests/README.md b/feature-registry-app/tests/README.md new file mode 100644 index 00000000..73c3932b --- /dev/null +++ b/feature-registry-app/tests/README.md @@ -0,0 +1,44 @@ +# Feature Registry Tests + +This directory contains unit tests for the Feature Registry Streamlit application. + +## Setup + +To install the required test dependencies: + +```bash +pip install -r ../test-requirements.txt +``` + +## Running Tests + +To run all tests: + +```bash +python -m pytest +``` + +To run tests with coverage: + +```bash +python -m pytest --cov=src +``` + +To run tests for a specific module: + +```bash +python -m pytest tests/entities/test_features.py +``` + +## Test Structure + +The tests are organized to mirror the structure of the main application: + +- `entities/`: Tests for entity classes + - `test_features.py`: Tests for feature-related classes + - `test_tables.py`: Tests for table-related classes + - `test_functions.py`: Tests for function-related classes + +## Notes + +These tests use mocks for the databricks-sdk components since they cannot be called directly from the test environment. \ No newline at end of file diff --git a/feature-registry-app/tests/__init__.py b/feature-registry-app/tests/__init__.py new file mode 100644 index 00000000..3ccc7a98 --- /dev/null +++ b/feature-registry-app/tests/__init__.py @@ -0,0 +1 @@ +# This file makes the tests directory a Python package diff --git a/feature-registry-app/tests/clients/__init__.py b/feature-registry-app/tests/clients/__init__.py new file mode 100644 index 00000000..7bf1574e --- /dev/null +++ b/feature-registry-app/tests/clients/__init__.py @@ -0,0 +1 @@ +# This file makes the clients test directory a Python package diff --git a/feature-registry-app/tests/clients/test_uc_client.py b/feature-registry-app/tests/clients/test_uc_client.py new file mode 100644 index 00000000..c938bdc2 --- /dev/null +++ b/feature-registry-app/tests/clients/test_uc_client.py @@ -0,0 +1,149 @@ +import unittest +from unittest.mock import patch, MagicMock + +import sys +import os + +# Add the src directory to the Python path +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../src'))) + +from src.clients.uc_client import UcClient + + +class TestUcClient(unittest.TestCase): + + def setUp(self): + # Create a patch for the WorkspaceClient + self.workspace_client_patcher = patch('src.clients.uc_client.WorkspaceClient') + self.mock_workspace_client_class = self.workspace_client_patcher.start() + + # Create mock WorkspaceClient instance + self.mock_workspace_client = MagicMock() + self.mock_workspace_client_class.return_value = self.mock_workspace_client + + # Create mock catalog, schema, table, and function resources + self.mock_catalogs = MagicMock() + self.mock_schemas = MagicMock() + self.mock_tables = MagicMock() + self.mock_functions = MagicMock() + + # Assign mocks to the WorkspaceClient instance + self.mock_workspace_client.catalogs = self.mock_catalogs + self.mock_workspace_client.schemas = self.mock_schemas + self.mock_workspace_client.tables = self.mock_tables + self.mock_workspace_client.functions = self.mock_functions + + self.user_access_token = "test_user_access_token" + + # Create UcClient instance + self.uc_client = UcClient(self.user_access_token) + + def tearDown(self): + self.workspace_client_patcher.stop() + + def test_init(self): + """Test that UcClient initializes with a WorkspaceClient.""" + # Verify that WorkspaceClient was called + self.mock_workspace_client_class.assert_called_with(token=self.user_access_token, auth_type="pat") + # Verify that the client was assigned + self.assertEqual(self.uc_client.w, self.mock_workspace_client) + + def test_get_catalogs(self): + """Test that get_catalogs calls the WorkspaceClient.catalogs.list method with the correct parameters.""" + # Setup mock return value + mock_catalog1 = MagicMock(name='catalog1') + mock_catalog2 = MagicMock(name='catalog2') + mock_catalogs_list = [mock_catalog1, mock_catalog2] + self.mock_catalogs.list.return_value = mock_catalogs_list + + # Call the method + result = self.uc_client.get_catalogs() + + # Verify the result + self.assertEqual(result, mock_catalogs_list) + self.mock_catalogs.list.assert_called_once_with(include_browse=False) + + def test_get_schemas(self): + """Test that get_schemas calls the WorkspaceClient.schemas.list method with the correct parameters.""" + # Setup mock return value + mock_schema1 = MagicMock(name='schema1') + mock_schema2 = MagicMock(name='schema2') + mock_schemas_list = [mock_schema1, mock_schema2] + self.mock_schemas.list.return_value = mock_schemas_list + + # Call the method + result = self.uc_client.get_schemas('test_catalog') + + # Verify the result + self.assertEqual(result, mock_schemas_list) + self.mock_schemas.list.assert_called_once_with(catalog_name='test_catalog') + + def test_get_tables(self): + """Test that get_tables calls the WorkspaceClient.tables.list method with the correct parameters.""" + # Setup mock return value + mock_table1 = MagicMock(name='table1') + mock_table2 = MagicMock(name='table2') + mock_tables_list = [mock_table1, mock_table2] + self.mock_tables.list.return_value = mock_tables_list + + # Call the method + result = self.uc_client.get_tables('test_catalog', 'test_schema') + + # Verify the result + self.assertEqual(result, mock_tables_list) + self.mock_tables.list.assert_called_once_with( + catalog_name='test_catalog', + schema_name='test_schema' + ) + + def test_get_table(self): + """Test that get_table calls the WorkspaceClient.tables.get method with the correct parameters.""" + # Setup mock return value + mock_table = MagicMock(name='table1') + self.mock_tables.get.return_value = mock_table + + # Call the method + result = self.uc_client.get_table('test_catalog.test_schema.test_table') + + # Verify the result + self.assertEqual(result, mock_table) + self.mock_tables.get.assert_called_once_with( + full_name='test_catalog.test_schema.test_table' + ) + + def test_get_functions(self): + """Test that get_functions calls the WorkspaceClient.functions.list method with the correct parameters.""" + # Setup mock return value + mock_function1 = MagicMock(name='function1') + mock_function2 = MagicMock(name='function2') + mock_functions_list = [mock_function1, mock_function2] + self.mock_functions.list.return_value = mock_functions_list + + # Call the method + result = self.uc_client.get_functions('test_catalog', 'test_schema') + + # Verify the result + self.assertEqual(result, mock_functions_list) + self.mock_functions.list.assert_called_once_with( + catalog_name='test_catalog', + schema_name='test_schema' + ) + + def test_get_function(self): + """Test that get_function calls the WorkspaceClient.functions.get method with the correct parameters.""" + # Setup mock return value + mock_function = MagicMock(name='function1') + self.mock_functions.get.return_value = mock_function + + # Call the method + result = self.uc_client.get_function('test_catalog.test_schema.test_function') + + # Verify the result + self.assertEqual(result, mock_function) + self.mock_functions.get.assert_called_once_with( + name='test_catalog.test_schema.test_function' + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/feature-registry-app/tests/entities/__init__.py b/feature-registry-app/tests/entities/__init__.py new file mode 100644 index 00000000..a8914d8e --- /dev/null +++ b/feature-registry-app/tests/entities/__init__.py @@ -0,0 +1 @@ +# This file makes the entities test directory a Python package diff --git a/feature-registry-app/tests/entities/test_features.py b/feature-registry-app/tests/entities/test_features.py new file mode 100644 index 00000000..13fea4aa --- /dev/null +++ b/feature-registry-app/tests/entities/test_features.py @@ -0,0 +1,202 @@ +import unittest +from unittest.mock import patch, MagicMock + +import sys +import os + +# Add the src directory to the Python path +sys.path.insert(0, os.path.abspath('./src')) + +from src.entities.features import Feature, MaterializedInfo, SelectableFeature +from src.entities.tables import Table + + +class TestMaterializedInfo(unittest.TestCase): + + def test_initialization(self): + """Test that MaterializedInfo can be initialized with the correct attributes.""" + info = MaterializedInfo( + schema_name="schema1", + table_name="table1", + primary_keys=["id", "timestamp"], + timeseries_columns=["timestamp"] + ) + + self.assertEqual(info.schema_name, "schema1") + self.assertEqual(info.table_name, "table1") + self.assertEqual(info.primary_keys, ["id", "timestamp"]) + self.assertEqual(info.timeseries_columns, ["timestamp"]) + + +class TestFeature(unittest.TestCase): + + def setUp(self): + # Create a patch for the TableInfo class + self.table_info_patcher = patch('entities.tables.sdk.service.catalog.TableInfo') + self.mock_table_info_class = self.table_info_patcher.start() + + # Create mock uc_table + self.mock_uc_table = MagicMock() + self.mock_uc_table.full_name = 'catalog1.schema1.table1' + self.mock_uc_table.name = 'table1' + self.mock_uc_table.catalog_name = 'catalog1' + self.mock_uc_table.schema_name = 'schema1' + self.mock_uc_table.table_type.name = 'MANAGED' + + # Configure the mock TableInfo class to recognize our mock as an instance + self.mock_table_info_class.side_effect = lambda **kwargs: self.mock_uc_table + # The following line was added by llm agent but does not work with MagicMocks + # ToDo: Identify the right fix + # self.mock_table_info_class.__instancecheck__.return_value = True + + # Create mock columns + self.mock_column1 = MagicMock() + self.mock_column1.name = 'id' + self.mock_column1.comment = 'Primary key' + + self.mock_column2 = MagicMock() + self.mock_column2.name = 'feature1' + self.mock_column2.comment = 'Feature 1 description' + + self.mock_column3 = MagicMock() + self.mock_column3.name = 'feature2' + self.mock_column3.comment = 'Feature 2 description' + + self.mock_uc_table.columns = [self.mock_column1, self.mock_column2, self.mock_column3] + self.mock_uc_table.view_definition = 'SELECT * FROM table1' + + # Create mock table + self.mock_table = Table(uc_table=self.mock_uc_table) + + # Create feature + self.feature = Feature(name='feature1', table=self.mock_table, pks=['id']) + + def tearDown(self): + self.table_info_patcher.stop() + + def test_get_materialized_info(self): + """Test that get_materialized_info returns the correct MaterializedInfo object.""" + result = self.feature.get_materialized_info() + + # Verify the result + self.assertIsInstance(result, MaterializedInfo) + self.assertEqual(result.schema_name, 'schema1') + self.assertEqual(result.table_name, 'table1') + self.assertEqual(result.primary_keys, ['id']) + self.assertEqual(result.timeseries_columns, []) + + def test_description(self): + """Test that description returns the correct comment from the column.""" + result = self.feature.description() + + # Verify the result + self.assertEqual(result, 'Feature 1 description') + + # Test with a feature name that doesn't match any column + feature_unknown = Feature(name='unknown', table=self.mock_table, pks=['id']) + self.assertEqual(feature_unknown.description(), '') + + def test_components(self): + """Test that components returns the correct tuple.""" + result = self.feature.components() + + # Verify the result + self.assertEqual(result, ('feature1', 'catalog1.schema1.table1', 'id')) + + def test_metadata(self): + """Test that metadata returns the correct dictionary.""" + result = self.feature.metadata() + + # Verify the result + self.assertEqual(result['Table Name'], 'catalog1.schema1.table1') + self.assertEqual(result['Primary Keys'], ['id']) + self.assertEqual(result['# of Features'], 2) # 3 columns - 1 primary key + self.assertEqual(result['Table Type'], 'MANAGED') + + def test_inputs_outputs(self): + """Test that inputs and outputs methods return None.""" + self.assertIsNone(self.feature.inputs()) + self.assertIsNone(self.feature.outputs()) + + def test_code(self): + """Test that code returns the view definition from the table.""" + result = self.feature.code() + + # Verify the result + self.assertEqual(result, 'SELECT * FROM table1') + + def test_table_name(self): + """Test that table_name returns the full name of the table.""" + result = self.feature.table_name() + + # Verify the result + self.assertEqual(result, 'catalog1.schema1.table1') + + def test_full_name(self): + """Test that full_name returns the full qualified name of the feature.""" + result = self.feature.full_name() + + # Verify the result + self.assertEqual(result, 'catalog1.schema1.table1.feature1') + + +class TestSelectableFeature(unittest.TestCase): + def setUp(self): + # Create a patch for the TableInfo class which is needed by Table + self.table_info_patcher = patch('entities.tables.sdk.service.catalog.TableInfo') + self.mock_table_info_class = self.table_info_patcher.start() + + # Create mock uc_table with necessary attributes + self.mock_uc_table = MagicMock() + self.mock_uc_table.full_name = 'catalog1.schema1.table1' + self.mock_uc_table.name = 'table1' + self.mock_uc_table.catalog_name = 'catalog1' + self.mock_uc_table.schema_name = 'schema1' + + # Configure the mock TableInfo + self.mock_table_info_class.side_effect = lambda **kwargs: self.mock_uc_table + + # Create an actual Table instance + self.table = Table(uc_table=self.mock_uc_table) + + # Create an actual Feature instance + self.feature = Feature(name='feature1', table=self.table, pks=['id']) + + def tearDown(self): + self.table_info_patcher.stop() + + def test_initialization(self): + """Test SelectableFeature initialization with various parameters""" + # Test with selected=True + selectable = SelectableFeature(feature=self.feature, selected=True) + self.assertEqual(selectable.feature, self.feature) + self.assertTrue(selectable.selected) + + # Test with selected=False + selectable = SelectableFeature(feature=self.feature, selected=False) + self.assertEqual(selectable.feature, self.feature) + self.assertFalse(selectable.selected) + + # Test default selected value (should be False) + selectable = SelectableFeature(feature=self.feature) + self.assertFalse(selectable.selected) + + def test_components(self): + """Test that components returns tuple with selection status and feature components""" + # Test with selected=True + selectable = SelectableFeature(feature=self.feature, selected=True) + self.assertEqual( + selectable.components(), + (True, 'feature1', 'catalog1.schema1.table1', 'id') + ) + + # Test with selected=False + selectable = SelectableFeature(feature=self.feature, selected=False) + self.assertEqual( + selectable.components(), + (False, 'feature1', 'catalog1.schema1.table1', 'id') + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/feature-registry-app/tests/entities/test_tables.py b/feature-registry-app/tests/entities/test_tables.py new file mode 100644 index 00000000..0bc049da --- /dev/null +++ b/feature-registry-app/tests/entities/test_tables.py @@ -0,0 +1,69 @@ +import unittest +from unittest.mock import patch, MagicMock + +import sys +import os + +# Add the src directory to the Python path +sys.path.insert(0, os.path.abspath('./src')) + +from src.entities.tables import Table + + +class TestTable(unittest.TestCase): + + def setUp(self): + # Create a patch for the TableInfo class + self.table_info_patcher = patch('entities.tables.sdk.service.catalog.TableInfo') + self.mock_table_info_class = self.table_info_patcher.start() + + # Create mock uc_table with the proper class type + self.mock_uc_table = MagicMock() + self.mock_uc_table.full_name = 'catalog1.schema1.table1' + self.mock_uc_table.name = 'table1' + self.mock_uc_table.catalog_name = 'catalog1' + self.mock_uc_table.schema_name = 'schema1' + + # Configure the mock TableInfo class to recognize our mock as an instance + self.mock_table_info_class.side_effect = lambda **kwargs: self.mock_uc_table + # The following line was added by llm agent but does not work with MagicMocks + # ToDo: Identify the right fix + # self.mock_table_info_class.__instancecheck__.return_value = True + + # Create table + self.table = Table(uc_table=self.mock_uc_table) + + def tearDown(self): + self.table_info_patcher.stop() + + def test_full_name(self): + """Test that full_name returns the correct full name from the uc_table.""" + result = self.table.full_name() + + # Verify the result + self.assertEqual(result, 'catalog1.schema1.table1') + + def test_name(self): + """Test that name returns the correct table name from the uc_table.""" + result = self.table.name() + + # Verify the result + self.assertEqual(result, 'table1') + + def test_schema(self): + """Test that schema returns the correct schema name from the uc_table.""" + result = self.table.schema() + + # Verify the result + self.assertEqual(result, 'schema1') + + def test_components(self): + """Test that components returns a tuple with catalog, schema, and table names.""" + result = self.table.components() + + # Verify the result + self.assertEqual(result, ('catalog1', 'schema1', 'table1')) + + +if __name__ == '__main__': + unittest.main() diff --git a/feature-registry-app/tests/services/__init__.py b/feature-registry-app/tests/services/__init__.py new file mode 100644 index 00000000..5f73cefa --- /dev/null +++ b/feature-registry-app/tests/services/__init__.py @@ -0,0 +1 @@ +# This file makes the services test directory a Python package diff --git a/feature-registry-app/tests/services/mock_session_state.py b/feature-registry-app/tests/services/mock_session_state.py new file mode 100644 index 00000000..3ae0e82c --- /dev/null +++ b/feature-registry-app/tests/services/mock_session_state.py @@ -0,0 +1,15 @@ +from collections import UserDict + +# Mock for st.session_state +class SessionStateMock(UserDict): + def __getattr__(self, name): + try: + return self.data[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name, value): + if name == "data": + super().__setattr__(name, value) + else: + self.data[name] = value diff --git a/feature-registry-app/tests/services/test_caches.py b/feature-registry-app/tests/services/test_caches.py new file mode 100644 index 00000000..49294c52 --- /dev/null +++ b/feature-registry-app/tests/services/test_caches.py @@ -0,0 +1,163 @@ +import unittest +from unittest.mock import patch, MagicMock + +import sys +import os + +from .mock_session_state import SessionStateMock + +# Tell where to source included imports by adding `./src` in import path +# For example `FeatureFilters` will import Feature from entities.features package. +sys.path.insert(0, os.path.abspath('./src')) + +# Import the constant we want to test +from src.services.caches import INTERNAL_SCHEMAS + + +class TestInternalSchemas(unittest.TestCase): + """Test the INTERNAL_SCHEMAS constant from CachedUcData.""" + + def test_internal_schemas_contains_information_schema(self): + """Test that INTERNAL_SCHEMAS contains 'information_schema'.""" + self.assertIn('information_schema', INTERNAL_SCHEMAS) + + def test_internal_schemas_is_a_set(self): + """Test that INTERNAL_SCHEMAS is a set.""" + self.assertIsInstance(INTERNAL_SCHEMAS, set) + + +class TestCachedUcData(unittest.TestCase): + """ + Tests for the CachedUcData class. + + Note: Due to the complex dependencies of this class, we're testing + specific functionality in isolation rather than the entire class. + """ + + def setUp(self): + # Import the class after mocking its dependencies + from src.services.caches import CachedUcData + self.CachedUcData = CachedUcData + + # Mock the streamlit session state + self.mock_session_state = SessionStateMock() + self.patcher = patch('src.services.caches.st.session_state', self.mock_session_state) + self.mock_st_session_state = self.patcher.start() + + # Mock the streamlit cache_data decorator + self.cache_data_patcher = patch('src.services.caches.st.cache_data') + self.mock_cache_data = self.cache_data_patcher.start() + # Make the cache_data decorator just return the function unchanged + self.mock_cache_data.side_effect = lambda ttl: lambda func: func + + def tearDown(self): + self.patcher.stop() + self.cache_data_patcher.stop() + + def test_init_creates_uc_client_if_not_exists(self): + """Test that __init__ creates a UcClient if it doesn't exist in session state.""" + # Setup: ensure uc_client is not in session state + self.mock_session_state.clear() + + # Create a mock UcClient + mock_uc_client_instance = MagicMock() + with patch('src.services.caches.UcClient') as mock_uc_client: + mock_uc_client.return_value = mock_uc_client_instance + + # Act: initialize CachedUcData + cached_data = self.CachedUcData() + + # Assert: UcClient was created and stored in session state + mock_uc_client.assert_called_once() + self.assertIs(self.mock_session_state['uc_client'], mock_uc_client_instance) + + def test_init_reuses_existing_uc_client(self): + """Test that __init__ reuses an existing UcClient if it exists in session state.""" + # Setup: put a mock UcClient in session state + existing_client = MagicMock() + self.mock_session_state['uc_client'] = existing_client + + with patch('src.services.caches.UcClient') as mock_uc_client: + # Act: initialize CachedUcData + cached_data = self.CachedUcData() + + # Assert: UcClient was not created again + mock_uc_client.assert_not_called() + + # Assert: verify that the uc_client in session state has not changed + self.assertIs(self.mock_session_state['uc_client'], existing_client) + + def test_get_tables_calls_uc_client(self): + """Test that get_tables calls the UcClient's get_tables method with the correct parameters.""" + # Setup: put a mock UcClient in session state + mock_uc_client_instance = MagicMock() + self.mock_session_state['uc_client'] = mock_uc_client_instance + + # Setup mock return value for get_tables + uc_table_1, uc_table_2 = MagicMock(), MagicMock() + mock_uc_client_instance.get_tables.return_value = [uc_table_1, uc_table_2] + + result = self.CachedUcData.get_tables('test_catalog', 'test_schema') + + # Assert: UcClient's get_tables was called with correct parameters + mock_uc_client_instance.get_tables.assert_called_once_with( + catalog_name='test_catalog', + schema_name='test_schema' + ) + + # Assert: CachedUcData.get_tables returned the list of Table instances + self.assertEqual(len(result), 2) + print(result, [uc_table_1, uc_table_2]) + self.assertEqual(result[0].uc_table, uc_table_1) + self.assertEqual(result[1].uc_table, uc_table_2) + + def test_get_catalogs_returns_catalog_names(self): + """Test that get_catalogs returns a list of catalog names.""" + # Setup: put a mock UcClient in session state + mock_uc_client_instance = MagicMock() + self.mock_session_state['uc_client'] = mock_uc_client_instance + + # Setup mock return value for get_catalogs + mock_catalog1 = MagicMock() + mock_catalog1.name = 'catalog1' + mock_catalog2 = MagicMock() + mock_catalog2.name = 'catalog2' + mock_uc_client_instance.get_catalogs.return_value = [mock_catalog1, mock_catalog2] + + # Act: call get_catalogs + result = self.CachedUcData.get_catalogs() + + # Assert: UcClient's get_catalogs was called + mock_uc_client_instance.get_catalogs.assert_called_once() + + # Assert: get_catalogs returned the list of catalog names + self.assertEqual(result, ['catalog1', 'catalog2']) + + def test_get_schemas_filters_internal_schemas(self): + """Test that get_schemas filters out internal schemas.""" + # Setup: put a mock UcClient in session state + mock_uc_client_instance = MagicMock() + self.mock_session_state['uc_client'] = mock_uc_client_instance + + # Setup mock return value for get_schemas + mock_schema1 = MagicMock() + mock_schema1.name = 'schema1' + mock_schema2 = MagicMock() + mock_schema2.name = 'schema2' + mock_info_schema = MagicMock() + mock_info_schema.name = 'information_schema' + mock_uc_client_instance.get_schemas.return_value = [mock_schema1, mock_schema2, mock_info_schema] + + # Act: call get_schemas + result = self.CachedUcData.get_schemas('test_catalog') + + # Assert: UcClient's get_schemas was called with correct parameters + mock_uc_client_instance.get_schemas.assert_called_once_with(catalog_name='test_catalog') + + # Assert: get_schemas returned the list of schema names, excluding information_schema + self.assertEqual(result, ['schema1', 'schema2']) + self.assertNotIn('information_schema', result) + + +if __name__ == '__main__': + unittest.main() diff --git a/feature-registry-app/tests/services/test_filters.py b/feature-registry-app/tests/services/test_filters.py new file mode 100644 index 00000000..ed97f173 --- /dev/null +++ b/feature-registry-app/tests/services/test_filters.py @@ -0,0 +1,174 @@ +import unittest +from unittest.mock import patch, MagicMock + +import sys +import os + +# Tell where to source included imports by adding `./src` in import path +# For example `FeatureFilters` will import Feature from entities.features package. +sys.path.insert(0, os.path.abspath('./src')) + +from src.services.filters import FeatureFilters + +from .mock_session_state import SessionStateMock + +class TestFeatureFilters(unittest.TestCase): + + def setUp(self): + # Mock the streamlit session state + self.mock_session_state = SessionStateMock() + self.patcher = patch('src.services.filters.st.session_state', self.mock_session_state) + self.mock_st_session_state = self.patcher.start() + + # Mock streamlit rerun + self.rerun_patcher = patch('src.services.filters.st.rerun') + self.mock_rerun = self.rerun_patcher.start() + + def tearDown(self): + self.patcher.stop() + self.rerun_patcher.stop() + + def test_init_first_time(self): + """Test initialization when filters are not in session state.""" + self.mock_session_state.clear() + + filters = FeatureFilters() + + # Check that filters were initialized + self.assertIn('name_filter', self.mock_session_state) + self.assertIn('table_filter', self.mock_session_state) + self.assertIn('description_filter', self.mock_session_state) + self.assertIn('reset_filters_on_init', self.mock_session_state) + + self.assertEqual(self.mock_session_state['name_filter'], '') + self.assertEqual(self.mock_session_state['table_filter'], '') + self.assertEqual(self.mock_session_state['description_filter'], '') + self.assertEqual(self.mock_session_state['reset_filters_on_init'], False) + + def test_init_with_reset_flag(self): + """Test initialization when reset flag is set.""" + self.mock_session_state['reset_filters_on_init'] = True + self.mock_session_state['name_filter'] = 'old_name' + self.mock_session_state['table_filter'] = 'old_table' + self.mock_session_state['description_filter'] = 'old_description' + + filters = FeatureFilters() + + # Check that filters were reset + self.assertEqual(self.mock_session_state['name_filter'], '') + self.assertEqual(self.mock_session_state['table_filter'], '') + self.assertEqual(self.mock_session_state['description_filter'], '') + self.assertEqual(self.mock_session_state['reset_filters_on_init'], False) + + def test_substring_match(self): + """Test substring matching functionality.""" + filters = FeatureFilters() + + # Test with empty needle (should always match) + self.assertTrue(filters.substring_match('haystack', '')) + + # Test with matching substring + self.assertTrue(filters.substring_match('haystack', 'hay')) + self.assertTrue(filters.substring_match('haystack', 'stack')) + self.assertTrue(filters.substring_match('haystack', 'yst')) + + # Test with non-matching substring + self.assertFalse(filters.substring_match('haystack', 'needle')) + + # Test case insensitivity + self.assertTrue(filters.substring_match('haystack', 'HAY')) + self.assertTrue(filters.substring_match('HAYSTACK', 'hay')) + + def test_clear(self): + """Test clearing filters.""" + self.mock_session_state['reset_filters_on_init'] = False + + filters = FeatureFilters() + filters.clear() + + # Check that reset flag was set and rerun was called + self.assertTrue(self.mock_session_state['reset_filters_on_init']) + self.mock_rerun.assert_called_once() + + def test_enabled(self): + """Test enabled method.""" + filters = FeatureFilters() + + # Test when all filters are empty + self.mock_session_state['name_filter'] = '' + self.mock_session_state['table_filter'] = '' + self.mock_session_state['description_filter'] = '' + self.assertFalse(filters.enabled()) + + # Test when name filter is set + self.mock_session_state['name_filter'] = 'test' + self.assertTrue(filters.enabled()) + + # Test when table filter is set + self.mock_session_state['name_filter'] = '' + self.mock_session_state['table_filter'] = 'test' + self.assertTrue(filters.enabled()) + + # Test when description filter is set + self.mock_session_state['table_filter'] = '' + self.mock_session_state['description_filter'] = 'test' + self.assertTrue(filters.enabled()) + + def test_include(self): + """Test include method.""" + filters = FeatureFilters() + + # Create mock feature + mock_feature = MagicMock() + mock_feature.name = 'test_feature' + mock_table = MagicMock() + mock_table.name.return_value = 'test_table' + mock_feature.table = mock_table + mock_feature.description.return_value = 'test description' + + # Test with no filters (should include) + self.mock_session_state['name_filter'] = '' + self.mock_session_state['table_filter'] = '' + self.mock_session_state['description_filter'] = '' + self.assertTrue(filters.include(mock_feature)) + + # Test with matching name filter + self.mock_session_state['name_filter'] = 'test' + self.assertTrue(filters.include(mock_feature)) + + # Test with non-matching name filter + self.mock_session_state['name_filter'] = 'nonexistent' + self.assertFalse(filters.include(mock_feature)) + + # Test with matching table filter + self.mock_session_state['name_filter'] = '' + self.mock_session_state['table_filter'] = 'test' + self.assertTrue(filters.include(mock_feature)) + + # Test with non-matching table filter + self.mock_session_state['table_filter'] = 'nonexistent' + self.assertFalse(filters.include(mock_feature)) + + # Test with matching description filter + self.mock_session_state['table_filter'] = '' + self.mock_session_state['description_filter'] = 'test' + self.assertTrue(filters.include(mock_feature)) + + # Test with non-matching description filter + self.mock_session_state['description_filter'] = 'nonexistent' + self.assertFalse(filters.include(mock_feature)) + + # Test with multiple matching filters + self.mock_session_state['name_filter'] = 'test' + self.mock_session_state['table_filter'] = 'test' + self.mock_session_state['description_filter'] = 'test' + self.assertTrue(filters.include(mock_feature)) + + # Test with some matching and some non-matching filters + self.mock_session_state['name_filter'] = 'test' + self.mock_session_state['table_filter'] = 'nonexistent' + self.assertFalse(filters.include(mock_feature)) + + +if __name__ == '__main__': + unittest.main()