107 lines
3.2 KiB
Python
107 lines
3.2 KiB
Python
import argparse
|
|
import asyncio
|
|
import logging
|
|
import signal
|
|
from functools import partial
|
|
from pathlib import Path
|
|
|
|
from wyoming.info import Attribution, Info, TtsProgram, TtsVoice
|
|
from wyoming.server import AsyncServer
|
|
|
|
from . import __version__
|
|
from .handler import GLaDOSEventHandler
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
async def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description="Wyoming TTS server for GLaDOS (Style-Bert-VITS2)"
|
|
)
|
|
parser.add_argument("--uri", default="tcp://0.0.0.0:10200",
|
|
help="URI for the Wyoming server")
|
|
parser.add_argument("--model-dir", type=Path, required=True,
|
|
help="Directory containing model files (config.json, *.safetensors, style_vectors.npy)")
|
|
parser.add_argument("--device", default="cpu",
|
|
help="Device for PyTorch (cpu, cuda)")
|
|
parser.add_argument("--debug", action="store_true",
|
|
help="Log DEBUG messages")
|
|
parser.add_argument("--version", action="version",
|
|
version=__version__)
|
|
args = parser.parse_args()
|
|
|
|
logging.basicConfig(
|
|
level=logging.DEBUG if args.debug else logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
|
)
|
|
|
|
model_dir: Path = args.model_dir.resolve()
|
|
if not model_dir.is_dir():
|
|
raise NotADirectoryError(f"Model directory not found: {model_dir}")
|
|
|
|
wyoming_info = Info(
|
|
tts=[
|
|
TtsProgram(
|
|
name="glados",
|
|
description="GLaDOS TTS - Style-Bert-VITS2 voice from Portal",
|
|
attribution=Attribution(
|
|
name="WarriorMama777",
|
|
url="https://huggingface.co/WarriorMama777/GLaDOS_TTS",
|
|
),
|
|
installed=True,
|
|
voices=[
|
|
TtsVoice(
|
|
name="glados",
|
|
description="GLaDOS (Portal) voice",
|
|
attribution=Attribution(
|
|
name="WarriorMama777",
|
|
url="https://huggingface.co/WarriorMama777/GLaDOS_TTS",
|
|
),
|
|
installed=True,
|
|
languages=["ja", "en", "zh"],
|
|
version=__version__,
|
|
)
|
|
],
|
|
version=__version__,
|
|
supports_synthesize_streaming=False,
|
|
)
|
|
],
|
|
)
|
|
|
|
server = AsyncServer.from_uri(args.uri)
|
|
|
|
_LOGGER.info("Starting GLaDOS Wyoming TTS server on %s", args.uri)
|
|
_LOGGER.info("Model directory: %s", model_dir)
|
|
_LOGGER.info("Device: %s", args.device)
|
|
|
|
server_task = asyncio.create_task(
|
|
server.run(
|
|
partial(
|
|
GLaDOSEventHandler,
|
|
wyoming_info,
|
|
model_dir,
|
|
args.device,
|
|
)
|
|
)
|
|
)
|
|
|
|
loop = asyncio.get_running_loop()
|
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
loop.add_signal_handler(sig, server_task.cancel)
|
|
|
|
try:
|
|
await server_task
|
|
except asyncio.CancelledError:
|
|
_LOGGER.info("Server stopped")
|
|
|
|
|
|
def run():
|
|
asyncio.run(main())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
run()
|
|
except KeyboardInterrupt:
|
|
pass
|