mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
Update advanced_training_notebook.ipynb
This commit is contained in:
@@ -124,7 +124,7 @@
|
|||||||
"# wake word samples, possibly with different phonetic pronunciations.\n",
|
"# wake word samples, possibly with different phonetic pronunciations.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"!\"{sys.executable}\" piper-sample-generator/generate_samples.py \"{target_word}\" \\\n",
|
"!\"{sys.executable}\" piper-sample-generator/generate_samples.py \"{target_word}\" \\\n",
|
||||||
"--max-samples 10000 \\\n",
|
"--max-samples 50000 \\\n",
|
||||||
"--batch-size 100 \\\n",
|
"--batch-size 100 \\\n",
|
||||||
"--output-dir generated_samples"
|
"--output-dir generated_samples"
|
||||||
]
|
]
|
||||||
@@ -246,13 +246,12 @@
|
|||||||
" fname = \"fma_xs.zip\"\n",
|
" fname = \"fma_xs.zip\"\n",
|
||||||
" link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/\" + fname\n",
|
" link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/\" + fname\n",
|
||||||
" out_dir = os.path.join(output_dir, fname)\n",
|
" out_dir = os.path.join(output_dir, fname)\n",
|
||||||
" print(f\"Downloading {fname}...\")\n",
|
" os.system(f\"wget -O {out_dir} {link}\")\n",
|
||||||
" os.system(f\"wget -q -O {out_dir} {link}\")\n",
|
" os.system(f\"cd {output_dir} && unzip -q {fname}\")\n",
|
||||||
" print(f\"Extracting {fname}...\")\n",
|
|
||||||
" os.system(f\"unzip -q -o {out_dir} -d {output_dir}\")\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"output_dir = \"./fma_16k\"\n",
|
"output_dir = \"./fma_16k\"\n",
|
||||||
"os.makedirs(output_dir, exist_ok=True)\n",
|
"if not os.path.exists(output_dir):\n",
|
||||||
|
" os.mkdir(output_dir)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Save clips to 16-bit PCM wav files\n",
|
"# Save clips to 16-bit PCM wav files\n",
|
||||||
"fma_files = list(Path(\"fma/fma_small\").glob(\"**/*.mp3\"))\n",
|
"fma_files = list(Path(\"fma/fma_small\").glob(\"**/*.mp3\"))\n",
|
||||||
@@ -263,32 +262,25 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" corrupted_files = []\n",
|
" corrupted_files = []\n",
|
||||||
" print(\"Converting FMA files to 16kHz WAV...\")\n",
|
" print(\"Converting FMA files to 16kHz WAV...\")\n",
|
||||||
" for row in tqdm(fma_dataset, desc=\"Processing FMA files\", unit=\"file\"):\n",
|
" for row in tqdm(fma_dataset):\n",
|
||||||
" try:\n",
|
" try:\n",
|
||||||
" name = Path(row[\"audio\"][\"path\"]).stem + \".wav\"\n",
|
" name = row[\"audio\"][\"path\"].split(\"/\")[-1].replace(\".mp3\", \".wav\")\n",
|
||||||
" output_path = Path(output_dir) / name\n",
|
|
||||||
"\n",
|
|
||||||
" # Check if audio data is valid before writing\n",
|
|
||||||
" if row[\"audio\"][\"array\"] is None or len(row[\"audio\"][\"array\"]) == 0:\n",
|
|
||||||
" raise ValueError(\"Empty or invalid audio data\")\n",
|
|
||||||
"\n",
|
|
||||||
" scipy.io.wavfile.write(\n",
|
" scipy.io.wavfile.write(\n",
|
||||||
" output_path,\n",
|
" os.path.join(output_dir, name), \n",
|
||||||
" 16000, \n",
|
" 16000, \n",
|
||||||
" (row[\"audio\"][\"array\"] * 32767).astype(np.int16),\n",
|
" (row[\"audio\"][\"array\"] * 32767).astype(np.int16)\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" except Exception as e:\n",
|
" except Exception as e:\n",
|
||||||
|
" print(f\"Error converting {row['audio']['path']}: {e}\")\n",
|
||||||
" corrupted_files.append(row[\"audio\"][\"path\"])\n",
|
" corrupted_files.append(row[\"audio\"][\"path\"])\n",
|
||||||
"\n",
|
"\n",
|
||||||
" if corrupted_files:\n",
|
" if corrupted_files:\n",
|
||||||
" log_path = Path(output_dir) / \"fma_corrupted_files.log\"\n",
|
" with open(\"fma_corrupted_files.log\", \"w\") as log_file:\n",
|
||||||
" with open(log_path, \"w\") as log_file:\n",
|
|
||||||
" log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n",
|
" log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n",
|
||||||
" print(f\"Logged {len(corrupted_files)} corrupted files to {log_path}\")\n",
|
|
||||||
"else:\n",
|
"else:\n",
|
||||||
" print(\"No MP3 files found in FMA.\")\n",
|
" print(\"No MP3 files found in FMA.\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(\"FMA dataset preparation complete!\")"
|
"print(\"Dataset preparation complete!\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -339,16 +331,16 @@
|
|||||||
" \"PitchShift\": 0.15,\n",
|
" \"PitchShift\": 0.15,\n",
|
||||||
" \"BandStopFilter\": 0.1,\n",
|
" \"BandStopFilter\": 0.1,\n",
|
||||||
" \"AddColorNoise\": 0.1,\n",
|
" \"AddColorNoise\": 0.1,\n",
|
||||||
" \"AddBackgroundNoise\": 0.9,\n",
|
" \"AddBackgroundNoise\": 0.7,\n",
|
||||||
" \"Gain\": 0.8,\n",
|
" \"Gain\": 0.8,\n",
|
||||||
" \"RIR\": 0.7,\n",
|
" \"RIR\": 0.7,\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" impulse_paths=impulse_paths,\n",
|
" impulse_paths=impulse_paths,\n",
|
||||||
" background_paths=background_paths,\n",
|
" background_paths=background_paths,\n",
|
||||||
" background_min_snr_db=-10,\n",
|
" background_min_snr_db=5,\n",
|
||||||
" background_max_snr_db=15,\n",
|
" background_max_snr_db=10,\n",
|
||||||
" min_jitter_s=0.15,\n",
|
" min_jitter_s=0.2,\n",
|
||||||
" max_jitter_s=0.25,\n",
|
" max_jitter_s=0.3,\n",
|
||||||
")\n"
|
")\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -543,7 +535,7 @@
|
|||||||
" },\n",
|
" },\n",
|
||||||
" {\n",
|
" {\n",
|
||||||
" \"features_dir\": \"negative_datasets/speech\",\n",
|
" \"features_dir\": \"negative_datasets/speech\",\n",
|
||||||
" \"sampling_weight\": 10.0, # Adjusted\n",
|
" \"sampling_weight\": 12.0, # Adjusted\n",
|
||||||
" \"penalty_weight\": 1.0,\n",
|
" \"penalty_weight\": 1.0,\n",
|
||||||
" \"truth\": False,\n",
|
" \"truth\": False,\n",
|
||||||
" \"truncation_strategy\": \"random\",\n",
|
" \"truncation_strategy\": \"random\",\n",
|
||||||
@@ -551,7 +543,7 @@
|
|||||||
" },\n",
|
" },\n",
|
||||||
" {\n",
|
" {\n",
|
||||||
" \"features_dir\": \"negative_datasets/dinner_party\",\n",
|
" \"features_dir\": \"negative_datasets/dinner_party\",\n",
|
||||||
" \"sampling_weight\": 10.0, # Adjusted\n",
|
" \"sampling_weight\": 12.0, # Adjusted\n",
|
||||||
" \"penalty_weight\": 1.0,\n",
|
" \"penalty_weight\": 1.0,\n",
|
||||||
" \"truth\": False,\n",
|
" \"truth\": False,\n",
|
||||||
" \"truncation_strategy\": \"random\",\n",
|
" \"truncation_strategy\": \"random\",\n",
|
||||||
@@ -559,7 +551,7 @@
|
|||||||
" },\n",
|
" },\n",
|
||||||
" {\n",
|
" {\n",
|
||||||
" \"features_dir\": \"negative_datasets/no_speech\",\n",
|
" \"features_dir\": \"negative_datasets/no_speech\",\n",
|
||||||
" \"sampling_weight\": 7.0, # Balanced\n",
|
" \"sampling_weight\": 5.0, # Balanced\n",
|
||||||
" \"penalty_weight\": 1.0,\n",
|
" \"penalty_weight\": 1.0,\n",
|
||||||
" \"truth\": False,\n",
|
" \"truth\": False,\n",
|
||||||
" \"truncation_strategy\": \"random\",\n",
|
" \"truncation_strategy\": \"random\",\n",
|
||||||
@@ -567,7 +559,7 @@
|
|||||||
" },\n",
|
" },\n",
|
||||||
" {\n",
|
" {\n",
|
||||||
" \"features_dir\": \"negative_datasets/dinner_party_eval\",\n",
|
" \"features_dir\": \"negative_datasets/dinner_party_eval\",\n",
|
||||||
" \"sampling_weight\": 8.0,\n",
|
" \"sampling_weight\": 0.0,\n",
|
||||||
" \"penalty_weight\": 1.0,\n",
|
" \"penalty_weight\": 1.0,\n",
|
||||||
" \"truth\": False,\n",
|
" \"truth\": False,\n",
|
||||||
" \"truncation_strategy\": \"split\",\n",
|
" \"truncation_strategy\": \"split\",\n",
|
||||||
@@ -575,18 +567,18 @@
|
|||||||
" },\n",
|
" },\n",
|
||||||
"]\n",
|
"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"config[\"training_steps\"] = [30000] # Increased\n",
|
"config[\"training_steps\"] = [40000] # Increased\n",
|
||||||
"config[\"positive_class_weight\"] = [2]\n",
|
"config[\"positive_class_weight\"] = [1]\n",
|
||||||
"config[\"negative_class_weight\"] = [15] # Adjusted\n",
|
"config[\"negative_class_weight\"] = [20] # Adjusted\n",
|
||||||
"config[\"learning_rates\"] = [0.001, 0.0001] # Adjusted\n",
|
"config[\"learning_rates\"] = [0.001] # Adjusted\n",
|
||||||
"config[\"batch_size\"] = 256\n",
|
"config[\"batch_size\"] = 128\n",
|
||||||
"\n",
|
"\n",
|
||||||
"config[\"time_mask_max_size\"] = [50] # Enabled SpecAugment\n",
|
"config[\"time_mask_max_size\"] = [0] # Enabled SpecAugment\n",
|
||||||
"config[\"time_mask_count\"] = [2]\n",
|
"config[\"time_mask_count\"] = [0]\n",
|
||||||
"config[\"freq_mask_max_size\"] = [5]\n",
|
"config[\"freq_mask_max_size\"] = [0]\n",
|
||||||
"config[\"freq_mask_count\"] = [2]\n",
|
"config[\"freq_mask_count\"] = [0]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"config[\"eval_step_interval\"] = 250 # Adjusted\n",
|
"config[\"eval_step_interval\"] = 500 # Adjusted\n",
|
||||||
"config[\"clip_duration_ms\"] = 1500 # Increased\n",
|
"config[\"clip_duration_ms\"] = 1500 # Increased\n",
|
||||||
"\n",
|
"\n",
|
||||||
"config[\"target_minimization\"] = 0.9\n",
|
"config[\"target_minimization\"] = 0.9\n",
|
||||||
@@ -632,9 +624,9 @@
|
|||||||
"--test_tflite_streaming_quantized 1 \\\n",
|
"--test_tflite_streaming_quantized 1 \\\n",
|
||||||
"--use_weights \"best_weights\" \\\n",
|
"--use_weights \"best_weights\" \\\n",
|
||||||
"mixednet \\\n",
|
"mixednet \\\n",
|
||||||
"--pointwise_filters \"64,64,64,64\" \\\n",
|
"--pointwise_filters \"80,80,80,80\" \\\n",
|
||||||
"--repeat_in_block \"1,1,1,1\" \\\n",
|
"--repeat_in_block \"1,1,1,1\" \\\n",
|
||||||
"--mixconv_kernel_sizes '[5], [7,11], [9,15], [23]' \\\n",
|
"--mixconv_kernel_sizes '[7], [9,13], [11,17], [25]' \\\n",
|
||||||
"--residual_connection \"0,0,0,0\" \\\n",
|
"--residual_connection \"0,0,0,0\" \\\n",
|
||||||
"--first_conv_filters 32 \\\n",
|
"--first_conv_filters 32 \\\n",
|
||||||
"--first_conv_kernel_size 5 \\\n",
|
"--first_conv_kernel_size 5 \\\n",
|
||||||
|
|||||||
Reference in New Issue
Block a user