Skip to content

Commit 510d630

Browse files
committed
update
1 parent fb87bc5 commit 510d630

File tree

1 file changed

+45
-28
lines changed

1 file changed

+45
-28
lines changed

sdk/ai/azure-ai-projects/tests/samples/test_samples.py

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
import csv, os, pytest, re, inspect, sys, json
77
import importlib.util
88
import unittest.mock as mock
9+
from typing import cast
10+
from azure.core.credentials import TokenCredential
11+
from azure.core.credentials_async import AsyncTokenCredential
912
from azure.core.exceptions import HttpResponseError
1013
from devtools_testutils.aio import recorded_by_proxy_async
1114
from devtools_testutils import AzureRecordedTestCase, recorded_by_proxy, RecordedTransport
@@ -131,7 +134,6 @@ def _get_validation_request_params(self) -> dict:
131134
- Error messages or exception text
132135
- Empty or null results where data is expected
133136
- Malformed or corrupted data
134-
- HTTP error codes (4xx, 5xx)
135137
- Timeout or connection errors
136138
- Warning messages indicating failures
137139
- Failure to retrieve or process data
@@ -180,9 +182,11 @@ def _assert_validation_result(self, test_report: dict) -> None:
180182

181183
def _validate_output(self):
182184
"""Validate sample output using synchronous OpenAI client."""
185+
credential = self.test_instance.get_credential(AIProjectClient, is_async=False)
183186
with (
184-
DefaultAzureCredential() as credential,
185-
AIProjectClient(endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential) as project_client,
187+
AIProjectClient(
188+
endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=cast(TokenCredential, credential)
189+
) as project_client,
186190
project_client.get_openai_client() as openai_client,
187191
):
188192
response = openai_client.responses.create(**self._get_validation_request_params())
@@ -191,10 +195,10 @@ def _validate_output(self):
191195

192196
async def _validate_output_async(self):
193197
"""Validate sample output using asynchronous OpenAI client."""
198+
credential = self.test_instance.get_credential(AIProjectClient, is_async=True)
194199
async with (
195-
AsyncDefaultAzureCredential() as credential,
196200
AsyncAIProjectClient(
197-
endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential
201+
endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=cast(AsyncTokenCredential, credential)
198202
) as project_client,
199203
):
200204
async with project_client.get_openai_client() as openai_client:
@@ -225,23 +229,29 @@ def _get_tools_sample_paths():
225229
samples_folder_path = os.path.normpath(os.path.join(current_dir, os.pardir, os.pardir))
226230
tools_folder = os.path.join(samples_folder_path, "samples", "agents", "tools")
227231

228-
tools_samples_to_skip = [
229-
"sample_agent_bing_custom_search.py",
230-
"sample_agent_bing_grounding.py",
231-
"sample_agent_browser_automation.py",
232-
"sample_agent_fabric.py",
233-
"sample_agent_mcp_with_project_connection.py",
234-
"sample_agent_memory_search.py",
235-
"sample_agent_openapi_with_project_connection.py",
236-
"sample_agent_to_agent.py",
232+
# Whitelist of samples to test
233+
tools_samples_to_test = [
234+
"sample_agent_ai_search.py",
235+
"sample_agent_code_interpreter.py",
236+
"sample_agent_file_search.py",
237+
"sample_agent_file_search_in_stream.py",
238+
"sample_agent_function_tool.py",
239+
"sample_agent_image_generation.py",
240+
"sample_agent_mcp.py",
241+
"sample_agent_openapi.py",
242+
"sample_agent_sharepoint.py",
243+
"sample_agent_web_search.py",
237244
]
238245
samples = []
239246

240-
for filename in sorted(os.listdir(tools_folder)):
241-
# Only include .py files, exclude __pycache__ and utility files
242-
if "sample_" in filename and "_async" not in filename and filename not in tools_samples_to_skip:
243-
sample_path = os.path.join(tools_folder, filename)
244-
samples.append(pytest.param(sample_path, id=filename.replace(".py", "")))
247+
for filename in tools_samples_to_test:
248+
sample_path = os.path.join(tools_folder, filename)
249+
if os.path.exists(sample_path):
250+
# Get relative path from samples folder and convert to test ID format
251+
rel_path = os.path.relpath(sample_path, samples_folder_path)
252+
# Remove 'samples\' prefix and convert to forward slashes
253+
test_id = rel_path.replace("samples\\", "").replace("\\", "/").replace(".py", "")
254+
samples.append(pytest.param(sample_path, id=test_id))
245255

246256
return samples
247257

@@ -252,18 +262,25 @@ def _get_tools_sample_paths_async():
252262
samples_folder_path = os.path.normpath(os.path.join(current_dir, os.pardir, os.pardir))
253263
tools_folder = os.path.join(samples_folder_path, "samples", "agents", "tools")
254264

255-
# Skip async samples that are not yet ready for testing
256-
tools_samples_to_skip = [
257-
"sample_agent_mcp_with_project_connection_async.py",
258-
"sample_agent_memory_search_async.py",
265+
# Whitelist of async samples to test
266+
tools_samples_to_test_async = [
267+
"sample_agent_code_interpreter_async.py",
268+
"sample_agent_computer_use_async.py",
269+
"sample_agent_file_search_in_stream_async.py",
270+
"sample_agent_function_tool_async.py",
271+
"sample_agent_image_generation_async.py",
272+
"sample_agent_mcp_async.py",
259273
]
260274
samples = []
261275

262-
for filename in sorted(os.listdir(tools_folder)):
263-
# Only include async .py files, exclude __pycache__ and utility files
264-
if "sample_" in filename and "_async" in filename and filename not in tools_samples_to_skip:
265-
sample_path = os.path.join(tools_folder, filename)
266-
samples.append(pytest.param(sample_path, id=filename.replace(".py", "")))
276+
for filename in tools_samples_to_test_async:
277+
sample_path = os.path.join(tools_folder, filename)
278+
if os.path.exists(sample_path):
279+
# Get relative path from samples folder and convert to test ID format
280+
rel_path = os.path.relpath(sample_path, samples_folder_path)
281+
# Remove 'samples\' prefix and convert to forward slashes
282+
test_id = rel_path.replace("samples\\", "").replace("\\", "/").replace(".py", "")
283+
samples.append(pytest.param(sample_path, id=test_id))
267284

268285
return samples
269286

0 commit comments

Comments
 (0)