-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathtest_multi_user.py
More file actions
266 lines (218 loc) · 9.24 KB
/
test_multi_user.py
File metadata and controls
266 lines (218 loc) · 9.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
"""
多用户功能测试模块
测试用户管理、鉴权和审计功能。
"""
import os
import sys
import tempfile
import json
import unittest
from unittest.mock import patch, MagicMock
# 添加项目根目录到 Python 路径
from video_transcript_api.utils.accounts import UserManager
from video_transcript_api.utils.logging import AuditLogger
class TestMultiUser(unittest.TestCase):
"""多用户功能测试类"""
def setUp(self):
"""测试前的准备工作"""
# 创建临时用户配置文件
self.temp_dir = tempfile.mkdtemp()
self.users_config_path = os.path.join(self.temp_dir, "users.json")
self.audit_db_path = os.path.join(self.temp_dir, "audit.db")
# 创建测试用户配置
self.test_users_config = {
"users": {
"sk-test001-abcdefghij": {
"user_id": "test_user_001",
"name": "测试用户001",
"wechat_webhook": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test1",
"created_at": "2025-01-01T00:00:00Z",
"enabled": True
},
"sk-test002-klmnopqrst": {
"user_id": "test_user_002",
"name": "测试用户002",
"wechat_webhook": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test2",
"created_at": "2025-01-01T00:00:00Z",
"enabled": True
},
"sk-test003-uvwxyz1234": {
"user_id": "test_user_003",
"name": "已禁用用户",
"wechat_webhook": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test3",
"created_at": "2025-01-01T00:00:00Z",
"enabled": False
}
}
}
# 写入用户配置文件
with open(self.users_config_path, 'w', encoding='utf-8') as f:
json.dump(self.test_users_config, f, indent=2, ensure_ascii=False)
# 创建回退配置
self.fallback_config = {
"api": {
"auth_token": "legacy-token-123456"
},
"wechat": {
"webhook": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=legacy"
}
}
def tearDown(self):
"""测试后的清理工作"""
import shutil
if os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir)
def test_user_manager_initialization(self):
"""测试用户管理器初始化"""
user_manager = UserManager(
users_config_path=self.users_config_path,
fallback_config=self.fallback_config
)
# 验证多用户模式
self.assertTrue(user_manager.is_multi_user_mode())
self.assertEqual(user_manager.get_user_count(), 3)
def test_valid_token_authentication(self):
"""测试有效令牌认证"""
user_manager = UserManager(
users_config_path=self.users_config_path,
fallback_config=self.fallback_config
)
# 测试有效令牌
user_info = user_manager.validate_token("sk-test001-abcdefghij")
self.assertIsNotNone(user_info)
self.assertEqual(user_info["user_id"], "test_user_001")
self.assertEqual(user_info["name"], "测试用户001")
self.assertTrue(user_info["enabled"])
# 测试启用状态的用户
user_info2 = user_manager.validate_token("sk-test002-klmnopqrst")
self.assertIsNotNone(user_info2)
self.assertEqual(user_info2["user_id"], "test_user_002")
def test_disabled_user_authentication(self):
"""测试已禁用用户认证"""
user_manager = UserManager(
users_config_path=self.users_config_path,
fallback_config=self.fallback_config
)
# 测试已禁用用户
user_info = user_manager.validate_token("sk-test003-uvwxyz1234")
self.assertIsNone(user_info)
def test_invalid_token_authentication(self):
"""测试无效令牌认证"""
user_manager = UserManager(
users_config_path=self.users_config_path,
fallback_config=self.fallback_config
)
# 测试无效令牌
user_info = user_manager.validate_token("invalid-token")
self.assertIsNone(user_info)
def test_fallback_token_authentication(self):
"""测试回退令牌认证"""
# 测试没有用户配置文件的情况
non_existent_path = os.path.join(self.temp_dir, "non_existent_users.json")
user_manager = UserManager(
users_config_path=non_existent_path,
fallback_config=self.fallback_config
)
# 验证单token回退模式
self.assertFalse(user_manager.is_multi_user_mode())
self.assertEqual(user_manager.get_user_count(), 1)
# 测试回退令牌
user_info = user_manager.validate_token("legacy-token-123456")
self.assertIsNotNone(user_info)
self.assertEqual(user_info["user_id"], "legacy_user")
self.assertTrue(user_info["is_legacy"])
def test_get_user_webhook(self):
"""测试获取用户webhook"""
user_manager = UserManager(
users_config_path=self.users_config_path,
fallback_config=self.fallback_config
)
# 测试获取用户webhook
webhook = user_manager.get_user_webhook("sk-test001-abcdefghij")
self.assertEqual(webhook, "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test1")
# 测试无效令牌
webhook = user_manager.get_user_webhook("invalid-token")
self.assertIsNone(webhook)
def test_audit_logger_initialization(self):
"""测试审计日志记录器初始化"""
audit_logger = AuditLogger(db_path=self.audit_db_path)
# 验证数据库文件创建
self.assertTrue(os.path.exists(self.audit_db_path))
def test_audit_log_api_call(self):
"""测试API调用日志记录"""
audit_logger = AuditLogger(db_path=self.audit_db_path)
# 记录API调用
success = audit_logger.log_api_call(
api_key="sk-test001-abcdefghij",
user_id="test_user_001",
endpoint="/api/transcribe",
video_url="https://example.com/video",
processing_time_ms=1500,
status_code=202,
task_id="task_123",
user_agent="TestAgent/1.0",
remote_ip="192.168.1.100"
)
self.assertTrue(success)
def test_audit_get_user_stats(self):
"""测试获取用户统计"""
audit_logger = AuditLogger(db_path=self.audit_db_path)
# 记录一些测试数据
for i in range(5):
audit_logger.log_api_call(
api_key="sk-test001-abcdefghij",
user_id="test_user_001",
endpoint="/api/transcribe",
status_code=200
)
# 获取用户统计
stats = audit_logger.get_user_stats("test_user_001", 30)
self.assertEqual(stats["user_id"], "test_user_001")
self.assertEqual(stats["total_calls"], 5)
self.assertGreaterEqual(stats["active_days"], 1)
def test_audit_get_recent_calls(self):
"""测试获取最近调用记录"""
audit_logger = AuditLogger(db_path=self.audit_db_path)
# 记录测试数据
audit_logger.log_api_call(
api_key="sk-test001-abcdefghij",
user_id="test_user_001",
endpoint="/api/transcribe",
video_url="https://example.com/video1",
status_code=200
)
audit_logger.log_api_call(
api_key="sk-test002-klmnopqrst",
user_id="test_user_002",
endpoint="/api/task/123",
status_code=200
)
# 获取指定用户的记录
calls = audit_logger.get_recent_calls("test_user_001", 10)
self.assertEqual(len(calls), 1)
self.assertEqual(calls[0]["user_id"], "test_user_001")
self.assertEqual(calls[0]["endpoint"], "/api/transcribe")
# 获取所有用户的记录
all_calls = audit_logger.get_recent_calls(None, 10)
self.assertEqual(len(all_calls), 2)
def test_api_key_masking(self):
"""测试API密钥脱敏"""
user_manager = UserManager(
users_config_path=self.users_config_path,
fallback_config=self.fallback_config
)
# 测试正常长度的密钥
masked = user_manager._mask_api_key("sk-test001-abcdefghij")
# API密钥长度为22,前4后4,中间14个*
self.assertEqual(masked, "sk-t*************ghij")
# 测试短密钥
masked_short = user_manager._mask_api_key("short")
self.assertEqual(masked_short, "****")
# 测试空密钥
masked_empty = user_manager._mask_api_key("")
self.assertEqual(masked_empty, "****")
def run_tests():
"""运行所有测试"""
unittest.main()
if __name__ == "__main__":
run_tests()