diff --git a/uc-quickstart/utils/abac-agent/.gitignore b/uc-quickstart/utils/abac-agent/.gitignore new file mode 100644 index 00000000..622dc875 --- /dev/null +++ b/uc-quickstart/utils/abac-agent/.gitignore @@ -0,0 +1,302 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be added to the global gitignore or merged into this project gitignore. For a PyCharm +# project, it is recommended to ignore entire .idea folder. +.idea/ + +# Visual Studio Code +.vscode/ +*.code-workspace + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +# Streamlit +.streamlit/ +streamlit_cache/ + +# MLflow +mlruns/ +.mlflow/ + +# Databricks +.databricks/ +databricks-cli +*.databricks +.db/ +databricks.cfg + +# Unity Catalog +*.unity + +# Logs +*.log +logs/ +.log/ + +# Temporary files +*.tmp +*.temp +tmp/ +temp/ + +# OS generated files +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Windows +*.lnk + +# Linux +*~ + +# macOS +.AppleDouble +.LSOverride +Icon? +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +# Config files with secrets +config.ini +config.yaml +secrets.json +.secrets +.env.local +.env.production +.env.staging + +# Data files +*.csv +*.parquet +*.json +*.xml +data/ +datasets/ + +# Model files +models/ +*.pkl +*.model +*.h5 +*.pb + +# Checkpoints +checkpoints/ +*.ckpt + +# Jupyter notebook checkpoints +.ipynb_checkpoints/ + +# Pytest +.pytest_cache/ + +# Coverage reports +htmlcov/ +.coverage + +# Backup files +*.bak +*.backup +*.orig + +# Generated documentation +docs/build/ +site/ + +# IDE files +*.swp +*.swo +*~ + +# Local development +local/ +local_* +dev/ +development/ + +# Deployment artifacts +deploy/ +deployment/ + +# Cache directories +.cache/ +cache/ + +# Lock files +*.lock + +# Demo setup +*.sql diff --git a/uc-quickstart/utils/abac-agent/README.md b/uc-quickstart/utils/abac-agent/README.md new file mode 100644 index 00000000..10efd4fb --- /dev/null +++ b/uc-quickstart/utils/abac-agent/README.md @@ -0,0 +1,212 @@ +# šŸ›”ļø ABAC Policy Assistant + +An AI-powered Unity Catalog Attribute-Based Access Control (ABAC) policy generation assistant built with Databricks Agent Framework and Streamlit. + +## šŸš€ Features + +- **Intelligent Table Analysis** - Automatically examines Unity Catalog table structures, columns, and metadata +- **ABAC Policy Generation** - Creates ROW FILTER and COLUMN MASK policy recommendations +- **Tag-Based Conditions** - Generates MATCH COLUMNS and FOR TABLES conditions using `hasTag()` and `hasTagValue()` +- **Real-time Streaming** - Provides streaming responses with tool call visualization +- **Professional UI** - Clean, Databricks-branded interface with responsive design +- **Unity Catalog Integration** - Direct integration with UC functions for metadata retrieval + +## šŸ—ļø Architecture + +``` +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ Streamlit Chat UI │ +ā”œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¤ +│ Databricks Agent Framework │ +ā”œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¤ +│ Unity Catalog Tools │ LLM Endpoint │ +│ • describe_extended_table │ • Claude Sonnet 4 │ +│ • get_table_tags │ • Streaming │ +│ • get_column_tags │ • Tool Calling │ +│ • list_row_filter_column_masking │ │ +│ • list_uc_tables │ │ +ā”œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¼ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”¤ +│ Unity Catalog Metastore │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ +``` + +## šŸ› ļø Prerequisites + +- Databricks workspace with Unity Catalog enabled +- Model serving endpoint with agent capabilities +- Python 3.8+ +- Required Unity Catalog functions deployed + +## šŸ“¦ Installation + +1. **Clone the repository** + ```bash + git clone + cd e2e-chatbot-app + ``` + +2. **Install dependencies** + ```bash + pip install -r requirements.txt + ``` + +3. **Set up Unity Catalog functions** + + Deploy the following functions to your Unity Catalog: + - `enterprise_gov.gov_admin.describe_extended_table` + - `enterprise_gov.gov_admin.get_table_tags` + - `enterprise_gov.gov_admin.get_column_tags` + - `enterprise_gov.gov_admin.list_row_filter_column_masking` + - `enterprise_gov.gov_admin.list_uc_tables` + +4. **Configure environment variables** + ```bash + export SERVING_ENDPOINT="your-agent-endpoint-name" + ``` + +## šŸš€ Usage + +### Local Development + +```bash +streamlit run app.py +``` + +### Databricks Apps Deployment + +1. **Create app.yaml** (already configured) + ```yaml + command: ["streamlit", "run", "app.py"] + env: + - name: STREAMLIT_BROWSER_GATHER_USAGE_STATS + value: "false" + - name: "SERVING_ENDPOINT" + valueFrom: "serving-endpoint" + ``` + +2. **Deploy to Databricks** + ```bash + databricks apps create your-app-name + databricks apps deploy your-app-name --source-dir . + ``` + +## šŸ’¬ Example Queries + +Ask the assistant natural language questions about your Unity Catalog tables: + +``` +"Suggest ABAC policies for enterprise_gov.hr_finance.customers" + +"What table-level access controls should I implement for sensitive customer data?" + +"Generate tag-based ABAC policies with MATCH conditions for the customers table" + +"Analyze my table schema and recommend governance policies" + +"What are the recommended ABAC FOR table conditions for PII data?" +``` + +## šŸ”§ Configuration + +### Agent Configuration (agent.py) + +- **LLM Endpoint**: Configure in `LLM_ENDPOINT_NAME` +- **System Prompt**: Customize ABAC policy generation behavior +- **UC Tools**: Add/remove Unity Catalog functions as needed +- **Vector Search**: Optional integration for document retrieval + +### UI Configuration (app.py) + +- **Databricks Branding**: Colors and styling in CSS +- **Page Layout**: Streamlit page configuration +- **Chat Interface**: Message rendering and interaction flow + +## šŸ“‹ Available Tools + +| Tool Name | Description | +|-----------|-------------| +| `describe_extended_table` | Get detailed table schema and metadata | +| `get_table_tags` | Retrieve table-level tag information | +| `get_column_tags` | Retrieve column-level tag information | +| `list_row_filter_column_masking` | Review existing ABAC policies | +| `list_uc_tables` | Discover tables in catalogs and schemas | + +## šŸ” ABAC Policy Types Supported + +- **ROW FILTER** policies for row-level security +- **COLUMN MASK** policies for column-level protection +- **Tag-based conditions** using `hasTag()` and `hasTagValue()` +- **Multi-table policies** with FOR TABLES conditions +- **Principal-specific** policies with TO/EXCEPT clauses + +## šŸ“Š Example Policy Output + +```sql +CREATE POLICY hide_sensitive_customers +ON SCHEMA enterprise_gov.hr_finance +COMMENT 'Hide rows with sensitive customer data from general analysts' +ROW FILTER filter_sensitive_data +TO general_analysts +FOR TABLES +WHEN hasTag('sensitivity_level') +MATCH COLUMNS + hasTagValue('data_classification', 'sensitive') AS sensitive_col +USING COLUMNS (sensitive_col); +``` + +## šŸ” Troubleshooting + +### Common Issues + +1. **Endpoint Connection Errors** + - Verify `SERVING_ENDPOINT` environment variable + - Check model serving endpoint permissions + - Ensure endpoint supports agent/chat completions + +2. **Unity Catalog Function Errors** + - Verify UC functions are deployed and accessible + - Check function permissions (CAN_EXECUTE) + - Validate function signatures match expected format + +3. **UI Rendering Issues** + - Clear browser cache + - Check Streamlit version compatibility + - Verify CSS styling in different browsers + +### Debug Mode + +Enable debug logging: +```python +import logging +logging.basicConfig(level=logging.DEBUG) +``` + +## šŸ¤ Contributing + +1. Fork the repository +2. Create a feature branch +3. Make your changes +4. Add tests if applicable +5. Submit a pull request + +## šŸ“„ License + +This project is licensed under the MIT License - see the LICENSE file for details. + +## šŸ†˜ Support + +For support and questions: +- Check Databricks documentation: [Agent Framework](https://docs.databricks.com/generative-ai/agent-framework/) +- Review Unity Catalog ABAC docs: [ABAC Policies](https://docs.databricks.com/data-governance/unity-catalog/abac/) +- Open an issue in this repository + +## šŸ“š Additional Resources + +- [Databricks Agent Framework](https://docs.databricks.com/generative-ai/agent-framework/) +- [Unity Catalog ABAC Tutorial](https://docs.databricks.com/data-governance/unity-catalog/abac/tutorial) +- [Streamlit Documentation](https://docs.streamlit.io/) +- [MLflow Agent Evaluation](https://docs.databricks.com/generative-ai/agent-evaluation/) + +--- + +Built with ā¤ļø using Databricks Agent Framework diff --git a/uc-quickstart/utils/abac-agent/app.py b/uc-quickstart/utils/abac-agent/app.py new file mode 100644 index 00000000..41e539a9 --- /dev/null +++ b/uc-quickstart/utils/abac-agent/app.py @@ -0,0 +1,493 @@ +import logging +import os +import streamlit as st +from model_serving_utils import ( + endpoint_supports_feedback, + query_endpoint, + query_endpoint_stream, + _get_endpoint_task_type, +) +from collections import OrderedDict +from messages import UserMessage, AssistantResponse, render_message + +# Configure page +st.set_page_config( + page_title="ABAC Policy Assistant", + page_icon="šŸ›”ļø", + layout="centered", + initial_sidebar_state="auto" +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +SERVING_ENDPOINT = os.getenv('SERVING_ENDPOINT') +assert SERVING_ENDPOINT, \ + ("Unable to determine serving endpoint to use for chatbot app. If developing locally, " + "set the SERVING_ENDPOINT environment variable to the name of your serving endpoint. If " + "deploying to a Databricks app, include a serving endpoint resource named " + "'serving_endpoint' with CAN_QUERY permissions, as described in " + "https://docs.databricks.com/aws/en/generative-ai/agent-framework/chat-app#deploy-the-databricks-app") + +ENDPOINT_SUPPORTS_FEEDBACK = endpoint_supports_feedback(SERVING_ENDPOINT) + +def reduce_chat_agent_chunks(chunks): + """ + Reduce a list of ChatAgentChunk objects corresponding to a particular + message into a single ChatAgentMessage + """ + deltas = [chunk.delta for chunk in chunks] + first_delta = deltas[0] + result_msg = first_delta + msg_contents = [] + + # Accumulate tool calls properly + tool_call_map = {} # Map call_id to tool call for accumulation + + for delta in deltas: + # Handle content + if delta.content: + msg_contents.append(delta.content) + + # Handle tool calls + if hasattr(delta, 'tool_calls') and delta.tool_calls: + for tool_call in delta.tool_calls: + call_id = getattr(tool_call, 'id', None) + tool_type = getattr(tool_call, 'type', "function") + function_info = getattr(tool_call, 'function', None) + if function_info: + func_name = getattr(function_info, 'name', "") + func_args = getattr(function_info, 'arguments', "") + else: + func_name = "" + func_args = "" + + if call_id: + if call_id not in tool_call_map: + # New tool call + tool_call_map[call_id] = { + "id": call_id, + "type": tool_type, + "function": { + "name": func_name, + "arguments": func_args + } + } + else: + # Accumulate arguments for existing tool call + existing_args = tool_call_map[call_id]["function"]["arguments"] + tool_call_map[call_id]["function"]["arguments"] = existing_args + func_args + + # Update function name if provided + if func_name: + tool_call_map[call_id]["function"]["name"] = func_name + + # Handle tool call IDs (for tool response messages) + if hasattr(delta, 'tool_call_id') and delta.tool_call_id: + result_msg = result_msg.model_copy(update={"tool_call_id": delta.tool_call_id}) + + # Convert tool call map back to list + if tool_call_map: + accumulated_tool_calls = list(tool_call_map.values()) + result_msg = result_msg.model_copy(update={"tool_calls": accumulated_tool_calls}) + + result_msg = result_msg.model_copy(update={"content": "".join(msg_contents)}) + return result_msg + + + +# --- Init state --- +if "history" not in st.session_state: + st.session_state.history = [] + +# Databricks-themed CSS styling +st.markdown(""" + +""", unsafe_allow_html=True) + +# Clean, professional header +st.markdown(""" +
+

šŸ›”ļø ABAC Policy Assistant

+

Generate Unity Catalog access control policies with AI

+
+""", unsafe_allow_html=True) + +# Clean sidebar +with st.sidebar: + st.markdown("### šŸ“‹ Quick Guide") + + st.markdown("**How to use:**") + st.markdown(""" + 1. Enter a Unity Catalog table name + 2. Ask for policy recommendations + 3. Review the generated ABAC policies + """) + + + st.markdown("---") + + if st.button("šŸ—‘ļø Clear History"): + st.session_state.history = [] + st.rerun() + + st.markdown("---") + st.caption(f"Endpoint: {SERVING_ENDPOINT}") + +# Simple example section +if len(st.session_state.history) == 0: + st.markdown(""" +
+

Get Started

+

Ask me to analyze any Unity Catalog table and generate ABAC policies. For example:

+
+ "Suggest ABAC policies for catalog.schema.table" +
+
+ """, unsafe_allow_html=True) + + + +# --- Render chat history --- +for i, element in enumerate(st.session_state.history): + element.render(i) + +def query_endpoint_and_render(task_type, input_messages): + """Handle streaming response based on task type.""" + if task_type == "agent/v1/responses": + return query_responses_endpoint_and_render(input_messages) + elif task_type == "agent/v2/chat": + return query_chat_agent_endpoint_and_render(input_messages) + else: # chat/completions + return query_chat_completions_endpoint_and_render(input_messages) + + +def query_chat_completions_endpoint_and_render(input_messages): + """Handle ChatCompletions streaming format.""" + with st.chat_message("assistant"): + response_area = st.empty() + response_area.markdown("_Thinking..._") + + accumulated_content = "" + request_id = None + + try: + for chunk in query_endpoint_stream( + endpoint_name=SERVING_ENDPOINT, + messages=input_messages, + return_traces=ENDPOINT_SUPPORTS_FEEDBACK + ): + if "choices" in chunk and chunk["choices"]: + delta = chunk["choices"][0].get("delta", {}) + content = delta.get("content", "") + if content: + accumulated_content += content + response_area.markdown(accumulated_content) + + if "databricks_output" in chunk: + req_id = chunk["databricks_output"].get("databricks_request_id") + if req_id: + request_id = req_id + + return AssistantResponse( + messages=[{"role": "assistant", "content": accumulated_content}], + request_id=request_id + ) + except Exception: + response_area.markdown("_Ran into an error. Retrying without streaming..._") + messages, request_id = query_endpoint( + endpoint_name=SERVING_ENDPOINT, + messages=input_messages, + return_traces=ENDPOINT_SUPPORTS_FEEDBACK + ) + response_area.empty() + with response_area.container(): + for message in messages: + render_message(message) + return AssistantResponse(messages=messages, request_id=request_id) + + +def query_chat_agent_endpoint_and_render(input_messages): + """Handle ChatAgent streaming format.""" + from mlflow.types.agent import ChatAgentChunk + + with st.chat_message("assistant"): + response_area = st.empty() + response_area.markdown("_Thinking..._") + + message_buffers = OrderedDict() + request_id = None + + try: + for raw_chunk in query_endpoint_stream( + endpoint_name=SERVING_ENDPOINT, + messages=input_messages, + return_traces=ENDPOINT_SUPPORTS_FEEDBACK + ): + response_area.empty() + chunk = ChatAgentChunk.model_validate(raw_chunk) + delta = chunk.delta + message_id = delta.id + + req_id = raw_chunk.get("databricks_output", {}).get("databricks_request_id") + if req_id: + request_id = req_id + if message_id not in message_buffers: + message_buffers[message_id] = { + "chunks": [], + "render_area": st.empty(), + } + message_buffers[message_id]["chunks"].append(chunk) + + partial_message = reduce_chat_agent_chunks(message_buffers[message_id]["chunks"]) + render_area = message_buffers[message_id]["render_area"] + message_content = partial_message.model_dump_compat(exclude_none=True) + with render_area.container(): + render_message(message_content) + + messages = [] + for msg_id, msg_info in message_buffers.items(): + messages.append(reduce_chat_agent_chunks(msg_info["chunks"])) + + return AssistantResponse( + messages=[message.model_dump_compat(exclude_none=True) for message in messages], + request_id=request_id + ) + except Exception: + response_area.markdown("_Ran into an error. Retrying without streaming..._") + messages, request_id = query_endpoint( + endpoint_name=SERVING_ENDPOINT, + messages=input_messages, + return_traces=ENDPOINT_SUPPORTS_FEEDBACK + ) + response_area.empty() + with response_area.container(): + for message in messages: + render_message(message) + return AssistantResponse(messages=messages, request_id=request_id) + + +def query_responses_endpoint_and_render(input_messages): + """Handle ResponsesAgent streaming format using MLflow types.""" + from mlflow.types.responses import ResponsesAgentStreamEvent + + with st.chat_message("assistant"): + response_area = st.empty() + response_area.markdown("_Thinking..._") + + # Track all the messages that need to be rendered in order + all_messages = [] + request_id = None + + try: + for raw_event in query_endpoint_stream( + endpoint_name=SERVING_ENDPOINT, + messages=input_messages, + return_traces=ENDPOINT_SUPPORTS_FEEDBACK + ): + # Extract databricks_output for request_id + if "databricks_output" in raw_event: + req_id = raw_event["databricks_output"].get("databricks_request_id") + if req_id: + request_id = req_id + + # Parse using MLflow streaming event types, similar to ChatAgentChunk + if "type" in raw_event: + event = ResponsesAgentStreamEvent.model_validate(raw_event) + + if hasattr(event, 'item') and event.item: + item = event.item # This is a dict, not a parsed object + + if item.get("type") == "message": + # Extract text content from message if present + content_parts = item.get("content", []) + for content_part in content_parts: + if content_part.get("type") == "output_text": + text = content_part.get("text", "") + if text: + all_messages.append({ + "role": "assistant", + "content": text + }) + + elif item.get("type") == "function_call": + # Tool call + call_id = item.get("call_id") + function_name = item.get("name") + arguments = item.get("arguments", "") + + # Add to messages for history + all_messages.append({ + "role": "assistant", + "content": "", + "tool_calls": [{ + "id": call_id, + "type": "function", + "function": { + "name": function_name, + "arguments": arguments + } + }] + }) + + elif item.get("type") == "function_call_output": + # Tool call output/result + call_id = item.get("call_id") + output = item.get("output", "") + + # Add to messages for history + all_messages.append({ + "role": "tool", + "content": output, + "tool_call_id": call_id + }) + + # Update the display by rendering all accumulated messages + if all_messages: + with response_area.container(): + for msg in all_messages: + render_message(msg) + + return AssistantResponse(messages=all_messages, request_id=request_id) + except Exception: + response_area.markdown("_Ran into an error. Retrying without streaming..._") + messages, request_id = query_endpoint( + endpoint_name=SERVING_ENDPOINT, + messages=input_messages, + return_traces=ENDPOINT_SUPPORTS_FEEDBACK + ) + response_area.empty() + with response_area.container(): + for message in messages: + render_message(message) + return AssistantResponse(messages=messages, request_id=request_id) + + + + +# --- Chat input (must run BEFORE rendering messages) --- +prompt = st.chat_input("Enter a table name or ask about ABAC policies...") +if prompt: + # Get the task type for this endpoint + task_type = _get_endpoint_task_type(SERVING_ENDPOINT) + + # Add user message to chat history + user_msg = UserMessage(content=prompt) + st.session_state.history.append(user_msg) + user_msg.render(len(st.session_state.history) - 1) + + # Convert history to standard chat message format for the query methods + input_messages = [msg for elem in st.session_state.history for msg in elem.to_input_messages()] + + # Handle the response using the appropriate handler + assistant_response = query_endpoint_and_render(task_type, input_messages) + + # Add assistant response to history + st.session_state.history.append(assistant_response) diff --git a/uc-quickstart/utils/abac-agent/app.yaml b/uc-quickstart/utils/abac-agent/app.yaml new file mode 100644 index 00000000..d81e38a3 --- /dev/null +++ b/uc-quickstart/utils/abac-agent/app.yaml @@ -0,0 +1,11 @@ +command: [ + "streamlit", + "run", + "app.py" +] + +env: + - name: STREAMLIT_BROWSER_GATHER_USAGE_STATS + value: "false" + - name: "SERVING_ENDPOINT" + valueFrom: "serving-endpoint" diff --git a/uc-quickstart/utils/abac-agent/driver.py b/uc-quickstart/utils/abac-agent/driver.py new file mode 100644 index 00000000..dd91684c --- /dev/null +++ b/uc-quickstart/utils/abac-agent/driver.py @@ -0,0 +1,456 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC #Tool-calling Agent +# MAGIC +# MAGIC This is an auto-generated notebook created by an AI playground export. In this notebook, you will: +# MAGIC - Author a tool-calling [MLflow's `ResponsesAgent`](https://mlflow.org/docs/latest/api_reference/python_api/mlflow.pyfunc.html#mlflow.pyfunc.ResponsesAgent) that uses the OpenAI client +# MAGIC - Manually test the agent's output +# MAGIC - Evaluate the agent with Mosaic AI Agent Evaluation +# MAGIC - Log and deploy the agent +# MAGIC +# MAGIC This notebook should be run on serverless or a cluster with DBR<17. +# MAGIC +# MAGIC **_NOTE:_** This notebook uses the OpenAI SDK, but AI Agent Framework is compatible with any agent authoring framework, including LlamaIndex or LangGraph. To learn more, see the [Authoring Agents](https://docs.databricks.com/generative-ai/agent-framework/author-agent) Databricks documentation. +# MAGIC +# MAGIC ## Prerequisites +# MAGIC +# MAGIC - Address all `TODO`s in this notebook. + +# COMMAND ---------- + +# MAGIC %pip install -U -qqqq backoff databricks-openai uv databricks-agents mlflow-skinny[databricks] +# MAGIC dbutils.library.restartPython() + +# COMMAND ---------- + +# MAGIC %md ## Define the agent in code +# MAGIC Below we define our agent code in a single cell, enabling us to easily write it to a local Python file for subsequent logging and deployment using the `%%writefile` magic command. +# MAGIC +# MAGIC For more examples of tools to add to your agent, see [docs](https://docs.databricks.com/generative-ai/agent-framework/agent-tool.html). + +# COMMAND ---------- + +# MAGIC %%writefile agent.py +# MAGIC import json +# MAGIC from typing import Any, Callable, Generator, Optional +# MAGIC from uuid import uuid4 +# MAGIC import warnings +# MAGIC +# MAGIC import backoff +# MAGIC import mlflow +# MAGIC import openai +# MAGIC from databricks.sdk import WorkspaceClient +# MAGIC from databricks_openai import UCFunctionToolkit, VectorSearchRetrieverTool +# MAGIC from mlflow.entities import SpanType +# MAGIC from mlflow.pyfunc import ResponsesAgent +# MAGIC from mlflow.types.responses import ( +# MAGIC ResponsesAgentRequest, +# MAGIC ResponsesAgentResponse, +# MAGIC ResponsesAgentStreamEvent, +# MAGIC ) +# MAGIC from openai import OpenAI +# MAGIC from pydantic import BaseModel +# MAGIC from unitycatalog.ai.core.base import get_uc_function_client +# MAGIC +# MAGIC ############################################ +# MAGIC # Define your LLM endpoint and system prompt +# MAGIC ############################################ +# MAGIC LLM_ENDPOINT_NAME = "databricks-claude-sonnet-4" +# MAGIC +# MAGIC SYSTEM_PROMPT = """You are an agent to review the Unity catalog schema tables and suggest the Attribute-based access control(ABAC) policies using the metadata obtained using the tool functions and the documentation as below. Feel free to research additional documentation to understand the concepts and policy generation aspects +# MAGIC +# MAGIC https://docs.databricks.com/aws/en/data-governance/unity-catalog/abac/ +# MAGIC +# MAGIC https://docs.databricks.com/aws/en/data-governance/unity-catalog/abac/tutorial +# MAGIC +# MAGIC https://docs.databricks.com/aws/en/data-governance/unity-catalog/abac/policies +# MAGIC +# MAGIC – +# MAGIC The following is the general syntax for creating a policy: +# MAGIC SQL +# MAGIC +# MAGIC CREATE POLICY +# MAGIC ON +# MAGIC COMMENT '' +# MAGIC -- One of the following: +# MAGIC ROW FILTER +# MAGIC | COLUMN MASK ON COLUMN +# MAGIC TO [, , ...] +# MAGIC [EXCEPT [, , ...]] +# MAGIC FOR TABLES +# MAGIC [WHEN hasTag('') OR hasTagValue('', '')] +# MAGIC MATCH COLUMNS hasTag('') OR hasTagValue('', '') AS +# MAGIC USING COLUMNS [, , ...]; +# MAGIC This example defines a row filter policy that excludes rows for European customers from queries by US-based analysts: +# MAGIC SQL +# MAGIC +# MAGIC CREATE POLICY hide_eu_customers +# MAGIC ON SCHEMA prod.customers +# MAGIC COMMENT 'Hide rows with European customers from sensitive tables' +# MAGIC ROW FILTER non_eu_region +# MAGIC TO us_analysts +# MAGIC FOR TABLES +# MAGIC MATCH COLUMNS +# MAGIC hasTag('geo_region') AS region +# MAGIC USING COLUMNS (region); +# MAGIC This example defines a column mask policy that hides social security numbers from US analysts, except for those with in the admins group: +# MAGIC SQL +# MAGIC +# MAGIC CREATE POLICY mask_SSN +# MAGIC ON SCHEMA prod.customers +# MAGIC COMMENT 'Mask social security numbers' +# MAGIC COLUMN MASK mask_SSN +# MAGIC TO us_analysts +# MAGIC EXCEPT admins +# MAGIC FOR TABLES +# MAGIC MATCH COLUMNS +# MAGIC hasTagValue('pii', 'ssn') AS ssn +# MAGIC ON COLUMN ssn; +# MAGIC +# MAGIC +# MAGIC Find additional examples in the documentation - https://docs.databricks.com/aws/en/data-governance/unity-catalog/abac/policies?language=SQL +# MAGIC +# MAGIC Usually table name is given as catalog_name.schem_name.table_name. +# MAGIC Considering the table metadata, column metadata, and corresponding tag details, +# MAGIC """ +# MAGIC +# MAGIC +# MAGIC ############################################################################### +# MAGIC ## Define tools for your agent, enabling it to retrieve data or take actions +# MAGIC ## beyond text generation +# MAGIC ## To create and see usage examples of more tools, see +# MAGIC ## https://docs.databricks.com/generative-ai/agent-framework/agent-tool.html +# MAGIC ############################################################################### +# MAGIC class ToolInfo(BaseModel): +# MAGIC """ +# MAGIC Class representing a tool for the agent. +# MAGIC - "name" (str): The name of the tool. +# MAGIC - "spec" (dict): JSON description of the tool (matches OpenAI Responses format) +# MAGIC - "exec_fn" (Callable): Function that implements the tool logic +# MAGIC """ +# MAGIC +# MAGIC name: str +# MAGIC spec: dict +# MAGIC exec_fn: Callable +# MAGIC +# MAGIC +# MAGIC def create_tool_info(tool_spec, exec_fn_param: Optional[Callable] = None): +# MAGIC tool_spec["function"].pop("strict", None) +# MAGIC tool_name = tool_spec["function"]["name"] +# MAGIC udf_name = tool_name.replace("__", ".") +# MAGIC +# MAGIC # Define a wrapper that accepts kwargs for the UC tool call, +# MAGIC # then passes them to the UC tool execution client +# MAGIC def exec_fn(**kwargs): +# MAGIC function_result = uc_function_client.execute_function(udf_name, kwargs) +# MAGIC if function_result.error is not None: +# MAGIC return function_result.error +# MAGIC else: +# MAGIC return function_result.value +# MAGIC return ToolInfo(name=tool_name, spec=tool_spec, exec_fn=exec_fn_param or exec_fn) +# MAGIC +# MAGIC +# MAGIC TOOL_INFOS = [] +# MAGIC +# MAGIC # You can use UDFs in Unity Catalog as agent tools +# MAGIC # TODO: Add additional tools +# MAGIC UC_TOOL_NAMES = ["enterprise_gov.gov_admin.describe_extended_table", "enterprise_gov.gov_admin.list_uc_tables", "enterprise_gov.gov_admin.list_row_filter_column_masking", "enterprise_gov.gov_admin.get_table_tags", "enterprise_gov.gov_admin.get_column_tags"] +# MAGIC +# MAGIC uc_toolkit = UCFunctionToolkit(function_names=UC_TOOL_NAMES) +# MAGIC uc_function_client = get_uc_function_client() +# MAGIC for tool_spec in uc_toolkit.tools: +# MAGIC TOOL_INFOS.append(create_tool_info(tool_spec)) +# MAGIC +# MAGIC +# MAGIC # Use Databricks vector search indexes as tools +# MAGIC # See [docs](https://docs.databricks.com/generative-ai/agent-framework/unstructured-retrieval-tools.html) for details +# MAGIC +# MAGIC # # (Optional) Use Databricks vector search indexes as tools +# MAGIC # # See https://docs.databricks.com/generative-ai/agent-framework/unstructured-retrieval-tools.html +# MAGIC # # for details +# MAGIC VECTOR_SEARCH_TOOLS = [] +# MAGIC # # TODO: Add vector search indexes as tools or delete this block +# MAGIC # VECTOR_SEARCH_TOOLS.append( +# MAGIC # VectorSearchRetrieverTool( +# MAGIC # index_name="", +# MAGIC # # filters="..." +# MAGIC # ) +# MAGIC # ) +# MAGIC for vs_tool in VECTOR_SEARCH_TOOLS: +# MAGIC TOOL_INFOS.append(create_tool_info(vs_tool.tool, vs_tool.execute)) +# MAGIC +# MAGIC +# MAGIC +# MAGIC class ToolCallingAgent(ResponsesAgent): +# MAGIC """ +# MAGIC Class representing a tool-calling Agent +# MAGIC """ +# MAGIC +# MAGIC def __init__(self, llm_endpoint: str, tools: list[ToolInfo]): +# MAGIC """Initializes the ToolCallingAgent with tools.""" +# MAGIC self.llm_endpoint = llm_endpoint +# MAGIC self.workspace_client = WorkspaceClient() +# MAGIC self.model_serving_client: OpenAI = ( +# MAGIC self.workspace_client.serving_endpoints.get_open_ai_client() +# MAGIC ) +# MAGIC self._tools_dict = {tool.name: tool for tool in tools} +# MAGIC +# MAGIC def get_tool_specs(self) -> list[dict]: +# MAGIC """Returns tool specifications in the format OpenAI expects.""" +# MAGIC return [tool_info.spec for tool_info in self._tools_dict.values()] +# MAGIC +# MAGIC @mlflow.trace(span_type=SpanType.TOOL) +# MAGIC def execute_tool(self, tool_name: str, args: dict) -> Any: +# MAGIC """Executes the specified tool with the given arguments.""" +# MAGIC return self._tools_dict[tool_name].exec_fn(**args) +# MAGIC +# MAGIC def call_llm(self, messages: list[dict[str, Any]]) -> Generator[dict[str, Any], None, None]: +# MAGIC with warnings.catch_warnings(): +# MAGIC warnings.filterwarnings("ignore", message="PydanticSerializationUnexpectedValue") +# MAGIC for chunk in self.model_serving_client.chat.completions.create( +# MAGIC model=self.llm_endpoint, +# MAGIC messages=self.prep_msgs_for_cc_llm(messages), +# MAGIC tools=self.get_tool_specs(), +# MAGIC stream=True, +# MAGIC ): +# MAGIC yield chunk.to_dict() +# MAGIC +# MAGIC def handle_tool_call( +# MAGIC self, +# MAGIC tool_call: dict[str, Any], +# MAGIC messages: list[dict[str, Any]], +# MAGIC ) -> ResponsesAgentStreamEvent: +# MAGIC """ +# MAGIC Execute tool calls, add them to the running message history, and return a ResponsesStreamEvent w/ tool output +# MAGIC """ +# MAGIC args = json.loads(tool_call["arguments"]) +# MAGIC result = str(self.execute_tool(tool_name=tool_call["name"], args=args)) +# MAGIC +# MAGIC tool_call_output = self.create_function_call_output_item(tool_call["call_id"], result) +# MAGIC messages.append(tool_call_output) +# MAGIC return ResponsesAgentStreamEvent(type="response.output_item.done", item=tool_call_output) +# MAGIC +# MAGIC def call_and_run_tools( +# MAGIC self, +# MAGIC messages: list[dict[str, Any]], +# MAGIC max_iter: int = 10, +# MAGIC ) -> Generator[ResponsesAgentStreamEvent, None, None]: +# MAGIC for _ in range(max_iter): +# MAGIC last_msg = messages[-1] +# MAGIC if last_msg.get("role", None) == "assistant": +# MAGIC return +# MAGIC elif last_msg.get("type", None) == "function_call": +# MAGIC yield self.handle_tool_call(last_msg, messages) +# MAGIC else: +# MAGIC yield from self.output_to_responses_items_stream( +# MAGIC chunks=self.call_llm(messages), aggregator=messages +# MAGIC ) +# MAGIC +# MAGIC yield ResponsesAgentStreamEvent( +# MAGIC type="response.output_item.done", +# MAGIC item=self.create_text_output_item("Max iterations reached. Stopping.", str(uuid4())), +# MAGIC ) +# MAGIC +# MAGIC def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse: +# MAGIC outputs = [ +# MAGIC event.item +# MAGIC for event in self.predict_stream(request) +# MAGIC if event.type == "response.output_item.done" +# MAGIC ] +# MAGIC return ResponsesAgentResponse(output=outputs, custom_outputs=request.custom_inputs) +# MAGIC +# MAGIC def predict_stream( +# MAGIC self, request: ResponsesAgentRequest +# MAGIC ) -> Generator[ResponsesAgentStreamEvent, None, None]: +# MAGIC messages = self.prep_msgs_for_cc_llm([i.model_dump() for i in request.input]) +# MAGIC if SYSTEM_PROMPT: +# MAGIC messages.insert(0, {"role": "system", "content": SYSTEM_PROMPT}) +# MAGIC yield from self.call_and_run_tools(messages=messages) +# MAGIC +# MAGIC +# MAGIC # Log the model using MLflow +# MAGIC mlflow.openai.autolog() +# MAGIC AGENT = ToolCallingAgent(llm_endpoint=LLM_ENDPOINT_NAME, tools=TOOL_INFOS) +# MAGIC mlflow.models.set_model(AGENT) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Test the agent +# MAGIC +# MAGIC Interact with the agent to test its output. Since we manually traced methods within `ResponsesAgent`, you can view the trace for each step the agent takes, with any LLM calls made via the OpenAI SDK automatically traced by autologging. +# MAGIC +# MAGIC Replace this placeholder input with an appropriate domain-specific example for your agent. + +# COMMAND ---------- + +dbutils.library.restartPython() + +# COMMAND ---------- + +from agent import AGENT + +AGENT.predict({"input": [{"role": "user", "content": "what is 4*3 in python"}]}) + +# COMMAND ---------- + +for chunk in AGENT.predict_stream( + {"input": [{"role": "user", "content": "What is 4*3 in Python?"}]} +): + print(chunk.model_dump(exclude_none=True)) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Log the `agent` as an MLflow model +# MAGIC Determine Databricks resources to specify for automatic auth passthrough at deployment time +# MAGIC - **TODO**: If your Unity Catalog Function queries a [vector search index](https://docs.databricks.com/generative-ai/agent-framework/unstructured-retrieval-tools.html) or leverages [external functions](https://docs.databricks.com/generative-ai/agent-framework/external-connection-tools.html), you need to include the dependent vector search index and UC connection objects, respectively, as resources. See [docs](https://docs.databricks.com/generative-ai/agent-framework/log-agent.html#specify-resources-for-automatic-authentication-passthrough) for more details. +# MAGIC +# MAGIC Log the agent as code from the `agent.py` file. See [MLflow - Models from Code](https://mlflow.org/docs/latest/models.html#models-from-code). + +# COMMAND ---------- + +# Determine Databricks resources to specify for automatic auth passthrough at deployment time +import mlflow +from agent import UC_TOOL_NAMES, VECTOR_SEARCH_TOOLS, LLM_ENDPOINT_NAME +from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint +from pkg_resources import get_distribution + +resources = [DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME)] +for tool in VECTOR_SEARCH_TOOLS: + resources.extend(tool.resources) +for tool_name in UC_TOOL_NAMES: + # TODO: If the UC function includes dependencies like external connection or vector search, please include them manually. + # See the TODO in the markdown above for more information. + resources.append(DatabricksFunction(function_name=tool_name)) + +input_example = { + "input": [ + { + "role": "user", + "content": "What is an LLM agent?" + } + ] +} + +with mlflow.start_run(): + logged_agent_info = mlflow.pyfunc.log_model( + name="agent", + python_model="agent.py", + input_example=input_example, + pip_requirements=[ + "databricks-openai", + "backoff", + f"databricks-connect=={get_distribution('databricks-connect').version}", + ], + #resources=resources, + ) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Evaluate the agent with [Agent Evaluation](https://docs.databricks.com/mlflow3/genai/eval-monitor) +# MAGIC +# MAGIC You can edit the requests or expected responses in your evaluation dataset and run evaluation as you iterate your agent, leveraging mlflow to track the computed quality metrics. +# MAGIC +# MAGIC Evaluate your agent with one of our [predefined LLM scorers](https://docs.databricks.com/mlflow3/genai/eval-monitor/predefined-judge-scorers), or try adding [custom metrics](https://docs.databricks.com/mlflow3/genai/eval-monitor/custom-scorers). + +# COMMAND ---------- + +import mlflow +from mlflow.genai.scorers import RelevanceToQuery, Safety, RetrievalRelevance, RetrievalGroundedness + +eval_dataset = [ + { + "inputs": { + "input": [ + { + "role": "system", + "content": "You are an agent to review the Unity catalog schema tables and suggest the Attribute-based access control(ABAC) policies using the metadata obtained using the tool functions and the documentation as below. Feel free to research additional documentation to understand the concepts and policy generation aspects\n\nhttps://docs.databricks.com/aws/en/data-governance/unity-catalog/abac/\n\nhttps://docs.databricks.com/aws/en/data-governance/unity-catalog/abac/tutorial\n\nhttps://docs.databricks.com/aws/en/data-governance/unity-catalog/abac/policies\n\n–\nThe following is the general syntax for creating a policy:\nSQL\n\nCREATE POLICY \nON \nCOMMENT ''\n-- One of the following:\n ROW FILTER \n | COLUMN MASK ON COLUMN \nTO [, , ...]\n[EXCEPT [, , ...]]\nFOR TABLES\n[WHEN hasTag('') OR hasTagValue('', '')]\nMATCH COLUMNS hasTag('') OR hasTagValue('', '') AS \nUSING COLUMNS [, , ...];\nThis example defines a row filter policy that excludes rows for European customers from queries by US-based analysts:\nSQL\n\nCREATE POLICY hide_eu_customers\nON SCHEMA prod.customers\nCOMMENT 'Hide rows with European customers from sensitive tables'\nROW FILTER non_eu_region\nTO us_analysts\nFOR TABLES\nMATCH COLUMNS\n hasTag('geo_region') AS region\nUSING COLUMNS (region);\nThis example defines a column mask policy that hides social security numbers from US analysts, except for those with in the admins group:\nSQL\n\nCREATE POLICY mask_SSN\nON SCHEMA prod.customers\nCOMMENT 'Mask social security numbers'\nCOLUMN MASK mask_SSN\nTO us_analysts\nEXCEPT admins\nFOR TABLES\nMATCH COLUMNS\n hasTagValue('pii', 'ssn') AS ssn\nON COLUMN ssn;\n\n\nFind additional examples in the documentation - https://docs.databricks.com/aws/en/data-governance/unity-catalog/abac/policies?language=SQL\n\nUsually table name is given as catalog_name.schem_name.table_name. \nConsidering the table metadata, column metadata, and corresponding tag details, \n" + }, + { + "role": "user", + "content": "What are some of the available functions for masking and filtering data in the enterprise_gov.hr_finance schema, and what do they do?" + } + ] + }, + "expected_response": None + } +] + +eval_results = mlflow.genai.evaluate( + data=eval_dataset, + predict_fn=lambda input: AGENT.predict({"input": input}), + scorers=[RelevanceToQuery(), Safety()], # add more scorers here if they're applicable +) + +# Review the evaluation results in the MLfLow UI (see console output) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Perform pre-deployment validation of the agent +# MAGIC Before registering and deploying the agent, we perform pre-deployment checks via the [mlflow.models.predict()](https://mlflow.org/docs/latest/python_api/mlflow.models.html#mlflow.models.predict) API. See [documentation](https://docs.databricks.com/machine-learning/model-serving/model-serving-debug.html#validate-inputs) for details + +# COMMAND ---------- + +mlflow.models.predict( + model_uri=f"runs:/{logged_agent_info.run_id}/agent", + input_data={"input": [{"role": "user", "content": "Hello!"}]}, + env_manager="uv", +) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Register the model to Unity Catalog +# MAGIC +# MAGIC Update the `catalog`, `schema`, and `model_name` below to register the MLflow model to Unity Catalog. + +# COMMAND ---------- + +mlflow.set_registry_uri("databricks-uc") + +# TODO: define the catalog, schema, and model name for your UC model +catalog = "enterprise_gov" +schema = "gov_admin" +model_name = "governance_abac_model" +UC_MODEL_NAME = f"{catalog}.{schema}.{model_name}" + +# register the model to UC +uc_registered_model_info = mlflow.register_model( + model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME +) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Deploy the agent + +# COMMAND ---------- + +secret_scope = 'david_scope' +client_secret_key = 'DATABRICKS_CLIENT_SECRET' +client_id_key = 'DATABRICKS_CLIENT_ID' + + +# COMMAND ---------- + +from databricks import agents + +deployment_info = agents.deploy( + UC_MODEL_NAME, + uc_registered_model_info.version, + environment_vars={ + "DATABRICKS_HOST": "https://dbc-a612b3a4-f0ff.cloud.databricks.com", + "DATABRICKS_CLIENT_ID": dbutils.secrets.get(scope=secret_scope, key=client_id_key), + "DATABRICKS_CLIENT_SECRET": dbutils.secrets.get(scope=secret_scope, key=client_secret_key), + }, +) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Next steps +# MAGIC +# MAGIC After your agent is deployed, you can chat with it in AI playground to perform additional checks, share it with SMEs in your organization for feedback, or embed it in a production application. See [docs](https://docs.databricks.com/generative-ai/deploy-agent.html) for details \ No newline at end of file diff --git a/uc-quickstart/utils/abac-agent/messages.py b/uc-quickstart/utils/abac-agent/messages.py new file mode 100644 index 00000000..7fcc569e --- /dev/null +++ b/uc-quickstart/utils/abac-agent/messages.py @@ -0,0 +1,124 @@ +""" +Message classes for the chatbot application. + +This module contains the message classes used throughout the app. +By keeping them in a separate module, they remain stable across +Streamlit app reruns, avoiding isinstance comparison issues. +""" +import streamlit as st +from abc import ABC, abstractmethod + + +class Message(ABC): + def __init__(self): + pass + + @abstractmethod + def to_input_messages(self): + """Convert this message into a list of dicts suitable for the model API.""" + pass + + @abstractmethod + def render(self, idx): + """Render the message in the Streamlit app.""" + pass + + +class UserMessage(Message): + def __init__(self, content): + super().__init__() + self.content = content + + def to_input_messages(self): + return [{ + "role": "user", + "content": self.content + }] + + def render(self, _): + with st.chat_message("user"): + st.markdown(self.content) + + +class AssistantResponse(Message): + def __init__(self, messages, request_id): + super().__init__() + self.messages = messages + # Request ID tracked to enable submitting feedback on assistant responses via the feedback endpoint + self.request_id = request_id + + def to_input_messages(self): + return self.messages + + def render(self, idx): + with st.chat_message("assistant"): + for msg in self.messages: + render_message(msg) + + if self.request_id is not None: + render_assistant_message_feedback(idx, self.request_id) + + +def render_message(msg): + """Render a single message with enhanced formatting for ABAC content.""" + if msg["role"] == "assistant": + # Render content first if it exists + if msg.get("content"): + st.markdown(msg["content"]) + + # Then render tool calls if they exist + if "tool_calls" in msg and msg["tool_calls"]: + for call in msg["tool_calls"]: + fn_name = call["function"]["name"] + args = call["function"]["arguments"] + + # Databricks-themed display for function calls + st.markdown(f""" +
+
+ šŸ” {fn_name.replace('enterprise_gov__gov_admin__', '').replace('_', ' ').title()} +
+
+ """, unsafe_allow_html=True) + + # Only show parameters if they're not empty + try: + import json + parsed_args = json.loads(args) + if parsed_args: + with st.expander("View parameters", expanded=False): + st.json(parsed_args) + except: + pass + + elif msg["role"] == "tool": + # Clean, minimal tool response display + try: + import json + parsed = json.loads(msg["content"]) + + # Show results in an expandable section for cleaner UI + with st.expander("šŸ“Š View Results", expanded=True): + st.json(parsed) + except: + # If not JSON, show as code + with st.expander("šŸ“Š View Results", expanded=True): + st.text(msg["content"]) + + +@st.fragment +def render_assistant_message_feedback(i, request_id): + """Render feedback UI for assistant messages.""" + from model_serving_utils import submit_feedback + import os + + def save_feedback(index): + serving_endpoint = os.getenv('SERVING_ENDPOINT') + if serving_endpoint: + submit_feedback( + endpoint=serving_endpoint, + request_id=request_id, + rating=st.session_state[f"feedback_{index}"] + ) + + st.feedback("thumbs", key=f"feedback_{i}", on_change=save_feedback, args=[i]) \ No newline at end of file diff --git a/uc-quickstart/utils/abac-agent/model_serving_utils.py b/uc-quickstart/utils/abac-agent/model_serving_utils.py new file mode 100644 index 00000000..e25f61e6 --- /dev/null +++ b/uc-quickstart/utils/abac-agent/model_serving_utils.py @@ -0,0 +1,267 @@ +from mlflow.deployments import get_deploy_client +from databricks.sdk import WorkspaceClient +import json +import uuid + +import logging + +logging.basicConfig( + format="%(levelname)s [%(asctime)s] %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.DEBUG +) + +def _get_endpoint_task_type(endpoint_name: str) -> str: + """Get the task type of a serving endpoint.""" + try: + w = WorkspaceClient() + ep = w.serving_endpoints.get(endpoint_name) + return ep.task if ep.task else "chat/completions" + except Exception: + return "chat/completions" + +def _convert_to_responses_format(messages): + """Convert chat messages to ResponsesAgent API format.""" + input_messages = [] + for msg in messages: + if msg["role"] == "user": + input_messages.append({"role": "user", "content": msg["content"]}) + elif msg["role"] == "assistant": + # Handle assistant messages with tool calls + if msg.get("tool_calls"): + # Add function calls + for tool_call in msg["tool_calls"]: + input_messages.append({ + "type": "function_call", + "id": tool_call["id"], + "call_id": tool_call["id"], + "name": tool_call["function"]["name"], + "arguments": tool_call["function"]["arguments"] + }) + # Add assistant message if it has content + if msg.get("content"): + input_messages.append({ + "type": "message", + "id": msg.get("id", str(uuid.uuid4())), + "content": [{"type": "output_text", "text": msg["content"]}], + "role": "assistant" + }) + else: + # Regular assistant message + input_messages.append({ + "type": "message", + "id": msg.get("id", str(uuid.uuid4())), + "content": [{"type": "output_text", "text": msg["content"]}], + "role": "assistant" + }) + elif msg["role"] == "tool": + input_messages.append({ + "type": "function_call_output", + "call_id": msg.get("tool_call_id"), + "output": msg["content"] + }) + return input_messages + +def _throw_unexpected_endpoint_format(): + raise Exception("This app can only run against ChatModel, ChatAgent, or ResponsesAgent endpoints") + +def query_endpoint_stream(endpoint_name: str, messages: list[dict[str, str]], return_traces: bool): + task_type = _get_endpoint_task_type(endpoint_name) + + if task_type == "agent/v1/responses": + return _query_responses_endpoint_stream(endpoint_name, messages, return_traces) + else: + return _query_chat_endpoint_stream(endpoint_name, messages, return_traces) + +def _query_chat_endpoint_stream(endpoint_name: str, messages: list[dict[str, str]], return_traces: bool): + """Invoke an endpoint that implements either chat completions or ChatAgent and stream the response""" + client = get_deploy_client("databricks") + + # Prepare input payload + inputs = { + "messages": messages, + } + if return_traces: + inputs["databricks_options"] = {"return_trace": True} + + for chunk in client.predict_stream(endpoint=endpoint_name, inputs=inputs): + if "choices" in chunk: + yield chunk + elif "delta" in chunk: + yield chunk + else: + _throw_unexpected_endpoint_format() + +def _query_responses_endpoint_stream(endpoint_name: str, messages: list[dict[str, str]], return_traces: bool): + """Stream responses from agent/v1/responses endpoints using MLflow deployments client.""" + client = get_deploy_client("databricks") + + input_messages = _convert_to_responses_format(messages) + + # Prepare input payload for ResponsesAgent + inputs = { + "input": input_messages, + "context": {}, + "stream": True + } + if return_traces: + inputs["databricks_options"] = {"return_trace": True} + + for event_data in client.predict_stream(endpoint=endpoint_name, inputs=inputs): + # Just yield the raw event data, let app.py handle the parsing + yield event_data + +def query_endpoint(endpoint_name, messages, return_traces): + """ + Query an endpoint, returning the string message content and request + ID for feedback + """ + task_type = _get_endpoint_task_type(endpoint_name) + + if task_type == "agent/v1/responses": + return _query_responses_endpoint(endpoint_name, messages, return_traces) + else: + return _query_chat_endpoint(endpoint_name, messages, return_traces) + +def _query_chat_endpoint(endpoint_name, messages, return_traces): + """Calls a model serving endpoint with chat/completions format.""" + inputs = {'messages': messages} + if return_traces: + inputs['databricks_options'] = {'return_trace': True} + + res = get_deploy_client('databricks').predict( + endpoint=endpoint_name, + inputs=inputs, + ) + request_id = res.get("databricks_output", {}).get("databricks_request_id") + if "messages" in res: + return res["messages"], request_id + elif "choices" in res: + choice_message = res["choices"][0]["message"] + choice_content = choice_message.get("content") + + # Case 1: The content is a list of structured objects + if isinstance(choice_content, list): + combined_content = "".join([part.get("text", "") for part in choice_content if part.get("type") == "text"]) + reformatted_message = { + "role": choice_message.get("role"), + "content": combined_content + } + return [reformatted_message], request_id + + # Case 2: The content is a simple string + elif isinstance(choice_content, str): + return [choice_message], request_id + + _throw_unexpected_endpoint_format() + +def _query_responses_endpoint(endpoint_name, messages, return_traces): + """Query agent/v1/responses endpoints using MLflow deployments client.""" + client = get_deploy_client("databricks") + + input_messages = _convert_to_responses_format(messages) + + # Prepare input payload for ResponsesAgent + inputs = { + "input": input_messages, + "context": {} + } + if return_traces: + inputs["databricks_options"] = {"return_trace": True} + + # Make the prediction call + response = client.predict(endpoint=endpoint_name, inputs=inputs) + + # Extract messages from the response + result_messages = [] + request_id = response.get("databricks_output", {}).get("databricks_request_id") + + # Process the output items from ResponsesAgent response + output_items = response.get("output", []) + + for item in output_items: + item_type = item.get("type") + + if item_type == "message": + # Extract text content from message + text_content = "" + content_parts = item.get("content", []) + + for content_part in content_parts: + if content_part.get("type") == "output_text": + text_content += content_part.get("text", "") + + if text_content: + result_messages.append({ + "role": "assistant", + "content": text_content + }) + + elif item_type == "function_call": + # Handle function calls + call_id = item.get("call_id") + function_name = item.get("name") + arguments = item.get("arguments", "") + + tool_calls = [{ + "id": call_id, + "type": "function", + "function": { + "name": function_name, + "arguments": arguments + } + }] + result_messages.append({ + "role": "assistant", + "content": "", + "tool_calls": tool_calls + }) + + elif item_type == "function_call_output": + # Handle function call output/result + call_id = item.get("call_id") + output_content = item.get("output", "") + + result_messages.append({ + "role": "tool", + "content": output_content, + "tool_call_id": call_id + }) + + return result_messages or [{"role": "assistant", "content": "No response found"}], request_id + +def submit_feedback(endpoint, request_id, rating): + """Submit feedback to the agent.""" + rating_string = "positive" if rating == 1 else "negative" + text_assessments = [] if rating is None else [{ + "ratings": { + "answer_correct": {"value": rating_string}, + }, + "free_text_comment": None + }] + + proxy_payload = { + "dataframe_records": [ + { + "source": json.dumps({ + "id": "e2e-chatbot-app", # Or extract from auth + "type": "human" + }), + "request_id": request_id, + "text_assessments": json.dumps(text_assessments), + "retrieval_assessments": json.dumps([]), + } + ] + } + w = WorkspaceClient() + return w.api_client.do( + method='POST', + path=f"/serving-endpoints/{endpoint}/served-models/feedback/invocations", + body=proxy_payload, + ) + + +def endpoint_supports_feedback(endpoint_name): + w = WorkspaceClient() + endpoint = w.serving_endpoints.get(endpoint_name) + return "feedback" in [entity.name for entity in endpoint.config.served_entities] diff --git a/uc-quickstart/utils/abac-agent/requirements.txt b/uc-quickstart/utils/abac-agent/requirements.txt new file mode 100644 index 00000000..6dc0a944 --- /dev/null +++ b/uc-quickstart/utils/abac-agent/requirements.txt @@ -0,0 +1,2 @@ +mlflow>=2.21.2 +streamlit==1.44.1