Add files via upload

This commit is contained in:
MasterPhooey
2025-01-04 21:10:44 -06:00
committed by GitHub
parent d25bfbbea9
commit 43812d5db6

View File

@@ -11,8 +11,6 @@
" <h1>MicroWakeWord Trainer Docker</h1>\n",
"</div>\n",
"\n",
"# Training a microWakeWord Model\n",
"\n",
"This notebook steps you through training a robust microWakeWord model. It is intended as a **starting point** for users looking to create a high-performance wake word detection model. This notebook is optimized for Python 3.10.\n",
"\n",
"**The model generated from this notebook is designed for practical use, but achieving optimal performance will require experimentation with various settings and datasets. The provided scripts and configurations aim to give you a strong foundation to build upon.**\n",
@@ -90,7 +88,7 @@
" !wget -O piper-sample-generator/models/en_US-libritts_r-medium.pt 'https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/en_US-libritts_r-medium.pt'\n",
"\n",
"# Install system dependencies\n",
"!\"{sys.executable}\" -m pip install torch torchaudio piper-phonemize-cross==1.2.1 --root-user-action=ignore\n",
"!\"{sys.executable}\" -m pip install torch torchaudio piper-phonemize-cross==1.2.1\n",
"\n",
"# Ensure the repository path is in sys.path\n",
"if \"piper-sample-generator/\" not in sys.path:\n",
@@ -139,23 +137,65 @@
},
"outputs": [],
"source": [
"# Downloads audio data for augmentation. This can be slow!\n",
"# Borrowed from openWakeWord's automatic_model_training.ipynb, accessed March 4, 2024\n",
"#\n",
"# **Important note!** The data downloaded here has a mixture of difference\n",
"# licenses and usage restrictions. As such, any custom models trained with this\n",
"# data should be considered as appropriate for **non-commercial** personal use only.\n",
"\n",
"import os\n",
"import scipy.io.wavfile\n",
"import numpy as np\n",
"from datasets import load_dataset, Audio\n",
"import soundfile as sf\n",
"from pathlib import Path\n",
"from tqdm import tqdm\n",
"from multiprocessing import Pool\n",
"import requests\n",
"import tarfile\n",
"import zipfile\n",
"from datasets import load_dataset\n",
"\n",
"# Function to download and process RIR dataset\n",
"def download_rir_dataset(dataset_name, output_dir, split=\"train\"):\n",
" output_dir = Path(output_dir)\n",
" if not output_dir.exists():\n",
" output_dir.mkdir(parents=True, exist_ok=True)\n",
" try:\n",
" rir_dataset = load_dataset(dataset_name, split=split, streaming=True)\n",
" print(f\"Downloading {dataset_name} to {output_dir}...\")\n",
" for row in tqdm(rir_dataset):\n",
" name = Path(row['audio']['path']).name\n",
" scipy.io.wavfile.write(\n",
" output_dir / name,\n",
" 16000,\n",
" (row['audio']['array'] * 32767).astype(np.int16)\n",
" )\n",
" print(f\"Finished downloading {dataset_name} to {output_dir}.\\n\")\n",
" except Exception as e:\n",
" print(f\"Error downloading {dataset_name}: {e}\")\n",
" else:\n",
" print(f\"{output_dir} already exists. Skipping download.\\n\")\n",
"\n",
"# Download MIT RIRs\n",
"download_rir_dataset(\n",
" \"davidscripka/MIT_environmental_impulse_responses\",\n",
" \"./mit_rirs\"\n",
")\n",
"\n",
"# Function to download files\n",
"def download_file(url, output_path):\n",
" response = requests.get(url, stream=True)\n",
" with open(output_path, \"wb\") as f:\n",
" total_size = int(response.headers.get('content-length', 0))\n",
" with open(output_path, \"wb\") as f, tqdm(\n",
" desc=f\"Downloading {output_path.name}\",\n",
" total=total_size,\n",
" unit=\"B\",\n",
" unit_scale=True,\n",
" unit_divisor=1024,\n",
" ) as bar:\n",
" for chunk in response.iter_content(chunk_size=1024):\n",
" f.write(chunk)\n",
" bar.update(len(chunk))\n",
" print(f\"Downloaded {output_path}\")\n",
"\n",
"# Function to extract .tar files\n",
@@ -173,14 +213,6 @@
" zip_ref.extractall(extract_to)\n",
" print(f\"Extracted {file_name} to {extract_to}\")\n",
"\n",
"# Function to convert audio files to 16kHz WAV format\n",
"def convert_audio(file_path, output_dir):\n",
" output_file = output_dir / file_path.name.replace(\".flac\", \".wav\")\n",
" audio = Audio(sampling_rate=16000).decode_example({\"path\": str(file_path)})\n",
" scipy.io.wavfile.write(\n",
" output_file, 16000, (audio[\"array\"] * 32767).astype(np.int16)\n",
" )\n",
"\n",
"# Directories\n",
"raw_dir = Path(\"./audioset_raw\")\n",
"processed_dir = Path(\"./audioset_16k\")\n",
@@ -194,46 +226,119 @@
"]\n",
"\n",
"# Step 1: Download all parts of the dataset\n",
"print(\"Downloading datasets...\")\n",
"for link in dataset_links:\n",
" file_name = link.split(\"/\")[-1]\n",
" output_path = raw_dir / file_name\n",
" if not output_path.exists():\n",
" print(f\"Downloading {file_name}...\")\n",
" download_file(link, output_path)\n",
"\n",
"# Step 2: Extract all .tar files\n",
"print(\"Extracting datasets...\")\n",
"for file_path in raw_dir.glob(\"*.tar\"):\n",
" extract_dir = raw_dir / file_path.stem\n",
" if not extract_dir.exists():\n",
" print(f\"Extracting {file_path}...\")\n",
" extract_tar(file_path, extract_dir)\n",
"\n",
"# Step 3: Convert audio files to 16kHz WAV\n",
"audio_files = list(Path(raw_dir).glob(\"**/*.flac\"))\n",
"\n",
"def process_audio(file_path):\n",
" convert_audio(file_path, processed_dir)\n",
"print(f\"Number of FLAC files found: {len(audio_files)}\")\n",
"if not audio_files:\n",
" raise FileNotFoundError(\"No .flac files found in the raw directories. Check your dataset extraction.\")\n",
"\n",
"print(\"Converting audio files to 16kHz WAV...\")\n",
"with Pool() as pool:\n",
" list(tqdm(pool.imap(process_audio, audio_files), total=len(audio_files)))\n",
"corrupted_files = []\n",
"resampled_files = []\n",
"\n",
"# Optional: Download and process additional datasets\n",
"additional_datasets = {\n",
" \"fsd50k\": \"https://zenodo.org/record/4060432/files/FSD50K.dev_audio.zip\",\n",
" \"fma_xs\": \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/fma_xs.zip\",\n",
"}\n",
"\n",
"for dataset_name, link in additional_datasets.items():\n",
" dataset_dir = Path(f\"./{dataset_name}\")\n",
" dataset_dir.mkdir(exist_ok=True)\n",
"for file_path in tqdm(audio_files, desc=\"Processing audio files\"):\n",
" try:\n",
" download_and_extract_zip(link, dataset_dir)\n",
" # Add specific processing logic for each dataset if required\n",
" except Exception as e:\n",
" print(f\"Error processing {dataset_name}: {e}\")\n",
" # Read the .flac file\n",
" data, samplerate = sf.read(file_path)\n",
"\n",
"print(\"Full-scale dataset preparation complete!\")\n"
" # Check and resample if needed\n",
" if samplerate != 16000:\n",
" resampled_files.append(str(file_path))\n",
" data = np.interp(\n",
" np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n",
" np.arange(len(data)),\n",
" data,\n",
" )\n",
"\n",
" # Convert and save as WAV\n",
" output_path = processed_dir / file_path.name.replace(\".flac\", \".wav\")\n",
" scipy.io.wavfile.write(\n",
" output_path,\n",
" 16000,\n",
" (data * 32767).astype(np.int16),\n",
" )\n",
" except Exception as e:\n",
" corrupted_files.append(str(file_path))\n",
"\n",
"# Log corrupted files\n",
"if corrupted_files:\n",
" with open(\"corrupted_files.log\", \"w\") as log_file:\n",
" log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n",
"\n",
"# Log resampled files\n",
"if resampled_files:\n",
" with open(\"resampled_files.log\", \"w\") as log_file:\n",
" log_file.writelines(f\"{file}\\n\" for file in resampled_files)\n",
"\n",
"print(f\"Audio conversion complete! {len(corrupted_files)} files corrupted and logged.\")\n",
"print(f\"{len(resampled_files)} files resampled and logged.\")\n",
"\n",
"# Process fma_xs dataset\n",
"fma_raw_dir = Path(\"./fma\")\n",
"fma_processed_dir = Path(\"./fma_16k\") # Separate directory for fma_xs processed files\n",
"fma_raw_dir.mkdir(exist_ok=True)\n",
"fma_processed_dir.mkdir(exist_ok=True)\n",
"\n",
"fma_link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/fma_xs.zip\"\n",
"fma_zip_path = fma_raw_dir / \"fma_xs.zip\"\n",
"\n",
"if not fma_zip_path.exists():\n",
" print(\"Downloading fma_xs dataset...\")\n",
" download_file(fma_link, fma_zip_path)\n",
"\n",
"print(\"Extracting fma_xs dataset...\")\n",
"download_and_extract_zip(fma_link, fma_raw_dir)\n",
"\n",
"fma_audio_files = list(fma_raw_dir.glob(\"**/*.mp3\"))\n",
"print(f\"Number of MP3 files found: {len(fma_audio_files)}\")\n",
"if not fma_audio_files:\n",
" raise FileNotFoundError(\"No .mp3 files found in the fma directory. Check your dataset extraction.\")\n",
"\n",
"print(\"Converting fma_xs files to 16kHz WAV...\")\n",
"for file_path in tqdm(fma_audio_files, desc=\"Processing fma_xs files\"):\n",
" try:\n",
" # Read the .mp3 file\n",
" data, samplerate = sf.read(file_path)\n",
"\n",
" # Check and resample if needed\n",
" if samplerate != 16000:\n",
" data = np.interp(\n",
" np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n",
" np.arange(len(data)),\n",
" data,\n",
" )\n",
"\n",
" # Convert and save as WAV\n",
" output_path = fma_processed_dir / file_path.name.replace(\".mp3\", \".wav\")\n",
" scipy.io.wavfile.write(\n",
" output_path,\n",
" 16000,\n",
" (data * 32767).astype(np.int16),\n",
" )\n",
" except Exception as e:\n",
" corrupted_files.append(str(file_path))\n",
"\n",
"# Log corrupted files from fma_xs\n",
"if corrupted_files:\n",
" with open(\"fma_corrupted_files.log\", \"w\") as log_file:\n",
" log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n",
"\n",
"print(\"fma_xs processing complete!\")\n",
"print(\"Full-scale dataset preparation complete!\")"
]
},
{
@@ -261,7 +366,7 @@
" return True\n",
"\n",
"# Paths to augmented data\n",
"impulse_paths = ['mit_rirs', 'openair_rirs']\n",
"impulse_paths = ['mit_rirs']\n",
"background_paths = ['fma_16k', 'audioset_16k']\n",
"\n",
"if not validate_directories(impulse_paths + background_paths):\n",