-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathtest_key_info.py
More file actions
136 lines (119 loc) · 4.43 KB
/
test_key_info.py
File metadata and controls
136 lines (119 loc) · 4.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
KeyInfo unit tests.
Covers:
- KeyInfo dataclass creation and serialization
- from_dict with missing fields
- format_for_prompt output format
- Empty KeyInfo formatting
All console output must be in English only (no emoji, no Chinese).
"""
import pytest
from video_transcript_api.llm.core.key_info_extractor import KeyInfo
class TestKeyInfo:
"""Test KeyInfo dataclass."""
def test_create_with_all_fields(self):
"""Should create KeyInfo with all fields."""
ki = KeyInfo(
names=["Zhang San"],
places=["Beijing"],
technical_terms=["GPT-4"],
brands=["OpenAI"],
abbreviations=["AI"],
foreign_terms=["machine learning"],
other_entities=["2024"],
)
assert ki.names == ["Zhang San"]
assert ki.brands == ["OpenAI"]
def test_to_dict(self):
"""to_dict should return all fields."""
ki = KeyInfo(
names=["A"], places=["B"], technical_terms=["C"],
brands=["D"], abbreviations=["E"], foreign_terms=["F"],
other_entities=["G"],
)
d = ki.to_dict()
assert d["names"] == ["A"]
assert d["other_entities"] == ["G"]
assert len(d) == 7
def test_from_dict_complete(self):
"""from_dict with all fields should work."""
data = {
"names": ["X"], "places": ["Y"], "technical_terms": ["Z"],
"brands": ["W"], "abbreviations": ["V"], "foreign_terms": ["U"],
"other_entities": ["T"],
}
ki = KeyInfo.from_dict(data)
assert ki.names == ["X"]
assert ki.other_entities == ["T"]
def test_from_dict_missing_fields(self):
"""from_dict with missing fields should use empty lists."""
ki = KeyInfo.from_dict({"names": ["Alice"]})
assert ki.names == ["Alice"]
assert ki.places == []
assert ki.technical_terms == []
assert ki.brands == []
def test_from_dict_empty(self):
"""from_dict with empty dict should have all empty lists."""
ki = KeyInfo.from_dict({})
assert ki.names == []
assert ki.places == []
def test_roundtrip(self):
"""to_dict -> from_dict should preserve data."""
original = KeyInfo(
names=["A", "B"], places=["C"], technical_terms=[],
brands=["D"], abbreviations=[], foreign_terms=["E"],
other_entities=[],
)
restored = KeyInfo.from_dict(original.to_dict())
assert restored.names == original.names
assert restored.brands == original.brands
assert restored.foreign_terms == original.foreign_terms
class TestKeyInfoFormatForPrompt:
"""Test format_for_prompt output."""
def test_all_fields_populated(self):
"""All populated fields should appear in output."""
ki = KeyInfo(
names=["Alice", "Bob"],
places=["Tokyo"],
technical_terms=["PyTorch"],
brands=["Apple"],
abbreviations=["AI"],
foreign_terms=["deep learning"],
other_entities=["2024"],
)
text = ki.format_for_prompt()
assert "Alice" in text
assert "Bob" in text
assert "Tokyo" in text
assert "PyTorch" in text
assert "Apple" in text
def test_partial_fields(self):
"""Only populated fields should appear."""
ki = KeyInfo(
names=["Alice"], places=[], technical_terms=["GPT"],
brands=[], abbreviations=[], foreign_terms=[],
other_entities=[],
)
text = ki.format_for_prompt()
assert "Alice" in text
assert "GPT" in text
assert "brand" not in text.lower() # Empty field not shown
def test_empty_returns_default(self):
"""All empty fields should return default message."""
ki = KeyInfo(
names=[], places=[], technical_terms=[],
brands=[], abbreviations=[], foreign_terms=[],
other_entities=[],
)
text = ki.format_for_prompt()
assert text != "" # Should return default text
def test_format_contains_correct_labels(self):
"""Output should contain Chinese labels."""
ki = KeyInfo(
names=["X"], places=["Y"], technical_terms=[],
brands=[], abbreviations=[], foreign_terms=[],
other_entities=[],
)
text = ki.format_for_prompt()
assert "人名" in text
assert "地名" in text