diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..263d9e5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +personal_samples/* + diff --git a/microWakeWord_training_notebook.ipynb b/microWakeWord_training_notebook.ipynb index b848e8b..7470a11 100644 --- a/microWakeWord_training_notebook.ipynb +++ b/microWakeWord_training_notebook.ipynb @@ -463,6 +463,7 @@ "# To improve your model, experiment with these settings and use more sources of\n", "# background clips.\n", "import sys, os\n", + "from pathlib import Path\n", "\n", "# try the common places we’ve used\n", "candidates = [\n", @@ -498,7 +499,8 @@ "if not validate_directories(impulse_paths + background_paths):\n", " raise ValueError(\"One or more required directories are missing.\")\n", "\n", - "clips = Clips(\n", + "# Process TTS generated samples (default)\n", + "clips_tts = Clips(\n", " input_directory='./generated_samples',\n", " file_pattern='*.wav',\n", " max_clip_duration_s=5,\n", @@ -507,6 +509,19 @@ " split_count=0.1,\n", ")\n", "\n", + "# Process personal recordings if available (optional)\n", + "clips_personal = None\n", + "if os.path.exists(\"./personal_samples\") and any(Path(\"./personal_samples\").glob(\"*.wav\")):\n", + " clips_personal = Clips(\n", + " input_directory=\"./personal_samples\",\n", + " file_pattern=\"*.wav\",\n", + " max_clip_duration_s=5,\n", + " remove_silence=True,\n", + " random_split_seed=10,\n", + " split_count=0.1,\n", + " )\n", + " print(\"✅ Found personal samples, will create separate feature set\")\n", + "\n", "augmenter = Augmentation(\n", " augmentation_duration_s=3.2,\n", " augmentation_probabilities={\n", @@ -593,14 +608,14 @@ "from microwakeword.audio.spectrograms import SpectrogramGeneration\n", "\n", "# ---- Patch: drive clips from generated_samples/*.wav (no datasets.Audio, no torchcodec) ----\n", - "def audio_generator_from_wavs(self, split=\"train\", repeat=1):\n", + "def audio_generator_from_wavs(self, split=\"train\", repeat=1, source_dir=\"generated_samples\"):\n", " \"\"\"\n", - " Yield 1-D float32 arrays loaded via librosa from generated_samples/*.wav.\n", + " Yield 1-D float32 arrays loaded via librosa from source_dir/*.wav.\n", " Deterministic 80/10/10 split with seed 10 to mirror original Clips behavior.\n", " \"\"\"\n", - " files = sorted(glob.glob(\"generated_samples/*.wav\"))\n", + " files = sorted(glob.glob(f\"{source_dir}/*.wav\"))\n", " if not files:\n", - " raise SystemExit(\"❌ No WAVs in generated_samples/. Generate TTS samples first.\")\n", + " raise SystemExit(f\"❌ No WAVs in {source_dir}/. Generate samples first.\")\n", "\n", " rng = random.Random(10) # deterministic shuffling like Clips(random_split_seed=10)\n", " files_shuf = files[:]\n", @@ -624,9 +639,19 @@ " y, sr = librosa.load(p, sr=16000, mono=True)\n", " yield y.astype(np.float32, copy=False)\n", "\n", - "# Bind the patched generator to your existing `clips` instance\n", - "clips.audio_generator = types.MethodType(audio_generator_from_wavs, clips)\n", - "print(\"✅ Patched clips.audio_generator to stream from generated_samples/*.wav (no torchcodec).\")\n", + "# Bind the patched generator to clips_tts instance\n", + "def audio_generator_tts(self, split=\"train\", repeat=1):\n", + " return audio_generator_from_wavs(self, split, repeat, \"generated_samples\")\n", + "\n", + "clips_tts.audio_generator = types.MethodType(audio_generator_tts, clips_tts)\n", + "print(\"✅ Patched clips_tts.audio_generator to stream from generated_samples/*.wav (no torchcodec).\")\n", + "\n", + "# Bind the patched generator to clips_personal if it exists\n", + "if clips_personal is not None:\n", + " def audio_generator_personal(self, split=\"train\", repeat=1):\n", + " return audio_generator_from_wavs(self, split, repeat, \"personal_samples\")\n", + " clips_personal.audio_generator = types.MethodType(audio_generator_personal, clips_personal)\n", + " print(\"✅ Patched clips_personal.audio_generator to stream from personal_samples/*.wav (no torchcodec).\")\n", "\n", "# ---- Validate augmentation asset folders exist ----\n", "def validate(paths):\n", @@ -649,14 +674,14 @@ " \"testing\": {\"name\": \"test\", \"repetition\": 1, \"slide_frames\": 1},\n", "}\n", "\n", - "# ---- Generate features ----\n", + "# ---- Generate features for TTS samples ----\n", "for split, cfg in split_cfg.items():\n", " out_dir = out_root / split\n", " out_dir.mkdir(parents=True, exist_ok=True)\n", - " print(f\"🧪 Processing {split} …\")\n", + " print(f\"🧪 Processing {split} (TTS) …\")\n", "\n", " spectros = SpectrogramGeneration(\n", - " clips=clips, # now backed by our WAV loader\n", + " clips=clips_tts, # now backed by our WAV loader\n", " augmenter=augmenter, # your existing augmenter\n", " slide_frames=cfg[\"slide_frames\"],\n", " step_ms=10,\n", @@ -671,6 +696,27 @@ " verbose=True,\n", " )\n", "\n", + "# ---- Generate features for personal samples if available ----\n", + "if clips_personal is not None:\n", + " out_root_personal = Path(\"personal_augmented_features\")\n", + " out_root_personal.mkdir(exist_ok=True)\n", + " for split, cfg in split_cfg.items():\n", + " out_dir = out_root_personal / split\n", + " out_dir.mkdir(parents=True, exist_ok=True)\n", + " print(f\"🧪 Processing {split} (personal) …\")\n", + " spectros = SpectrogramGeneration(\n", + " clips=clips_personal,\n", + " augmenter=augmenter,\n", + " slide_frames=cfg[\"slide_frames\"],\n", + " step_ms=10,\n", + " )\n", + " RaggedMmap.from_generator(\n", + " out_dir=str(out_dir / \"wakeword_mmap\"),\n", + " sample_generator=spectros.spectrogram_generator(split=cfg[\"name\"], repeat=cfg[\"repetition\"]),\n", + " batch_size=100,\n", + " verbose=True,\n", + " )\n", + "\n", "print(\"✅ Features ready (generated_augmented_features/*/wakeword_mmap)\")" ] }, @@ -752,6 +798,7 @@ "# --- Save a yaml config that controls the training process ---\n", "\n", "import os, sys, yaml\n", + "from pathlib import Path\n", "\n", "config = {}\n", "\n", @@ -766,6 +813,11 @@ " {\"features_dir\":\"negative_datasets/dinner_party_eval\",\"sampling_weight\":0.0,\"penalty_weight\":1.0,\"truth\":False,\"truncation_strategy\":\"split\",\"type\":\"mmap\"},\n", "]\n", "\n", + "# Add personal features if they exist\n", + "if os.path.exists(\"personal_augmented_features/training\"):\n", + " config[\"features\"].insert(1, {\"features_dir\": \"personal_augmented_features\", \"sampling_weight\": 3.0, \"penalty_weight\": 1.0, \"truth\": True, \"truncation_strategy\": \"truncate_start\", \"type\": \"mmap\"})\n", + " print(\"✅ Added personal features with higher weight (3.0)\")\n", + "\n", "config[\"training_steps\"] = [40000]\n", "config[\"positive_class_weight\"] = [1]\n", "config[\"negative_class_weight\"] = [20]\n", diff --git a/startup.sh b/startup.sh index 92dab09..bb4f3e2 100644 --- a/startup.sh +++ b/startup.sh @@ -8,7 +8,7 @@ umask 002 NOTEBOOK_SRC="/root/microWakeWord_training_notebook.ipynb" NOTEBOOK_DST="/data/microWakeWord_training_notebook.ipynb" -mkdir -p /data /data/generated_samples +mkdir -p /data /data/generated_samples /data/personal_samples if [[ ! -f "$NOTEBOOK_DST" ]]; then echo "No training notebook found in /data; copying default…"