mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
Update advanced_training_notebook.ipynb
This commit is contained in:
@@ -10,7 +10,6 @@
|
|||||||
" <img src=\"https://raw.githubusercontent.com/MasterPhooey/MicroWakeWord-Trainer-Docker/refs/heads/main/mmw.png\" alt=\"MicroWakeWord Trainer Logo\" width=\"100\" />\n",
|
" <img src=\"https://raw.githubusercontent.com/MasterPhooey/MicroWakeWord-Trainer-Docker/refs/heads/main/mmw.png\" alt=\"MicroWakeWord Trainer Logo\" width=\"100\" />\n",
|
||||||
" <h1>MicroWakeWord Trainer Docker</h1>\n",
|
" <h1>MicroWakeWord Trainer Docker</h1>\n",
|
||||||
"</div>\n",
|
"</div>\n",
|
||||||
"**This notebook requires 24G of Vram.**\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"This notebook steps you through training a robust microWakeWord model. It is intended as a **starting point** for users looking to create a high-performance wake word detection model. This notebook is optimized for Python 3.10.\n",
|
"This notebook steps you through training a robust microWakeWord model. It is intended as a **starting point** for users looking to create a high-performance wake word detection model. This notebook is optimized for Python 3.10.\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -641,7 +640,7 @@
|
|||||||
"config[\"clip_duration_ms\"] = 2000 # Increased\n",
|
"config[\"clip_duration_ms\"] = 2000 # Increased\n",
|
||||||
"\n",
|
"\n",
|
||||||
"config[\"target_minimization\"] = 0.9\n",
|
"config[\"target_minimization\"] = 0.9\n",
|
||||||
"config[\"minimization_metric\"] = None # Updated\n",
|
"config[\"minimization_metric\"] = \"false_positive_rate\" # Updated\n",
|
||||||
"config[\"maximization_metric\"] = \"average_viable_recall\"\n",
|
"config[\"maximization_metric\"] = \"average_viable_recall\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"with open(os.path.join(\"training_parameters.yaml\"), \"w\") as file:\n",
|
"with open(os.path.join(\"training_parameters.yaml\"), \"w\") as file:\n",
|
||||||
@@ -676,19 +675,19 @@
|
|||||||
"--training_config='training_parameters.yaml' \\\n",
|
"--training_config='training_parameters.yaml' \\\n",
|
||||||
"--train 1 \\\n",
|
"--train 1 \\\n",
|
||||||
"--restore_checkpoint 1 \\\n",
|
"--restore_checkpoint 1 \\\n",
|
||||||
"--test_tf_nonstreaming 1 \\\n",
|
"--test_tf_nonstreaming 0 \\\n",
|
||||||
"--test_tflite_nonstreaming 1 \\\n",
|
"--test_tflite_nonstreaming 0 \\\n",
|
||||||
"--test_tflite_nonstreaming_quantized 1 \\\n",
|
"--test_tflite_nonstreaming_quantized 0 \\\n",
|
||||||
"--test_tflite_streaming 1 \\\n",
|
"--test_tflite_streaming 0 \\\n",
|
||||||
"--test_tflite_streaming_quantized 1 \\\n",
|
"--test_tflite_streaming_quantized 1 \\\n",
|
||||||
"--use_weights \"best_weights\" \\\n",
|
"--use_weights \"best_weights\" \\\n",
|
||||||
"mixednet \\\n",
|
"mixednet \\\n",
|
||||||
"--pointwise_filters \"64,96,128,160\" \\\n",
|
"--pointwise_filters \"64,64,64,64\" \\\n",
|
||||||
"--repeat_in_block \"2,2,3,3\" \\\n",
|
"--repeat_in_block \"1,1,1,1\" \\\n",
|
||||||
"--mixconv_kernel_sizes '[5], [7,11], [9,15], [17,23]' \\\n",
|
"--mixconv_kernel_sizes '[5], [7,11], [9,15], [23]' \\\n",
|
||||||
"--residual_connection \"1,1,1,0\" \\\n",
|
"--residual_connection \"0,0,0,0\" \\\n",
|
||||||
"--first_conv_filters 48 \\\n",
|
"--first_conv_filters 32 \\\n",
|
||||||
"--first_conv_kernel_size 7 \\\n",
|
"--first_conv_kernel_size 5 \\\n",
|
||||||
"--stride 2\n"
|
"--stride 2\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -702,7 +701,9 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"import shutil\n",
|
"import shutil\n",
|
||||||
"import json\n",
|
"import json\n",
|
||||||
"from IPython.display import FileLink, HTML\n",
|
"from http.server import HTTPServer, SimpleHTTPRequestHandler\n",
|
||||||
|
"import threading\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Copy the TFLite model file to the working directory\n",
|
"# Copy the TFLite model file to the working directory\n",
|
||||||
"source_path = \"trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite\"\n",
|
"source_path = \"trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite\"\n",
|
||||||
@@ -712,7 +713,7 @@
|
|||||||
"# Define the JSON metadata for the model\n",
|
"# Define the JSON metadata for the model\n",
|
||||||
"json_data = {\n",
|
"json_data = {\n",
|
||||||
" \"type\": \"micro\",\n",
|
" \"type\": \"micro\",\n",
|
||||||
" \"wake_word\": \"hey_norman\", # Adjust based on your target wake word\n",
|
" \"wake_word\": \"khum_puter\", # Adjust based on your target wake word\n",
|
||||||
" \"author\": \"master phooey\",\n",
|
" \"author\": \"master phooey\",\n",
|
||||||
" \"website\": \"https://github.com/MasterPhooey/MicroWakeWord-Trainer-Docker\",\n",
|
" \"website\": \"https://github.com/MasterPhooey/MicroWakeWord-Trainer-Docker\",\n",
|
||||||
" \"model\": \"stream_state_internal_quant.tflite\",\n",
|
" \"model\": \"stream_state_internal_quant.tflite\",\n",
|
||||||
@@ -732,19 +733,37 @@
|
|||||||
"with open(json_path, \"w\") as json_file:\n",
|
"with open(json_path, \"w\") as json_file:\n",
|
||||||
" json.dump(json_data, json_file, indent=2)\n",
|
" json.dump(json_data, json_file, indent=2)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Generate download links with styled HTML\n",
|
"# Custom HTTPRequestHandler to set the Content-Disposition header\n",
|
||||||
"tflite_link = FileLink(destination_path)\n",
|
"class CustomHTTPRequestHandler(SimpleHTTPRequestHandler):\n",
|
||||||
"json_link = FileLink(json_path)\n",
|
" def end_headers(self):\n",
|
||||||
|
" if self.path.endswith(\".json\"):\n",
|
||||||
|
" self.send_header(\"Content-Disposition\", \"attachment\")\n",
|
||||||
|
" super().end_headers()\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"# Start an HTTP server in a separate thread\n",
|
||||||
|
"def start_server():\n",
|
||||||
|
" server = HTTPServer((\"0.0.0.0\", 8000), CustomHTTPRequestHandler)\n",
|
||||||
|
" print(\"Serving files on http://localhost:8000\")\n",
|
||||||
|
" server.serve_forever()\n",
|
||||||
|
"\n",
|
||||||
|
"thread = threading.Thread(target=start_server, daemon=True)\n",
|
||||||
|
"thread.start()\n",
|
||||||
|
"\n",
|
||||||
|
"# Generate download links with styled HTML\n",
|
||||||
"html_content = f\"\"\"\n",
|
"html_content = f\"\"\"\n",
|
||||||
|
"<div align=\\\"center\\\">,\n",
|
||||||
|
"<img src=\\\"https://raw.githubusercontent.com/MasterPhooey/MicroWakeWord-Trainer-Docker/refs/heads/main/mmw.png\\\" alt=\\\"MicroWakeWord Trainer Logo\\\" width=\\\"100\\\" />\\n\"\n",
|
||||||
|
"<h1>MicroWakeWord Trainer Docker</h1>\n",
|
||||||
|
"</div>\\n\n",
|
||||||
"<h3 style=\"color:orange;\">Your files are ready for download:</h3>\n",
|
"<h3 style=\"color:orange;\">Your files are ready for download:</h3>\n",
|
||||||
"<ul>\n",
|
"<ul>\n",
|
||||||
" <li><b><a href=\"{tflite_link.url}\" target=\"_blank\" style=\"color:orange;\">TFLite Model: stream_state_internal_quant.tflite</a></b></li>\n",
|
" <li><b><a href=\"http://localhost:8000/stream_state_internal_quant.tflite\" target=\"_blank\" style=\"color:orange;\">TFLite Model: stream_state_internal_quant.tflite</a></b></li>\n",
|
||||||
" <li><b><a href=\"{json_link.url}\" target=\"_blank\" style=\"color:orange;\">JSON Metadata: stream_state_internal_quant.json</a></b></li>\n",
|
" <li><b><a href=\"http://localhost:8000/stream_state_internal_quant.json\" target=\"_blank\" style=\"color:orange;\">JSON Metadata: stream_state_internal_quant.json</a></b></li>\n",
|
||||||
"</ul>\n",
|
"</ul>\n",
|
||||||
"<p style=\"font-size:12px; color:gray;\">Click the links to download the files. Ensure the files are moved to the correct directory for deployment.</p>\n",
|
"<p style=\"font-size:12px; color:gray;\">Click the links to download the files. Ensure the files are moved to the correct directory for deployment.</p>\n",
|
||||||
"\"\"\"\n",
|
"\"\"\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"from IPython.display import HTML\n",
|
||||||
"display(HTML(html_content))"
|
"display(HTML(html_content))"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user