mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
Add support for personal samples
This commit is contained in:
@@ -463,6 +463,7 @@
|
||||
"# To improve your model, experiment with these settings and use more sources of\n",
|
||||
"# background clips.\n",
|
||||
"import sys, os\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"# try the common places we’ve used\n",
|
||||
"candidates = [\n",
|
||||
@@ -498,7 +499,8 @@
|
||||
"if not validate_directories(impulse_paths + background_paths):\n",
|
||||
" raise ValueError(\"One or more required directories are missing.\")\n",
|
||||
"\n",
|
||||
"clips = Clips(\n",
|
||||
"# Process TTS generated samples (default)\n",
|
||||
"clips_tts = Clips(\n",
|
||||
" input_directory='./generated_samples',\n",
|
||||
" file_pattern='*.wav',\n",
|
||||
" max_clip_duration_s=5,\n",
|
||||
@@ -507,6 +509,19 @@
|
||||
" split_count=0.1,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Process personal recordings if available (optional)\n",
|
||||
"clips_personal = None\n",
|
||||
"if os.path.exists(\"./personal_samples\") and any(Path(\"./personal_samples\").glob(\"*.wav\")):\n",
|
||||
" clips_personal = Clips(\n",
|
||||
" input_directory=\"./personal_samples\",\n",
|
||||
" file_pattern=\"*.wav\",\n",
|
||||
" max_clip_duration_s=5,\n",
|
||||
" remove_silence=True,\n",
|
||||
" random_split_seed=10,\n",
|
||||
" split_count=0.1,\n",
|
||||
" )\n",
|
||||
" print(\"✅ Found personal samples, will create separate feature set\")\n",
|
||||
"\n",
|
||||
"augmenter = Augmentation(\n",
|
||||
" augmentation_duration_s=3.2,\n",
|
||||
" augmentation_probabilities={\n",
|
||||
@@ -593,14 +608,14 @@
|
||||
"from microwakeword.audio.spectrograms import SpectrogramGeneration\n",
|
||||
"\n",
|
||||
"# ---- Patch: drive clips from generated_samples/*.wav (no datasets.Audio, no torchcodec) ----\n",
|
||||
"def audio_generator_from_wavs(self, split=\"train\", repeat=1):\n",
|
||||
"def audio_generator_from_wavs(self, split=\"train\", repeat=1, source_dir=\"generated_samples\"):\n",
|
||||
" \"\"\"\n",
|
||||
" Yield 1-D float32 arrays loaded via librosa from generated_samples/*.wav.\n",
|
||||
" Yield 1-D float32 arrays loaded via librosa from source_dir/*.wav.\n",
|
||||
" Deterministic 80/10/10 split with seed 10 to mirror original Clips behavior.\n",
|
||||
" \"\"\"\n",
|
||||
" files = sorted(glob.glob(\"generated_samples/*.wav\"))\n",
|
||||
" files = sorted(glob.glob(f\"{source_dir}/*.wav\"))\n",
|
||||
" if not files:\n",
|
||||
" raise SystemExit(\"❌ No WAVs in generated_samples/. Generate TTS samples first.\")\n",
|
||||
" raise SystemExit(f\"❌ No WAVs in {source_dir}/. Generate samples first.\")\n",
|
||||
"\n",
|
||||
" rng = random.Random(10) # deterministic shuffling like Clips(random_split_seed=10)\n",
|
||||
" files_shuf = files[:]\n",
|
||||
@@ -624,9 +639,19 @@
|
||||
" y, sr = librosa.load(p, sr=16000, mono=True)\n",
|
||||
" yield y.astype(np.float32, copy=False)\n",
|
||||
"\n",
|
||||
"# Bind the patched generator to your existing `clips` instance\n",
|
||||
"clips.audio_generator = types.MethodType(audio_generator_from_wavs, clips)\n",
|
||||
"print(\"✅ Patched clips.audio_generator to stream from generated_samples/*.wav (no torchcodec).\")\n",
|
||||
"# Bind the patched generator to clips_tts instance\n",
|
||||
"def audio_generator_tts(self, split=\"train\", repeat=1):\n",
|
||||
" return audio_generator_from_wavs(self, split, repeat, \"generated_samples\")\n",
|
||||
"\n",
|
||||
"clips_tts.audio_generator = types.MethodType(audio_generator_tts, clips_tts)\n",
|
||||
"print(\"✅ Patched clips_tts.audio_generator to stream from generated_samples/*.wav (no torchcodec).\")\n",
|
||||
"\n",
|
||||
"# Bind the patched generator to clips_personal if it exists\n",
|
||||
"if clips_personal is not None:\n",
|
||||
" def audio_generator_personal(self, split=\"train\", repeat=1):\n",
|
||||
" return audio_generator_from_wavs(self, split, repeat, \"personal_samples\")\n",
|
||||
" clips_personal.audio_generator = types.MethodType(audio_generator_personal, clips_personal)\n",
|
||||
" print(\"✅ Patched clips_personal.audio_generator to stream from personal_samples/*.wav (no torchcodec).\")\n",
|
||||
"\n",
|
||||
"# ---- Validate augmentation asset folders exist ----\n",
|
||||
"def validate(paths):\n",
|
||||
@@ -649,14 +674,14 @@
|
||||
" \"testing\": {\"name\": \"test\", \"repetition\": 1, \"slide_frames\": 1},\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# ---- Generate features ----\n",
|
||||
"# ---- Generate features for TTS samples ----\n",
|
||||
"for split, cfg in split_cfg.items():\n",
|
||||
" out_dir = out_root / split\n",
|
||||
" out_dir.mkdir(parents=True, exist_ok=True)\n",
|
||||
" print(f\"🧪 Processing {split} …\")\n",
|
||||
" print(f\"🧪 Processing {split} (TTS) …\")\n",
|
||||
"\n",
|
||||
" spectros = SpectrogramGeneration(\n",
|
||||
" clips=clips, # now backed by our WAV loader\n",
|
||||
" clips=clips_tts, # now backed by our WAV loader\n",
|
||||
" augmenter=augmenter, # your existing augmenter\n",
|
||||
" slide_frames=cfg[\"slide_frames\"],\n",
|
||||
" step_ms=10,\n",
|
||||
@@ -671,6 +696,27 @@
|
||||
" verbose=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"# ---- Generate features for personal samples if available ----\n",
|
||||
"if clips_personal is not None:\n",
|
||||
" out_root_personal = Path(\"personal_augmented_features\")\n",
|
||||
" out_root_personal.mkdir(exist_ok=True)\n",
|
||||
" for split, cfg in split_cfg.items():\n",
|
||||
" out_dir = out_root_personal / split\n",
|
||||
" out_dir.mkdir(parents=True, exist_ok=True)\n",
|
||||
" print(f\"🧪 Processing {split} (personal) …\")\n",
|
||||
" spectros = SpectrogramGeneration(\n",
|
||||
" clips=clips_personal,\n",
|
||||
" augmenter=augmenter,\n",
|
||||
" slide_frames=cfg[\"slide_frames\"],\n",
|
||||
" step_ms=10,\n",
|
||||
" )\n",
|
||||
" RaggedMmap.from_generator(\n",
|
||||
" out_dir=str(out_dir / \"wakeword_mmap\"),\n",
|
||||
" sample_generator=spectros.spectrogram_generator(split=cfg[\"name\"], repeat=cfg[\"repetition\"]),\n",
|
||||
" batch_size=100,\n",
|
||||
" verbose=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"print(\"✅ Features ready (generated_augmented_features/*/wakeword_mmap)\")"
|
||||
]
|
||||
},
|
||||
@@ -752,6 +798,7 @@
|
||||
"# --- Save a yaml config that controls the training process ---\n",
|
||||
"\n",
|
||||
"import os, sys, yaml\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"config = {}\n",
|
||||
"\n",
|
||||
@@ -766,6 +813,11 @@
|
||||
" {\"features_dir\":\"negative_datasets/dinner_party_eval\",\"sampling_weight\":0.0,\"penalty_weight\":1.0,\"truth\":False,\"truncation_strategy\":\"split\",\"type\":\"mmap\"},\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Add personal features if they exist\n",
|
||||
"if os.path.exists(\"personal_augmented_features/training\"):\n",
|
||||
" config[\"features\"].insert(1, {\"features_dir\": \"personal_augmented_features\", \"sampling_weight\": 3.0, \"penalty_weight\": 1.0, \"truth\": True, \"truncation_strategy\": \"truncate_start\", \"type\": \"mmap\"})\n",
|
||||
" print(\"✅ Added personal features with higher weight (3.0)\")\n",
|
||||
"\n",
|
||||
"config[\"training_steps\"] = [40000]\n",
|
||||
"config[\"positive_class_weight\"] = [1]\n",
|
||||
"config[\"negative_class_weight\"] = [20]\n",
|
||||
|
||||
Reference in New Issue
Block a user