Update advanced_training_notebook.ipynb

This commit is contained in:
MasterPhooey
2025-01-05 10:42:01 -06:00
committed by GitHub
parent 07addfe2ce
commit 53fae46446

View File

@@ -144,141 +144,129 @@
"# licenses and usage restrictions. As such, any custom models trained with this\n", "# licenses and usage restrictions. As such, any custom models trained with this\n",
"# data should be considered as appropriate for **non-commercial** personal use only.\n", "# data should be considered as appropriate for **non-commercial** personal use only.\n",
"\n", "\n",
"import datasets\n",
"import scipy\n",
"import os\n", "import os\n",
"import requests\n", "import scipy.io.wavfile\n",
"\n",
"import numpy as np\n", "import numpy as np\n",
"\n", "from datasets import Dataset, Audio, load_dataset\n",
"from pathlib import Path\n", "from pathlib import Path\n",
"from tqdm import tqdm\n", "from tqdm import tqdm\n",
"\n", "\n",
"## Download MIR RIR data\n", "# -----------------------------\n",
"\n", "# Download and Process MIT RIR\n",
"# -----------------------------\n",
"output_dir = \"./mit_rirs\"\n", "output_dir = \"./mit_rirs\"\n",
"if not os.path.exists(output_dir):\n", "if not os.path.exists(output_dir):\n",
" os.mkdir(output_dir)\n", " os.mkdir(output_dir)\n",
" rir_dataset = datasets.load_dataset(\"davidscripka/MIT_environmental_impulse_responses\", split=\"train\", streaming=True)\n", " rir_dataset = load_dataset(\"davidscripka/MIT_environmental_impulse_responses\", split=\"train\", streaming=True)\n",
" # Save clips to 16-bit PCM wav files\n", " print(f\"Downloading MIT RIR dataset to {output_dir}...\")\n",
" for row in tqdm(rir_dataset):\n", " for row in tqdm(rir_dataset):\n",
" name = row['audio']['path'].split('/')[-1]\n", " name = row[\"audio\"][\"path\"].split(\"/\")[-1]\n",
" scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array'] * 32767).astype(np.int16))\n", " scipy.io.wavfile.write(\n",
" os.path.join(output_dir, name), \n",
" 16000, \n",
" (row[\"audio\"][\"array\"] * 32767).astype(np.int16)\n",
" )\n",
" print(f\"Finished downloading MIT RIR dataset to {output_dir}.\\n\")\n",
"else:\n",
" print(f\"{output_dir} already exists. Skipping download.\")\n",
"\n", "\n",
"## Download noise and background audio\n", "# -----------------------------\n",
"# Download and Process Audioset\n",
"# -----------------------------\n",
"audioset_dir = \"./audioset\"\n",
"if not os.path.exists(audioset_dir):\n",
" os.mkdir(audioset_dir)\n",
"\n", "\n",
"# Full-scale dataset download links\n", "# Full-scale dataset download links\n",
"dataset_links = [\n", "dataset_links = [\n",
" f\"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/bal_train0{i}.tar\"\n", " f\"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/bal_train0{i}.tar\"\n",
" for i in range(10) # Adjust for additional parts\n", " for i in range(10)\n",
"]\n", "]\n",
"\n", "\n",
"if not os.path.exists(\"audioset\"):\n",
" os.mkdir(\"audioset\")\n",
"\n",
"for link in dataset_links:\n", "for link in dataset_links:\n",
" fname = link.split(\"/\")[-1]\n", " file_name = link.split(\"/\")[-1]\n",
" out_dir = f\"audioset/{fname}\"\n", " out_dir = os.path.join(audioset_dir, file_name)\n",
" if not os.path.exists(out_dir):\n", " if not os.path.exists(out_dir):\n",
" print(f\"Downloading {fname}...\")\n", " print(f\"Downloading {file_name}...\")\n",
" response = requests.get(link, stream=True)\n", " os.system(f\"wget --quiet -O {out_dir} {link}\")\n",
" with open(out_dir, \"wb\") as f:\n", " print(f\"Extracting {file_name}...\")\n",
" for chunk in response.iter_content(chunk_size=1024):\n", " os.system(f\"cd {audioset_dir} && tar -xf {file_name}\")\n",
" if chunk:\n",
" f.write(chunk)\n",
" print(f\"Downloaded {fname}\")\n",
"\n",
" # Extract the tar file\n",
" print(f\"Extracting {fname}...\")\n",
" os.system(f\"cd audioset && tar -xf {fname}\")\n",
"\n", "\n",
"output_dir = \"./audioset_16k\"\n", "output_dir = \"./audioset_16k\"\n",
"if not os.path.exists(output_dir):\n", "if not os.path.exists(output_dir):\n",
" os.mkdir(output_dir)\n", " os.mkdir(output_dir)\n",
"\n", "\n",
"# Save clips to 16-bit PCM wav files\n", "# Save clips to 16-bit PCM wav files\n",
" audio_files = list(Path(\"audioset/audio\").glob(\"**/*.flac\"))\n", "audioset_files = list(Path(\"audioset/audio\").glob(\"**/*.flac\"))\n",
" print(f\"Number of FLAC files found: {len(audio_files)}\")\n", "print(f\"Number of FLAC files found: {len(audioset_files)}\")\n",
" if not audio_files:\n", "if audioset_files:\n",
" raise FileNotFoundError(\"No .flac files found in the audioset directory. Check your dataset extraction.\")\n", " audioset_dataset = Dataset.from_dict({\"audio\": [str(file) for file in audioset_files]})\n",
" audioset_dataset = audioset_dataset.cast_column(\"audio\", Audio(sampling_rate=16000))\n",
"\n", "\n",
" print(\"Converting audioset files to 16kHz WAV...\")\n", " corrupted_files = []\n",
" for file_path in tqdm(audio_files):\n", " print(\"Converting Audioset files to 16kHz WAV...\")\n",
" for row in tqdm(audioset_dataset):\n",
" try:\n", " try:\n",
" # Read and convert the .flac file\n", " name = row[\"audio\"][\"path\"].split(\"/\")[-1].replace(\".flac\", \".wav\")\n",
" data, samplerate = scipy.io.wavfile.read(file_path)\n",
" if samplerate != 16000:\n",
" data = np.interp(\n",
" np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n",
" np.arange(len(data)),\n",
" data,\n",
" )\n",
"\n",
" # Save as WAV\n",
" output_path = Path(output_dir) / file_path.name.replace(\".flac\", \".wav\")\n",
" scipy.io.wavfile.write(\n", " scipy.io.wavfile.write(\n",
" output_path,\n", " os.path.join(output_dir, name), \n",
" 16000, \n", " 16000, \n",
" (data * 32767).astype(np.int16),\n", " (row[\"audio\"][\"array\"] * 32767).astype(np.int16)\n",
" )\n", " )\n",
" except Exception as e:\n", " except Exception as e:\n",
" print(f\"Error converting {file_path}: {e}\")\n", " print(f\"Error converting {row['audio']['path']}: {e}\")\n",
" continue # Skip and proceed with the next file\n", " corrupted_files.append(row[\"audio\"][\"path\"])\n",
"\n", "\n",
"# Free Music Archive dataset\n", " if corrupted_files:\n",
"# https://github.com/mdeff/fma\n", " with open(\"audioset_corrupted_files.log\", \"w\") as log_file:\n",
"# (Third-party mchl914 extra small set)\n", " log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n",
"else:\n",
" print(\"No FLAC files found in Audioset.\")\n",
"\n", "\n",
"# -----------------------------\n",
"# Download and Process FMA\n",
"# -----------------------------\n",
"output_dir = \"./fma\"\n", "output_dir = \"./fma\"\n",
"if not os.path.exists(output_dir):\n", "if not os.path.exists(output_dir):\n",
" os.mkdir(output_dir)\n", " os.mkdir(output_dir)\n",
" fname = \"fma_xs.zip\"\n", " fname = \"fma_xs.zip\"\n",
" link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/\" + fname\n", " link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/\" + fname\n",
" out_dir = f\"fma/{fname}\"\n", " out_dir = os.path.join(output_dir, fname)\n",
" print(\"Downloading fma_xs dataset...\")\n", " os.system(f\"wget -O {out_dir} {link}\")\n",
" response = requests.get(link, stream=True)\n", " os.system(f\"cd {output_dir} && unzip -q {fname}\")\n",
" with open(out_dir, \"wb\") as f:\n",
" for chunk in response.iter_content(chunk_size=1024):\n",
" if chunk:\n",
" f.write(chunk)\n",
" print(\"Downloaded fma_xs dataset.\")\n",
"\n",
" os.system(f\"cd fma && unzip -q {fname}\")\n",
"\n", "\n",
"output_dir = \"./fma_16k\"\n", "output_dir = \"./fma_16k\"\n",
"if not os.path.exists(output_dir):\n", "if not os.path.exists(output_dir):\n",
" os.mkdir(output_dir)\n", " os.mkdir(output_dir)\n",
"\n", "\n",
"# Save clips to 16-bit PCM wav files\n", "# Save clips to 16-bit PCM wav files\n",
" audio_files = list(Path(\"fma/fma_small\").glob(\"**/*.mp3\"))\n", "fma_files = list(Path(\"fma/fma_small\").glob(\"**/*.mp3\"))\n",
" print(f\"Number of MP3 files found: {len(audio_files)}\")\n", "print(f\"Number of MP3 files found: {len(fma_files)}\")\n",
" if not audio_files:\n", "if fma_files:\n",
" raise FileNotFoundError(\"No .mp3 files found in the fma directory. Check your dataset extraction.\")\n", " fma_dataset = Dataset.from_dict({\"audio\": [str(file) for file in fma_files]})\n",
" fma_dataset = fma_dataset.cast_column(\"audio\", Audio(sampling_rate=16000))\n",
"\n", "\n",
" print(\"Converting fma_xs files to 16kHz WAV...\")\n", " corrupted_files = []\n",
" for file_path in tqdm(audio_files):\n", " print(\"Converting FMA files to 16kHz WAV...\")\n",
" for row in tqdm(fma_dataset):\n",
" try:\n", " try:\n",
" # Read and convert the .mp3 file\n", " name = row[\"audio\"][\"path\"].split(\"/\")[-1].replace(\".mp3\", \".wav\")\n",
" data, samplerate = scipy.io.wavfile.read(file_path)\n",
" if samplerate != 16000:\n",
" data = np.interp(\n",
" np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n",
" np.arange(len(data)),\n",
" data,\n",
" )\n",
"\n",
" # Save as WAV\n",
" output_path = Path(output_dir) / file_path.name.replace(\".mp3\", \".wav\")\n",
" scipy.io.wavfile.write(\n", " scipy.io.wavfile.write(\n",
" output_path,\n", " os.path.join(output_dir, name), \n",
" 16000, \n", " 16000, \n",
" (data * 32767).astype(np.int16),\n", " (row[\"audio\"][\"array\"] * 32767).astype(np.int16)\n",
" )\n", " )\n",
" except Exception as e:\n", " except Exception as e:\n",
" print(f\"Error converting {file_path}: {e}\")\n", " print(f\"Error converting {row['audio']['path']}: {e}\")\n",
" continue # Skip and proceed with the next file\n", " corrupted_files.append(row[\"audio\"][\"path\"])\n",
"\n", "\n",
"print(\"All datasets processed successfully!\")" " if corrupted_files:\n",
" with open(\"fma_corrupted_files.log\", \"w\") as log_file:\n",
" log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n",
"else:\n",
" print(\"No MP3 files found in FMA.\")\n",
"\n",
"print(\"Dataset preparation complete!\")"
] ]
}, },
{ {