Update advanced_training_notebook.ipynb

This commit is contained in:
MasterPhooey
2025-01-05 09:50:20 -06:00
committed by GitHub
parent f596a611bb
commit 07addfe2ce

View File

@@ -144,79 +144,28 @@
"# licenses and usage restrictions. As such, any custom models trained with this\n", "# licenses and usage restrictions. As such, any custom models trained with this\n",
"# data should be considered as appropriate for **non-commercial** personal use only.\n", "# data should be considered as appropriate for **non-commercial** personal use only.\n",
"\n", "\n",
"import datasets\n",
"import scipy\n",
"import os\n", "import os\n",
"import scipy.io.wavfile\n", "import requests\n",
"\n",
"import numpy as np\n", "import numpy as np\n",
"import soundfile as sf\n", "\n",
"from pathlib import Path\n", "from pathlib import Path\n",
"from tqdm import tqdm\n", "from tqdm import tqdm\n",
"import requests\n",
"import tarfile\n",
"import zipfile\n",
"from datasets import load_dataset\n",
"\n", "\n",
"# Function to download and process RIR dataset\n", "## Download MIR RIR data\n",
"# Function to download and process RIR dataset\n", "\n",
"def download_rir_dataset(dataset_name, output_dir, split=\"train\"):\n", "output_dir = \"./mit_rirs\"\n",
" output_dir = Path(output_dir)\n", "if not os.path.exists(output_dir):\n",
" if not output_dir.exists():\n", " os.mkdir(output_dir)\n",
" output_dir.mkdir(parents=True, exist_ok=True)\n", " rir_dataset = datasets.load_dataset(\"davidscripka/MIT_environmental_impulse_responses\", split=\"train\", streaming=True)\n",
" try:\n", " # Save clips to 16-bit PCM wav files\n",
" rir_dataset = load_dataset(dataset_name, split=split, streaming=True)\n",
" print(f\"Downloading {dataset_name} to {output_dir}...\")\n",
" for row in tqdm(rir_dataset):\n", " for row in tqdm(rir_dataset):\n",
" name = Path(row['audio']['path']).name\n", " name = row['audio']['path'].split('/')[-1]\n",
" # Save the original audio file\n", " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array'] * 32767).astype(np.int16))\n",
" with open(output_dir / name, \"wb\") as audio_file:\n",
" audio_file.write(row[\"audio\"][\"bytes\"])\n",
" print(f\"Finished downloading {dataset_name} to {output_dir}.\\n\")\n",
" except Exception as e:\n",
" print(f\"Error downloading {dataset_name}: {e}\")\n",
" else:\n",
" print(f\"{output_dir} already exists. Skipping download.\\n\")\n",
"\n", "\n",
"# Download MIT RIRs\n", "## Download noise and background audio\n",
"download_rir_dataset(\n",
" \"davidscripka/MIT_environmental_impulse_responses\",\n",
" \"./mit_rirs\"\n",
")\n",
"\n",
"# Function to download files\n",
"def download_file(url, output_path):\n",
" response = requests.get(url, stream=True)\n",
" total_size = int(response.headers.get('content-length', 0))\n",
" with open(output_path, \"wb\") as f, tqdm(\n",
" desc=f\"Downloading {output_path.name}\",\n",
" total=total_size,\n",
" unit=\"B\",\n",
" unit_scale=True,\n",
" unit_divisor=1024,\n",
" ) as bar:\n",
" for chunk in response.iter_content(chunk_size=1024):\n",
" f.write(chunk)\n",
" bar.update(len(chunk))\n",
" print(f\"Downloaded {output_path}\")\n",
"\n",
"# Function to extract .tar files\n",
"def extract_tar(file_path, extract_dir):\n",
" with tarfile.open(file_path, \"r\") as tar:\n",
" tar.extractall(path=extract_dir)\n",
" print(f\"Extracted {file_path} to {extract_dir}\")\n",
"\n",
"# Function to download and extract ZIP files\n",
"def download_and_extract_zip(url, extract_to):\n",
" file_name = url.split(\"/\")[-1]\n",
" local_path = Path(extract_to) / file_name\n",
" download_file(url, local_path)\n",
" with zipfile.ZipFile(local_path, 'r') as zip_ref:\n",
" zip_ref.extractall(extract_to)\n",
" print(f\"Extracted {file_name} to {extract_to}\")\n",
"\n",
"# Directories\n",
"raw_dir = Path(\"./audioset_raw\")\n",
"processed_dir = Path(\"./audioset_16k\")\n",
"raw_dir.mkdir(exist_ok=True)\n",
"processed_dir.mkdir(exist_ok=True)\n",
"\n", "\n",
"# Full-scale dataset download links\n", "# Full-scale dataset download links\n",
"dataset_links = [\n", "dataset_links = [\n",
@@ -224,96 +173,93 @@
" for i in range(10) # Adjust for additional parts\n", " for i in range(10) # Adjust for additional parts\n",
"]\n", "]\n",
"\n", "\n",
"# Step 1: Download all parts of the dataset\n", "if not os.path.exists(\"audioset\"):\n",
"print(\"Downloading datasets...\")\n", " os.mkdir(\"audioset\")\n",
"\n",
" for link in dataset_links:\n", " for link in dataset_links:\n",
" file_name = link.split(\"/\")[-1]\n", " fname = link.split(\"/\")[-1]\n",
" output_path = raw_dir / file_name\n", " out_dir = f\"audioset/{fname}\"\n",
" if not output_path.exists():\n", " if not os.path.exists(out_dir):\n",
" download_file(link, output_path)\n", " print(f\"Downloading {fname}...\")\n",
" response = requests.get(link, stream=True)\n",
" with open(out_dir, \"wb\") as f:\n",
" for chunk in response.iter_content(chunk_size=1024):\n",
" if chunk:\n",
" f.write(chunk)\n",
" print(f\"Downloaded {fname}\")\n",
"\n", "\n",
"# Step 2: Extract all .tar files\n", " # Extract the tar file\n",
"print(\"Extracting datasets...\")\n", " print(f\"Extracting {fname}...\")\n",
"for file_path in raw_dir.glob(\"*.tar\"):\n", " os.system(f\"cd audioset && tar -xf {fname}\")\n",
" extract_dir = raw_dir / file_path.stem\n",
" if not extract_dir.exists():\n",
" extract_tar(file_path, extract_dir)\n",
"\n", "\n",
"# Step 3: Convert audio files to 16kHz WAV\n", " output_dir = \"./audioset_16k\"\n",
"audio_files = list(Path(raw_dir).glob(\"**/*.flac\"))\n", " if not os.path.exists(output_dir):\n",
" os.mkdir(output_dir)\n",
"\n",
" # Save clips to 16-bit PCM wav files\n",
" audio_files = list(Path(\"audioset/audio\").glob(\"**/*.flac\"))\n",
" print(f\"Number of FLAC files found: {len(audio_files)}\")\n", " print(f\"Number of FLAC files found: {len(audio_files)}\")\n",
" if not audio_files:\n", " if not audio_files:\n",
" raise FileNotFoundError(\"No .flac files found in the raw directories. Check your dataset extraction.\")\n", " raise FileNotFoundError(\"No .flac files found in the audioset directory. Check your dataset extraction.\")\n",
"\n", "\n",
"print(\"Converting audio files to 16kHz WAV...\")\n", " print(\"Converting audioset files to 16kHz WAV...\")\n",
"corrupted_files = []\n", " for file_path in tqdm(audio_files):\n",
"resampled_files = []\n",
"\n",
"for file_path in tqdm(audio_files, desc=\"Processing audio files\"):\n",
" try:\n", " try:\n",
" # Read the .flac file\n", " # Read and convert the .flac file\n",
" data, samplerate = sf.read(file_path)\n", " data, samplerate = scipy.io.wavfile.read(file_path)\n",
"\n",
" # Check and resample if needed\n",
" if samplerate != 16000:\n", " if samplerate != 16000:\n",
" resampled_files.append(str(file_path))\n",
" data = np.interp(\n", " data = np.interp(\n",
" np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n", " np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n",
" np.arange(len(data)),\n", " np.arange(len(data)),\n",
" data,\n", " data,\n",
" )\n", " )\n",
"\n", "\n",
" # Convert and save as WAV\n", " # Save as WAV\n",
" output_path = processed_dir / file_path.name.replace(\".flac\", \".wav\")\n", " output_path = Path(output_dir) / file_path.name.replace(\".flac\", \".wav\")\n",
" scipy.io.wavfile.write(\n", " scipy.io.wavfile.write(\n",
" output_path,\n", " output_path,\n",
" 16000,\n", " 16000,\n",
" (data * 32767).astype(np.int16),\n", " (data * 32767).astype(np.int16),\n",
" )\n", " )\n",
" except Exception as e:\n", " except Exception as e:\n",
" corrupted_files.append(str(file_path))\n", " print(f\"Error converting {file_path}: {e}\")\n",
" continue # Skip and proceed with the next file\n",
"\n", "\n",
"# Log corrupted files\n", "# Free Music Archive dataset\n",
"if corrupted_files:\n", "# https://github.com/mdeff/fma\n",
" with open(\"corrupted_files.log\", \"w\") as log_file:\n", "# (Third-party mchl914 extra small set)\n",
" log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n",
"\n", "\n",
"# Log resampled files\n", "output_dir = \"./fma\"\n",
"if resampled_files:\n", "if not os.path.exists(output_dir):\n",
" with open(\"resampled_files.log\", \"w\") as log_file:\n", " os.mkdir(output_dir)\n",
" log_file.writelines(f\"{file}\\n\" for file in resampled_files)\n", " fname = \"fma_xs.zip\"\n",
"\n", " link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/\" + fname\n",
"print(f\"Audio conversion complete! {len(corrupted_files)} files corrupted and logged.\")\n", " out_dir = f\"fma/{fname}\"\n",
"print(f\"{len(resampled_files)} files resampled and logged.\")\n",
"\n",
"# Process fma_xs dataset\n",
"fma_raw_dir = Path(\"./fma\")\n",
"fma_processed_dir = Path(\"./fma_16k\") # Separate directory for fma_xs processed files\n",
"fma_raw_dir.mkdir(exist_ok=True)\n",
"fma_processed_dir.mkdir(exist_ok=True)\n",
"\n",
"fma_link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/fma_xs.zip\"\n",
"fma_zip_path = fma_raw_dir / \"fma_xs.zip\"\n",
"\n",
"if not fma_zip_path.exists():\n",
" print(\"Downloading fma_xs dataset...\")\n", " print(\"Downloading fma_xs dataset...\")\n",
" download_file(fma_link, fma_zip_path)\n", " response = requests.get(link, stream=True)\n",
" with open(out_dir, \"wb\") as f:\n",
" for chunk in response.iter_content(chunk_size=1024):\n",
" if chunk:\n",
" f.write(chunk)\n",
" print(\"Downloaded fma_xs dataset.\")\n",
"\n", "\n",
"print(\"Extracting fma_xs dataset...\")\n", " os.system(f\"cd fma && unzip -q {fname}\")\n",
"download_and_extract_zip(fma_link, fma_raw_dir)\n",
"\n", "\n",
"fma_audio_files = list(fma_raw_dir.glob(\"**/*.mp3\"))\n", " output_dir = \"./fma_16k\"\n",
"print(f\"Number of MP3 files found: {len(fma_audio_files)}\")\n", " if not os.path.exists(output_dir):\n",
"if not fma_audio_files:\n", " os.mkdir(output_dir)\n",
"\n",
" # Save clips to 16-bit PCM wav files\n",
" audio_files = list(Path(\"fma/fma_small\").glob(\"**/*.mp3\"))\n",
" print(f\"Number of MP3 files found: {len(audio_files)}\")\n",
" if not audio_files:\n",
" raise FileNotFoundError(\"No .mp3 files found in the fma directory. Check your dataset extraction.\")\n", " raise FileNotFoundError(\"No .mp3 files found in the fma directory. Check your dataset extraction.\")\n",
"\n", "\n",
" print(\"Converting fma_xs files to 16kHz WAV...\")\n", " print(\"Converting fma_xs files to 16kHz WAV...\")\n",
"for file_path in tqdm(fma_audio_files, desc=\"Processing fma_xs files\"):\n", " for file_path in tqdm(audio_files):\n",
" try:\n", " try:\n",
" # Read the .mp3 file\n", " # Read and convert the .mp3 file\n",
" data, samplerate = sf.read(file_path)\n", " data, samplerate = scipy.io.wavfile.read(file_path)\n",
"\n",
" # Check and resample if needed\n",
" if samplerate != 16000:\n", " if samplerate != 16000:\n",
" data = np.interp(\n", " data = np.interp(\n",
" np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n", " np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n",
@@ -321,23 +267,18 @@
" data,\n", " data,\n",
" )\n", " )\n",
"\n", "\n",
" # Convert and save as WAV\n", " # Save as WAV\n",
" output_path = fma_processed_dir / file_path.name.replace(\".mp3\", \".wav\")\n", " output_path = Path(output_dir) / file_path.name.replace(\".mp3\", \".wav\")\n",
" scipy.io.wavfile.write(\n", " scipy.io.wavfile.write(\n",
" output_path,\n", " output_path,\n",
" 16000,\n", " 16000,\n",
" (data * 32767).astype(np.int16),\n", " (data * 32767).astype(np.int16),\n",
" )\n", " )\n",
" except Exception as e:\n", " except Exception as e:\n",
" corrupted_files.append(str(file_path))\n", " print(f\"Error converting {file_path}: {e}\")\n",
" continue # Skip and proceed with the next file\n",
"\n", "\n",
"# Log corrupted files from fma_xs\n", "print(\"All datasets processed successfully!\")"
"if corrupted_files:\n",
" with open(\"fma_corrupted_files.log\", \"w\") as log_file:\n",
" log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n",
"\n",
"print(\"fma_xs processing complete!\")\n",
"print(\"Full-scale dataset preparation complete!\")"
] ]
}, },
{ {
@@ -712,7 +653,7 @@
"# Define the JSON metadata for the model\n", "# Define the JSON metadata for the model\n",
"json_data = {\n", "json_data = {\n",
" \"type\": \"micro\",\n", " \"type\": \"micro\",\n",
" \"wake_word\": \"hey_norman\", # Adjust based on your target wake word\n", " \"wake_word\": \"khum_puter\", # Adjust based on your target wake word\n",
" \"author\": \"master phooey\",\n", " \"author\": \"master phooey\",\n",
" \"website\": \"https://github.com/MasterPhooey/MicroWakeWord-Trainer-Docker\",\n", " \"website\": \"https://github.com/MasterPhooey/MicroWakeWord-Trainer-Docker\",\n",
" \"model\": \"stream_state_internal_quant.tflite\",\n", " \"model\": \"stream_state_internal_quant.tflite\",\n",