Skip to content

Commit 3c1dc16

Browse files
Support for Service Principal Auth for Databricks Retrieve with DSPy (#8293)
* Updated the DatabricksRM class to use Databricks service principals. - Updated the base class to include the optional parameters - Updated logic on the _query_via_databricks_sdk to use the SP credentials if they exist otherwise will fallback to PAT. * Updated comments for the usage of service principals for the Databricks SDK Client * Updated print statements to acknowledge auth method. * Updated print statements to acknowledge auth method. * format --------- Co-authored-by: chenmoneygithub <[email protected]>
1 parent 8aa8b9c commit 3c1dc16

File tree

1 file changed

+42
-2
lines changed

1 file changed

+42
-2
lines changed

dspy/retrieve/databricks_rm.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ def __init__(
8686
databricks_index_name: str,
8787
databricks_endpoint: Optional[str] = None,
8888
databricks_token: Optional[str] = None,
89+
databricks_client_id: Optional[str] = None,
90+
databricks_client_secret: Optional[str] = None,
8991
columns: Optional[List[str]] = None,
9092
filters_json: Optional[str] = None,
9193
k: int = 3,
@@ -105,6 +107,10 @@ def __init__(
105107
when querying the Vector Search Index. Defaults to the value of the
106108
``DATABRICKS_TOKEN`` environment variable. If unspecified, the Databricks SDK is
107109
used to identify the token based on the current environment.
110+
databricks_client_id (str): Databricks service principal id. If not specified,
111+
the token is resolved from the current environment (DATABRICKS_CLIENT_ID).
112+
databricks_client_secret (str): Databricks service principal secret. If not specified,
113+
the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET).
108114
columns (Optional[List[str]]): Extra column names to include in response,
109115
in addition to the document id and text columns specified by
110116
``docs_id_column_name`` and ``text_column_name``.
@@ -127,12 +133,21 @@ def __init__(
127133
self.databricks_endpoint = (
128134
databricks_endpoint if databricks_endpoint is not None else os.environ.get("DATABRICKS_HOST")
129135
)
136+
self.databricks_client_id = (
137+
databricks_client_id if databricks_client_id is not None else os.environ.get("DATABRICKS_CLIENT_ID")
138+
)
139+
self.databricks_client_secret = (
140+
databricks_client_secret
141+
if databricks_client_secret is not None
142+
else os.environ.get("DATABRICKS_CLIENT_SECRET")
143+
)
130144
if not _databricks_sdk_installed and (self.databricks_token, self.databricks_endpoint).count(None) > 0:
131145
raise ValueError(
132146
"To retrieve documents with Databricks Vector Search, you must install the"
133147
" databricks-sdk Python library, supply the databricks_token and"
134148
" databricks_endpoint parameters, or set the DATABRICKS_TOKEN and DATABRICKS_HOST"
135-
" environment variables."
149+
" environment variables. You may also supply a service principal the databricks_client_id and"
150+
" databricks_client_secret parameters, or set the DATABRICKS_CLIENT_ID and DATABRICKS_CLIENT_SECRET"
136151
)
137152
self.databricks_index_name = databricks_index_name
138153
self.columns = list({docs_id_column_name, text_column_name, *(columns or [])})
@@ -245,6 +260,8 @@ def forward(
245260
query_vector=query_vector,
246261
databricks_token=self.databricks_token,
247262
databricks_endpoint=self.databricks_endpoint,
263+
databricks_client_id=self.databricks_client_id,
264+
databricks_client_secret=self.databricks_client_secret,
248265
filters_json=filters_json or self.filters_json,
249266
)
250267
else:
@@ -315,6 +332,8 @@ def _query_via_databricks_sdk(
315332
query_vector: Optional[List[float]],
316333
databricks_token: Optional[str],
317334
databricks_endpoint: Optional[str],
335+
databricks_client_id: Optional[str],
336+
databricks_client_secret: Optional[str],
318337
filters_json: Optional[str],
319338
) -> Dict[str, Any]:
320339
"""
@@ -334,15 +353,36 @@ def _query_via_databricks_sdk(
334353
the token is resolved from the current environment.
335354
databricks_endpoint (str): Databricks index endpoint url. If not specified,
336355
the endpoint is resolved from the current environment.
356+
databricks_client_id (str): Databricks service principal id. If not specified,
357+
the token is resolved from the current environment (DATABRICKS_CLIENT_ID).
358+
databricks_client_secret (str): Databricks service principal secret. If not specified,
359+
the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET).
360+
Returns:
337361
Returns:
338362
Dict[str, Any]: Parsed JSON response from the Databricks Vector Search Index query.
339363
"""
364+
340365
from databricks.sdk import WorkspaceClient
341366

342367
if (query_text, query_vector).count(None) != 1:
343368
raise ValueError("Exactly one of query_text or query_vector must be specified.")
344369

345-
databricks_client = WorkspaceClient(host=databricks_endpoint, token=databricks_token)
370+
if databricks_client_secret and databricks_client_id:
371+
# Use client ID and secret for authentication if they are provided
372+
databricks_client = WorkspaceClient(
373+
client_id=databricks_client_id,
374+
client_secret=databricks_client_secret,
375+
)
376+
print("Creating Databricks workspace client using service principal authentication.")
377+
378+
else:
379+
# Fallback for token-based authentication
380+
databricks_client = WorkspaceClient(
381+
host=databricks_endpoint,
382+
token=databricks_token,
383+
)
384+
print("Creating Databricks workspace client using token authentication.")
385+
346386
return databricks_client.vector_search_indexes.query_index(
347387
index_name=index_name,
348388
query_type=query_type,

0 commit comments

Comments
 (0)