Skip to content

Commit 59e73bd

Browse files
authored
fix: issue with parsing nested lists/tuples for extra inputs (#318)
1 parent 1d8656e commit 59e73bd

File tree

3 files changed

+152
-18
lines changed

3 files changed

+152
-18
lines changed

src/aiperf/common/config/config_validators.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def parse_str_or_dict_as_tuple_list(input: Any | None) -> list[tuple[str, Any]]
118118
into key and value, trims any whitespace, and coerces the value to the correct type.
119119
- If the input is a dictionary, it is converted to a list of tuples by key and value pairs.
120120
- If the input is a list, it recursively calls this function on each item, and aggregates the results.
121+
- If the item is already a 2-element sequence (key-value pair), it is converted directly to a tuple.
121122
- Otherwise, a ValueError is raised.
122123
123124
Args:
@@ -133,9 +134,14 @@ def parse_str_or_dict_as_tuple_list(input: Any | None) -> list[tuple[str, Any]]
133134
if isinstance(input, list | tuple | set):
134135
output = []
135136
for item in input:
136-
res = parse_str_or_dict_as_tuple_list(item)
137-
if res is not None:
138-
output.extend(res)
137+
# If item is already a 2-element sequence (key-value pair), convert directly to tuple
138+
if isinstance(item, list | tuple) and len(item) == 2:
139+
key, value = item
140+
output.append((str(key), coerce_value(value)))
141+
else:
142+
res = parse_str_or_dict_as_tuple_list(item)
143+
if res is not None:
144+
output.extend(res)
139145
return output
140146

141147
if isinstance(input, dict):
@@ -150,11 +156,16 @@ def parse_str_or_dict_as_tuple_list(input: Any | None) -> list[tuple[str, Any]]
150156
f"User Config: {input} - must be a valid JSON string"
151157
) from e
152158
else:
153-
return [
154-
(key.strip(), coerce_value(value.strip()))
155-
for item in input.split(",")
156-
for key, value in [item.split(":")]
157-
]
159+
result = []
160+
for item in input.split(","):
161+
parts = item.split(":", 1)
162+
if len(parts) != 2:
163+
raise ValueError(
164+
f"User Config: {input} - each item must be in 'key:value' format"
165+
)
166+
key, value = parts
167+
result.append((key.strip(), coerce_value(value.strip())))
168+
return result
158169

159170
raise ValueError(f"User Config: {input} - must be a valid string, list, or dict")
160171

tests/config/test_config_validators.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ def test_invalid_json_string_raises_value_error(self, invalid_json):
156156
[
157157
["key1_no_colon"], # Missing colon
158158
["key1:value1", "key2_no_colon"], # One valid, one invalid
159-
["key1:value1:extra"], # Too many colons
160159
],
161160
)
162161
def test_invalid_list_format_raises_value_error(self, invalid_list):
@@ -190,15 +189,36 @@ def test_invalid_input_type_raises_value_error(self, invalid_input):
190189
with pytest.raises(ValueError, match="must be a valid string, list, or dict"):
191190
parse_str_or_dict_as_tuple_list(invalid_input)
192191

193-
def test_string_with_multiple_colons_raises_value_error(self):
194-
"""Test that strings with multiple colons raise ValueError."""
195-
with pytest.raises(ValueError):
196-
parse_str_or_dict_as_tuple_list("key1:value1:extra,key2:value2")
197-
198-
def test_list_with_multiple_colons_raises_value_error(self):
199-
"""Test that list items with multiple colons raise ValueError."""
200-
with pytest.raises(ValueError):
201-
parse_str_or_dict_as_tuple_list(["key1:value1:extra", "key2:value2"])
192+
@pytest.mark.parametrize(
193+
"input_value,expected",
194+
[
195+
# String with multiple colons
196+
(
197+
"key1:value1:extra,key2:value2",
198+
[("key1", "value1:extra"), ("key2", "value2")],
199+
),
200+
# List with multiple colons
201+
(
202+
["key1:value1:extra", "key2:value2"],
203+
[("key1", "value1:extra"), ("key2", "value2")],
204+
),
205+
# URL with port
206+
("url:http://example.com:8080", [("url", "http://example.com:8080")]),
207+
# Multiple entries with colons in values (timestamps, ports, etc)
208+
(
209+
"server:localhost:8080,time:12:30:45,status:active",
210+
[
211+
("server", "localhost:8080"),
212+
("time", "12:30:45"),
213+
("status", "active"),
214+
],
215+
),
216+
],
217+
)
218+
def test_values_can_contain_colons(self, input_value, expected):
219+
"""Test that values can contain colons (URLs, timestamps, etc)."""
220+
result = parse_str_or_dict_as_tuple_list(input_value)
221+
assert result == expected
202222

203223
def test_whitespace_handling_in_string_input(self):
204224
"""Test that whitespace is properly trimmed in string input."""
@@ -248,6 +268,31 @@ def test_none_input_returns_none(self):
248268
result = parse_str_or_dict_as_tuple_list(None)
249269
assert result is None
250270

271+
@pytest.mark.parametrize(
272+
"input_list,expected",
273+
[
274+
(
275+
[["temperature", 0.1], ["max_tokens", 150]],
276+
[("temperature", 0.1), ("max_tokens", 150)],
277+
),
278+
(
279+
[("temperature", 0.1), ("max_tokens", 150)],
280+
[("temperature", 0.1), ("max_tokens", 150)],
281+
),
282+
(
283+
[("key1", "value1"), ("key2", 123), ("key3", True)],
284+
[("key1", "value1"), ("key2", 123), ("key3", True)],
285+
),
286+
],
287+
)
288+
def test_list_of_key_value_pairs_input(self, input_list, expected):
289+
"""Test that a list of key-value pairs (lists/tuples) is converted correctly to a list of tuples."""
290+
result = parse_str_or_dict_as_tuple_list(input_list)
291+
assert result == expected
292+
# Make sure that the result is the same when parsed again.
293+
result2 = parse_str_or_dict_as_tuple_list(result)
294+
assert result2 == expected
295+
251296

252297
class TestParseStrOrListOfPositiveValues:
253298
"""Test suite for the parse_str_or_list_of_positive_values function."""

tests/config/test_user_config.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,95 @@
66
import pytest
77

88
from aiperf.common.config import (
9+
ConversationConfig,
910
EndpointConfig,
1011
EndpointDefaults,
1112
InputConfig,
1213
LoadGeneratorConfig,
1314
OutputConfig,
1415
TokenizerConfig,
16+
TurnConfig,
17+
TurnDelayConfig,
1518
UserConfig,
1619
)
1720
from aiperf.common.enums import EndpointType
21+
from aiperf.common.enums.dataset_enums import CustomDatasetType
1822
from aiperf.common.enums.timing_enums import TimingMode
1923

24+
"""
25+
Test suite for the UserConfig class.
26+
"""
27+
28+
29+
class TestUserConfig:
30+
"""Test suite for the UserConfig class."""
31+
32+
def test_user_config_serialization_to_json_string(self):
33+
"""Test the serialization and deserialization of a UserConfig object to and from a JSON string."""
34+
config = UserConfig(
35+
endpoint=EndpointConfig(
36+
model_names=["model1", "model2"],
37+
type=EndpointType.CHAT,
38+
custom_endpoint="custom_endpoint",
39+
streaming=True,
40+
url="http://custom-url",
41+
extra=[
42+
("key1", "value1"),
43+
("key2", "value2"),
44+
("key3", "value3"),
45+
],
46+
headers=[
47+
("Authorization", "Bearer token"),
48+
("Content-Type", "application/json"),
49+
],
50+
api_key="test_api_key",
51+
ssl_options={"verify": False},
52+
timeout=10,
53+
),
54+
conversation_config=ConversationConfig(
55+
num=10,
56+
turn=TurnConfig(
57+
mean=10,
58+
stddev=10,
59+
delay=TurnDelayConfig(
60+
mean=10,
61+
stddev=10,
62+
),
63+
),
64+
),
65+
input=InputConfig(
66+
custom_dataset_type=CustomDatasetType.SINGLE_TURN,
67+
),
68+
output=OutputConfig(
69+
artifact_directory="test_artifacts",
70+
),
71+
tokenizer=TokenizerConfig(
72+
model_name="test_tokenizer",
73+
),
74+
loadgen=LoadGeneratorConfig(
75+
concurrency=10,
76+
request_rate=10,
77+
),
78+
verbose=True,
79+
template_filename="test_template.yaml",
80+
cli_command="test_cli_command",
81+
)
82+
83+
# NOTE: Currently, we have validation logic that uses the concept of whether a field was set by the user, so
84+
# exclude_unset must be used. exclude_defaults should also be able to work.
85+
assert (
86+
UserConfig.model_validate_json(
87+
config.model_dump_json(indent=4, exclude_unset=True)
88+
)
89+
== config
90+
)
91+
assert (
92+
UserConfig.model_validate_json(
93+
config.model_dump_json(indent=4, exclude_defaults=True)
94+
)
95+
== config
96+
)
97+
2098

2199
def test_user_config_serialization_to_file():
22100
"""

0 commit comments

Comments
 (0)