mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
Add VAD trimming and Docker publishing
This commit is contained in:
@@ -37,10 +37,11 @@ STATIC_DIR = Path(os.environ.get("STATIC_DIR", str(ROOT_DIR / "static"))).resolv
|
||||
PERSONAL_DIR = Path(os.environ.get("PERSONAL_DIR", str(DATA_DIR / "personal_samples"))).resolve()
|
||||
CAPTURED_DIR = Path(os.environ.get("CAPTURED_DIR", str(DATA_DIR / "captured_audio"))).resolve()
|
||||
NEGATIVE_DIR = Path(os.environ.get("NEGATIVE_DIR", str(DATA_DIR / "negative_samples"))).resolve()
|
||||
TRIM_HISTORY_DIR = Path(os.environ.get("TRIM_HISTORY_DIR", str(DATA_DIR / "trim_history"))).resolve()
|
||||
TRIM_HISTORY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
TRAINED_WAKE_WORDS_DIR = Path(
|
||||
os.environ.get("TRAINED_WAKE_WORDS_DIR", str(DATA_DIR / "trained_wake_words"))
|
||||
).resolve()
|
||||
|
||||
CLI_DIR = Path(os.environ.get("CLI_DIR", str(ROOT_DIR / "cli"))).resolve()
|
||||
PIPER_ROOT = DATA_DIR / "tools" / "piper-sample-generator"
|
||||
PIPER_VOICES_DIR = PIPER_ROOT / "voices"
|
||||
@@ -169,11 +170,58 @@ FIRMWARE_LOCK = threading.Lock()
|
||||
FIRMWARE_SESSIONS: Dict[str, Dict[str, Any]] = {}
|
||||
ANSI_ESCAPE_RE = re.compile(r"\x1B(?:\[[0-?]*[ -/]*[@-~]|[@-Z\\-_])")
|
||||
|
||||
# --- Silero VAD (lazy-loaded) ---
|
||||
_silero_vad_model = None
|
||||
_silero_vad_utils = None
|
||||
_SILERO_VAD_LOCK = threading.Lock()
|
||||
VAD_SELECTION_PAD_START_S = 0.08
|
||||
VAD_SELECTION_PAD_END_S = 0.08
|
||||
|
||||
|
||||
def _load_silero_vad():
|
||||
"""Lazy-load Silero VAD model on first use. Returns (model, utils)."""
|
||||
global _silero_vad_model, _silero_vad_utils
|
||||
if _silero_vad_model is not None:
|
||||
return _silero_vad_model, _silero_vad_utils
|
||||
with _SILERO_VAD_LOCK:
|
||||
if _silero_vad_model is not None:
|
||||
return _silero_vad_model, _silero_vad_utils
|
||||
import torch
|
||||
import silero_vad
|
||||
model = silero_vad.load_silero_vad()
|
||||
model.eval()
|
||||
_silero_vad_model = model
|
||||
_silero_vad_utils = {"torch": torch}
|
||||
return model, _silero_vad_utils
|
||||
|
||||
|
||||
def _detect_speech_segments(wav_bytes: bytes) -> List[Dict[str, float]]:
|
||||
"""Run Silero VAD on 16 kHz mono WAV bytes. Return {start, end} seconds."""
|
||||
model, utils = _load_silero_vad()
|
||||
torch = utils["torch"]
|
||||
import numpy as np
|
||||
from silero_vad.utils_vad import get_speech_timestamps
|
||||
|
||||
with wave.open(io.BytesIO(wav_bytes), "rb") as wf:
|
||||
raw = wf.readframes(wf.getnframes())
|
||||
samples = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
audio_tensor = torch.from_numpy(samples)
|
||||
|
||||
timestamps = get_speech_timestamps(
|
||||
audio_tensor,
|
||||
model,
|
||||
sampling_rate=16000,
|
||||
threshold=0.5,
|
||||
min_speech_duration_ms=150,
|
||||
min_silence_duration_ms=100,
|
||||
return_seconds=True,
|
||||
)
|
||||
return [{"start": round(ts["start"], 3), "end": round(ts["end"], 3)} for ts in timestamps]
|
||||
|
||||
|
||||
class _FirmwareYamlLoader(yaml.SafeLoader):
|
||||
pass
|
||||
|
||||
|
||||
class _FirmwareYamlDumper(yaml.SafeDumper):
|
||||
pass
|
||||
|
||||
@@ -1009,7 +1057,7 @@ def _list_captured_items() -> List[Dict[str, Any]]:
|
||||
def _sample_item_from_path(audio_path: Path, bucket: str) -> Dict[str, Any]:
|
||||
meta = _load_sidecar_json(audio_path)
|
||||
stat = audio_path.stat()
|
||||
final_format = meta.get("final_format") or _inspect_wav_bytes(audio_path.read_bytes()) or {}
|
||||
final_format = meta.get("final_format") or meta.get("detected_format") or _inspect_wav_bytes(audio_path.read_bytes()) or {}
|
||||
return {
|
||||
"bucket": bucket,
|
||||
"saved_as": audio_path.name,
|
||||
@@ -1021,6 +1069,8 @@ def _sample_item_from_path(audio_path: Path, bucket: str) -> Dict[str, Any]:
|
||||
"reviewed_at": meta.get("reviewed_at") or "",
|
||||
"created_at": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat(),
|
||||
"converted": bool(meta.get("converted")),
|
||||
"trimmed": bool(meta.get("trimmed")),
|
||||
"source_file": meta.get("source_file") or "",
|
||||
"final_format": final_format,
|
||||
"message": meta.get("message") or "",
|
||||
"size_bytes": stat.st_size,
|
||||
@@ -1036,9 +1086,10 @@ def _list_sample_items(directory: Path, bucket: str) -> List[Dict[str, Any]]:
|
||||
items.append(_sample_item_from_path(audio_path, bucket))
|
||||
except Exception:
|
||||
continue
|
||||
# Untrimmed first (stable sort preserves mtime order within each group).
|
||||
items.sort(key=lambda x: x.get("trimmed", False))
|
||||
return items
|
||||
|
||||
|
||||
def _samples_payload() -> Dict[str, Any]:
|
||||
takes = _sync_personal_samples_state()
|
||||
personal_items = _list_sample_items(PERSONAL_DIR, "personal")
|
||||
@@ -2768,7 +2819,143 @@ def delete_sample(bucket: str, file_name: str):
|
||||
_remove_audio_with_sidecar(path)
|
||||
except FileNotFoundError as e:
|
||||
return JSONResponse({"ok": False, "error": str(e)}, status_code=404)
|
||||
return _samples_payload()
|
||||
return {"ok": True, "deleted_bucket": bucket, "deleted_file": file_name, "message": f"Deleted {file_name}"}
|
||||
|
||||
|
||||
@app.post("/api/samples/{bucket}/{file_name}/vad")
|
||||
def vad_segments(bucket: str, file_name: str):
|
||||
bucket_map = {"personal": PERSONAL_DIR, "negative": NEGATIVE_DIR}
|
||||
directory = bucket_map.get(bucket)
|
||||
if directory is None:
|
||||
return JSONResponse({"ok": False, "error": "Unknown sample bucket."}, status_code=404)
|
||||
try:
|
||||
path = _resolve_audio_path(directory, file_name)
|
||||
except FileNotFoundError as e:
|
||||
return JSONResponse({"ok": False, "error": str(e)}, status_code=404)
|
||||
|
||||
wav_bytes = path.read_bytes()
|
||||
try:
|
||||
all_segments = _detect_speech_segments(wav_bytes)
|
||||
except Exception as e:
|
||||
return JSONResponse({"ok": False, "error": f"VAD failed: {str(e)}"}, status_code=500)
|
||||
|
||||
# Only return the first segment longer than 250 ms. Add deterministic
|
||||
# padding so VAD guides trimming without clipping quiet wake-word edges.
|
||||
filtered = [s for s in all_segments if (s["end"] - s["start"]) >= 0.25]
|
||||
if not filtered:
|
||||
return {"ok": True, "file_name": file_name, "segments": [], "segment_count": 0}
|
||||
seg = filtered[0]
|
||||
info = _inspect_wav_bytes(wav_bytes) or {}
|
||||
duration_s = float(info.get("duration_s") or 0.0)
|
||||
start = max(0.0, round(seg["start"] - VAD_SELECTION_PAD_START_S, 3))
|
||||
end = round(seg["end"] + VAD_SELECTION_PAD_END_S, 3)
|
||||
if duration_s > 0:
|
||||
end = min(duration_s, end)
|
||||
if end <= start:
|
||||
end = start + 0.001
|
||||
segment = {"start": start, "end": end}
|
||||
return {"ok": True, "file_name": file_name, "segments": [segment], "segment_count": 1}
|
||||
|
||||
|
||||
@app.post("/api/samples/trim")
|
||||
async def trim_sample_upload(
|
||||
file: UploadFile = File(...),
|
||||
bucket: str = Form(...),
|
||||
source_file: str = Form(...),
|
||||
start_time: str | None = Form(None),
|
||||
end_time: str | None = Form(None),
|
||||
):
|
||||
bucket_map = {"personal": PERSONAL_DIR, "negative": NEGATIVE_DIR}
|
||||
directory = bucket_map.get(bucket)
|
||||
if directory is None:
|
||||
return JSONResponse({"ok": False, "error": "Unknown sample bucket."}, status_code=404)
|
||||
|
||||
data = await file.read()
|
||||
if not data:
|
||||
return JSONResponse({"ok": False, "error": "Empty audio file."}, status_code=400)
|
||||
|
||||
info = _inspect_wav_bytes(data)
|
||||
if not info:
|
||||
try:
|
||||
data = _normalize_audio_to_target_wav(data, file.filename or "trimmed.wav")
|
||||
except Exception as e:
|
||||
return JSONResponse({"ok": False, "error": f"Audio normalization failed: {e}"}, status_code=400)
|
||||
elif not _is_target_wav(info):
|
||||
try:
|
||||
data = _normalize_audio_to_target_wav(data, file.filename or "trimmed.wav")
|
||||
except Exception as e:
|
||||
return JSONResponse({"ok": False, "error": f"Audio normalization failed: {e}"}, status_code=400)
|
||||
|
||||
try:
|
||||
orig_path = _resolve_audio_path(directory, source_file)
|
||||
except FileNotFoundError as e:
|
||||
return JSONResponse({"ok": False, "error": str(e)}, status_code=404)
|
||||
|
||||
TRIM_HISTORY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S%f")
|
||||
backup_name = f"{ts}_{source_file}"
|
||||
backup_path = TRIM_HISTORY_DIR / backup_name
|
||||
shutil.copy2(orig_path, backup_path)
|
||||
|
||||
orig_sidecar = _audio_sidecar_path(orig_path)
|
||||
if orig_sidecar.exists():
|
||||
shutil.copy2(orig_sidecar, _audio_sidecar_path(backup_path))
|
||||
|
||||
orig_path.write_bytes(data)
|
||||
|
||||
old_sidecar = _load_sidecar_json(orig_path)
|
||||
sidecar = {
|
||||
**old_sidecar,
|
||||
"trimmed": True,
|
||||
"source_file": source_file,
|
||||
"source_bucket": bucket,
|
||||
"trim_start_s": float(start_time) if start_time else None,
|
||||
"trim_end_s": float(end_time) if end_time else None,
|
||||
"undo_backup_file": backup_name,
|
||||
}
|
||||
_write_sidecar_json(orig_path, sidecar)
|
||||
|
||||
updated_item = _sample_item_from_path(orig_path, bucket)
|
||||
updated_item["trimmed"] = True
|
||||
updated_item["source_file"] = source_file
|
||||
return {"ok": True, "updated_sample": updated_item, "message": f"Trimmed {source_file}"}
|
||||
|
||||
|
||||
@app.post("/api/samples/revert")
|
||||
def revert_trim(
|
||||
bucket: str = Form(...),
|
||||
file_name: str = Form(...),
|
||||
):
|
||||
bucket_map = {"personal": PERSONAL_DIR, "negative": NEGATIVE_DIR}
|
||||
directory = bucket_map.get(bucket)
|
||||
if directory is None:
|
||||
return JSONResponse({"ok": False, "error": "Unknown sample bucket."}, status_code=404)
|
||||
|
||||
try:
|
||||
file_path = _resolve_audio_path(directory, file_name)
|
||||
except FileNotFoundError as e:
|
||||
return JSONResponse({"ok": False, "error": str(e)}, status_code=404)
|
||||
|
||||
sidecar = _load_sidecar_json(file_path)
|
||||
backup_name = sidecar.get("undo_backup_file")
|
||||
if not backup_name:
|
||||
return JSONResponse({"ok": False, "error": "No trim backup found for this sample."}, status_code=400)
|
||||
|
||||
backup_path = TRIM_HISTORY_DIR / backup_name
|
||||
if not backup_path.exists():
|
||||
return JSONResponse({"ok": False, "error": "Trim backup file missing."}, status_code=404)
|
||||
|
||||
shutil.copy2(backup_path, file_path)
|
||||
backup_sidecar = _audio_sidecar_path(backup_path)
|
||||
if backup_sidecar.exists():
|
||||
shutil.copy2(backup_sidecar, _audio_sidecar_path(file_path))
|
||||
|
||||
backup_path.unlink()
|
||||
if backup_sidecar.exists():
|
||||
backup_sidecar.unlink()
|
||||
|
||||
updated_item = _sample_item_from_path(file_path, bucket)
|
||||
return {"ok": True, "updated_sample": updated_item, "message": f"Reverted {file_name}"}
|
||||
|
||||
|
||||
@app.post("/api/captured_audio/{file_name}/approve_personal")
|
||||
|
||||
Reference in New Issue
Block a user