Skip to content

Commit ae0be80

Browse files
authored
Merge branch 'main' into mitali/add-avzone-de
2 parents f0ef207 + 9806cf1 commit ae0be80

File tree

5 files changed

+184
-33
lines changed

5 files changed

+184
-33
lines changed

src/together/abstract/api_requestor.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -619,14 +619,29 @@ def _interpret_response(
619619
) -> Tuple[TogetherResponse | Iterator[TogetherResponse], bool]:
620620
"""Returns the response(s) and a bool indicating whether it is a stream."""
621621
content_type = result.headers.get("Content-Type", "")
622+
622623
if stream and "text/event-stream" in content_type:
624+
# SSE format streaming
623625
return (
624626
self._interpret_response_line(
625627
line, result.status_code, result.headers, stream=True
626628
)
627629
for line in parse_stream(result.iter_lines())
628630
), True
631+
elif stream and content_type in [
632+
"audio/wav",
633+
"audio/mpeg",
634+
"application/octet-stream",
635+
]:
636+
# Binary audio streaming - return chunks as binary data
637+
def binary_stream_generator() -> Iterator[TogetherResponse]:
638+
for chunk in result.iter_content(chunk_size=8192):
639+
if chunk: # Skip empty chunks
640+
yield TogetherResponse(chunk, dict(result.headers))
641+
642+
return binary_stream_generator(), True
629643
else:
644+
# Non-streaming response
630645
if content_type in ["application/octet-stream", "audio/wav", "audio/mpeg"]:
631646
content = result.content
632647
else:
@@ -648,23 +663,49 @@ async def _interpret_async_response(
648663
| tuple[TogetherResponse, bool]
649664
):
650665
"""Returns the response(s) and a bool indicating whether it is a stream."""
651-
if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
666+
content_type = result.headers.get("Content-Type", "")
667+
668+
if stream and "text/event-stream" in content_type:
669+
# SSE format streaming
652670
return (
653671
self._interpret_response_line(
654672
line, result.status, result.headers, stream=True
655673
)
656674
async for line in parse_stream_async(result.content)
657675
), True
676+
elif stream and content_type in [
677+
"audio/wav",
678+
"audio/mpeg",
679+
"application/octet-stream",
680+
]:
681+
# Binary audio streaming - return chunks as binary data
682+
async def binary_stream_generator() -> (
683+
AsyncGenerator[TogetherResponse, None]
684+
):
685+
async for chunk in result.content.iter_chunked(8192):
686+
if chunk: # Skip empty chunks
687+
yield TogetherResponse(chunk, dict(result.headers))
688+
689+
return binary_stream_generator(), True
658690
else:
691+
# Non-streaming response
659692
try:
660-
await result.read()
693+
content = await result.read()
661694
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
662695
raise error.Timeout("Request timed out") from e
663696
except aiohttp.ClientError as e:
664697
utils.log_warn(e, body=result.content)
698+
699+
if content_type in ["application/octet-stream", "audio/wav", "audio/mpeg"]:
700+
# Binary content - keep as bytes
701+
response_content: str | bytes = content
702+
else:
703+
# Text content - decode to string
704+
response_content = content.decode("utf-8")
705+
665706
return (
666707
self._interpret_response_line(
667-
(await result.read()).decode("utf-8"),
708+
response_content,
668709
result.status,
669710
result.headers,
670711
stream=False,

src/together/resources/audio/speech.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def create(
3030
response_format: str = "wav",
3131
language: str = "en",
3232
response_encoding: str = "pcm_f32le",
33-
sample_rate: int = 44100,
33+
sample_rate: int | None = None,
3434
stream: bool = False,
3535
**kwargs: Any,
3636
) -> AudioSpeechStreamResponse:
@@ -49,14 +49,20 @@ def create(
4949
response_encoding (str, optional): Audio encoding of response.
5050
Defaults to "pcm_f32le".
5151
sample_rate (int, optional): Sampling rate to use for the output audio.
52-
Defaults to 44100.
52+
Defaults to None. If not provided, the default sampling rate for the model will be used.
5353
stream (bool, optional): If true, output is streamed for several characters at a time.
5454
Defaults to False.
5555
5656
Returns:
5757
Union[bytes, Iterator[AudioSpeechStreamChunk]]: The generated audio as bytes or an iterator over audio stream chunks.
5858
"""
5959

60+
if sample_rate is None:
61+
if "cartesia" in model:
62+
sample_rate = 44100
63+
else:
64+
sample_rate = 24000
65+
6066
requestor = api_requestor.APIRequestor(
6167
client=self._client,
6268
)

src/together/resources/audio/transcriptions.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def create(
3030
timestamp_granularities: Optional[
3131
Union[str, AudioTimestampGranularities]
3232
] = None,
33+
diarize: bool = False,
3334
**kwargs: Any,
3435
) -> Union[AudioTranscriptionResponse, AudioTranscriptionVerboseResponse]:
3536
"""
@@ -52,7 +53,11 @@ def create(
5253
timestamp_granularities: The timestamp granularities to populate for this
5354
transcription. response_format must be set verbose_json to use timestamp
5455
granularities. Either or both of these options are supported: word, or segment.
55-
56+
diarize: Whether to enable speaker diarization. When enabled, you will get the speaker id for each word in the transcription.
57+
In the response, in the words array, you will get the speaker id for each word.
58+
In addition, we also return the speaker_segments array which contains the speaker id for each speaker segment along with the start and end time of the segment along with all the words in the segment.
59+
You can use the speaker_id to group the words by speaker.
60+
You can use the speaker_segments to get the start and end time of each speaker segment.
5661
Returns:
5762
The transcribed text in the requested format.
5863
"""
@@ -103,6 +108,9 @@ def create(
103108
else timestamp_granularities
104109
)
105110

111+
if diarize:
112+
params_data["diarize"] = diarize
113+
106114
# Add any additional kwargs
107115
# Convert boolean values to lowercase strings for proper form encoding
108116
for key, value in kwargs.items():
@@ -135,6 +143,7 @@ def create(
135143
if (
136144
response_format == "verbose_json"
137145
or response_format == AudioTranscriptionResponseFormat.VERBOSE_JSON
146+
or diarize
138147
):
139148
# Create response with model validation that preserves extra fields
140149
return AudioTranscriptionVerboseResponse.model_validate(response.data)
@@ -158,6 +167,7 @@ async def create(
158167
timestamp_granularities: Optional[
159168
Union[str, AudioTimestampGranularities]
160169
] = None,
170+
diarize: bool = False,
161171
**kwargs: Any,
162172
) -> Union[AudioTranscriptionResponse, AudioTranscriptionVerboseResponse]:
163173
"""
@@ -180,7 +190,11 @@ async def create(
180190
timestamp_granularities: The timestamp granularities to populate for this
181191
transcription. response_format must be set verbose_json to use timestamp
182192
granularities. Either or both of these options are supported: word, or segment.
183-
193+
diarize: Whether to enable speaker diarization. When enabled, you will get the speaker id for each word in the transcription.
194+
In the response, in the words array, you will get the speaker id for each word.
195+
In addition, we also return the speaker_segments array which contains the speaker id for each speaker segment along with the start and end time of the segment along with all the words in the segment.
196+
You can use the speaker_id to group the words by speaker.
197+
You can use the speaker_segments to get the start and end time of each speaker segment.
184198
Returns:
185199
The transcribed text in the requested format.
186200
"""
@@ -239,6 +253,9 @@ async def create(
239253
)
240254
)
241255

256+
if diarize:
257+
params_data["diarize"] = diarize
258+
242259
# Add any additional kwargs
243260
# Convert boolean values to lowercase strings for proper form encoding
244261
for key, value in kwargs.items():
@@ -271,6 +288,7 @@ async def create(
271288
if (
272289
response_format == "verbose_json"
273290
or response_format == AudioTranscriptionResponseFormat.VERBOSE_JSON
291+
or diarize
274292
):
275293
# Create response with model validation that preserves extra fields
276294
return AudioTranscriptionVerboseResponse.model_validate(response.data)

src/together/types/audio_speech.py

Lines changed: 112 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,27 +82,126 @@ class AudioSpeechStreamResponse(BaseModel):
8282

8383
model_config = ConfigDict(arbitrary_types_allowed=True)
8484

85-
def stream_to_file(self, file_path: str) -> None:
85+
def stream_to_file(
86+
self, file_path: str, response_format: AudioResponseFormat | str | None = None
87+
) -> None:
88+
"""
89+
Save the audio response to a file.
90+
91+
For non-streaming responses, writes the complete file as received.
92+
For streaming responses, collects binary chunks and constructs a valid
93+
file format based on the response_format parameter.
94+
95+
Args:
96+
file_path: Path where the audio file should be saved.
97+
response_format: Format of the audio (wav, mp3, or raw). If not provided,
98+
will attempt to infer from file extension or default to wav.
99+
"""
100+
# Determine response format
101+
if response_format is None:
102+
# Infer from file extension
103+
ext = file_path.lower().split(".")[-1] if "." in file_path else ""
104+
if ext in ["wav"]:
105+
response_format = AudioResponseFormat.WAV
106+
elif ext in ["mp3", "mpeg"]:
107+
response_format = AudioResponseFormat.MP3
108+
elif ext in ["raw", "pcm"]:
109+
response_format = AudioResponseFormat.RAW
110+
else:
111+
# Default to WAV if unknown
112+
response_format = AudioResponseFormat.WAV
113+
114+
if isinstance(response_format, str):
115+
response_format = AudioResponseFormat(response_format)
116+
86117
if isinstance(self.response, TogetherResponse):
87-
# save response to file
118+
# Non-streaming: save complete file
88119
with open(file_path, "wb") as f:
89120
f.write(self.response.data)
90121

91122
elif isinstance(self.response, Iterator):
123+
# Streaming: collect binary chunks
124+
audio_chunks = []
125+
for chunk in self.response:
126+
if isinstance(chunk.data, bytes):
127+
audio_chunks.append(chunk.data)
128+
elif isinstance(chunk.data, dict):
129+
# SSE format with JSON/base64
130+
try:
131+
stream_event = AudioSpeechStreamEventResponse(
132+
response={"data": chunk.data}
133+
)
134+
if isinstance(stream_event.response, StreamSentinel):
135+
break
136+
audio_chunks.append(
137+
base64.b64decode(stream_event.response.data.b64)
138+
)
139+
except Exception:
140+
continue # Skip malformed chunks
141+
142+
if not audio_chunks:
143+
raise ValueError("No audio data received in streaming response")
144+
145+
# Concatenate all chunks
146+
audio_data = b"".join(audio_chunks)
147+
92148
with open(file_path, "wb") as f:
93-
for chunk in self.response:
94-
# Try to parse as stream chunk
95-
stream_event_response = AudioSpeechStreamEventResponse(
96-
response={"data": chunk.data}
149+
if response_format == AudioResponseFormat.WAV:
150+
if audio_data.startswith(b"RIFF"):
151+
# Already a valid WAV file
152+
f.write(audio_data)
153+
else:
154+
# Raw PCM - add WAV header
155+
self._write_wav_header(f, audio_data)
156+
elif response_format == AudioResponseFormat.MP3:
157+
# MP3 format: Check if data is actually MP3 or raw PCM
158+
# MP3 files start with ID3 tag or sync word (0xFF 0xFB/0xFA/0xF3/0xF2)
159+
is_mp3 = audio_data.startswith(b"ID3") or (
160+
len(audio_data) > 0
161+
and audio_data[0:1] == b"\xff"
162+
and len(audio_data) > 1
163+
and audio_data[1] & 0xE0 == 0xE0
97164
)
98165

99-
if isinstance(stream_event_response.response, StreamSentinel):
100-
break
101-
102-
# decode base64
103-
audio = base64.b64decode(stream_event_response.response.data.b64)
104-
105-
f.write(audio)
166+
if is_mp3:
167+
f.write(audio_data)
168+
else:
169+
raise ValueError("Invalid MP3 data received.")
170+
else:
171+
# RAW format: write PCM data as-is
172+
f.write(audio_data)
173+
174+
@staticmethod
175+
def _write_wav_header(file_handle: BinaryIO, audio_data: bytes) -> None:
176+
"""
177+
Write WAV file header for raw PCM audio data.
178+
179+
Uses default TTS parameters: 16-bit PCM, mono, 24000 Hz sample rate.
180+
"""
181+
import struct
182+
183+
sample_rate = 24000
184+
num_channels = 1
185+
bits_per_sample = 16
186+
byte_rate = sample_rate * num_channels * bits_per_sample // 8
187+
block_align = num_channels * bits_per_sample // 8
188+
data_size = len(audio_data)
189+
190+
# Write WAV header
191+
file_handle.write(b"RIFF")
192+
file_handle.write(struct.pack("<I", 36 + data_size)) # File size - 8
193+
file_handle.write(b"WAVE")
194+
file_handle.write(b"fmt ")
195+
file_handle.write(struct.pack("<I", 16)) # fmt chunk size
196+
file_handle.write(struct.pack("<H", 1)) # Audio format (1 = PCM)
197+
file_handle.write(struct.pack("<H", num_channels))
198+
file_handle.write(struct.pack("<I", sample_rate))
199+
file_handle.write(struct.pack("<I", byte_rate))
200+
file_handle.write(struct.pack("<H", block_align))
201+
file_handle.write(struct.pack("<H", bits_per_sample))
202+
file_handle.write(b"data")
203+
file_handle.write(struct.pack("<I", data_size))
204+
file_handle.write(audio_data)
106205

107206

108207
class AudioTranscriptionResponseFormat(str, Enum):

tests/integration/resources/test_transcriptions.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,6 @@ def validate_diarization_response(response_dict):
3636
assert "end" in word
3737
assert "speaker_id" in word
3838

39-
# Validate top-level words field
40-
assert "words" in response_dict
41-
assert isinstance(response_dict["words"], list)
42-
assert len(response_dict["words"]) > 0
43-
44-
# Validate each word in top-level words
45-
for word in response_dict["words"]:
46-
assert "id" in word
47-
assert "word" in word
48-
assert "start" in word
49-
assert "end" in word
50-
assert "speaker_id" in word
51-
5239

5340
class TestTogetherTranscriptions:
5441
@pytest.fixture

0 commit comments

Comments
 (0)