diff --git a/advanced_training_notebook.ipynb b/advanced_training_notebook.ipynb index aee839a..ba5ec23 100644 --- a/advanced_training_notebook.ipynb +++ b/advanced_training_notebook.ipynb @@ -136,7 +136,7 @@ "id": "YJRG4Qvo9nXG" }, "outputs": [], - "source": [ + "source": [ "# Downloads audio data for augmentation. This can be slow!\n", "# Borrowed from openWakeWord's automatic_model_training.ipynb, accessed March 4, 2024\n", "#\n", @@ -144,141 +144,129 @@ "# licenses and usage restrictions. As such, any custom models trained with this\n", "# data should be considered as appropriate for **non-commercial** personal use only.\n", "\n", - "import datasets\n", - "import scipy\n", "import os\n", - "import requests\n", - "\n", + "import scipy.io.wavfile\n", "import numpy as np\n", - "\n", + "from datasets import Dataset, Audio, load_dataset\n", "from pathlib import Path\n", "from tqdm import tqdm\n", "\n", - "## Download MIR RIR data\n", - "\n", + "# -----------------------------\n", + "# Download and Process MIT RIR\n", + "# -----------------------------\n", "output_dir = \"./mit_rirs\"\n", "if not os.path.exists(output_dir):\n", " os.mkdir(output_dir)\n", - " rir_dataset = datasets.load_dataset(\"davidscripka/MIT_environmental_impulse_responses\", split=\"train\", streaming=True)\n", - " # Save clips to 16-bit PCM wav files\n", + " rir_dataset = load_dataset(\"davidscripka/MIT_environmental_impulse_responses\", split=\"train\", streaming=True)\n", + " print(f\"Downloading MIT RIR dataset to {output_dir}...\")\n", " for row in tqdm(rir_dataset):\n", - " name = row['audio']['path'].split('/')[-1]\n", - " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array'] * 32767).astype(np.int16))\n", + " name = row[\"audio\"][\"path\"].split(\"/\")[-1]\n", + " scipy.io.wavfile.write(\n", + " os.path.join(output_dir, name), \n", + " 16000, \n", + " (row[\"audio\"][\"array\"] * 32767).astype(np.int16)\n", + " )\n", + " print(f\"Finished downloading MIT RIR dataset to {output_dir}.\\n\")\n", + "else:\n", + " print(f\"{output_dir} already exists. Skipping download.\")\n", "\n", - "## Download noise and background audio\n", + "# -----------------------------\n", + "# Download and Process Audioset\n", + "# -----------------------------\n", + "audioset_dir = \"./audioset\"\n", + "if not os.path.exists(audioset_dir):\n", + " os.mkdir(audioset_dir)\n", "\n", "# Full-scale dataset download links\n", "dataset_links = [\n", " f\"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/bal_train0{i}.tar\"\n", - " for i in range(10) # Adjust for additional parts\n", + " for i in range(10)\n", "]\n", "\n", - "if not os.path.exists(\"audioset\"):\n", - " os.mkdir(\"audioset\")\n", + "for link in dataset_links:\n", + " file_name = link.split(\"/\")[-1]\n", + " out_dir = os.path.join(audioset_dir, file_name)\n", + " if not os.path.exists(out_dir):\n", + " print(f\"Downloading {file_name}...\")\n", + " os.system(f\"wget --quiet -O {out_dir} {link}\")\n", + " print(f\"Extracting {file_name}...\")\n", + " os.system(f\"cd {audioset_dir} && tar -xf {file_name}\")\n", "\n", - " for link in dataset_links:\n", - " fname = link.split(\"/\")[-1]\n", - " out_dir = f\"audioset/{fname}\"\n", - " if not os.path.exists(out_dir):\n", - " print(f\"Downloading {fname}...\")\n", - " response = requests.get(link, stream=True)\n", - " with open(out_dir, \"wb\") as f:\n", - " for chunk in response.iter_content(chunk_size=1024):\n", - " if chunk:\n", - " f.write(chunk)\n", - " print(f\"Downloaded {fname}\")\n", + "output_dir = \"./audioset_16k\"\n", + "if not os.path.exists(output_dir):\n", + " os.mkdir(output_dir)\n", "\n", - " # Extract the tar file\n", - " print(f\"Extracting {fname}...\")\n", - " os.system(f\"cd audioset && tar -xf {fname}\")\n", + "# Save clips to 16-bit PCM wav files\n", + "audioset_files = list(Path(\"audioset/audio\").glob(\"**/*.flac\"))\n", + "print(f\"Number of FLAC files found: {len(audioset_files)}\")\n", + "if audioset_files:\n", + " audioset_dataset = Dataset.from_dict({\"audio\": [str(file) for file in audioset_files]})\n", + " audioset_dataset = audioset_dataset.cast_column(\"audio\", Audio(sampling_rate=16000))\n", "\n", - " output_dir = \"./audioset_16k\"\n", - " if not os.path.exists(output_dir):\n", - " os.mkdir(output_dir)\n", - "\n", - " # Save clips to 16-bit PCM wav files\n", - " audio_files = list(Path(\"audioset/audio\").glob(\"**/*.flac\"))\n", - " print(f\"Number of FLAC files found: {len(audio_files)}\")\n", - " if not audio_files:\n", - " raise FileNotFoundError(\"No .flac files found in the audioset directory. Check your dataset extraction.\")\n", - "\n", - " print(\"Converting audioset files to 16kHz WAV...\")\n", - " for file_path in tqdm(audio_files):\n", + " corrupted_files = []\n", + " print(\"Converting Audioset files to 16kHz WAV...\")\n", + " for row in tqdm(audioset_dataset):\n", " try:\n", - " # Read and convert the .flac file\n", - " data, samplerate = scipy.io.wavfile.read(file_path)\n", - " if samplerate != 16000:\n", - " data = np.interp(\n", - " np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n", - " np.arange(len(data)),\n", - " data,\n", - " )\n", - "\n", - " # Save as WAV\n", - " output_path = Path(output_dir) / file_path.name.replace(\".flac\", \".wav\")\n", + " name = row[\"audio\"][\"path\"].split(\"/\")[-1].replace(\".flac\", \".wav\")\n", " scipy.io.wavfile.write(\n", - " output_path,\n", - " 16000,\n", - " (data * 32767).astype(np.int16),\n", + " os.path.join(output_dir, name), \n", + " 16000, \n", + " (row[\"audio\"][\"array\"] * 32767).astype(np.int16)\n", " )\n", " except Exception as e:\n", - " print(f\"Error converting {file_path}: {e}\")\n", - " continue # Skip and proceed with the next file\n", + " print(f\"Error converting {row['audio']['path']}: {e}\")\n", + " corrupted_files.append(row[\"audio\"][\"path\"])\n", "\n", - "# Free Music Archive dataset\n", - "# https://github.com/mdeff/fma\n", - "# (Third-party mchl914 extra small set)\n", + " if corrupted_files:\n", + " with open(\"audioset_corrupted_files.log\", \"w\") as log_file:\n", + " log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n", + "else:\n", + " print(\"No FLAC files found in Audioset.\")\n", "\n", + "# -----------------------------\n", + "# Download and Process FMA\n", + "# -----------------------------\n", "output_dir = \"./fma\"\n", "if not os.path.exists(output_dir):\n", " os.mkdir(output_dir)\n", " fname = \"fma_xs.zip\"\n", " link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/\" + fname\n", - " out_dir = f\"fma/{fname}\"\n", - " print(\"Downloading fma_xs dataset...\")\n", - " response = requests.get(link, stream=True)\n", - " with open(out_dir, \"wb\") as f:\n", - " for chunk in response.iter_content(chunk_size=1024):\n", - " if chunk:\n", - " f.write(chunk)\n", - " print(\"Downloaded fma_xs dataset.\")\n", + " out_dir = os.path.join(output_dir, fname)\n", + " os.system(f\"wget -O {out_dir} {link}\")\n", + " os.system(f\"cd {output_dir} && unzip -q {fname}\")\n", "\n", - " os.system(f\"cd fma && unzip -q {fname}\")\n", + "output_dir = \"./fma_16k\"\n", + "if not os.path.exists(output_dir):\n", + " os.mkdir(output_dir)\n", "\n", - " output_dir = \"./fma_16k\"\n", - " if not os.path.exists(output_dir):\n", - " os.mkdir(output_dir)\n", + "# Save clips to 16-bit PCM wav files\n", + "fma_files = list(Path(\"fma/fma_small\").glob(\"**/*.mp3\"))\n", + "print(f\"Number of MP3 files found: {len(fma_files)}\")\n", + "if fma_files:\n", + " fma_dataset = Dataset.from_dict({\"audio\": [str(file) for file in fma_files]})\n", + " fma_dataset = fma_dataset.cast_column(\"audio\", Audio(sampling_rate=16000))\n", "\n", - " # Save clips to 16-bit PCM wav files\n", - " audio_files = list(Path(\"fma/fma_small\").glob(\"**/*.mp3\"))\n", - " print(f\"Number of MP3 files found: {len(audio_files)}\")\n", - " if not audio_files:\n", - " raise FileNotFoundError(\"No .mp3 files found in the fma directory. Check your dataset extraction.\")\n", - "\n", - " print(\"Converting fma_xs files to 16kHz WAV...\")\n", - " for file_path in tqdm(audio_files):\n", + " corrupted_files = []\n", + " print(\"Converting FMA files to 16kHz WAV...\")\n", + " for row in tqdm(fma_dataset):\n", " try:\n", - " # Read and convert the .mp3 file\n", - " data, samplerate = scipy.io.wavfile.read(file_path)\n", - " if samplerate != 16000:\n", - " data = np.interp(\n", - " np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n", - " np.arange(len(data)),\n", - " data,\n", - " )\n", - "\n", - " # Save as WAV\n", - " output_path = Path(output_dir) / file_path.name.replace(\".mp3\", \".wav\")\n", + " name = row[\"audio\"][\"path\"].split(\"/\")[-1].replace(\".mp3\", \".wav\")\n", " scipy.io.wavfile.write(\n", - " output_path,\n", - " 16000,\n", - " (data * 32767).astype(np.int16),\n", + " os.path.join(output_dir, name), \n", + " 16000, \n", + " (row[\"audio\"][\"array\"] * 32767).astype(np.int16)\n", " )\n", " except Exception as e:\n", - " print(f\"Error converting {file_path}: {e}\")\n", - " continue # Skip and proceed with the next file\n", + " print(f\"Error converting {row['audio']['path']}: {e}\")\n", + " corrupted_files.append(row[\"audio\"][\"path\"])\n", "\n", - "print(\"All datasets processed successfully!\")" + " if corrupted_files:\n", + " with open(\"fma_corrupted_files.log\", \"w\") as log_file:\n", + " log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n", + "else:\n", + " print(\"No MP3 files found in FMA.\")\n", + "\n", + "print(\"Dataset preparation complete!\")" ] }, {