please god work
This commit is contained in:
202
wyoming_glados/handler.py
Normal file
202
wyoming_glados/handler.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
||||
from wyoming.error import Error
|
||||
from wyoming.event import Event
|
||||
from wyoming.info import Describe, Info
|
||||
from wyoming.server import AsyncEventHandler
|
||||
from wyoming.tts import Synthesize
|
||||
|
||||
from style_bert_vits2.nlp import bert_models
|
||||
from style_bert_vits2.constants import Languages
|
||||
from style_bert_vits2.tts_model import TTSModel
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_VOICE_LOCK = asyncio.Lock()
|
||||
_MODEL: Optional[TTSModel] = None
|
||||
|
||||
_BERT_MODEL_NAMES = {
|
||||
Languages.JP: "ku-nlp/deberta-v2-large-japanese-char-wwm",
|
||||
Languages.EN: "microsoft/deberta-v3-large",
|
||||
Languages.ZH: "hfl/chinese-roberta-wwm-ext-large",
|
||||
}
|
||||
|
||||
_HIRAGANA_KATAKANA = re.compile(r"[\u3040-\u309F\u30A0-\u30FF]")
|
||||
_CJK = re.compile(r"[\u4E00-\u9FFF]")
|
||||
|
||||
|
||||
def _detect_language(text: str) -> Languages:
|
||||
if _HIRAGANA_KATAKANA.search(text):
|
||||
return Languages.JP
|
||||
if _CJK.search(text):
|
||||
return Languages.ZH
|
||||
return Languages.EN
|
||||
|
||||
|
||||
def _load_bert_for_language(language: Languages, device: str) -> None:
|
||||
model_name = _BERT_MODEL_NAMES[language]
|
||||
if not bert_models.is_model_loaded(language):
|
||||
_LOGGER.info("Loading BERT model for %s (%s)", language.name, model_name)
|
||||
bert_models.load_model(language, model_name)
|
||||
if not bert_models.is_tokenizer_loaded(language):
|
||||
bert_models.load_tokenizer(language, model_name)
|
||||
bert = bert_models.__loaded_models.get(language)
|
||||
if bert is not None:
|
||||
bert = bert.float()
|
||||
bert.eval()
|
||||
bert_models.__loaded_models[language] = bert
|
||||
_LOGGER.info("BERT model for %s cast to float32", language.name)
|
||||
|
||||
|
||||
def _find_model_files(model_dir: Path):
|
||||
model_dir = model_dir.resolve()
|
||||
safetensors = list(model_dir.glob("*.safetensors"))
|
||||
config = model_dir / "config.json"
|
||||
style = model_dir / "style_vectors.npy"
|
||||
|
||||
if safetensors and config.exists():
|
||||
return safetensors[0], config, style if style.exists() else None
|
||||
|
||||
for subdir in sorted(model_dir.iterdir()):
|
||||
if not subdir.is_dir():
|
||||
continue
|
||||
safetensors = list(subdir.glob("*.safetensors"))
|
||||
config = subdir / "config.json"
|
||||
style = subdir / "style_vectors.npy"
|
||||
if safetensors and config.exists():
|
||||
return safetensors[0], config, style if style.exists() else None
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"No .safetensors files found in {model_dir} or its subdirectories"
|
||||
)
|
||||
|
||||
|
||||
def _load_model(model_dir: Path, device: str) -> TTSModel:
|
||||
model_path, config_path, style_path = _find_model_files(model_dir)
|
||||
|
||||
_LOGGER.info("Creating TTSModel (model=%s, config=%s, device=%s)",
|
||||
model_path.name, config_path.name, device)
|
||||
|
||||
model = TTSModel(
|
||||
model_path=model_path,
|
||||
config_path=config_path,
|
||||
style_vec_path=style_path,
|
||||
device=device,
|
||||
)
|
||||
|
||||
_LOGGER.info("Loading model weights...")
|
||||
model.load()
|
||||
|
||||
net_g = getattr(model, "_TTSModel__net_g", None)
|
||||
if net_g is not None:
|
||||
net_g = net_g.float()
|
||||
setattr(model, "_TTSModel__net_g", net_g)
|
||||
_LOGGER.info("TTS network cast to float32")
|
||||
|
||||
_LOGGER.info("Model loaded successfully")
|
||||
return model
|
||||
|
||||
|
||||
class GLaDOSEventHandler(AsyncEventHandler):
|
||||
def __init__(
|
||||
self,
|
||||
wyoming_info: Info,
|
||||
model_dir: Path,
|
||||
device: str,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.wyoming_info_event = wyoming_info.event()
|
||||
self.model_dir = model_dir
|
||||
self.device = device
|
||||
|
||||
async def handle_event(self, event: Event) -> bool:
|
||||
if Describe.is_type(event.type):
|
||||
await self.write_event(self.wyoming_info_event)
|
||||
return True
|
||||
|
||||
if not Synthesize.is_type(event.type):
|
||||
return True
|
||||
|
||||
synthesize = Synthesize.from_event(event)
|
||||
return await self._handle_synthesize(synthesize)
|
||||
|
||||
async def _handle_synthesize(self, synthesize: Synthesize) -> bool:
|
||||
global _MODEL
|
||||
|
||||
text = synthesize.text.strip()
|
||||
if not text:
|
||||
return True
|
||||
|
||||
language = _detect_language(text)
|
||||
|
||||
speaker_id = 0
|
||||
style = "Neutral"
|
||||
if synthesize.voice is not None and synthesize.voice.speaker:
|
||||
try:
|
||||
speaker_id = int(synthesize.voice.speaker)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
_LOGGER.info("Synthesizing: text='%s' language=%s speaker=%s style=%s",
|
||||
text[:80], language.name, speaker_id, style)
|
||||
|
||||
try:
|
||||
async with _VOICE_LOCK:
|
||||
if _MODEL is None:
|
||||
_LOGGER.info("Loading GLaDOS model from %s on %s",
|
||||
self.model_dir, self.device)
|
||||
_MODEL = _load_model(self.model_dir, self.device)
|
||||
|
||||
_load_bert_for_language(language, self.device)
|
||||
|
||||
sr, audio = await asyncio.to_thread(
|
||||
_MODEL.infer,
|
||||
text=text,
|
||||
language=language,
|
||||
speaker_id=speaker_id,
|
||||
style=style,
|
||||
)
|
||||
|
||||
audio_int16 = np.round(audio).astype(np.int16)
|
||||
raw_bytes = audio_int16.tobytes()
|
||||
|
||||
rate = sr
|
||||
width = 2
|
||||
channels = 1
|
||||
|
||||
await self.write_event(
|
||||
AudioStart(rate=rate, width=width, channels=channels).event()
|
||||
)
|
||||
|
||||
samples_per_chunk = 1024
|
||||
bytes_per_sample = width * channels
|
||||
bytes_per_chunk = bytes_per_sample * samples_per_chunk
|
||||
|
||||
for i in range(0, len(raw_bytes), bytes_per_chunk):
|
||||
chunk = raw_bytes[i:i + bytes_per_chunk]
|
||||
await self.write_event(
|
||||
AudioChunk(
|
||||
audio=chunk,
|
||||
rate=rate,
|
||||
width=width,
|
||||
channels=channels,
|
||||
).event()
|
||||
)
|
||||
|
||||
await self.write_event(AudioStop().event())
|
||||
return True
|
||||
|
||||
except Exception as err:
|
||||
_LOGGER.exception("Synthesis failed")
|
||||
await self.write_event(
|
||||
Error(text=str(err), code=err.__class__.__name__).event()
|
||||
)
|
||||
return True
|
||||
Reference in New Issue
Block a user