mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
Update advanced_training_notebook.ipynb
This commit is contained in:
@@ -144,141 +144,129 @@
|
|||||||
"# licenses and usage restrictions. As such, any custom models trained with this\n",
|
"# licenses and usage restrictions. As such, any custom models trained with this\n",
|
||||||
"# data should be considered as appropriate for **non-commercial** personal use only.\n",
|
"# data should be considered as appropriate for **non-commercial** personal use only.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"import datasets\n",
|
|
||||||
"import scipy\n",
|
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"import requests\n",
|
"import scipy.io.wavfile\n",
|
||||||
"\n",
|
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
"\n",
|
"from datasets import Dataset, Audio, load_dataset\n",
|
||||||
"from pathlib import Path\n",
|
"from pathlib import Path\n",
|
||||||
"from tqdm import tqdm\n",
|
"from tqdm import tqdm\n",
|
||||||
"\n",
|
"\n",
|
||||||
"## Download MIR RIR data\n",
|
"# -----------------------------\n",
|
||||||
"\n",
|
"# Download and Process MIT RIR\n",
|
||||||
|
"# -----------------------------\n",
|
||||||
"output_dir = \"./mit_rirs\"\n",
|
"output_dir = \"./mit_rirs\"\n",
|
||||||
"if not os.path.exists(output_dir):\n",
|
"if not os.path.exists(output_dir):\n",
|
||||||
" os.mkdir(output_dir)\n",
|
" os.mkdir(output_dir)\n",
|
||||||
" rir_dataset = datasets.load_dataset(\"davidscripka/MIT_environmental_impulse_responses\", split=\"train\", streaming=True)\n",
|
" rir_dataset = load_dataset(\"davidscripka/MIT_environmental_impulse_responses\", split=\"train\", streaming=True)\n",
|
||||||
" # Save clips to 16-bit PCM wav files\n",
|
" print(f\"Downloading MIT RIR dataset to {output_dir}...\")\n",
|
||||||
" for row in tqdm(rir_dataset):\n",
|
" for row in tqdm(rir_dataset):\n",
|
||||||
" name = row['audio']['path'].split('/')[-1]\n",
|
" name = row[\"audio\"][\"path\"].split(\"/\")[-1]\n",
|
||||||
" scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array'] * 32767).astype(np.int16))\n",
|
" scipy.io.wavfile.write(\n",
|
||||||
|
" os.path.join(output_dir, name), \n",
|
||||||
|
" 16000, \n",
|
||||||
|
" (row[\"audio\"][\"array\"] * 32767).astype(np.int16)\n",
|
||||||
|
" )\n",
|
||||||
|
" print(f\"Finished downloading MIT RIR dataset to {output_dir}.\\n\")\n",
|
||||||
|
"else:\n",
|
||||||
|
" print(f\"{output_dir} already exists. Skipping download.\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"## Download noise and background audio\n",
|
"# -----------------------------\n",
|
||||||
|
"# Download and Process Audioset\n",
|
||||||
|
"# -----------------------------\n",
|
||||||
|
"audioset_dir = \"./audioset\"\n",
|
||||||
|
"if not os.path.exists(audioset_dir):\n",
|
||||||
|
" os.mkdir(audioset_dir)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Full-scale dataset download links\n",
|
"# Full-scale dataset download links\n",
|
||||||
"dataset_links = [\n",
|
"dataset_links = [\n",
|
||||||
" f\"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/bal_train0{i}.tar\"\n",
|
" f\"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/bal_train0{i}.tar\"\n",
|
||||||
" for i in range(10) # Adjust for additional parts\n",
|
" for i in range(10)\n",
|
||||||
"]\n",
|
"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"if not os.path.exists(\"audioset\"):\n",
|
|
||||||
" os.mkdir(\"audioset\")\n",
|
|
||||||
"\n",
|
|
||||||
"for link in dataset_links:\n",
|
"for link in dataset_links:\n",
|
||||||
" fname = link.split(\"/\")[-1]\n",
|
" file_name = link.split(\"/\")[-1]\n",
|
||||||
" out_dir = f\"audioset/{fname}\"\n",
|
" out_dir = os.path.join(audioset_dir, file_name)\n",
|
||||||
" if not os.path.exists(out_dir):\n",
|
" if not os.path.exists(out_dir):\n",
|
||||||
" print(f\"Downloading {fname}...\")\n",
|
" print(f\"Downloading {file_name}...\")\n",
|
||||||
" response = requests.get(link, stream=True)\n",
|
" os.system(f\"wget --quiet -O {out_dir} {link}\")\n",
|
||||||
" with open(out_dir, \"wb\") as f:\n",
|
" print(f\"Extracting {file_name}...\")\n",
|
||||||
" for chunk in response.iter_content(chunk_size=1024):\n",
|
" os.system(f\"cd {audioset_dir} && tar -xf {file_name}\")\n",
|
||||||
" if chunk:\n",
|
|
||||||
" f.write(chunk)\n",
|
|
||||||
" print(f\"Downloaded {fname}\")\n",
|
|
||||||
"\n",
|
|
||||||
" # Extract the tar file\n",
|
|
||||||
" print(f\"Extracting {fname}...\")\n",
|
|
||||||
" os.system(f\"cd audioset && tar -xf {fname}\")\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"output_dir = \"./audioset_16k\"\n",
|
"output_dir = \"./audioset_16k\"\n",
|
||||||
"if not os.path.exists(output_dir):\n",
|
"if not os.path.exists(output_dir):\n",
|
||||||
" os.mkdir(output_dir)\n",
|
" os.mkdir(output_dir)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Save clips to 16-bit PCM wav files\n",
|
"# Save clips to 16-bit PCM wav files\n",
|
||||||
" audio_files = list(Path(\"audioset/audio\").glob(\"**/*.flac\"))\n",
|
"audioset_files = list(Path(\"audioset/audio\").glob(\"**/*.flac\"))\n",
|
||||||
" print(f\"Number of FLAC files found: {len(audio_files)}\")\n",
|
"print(f\"Number of FLAC files found: {len(audioset_files)}\")\n",
|
||||||
" if not audio_files:\n",
|
"if audioset_files:\n",
|
||||||
" raise FileNotFoundError(\"No .flac files found in the audioset directory. Check your dataset extraction.\")\n",
|
" audioset_dataset = Dataset.from_dict({\"audio\": [str(file) for file in audioset_files]})\n",
|
||||||
|
" audioset_dataset = audioset_dataset.cast_column(\"audio\", Audio(sampling_rate=16000))\n",
|
||||||
"\n",
|
"\n",
|
||||||
" print(\"Converting audioset files to 16kHz WAV...\")\n",
|
" corrupted_files = []\n",
|
||||||
" for file_path in tqdm(audio_files):\n",
|
" print(\"Converting Audioset files to 16kHz WAV...\")\n",
|
||||||
|
" for row in tqdm(audioset_dataset):\n",
|
||||||
" try:\n",
|
" try:\n",
|
||||||
" # Read and convert the .flac file\n",
|
" name = row[\"audio\"][\"path\"].split(\"/\")[-1].replace(\".flac\", \".wav\")\n",
|
||||||
" data, samplerate = scipy.io.wavfile.read(file_path)\n",
|
|
||||||
" if samplerate != 16000:\n",
|
|
||||||
" data = np.interp(\n",
|
|
||||||
" np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n",
|
|
||||||
" np.arange(len(data)),\n",
|
|
||||||
" data,\n",
|
|
||||||
" )\n",
|
|
||||||
"\n",
|
|
||||||
" # Save as WAV\n",
|
|
||||||
" output_path = Path(output_dir) / file_path.name.replace(\".flac\", \".wav\")\n",
|
|
||||||
" scipy.io.wavfile.write(\n",
|
" scipy.io.wavfile.write(\n",
|
||||||
" output_path,\n",
|
" os.path.join(output_dir, name), \n",
|
||||||
" 16000, \n",
|
" 16000, \n",
|
||||||
" (data * 32767).astype(np.int16),\n",
|
" (row[\"audio\"][\"array\"] * 32767).astype(np.int16)\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" except Exception as e:\n",
|
" except Exception as e:\n",
|
||||||
" print(f\"Error converting {file_path}: {e}\")\n",
|
" print(f\"Error converting {row['audio']['path']}: {e}\")\n",
|
||||||
" continue # Skip and proceed with the next file\n",
|
" corrupted_files.append(row[\"audio\"][\"path\"])\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Free Music Archive dataset\n",
|
" if corrupted_files:\n",
|
||||||
"# https://github.com/mdeff/fma\n",
|
" with open(\"audioset_corrupted_files.log\", \"w\") as log_file:\n",
|
||||||
"# (Third-party mchl914 extra small set)\n",
|
" log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n",
|
||||||
|
"else:\n",
|
||||||
|
" print(\"No FLAC files found in Audioset.\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"# -----------------------------\n",
|
||||||
|
"# Download and Process FMA\n",
|
||||||
|
"# -----------------------------\n",
|
||||||
"output_dir = \"./fma\"\n",
|
"output_dir = \"./fma\"\n",
|
||||||
"if not os.path.exists(output_dir):\n",
|
"if not os.path.exists(output_dir):\n",
|
||||||
" os.mkdir(output_dir)\n",
|
" os.mkdir(output_dir)\n",
|
||||||
" fname = \"fma_xs.zip\"\n",
|
" fname = \"fma_xs.zip\"\n",
|
||||||
" link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/\" + fname\n",
|
" link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/\" + fname\n",
|
||||||
" out_dir = f\"fma/{fname}\"\n",
|
" out_dir = os.path.join(output_dir, fname)\n",
|
||||||
" print(\"Downloading fma_xs dataset...\")\n",
|
" os.system(f\"wget -O {out_dir} {link}\")\n",
|
||||||
" response = requests.get(link, stream=True)\n",
|
" os.system(f\"cd {output_dir} && unzip -q {fname}\")\n",
|
||||||
" with open(out_dir, \"wb\") as f:\n",
|
|
||||||
" for chunk in response.iter_content(chunk_size=1024):\n",
|
|
||||||
" if chunk:\n",
|
|
||||||
" f.write(chunk)\n",
|
|
||||||
" print(\"Downloaded fma_xs dataset.\")\n",
|
|
||||||
"\n",
|
|
||||||
" os.system(f\"cd fma && unzip -q {fname}\")\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"output_dir = \"./fma_16k\"\n",
|
"output_dir = \"./fma_16k\"\n",
|
||||||
"if not os.path.exists(output_dir):\n",
|
"if not os.path.exists(output_dir):\n",
|
||||||
" os.mkdir(output_dir)\n",
|
" os.mkdir(output_dir)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Save clips to 16-bit PCM wav files\n",
|
"# Save clips to 16-bit PCM wav files\n",
|
||||||
" audio_files = list(Path(\"fma/fma_small\").glob(\"**/*.mp3\"))\n",
|
"fma_files = list(Path(\"fma/fma_small\").glob(\"**/*.mp3\"))\n",
|
||||||
" print(f\"Number of MP3 files found: {len(audio_files)}\")\n",
|
"print(f\"Number of MP3 files found: {len(fma_files)}\")\n",
|
||||||
" if not audio_files:\n",
|
"if fma_files:\n",
|
||||||
" raise FileNotFoundError(\"No .mp3 files found in the fma directory. Check your dataset extraction.\")\n",
|
" fma_dataset = Dataset.from_dict({\"audio\": [str(file) for file in fma_files]})\n",
|
||||||
|
" fma_dataset = fma_dataset.cast_column(\"audio\", Audio(sampling_rate=16000))\n",
|
||||||
"\n",
|
"\n",
|
||||||
" print(\"Converting fma_xs files to 16kHz WAV...\")\n",
|
" corrupted_files = []\n",
|
||||||
" for file_path in tqdm(audio_files):\n",
|
" print(\"Converting FMA files to 16kHz WAV...\")\n",
|
||||||
|
" for row in tqdm(fma_dataset):\n",
|
||||||
" try:\n",
|
" try:\n",
|
||||||
" # Read and convert the .mp3 file\n",
|
" name = row[\"audio\"][\"path\"].split(\"/\")[-1].replace(\".mp3\", \".wav\")\n",
|
||||||
" data, samplerate = scipy.io.wavfile.read(file_path)\n",
|
|
||||||
" if samplerate != 16000:\n",
|
|
||||||
" data = np.interp(\n",
|
|
||||||
" np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n",
|
|
||||||
" np.arange(len(data)),\n",
|
|
||||||
" data,\n",
|
|
||||||
" )\n",
|
|
||||||
"\n",
|
|
||||||
" # Save as WAV\n",
|
|
||||||
" output_path = Path(output_dir) / file_path.name.replace(\".mp3\", \".wav\")\n",
|
|
||||||
" scipy.io.wavfile.write(\n",
|
" scipy.io.wavfile.write(\n",
|
||||||
" output_path,\n",
|
" os.path.join(output_dir, name), \n",
|
||||||
" 16000, \n",
|
" 16000, \n",
|
||||||
" (data * 32767).astype(np.int16),\n",
|
" (row[\"audio\"][\"array\"] * 32767).astype(np.int16)\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" except Exception as e:\n",
|
" except Exception as e:\n",
|
||||||
" print(f\"Error converting {file_path}: {e}\")\n",
|
" print(f\"Error converting {row['audio']['path']}: {e}\")\n",
|
||||||
" continue # Skip and proceed with the next file\n",
|
" corrupted_files.append(row[\"audio\"][\"path\"])\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(\"All datasets processed successfully!\")"
|
" if corrupted_files:\n",
|
||||||
|
" with open(\"fma_corrupted_files.log\", \"w\") as log_file:\n",
|
||||||
|
" log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n",
|
||||||
|
"else:\n",
|
||||||
|
" print(\"No MP3 files found in FMA.\")\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Dataset preparation complete!\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user