@@ -86,6 +86,8 @@ def __init__(
86
86
databricks_index_name : str ,
87
87
databricks_endpoint : Optional [str ] = None ,
88
88
databricks_token : Optional [str ] = None ,
89
+ databricks_client_id : Optional [str ] = None ,
90
+ databricks_client_secret : Optional [str ] = None ,
89
91
columns : Optional [List [str ]] = None ,
90
92
filters_json : Optional [str ] = None ,
91
93
k : int = 3 ,
@@ -105,6 +107,10 @@ def __init__(
105
107
when querying the Vector Search Index. Defaults to the value of the
106
108
``DATABRICKS_TOKEN`` environment variable. If unspecified, the Databricks SDK is
107
109
used to identify the token based on the current environment.
110
+ databricks_client_id (str): Databricks service principal id. If not specified,
111
+ the token is resolved from the current environment (DATABRICKS_CLIENT_ID).
112
+ databricks_client_secret (str): Databricks service principal secret. If not specified,
113
+ the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET).
108
114
columns (Optional[List[str]]): Extra column names to include in response,
109
115
in addition to the document id and text columns specified by
110
116
``docs_id_column_name`` and ``text_column_name``.
@@ -127,12 +133,21 @@ def __init__(
127
133
self .databricks_endpoint = (
128
134
databricks_endpoint if databricks_endpoint is not None else os .environ .get ("DATABRICKS_HOST" )
129
135
)
136
+ self .databricks_client_id = (
137
+ databricks_client_id if databricks_client_id is not None else os .environ .get ("DATABRICKS_CLIENT_ID" )
138
+ )
139
+ self .databricks_client_secret = (
140
+ databricks_client_secret
141
+ if databricks_client_secret is not None
142
+ else os .environ .get ("DATABRICKS_CLIENT_SECRET" )
143
+ )
130
144
if not _databricks_sdk_installed and (self .databricks_token , self .databricks_endpoint ).count (None ) > 0 :
131
145
raise ValueError (
132
146
"To retrieve documents with Databricks Vector Search, you must install the"
133
147
" databricks-sdk Python library, supply the databricks_token and"
134
148
" databricks_endpoint parameters, or set the DATABRICKS_TOKEN and DATABRICKS_HOST"
135
- " environment variables."
149
+ " environment variables. You may also supply a service principal the databricks_client_id and"
150
+ " databricks_client_secret parameters, or set the DATABRICKS_CLIENT_ID and DATABRICKS_CLIENT_SECRET"
136
151
)
137
152
self .databricks_index_name = databricks_index_name
138
153
self .columns = list ({docs_id_column_name , text_column_name , * (columns or [])})
@@ -245,6 +260,8 @@ def forward(
245
260
query_vector = query_vector ,
246
261
databricks_token = self .databricks_token ,
247
262
databricks_endpoint = self .databricks_endpoint ,
263
+ databricks_client_id = self .databricks_client_id ,
264
+ databricks_client_secret = self .databricks_client_secret ,
248
265
filters_json = filters_json or self .filters_json ,
249
266
)
250
267
else :
@@ -315,6 +332,8 @@ def _query_via_databricks_sdk(
315
332
query_vector : Optional [List [float ]],
316
333
databricks_token : Optional [str ],
317
334
databricks_endpoint : Optional [str ],
335
+ databricks_client_id : Optional [str ],
336
+ databricks_client_secret : Optional [str ],
318
337
filters_json : Optional [str ],
319
338
) -> Dict [str , Any ]:
320
339
"""
@@ -334,15 +353,36 @@ def _query_via_databricks_sdk(
334
353
the token is resolved from the current environment.
335
354
databricks_endpoint (str): Databricks index endpoint url. If not specified,
336
355
the endpoint is resolved from the current environment.
356
+ databricks_client_id (str): Databricks service principal id. If not specified,
357
+ the token is resolved from the current environment (DATABRICKS_CLIENT_ID).
358
+ databricks_client_secret (str): Databricks service principal secret. If not specified,
359
+ the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET).
360
+ Returns:
337
361
Returns:
338
362
Dict[str, Any]: Parsed JSON response from the Databricks Vector Search Index query.
339
363
"""
364
+
340
365
from databricks .sdk import WorkspaceClient
341
366
342
367
if (query_text , query_vector ).count (None ) != 1 :
343
368
raise ValueError ("Exactly one of query_text or query_vector must be specified." )
344
369
345
- databricks_client = WorkspaceClient (host = databricks_endpoint , token = databricks_token )
370
+ if databricks_client_secret and databricks_client_id :
371
+ # Use client ID and secret for authentication if they are provided
372
+ databricks_client = WorkspaceClient (
373
+ client_id = databricks_client_id ,
374
+ client_secret = databricks_client_secret ,
375
+ )
376
+ print ("Creating Databricks workspace client using service principal authentication." )
377
+
378
+ else :
379
+ # Fallback for token-based authentication
380
+ databricks_client = WorkspaceClient (
381
+ host = databricks_endpoint ,
382
+ token = databricks_token ,
383
+ )
384
+ print ("Creating Databricks workspace client using token authentication." )
385
+
346
386
return databricks_client .vector_search_indexes .query_index (
347
387
index_name = index_name ,
348
388
query_type = query_type ,
0 commit comments