Update basic_training_notebook.ipynb

This commit is contained in:
MasterPhooey
2025-01-03 08:07:24 -06:00
committed by GitHub
parent 8a10e4f8f7
commit 2e80414de7

View File

@@ -501,36 +501,38 @@
"id": "WoEXJBaiC9mf" "id": "WoEXJBaiC9mf"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# Trains a model. When finished, it will quantize and convert the model to a\n", "# Trains a model. When finished, it will quantize and convert the model to a\n",
"# streaming version suitable for on-device detection.\n", "# streaming version suitable for on-device detection.\n",
"# It will resume if stopped, but it will start over at the configured training\n", "# It will resume if stopped, but it will start over at the configured training\n",
"# steps in the yaml file.\n", "# steps in the yaml file.\n",
"# Change --train 0 to only convert and test the best-weighted model.\n", "# Change --train 0 to only convert and test the best-weighted model.\n",
"# On Google colab, it doesn't print the mini-batch results, so it may appear\n", "# On Google colab, it doesn't print the mini-batch results, so it may appear\n",
"# stuck for several minutes! Additionally, it is very slow compared to training\n", "# stuck for several minutes! Additionally, it is very slow compared to training\n",
"# on a local GPU.\n", "# on a local GPU.\n",
"\n", "import os\n",
"!python -m microwakeword.model_train_eval \\\n", "os.environ['LD_LIBRARY_PATH'] = \"/usr/lib/x86_64-linux-gnu:\" + os.environ.get('LD_LIBRARY_PATH', '')\n",
"--training_config='training_parameters.yaml' \\\n", "\n",
"--train 1 \\\n", "!\"{sys.executable}\" -m microwakeword.model_train_eval \\\n",
"--restore_checkpoint 1 \\\n", "--training_config='training_parameters.yaml' \\\n",
"--test_tf_nonstreaming 0 \\\n", "--train 1 \\\n",
"--test_tflite_nonstreaming 0 \\\n", "--restore_checkpoint 1 \\\n",
"--test_tflite_nonstreaming_quantized 0 \\\n", "--test_tf_nonstreaming 0 \\\n",
"--test_tflite_streaming 0 \\\n", "--test_tflite_nonstreaming 0 \\\n",
"--test_tflite_streaming_quantized 1 \\\n", "--test_tflite_nonstreaming_quantized 0 \\\n",
"--use_weights \"best_weights\" \\\n", "--test_tflite_streaming 0 \\\n",
"mixednet \\\n", "--test_tflite_streaming_quantized 1 \\\n",
"--pointwise_filters \"64,64,64,64\" \\\n", "--use_weights \"best_weights\" \\\n",
"--repeat_in_block \"1, 1, 1, 1\" \\\n", "mixednet \\\n",
"--mixconv_kernel_sizes '[5], [7,11], [9,15], [23]' \\\n", "--pointwise_filters \"64,64,64,64\" \\\n",
"--residual_connection \"0,0,0,0\" \\\n", "--repeat_in_block \"1, 1, 1, 1\" \\\n",
"--first_conv_filters 32 \\\n", "--mixconv_kernel_sizes '[5], [7,11], [9,15], [23]' \\\n",
"--first_conv_kernel_size 5 \\\n", "--residual_connection \"0,0,0,0\" \\\n",
"--stride 3" "--first_conv_filters 32 \\\n",
] "--first_conv_kernel_size 5 \\\n",
}, "--stride 3"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
@@ -539,20 +541,21 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# Downloads the tflite model file. To use on the device, you need to write a\n", "import shutil\n",
"# Model JSON file. See https://esphome.io/components/micro_wake_word for the\n", "from IPython.display import FileLink\n",
"# documentation and\n",
"# https://github.com/esphome/micro-wake-word-models/tree/main/models/v2 for\n",
"# examples. Adjust the probability threshold based on the test results obtained\n",
"# after training is finished. You may also need to increase the Tensor arena\n",
"# model size if the model fails to load.\n",
"\n", "\n",
"from google.colab import files\n", "# Define the source path and desired download location\n",
"source_path = \"trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite\"\n",
"destination_path = \"./stream_state_internal_quant.tflite\"\n",
"\n", "\n",
"files.download(f\"trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite\")" "# Copy the file to the current working directory\n",
"shutil.copy(source_path, destination_path)\n",
"\n",
"# Generate a link to download the file\n",
"print(\"Download your file:\")\n",
"FileLink(destination_path)"
] ]
} },
],
"metadata": { "metadata": {
"accelerator": "GPU", "accelerator": "GPU",
"colab": { "colab": {