mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
Update advanced_training_notebook.ipynb
This commit is contained in:
@@ -144,79 +144,28 @@
|
|||||||
"# licenses and usage restrictions. As such, any custom models trained with this\n",
|
"# licenses and usage restrictions. As such, any custom models trained with this\n",
|
||||||
"# data should be considered as appropriate for **non-commercial** personal use only.\n",
|
"# data should be considered as appropriate for **non-commercial** personal use only.\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"import datasets\n",
|
||||||
|
"import scipy\n",
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"import scipy.io.wavfile\n",
|
"import requests\n",
|
||||||
|
"\n",
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
"import soundfile as sf\n",
|
"\n",
|
||||||
"from pathlib import Path\n",
|
"from pathlib import Path\n",
|
||||||
"from tqdm import tqdm\n",
|
"from tqdm import tqdm\n",
|
||||||
"import requests\n",
|
|
||||||
"import tarfile\n",
|
|
||||||
"import zipfile\n",
|
|
||||||
"from datasets import load_dataset\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# Function to download and process RIR dataset\n",
|
"## Download MIR RIR data\n",
|
||||||
"# Function to download and process RIR dataset\n",
|
|
||||||
"def download_rir_dataset(dataset_name, output_dir, split=\"train\"):\n",
|
|
||||||
" output_dir = Path(output_dir)\n",
|
|
||||||
" if not output_dir.exists():\n",
|
|
||||||
" output_dir.mkdir(parents=True, exist_ok=True)\n",
|
|
||||||
" try:\n",
|
|
||||||
" rir_dataset = load_dataset(dataset_name, split=split, streaming=True)\n",
|
|
||||||
" print(f\"Downloading {dataset_name} to {output_dir}...\")\n",
|
|
||||||
" for row in tqdm(rir_dataset):\n",
|
|
||||||
" name = Path(row['audio']['path']).name\n",
|
|
||||||
" # Save the original audio file\n",
|
|
||||||
" with open(output_dir / name, \"wb\") as audio_file:\n",
|
|
||||||
" audio_file.write(row[\"audio\"][\"bytes\"])\n",
|
|
||||||
" print(f\"Finished downloading {dataset_name} to {output_dir}.\\n\")\n",
|
|
||||||
" except Exception as e:\n",
|
|
||||||
" print(f\"Error downloading {dataset_name}: {e}\")\n",
|
|
||||||
" else:\n",
|
|
||||||
" print(f\"{output_dir} already exists. Skipping download.\\n\")\n",
|
|
||||||
" \n",
|
|
||||||
"# Download MIT RIRs\n",
|
|
||||||
"download_rir_dataset(\n",
|
|
||||||
" \"davidscripka/MIT_environmental_impulse_responses\",\n",
|
|
||||||
" \"./mit_rirs\"\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# Function to download files\n",
|
"output_dir = \"./mit_rirs\"\n",
|
||||||
"def download_file(url, output_path):\n",
|
"if not os.path.exists(output_dir):\n",
|
||||||
" response = requests.get(url, stream=True)\n",
|
" os.mkdir(output_dir)\n",
|
||||||
" total_size = int(response.headers.get('content-length', 0))\n",
|
" rir_dataset = datasets.load_dataset(\"davidscripka/MIT_environmental_impulse_responses\", split=\"train\", streaming=True)\n",
|
||||||
" with open(output_path, \"wb\") as f, tqdm(\n",
|
" # Save clips to 16-bit PCM wav files\n",
|
||||||
" desc=f\"Downloading {output_path.name}\",\n",
|
" for row in tqdm(rir_dataset):\n",
|
||||||
" total=total_size,\n",
|
" name = row['audio']['path'].split('/')[-1]\n",
|
||||||
" unit=\"B\",\n",
|
" scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array'] * 32767).astype(np.int16))\n",
|
||||||
" unit_scale=True,\n",
|
|
||||||
" unit_divisor=1024,\n",
|
|
||||||
" ) as bar:\n",
|
|
||||||
" for chunk in response.iter_content(chunk_size=1024):\n",
|
|
||||||
" f.write(chunk)\n",
|
|
||||||
" bar.update(len(chunk))\n",
|
|
||||||
" print(f\"Downloaded {output_path}\")\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# Function to extract .tar files\n",
|
"## Download noise and background audio\n",
|
||||||
"def extract_tar(file_path, extract_dir):\n",
|
|
||||||
" with tarfile.open(file_path, \"r\") as tar:\n",
|
|
||||||
" tar.extractall(path=extract_dir)\n",
|
|
||||||
" print(f\"Extracted {file_path} to {extract_dir}\")\n",
|
|
||||||
"\n",
|
|
||||||
"# Function to download and extract ZIP files\n",
|
|
||||||
"def download_and_extract_zip(url, extract_to):\n",
|
|
||||||
" file_name = url.split(\"/\")[-1]\n",
|
|
||||||
" local_path = Path(extract_to) / file_name\n",
|
|
||||||
" download_file(url, local_path)\n",
|
|
||||||
" with zipfile.ZipFile(local_path, 'r') as zip_ref:\n",
|
|
||||||
" zip_ref.extractall(extract_to)\n",
|
|
||||||
" print(f\"Extracted {file_name} to {extract_to}\")\n",
|
|
||||||
"\n",
|
|
||||||
"# Directories\n",
|
|
||||||
"raw_dir = Path(\"./audioset_raw\")\n",
|
|
||||||
"processed_dir = Path(\"./audioset_16k\")\n",
|
|
||||||
"raw_dir.mkdir(exist_ok=True)\n",
|
|
||||||
"processed_dir.mkdir(exist_ok=True)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# Full-scale dataset download links\n",
|
"# Full-scale dataset download links\n",
|
||||||
"dataset_links = [\n",
|
"dataset_links = [\n",
|
||||||
@@ -224,120 +173,112 @@
|
|||||||
" for i in range(10) # Adjust for additional parts\n",
|
" for i in range(10) # Adjust for additional parts\n",
|
||||||
"]\n",
|
"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Step 1: Download all parts of the dataset\n",
|
"if not os.path.exists(\"audioset\"):\n",
|
||||||
"print(\"Downloading datasets...\")\n",
|
" os.mkdir(\"audioset\")\n",
|
||||||
"for link in dataset_links:\n",
|
|
||||||
" file_name = link.split(\"/\")[-1]\n",
|
|
||||||
" output_path = raw_dir / file_name\n",
|
|
||||||
" if not output_path.exists():\n",
|
|
||||||
" download_file(link, output_path)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# Step 2: Extract all .tar files\n",
|
" for link in dataset_links:\n",
|
||||||
"print(\"Extracting datasets...\")\n",
|
" fname = link.split(\"/\")[-1]\n",
|
||||||
"for file_path in raw_dir.glob(\"*.tar\"):\n",
|
" out_dir = f\"audioset/{fname}\"\n",
|
||||||
" extract_dir = raw_dir / file_path.stem\n",
|
" if not os.path.exists(out_dir):\n",
|
||||||
" if not extract_dir.exists():\n",
|
" print(f\"Downloading {fname}...\")\n",
|
||||||
" extract_tar(file_path, extract_dir)\n",
|
" response = requests.get(link, stream=True)\n",
|
||||||
|
" with open(out_dir, \"wb\") as f:\n",
|
||||||
|
" for chunk in response.iter_content(chunk_size=1024):\n",
|
||||||
|
" if chunk:\n",
|
||||||
|
" f.write(chunk)\n",
|
||||||
|
" print(f\"Downloaded {fname}\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Step 3: Convert audio files to 16kHz WAV\n",
|
" # Extract the tar file\n",
|
||||||
"audio_files = list(Path(raw_dir).glob(\"**/*.flac\"))\n",
|
" print(f\"Extracting {fname}...\")\n",
|
||||||
"print(f\"Number of FLAC files found: {len(audio_files)}\")\n",
|
" os.system(f\"cd audioset && tar -xf {fname}\")\n",
|
||||||
"if not audio_files:\n",
|
|
||||||
" raise FileNotFoundError(\"No .flac files found in the raw directories. Check your dataset extraction.\")\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"print(\"Converting audio files to 16kHz WAV...\")\n",
|
" output_dir = \"./audioset_16k\"\n",
|
||||||
"corrupted_files = []\n",
|
" if not os.path.exists(output_dir):\n",
|
||||||
"resampled_files = []\n",
|
" os.mkdir(output_dir)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"for file_path in tqdm(audio_files, desc=\"Processing audio files\"):\n",
|
" # Save clips to 16-bit PCM wav files\n",
|
||||||
" try:\n",
|
" audio_files = list(Path(\"audioset/audio\").glob(\"**/*.flac\"))\n",
|
||||||
" # Read the .flac file\n",
|
" print(f\"Number of FLAC files found: {len(audio_files)}\")\n",
|
||||||
" data, samplerate = sf.read(file_path)\n",
|
" if not audio_files:\n",
|
||||||
|
" raise FileNotFoundError(\"No .flac files found in the audioset directory. Check your dataset extraction.\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Check and resample if needed\n",
|
" print(\"Converting audioset files to 16kHz WAV...\")\n",
|
||||||
" if samplerate != 16000:\n",
|
" for file_path in tqdm(audio_files):\n",
|
||||||
" resampled_files.append(str(file_path))\n",
|
" try:\n",
|
||||||
" data = np.interp(\n",
|
" # Read and convert the .flac file\n",
|
||||||
" np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n",
|
" data, samplerate = scipy.io.wavfile.read(file_path)\n",
|
||||||
" np.arange(len(data)),\n",
|
" if samplerate != 16000:\n",
|
||||||
" data,\n",
|
" data = np.interp(\n",
|
||||||
|
" np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n",
|
||||||
|
" np.arange(len(data)),\n",
|
||||||
|
" data,\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" # Save as WAV\n",
|
||||||
|
" output_path = Path(output_dir) / file_path.name.replace(\".flac\", \".wav\")\n",
|
||||||
|
" scipy.io.wavfile.write(\n",
|
||||||
|
" output_path,\n",
|
||||||
|
" 16000,\n",
|
||||||
|
" (data * 32767).astype(np.int16),\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
|
" except Exception as e:\n",
|
||||||
|
" print(f\"Error converting {file_path}: {e}\")\n",
|
||||||
|
" continue # Skip and proceed with the next file\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Convert and save as WAV\n",
|
"# Free Music Archive dataset\n",
|
||||||
" output_path = processed_dir / file_path.name.replace(\".flac\", \".wav\")\n",
|
"# https://github.com/mdeff/fma\n",
|
||||||
" scipy.io.wavfile.write(\n",
|
"# (Third-party mchl914 extra small set)\n",
|
||||||
" output_path,\n",
|
|
||||||
" 16000,\n",
|
|
||||||
" (data * 32767).astype(np.int16),\n",
|
|
||||||
" )\n",
|
|
||||||
" except Exception as e:\n",
|
|
||||||
" corrupted_files.append(str(file_path))\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# Log corrupted files\n",
|
"output_dir = \"./fma\"\n",
|
||||||
"if corrupted_files:\n",
|
"if not os.path.exists(output_dir):\n",
|
||||||
" with open(\"corrupted_files.log\", \"w\") as log_file:\n",
|
" os.mkdir(output_dir)\n",
|
||||||
" log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n",
|
" fname = \"fma_xs.zip\"\n",
|
||||||
"\n",
|
" link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/\" + fname\n",
|
||||||
"# Log resampled files\n",
|
" out_dir = f\"fma/{fname}\"\n",
|
||||||
"if resampled_files:\n",
|
|
||||||
" with open(\"resampled_files.log\", \"w\") as log_file:\n",
|
|
||||||
" log_file.writelines(f\"{file}\\n\" for file in resampled_files)\n",
|
|
||||||
"\n",
|
|
||||||
"print(f\"Audio conversion complete! {len(corrupted_files)} files corrupted and logged.\")\n",
|
|
||||||
"print(f\"{len(resampled_files)} files resampled and logged.\")\n",
|
|
||||||
"\n",
|
|
||||||
"# Process fma_xs dataset\n",
|
|
||||||
"fma_raw_dir = Path(\"./fma\")\n",
|
|
||||||
"fma_processed_dir = Path(\"./fma_16k\") # Separate directory for fma_xs processed files\n",
|
|
||||||
"fma_raw_dir.mkdir(exist_ok=True)\n",
|
|
||||||
"fma_processed_dir.mkdir(exist_ok=True)\n",
|
|
||||||
"\n",
|
|
||||||
"fma_link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/fma_xs.zip\"\n",
|
|
||||||
"fma_zip_path = fma_raw_dir / \"fma_xs.zip\"\n",
|
|
||||||
"\n",
|
|
||||||
"if not fma_zip_path.exists():\n",
|
|
||||||
" print(\"Downloading fma_xs dataset...\")\n",
|
" print(\"Downloading fma_xs dataset...\")\n",
|
||||||
" download_file(fma_link, fma_zip_path)\n",
|
" response = requests.get(link, stream=True)\n",
|
||||||
|
" with open(out_dir, \"wb\") as f:\n",
|
||||||
|
" for chunk in response.iter_content(chunk_size=1024):\n",
|
||||||
|
" if chunk:\n",
|
||||||
|
" f.write(chunk)\n",
|
||||||
|
" print(\"Downloaded fma_xs dataset.\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(\"Extracting fma_xs dataset...\")\n",
|
" os.system(f\"cd fma && unzip -q {fname}\")\n",
|
||||||
"download_and_extract_zip(fma_link, fma_raw_dir)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"fma_audio_files = list(fma_raw_dir.glob(\"**/*.mp3\"))\n",
|
" output_dir = \"./fma_16k\"\n",
|
||||||
"print(f\"Number of MP3 files found: {len(fma_audio_files)}\")\n",
|
" if not os.path.exists(output_dir):\n",
|
||||||
"if not fma_audio_files:\n",
|
" os.mkdir(output_dir)\n",
|
||||||
" raise FileNotFoundError(\"No .mp3 files found in the fma directory. Check your dataset extraction.\")\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"print(\"Converting fma_xs files to 16kHz WAV...\")\n",
|
" # Save clips to 16-bit PCM wav files\n",
|
||||||
"for file_path in tqdm(fma_audio_files, desc=\"Processing fma_xs files\"):\n",
|
" audio_files = list(Path(\"fma/fma_small\").glob(\"**/*.mp3\"))\n",
|
||||||
" try:\n",
|
" print(f\"Number of MP3 files found: {len(audio_files)}\")\n",
|
||||||
" # Read the .mp3 file\n",
|
" if not audio_files:\n",
|
||||||
" data, samplerate = sf.read(file_path)\n",
|
" raise FileNotFoundError(\"No .mp3 files found in the fma directory. Check your dataset extraction.\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Check and resample if needed\n",
|
" print(\"Converting fma_xs files to 16kHz WAV...\")\n",
|
||||||
" if samplerate != 16000:\n",
|
" for file_path in tqdm(audio_files):\n",
|
||||||
" data = np.interp(\n",
|
" try:\n",
|
||||||
" np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n",
|
" # Read and convert the .mp3 file\n",
|
||||||
" np.arange(len(data)),\n",
|
" data, samplerate = scipy.io.wavfile.read(file_path)\n",
|
||||||
" data,\n",
|
" if samplerate != 16000:\n",
|
||||||
|
" data = np.interp(\n",
|
||||||
|
" np.linspace(0, len(data), int(len(data) * 16000 / samplerate)),\n",
|
||||||
|
" np.arange(len(data)),\n",
|
||||||
|
" data,\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" # Save as WAV\n",
|
||||||
|
" output_path = Path(output_dir) / file_path.name.replace(\".mp3\", \".wav\")\n",
|
||||||
|
" scipy.io.wavfile.write(\n",
|
||||||
|
" output_path,\n",
|
||||||
|
" 16000,\n",
|
||||||
|
" (data * 32767).astype(np.int16),\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
|
" except Exception as e:\n",
|
||||||
|
" print(f\"Error converting {file_path}: {e}\")\n",
|
||||||
|
" continue # Skip and proceed with the next file\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Convert and save as WAV\n",
|
"print(\"All datasets processed successfully!\")"
|
||||||
" output_path = fma_processed_dir / file_path.name.replace(\".mp3\", \".wav\")\n",
|
|
||||||
" scipy.io.wavfile.write(\n",
|
|
||||||
" output_path,\n",
|
|
||||||
" 16000,\n",
|
|
||||||
" (data * 32767).astype(np.int16),\n",
|
|
||||||
" )\n",
|
|
||||||
" except Exception as e:\n",
|
|
||||||
" corrupted_files.append(str(file_path))\n",
|
|
||||||
"\n",
|
|
||||||
"# Log corrupted files from fma_xs\n",
|
|
||||||
"if corrupted_files:\n",
|
|
||||||
" with open(\"fma_corrupted_files.log\", \"w\") as log_file:\n",
|
|
||||||
" log_file.writelines(f\"{file}\\n\" for file in corrupted_files)\n",
|
|
||||||
"\n",
|
|
||||||
"print(\"fma_xs processing complete!\")\n",
|
|
||||||
"print(\"Full-scale dataset preparation complete!\")"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -712,7 +653,7 @@
|
|||||||
"# Define the JSON metadata for the model\n",
|
"# Define the JSON metadata for the model\n",
|
||||||
"json_data = {\n",
|
"json_data = {\n",
|
||||||
" \"type\": \"micro\",\n",
|
" \"type\": \"micro\",\n",
|
||||||
" \"wake_word\": \"hey_norman\", # Adjust based on your target wake word\n",
|
" \"wake_word\": \"khum_puter\", # Adjust based on your target wake word\n",
|
||||||
" \"author\": \"master phooey\",\n",
|
" \"author\": \"master phooey\",\n",
|
||||||
" \"website\": \"https://github.com/MasterPhooey/MicroWakeWord-Trainer-Docker\",\n",
|
" \"website\": \"https://github.com/MasterPhooey/MicroWakeWord-Trainer-Docker\",\n",
|
||||||
" \"model\": \"stream_state_internal_quant.tflite\",\n",
|
" \"model\": \"stream_state_internal_quant.tflite\",\n",
|
||||||
|
|||||||
Reference in New Issue
Block a user