Update advanced_training_notebook.ipynb

This commit is contained in:
MasterPhooey
2025-01-05 09:50:20 -06:00
committed by GitHub
parent f596a611bb
commit 07addfe2ce

View File

@@ -144,79 +144,28 @@
"# licenses and usage restrictions. As such, any custom models trained with this\n", "# licenses and usage restrictions. As such, any custom models trained with this\n",
"# data should be considered as appropriate for **non-commercial** personal use only.\n", "# data should be considered as appropriate for **non-commercial** personal use only.\n",
"\n", "\n",
"import datasets\n",
"import scipy\n",
"import os\n", "import os\n",
"import scipy.io.wavfile\n", "import requests\n",
"\n",
"import numpy as np\n", "import numpy as np\n",
"import soundfile as sf\n", "\n",
"from pathlib import Path\n", "from pathlib import Path\n",
"from tqdm import tqdm\n", "from tqdm import tqdm\n",
"import requests\n",
"import tarfile\n",
"import zipfile\n",
"from datasets import load_dataset\n",
"\n", "\n",
"# Function to download and process RIR dataset\n", "## Download MIR RIR data\n",
"# Function to download and process RIR dataset\n",
"def download_rir_dataset(dataset_name, output_dir, split=\"train\"):\n",
" output_dir = Path(output_dir)\n",
" if not output_dir.exists():\n",
" output_dir.mkdir(parents=True, exist_ok=True)\n",
" try:\n",
" rir_dataset = load_dataset(dataset_name, split=split, streaming=True)\n",
" print(f\"Downloading {dataset_name} to {output_dir}...\")\n",
" for row in tqdm(rir_dataset):\n",
" name = Path(row['audio']['path']).name\n",
" # Save the original audio file\n",
" with open(output_dir / name, \"wb\") as audio_file:\n",
" audio_file.write(row[\"audio\"][\"bytes\"])\n",
" print(f\"Finished downloading {dataset_name} to {output_dir}.\\n\")\n",
" except Exception as e:\n",
" print(f\"Error downloading {dataset_name}: {e}\")\n",
" else:\n",
" print(f\"{output_dir} already exists. Skipping download.\\n\")\n",
" \n",
"# Download MIT RIRs\n",
"download_rir_dataset(\n",
" \"davidscripka/MIT_environmental_impulse_responses\",\n",
" \"./mit_rirs\"\n",
")\n",
"\n", "\n",
"# Function to download files\n", "output_dir = \"./mit_rirs\"\n",
"def download_file(url, output_path):\n", "if not os.path.exists(output_dir):\n",
" response = requests.get(url, stream=True)\n", " os.mkdir(output_dir)\n",
" total_size = int(response.headers.get('content-length', 0))\n", " rir_dataset = datasets.load_dataset(\"davidscripka/MIT_environmental_impulse_responses\", split=\"train\", streaming=True)\n",
" with open(output_path, \"wb\") as f, tqdm(\n", " # Save clips to 16-bit PCM wav files\n",
" desc=f\"Downloading {output_path.name}\",\n", " for row in tqdm(rir_dataset):\n",
" total=total_size,\n", " name = row['audio']['path'].split('/')[-1]\n",
" unit=\"B\",\n", " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array'] * 32767).astype(np.int16))\n",
" unit_scale=True,\n",
" unit_divisor=1024,\n",
" ) as bar:\n",
" for chunk in response.iter_content(chunk_size=1024):\n",
" f.write(chunk)\n",
" bar.update(len(chunk))\n",
" print(f\"Downloaded {output_path}\")\n",
"\n", "\n",
"# Function to extract .tar files\n", "## Download noise and background audio\n",
"def extract_tar(file_path, extract_dir):\n",
" with tarfile.open(file_path, \"r\") as tar:\n",
" tar.extractall(path=extract_dir)\n",
" print(f\"Extracted {file_path} to {extract_dir}\")\n",
"\n",
"# Function to download and extract ZIP files\n",
"def download_and_extract_zip(url, extract_to):\n",
" file_name = url.split(\"/\")[-1]\n",
" local_path = Path(extract_to) / file_name\n",
" download_file(url, local_path)\n",
" with zipfile.ZipFile(local_path, 'r') as zip_ref:\n",
" zip_ref.extractall(extract_to)\n",
" print(f\"Extracted {file_name} to {extract_to}\")\n",
"\n",
"# Directories\n",
"raw_dir = Path(\"./audioset_raw\")\n",
"processed_dir = Path(\"./audioset_16k\")\n",
"raw_dir.mkdir(exist_ok=True)\n",
"processed_dir.mkdir(exist_ok=True)\n",
"\n", "\n",
"# Full-scale dataset download links\n", "# Full-scale dataset download links\n",
"dataset_links = [\n", "dataset_links = [\n",
@@ -224,120 +173,112 @@
" for i in range(10) # Adjust for additional parts\n", " for i in range(10) # Adjust for additional parts\n",
"]\n", "]\n",
"\n", "\n",
"# Step 1: Download all parts of the dataset\n", "if not os.path.exists(\"audioset\"):\n",
"print(\"Downloading datasets...\")\n", " os.mkdir(\"audioset\")\n",
"for link in dataset_links:\n",
" file_name = link.split(\"/\")[-1]\n",
" output_path = raw_dir / file_name\n",
" if not output_path.exists():\n",
" download_file(link, output_path)\n",
"\n", "\n",
"# Step 2: Extract all .tar files\n", " for link in dataset_links:\n",
"print(\"Extracting datasets...\")\n", " fname = link.split(\"/\")[-1]\n",
"for file_path in raw_dir.glob(\"*.tar\"):\n", " out_dir = f\"audioset/{fname}\"\n",
" extract_dir = raw_dir / file_path.stem\n", " if not os.path.exists(out_dir):\n",
" if not extract_dir.exists():\n", " print(f\"Downloading {fname}...\")\n",
" extract_tar(file_path, extract_dir)\n", " response = requests.get(link, stream=True)\n",
" with open(out_dir, \"wb\") as f:\n",
" for chunk in response.iter_content(chunk_size=1024):\n",
" if chunk:\n",
" f.write(chunk)\n",
" print(f\"Downloaded {fname}\")\n",
"\n", "\n",
"# Step 3: Convert audio files to 16kHz WAV\n", " # Extract the tar file\n",
"audio_files = list(Path(raw_dir).glob(\"**/*.flac\"))\n", " print(f\"Extracting {fname}...\")\n",
"print(f\"Number of FLAC files found: {len(audio_files)}\")\n", " os.system(f\"cd audioset && tar -xf {fname}\")\n",
"if not audio_files:\n",
" raise FileNotFoundError(\"No .flac files found in the raw directories. Check your dataset extraction.\")\n",
"\n", "\n",
"print(\"Converting audio files to 16kHz WAV...\")\n", " output_dir = \"./audioset_16k\"\n",
"corrupted_files = []\n", " if not os.path.exists(output_dir):\n",
"resampled_files = []\n", " os.mkdir(output_dir)\n",
"\n", "\n",
"for file_path in tqdm(audio_files, desc=\"Processing audio files\"):\n", " # Save clips to 16-bit PCM wav files\n",
" try:\n", " audio_files = list(Path(\"audioset/audio\").glob(\"**/*.flac\"))\n",
" # Read the .flac file\n", " print(f\"Number of FLAC files found: {len(audio_files)}\")\n",
" data, samplerate = sf.read(file_path)\n", " if not audio_files:\n",
" raise FileNotFoundError(\"No .flac files found in the audioset directory. Check your dataset extraction.\")\n",
"\n", "\n",
" # Check and resample if needed\n", " print(\"Converting audioset files to 16kHz WAV...\")\n",
" if samplerate != 16000:\n", " for file_path in tqdm(audio_files):\n",
" resampled_files.append(str(file_path))\n", " try:\n",
" data = np.interp(\n", " # Read and convert the .flac file\n",
" np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n", " data, samplerate = scipy.io.wavfile.read(file_path)\n",
" np.arange(len(data)),\n", " if samplerate != 16000:\n",
" data,\n", " data = np.interp(\n",
" np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n",
" np.arange(len(data)),\n",
" data,\n",
" )\n",
"\n",
" # Save as WAV\n",
" output_path = Path(output_dir) / file_path.name.replace(\".flac\", \".wav\")\n",
" scipy.io.wavfile.write(\n",
" output_path,\n",
" 16000,\n",
" (data * 32767).astype(np.int16),\n",
" )\n", " )\n",
" except Exception as e:\n",
" print(f\"Error converting {file_path}: {e}\")\n",
" continue # Skip and proceed with the next file\n",
"\n", "\n",
" # Convert and save as WAV\n", "# Free Music Archive dataset\n",
" output_path = processed_dir / file_path.name.replace(\".flac\", \".wav\")\n", "# https://github.com/mdeff/fma\n",
" scipy.io.wavfile.write(\n", "# (Third-party mchl914 extra small set)\n",
" output_path,\n",
" 16000,\n",
" (data * 32767).astype(np.int16),\n",
" )\n",
" except Exception as e:\n",
" corrupted_files.append(str(file_path))\n",
"\n", "\n",
"# Log corrupted files\n", "output_dir = \"./fma\"\n",
"if corrupted_files:\n", "if not os.path.exists(output_dir):\n",
" with open(\"corrupted_files.log\", \"w\") as log_file:\n", " os.mkdir(output_dir)\n",
" log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n", " fname = \"fma_xs.zip\"\n",
"\n", " link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/\" + fname\n",
"# Log resampled files\n", " out_dir = f\"fma/{fname}\"\n",
"if resampled_files:\n",
" with open(\"resampled_files.log\", \"w\") as log_file:\n",
" log_file.writelines(f\"{file}\\n\" for file in resampled_files)\n",
"\n",
"print(f\"Audio conversion complete! {len(corrupted_files)} files corrupted and logged.\")\n",
"print(f\"{len(resampled_files)} files resampled and logged.\")\n",
"\n",
"# Process fma_xs dataset\n",
"fma_raw_dir = Path(\"./fma\")\n",
"fma_processed_dir = Path(\"./fma_16k\") # Separate directory for fma_xs processed files\n",
"fma_raw_dir.mkdir(exist_ok=True)\n",
"fma_processed_dir.mkdir(exist_ok=True)\n",
"\n",
"fma_link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/fma_xs.zip\"\n",
"fma_zip_path = fma_raw_dir / \"fma_xs.zip\"\n",
"\n",
"if not fma_zip_path.exists():\n",
" print(\"Downloading fma_xs dataset...\")\n", " print(\"Downloading fma_xs dataset...\")\n",
" download_file(fma_link, fma_zip_path)\n", " response = requests.get(link, stream=True)\n",
" with open(out_dir, \"wb\") as f:\n",
" for chunk in response.iter_content(chunk_size=1024):\n",
" if chunk:\n",
" f.write(chunk)\n",
" print(\"Downloaded fma_xs dataset.\")\n",
"\n", "\n",
"print(\"Extracting fma_xs dataset...\")\n", " os.system(f\"cd fma && unzip -q {fname}\")\n",
"download_and_extract_zip(fma_link, fma_raw_dir)\n",
"\n", "\n",
"fma_audio_files = list(fma_raw_dir.glob(\"**/*.mp3\"))\n", " output_dir = \"./fma_16k\"\n",
"print(f\"Number of MP3 files found: {len(fma_audio_files)}\")\n", " if not os.path.exists(output_dir):\n",
"if not fma_audio_files:\n", " os.mkdir(output_dir)\n",
" raise FileNotFoundError(\"No .mp3 files found in the fma directory. Check your dataset extraction.\")\n",
"\n", "\n",
"print(\"Converting fma_xs files to 16kHz WAV...\")\n", " # Save clips to 16-bit PCM wav files\n",
"for file_path in tqdm(fma_audio_files, desc=\"Processing fma_xs files\"):\n", " audio_files = list(Path(\"fma/fma_small\").glob(\"**/*.mp3\"))\n",
" try:\n", " print(f\"Number of MP3 files found: {len(audio_files)}\")\n",
" # Read the .mp3 file\n", " if not audio_files:\n",
" data, samplerate = sf.read(file_path)\n", " raise FileNotFoundError(\"No .mp3 files found in the fma directory. Check your dataset extraction.\")\n",
"\n", "\n",
" # Check and resample if needed\n", " print(\"Converting fma_xs files to 16kHz WAV...\")\n",
" if samplerate != 16000:\n", " for file_path in tqdm(audio_files):\n",
" data = np.interp(\n", " try:\n",
" np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n", " # Read and convert the .mp3 file\n",
" np.arange(len(data)),\n", " data, samplerate = scipy.io.wavfile.read(file_path)\n",
" data,\n", " if samplerate != 16000:\n",
" data = np.interp(\n",
" np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n",
" np.arange(len(data)),\n",
" data,\n",
" )\n",
"\n",
" # Save as WAV\n",
" output_path = Path(output_dir) / file_path.name.replace(\".mp3\", \".wav\")\n",
" scipy.io.wavfile.write(\n",
" output_path,\n",
" 16000,\n",
" (data * 32767).astype(np.int16),\n",
" )\n", " )\n",
" except Exception as e:\n",
" print(f\"Error converting {file_path}: {e}\")\n",
" continue # Skip and proceed with the next file\n",
"\n", "\n",
" # Convert and save as WAV\n", "print(\"All datasets processed successfully!\")"
" output_path = fma_processed_dir / file_path.name.replace(\".mp3\", \".wav\")\n",
" scipy.io.wavfile.write(\n",
" output_path,\n",
" 16000,\n",
" (data * 32767).astype(np.int16),\n",
" )\n",
" except Exception as e:\n",
" corrupted_files.append(str(file_path))\n",
"\n",
"# Log corrupted files from fma_xs\n",
"if corrupted_files:\n",
" with open(\"fma_corrupted_files.log\", \"w\") as log_file:\n",
" log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n",
"\n",
"print(\"fma_xs processing complete!\")\n",
"print(\"Full-scale dataset preparation complete!\")"
] ]
}, },
{ {
@@ -712,7 +653,7 @@
"# Define the JSON metadata for the model\n", "# Define the JSON metadata for the model\n",
"json_data = {\n", "json_data = {\n",
" \"type\": \"micro\",\n", " \"type\": \"micro\",\n",
" \"wake_word\": \"hey_norman\", # Adjust based on your target wake word\n", " \"wake_word\": \"khum_puter\", # Adjust based on your target wake word\n",
" \"author\": \"master phooey\",\n", " \"author\": \"master phooey\",\n",
" \"website\": \"https://github.com/MasterPhooey/MicroWakeWord-Trainer-Docker\",\n", " \"website\": \"https://github.com/MasterPhooey/MicroWakeWord-Trainer-Docker\",\n",
" \"model\": \"stream_state_internal_quant.tflite\",\n", " \"model\": \"stream_state_internal_quant.tflite\",\n",