Skip to content

Commit 40cd801

Browse files
committed
[text-to-audio-generator] Add new function to hub
1 parent 8dd33c3 commit 40cd801

File tree

7 files changed

+886
-0
lines changed

7 files changed

+886
-0
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Client: I love MLRun!
2+
Agent: Me too!

text_to_audio_generator/function.yaml

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
kind: job
2+
metadata:
3+
name: text-to-audio-generator
4+
tag: ''
5+
hash: f36d56d620c6a69f414c9cb90e42ec012847a607
6+
project: ''
7+
labels:
8+
author: yonatans
9+
categories:
10+
- data-preparation
11+
- machine-learning
12+
spec:
13+
command: ''
14+
args: []
15+
image: ''
16+
build:
17+
functionSourceCode: # Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import pathlib
import random
from typing import Dict, List, Optional, Tuple, Union

import bark
import numpy as np
import pandas as pd
import torch
import torchaudio
import tqdm

# Get the global logger:
_LOGGER = logging.getLogger()


def generate_multi_speakers_audio(
    data_path: str,
    output_directory: str,
    speakers: Union[List[str], Dict[str, int]],
    available_voices: List[str],
    use_gpu: bool = True,
    use_small_models: bool = False,
    offload_cpu: bool = False,
    sample_rate: int = 16000,
    file_format: str = "wav",
    verbose: bool = True,
    bits_per_sample: Optional[int] = None,
) -> Tuple[str, pd.DataFrame, dict]:
    """

    :param data_path:           Path to the text file or directory containing the text files to generate audio from.
    :param output_directory:    Path to the directory to save the generated audio files to.
    :param speakers:            List / Dict of speakers to generate audio for.
                                If a list is given, the speakers will be assigned to channels in the order given.
                                If dictionary, the keys will be the speakers and the values will be the channels.
    :param available_voices:    List of available voices to use for the generation.
                        See here for the available voices:
                        https://suno-ai.notion.site/8b8e8749ed514b0cbf3f699013548683?v=bc67cff786b04b50b3ceb756fd05f68c
    :param use_gpu:             Whether to use the GPU for the generation.
    :param use_small_models:    Whether to use the small models for the generation.
    :param offload_cpu:         TODO: What does this do?
    :param sample_rate:         The sampling rate of the generated audio.
    :param file_format:         The format of the generated audio files.
    :param verbose:             Whether to print the progress of the generation.
    :param bits_per_sample:     Changes the bit depth for the supported formats.
                                Supported only in "wav" or "flac" formats.

    :returns:                   A tuple of:
                                - The output directory path.
                                - The generated audio files dataframe.
                                - The errors dictionary.
    """

    global _LOGGER
    _LOGGER = _get_logger()
    # Get the input text files to turn to audio:
    data_path = pathlib.Path(data_path).absolute()
    text_files = _get_text_files(data_path=data_path)

    # Load the bark models according to the given configurations:
    bark.preload_models(
        text_use_gpu=use_gpu,
        text_use_small=use_small_models,
        coarse_use_gpu=use_gpu,
        coarse_use_small=use_small_models,
        fine_use_gpu=use_gpu,
        fine_use_small=use_small_models,
        codec_use_gpu=use_gpu,
        force_reload=offload_cpu,
    )

    # Check for per channel generation:
    if isinstance(speakers, dict):
        speaker_per_channel = True
        # Sort the given speakers by channels:
        speakers = {
            speaker: channel
            for speaker, channel in sorted(speakers.items(), key=lambda item: item[1])
        }
    else:
        speaker_per_channel = False

    # Prepare the resampling module:
    resampler = torchaudio.transforms.Resample(
        orig_freq=bark.SAMPLE_RATE, new_freq=sample_rate, dtype=torch.float32
    )

    # Prepare the gap between each speaker:
    gap_between_speakers = np.zeros(int(0.5 * bark.SAMPLE_RATE))

    # Prepare the successes dataframe and errors dictionary to be returned:
    successes = []
    errors = {}

    # Create the output directory:
    output_directory = pathlib.Path(output_directory)
    output_directory.mkdir(exist_ok=True)

    # Start generating audio:
    # Go over the audio files and transcribe:
    for text_file in tqdm.tqdm(
        text_files, desc="Generating", unit="file", disable=not verbose
    ):

        try:
            # Randomize voices for each speaker:
            chosen_voices = {}
            available_voices_copy = available_voices.copy()
            for speaker in speakers:
                voice = random.choice(available_voices_copy)
                chosen_voices[speaker] = voice
                available_voices_copy.remove(voice)
            # Read text:
            with open(text_file, "r") as fp:
                text = fp.read()
            # Prepare a holder for all the generated pieces (if per channel each speaker will have its own):
            audio_pieces = (
                {speaker: [] for speaker in speakers}
                if speaker_per_channel
                else {"all": []}
            )

            # Generate audio per line:
            for line in text.splitlines():
                # Validate line is in correct speaker format:

                if ": " not in line:
                    if verbose:
                        _LOGGER.warning(f"Skipping line: {line}")
                    continue
                # Split line to speaker and his words:
                current_speaker, sentences = line.split(": ", 1)
                # Validate speaker is known:
                if current_speaker not in speakers:
                    raise ValueError(
                        f"Unknown speaker: {current_speaker}. Given speakers are: {speakers}"
                    )
                for sentence in _split_line(line=sentences):
                    # Generate words audio:
                    audio = bark.generate_audio(
                        sentence,
                        history_prompt=chosen_voices[current_speaker],
                        silent=True,
                    )
                    if speaker_per_channel:
                        silence = np.zeros_like(audio)
                        for speaker in audio_pieces.keys():
                            if speaker == current_speaker:
                                audio_pieces[speaker] += [audio, gap_between_speakers]
                            else:
                                audio_pieces[speaker] += [silence, gap_between_speakers]
                    else:
                        audio_pieces["all"] += [audio, gap_between_speakers]
            # Construct a single audio array from all the pieces and channels:

            audio = np.vstack(
                [np.concatenate(audio_pieces[speaker]) for speaker in speakers]
            ).astype(dtype=np.float32)
            # Resample:
            audio = torch.from_numpy(audio)
            audio = resampler(audio)
            # Save to audio file:
            audio_file = output_directory / f"{text_file.stem}.{file_format}"

            torchaudio.save(
                uri=str(audio_file),
                src=audio,
                sample_rate=sample_rate,
                format=file_format,
                bits_per_sample=bits_per_sample,
            )

            # Collect to the successes:
            successes.append([text_file.name, audio_file.name])
        except Exception as exception:
            # Note the exception as error in the dictionary:
            if verbose:
                _LOGGER.warning(f"Error in file: '{text_file.name}'")
            print(exception)
            errors[text_file.name] = str(exception)

    # Construct the translations dataframe:
    successes = pd.DataFrame(
        successes,
        columns=["text_file", "audio_file"],
    )

    # Print the head of the produced dataframe and return:
    if verbose:
        _LOGGER.info(
            f"Done ({successes.shape[0]}/{len(text_files)})\n"
            f"Translations summary:\n"
            f"{successes.head()}"
        )
    return str(output_directory), successes, errors


def _get_text_files(
    data_path: pathlib.Path,
) -> List[pathlib.Path]:
    # Check if the path is of a directory or a file:
    if data_path.is_dir():
        # Get all files inside the directory:
        text_files = list(data_path.glob("*.*"))
    elif data_path.is_file():
        text_files = [data_path]
    else:
        raise ValueError(
            f"Unrecognized data path. The parameter `data_path` must be either a directory path or a file path. "
            f"Given: {str(data_path)} "
        )

    return text_files


def _split_line(line: str, max_length: int = 250) -> List[str]:
    if len(line) < max_length:
        return [line]

    sentences = [
        f"{sentence.strip()}." for sentence in line.split(".") if sentence.strip()
    ]

    splits = []
    current_length = len(sentences[0])
    split = sentences[0]
    for sentence in sentences[1:]:
        if current_length + len(sentence) > max_length:
            splits.append(split)
            split = sentence
            current_length = len(sentence)
        else:
            current_length += len(sentence)
            split += " " + sentence
    if split:
        splits.append(split)

    return splits


def _get_logger():
    global _LOGGER
    try:
        import mlrun
        # Check if MLRun is available:
        context = mlrun.get_or_create_ctx(name="mlrun")
        return context.logger
    except ModuleNotFoundError:
        return _LOGGER

18+
base_image: mlrun/mlrun
19+
commands: []
20+
code_origin: ''
21+
origin_filename: ''
22+
requirements:
23+
- bark
24+
- torchaudio
25+
entry_points:
26+
generate_multi_speakers_audio:
27+
name: generate_multi_speakers_audio
28+
doc: ''
29+
parameters:
30+
- name: data_path
31+
type: str
32+
doc: Path to the text file or directory containing the text files to generate
33+
audio from.
34+
- name: output_directory
35+
type: str
36+
doc: Path to the directory to save the generated audio files to.
37+
- name: speakers
38+
type: Union[List[str], Dict[str, int]]
39+
doc: List / Dict of speakers to generate audio for. If a list is given, the
40+
speakers will be assigned to channels in the order given. If dictionary,
41+
the keys will be the speakers and the values will be the channels.
42+
- name: available_voices
43+
type: List[str]
44+
doc: 'List of available voices to use for the generation. See here for the
45+
available voices: https://suno-ai.notion.site/8b8e8749ed514b0cbf3f699013548683?v=bc67cff786b04b50b3ceb756fd05f68c'
46+
- name: use_gpu
47+
type: bool
48+
doc: Whether to use the GPU for the generation.
49+
default: true
50+
- name: use_small_models
51+
type: bool
52+
doc: Whether to use the small models for the generation.
53+
default: false
54+
- name: offload_cpu
55+
type: bool
56+
doc: 'TODO: What does this do?'
57+
default: false
58+
- name: sample_rate
59+
type: int
60+
doc: The sampling rate of the generated audio.
61+
default: 16000
62+
- name: file_format
63+
type: str
64+
doc: The format of the generated audio files.
65+
default: wav
66+
- name: verbose
67+
type: bool
68+
doc: Whether to print the progress of the generation.
69+
default: true
70+
- name: bits_per_sample
71+
type: Optional[int]
72+
doc: Changes the bit depth for the supported formats. Supported only in "wav"
73+
or "flac" formats.
74+
default: null
75+
outputs:
76+
- doc: 'A tuple of: - The output directory path. - The generated audio files
77+
dataframe. - The errors dictionary.'
78+
default: ''
79+
lineno: 30
80+
description: Generate audio file from text using different speakers
81+
default_handler: generate_multi_speakers_audio
82+
disable_auto_mount: false
83+
clone_target_dir: ''
84+
env: []
85+
priority_class_name: ''
86+
preemption_mode: prevent
87+
affinity: null
88+
tolerations: null
89+
security_context: {}
90+
verbose: false

text_to_audio_generator/item.yaml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
apiVersion: v1
2+
categories:
3+
- data-preparation
4+
- machine-learning
5+
description: Generate audio file from text using different speakers
6+
doc: ''
7+
example: text_to_audio_generator.ipynb
8+
generationDate: 2023-12-03:15-30
9+
hidden: false
10+
icon: ''
11+
labels:
12+
author: yonatans
13+
maintainers: []
14+
marketplaceType: ''
15+
mlrunVersion: 1.5.2
16+
name: text_to_audio_generator
17+
platformVersion: 3.5.3
18+
spec:
19+
filename: text_to_audio_generator.py
20+
handler: generate_multi_speakers_audio
21+
image: mlrun/mlrun
22+
kind: job
23+
requirements:
24+
- bark
25+
- torchaudio
26+
url: ''
27+
version: 1.0.0
28+
test_valid: True
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
bark
2+
torchaudio>=2.1.0
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2023 Iguazio
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import mlrun
16+
import tempfile
17+
import pytest
18+
19+
20+
@pytest.mark.parametrize("file_format,bits_per_sample", [("wav", 8), ("mp3", None)])
21+
def test_generate_multi_speakers_audio(file_format, bits_per_sample):
22+
text_to_audio_generator_function = mlrun.import_function("function.yaml")
23+
with tempfile.TemporaryDirectory() as test_directory:
24+
function_run = text_to_audio_generator_function.run(
25+
handler="generate_multi_speakers_audio",
26+
inputs={"data_path": "data/test_data.txt"},
27+
params={
28+
"output_directory": test_directory,
29+
"speakers": {"Agent": 0, "Client": 1},
30+
"available_voices": [
31+
"v2/en_speaker_0",
32+
"v2/en_speaker_1",
33+
],
34+
"use_small_models": True,
35+
"use_gpu": False,
36+
"offload_cpu": True,
37+
"file_format": file_format,
38+
"bits_per_sample": bits_per_sample,
39+
},
40+
local=True,
41+
returns=[
42+
"audio_files: path",
43+
"audio_files_dataframe: dataset",
44+
"text_to_speech_errors: file",
45+
],
46+
artifact_path=test_directory,
47+
)
48+
assert function_run.error == "Run state (completed) is not in error state"
49+
for key in ["audio_files", "audio_files_dataframe", "text_to_speech_errors"]:
50+
assert key in function_run.outputs and function_run.outputs[key] is not None

0 commit comments

Comments
 (0)