Files
glados-ladosp-tts/wyoming_glados/handler.py
taco b90e94b83b
Some checks failed
Build and Publish Docker Images / build-cpu (push) Waiting to run
Build and Publish Docker Images / build-cuda (push) Waiting to run
Build and Publish Docker Images / build-rocm (push) Failing after 2m5s
please god work
2026-06-12 17:06:52 -06:00

203 lines
6.3 KiB
Python

import asyncio
import logging
import re
from pathlib import Path
from typing import Optional
import numpy as np
from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.error import Error
from wyoming.event import Event
from wyoming.info import Describe, Info
from wyoming.server import AsyncEventHandler
from wyoming.tts import Synthesize
from style_bert_vits2.nlp import bert_models
from style_bert_vits2.constants import Languages
from style_bert_vits2.tts_model import TTSModel
_LOGGER = logging.getLogger(__name__)
_VOICE_LOCK = asyncio.Lock()
_MODEL: Optional[TTSModel] = None
_BERT_MODEL_NAMES = {
Languages.JP: "ku-nlp/deberta-v2-large-japanese-char-wwm",
Languages.EN: "microsoft/deberta-v3-large",
Languages.ZH: "hfl/chinese-roberta-wwm-ext-large",
}
_HIRAGANA_KATAKANA = re.compile(r"[\u3040-\u309F\u30A0-\u30FF]")
_CJK = re.compile(r"[\u4E00-\u9FFF]")
def _detect_language(text: str) -> Languages:
if _HIRAGANA_KATAKANA.search(text):
return Languages.JP
if _CJK.search(text):
return Languages.ZH
return Languages.EN
def _load_bert_for_language(language: Languages, device: str) -> None:
model_name = _BERT_MODEL_NAMES[language]
if not bert_models.is_model_loaded(language):
_LOGGER.info("Loading BERT model for %s (%s)", language.name, model_name)
bert_models.load_model(language, model_name)
if not bert_models.is_tokenizer_loaded(language):
bert_models.load_tokenizer(language, model_name)
bert = bert_models.__loaded_models.get(language)
if bert is not None:
bert = bert.float()
bert.eval()
bert_models.__loaded_models[language] = bert
_LOGGER.info("BERT model for %s cast to float32", language.name)
def _find_model_files(model_dir: Path):
model_dir = model_dir.resolve()
safetensors = list(model_dir.glob("*.safetensors"))
config = model_dir / "config.json"
style = model_dir / "style_vectors.npy"
if safetensors and config.exists():
return safetensors[0], config, style if style.exists() else None
for subdir in sorted(model_dir.iterdir()):
if not subdir.is_dir():
continue
safetensors = list(subdir.glob("*.safetensors"))
config = subdir / "config.json"
style = subdir / "style_vectors.npy"
if safetensors and config.exists():
return safetensors[0], config, style if style.exists() else None
raise FileNotFoundError(
f"No .safetensors files found in {model_dir} or its subdirectories"
)
def _load_model(model_dir: Path, device: str) -> TTSModel:
model_path, config_path, style_path = _find_model_files(model_dir)
_LOGGER.info("Creating TTSModel (model=%s, config=%s, device=%s)",
model_path.name, config_path.name, device)
model = TTSModel(
model_path=model_path,
config_path=config_path,
style_vec_path=style_path,
device=device,
)
_LOGGER.info("Loading model weights...")
model.load()
net_g = getattr(model, "_TTSModel__net_g", None)
if net_g is not None:
net_g = net_g.float()
setattr(model, "_TTSModel__net_g", net_g)
_LOGGER.info("TTS network cast to float32")
_LOGGER.info("Model loaded successfully")
return model
class GLaDOSEventHandler(AsyncEventHandler):
def __init__(
self,
wyoming_info: Info,
model_dir: Path,
device: str,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.wyoming_info_event = wyoming_info.event()
self.model_dir = model_dir
self.device = device
async def handle_event(self, event: Event) -> bool:
if Describe.is_type(event.type):
await self.write_event(self.wyoming_info_event)
return True
if not Synthesize.is_type(event.type):
return True
synthesize = Synthesize.from_event(event)
return await self._handle_synthesize(synthesize)
async def _handle_synthesize(self, synthesize: Synthesize) -> bool:
global _MODEL
text = synthesize.text.strip()
if not text:
return True
language = _detect_language(text)
speaker_id = 0
style = "Neutral"
if synthesize.voice is not None and synthesize.voice.speaker:
try:
speaker_id = int(synthesize.voice.speaker)
except ValueError:
pass
_LOGGER.info("Synthesizing: text='%s' language=%s speaker=%s style=%s",
text[:80], language.name, speaker_id, style)
try:
async with _VOICE_LOCK:
if _MODEL is None:
_LOGGER.info("Loading GLaDOS model from %s on %s",
self.model_dir, self.device)
_MODEL = _load_model(self.model_dir, self.device)
_load_bert_for_language(language, self.device)
sr, audio = await asyncio.to_thread(
_MODEL.infer,
text=text,
language=language,
speaker_id=speaker_id,
style=style,
)
audio_int16 = np.round(audio).astype(np.int16)
raw_bytes = audio_int16.tobytes()
rate = sr
width = 2
channels = 1
await self.write_event(
AudioStart(rate=rate, width=width, channels=channels).event()
)
samples_per_chunk = 1024
bytes_per_sample = width * channels
bytes_per_chunk = bytes_per_sample * samples_per_chunk
for i in range(0, len(raw_bytes), bytes_per_chunk):
chunk = raw_bytes[i:i + bytes_per_chunk]
await self.write_event(
AudioChunk(
audio=chunk,
rate=rate,
width=width,
channels=channels,
).event()
)
await self.write_event(AudioStop().event())
return True
except Exception as err:
_LOGGER.exception("Synthesis failed")
await self.write_event(
Error(text=str(err), code=err.__class__.__name__).event()
)
return True