From 4d25fc3051d82590b03d5324202007961309130c Mon Sep 17 00:00:00 2001 From: MasterPhooey <106418429+MasterPhooey@users.noreply.github.com> Date: Sun, 19 Jan 2025 15:45:38 -0600 Subject: [PATCH] Update advanced_training_notebook.ipynb --- advanced_training_notebook.ipynb | 76 ++++++++++++++------------------ 1 file changed, 34 insertions(+), 42 deletions(-) diff --git a/advanced_training_notebook.ipynb b/advanced_training_notebook.ipynb index df9d4e2..88765ac 100644 --- a/advanced_training_notebook.ipynb +++ b/advanced_training_notebook.ipynb @@ -124,7 +124,7 @@ "# wake word samples, possibly with different phonetic pronunciations.\n", "\n", "!\"{sys.executable}\" piper-sample-generator/generate_samples.py \"{target_word}\" \\\n", - "--max-samples 10000 \\\n", + "--max-samples 50000 \\\n", "--batch-size 100 \\\n", "--output-dir generated_samples" ] @@ -246,13 +246,12 @@ " fname = \"fma_xs.zip\"\n", " link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/\" + fname\n", " out_dir = os.path.join(output_dir, fname)\n", - " print(f\"Downloading {fname}...\")\n", - " os.system(f\"wget -q -O {out_dir} {link}\")\n", - " print(f\"Extracting {fname}...\")\n", - " os.system(f\"unzip -q -o {out_dir} -d {output_dir}\")\n", + " os.system(f\"wget -O {out_dir} {link}\")\n", + " os.system(f\"cd {output_dir} && unzip -q {fname}\")\n", "\n", "output_dir = \"./fma_16k\"\n", - "os.makedirs(output_dir, exist_ok=True)\n", + "if not os.path.exists(output_dir):\n", + " os.mkdir(output_dir)\n", "\n", "# Save clips to 16-bit PCM wav files\n", "fma_files = list(Path(\"fma/fma_small\").glob(\"**/*.mp3\"))\n", @@ -263,32 +262,25 @@ "\n", " corrupted_files = []\n", " print(\"Converting FMA files to 16kHz WAV...\")\n", - " for row in tqdm(fma_dataset, desc=\"Processing FMA files\", unit=\"file\"):\n", + " for row in tqdm(fma_dataset):\n", " try:\n", - " name = Path(row[\"audio\"][\"path\"]).stem + \".wav\"\n", - " output_path = Path(output_dir) / name\n", - "\n", - " # Check if audio data is valid before writing\n", - " if row[\"audio\"][\"array\"] is None or len(row[\"audio\"][\"array\"]) == 0:\n", - " raise ValueError(\"Empty or invalid audio data\")\n", - "\n", + " name = row[\"audio\"][\"path\"].split(\"/\")[-1].replace(\".mp3\", \".wav\")\n", " scipy.io.wavfile.write(\n", - " output_path,\n", - " 16000,\n", - " (row[\"audio\"][\"array\"] * 32767).astype(np.int16),\n", + " os.path.join(output_dir, name), \n", + " 16000, \n", + " (row[\"audio\"][\"array\"] * 32767).astype(np.int16)\n", " )\n", " except Exception as e:\n", + " print(f\"Error converting {row['audio']['path']}: {e}\")\n", " corrupted_files.append(row[\"audio\"][\"path\"])\n", "\n", " if corrupted_files:\n", - " log_path = Path(output_dir) / \"fma_corrupted_files.log\"\n", - " with open(log_path, \"w\") as log_file:\n", + " with open(\"fma_corrupted_files.log\", \"w\") as log_file:\n", " log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n", - " print(f\"Logged {len(corrupted_files)} corrupted files to {log_path}\")\n", "else:\n", " print(\"No MP3 files found in FMA.\")\n", "\n", - "print(\"FMA dataset preparation complete!\")" + "print(\"Dataset preparation complete!\")" ] }, { @@ -339,16 +331,16 @@ " \"PitchShift\": 0.15,\n", " \"BandStopFilter\": 0.1,\n", " \"AddColorNoise\": 0.1,\n", - " \"AddBackgroundNoise\": 0.9,\n", + " \"AddBackgroundNoise\": 0.7,\n", " \"Gain\": 0.8,\n", " \"RIR\": 0.7,\n", " },\n", " impulse_paths=impulse_paths,\n", " background_paths=background_paths,\n", - " background_min_snr_db=-10,\n", - " background_max_snr_db=15,\n", - " min_jitter_s=0.15,\n", - " max_jitter_s=0.25,\n", + " background_min_snr_db=5,\n", + " background_max_snr_db=10,\n", + " min_jitter_s=0.2,\n", + " max_jitter_s=0.3,\n", ")\n" ] }, @@ -543,7 +535,7 @@ " },\n", " {\n", " \"features_dir\": \"negative_datasets/speech\",\n", - " \"sampling_weight\": 10.0, # Adjusted\n", + " \"sampling_weight\": 12.0, # Adjusted\n", " \"penalty_weight\": 1.0,\n", " \"truth\": False,\n", " \"truncation_strategy\": \"random\",\n", @@ -551,7 +543,7 @@ " },\n", " {\n", " \"features_dir\": \"negative_datasets/dinner_party\",\n", - " \"sampling_weight\": 10.0, # Adjusted\n", + " \"sampling_weight\": 12.0, # Adjusted\n", " \"penalty_weight\": 1.0,\n", " \"truth\": False,\n", " \"truncation_strategy\": \"random\",\n", @@ -559,7 +551,7 @@ " },\n", " {\n", " \"features_dir\": \"negative_datasets/no_speech\",\n", - " \"sampling_weight\": 7.0, # Balanced\n", + " \"sampling_weight\": 5.0, # Balanced\n", " \"penalty_weight\": 1.0,\n", " \"truth\": False,\n", " \"truncation_strategy\": \"random\",\n", @@ -567,7 +559,7 @@ " },\n", " {\n", " \"features_dir\": \"negative_datasets/dinner_party_eval\",\n", - " \"sampling_weight\": 8.0,\n", + " \"sampling_weight\": 0.0,\n", " \"penalty_weight\": 1.0,\n", " \"truth\": False,\n", " \"truncation_strategy\": \"split\",\n", @@ -575,18 +567,18 @@ " },\n", "]\n", "\n", - "config[\"training_steps\"] = [30000] # Increased\n", - "config[\"positive_class_weight\"] = [2]\n", - "config[\"negative_class_weight\"] = [15] # Adjusted\n", - "config[\"learning_rates\"] = [0.001, 0.0001] # Adjusted\n", - "config[\"batch_size\"] = 256\n", + "config[\"training_steps\"] = [40000] # Increased\n", + "config[\"positive_class_weight\"] = [1]\n", + "config[\"negative_class_weight\"] = [20] # Adjusted\n", + "config[\"learning_rates\"] = [0.001] # Adjusted\n", + "config[\"batch_size\"] = 128\n", "\n", - "config[\"time_mask_max_size\"] = [50] # Enabled SpecAugment\n", - "config[\"time_mask_count\"] = [2]\n", - "config[\"freq_mask_max_size\"] = [5]\n", - "config[\"freq_mask_count\"] = [2]\n", + "config[\"time_mask_max_size\"] = [0] # Enabled SpecAugment\n", + "config[\"time_mask_count\"] = [0]\n", + "config[\"freq_mask_max_size\"] = [0]\n", + "config[\"freq_mask_count\"] = [0]\n", "\n", - "config[\"eval_step_interval\"] = 250 # Adjusted\n", + "config[\"eval_step_interval\"] = 500 # Adjusted\n", "config[\"clip_duration_ms\"] = 1500 # Increased\n", "\n", "config[\"target_minimization\"] = 0.9\n", @@ -632,9 +624,9 @@ "--test_tflite_streaming_quantized 1 \\\n", "--use_weights \"best_weights\" \\\n", "mixednet \\\n", - "--pointwise_filters \"64,64,64,64\" \\\n", + "--pointwise_filters \"80,80,80,80\" \\\n", "--repeat_in_block \"1,1,1,1\" \\\n", - "--mixconv_kernel_sizes '[5], [7,11], [9,15], [23]' \\\n", + "--mixconv_kernel_sizes '[7], [9,13], [11,17], [25]' \\\n", "--residual_connection \"0,0,0,0\" \\\n", "--first_conv_filters 32 \\\n", "--first_conv_kernel_size 5 \\\n",