Add support for personal samples

This commit is contained in:
Nicolas Mowen
2025-12-28 13:14:48 -07:00
parent 4dd7503248
commit 53d858e403
3 changed files with 66 additions and 12 deletions

2
.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
personal_samples/*

View File

@@ -463,6 +463,7 @@
"# To improve your model, experiment with these settings and use more sources of\n", "# To improve your model, experiment with these settings and use more sources of\n",
"# background clips.\n", "# background clips.\n",
"import sys, os\n", "import sys, os\n",
"from pathlib import Path\n",
"\n", "\n",
"# try the common places weve used\n", "# try the common places weve used\n",
"candidates = [\n", "candidates = [\n",
@@ -498,7 +499,8 @@
"if not validate_directories(impulse_paths + background_paths):\n", "if not validate_directories(impulse_paths + background_paths):\n",
" raise ValueError(\"One or more required directories are missing.\")\n", " raise ValueError(\"One or more required directories are missing.\")\n",
"\n", "\n",
"clips = Clips(\n", "# Process TTS generated samples (default)\n",
"clips_tts = Clips(\n",
" input_directory='./generated_samples',\n", " input_directory='./generated_samples',\n",
" file_pattern='*.wav',\n", " file_pattern='*.wav',\n",
" max_clip_duration_s=5,\n", " max_clip_duration_s=5,\n",
@@ -507,6 +509,19 @@
" split_count=0.1,\n", " split_count=0.1,\n",
")\n", ")\n",
"\n", "\n",
"# Process personal recordings if available (optional)\n",
"clips_personal = None\n",
"if os.path.exists(\"./personal_samples\") and any(Path(\"./personal_samples\").glob(\"*.wav\")):\n",
" clips_personal = Clips(\n",
" input_directory=\"./personal_samples\",\n",
" file_pattern=\"*.wav\",\n",
" max_clip_duration_s=5,\n",
" remove_silence=True,\n",
" random_split_seed=10,\n",
" split_count=0.1,\n",
" )\n",
" print(\"✅ Found personal samples, will create separate feature set\")\n",
"\n",
"augmenter = Augmentation(\n", "augmenter = Augmentation(\n",
" augmentation_duration_s=3.2,\n", " augmentation_duration_s=3.2,\n",
" augmentation_probabilities={\n", " augmentation_probabilities={\n",
@@ -593,14 +608,14 @@
"from microwakeword.audio.spectrograms import SpectrogramGeneration\n", "from microwakeword.audio.spectrograms import SpectrogramGeneration\n",
"\n", "\n",
"# ---- Patch: drive clips from generated_samples/*.wav (no datasets.Audio, no torchcodec) ----\n", "# ---- Patch: drive clips from generated_samples/*.wav (no datasets.Audio, no torchcodec) ----\n",
"def audio_generator_from_wavs(self, split=\"train\", repeat=1):\n", "def audio_generator_from_wavs(self, split=\"train\", repeat=1, source_dir=\"generated_samples\"):\n",
" \"\"\"\n", " \"\"\"\n",
" Yield 1-D float32 arrays loaded via librosa from generated_samples/*.wav.\n", " Yield 1-D float32 arrays loaded via librosa from source_dir/*.wav.\n",
" Deterministic 80/10/10 split with seed 10 to mirror original Clips behavior.\n", " Deterministic 80/10/10 split with seed 10 to mirror original Clips behavior.\n",
" \"\"\"\n", " \"\"\"\n",
" files = sorted(glob.glob(\"generated_samples/*.wav\"))\n", " files = sorted(glob.glob(f\"{source_dir}/*.wav\"))\n",
" if not files:\n", " if not files:\n",
" raise SystemExit(\"❌ No WAVs in generated_samples/. Generate TTS samples first.\")\n", " raise SystemExit(f\"❌ No WAVs in {source_dir}/. Generate samples first.\")\n",
"\n", "\n",
" rng = random.Random(10) # deterministic shuffling like Clips(random_split_seed=10)\n", " rng = random.Random(10) # deterministic shuffling like Clips(random_split_seed=10)\n",
" files_shuf = files[:]\n", " files_shuf = files[:]\n",
@@ -624,9 +639,19 @@
" y, sr = librosa.load(p, sr=16000, mono=True)\n", " y, sr = librosa.load(p, sr=16000, mono=True)\n",
" yield y.astype(np.float32, copy=False)\n", " yield y.astype(np.float32, copy=False)\n",
"\n", "\n",
"# Bind the patched generator to your existing `clips` instance\n", "# Bind the patched generator to clips_tts instance\n",
"clips.audio_generator = types.MethodType(audio_generator_from_wavs, clips)\n", "def audio_generator_tts(self, split=\"train\", repeat=1):\n",
"print(\"✅ Patched clips.audio_generator to stream from generated_samples/*.wav (no torchcodec).\")\n", " return audio_generator_from_wavs(self, split, repeat, \"generated_samples\")\n",
"\n",
"clips_tts.audio_generator = types.MethodType(audio_generator_tts, clips_tts)\n",
"print(\"✅ Patched clips_tts.audio_generator to stream from generated_samples/*.wav (no torchcodec).\")\n",
"\n",
"# Bind the patched generator to clips_personal if it exists\n",
"if clips_personal is not None:\n",
" def audio_generator_personal(self, split=\"train\", repeat=1):\n",
" return audio_generator_from_wavs(self, split, repeat, \"personal_samples\")\n",
" clips_personal.audio_generator = types.MethodType(audio_generator_personal, clips_personal)\n",
" print(\"✅ Patched clips_personal.audio_generator to stream from personal_samples/*.wav (no torchcodec).\")\n",
"\n", "\n",
"# ---- Validate augmentation asset folders exist ----\n", "# ---- Validate augmentation asset folders exist ----\n",
"def validate(paths):\n", "def validate(paths):\n",
@@ -649,14 +674,14 @@
" \"testing\": {\"name\": \"test\", \"repetition\": 1, \"slide_frames\": 1},\n", " \"testing\": {\"name\": \"test\", \"repetition\": 1, \"slide_frames\": 1},\n",
"}\n", "}\n",
"\n", "\n",
"# ---- Generate features ----\n", "# ---- Generate features for TTS samples ----\n",
"for split, cfg in split_cfg.items():\n", "for split, cfg in split_cfg.items():\n",
" out_dir = out_root / split\n", " out_dir = out_root / split\n",
" out_dir.mkdir(parents=True, exist_ok=True)\n", " out_dir.mkdir(parents=True, exist_ok=True)\n",
" print(f\"🧪 Processing {split} …\")\n", " print(f\"🧪 Processing {split} (TTS) …\")\n",
"\n", "\n",
" spectros = SpectrogramGeneration(\n", " spectros = SpectrogramGeneration(\n",
" clips=clips, # now backed by our WAV loader\n", " clips=clips_tts, # now backed by our WAV loader\n",
" augmenter=augmenter, # your existing augmenter\n", " augmenter=augmenter, # your existing augmenter\n",
" slide_frames=cfg[\"slide_frames\"],\n", " slide_frames=cfg[\"slide_frames\"],\n",
" step_ms=10,\n", " step_ms=10,\n",
@@ -671,6 +696,27 @@
" verbose=True,\n", " verbose=True,\n",
" )\n", " )\n",
"\n", "\n",
"# ---- Generate features for personal samples if available ----\n",
"if clips_personal is not None:\n",
" out_root_personal = Path(\"personal_augmented_features\")\n",
" out_root_personal.mkdir(exist_ok=True)\n",
" for split, cfg in split_cfg.items():\n",
" out_dir = out_root_personal / split\n",
" out_dir.mkdir(parents=True, exist_ok=True)\n",
" print(f\"🧪 Processing {split} (personal) …\")\n",
" spectros = SpectrogramGeneration(\n",
" clips=clips_personal,\n",
" augmenter=augmenter,\n",
" slide_frames=cfg[\"slide_frames\"],\n",
" step_ms=10,\n",
" )\n",
" RaggedMmap.from_generator(\n",
" out_dir=str(out_dir / \"wakeword_mmap\"),\n",
" sample_generator=spectros.spectrogram_generator(split=cfg[\"name\"], repeat=cfg[\"repetition\"]),\n",
" batch_size=100,\n",
" verbose=True,\n",
" )\n",
"\n",
"print(\"✅ Features ready (generated_augmented_features/*/wakeword_mmap)\")" "print(\"✅ Features ready (generated_augmented_features/*/wakeword_mmap)\")"
] ]
}, },
@@ -752,6 +798,7 @@
"# --- Save a yaml config that controls the training process ---\n", "# --- Save a yaml config that controls the training process ---\n",
"\n", "\n",
"import os, sys, yaml\n", "import os, sys, yaml\n",
"from pathlib import Path\n",
"\n", "\n",
"config = {}\n", "config = {}\n",
"\n", "\n",
@@ -766,6 +813,11 @@
" {\"features_dir\":\"negative_datasets/dinner_party_eval\",\"sampling_weight\":0.0,\"penalty_weight\":1.0,\"truth\":False,\"truncation_strategy\":\"split\",\"type\":\"mmap\"},\n", " {\"features_dir\":\"negative_datasets/dinner_party_eval\",\"sampling_weight\":0.0,\"penalty_weight\":1.0,\"truth\":False,\"truncation_strategy\":\"split\",\"type\":\"mmap\"},\n",
"]\n", "]\n",
"\n", "\n",
"# Add personal features if they exist\n",
"if os.path.exists(\"personal_augmented_features/training\"):\n",
" config[\"features\"].insert(1, {\"features_dir\": \"personal_augmented_features\", \"sampling_weight\": 3.0, \"penalty_weight\": 1.0, \"truth\": True, \"truncation_strategy\": \"truncate_start\", \"type\": \"mmap\"})\n",
" print(\"✅ Added personal features with higher weight (3.0)\")\n",
"\n",
"config[\"training_steps\"] = [40000]\n", "config[\"training_steps\"] = [40000]\n",
"config[\"positive_class_weight\"] = [1]\n", "config[\"positive_class_weight\"] = [1]\n",
"config[\"negative_class_weight\"] = [20]\n", "config[\"negative_class_weight\"] = [20]\n",

View File

@@ -8,7 +8,7 @@ umask 002
NOTEBOOK_SRC="/root/microWakeWord_training_notebook.ipynb" NOTEBOOK_SRC="/root/microWakeWord_training_notebook.ipynb"
NOTEBOOK_DST="/data/microWakeWord_training_notebook.ipynb" NOTEBOOK_DST="/data/microWakeWord_training_notebook.ipynb"
mkdir -p /data /data/generated_samples mkdir -p /data /data/generated_samples /data/personal_samples
if [[ ! -f "$NOTEBOOK_DST" ]]; then if [[ ! -f "$NOTEBOOK_DST" ]]; then
echo "No training notebook found in /data; copying default…" echo "No training notebook found in /data; copying default…"