Update advanced_training_notebook.ipynb

This commit is contained in:
MasterPhooey
2025-01-05 09:50:20 -06:00
committed by GitHub
parent f596a611bb
commit 07addfe2ce

View File

@@ -144,79 +144,28 @@
"# licenses and usage restrictions. As such, any custom models trained with this\n",
"# data should be considered as appropriate for **non-commercial** personal use only.\n",
"\n",
"import datasets\n",
"import scipy\n",
"import os\n",
"import scipy.io.wavfile\n",
"import requests\n",
"\n",
"import numpy as np\n",
"import soundfile as sf\n",
"\n",
"from pathlib import Path\n",
"from tqdm import tqdm\n",
"import requests\n",
"import tarfile\n",
"import zipfile\n",
"from datasets import load_dataset\n",
"\n",
"# Function to download and process RIR dataset\n",
"# Function to download and process RIR dataset\n",
"def download_rir_dataset(dataset_name, output_dir, split=\"train\"):\n",
" output_dir = Path(output_dir)\n",
" if not output_dir.exists():\n",
" output_dir.mkdir(parents=True, exist_ok=True)\n",
" try:\n",
" rir_dataset = load_dataset(dataset_name, split=split, streaming=True)\n",
" print(f\"Downloading {dataset_name} to {output_dir}...\")\n",
" for row in tqdm(rir_dataset):\n",
" name = Path(row['audio']['path']).name\n",
" # Save the original audio file\n",
" with open(output_dir / name, \"wb\") as audio_file:\n",
" audio_file.write(row[\"audio\"][\"bytes\"])\n",
" print(f\"Finished downloading {dataset_name} to {output_dir}.\\n\")\n",
" except Exception as e:\n",
" print(f\"Error downloading {dataset_name}: {e}\")\n",
" else:\n",
" print(f\"{output_dir} already exists. Skipping download.\\n\")\n",
" \n",
"# Download MIT RIRs\n",
"download_rir_dataset(\n",
" \"davidscripka/MIT_environmental_impulse_responses\",\n",
" \"./mit_rirs\"\n",
")\n",
"## Download MIR RIR data\n",
"\n",
"# Function to download files\n",
"def download_file(url, output_path):\n",
" response = requests.get(url, stream=True)\n",
" total_size = int(response.headers.get('content-length', 0))\n",
" with open(output_path, \"wb\") as f, tqdm(\n",
" desc=f\"Downloading {output_path.name}\",\n",
" total=total_size,\n",
" unit=\"B\",\n",
" unit_scale=True,\n",
" unit_divisor=1024,\n",
" ) as bar:\n",
" for chunk in response.iter_content(chunk_size=1024):\n",
" f.write(chunk)\n",
" bar.update(len(chunk))\n",
" print(f\"Downloaded {output_path}\")\n",
"output_dir = \"./mit_rirs\"\n",
"if not os.path.exists(output_dir):\n",
" os.mkdir(output_dir)\n",
" rir_dataset = datasets.load_dataset(\"davidscripka/MIT_environmental_impulse_responses\", split=\"train\", streaming=True)\n",
" # Save clips to 16-bit PCM wav files\n",
" for row in tqdm(rir_dataset):\n",
" name = row['audio']['path'].split('/')[-1]\n",
" scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array'] * 32767).astype(np.int16))\n",
"\n",
"# Function to extract .tar files\n",
"def extract_tar(file_path, extract_dir):\n",
" with tarfile.open(file_path, \"r\") as tar:\n",
" tar.extractall(path=extract_dir)\n",
" print(f\"Extracted {file_path} to {extract_dir}\")\n",
"\n",
"# Function to download and extract ZIP files\n",
"def download_and_extract_zip(url, extract_to):\n",
" file_name = url.split(\"/\")[-1]\n",
" local_path = Path(extract_to) / file_name\n",
" download_file(url, local_path)\n",
" with zipfile.ZipFile(local_path, 'r') as zip_ref:\n",
" zip_ref.extractall(extract_to)\n",
" print(f\"Extracted {file_name} to {extract_to}\")\n",
"\n",
"# Directories\n",
"raw_dir = Path(\"./audioset_raw\")\n",
"processed_dir = Path(\"./audioset_16k\")\n",
"raw_dir.mkdir(exist_ok=True)\n",
"processed_dir.mkdir(exist_ok=True)\n",
"## Download noise and background audio\n",
"\n",
"# Full-scale dataset download links\n",
"dataset_links = [\n",
@@ -224,120 +173,112 @@
" for i in range(10) # Adjust for additional parts\n",
"]\n",
"\n",
"# Step 1: Download all parts of the dataset\n",
"print(\"Downloading datasets...\")\n",
"for link in dataset_links:\n",
" file_name = link.split(\"/\")[-1]\n",
" output_path = raw_dir / file_name\n",
" if not output_path.exists():\n",
" download_file(link, output_path)\n",
"if not os.path.exists(\"audioset\"):\n",
" os.mkdir(\"audioset\")\n",
"\n",
"# Step 2: Extract all .tar files\n",
"print(\"Extracting datasets...\")\n",
"for file_path in raw_dir.glob(\"*.tar\"):\n",
" extract_dir = raw_dir / file_path.stem\n",
" if not extract_dir.exists():\n",
" extract_tar(file_path, extract_dir)\n",
" for link in dataset_links:\n",
" fname = link.split(\"/\")[-1]\n",
" out_dir = f\"audioset/{fname}\"\n",
" if not os.path.exists(out_dir):\n",
" print(f\"Downloading {fname}...\")\n",
" response = requests.get(link, stream=True)\n",
" with open(out_dir, \"wb\") as f:\n",
" for chunk in response.iter_content(chunk_size=1024):\n",
" if chunk:\n",
" f.write(chunk)\n",
" print(f\"Downloaded {fname}\")\n",
"\n",
"# Step 3: Convert audio files to 16kHz WAV\n",
"audio_files = list(Path(raw_dir).glob(\"**/*.flac\"))\n",
"print(f\"Number of FLAC files found: {len(audio_files)}\")\n",
"if not audio_files:\n",
" raise FileNotFoundError(\"No .flac files found in the raw directories. Check your dataset extraction.\")\n",
" # Extract the tar file\n",
" print(f\"Extracting {fname}...\")\n",
" os.system(f\"cd audioset && tar -xf {fname}\")\n",
"\n",
"print(\"Converting audio files to 16kHz WAV...\")\n",
"corrupted_files = []\n",
"resampled_files = []\n",
" output_dir = \"./audioset_16k\"\n",
" if not os.path.exists(output_dir):\n",
" os.mkdir(output_dir)\n",
"\n",
"for file_path in tqdm(audio_files, desc=\"Processing audio files\"):\n",
" try:\n",
" # Read the .flac file\n",
" data, samplerate = sf.read(file_path)\n",
" # Save clips to 16-bit PCM wav files\n",
" audio_files = list(Path(\"audioset/audio\").glob(\"**/*.flac\"))\n",
" print(f\"Number of FLAC files found: {len(audio_files)}\")\n",
" if not audio_files:\n",
" raise FileNotFoundError(\"No .flac files found in the audioset directory. Check your dataset extraction.\")\n",
"\n",
" # Check and resample if needed\n",
" if samplerate != 16000:\n",
" resampled_files.append(str(file_path))\n",
" data = np.interp(\n",
" np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n",
" np.arange(len(data)),\n",
" data,\n",
" print(\"Converting audioset files to 16kHz WAV...\")\n",
" for file_path in tqdm(audio_files):\n",
" try:\n",
" # Read and convert the .flac file\n",
" data, samplerate = scipy.io.wavfile.read(file_path)\n",
" if samplerate != 16000:\n",
" data = np.interp(\n",
" np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n",
" np.arange(len(data)),\n",
" data,\n",
" )\n",
"\n",
" # Save as WAV\n",
" output_path = Path(output_dir) / file_path.name.replace(\".flac\", \".wav\")\n",
" scipy.io.wavfile.write(\n",
" output_path,\n",
" 16000,\n",
" (data * 32767).astype(np.int16),\n",
" )\n",
" except Exception as e:\n",
" print(f\"Error converting {file_path}: {e}\")\n",
" continue # Skip and proceed with the next file\n",
"\n",
" # Convert and save as WAV\n",
" output_path = processed_dir / file_path.name.replace(\".flac\", \".wav\")\n",
" scipy.io.wavfile.write(\n",
" output_path,\n",
" 16000,\n",
" (data * 32767).astype(np.int16),\n",
" )\n",
" except Exception as e:\n",
" corrupted_files.append(str(file_path))\n",
"# Free Music Archive dataset\n",
"# https://github.com/mdeff/fma\n",
"# (Third-party mchl914 extra small set)\n",
"\n",
"# Log corrupted files\n",
"if corrupted_files:\n",
" with open(\"corrupted_files.log\", \"w\") as log_file:\n",
" log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n",
"\n",
"# Log resampled files\n",
"if resampled_files:\n",
" with open(\"resampled_files.log\", \"w\") as log_file:\n",
" log_file.writelines(f\"{file}\\n\" for file in resampled_files)\n",
"\n",
"print(f\"Audio conversion complete! {len(corrupted_files)} files corrupted and logged.\")\n",
"print(f\"{len(resampled_files)} files resampled and logged.\")\n",
"\n",
"# Process fma_xs dataset\n",
"fma_raw_dir = Path(\"./fma\")\n",
"fma_processed_dir = Path(\"./fma_16k\") # Separate directory for fma_xs processed files\n",
"fma_raw_dir.mkdir(exist_ok=True)\n",
"fma_processed_dir.mkdir(exist_ok=True)\n",
"\n",
"fma_link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/fma_xs.zip\"\n",
"fma_zip_path = fma_raw_dir / \"fma_xs.zip\"\n",
"\n",
"if not fma_zip_path.exists():\n",
"output_dir = \"./fma\"\n",
"if not os.path.exists(output_dir):\n",
" os.mkdir(output_dir)\n",
" fname = \"fma_xs.zip\"\n",
" link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/\" + fname\n",
" out_dir = f\"fma/{fname}\"\n",
" print(\"Downloading fma_xs dataset...\")\n",
" download_file(fma_link, fma_zip_path)\n",
" response = requests.get(link, stream=True)\n",
" with open(out_dir, \"wb\") as f:\n",
" for chunk in response.iter_content(chunk_size=1024):\n",
" if chunk:\n",
" f.write(chunk)\n",
" print(\"Downloaded fma_xs dataset.\")\n",
"\n",
"print(\"Extracting fma_xs dataset...\")\n",
"download_and_extract_zip(fma_link, fma_raw_dir)\n",
" os.system(f\"cd fma && unzip -q {fname}\")\n",
"\n",
"fma_audio_files = list(fma_raw_dir.glob(\"**/*.mp3\"))\n",
"print(f\"Number of MP3 files found: {len(fma_audio_files)}\")\n",
"if not fma_audio_files:\n",
" raise FileNotFoundError(\"No .mp3 files found in the fma directory. Check your dataset extraction.\")\n",
" output_dir = \"./fma_16k\"\n",
" if not os.path.exists(output_dir):\n",
" os.mkdir(output_dir)\n",
"\n",
"print(\"Converting fma_xs files to 16kHz WAV...\")\n",
"for file_path in tqdm(fma_audio_files, desc=\"Processing fma_xs files\"):\n",
" try:\n",
" # Read the .mp3 file\n",
" data, samplerate = sf.read(file_path)\n",
" # Save clips to 16-bit PCM wav files\n",
" audio_files = list(Path(\"fma/fma_small\").glob(\"**/*.mp3\"))\n",
" print(f\"Number of MP3 files found: {len(audio_files)}\")\n",
" if not audio_files:\n",
" raise FileNotFoundError(\"No .mp3 files found in the fma directory. Check your dataset extraction.\")\n",
"\n",
" # Check and resample if needed\n",
" if samplerate != 16000:\n",
" data = np.interp(\n",
" np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n",
" np.arange(len(data)),\n",
" data,\n",
" print(\"Converting fma_xs files to 16kHz WAV...\")\n",
" for file_path in tqdm(audio_files):\n",
" try:\n",
" # Read and convert the .mp3 file\n",
" data, samplerate = scipy.io.wavfile.read(file_path)\n",
" if samplerate != 16000:\n",
" data = np.interp(\n",
" np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n",
" np.arange(len(data)),\n",
" data,\n",
" )\n",
"\n",
" # Save as WAV\n",
" output_path = Path(output_dir) / file_path.name.replace(\".mp3\", \".wav\")\n",
" scipy.io.wavfile.write(\n",
" output_path,\n",
" 16000,\n",
" (data * 32767).astype(np.int16),\n",
" )\n",
" except Exception as e:\n",
" print(f\"Error converting {file_path}: {e}\")\n",
" continue # Skip and proceed with the next file\n",
"\n",
" # Convert and save as WAV\n",
" output_path = fma_processed_dir / file_path.name.replace(\".mp3\", \".wav\")\n",
" scipy.io.wavfile.write(\n",
" output_path,\n",
" 16000,\n",
" (data * 32767).astype(np.int16),\n",
" )\n",
" except Exception as e:\n",
" corrupted_files.append(str(file_path))\n",
"\n",
"# Log corrupted files from fma_xs\n",
"if corrupted_files:\n",
" with open(\"fma_corrupted_files.log\", \"w\") as log_file:\n",
" log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n",
"\n",
"print(\"fma_xs processing complete!\")\n",
"print(\"Full-scale dataset preparation complete!\")"
"print(\"All datasets processed successfully!\")"
]
},
{
@@ -712,7 +653,7 @@
"# Define the JSON metadata for the model\n",
"json_data = {\n",
" \"type\": \"micro\",\n",
" \"wake_word\": \"hey_norman\", # Adjust based on your target wake word\n",
" \"wake_word\": \"khum_puter\", # Adjust based on your target wake word\n",
" \"author\": \"master phooey\",\n",
" \"website\": \"https://github.com/MasterPhooey/MicroWakeWord-Trainer-Docker\",\n",
" \"model\": \"stream_state_internal_quant.tflite\",\n",