203 lines
6.3 KiB
Python
203 lines
6.3 KiB
Python
import asyncio
|
|
import logging
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
|
from wyoming.error import Error
|
|
from wyoming.event import Event
|
|
from wyoming.info import Describe, Info
|
|
from wyoming.server import AsyncEventHandler
|
|
from wyoming.tts import Synthesize
|
|
|
|
from style_bert_vits2.nlp import bert_models
|
|
from style_bert_vits2.constants import Languages
|
|
from style_bert_vits2.tts_model import TTSModel
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
_VOICE_LOCK = asyncio.Lock()
|
|
_MODEL: Optional[TTSModel] = None
|
|
|
|
_BERT_MODEL_NAMES = {
|
|
Languages.JP: "ku-nlp/deberta-v2-large-japanese-char-wwm",
|
|
Languages.EN: "microsoft/deberta-v3-large",
|
|
Languages.ZH: "hfl/chinese-roberta-wwm-ext-large",
|
|
}
|
|
|
|
_HIRAGANA_KATAKANA = re.compile(r"[\u3040-\u309F\u30A0-\u30FF]")
|
|
_CJK = re.compile(r"[\u4E00-\u9FFF]")
|
|
|
|
|
|
def _detect_language(text: str) -> Languages:
|
|
if _HIRAGANA_KATAKANA.search(text):
|
|
return Languages.JP
|
|
if _CJK.search(text):
|
|
return Languages.ZH
|
|
return Languages.EN
|
|
|
|
|
|
def _load_bert_for_language(language: Languages, device: str) -> None:
|
|
model_name = _BERT_MODEL_NAMES[language]
|
|
if not bert_models.is_model_loaded(language):
|
|
_LOGGER.info("Loading BERT model for %s (%s)", language.name, model_name)
|
|
bert_models.load_model(language, model_name)
|
|
if not bert_models.is_tokenizer_loaded(language):
|
|
bert_models.load_tokenizer(language, model_name)
|
|
bert = bert_models.__loaded_models.get(language)
|
|
if bert is not None:
|
|
bert = bert.float()
|
|
bert.eval()
|
|
bert_models.__loaded_models[language] = bert
|
|
_LOGGER.info("BERT model for %s cast to float32", language.name)
|
|
|
|
|
|
def _find_model_files(model_dir: Path):
|
|
model_dir = model_dir.resolve()
|
|
safetensors = list(model_dir.glob("*.safetensors"))
|
|
config = model_dir / "config.json"
|
|
style = model_dir / "style_vectors.npy"
|
|
|
|
if safetensors and config.exists():
|
|
return safetensors[0], config, style if style.exists() else None
|
|
|
|
for subdir in sorted(model_dir.iterdir()):
|
|
if not subdir.is_dir():
|
|
continue
|
|
safetensors = list(subdir.glob("*.safetensors"))
|
|
config = subdir / "config.json"
|
|
style = subdir / "style_vectors.npy"
|
|
if safetensors and config.exists():
|
|
return safetensors[0], config, style if style.exists() else None
|
|
|
|
raise FileNotFoundError(
|
|
f"No .safetensors files found in {model_dir} or its subdirectories"
|
|
)
|
|
|
|
|
|
def _load_model(model_dir: Path, device: str) -> TTSModel:
|
|
model_path, config_path, style_path = _find_model_files(model_dir)
|
|
|
|
_LOGGER.info("Creating TTSModel (model=%s, config=%s, device=%s)",
|
|
model_path.name, config_path.name, device)
|
|
|
|
model = TTSModel(
|
|
model_path=model_path,
|
|
config_path=config_path,
|
|
style_vec_path=style_path,
|
|
device=device,
|
|
)
|
|
|
|
_LOGGER.info("Loading model weights...")
|
|
model.load()
|
|
|
|
net_g = getattr(model, "_TTSModel__net_g", None)
|
|
if net_g is not None:
|
|
net_g = net_g.float()
|
|
setattr(model, "_TTSModel__net_g", net_g)
|
|
_LOGGER.info("TTS network cast to float32")
|
|
|
|
_LOGGER.info("Model loaded successfully")
|
|
return model
|
|
|
|
|
|
class GLaDOSEventHandler(AsyncEventHandler):
|
|
def __init__(
|
|
self,
|
|
wyoming_info: Info,
|
|
model_dir: Path,
|
|
device: str,
|
|
*args,
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self.wyoming_info_event = wyoming_info.event()
|
|
self.model_dir = model_dir
|
|
self.device = device
|
|
|
|
async def handle_event(self, event: Event) -> bool:
|
|
if Describe.is_type(event.type):
|
|
await self.write_event(self.wyoming_info_event)
|
|
return True
|
|
|
|
if not Synthesize.is_type(event.type):
|
|
return True
|
|
|
|
synthesize = Synthesize.from_event(event)
|
|
return await self._handle_synthesize(synthesize)
|
|
|
|
async def _handle_synthesize(self, synthesize: Synthesize) -> bool:
|
|
global _MODEL
|
|
|
|
text = synthesize.text.strip()
|
|
if not text:
|
|
return True
|
|
|
|
language = _detect_language(text)
|
|
|
|
speaker_id = 0
|
|
style = "Neutral"
|
|
if synthesize.voice is not None and synthesize.voice.speaker:
|
|
try:
|
|
speaker_id = int(synthesize.voice.speaker)
|
|
except ValueError:
|
|
pass
|
|
|
|
_LOGGER.info("Synthesizing: text='%s' language=%s speaker=%s style=%s",
|
|
text[:80], language.name, speaker_id, style)
|
|
|
|
try:
|
|
async with _VOICE_LOCK:
|
|
if _MODEL is None:
|
|
_LOGGER.info("Loading GLaDOS model from %s on %s",
|
|
self.model_dir, self.device)
|
|
_MODEL = _load_model(self.model_dir, self.device)
|
|
|
|
_load_bert_for_language(language, self.device)
|
|
|
|
sr, audio = await asyncio.to_thread(
|
|
_MODEL.infer,
|
|
text=text,
|
|
language=language,
|
|
speaker_id=speaker_id,
|
|
style=style,
|
|
)
|
|
|
|
audio_int16 = np.round(audio).astype(np.int16)
|
|
raw_bytes = audio_int16.tobytes()
|
|
|
|
rate = sr
|
|
width = 2
|
|
channels = 1
|
|
|
|
await self.write_event(
|
|
AudioStart(rate=rate, width=width, channels=channels).event()
|
|
)
|
|
|
|
samples_per_chunk = 1024
|
|
bytes_per_sample = width * channels
|
|
bytes_per_chunk = bytes_per_sample * samples_per_chunk
|
|
|
|
for i in range(0, len(raw_bytes), bytes_per_chunk):
|
|
chunk = raw_bytes[i:i + bytes_per_chunk]
|
|
await self.write_event(
|
|
AudioChunk(
|
|
audio=chunk,
|
|
rate=rate,
|
|
width=width,
|
|
channels=channels,
|
|
).event()
|
|
)
|
|
|
|
await self.write_event(AudioStop().event())
|
|
return True
|
|
|
|
except Exception as err:
|
|
_LOGGER.exception("Synthesis failed")
|
|
await self.write_event(
|
|
Error(text=str(err), code=err.__class__.__name__).event()
|
|
)
|
|
return True
|