diff --git a/advanced_training_notebook.ipynb b/advanced_training_notebook.ipynb index 626f037..aee839a 100644 --- a/advanced_training_notebook.ipynb +++ b/advanced_training_notebook.ipynb @@ -144,79 +144,28 @@ "# licenses and usage restrictions. As such, any custom models trained with this\n", "# data should be considered as appropriate for **non-commercial** personal use only.\n", "\n", + "import datasets\n", + "import scipy\n", "import os\n", - "import scipy.io.wavfile\n", + "import requests\n", + "\n", "import numpy as np\n", - "import soundfile as sf\n", + "\n", "from pathlib import Path\n", "from tqdm import tqdm\n", - "import requests\n", - "import tarfile\n", - "import zipfile\n", - "from datasets import load_dataset\n", "\n", - "# Function to download and process RIR dataset\n", - "# Function to download and process RIR dataset\n", - "def download_rir_dataset(dataset_name, output_dir, split=\"train\"):\n", - " output_dir = Path(output_dir)\n", - " if not output_dir.exists():\n", - " output_dir.mkdir(parents=True, exist_ok=True)\n", - " try:\n", - " rir_dataset = load_dataset(dataset_name, split=split, streaming=True)\n", - " print(f\"Downloading {dataset_name} to {output_dir}...\")\n", - " for row in tqdm(rir_dataset):\n", - " name = Path(row['audio']['path']).name\n", - " # Save the original audio file\n", - " with open(output_dir / name, \"wb\") as audio_file:\n", - " audio_file.write(row[\"audio\"][\"bytes\"])\n", - " print(f\"Finished downloading {dataset_name} to {output_dir}.\\n\")\n", - " except Exception as e:\n", - " print(f\"Error downloading {dataset_name}: {e}\")\n", - " else:\n", - " print(f\"{output_dir} already exists. Skipping download.\\n\")\n", - " \n", - "# Download MIT RIRs\n", - "download_rir_dataset(\n", - " \"davidscripka/MIT_environmental_impulse_responses\",\n", - " \"./mit_rirs\"\n", - ")\n", + "## Download MIR RIR data\n", "\n", - "# Function to download files\n", - "def download_file(url, output_path):\n", - " response = requests.get(url, stream=True)\n", - " total_size = int(response.headers.get('content-length', 0))\n", - " with open(output_path, \"wb\") as f, tqdm(\n", - " desc=f\"Downloading {output_path.name}\",\n", - " total=total_size,\n", - " unit=\"B\",\n", - " unit_scale=True,\n", - " unit_divisor=1024,\n", - " ) as bar:\n", - " for chunk in response.iter_content(chunk_size=1024):\n", - " f.write(chunk)\n", - " bar.update(len(chunk))\n", - " print(f\"Downloaded {output_path}\")\n", + "output_dir = \"./mit_rirs\"\n", + "if not os.path.exists(output_dir):\n", + " os.mkdir(output_dir)\n", + " rir_dataset = datasets.load_dataset(\"davidscripka/MIT_environmental_impulse_responses\", split=\"train\", streaming=True)\n", + " # Save clips to 16-bit PCM wav files\n", + " for row in tqdm(rir_dataset):\n", + " name = row['audio']['path'].split('/')[-1]\n", + " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array'] * 32767).astype(np.int16))\n", "\n", - "# Function to extract .tar files\n", - "def extract_tar(file_path, extract_dir):\n", - " with tarfile.open(file_path, \"r\") as tar:\n", - " tar.extractall(path=extract_dir)\n", - " print(f\"Extracted {file_path} to {extract_dir}\")\n", - "\n", - "# Function to download and extract ZIP files\n", - "def download_and_extract_zip(url, extract_to):\n", - " file_name = url.split(\"/\")[-1]\n", - " local_path = Path(extract_to) / file_name\n", - " download_file(url, local_path)\n", - " with zipfile.ZipFile(local_path, 'r') as zip_ref:\n", - " zip_ref.extractall(extract_to)\n", - " print(f\"Extracted {file_name} to {extract_to}\")\n", - "\n", - "# Directories\n", - "raw_dir = Path(\"./audioset_raw\")\n", - "processed_dir = Path(\"./audioset_16k\")\n", - "raw_dir.mkdir(exist_ok=True)\n", - "processed_dir.mkdir(exist_ok=True)\n", + "## Download noise and background audio\n", "\n", "# Full-scale dataset download links\n", "dataset_links = [\n", @@ -224,120 +173,112 @@ " for i in range(10) # Adjust for additional parts\n", "]\n", "\n", - "# Step 1: Download all parts of the dataset\n", - "print(\"Downloading datasets...\")\n", - "for link in dataset_links:\n", - " file_name = link.split(\"/\")[-1]\n", - " output_path = raw_dir / file_name\n", - " if not output_path.exists():\n", - " download_file(link, output_path)\n", + "if not os.path.exists(\"audioset\"):\n", + " os.mkdir(\"audioset\")\n", "\n", - "# Step 2: Extract all .tar files\n", - "print(\"Extracting datasets...\")\n", - "for file_path in raw_dir.glob(\"*.tar\"):\n", - " extract_dir = raw_dir / file_path.stem\n", - " if not extract_dir.exists():\n", - " extract_tar(file_path, extract_dir)\n", + " for link in dataset_links:\n", + " fname = link.split(\"/\")[-1]\n", + " out_dir = f\"audioset/{fname}\"\n", + " if not os.path.exists(out_dir):\n", + " print(f\"Downloading {fname}...\")\n", + " response = requests.get(link, stream=True)\n", + " with open(out_dir, \"wb\") as f:\n", + " for chunk in response.iter_content(chunk_size=1024):\n", + " if chunk:\n", + " f.write(chunk)\n", + " print(f\"Downloaded {fname}\")\n", "\n", - "# Step 3: Convert audio files to 16kHz WAV\n", - "audio_files = list(Path(raw_dir).glob(\"**/*.flac\"))\n", - "print(f\"Number of FLAC files found: {len(audio_files)}\")\n", - "if not audio_files:\n", - " raise FileNotFoundError(\"No .flac files found in the raw directories. Check your dataset extraction.\")\n", + " # Extract the tar file\n", + " print(f\"Extracting {fname}...\")\n", + " os.system(f\"cd audioset && tar -xf {fname}\")\n", "\n", - "print(\"Converting audio files to 16kHz WAV...\")\n", - "corrupted_files = []\n", - "resampled_files = []\n", + " output_dir = \"./audioset_16k\"\n", + " if not os.path.exists(output_dir):\n", + " os.mkdir(output_dir)\n", "\n", - "for file_path in tqdm(audio_files, desc=\"Processing audio files\"):\n", - " try:\n", - " # Read the .flac file\n", - " data, samplerate = sf.read(file_path)\n", + " # Save clips to 16-bit PCM wav files\n", + " audio_files = list(Path(\"audioset/audio\").glob(\"**/*.flac\"))\n", + " print(f\"Number of FLAC files found: {len(audio_files)}\")\n", + " if not audio_files:\n", + " raise FileNotFoundError(\"No .flac files found in the audioset directory. Check your dataset extraction.\")\n", "\n", - " # Check and resample if needed\n", - " if samplerate != 16000:\n", - " resampled_files.append(str(file_path))\n", - " data = np.interp(\n", - " np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n", - " np.arange(len(data)),\n", - " data,\n", + " print(\"Converting audioset files to 16kHz WAV...\")\n", + " for file_path in tqdm(audio_files):\n", + " try:\n", + " # Read and convert the .flac file\n", + " data, samplerate = scipy.io.wavfile.read(file_path)\n", + " if samplerate != 16000:\n", + " data = np.interp(\n", + " np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n", + " np.arange(len(data)),\n", + " data,\n", + " )\n", + "\n", + " # Save as WAV\n", + " output_path = Path(output_dir) / file_path.name.replace(\".flac\", \".wav\")\n", + " scipy.io.wavfile.write(\n", + " output_path,\n", + " 16000,\n", + " (data * 32767).astype(np.int16),\n", " )\n", + " except Exception as e:\n", + " print(f\"Error converting {file_path}: {e}\")\n", + " continue # Skip and proceed with the next file\n", "\n", - " # Convert and save as WAV\n", - " output_path = processed_dir / file_path.name.replace(\".flac\", \".wav\")\n", - " scipy.io.wavfile.write(\n", - " output_path,\n", - " 16000,\n", - " (data * 32767).astype(np.int16),\n", - " )\n", - " except Exception as e:\n", - " corrupted_files.append(str(file_path))\n", + "# Free Music Archive dataset\n", + "# https://github.com/mdeff/fma\n", + "# (Third-party mchl914 extra small set)\n", "\n", - "# Log corrupted files\n", - "if corrupted_files:\n", - " with open(\"corrupted_files.log\", \"w\") as log_file:\n", - " log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n", - "\n", - "# Log resampled files\n", - "if resampled_files:\n", - " with open(\"resampled_files.log\", \"w\") as log_file:\n", - " log_file.writelines(f\"{file}\\n\" for file in resampled_files)\n", - "\n", - "print(f\"Audio conversion complete! {len(corrupted_files)} files corrupted and logged.\")\n", - "print(f\"{len(resampled_files)} files resampled and logged.\")\n", - "\n", - "# Process fma_xs dataset\n", - "fma_raw_dir = Path(\"./fma\")\n", - "fma_processed_dir = Path(\"./fma_16k\") # Separate directory for fma_xs processed files\n", - "fma_raw_dir.mkdir(exist_ok=True)\n", - "fma_processed_dir.mkdir(exist_ok=True)\n", - "\n", - "fma_link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/fma_xs.zip\"\n", - "fma_zip_path = fma_raw_dir / \"fma_xs.zip\"\n", - "\n", - "if not fma_zip_path.exists():\n", + "output_dir = \"./fma\"\n", + "if not os.path.exists(output_dir):\n", + " os.mkdir(output_dir)\n", + " fname = \"fma_xs.zip\"\n", + " link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/\" + fname\n", + " out_dir = f\"fma/{fname}\"\n", " print(\"Downloading fma_xs dataset...\")\n", - " download_file(fma_link, fma_zip_path)\n", + " response = requests.get(link, stream=True)\n", + " with open(out_dir, \"wb\") as f:\n", + " for chunk in response.iter_content(chunk_size=1024):\n", + " if chunk:\n", + " f.write(chunk)\n", + " print(\"Downloaded fma_xs dataset.\")\n", "\n", - "print(\"Extracting fma_xs dataset...\")\n", - "download_and_extract_zip(fma_link, fma_raw_dir)\n", + " os.system(f\"cd fma && unzip -q {fname}\")\n", "\n", - "fma_audio_files = list(fma_raw_dir.glob(\"**/*.mp3\"))\n", - "print(f\"Number of MP3 files found: {len(fma_audio_files)}\")\n", - "if not fma_audio_files:\n", - " raise FileNotFoundError(\"No .mp3 files found in the fma directory. Check your dataset extraction.\")\n", + " output_dir = \"./fma_16k\"\n", + " if not os.path.exists(output_dir):\n", + " os.mkdir(output_dir)\n", "\n", - "print(\"Converting fma_xs files to 16kHz WAV...\")\n", - "for file_path in tqdm(fma_audio_files, desc=\"Processing fma_xs files\"):\n", - " try:\n", - " # Read the .mp3 file\n", - " data, samplerate = sf.read(file_path)\n", + " # Save clips to 16-bit PCM wav files\n", + " audio_files = list(Path(\"fma/fma_small\").glob(\"**/*.mp3\"))\n", + " print(f\"Number of MP3 files found: {len(audio_files)}\")\n", + " if not audio_files:\n", + " raise FileNotFoundError(\"No .mp3 files found in the fma directory. Check your dataset extraction.\")\n", "\n", - " # Check and resample if needed\n", - " if samplerate != 16000:\n", - " data = np.interp(\n", - " np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n", - " np.arange(len(data)),\n", - " data,\n", + " print(\"Converting fma_xs files to 16kHz WAV...\")\n", + " for file_path in tqdm(audio_files):\n", + " try:\n", + " # Read and convert the .mp3 file\n", + " data, samplerate = scipy.io.wavfile.read(file_path)\n", + " if samplerate != 16000:\n", + " data = np.interp(\n", + " np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n", + " np.arange(len(data)),\n", + " data,\n", + " )\n", + "\n", + " # Save as WAV\n", + " output_path = Path(output_dir) / file_path.name.replace(\".mp3\", \".wav\")\n", + " scipy.io.wavfile.write(\n", + " output_path,\n", + " 16000,\n", + " (data * 32767).astype(np.int16),\n", " )\n", + " except Exception as e:\n", + " print(f\"Error converting {file_path}: {e}\")\n", + " continue # Skip and proceed with the next file\n", "\n", - " # Convert and save as WAV\n", - " output_path = fma_processed_dir / file_path.name.replace(\".mp3\", \".wav\")\n", - " scipy.io.wavfile.write(\n", - " output_path,\n", - " 16000,\n", - " (data * 32767).astype(np.int16),\n", - " )\n", - " except Exception as e:\n", - " corrupted_files.append(str(file_path))\n", - "\n", - "# Log corrupted files from fma_xs\n", - "if corrupted_files:\n", - " with open(\"fma_corrupted_files.log\", \"w\") as log_file:\n", - " log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n", - "\n", - "print(\"fma_xs processing complete!\")\n", - "print(\"Full-scale dataset preparation complete!\")" + "print(\"All datasets processed successfully!\")" ] }, { @@ -712,7 +653,7 @@ "# Define the JSON metadata for the model\n", "json_data = {\n", " \"type\": \"micro\",\n", - " \"wake_word\": \"hey_norman\", # Adjust based on your target wake word\n", + " \"wake_word\": \"khum_puter\", # Adjust based on your target wake word\n", " \"author\": \"master phooey\",\n", " \"website\": \"https://github.com/MasterPhooey/MicroWakeWord-Trainer-Docker\",\n", " \"model\": \"stream_state_internal_quant.tflite\",\n",