Skip to content

Commit

Permalink
feat(tts): Support AzureTTS. (#474)
Browse files Browse the repository at this point in the history
* feat(tts): Support AzureTTS.

1. Add `AzureTTS`
2. Update `README.md`
3. Add `azure_tts_speech_key` and `azure_tts_service_region` to `Config`
4. Update example configure file and parser
5. Add `azure-cognitiveservices-speech` to dependencies.

* style(tts/azure): Use logger and sort imports.

* fix: better requirements and close stream when use tts(not mi)

Signed-off-by: yihong0618 <[email protected]>

---------

Signed-off-by: yihong0618 <[email protected]>
Co-authored-by: yihong0618 <[email protected]>
  • Loading branch information
laipz8200 and yihong0618 authored Apr 12, 2024
1 parent 0e73138 commit d0dfa75
Show file tree
Hide file tree
Showing 11 changed files with 669 additions and 226 deletions.
65 changes: 34 additions & 31 deletions README.md

Large diffs are not rendered by default.

638 changes: 466 additions & 172 deletions pdm.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies = [
"aiohttp",
"rich",
"zhipuai==2.0.1",
"httpx==0.24.1",
"bardapi",
"edge-tts>=6.1.3",
"EdgeGPT==0.1.26",
Expand All @@ -26,6 +27,8 @@ dependencies = [
"google-generativeai",
"numexpr>=2.8.6",
"dashscope==1.10.0",
"httpcore==0.15.0",
"azure-cognitiveservices-speech>=1.37.0",
]
license = {text = "MIT"}
dynamic = ["version"]
Expand Down
51 changes: 33 additions & 18 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,72 +1,85 @@
# This file is @generated by PDM.
# Please do not edit it manually.

aiohttp==3.9.2
aiohttp==3.9.4
aiosignal==1.3.1
annotated-types==0.6.0
anyio==4.2.0
anyio==3.7.1
async-timeout==4.0.3; python_version < "3.11"
attrs==23.2.0
bardapi==0.1.39
azure-cognitiveservices-speech==1.37.0
bardapi==1.0.0
beautifulsoup4==4.12.3
bingimagecreator==0.5.0
browser-cookie3==0.19.1
cachetools==5.3.2
certifi==2023.7.22
certifi==2024.2.2
charset-normalizer==3.3.2
colorama==0.4.6
dashscope==1.10.0
dataclasses-json==0.6.3
distro==1.9.0
edge-tts==6.1.9
edge-tts==6.1.10
edgegpt==0.1.26
exceptiongroup==1.2.0; python_version < "3.11"
frozenlist==1.4.1
google-ai-generativelanguage==0.4.0
google-ai-generativelanguage==0.6.1
google-api-core==2.15.0
google-api-python-client==2.125.0
google-auth==2.26.1
google-generativeai==0.3.2
google-auth-httplib2==0.2.0
google-generativeai==0.5.0
google-search-results==2.4.2
googleapis-common-protos==1.62.0
greenlet==3.0.3; platform_machine == "win32" or platform_machine == "WIN32" or platform_machine == "AMD64" or platform_machine == "amd64" or platform_machine == "x86_64" or platform_machine == "ppc64le" or platform_machine == "aarch64"
grpcio==1.60.0
grpcio-status==1.60.0
h11==0.14.0
h11==0.12.0
h2==4.1.0
hpack==4.0.0
httpcore==1.0.2
httpx==0.26.0
httpcore==0.15.0
httplib2==0.22.0
httpx==0.24.1
hyperframe==6.0.1
idna==3.6
jeepney==0.8.0; "bsd" in sys_platform and python_version >= "3.7" or sys_platform == "linux" and python_version >= "3.7"
jsonpatch==1.33
jsonpointer==2.4
langchain==0.1.4
langchain-community==0.0.16
langchain-core==0.1.16
langsmith==0.0.84
langchain==0.1.16
langchain-community==0.0.32
langchain-core==0.1.42
langchain-text-splitters==0.0.1
langsmith==0.1.45
loguru==0.7.2
lz4==4.3.3
markdown-it-py==3.0.0
marshmallow==3.20.1
mdurl==0.1.2
miservice-fork==2.3.2
miservice-fork==2.4.1
multidict==6.0.4
mutagen==1.47.0
mypy-extensions==1.0.0
numexpr==2.9.0
numexpr==2.10.0
numpy==1.26.3
openai==1.10.0
openai==1.17.0
orjson==3.10.0
packaging==23.2
prompt-toolkit==3.0.43
proto-plus==1.23.0
protobuf==4.25.1
pyasn1==0.5.1
pyasn1-modules==0.3.0
pycryptodomex==3.20.0
pydantic==2.5.3
pydantic-core==2.14.6
pygments==2.17.2
pyjwt==2.8.0
pyparsing==3.1.2; python_version > "3.0"
python-gemini-api==2.4.2
pyyaml==6.0.1
regex==2023.12.25
requests==2.31.0
rich==13.7.0
rich==13.7.1
rsa==4.9
sniffio==1.3.0
soupsieve==2.5
Expand All @@ -75,8 +88,10 @@ tenacity==8.2.3
tqdm==4.66.1
typing-extensions==4.9.0
typing-inspect==0.9.0
uritemplate==4.1.1
urllib3==2.1.0
wcwidth==0.2.13
websockets==12.0
win32-setctime==1.1.0; sys_platform == "win32"
yarl==1.9.4
zhipuai==2.0.1
4 changes: 3 additions & 1 deletion xiao_config.json.example
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,7 @@
"bing_cookie_path": "",
"bing_cookies": {},
"api_base": "https://abc-def.openai.azure.com/",
"deployment_id": ""
"deployment_id": "",
"azure_tts_speech_key": null,
"azure_tts_service_region": "eastasia"
}
19 changes: 18 additions & 1 deletion xiaogpt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ def main():
default=None,
help="show info",
)
parser.add_argument(
"--azure_tts_speech_key",
dest="azure_tts_speech_key",
help="if use azure tts",
)
parser.add_argument(
"--azure_tts_service_region",
dest="azure_tts_service_region",
help="if use azure tts",
)
tts_group = parser.add_mutually_exclusive_group()
tts_group.add_argument(
"--enable_edge_tts",
Expand All @@ -98,7 +108,11 @@ def main():
const="edge",
help="if use edge tts",
)
tts_group.add_argument("--tts", help="tts type", choices=["mi", "edge", "openai"])
tts_group.add_argument(
"--tts",
help="tts type",
choices=["mi", "edge", "openai", "azure"],
)
bot_group = parser.add_mutually_exclusive_group()
bot_group.add_argument(
"--use_gpt3",
Expand Down Expand Up @@ -197,6 +211,9 @@ def main():
options = parser.parse_args()
if options.bot in ["bard"] and options.stream:
raise Exception("For now Bard do not support stream")
if options.tts in ["edge", "openai", "azure"]:
print("Will close stream to better tts")
options.stream = False
config = Config.from_options(options)

miboy = MiGPT(config)
Expand Down
9 changes: 8 additions & 1 deletion xiaogpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,13 @@ class Config:
start_conversation: str = "开始持续对话"
end_conversation: str = "结束持续对话"
stream: bool = False
tts: Literal["mi", "edge"] = "mi"
tts: Literal["mi", "edge", "azure", "openai"] = "mi"
tts_voice: str | None = None
gpt_options: dict[str, Any] = field(default_factory=dict)
bing_cookie_path: str = ""
bing_cookies: dict | None = None
azure_tts_speech_key: str | None = None
azure_tts_service_region: str = "eastasia"

def __post_init__(self) -> None:
if self.proxy:
Expand All @@ -110,6 +112,11 @@ def __post_init__(self) -> None:
raise Exception(
"Using GPT api needs openai API key, please google how to"
)
if self.tts == "azure" and not self.azure_tts_speech_key:
raise Exception("Using Azure TTS needs azure speech key")
if self.tts in ["azure", "edge", "openai"]:
print("Will close stream when use tts: {self.tts} for better experience")
self.stream = False

@property
def tts_command(self) -> str:
Expand Down
3 changes: 3 additions & 0 deletions xiaogpt/tts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from xiaogpt.tts.base import TTS as TTS
from xiaogpt.tts.edge import EdgeTTS as EdgeTTS
from xiaogpt.tts.mi import MiTTS as MiTTS
from xiaogpt.tts.azure import AzureTTS

__all__ = ["TTS", "EdgeTTS", "MiTTS", "AzureTTS"]
97 changes: 97 additions & 0 deletions xiaogpt/tts/azure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from __future__ import annotations

import logging
import tempfile
from pathlib import Path
from typing import Optional

import azure.cognitiveservices.speech as speechsdk

from xiaogpt.tts.base import AudioFileTTS
from xiaogpt.utils import calculate_tts_elapse

logger = logging.getLogger(__name__)


class AzureTTS(AudioFileTTS):
voice_name = "zh-CN-XiaoxiaoMultilingualNeural"

async def make_audio_file(self, query: str, text: str) -> tuple[Path, float]:
output_file = tempfile.NamedTemporaryFile(
suffix=".mp3", mode="wb", delete=False, dir=self.dirname.name
)

speech_synthesizer = self._build_speech_synthesizer(output_file.name)
result: Optional[speechsdk.SpeechSynthesisResult] = (
speech_synthesizer.speak_text_async(text).get()
)
if result is None:
raise RuntimeError(
f"Failed to get tts from azure with voice={self.voice_name}"
)
# Check result
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
logger.debug("Speech synthesized for text [{}]".format(text))
return Path(output_file.name), calculate_tts_elapse(text)
elif result.reason == speechsdk.ResultReason.Canceled:
cancellation_details = result.cancellation_details
logger.warning(f"Speech synthesis canceled: {cancellation_details.reason}")
if cancellation_details.reason == speechsdk.CancellationReason.Error:
errmsg = f"Error details: {cancellation_details.error_details}"
logger.error(errmsg)
raise RuntimeError(errmsg)
raise RuntimeError(f"Failed to get tts from azure with voice={self.voice_name}")

def _build_speech_synthesizer(self, filename: str):
speech_key = self.config.azure_tts_speech_key
service_region = self.config.azure_tts_service_region
if not speech_key:
raise Exception("Azure tts need speech key")
speech_config = speechsdk.SpeechConfig(
subscription=speech_key, region=service_region
)
speech_config.set_speech_synthesis_output_format(
speechsdk.SpeechSynthesisOutputFormat.Audio16Khz32KBitRateMonoMp3
)
if self.config.proxy:
host, port, username, password = self._parse_proxy(self.config.proxy)

if username and password:
speech_config.set_proxy(
hostname=host, port=port, username=username, password=password
)
else:
speech_config.set_proxy(hostname=host, port=port)

speech_config.speech_synthesis_voice_name = (
self.config.tts_voice or self.voice_name
)
speech_synthesizer = speechsdk.SpeechSynthesizer(
speech_config=speech_config,
audio_config=speechsdk.audio.AudioOutputConfig(filename=filename), # type: ignore
)
return speech_synthesizer

def _parse_proxy(self, proxy_str: str):
proxy_str = proxy_str
proxy_str_splited = proxy_str.split("://")
proxy_type = proxy_str_splited[0]
proxy_addr = proxy_str_splited[1]

if proxy_type == "http":
if "@" in proxy_addr:
proxy_addr_splited = proxy_addr.split("@")
proxy_auth = proxy_addr_splited[0]
proxy_addr_netloc = proxy_addr_splited[1]
proxy_auth_splited = proxy_auth.split(":")
username = proxy_auth_splited[0]
password = proxy_auth_splited[1]
else:
proxy_addr_netloc = proxy_addr
username, password = None, None

proxy_addr_netloc_splited = proxy_addr_netloc.split(":")
host = proxy_addr_netloc_splited[0]
port = int(proxy_addr_netloc_splited[1])
return host, port, username, password
raise NotImplementedError
2 changes: 1 addition & 1 deletion xiaogpt/tts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async def worker():
break
else:
url, duration = result
logger.debug("Playing URL %s(%s seconds)", url, duration)
logger.debug("Playing URL %s (%s seconds)", url, duration)
await self.mina_service.play_by_url(self.device_id, url)
await self.wait_for_duration(duration)
await task
Expand Down
4 changes: 3 additions & 1 deletion xiaogpt/xiaogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
WAKEUP_KEYWORD,
Config,
)
from xiaogpt.tts import TTS, EdgeTTS, MiTTS
from xiaogpt.tts import TTS, EdgeTTS, MiTTS, AzureTTS
from xiaogpt.tts.openai import OpenAITTS
from xiaogpt.utils import (
parse_cookie_string,
Expand Down Expand Up @@ -260,6 +260,8 @@ async def do_tts(self, value):
def tts(self) -> TTS:
if self.config.tts == "edge":
return EdgeTTS(self.mina_service, self.device_id, self.config)
elif self.config.tts == "azure":
return AzureTTS(self.mina_service, self.device_id, self.config)
elif self.config.tts == "openai":
return OpenAITTS(self.mina_service, self.device_id, self.config)
else:
Expand Down

0 comments on commit d0dfa75

Please sign in to comment.