import asyncio import logging import re from pathlib import Path from typing import Optional import numpy as np from wyoming.audio import AudioChunk, AudioStart, AudioStop from wyoming.error import Error from wyoming.event import Event from wyoming.info import Describe, Info from wyoming.server import AsyncEventHandler from wyoming.tts import Synthesize from style_bert_vits2.nlp import bert_models from style_bert_vits2.constants import Languages from style_bert_vits2.tts_model import TTSModel _LOGGER = logging.getLogger(__name__) _VOICE_LOCK = asyncio.Lock() _MODEL: Optional[TTSModel] = None _BERT_MODEL_NAMES = { Languages.JP: "ku-nlp/deberta-v2-large-japanese-char-wwm", Languages.EN: "microsoft/deberta-v3-large", Languages.ZH: "hfl/chinese-roberta-wwm-ext-large", } _HIRAGANA_KATAKANA = re.compile(r"[\u3040-\u309F\u30A0-\u30FF]") _CJK = re.compile(r"[\u4E00-\u9FFF]") def _detect_language(text: str) -> Languages: if _HIRAGANA_KATAKANA.search(text): return Languages.JP if _CJK.search(text): return Languages.ZH return Languages.EN def _load_bert_for_language(language: Languages, device: str) -> None: model_name = _BERT_MODEL_NAMES[language] if not bert_models.is_model_loaded(language): _LOGGER.info("Loading BERT model for %s (%s)", language.name, model_name) bert_models.load_model(language, model_name) if not bert_models.is_tokenizer_loaded(language): bert_models.load_tokenizer(language, model_name) bert = bert_models.__loaded_models.get(language) if bert is not None: bert = bert.float() bert.eval() bert_models.__loaded_models[language] = bert _LOGGER.info("BERT model for %s cast to float32", language.name) def _find_model_files(model_dir: Path): model_dir = model_dir.resolve() safetensors = list(model_dir.glob("*.safetensors")) config = model_dir / "config.json" style = model_dir / "style_vectors.npy" if safetensors and config.exists(): return safetensors[0], config, style if style.exists() else None for subdir in sorted(model_dir.iterdir()): if not subdir.is_dir(): continue safetensors = list(subdir.glob("*.safetensors")) config = subdir / "config.json" style = subdir / "style_vectors.npy" if safetensors and config.exists(): return safetensors[0], config, style if style.exists() else None raise FileNotFoundError( f"No .safetensors files found in {model_dir} or its subdirectories" ) def _load_model(model_dir: Path, device: str) -> TTSModel: model_path, config_path, style_path = _find_model_files(model_dir) _LOGGER.info("Creating TTSModel (model=%s, config=%s, device=%s)", model_path.name, config_path.name, device) model = TTSModel( model_path=model_path, config_path=config_path, style_vec_path=style_path, device=device, ) _LOGGER.info("Loading model weights...") model.load() net_g = getattr(model, "_TTSModel__net_g", None) if net_g is not None: net_g = net_g.float() setattr(model, "_TTSModel__net_g", net_g) _LOGGER.info("TTS network cast to float32") _LOGGER.info("Model loaded successfully") return model class GLaDOSEventHandler(AsyncEventHandler): def __init__( self, wyoming_info: Info, model_dir: Path, device: str, *args, **kwargs, ) -> None: super().__init__(*args, **kwargs) self.wyoming_info_event = wyoming_info.event() self.model_dir = model_dir self.device = device async def handle_event(self, event: Event) -> bool: if Describe.is_type(event.type): await self.write_event(self.wyoming_info_event) return True if not Synthesize.is_type(event.type): return True synthesize = Synthesize.from_event(event) return await self._handle_synthesize(synthesize) async def _handle_synthesize(self, synthesize: Synthesize) -> bool: global _MODEL text = synthesize.text.strip() if not text: return True language = _detect_language(text) speaker_id = 0 style = "Neutral" if synthesize.voice is not None and synthesize.voice.speaker: try: speaker_id = int(synthesize.voice.speaker) except ValueError: pass _LOGGER.info("Synthesizing: text='%s' language=%s speaker=%s style=%s", text[:80], language.name, speaker_id, style) try: async with _VOICE_LOCK: if _MODEL is None: _LOGGER.info("Loading GLaDOS model from %s on %s", self.model_dir, self.device) _MODEL = _load_model(self.model_dir, self.device) _load_bert_for_language(language, self.device) sr, audio = await asyncio.to_thread( _MODEL.infer, text=text, language=language, speaker_id=speaker_id, style=style, ) audio_int16 = np.round(audio).astype(np.int16) raw_bytes = audio_int16.tobytes() rate = sr width = 2 channels = 1 await self.write_event( AudioStart(rate=rate, width=width, channels=channels).event() ) samples_per_chunk = 1024 bytes_per_sample = width * channels bytes_per_chunk = bytes_per_sample * samples_per_chunk for i in range(0, len(raw_bytes), bytes_per_chunk): chunk = raw_bytes[i:i + bytes_per_chunk] await self.write_event( AudioChunk( audio=chunk, rate=rate, width=width, channels=channels, ).event() ) await self.write_event(AudioStop().event()) return True except Exception as err: _LOGGER.exception("Synthesis failed") await self.write_event( Error(text=str(err), code=err.__class__.__name__).event() ) return True