Skip to content

Commit c05ca1a

Browse files
he-jamesAssemblyAI
andauthored
chore: sync sdk code with DeepLearning repo (#190)
Co-authored-by: AssemblyAI <engineering.sdk@assemblyai.com>
1 parent 4217759 commit c05ca1a

4 files changed

Lines changed: 192 additions & 1 deletion

File tree

assemblyai/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.59.0"
1+
__version__ = "0.61.0"

assemblyai/types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,9 @@ class SpeakerIdentificationRequest(BaseModel):
645645
known_values: Optional[List[str]] = None
646646
"Known speaker values (required when speaker_type is 'role')"
647647

648+
speakers: Optional[List[Dict[str, Any]]] = None
649+
"Known speaker definitions with optional descriptions for improved accuracy"
650+
648651

649652
class TranslationRequest(BaseModel):
650653
"""
@@ -2380,13 +2383,22 @@ class TranscriptRequest(BaseTranscript):
23802383
"""
23812384

23822385

2386+
class TranscriptWarning(BaseModel):
2387+
"A warning about the transcription."
2388+
2389+
message: str
2390+
"The warning message."
2391+
2392+
23832393
class TranscriptMetadata(BaseModel):
23842394
"Metadata returned from the transcription API."
23852395

23862396
domain_used: Optional[str] = None
23872397
"The domain that was actually used for the transcription."
23882398
warning: Optional[str] = None
23892399
"An optional warning message, if applicable."
2400+
warnings: Optional[List[TranscriptWarning]] = None
2401+
"A list of warnings about the transcription."
23902402

23912403

23922404
class TranscriptResponse(BaseTranscript):
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import assemblyai as aai
2+
3+
4+
def test_speaker_identification_request_with_known_values_role():
5+
req = aai.SpeakerIdentificationRequest(
6+
speaker_type=aai.SpeakerType.role,
7+
known_values=["Agent", "Customer"],
8+
)
9+
assert req.speaker_type == aai.SpeakerType.role
10+
assert req.known_values == ["Agent", "Customer"]
11+
assert req.speakers is None
12+
13+
14+
def test_speaker_identification_request_with_known_values_name():
15+
req = aai.SpeakerIdentificationRequest(
16+
speaker_type=aai.SpeakerType.name,
17+
known_values=["Alice", "Bob"],
18+
)
19+
assert req.speaker_type == aai.SpeakerType.name
20+
assert req.known_values == ["Alice", "Bob"]
21+
assert req.speakers is None
22+
23+
24+
def test_speaker_identification_request_with_speakers_role():
25+
req = aai.SpeakerIdentificationRequest(
26+
speaker_type=aai.SpeakerType.role,
27+
speakers=[
28+
{
29+
"role": "Operador",
30+
"description": "Human agent who starts the call with a standard greeting",
31+
},
32+
{
33+
"role": "IVR",
34+
"description": "Automated system playing recorded messages",
35+
},
36+
{
37+
"role": "Customer",
38+
"description": "The person who called the service center",
39+
},
40+
],
41+
)
42+
assert req.speaker_type == aai.SpeakerType.role
43+
assert req.known_values is None
44+
assert len(req.speakers) == 3
45+
assert req.speakers[0]["role"] == "Operador"
46+
assert req.speakers[1]["role"] == "IVR"
47+
assert req.speakers[2]["role"] == "Customer"
48+
assert (
49+
req.speakers[0]["description"]
50+
== "Human agent who starts the call with a standard greeting"
51+
)
52+
53+
54+
def test_speaker_identification_request_with_speakers_name():
55+
req = aai.SpeakerIdentificationRequest(
56+
speaker_type=aai.SpeakerType.name,
57+
speakers=[
58+
{
59+
"name": "Michel Martin",
60+
"description": "Hosts the program and interviews the guests",
61+
},
62+
{
63+
"name": "Peter DeCarlo",
64+
"description": "Answers questions from the interview",
65+
},
66+
],
67+
)
68+
assert req.speaker_type == aai.SpeakerType.name
69+
assert req.known_values is None
70+
assert len(req.speakers) == 2
71+
assert req.speakers[0]["name"] == "Michel Martin"
72+
assert req.speakers[1]["name"] == "Peter DeCarlo"
73+
74+
75+
def test_speaker_identification_request_with_speakers_custom_properties():
76+
req = aai.SpeakerIdentificationRequest(
77+
speaker_type=aai.SpeakerType.name,
78+
speakers=[
79+
{
80+
"name": "Michel Martin",
81+
"description": "Hosts the program",
82+
"company": "NPR",
83+
"title": "Host Morning Edition",
84+
},
85+
],
86+
)
87+
assert req.speakers[0]["company"] == "NPR"
88+
assert req.speakers[0]["title"] == "Host Morning Edition"
89+
90+
91+
def test_speaker_identification_in_speech_understanding():
92+
config_args = {}
93+
config_args["speech_understanding"] = aai.SpeechUnderstandingRequest(
94+
request=aai.SpeechUnderstandingFeatureRequests(
95+
speaker_identification=aai.SpeakerIdentificationRequest(
96+
speaker_type=aai.SpeakerType.role,
97+
speakers=[
98+
{
99+
"role": "Operador",
100+
"description": "Human agent who starts the call with a standard greeting",
101+
},
102+
{
103+
"role": "IVR",
104+
"description": "Automated system playing recorded messages",
105+
},
106+
{
107+
"role": "Customer",
108+
"description": "The person who called the service center",
109+
},
110+
],
111+
)
112+
)
113+
)
114+
si = config_args["speech_understanding"].request.speaker_identification
115+
assert si.speaker_type == aai.SpeakerType.role
116+
assert len(si.speakers) == 3
117+
assert si.speakers[0]["role"] == "Operador"

tests/unit/test_transcript.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,3 +497,65 @@ def test_speech_model_used_field_missing():
497497

498498
# The field should be None when not present
499499
assert transcript_response.speech_model_used is None
500+
501+
502+
def test_metadata_warnings_present():
503+
"""
504+
Tests that metadata.warnings is properly deserialized when present.
505+
"""
506+
mock_transcript_response = factories.generate_dict_factory(
507+
factories.TranscriptCompletedResponseFactory
508+
)()
509+
mock_transcript_response["metadata"] = {
510+
"domain_used": None,
511+
"warnings": [
512+
{
513+
"message": "Skipped medical-v1 correction because the language is not supported"
514+
},
515+
],
516+
}
517+
518+
transcript_response = aai.types.TranscriptResponse(**mock_transcript_response)
519+
520+
assert transcript_response.metadata is not None
521+
assert transcript_response.metadata.warnings is not None
522+
assert len(transcript_response.metadata.warnings) == 1
523+
assert (
524+
transcript_response.metadata.warnings[0].message
525+
== "Skipped medical-v1 correction because the language is not supported"
526+
)
527+
assert transcript_response.metadata.domain_used is None
528+
529+
530+
def test_metadata_warnings_key_missing():
531+
"""
532+
Tests that metadata without a warnings key deserializes correctly.
533+
"""
534+
mock_transcript_response = factories.generate_dict_factory(
535+
factories.TranscriptCompletedResponseFactory
536+
)()
537+
mock_transcript_response["metadata"] = {
538+
"domain_used": None,
539+
}
540+
541+
transcript_response = aai.types.TranscriptResponse(**mock_transcript_response)
542+
543+
assert transcript_response.metadata is not None
544+
assert transcript_response.metadata.warnings is None
545+
assert transcript_response.metadata.domain_used is None
546+
547+
548+
def test_metadata_not_present():
549+
"""
550+
Tests that a response without metadata at all deserializes correctly.
551+
"""
552+
mock_transcript_response = factories.generate_dict_factory(
553+
factories.TranscriptCompletedResponseFactory
554+
)()
555+
556+
if "metadata" in mock_transcript_response:
557+
del mock_transcript_response["metadata"]
558+
559+
transcript_response = aai.types.TranscriptResponse(**mock_transcript_response)
560+
561+
assert transcript_response.metadata is None

0 commit comments

Comments
 (0)