Update advanced_training_notebook.ipynb

This commit is contained in:
MasterPhooey
2025-01-08 20:57:36 -06:00
committed by GitHub
parent b9be19db0d
commit f2368903d2

View File

@@ -246,12 +246,13 @@
" fname = \"fma_xs.zip\"\n", " fname = \"fma_xs.zip\"\n",
" link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/\" + fname\n", " link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/\" + fname\n",
" out_dir = os.path.join(output_dir, fname)\n", " out_dir = os.path.join(output_dir, fname)\n",
" os.system(f\"wget -O {out_dir} {link}\")\n", " print(f\"Downloading {fname}...\")\n",
" os.system(f\"cd {output_dir} && unzip -q {fname}\")\n", " os.system(f\"wget -q -O {out_dir} {link}\")\n",
" print(f\"Extracting {fname}...\")\n",
" os.system(f\"unzip -q -o {out_dir} -d {output_dir}\")\n",
"\n", "\n",
"output_dir = \"./fma_16k\"\n", "output_dir = \"./fma_16k\"\n",
"if not os.path.exists(output_dir):\n", "os.makedirs(output_dir, exist_ok=True)\n",
" os.mkdir(output_dir)\n",
"\n", "\n",
"# Save clips to 16-bit PCM wav files\n", "# Save clips to 16-bit PCM wav files\n",
"fma_files = list(Path(\"fma/fma_small\").glob(\"**/*.mp3\"))\n", "fma_files = list(Path(\"fma/fma_small\").glob(\"**/*.mp3\"))\n",
@@ -262,25 +263,32 @@
"\n", "\n",
" corrupted_files = []\n", " corrupted_files = []\n",
" print(\"Converting FMA files to 16kHz WAV...\")\n", " print(\"Converting FMA files to 16kHz WAV...\")\n",
" for row in tqdm(fma_dataset):\n", " for row in tqdm(fma_dataset, desc=\"Processing FMA files\", unit=\"file\"):\n",
" try:\n", " try:\n",
" name = row[\"audio\"][\"path\"].split(\"/\")[-1].replace(\".mp3\", \".wav\")\n", " name = Path(row[\"audio\"][\"path\"]).stem + \".wav\"\n",
" output_path = Path(output_dir) / name\n",
"\n",
" # Check if audio data is valid before writing\n",
" if row[\"audio\"][\"array\"] is None or len(row[\"audio\"][\"array\"]) == 0:\n",
" raise ValueError(\"Empty or invalid audio data\")\n",
"\n",
" scipy.io.wavfile.write(\n", " scipy.io.wavfile.write(\n",
" os.path.join(output_dir, name), \n", " output_path,\n",
" 16000,\n", " 16000,\n",
" (row[\"audio\"][\"array\"] * 32767).astype(np.int16)\n", " (row[\"audio\"][\"array\"] * 32767).astype(np.int16),\n",
" )\n", " )\n",
" except Exception as e:\n", " except Exception as e:\n",
" print(f\"Error converting {row['audio']['path']}: {e}\")\n",
" corrupted_files.append(row[\"audio\"][\"path\"])\n", " corrupted_files.append(row[\"audio\"][\"path\"])\n",
"\n", "\n",
" if corrupted_files:\n", " if corrupted_files:\n",
" with open(\"fma_corrupted_files.log\", \"w\") as log_file:\n", " log_path = Path(output_dir) / \"fma_corrupted_files.log\"\n",
" with open(log_path, \"w\") as log_file:\n",
" log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n", " log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n",
" print(f\"Logged {len(corrupted_files)} corrupted files to {log_path}\")\n",
"else:\n", "else:\n",
" print(\"No MP3 files found in FMA.\")\n", " print(\"No MP3 files found in FMA.\")\n",
"\n", "\n",
"print(\"Dataset preparation complete!\")" "print(\"FMA dataset preparation complete!\")"
] ]
}, },
{ {
@@ -551,7 +559,7 @@
" },\n", " },\n",
" {\n", " {\n",
" \"features_dir\": \"negative_datasets/no_speech\",\n", " \"features_dir\": \"negative_datasets/no_speech\",\n",
" \"sampling_weight\": 5.0, # Balanced\n", " \"sampling_weight\": 7.0, # Balanced\n",
" \"penalty_weight\": 1.0,\n", " \"penalty_weight\": 1.0,\n",
" \"truth\": False,\n", " \"truth\": False,\n",
" \"truncation_strategy\": \"random\",\n", " \"truncation_strategy\": \"random\",\n",
@@ -559,7 +567,7 @@
" },\n", " },\n",
" {\n", " {\n",
" \"features_dir\": \"negative_datasets/dinner_party_eval\",\n", " \"features_dir\": \"negative_datasets/dinner_party_eval\",\n",
" \"sampling_weight\": 0.0,\n", " \"sampling_weight\": 8.0,\n",
" \"penalty_weight\": 1.0,\n", " \"penalty_weight\": 1.0,\n",
" \"truth\": False,\n", " \"truth\": False,\n",
" \"truncation_strategy\": \"split\",\n", " \"truncation_strategy\": \"split\",\n",
@@ -567,18 +575,18 @@
" },\n", " },\n",
"]\n", "]\n",
"\n", "\n",
"config[\"training_steps\"] = [20000] # Increased\n", "config[\"training_steps\"] = [30000] # Increased\n",
"config[\"positive_class_weight\"] = [1]\n", "config[\"positive_class_weight\"] = [2]\n",
"config[\"negative_class_weight\"] = [20] # Adjusted\n", "config[\"negative_class_weight\"] = [15] # Adjusted\n",
"config[\"learning_rates\"] = [0.001] # Adjusted\n", "config[\"learning_rates\"] = [0.001, 0.0001] # Adjusted\n",
"config[\"batch_size\"] = 128\n", "config[\"batch_size\"] = 256\n",
"\n", "\n",
"config[\"time_mask_max_size\"] = [0] # Enabled SpecAugment\n", "config[\"time_mask_max_size\"] = [50] # Enabled SpecAugment\n",
"config[\"time_mask_count\"] = [0]\n", "config[\"time_mask_count\"] = [2]\n",
"config[\"freq_mask_max_size\"] = [0]\n", "config[\"freq_mask_max_size\"] = [5]\n",
"config[\"freq_mask_count\"] = [0]\n", "config[\"freq_mask_count\"] = [2]\n",
"\n", "\n",
"config[\"eval_step_interval\"] = 500 # Adjusted\n", "config[\"eval_step_interval\"] = 250 # Adjusted\n",
"config[\"clip_duration_ms\"] = 1500 # Increased\n", "config[\"clip_duration_ms\"] = 1500 # Increased\n",
"\n", "\n",
"config[\"target_minimization\"] = 0.9\n", "config[\"target_minimization\"] = 0.9\n",