Skip to content

Commit 65570ab

Browse files
committed
resolved comment
1 parent 99386f8 commit 65570ab

File tree

2 files changed

+95
-52
lines changed

2 files changed

+95
-52
lines changed

sdk/ai/azure-ai-projects/assets.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
"AssetsRepo": "Azure/azure-sdk-assets",
33
"AssetsRepoPrefixPath": "python",
44
"TagPrefix": "python/ai/azure-ai-projects",
5-
"Tag": "python/ai/azure-ai-projects_d44710c465"
5+
"Tag": "python/ai/azure-ai-projects_93d1dc0fe7"
66
}

sdk/ai/azure-ai-projects/tests/samples/test_samples.py

Lines changed: 94 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import csv, os, pytest, re, inspect, sys, json
77
import importlib.util
88
import unittest.mock as mock
9-
from typing import cast
9+
from typing import Union, cast, overload
1010
from azure.core.credentials import TokenCredential
1111
from azure.core.credentials_async import AsyncTokenCredential
1212
from azure.core.exceptions import HttpResponseError
@@ -83,7 +83,7 @@ def _get_mock_credential(self, is_async: bool):
8383

8484
return mock.patch(patch_target, new=mock_credential_class)
8585

86-
def execute(self):
86+
def execute(self, enable_llm_validation: bool = True):
8787
"""Execute a synchronous sample with proper mocking and environment setup."""
8888

8989
with (
@@ -101,9 +101,10 @@ def execute(self):
101101
raise ImportError(f"Could not load module {self.spec.name} from {self.sample_path}")
102102
self.spec.loader.exec_module(self.module)
103103

104-
self._validate_output()
104+
if enable_llm_validation:
105+
self._validate_output()
105106

106-
async def execute_async(self):
107+
async def execute_async(self, enable_llm_validation: bool = True):
107108
"""Execute an asynchronous sample with proper mocking and environment setup."""
108109

109110
with (
@@ -122,7 +123,8 @@ async def execute_async(self):
122123
self.spec.loader.exec_module(self.module)
123124
await self.module.main()
124125

125-
await self._validate_output_async()
126+
if enable_llm_validation:
127+
await self._validate_output_async()
126128

127129
def _get_validation_request_params(self) -> dict:
128130
"""Get common parameters for validation request."""
@@ -223,66 +225,97 @@ def _wrapper_sync(test_class, sample_path, **kwargs):
223225
return _wrapper_sync
224226

225227

226-
def _get_tools_sample_paths():
227-
# Get the path to the samples folder
228-
current_dir = os.path.dirname(os.path.abspath(__file__))
229-
samples_folder_path = os.path.normpath(os.path.join(current_dir, os.pardir, os.pardir))
230-
tools_folder = os.path.join(samples_folder_path, "samples", "agents", "tools")
231-
232-
# Whitelist of samples to test
233-
tools_samples_to_test = [
234-
"sample_agent_ai_search.py",
235-
"sample_agent_code_interpreter.py",
236-
"sample_agent_file_search.py",
237-
"sample_agent_file_search_in_stream.py",
238-
"sample_agent_function_tool.py",
239-
"sample_agent_image_generation.py",
240-
"sample_agent_mcp.py",
241-
"sample_agent_openapi.py",
242-
"sample_agent_sharepoint.py",
243-
"sample_agent_web_search.py",
244-
]
245-
samples = []
228+
@overload
229+
def _get_sample_paths(sub_folder: str, *, samples_to_test: list[str]) -> list:
230+
"""Get sample paths for testing (whitelist mode).
246231
247-
for filename in tools_samples_to_test:
248-
sample_path = os.path.join(tools_folder, filename)
249-
if os.path.exists(sample_path):
250-
test_id = filename.replace(".py", "")
251-
samples.append(pytest.param(sample_path, id=test_id))
232+
Args:
233+
sub_folder: Relative path to the samples subfolder (e.g., "agents/tools")
234+
samples_to_test: Whitelist of sample filenames to include
252235
253-
return samples
236+
Returns:
237+
List of pytest.param objects with sample paths and test IDs
238+
"""
239+
...
240+
241+
242+
@overload
243+
def _get_sample_paths(sub_folder: str, *, samples_to_skip: list[str], is_async: bool) -> list:
244+
"""Get sample paths for testing (blacklist mode).
245+
246+
Args:
247+
sub_folder: Relative path to the samples subfolder (e.g., "agents/tools")
248+
samples_to_skip: Blacklist of sample filenames to exclude (auto-discovers all samples)
249+
is_async: Whether to filter for async samples (_async.py suffix)
250+
251+
Returns:
252+
List of pytest.param objects with sample paths and test IDs
253+
"""
254+
...
254255

255256

256-
def _get_tools_sample_paths_async():
257+
def _get_sample_paths(
258+
sub_folder: str,
259+
*,
260+
samples_to_skip: Union[list[str], None] = None,
261+
is_async: Union[bool, None] = None,
262+
samples_to_test: Union[list[str], None] = None,
263+
) -> list:
257264
# Get the path to the samples folder
258265
current_dir = os.path.dirname(os.path.abspath(__file__))
259266
samples_folder_path = os.path.normpath(os.path.join(current_dir, os.pardir, os.pardir))
260-
tools_folder = os.path.join(samples_folder_path, "samples", "agents", "tools")
261-
262-
# Whitelist of async samples to test
263-
tools_samples_to_test_async = [
264-
"sample_agent_code_interpreter_async.py",
265-
"sample_agent_computer_use_async.py",
266-
"sample_agent_file_search_in_stream_async.py",
267-
"sample_agent_function_tool_async.py",
268-
"sample_agent_image_generation_async.py",
269-
"sample_agent_mcp_async.py",
270-
]
271-
samples = []
267+
target_folder = os.path.join(samples_folder_path, "samples", *sub_folder.split("/"))
268+
269+
if not os.path.exists(target_folder):
270+
raise ValueError(f"Target folder does not exist: {target_folder}")
271+
272+
# Discover all sample files in the folder
273+
all_files = [f for f in os.listdir(target_folder) if f.startswith("sample_") and f.endswith(".py")]
274+
275+
# Filter by async suffix only when using samples_to_skip
276+
if samples_to_skip is not None and is_async is not None:
277+
if is_async:
278+
all_files = [f for f in all_files if f.endswith("_async.py")]
279+
else:
280+
all_files = [f for f in all_files if not f.endswith("_async.py")]
272281

273-
for filename in tools_samples_to_test_async:
274-
sample_path = os.path.join(tools_folder, filename)
275-
if os.path.exists(sample_path):
276-
test_id = filename.replace(".py", "")
277-
samples.append(pytest.param(sample_path, id=test_id))
282+
# Apply whitelist or blacklist
283+
if samples_to_test is not None:
284+
files_to_test = [f for f in all_files if f in samples_to_test]
285+
else: # samples_to_skip is not None
286+
assert samples_to_skip is not None
287+
files_to_test = [f for f in all_files if f not in samples_to_skip]
288+
289+
# Create pytest.param objects
290+
samples = []
291+
for filename in sorted(files_to_test):
292+
sample_path = os.path.join(target_folder, filename)
293+
test_id = filename.replace(".py", "")
294+
samples.append(pytest.param(sample_path, id=test_id))
278295

279296
return samples
280297

281298

282299
class TestSamples(AzureRecordedTestCase):
283300

284301
@servicePreparer()
285-
@pytest.mark.parametrize("sample_path", _get_tools_sample_paths())
302+
@pytest.mark.parametrize(
303+
"sample_path",
304+
_get_sample_paths(
305+
"agents/tools",
306+
samples_to_skip=[
307+
"sample_agent_bing_custom_search.py",
308+
"sample_agent_bing_grounding.py",
309+
"sample_agent_browser_automation.py",
310+
"sample_agent_fabric.py",
311+
"sample_agent_mcp_with_project_connection.py",
312+
"sample_agent_memory_search.py",
313+
"sample_agent_openapi_with_project_connection.py",
314+
"sample_agent_to_agent.py",
315+
],
316+
is_async=False,
317+
),
318+
)
286319
@SamplePathPasser()
287320
@recorded_by_proxy(RecordedTransport.AZURE_CORE, RecordedTransport.HTTPX)
288321
def test_agent_tools_samples(self, sample_path: str, **kwargs) -> None:
@@ -291,7 +324,17 @@ def test_agent_tools_samples(self, sample_path: str, **kwargs) -> None:
291324
executor.execute()
292325

293326
@servicePreparer()
294-
@pytest.mark.parametrize("sample_path", _get_tools_sample_paths_async())
327+
@pytest.mark.parametrize(
328+
"sample_path",
329+
_get_sample_paths(
330+
"agents/tools",
331+
samples_to_skip=[
332+
"sample_agent_mcp_with_project_connection_async.py",
333+
"sample_agent_memory_search_async.py",
334+
],
335+
is_async=True,
336+
),
337+
)
295338
@SamplePathPasser()
296339
@recorded_by_proxy_async(RecordedTransport.AZURE_CORE, RecordedTransport.HTTPX)
297340
async def test_agent_tools_samples_async(self, sample_path: str, **kwargs) -> None:

0 commit comments

Comments
 (0)