-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathtest_llm_config.py
More file actions
144 lines (125 loc) · 4.67 KB
/
test_llm_config.py
File metadata and controls
144 lines (125 loc) · 4.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
LLMConfig unit tests.
Covers:
- Default values
- from_dict parsing from config dictionary
- get_models
- Edge cases (missing keys, partial config)
All console output must be in English only (no emoji, no Chinese).
"""
import pytest
from video_transcript_api.llm.core.config import LLMConfig
class TestLLMConfigDefaults:
"""Test LLMConfig default values."""
def test_required_fields(self):
"""Should require api_key, base_url, calibrate_model, summary_model."""
config = LLMConfig(
api_key="key",
base_url="https://api.test.com",
calibrate_model="model-a",
summary_model="model-b",
)
assert config.api_key == "key"
assert config.calibrate_model == "model-a"
def test_default_retry_values(self):
"""Default retry config should be sensible."""
config = LLMConfig(
api_key="k", base_url="u",
calibrate_model="m", summary_model="s",
)
assert config.max_retries == 3
assert config.retry_delay == 5
def test_default_segment_sizes(self):
"""Default segmentation config should be set."""
config = LLMConfig(
api_key="k", base_url="u",
calibrate_model="m", summary_model="s",
)
assert config.segment_size == 2000
assert config.max_segment_size == 3000
assert config.enable_threshold == 5000
def test_default_quality_weights(self):
"""Default quality score weights should sum to 1.0."""
config = LLMConfig(
api_key="k", base_url="u",
calibrate_model="m", summary_model="s",
)
total = sum(config.quality_score_weights.values())
assert abs(total - 1.0) < 0.01
class TestLLMConfigFromDict:
"""Test from_dict parsing."""
def test_basic_config(self):
"""Should parse basic config dict."""
config_dict = {
"llm": {
"api_key": "test-key",
"base_url": "https://api.test.com",
"calibrate_model": "deepseek-v4-flash",
"summary_model": "deepseek-v4-pro",
}
}
config = LLMConfig.from_dict(config_dict)
assert config.api_key == "test-key"
assert config.calibrate_model == "deepseek-v4-flash"
assert config.summary_model == "deepseek-v4-pro"
def test_old_risk_fields_ignored(self):
"""Old risk model fields in config should be silently ignored."""
config_dict = {
"llm": {
"api_key": "k",
"base_url": "u",
"calibrate_model": "normal",
"summary_model": "normal-summary",
"risk_calibrate_model": "risk-model",
"enable_risk_model_selection": True,
}
}
config = LLMConfig.from_dict(config_dict)
assert config.calibrate_model == "normal"
assert not hasattr(config, "risk_calibrate_model")
def test_segmentation_config(self):
"""Should parse segmentation config."""
config_dict = {
"llm": {
"api_key": "k", "base_url": "u",
"calibrate_model": "m", "summary_model": "s",
"segmentation": {
"segment_size": 1500,
"max_segment_size": 2500,
},
}
}
config = LLMConfig.from_dict(config_dict)
assert config.segment_size == 1500
assert config.max_segment_size == 2500
def test_missing_optional_fields_use_defaults(self):
"""Missing optional fields should use defaults."""
config_dict = {
"llm": {
"api_key": "k", "base_url": "u",
"calibrate_model": "m", "summary_model": "s",
}
}
config = LLMConfig.from_dict(config_dict)
assert config.max_retries == 3
assert config.concurrent_workers == 10
class TestLLMConfigGetModels:
"""Test get_models method."""
def test_returns_configured_models(self):
"""get_models should return all configured models."""
config = LLMConfig(
api_key="k", base_url="u",
calibrate_model="cal-model",
summary_model="sum-model",
)
models = config.get_models()
assert models["calibrate_model"] == "cal-model"
assert models["summary_model"] == "sum-model"
def test_no_has_risk_in_result(self):
"""get_models result should not contain has_risk field."""
config = LLMConfig(
api_key="k", base_url="u",
calibrate_model="m", summary_model="s",
)
models = config.get_models()
assert "has_risk" not in models