diff --git a/basic_training_notebook.ipynb b/basic_training_notebook.ipynb index baa88fb..5032690 100644 --- a/basic_training_notebook.ipynb +++ b/basic_training_notebook.ipynb @@ -501,36 +501,38 @@ "id": "WoEXJBaiC9mf" }, "outputs": [], - "source": [ - "# Trains a model. When finished, it will quantize and convert the model to a\n", - "# streaming version suitable for on-device detection.\n", - "# It will resume if stopped, but it will start over at the configured training\n", - "# steps in the yaml file.\n", - "# Change --train 0 to only convert and test the best-weighted model.\n", - "# On Google colab, it doesn't print the mini-batch results, so it may appear\n", - "# stuck for several minutes! Additionally, it is very slow compared to training\n", - "# on a local GPU.\n", - "\n", - "!python -m microwakeword.model_train_eval \\\n", - "--training_config='training_parameters.yaml' \\\n", - "--train 1 \\\n", - "--restore_checkpoint 1 \\\n", - "--test_tf_nonstreaming 0 \\\n", - "--test_tflite_nonstreaming 0 \\\n", - "--test_tflite_nonstreaming_quantized 0 \\\n", - "--test_tflite_streaming 0 \\\n", - "--test_tflite_streaming_quantized 1 \\\n", - "--use_weights \"best_weights\" \\\n", - "mixednet \\\n", - "--pointwise_filters \"64,64,64,64\" \\\n", - "--repeat_in_block \"1, 1, 1, 1\" \\\n", - "--mixconv_kernel_sizes '[5], [7,11], [9,15], [23]' \\\n", - "--residual_connection \"0,0,0,0\" \\\n", - "--first_conv_filters 32 \\\n", - "--first_conv_kernel_size 5 \\\n", - "--stride 3" - ] - }, + "source": [ + "# Trains a model. When finished, it will quantize and convert the model to a\n", + "# streaming version suitable for on-device detection.\n", + "# It will resume if stopped, but it will start over at the configured training\n", + "# steps in the yaml file.\n", + "# Change --train 0 to only convert and test the best-weighted model.\n", + "# On Google colab, it doesn't print the mini-batch results, so it may appear\n", + "# stuck for several minutes! Additionally, it is very slow compared to training\n", + "# on a local GPU.\n", + "import os\n", + "os.environ['LD_LIBRARY_PATH'] = \"/usr/lib/x86_64-linux-gnu:\" + os.environ.get('LD_LIBRARY_PATH', '')\n", + "\n", + "!\"{sys.executable}\" -m microwakeword.model_train_eval \\\n", + "--training_config='training_parameters.yaml' \\\n", + "--train 1 \\\n", + "--restore_checkpoint 1 \\\n", + "--test_tf_nonstreaming 0 \\\n", + "--test_tflite_nonstreaming 0 \\\n", + "--test_tflite_nonstreaming_quantized 0 \\\n", + "--test_tflite_streaming 0 \\\n", + "--test_tflite_streaming_quantized 1 \\\n", + "--use_weights \"best_weights\" \\\n", + "mixednet \\\n", + "--pointwise_filters \"64,64,64,64\" \\\n", + "--repeat_in_block \"1, 1, 1, 1\" \\\n", + "--mixconv_kernel_sizes '[5], [7,11], [9,15], [23]' \\\n", + "--residual_connection \"0,0,0,0\" \\\n", + "--first_conv_filters 32 \\\n", + "--first_conv_kernel_size 5 \\\n", + "--stride 3" + ] + }, { "cell_type": "code", "execution_count": null, @@ -539,20 +541,21 @@ }, "outputs": [], "source": [ - "# Downloads the tflite model file. To use on the device, you need to write a\n", - "# Model JSON file. See https://esphome.io/components/micro_wake_word for the\n", - "# documentation and\n", - "# https://github.com/esphome/micro-wake-word-models/tree/main/models/v2 for\n", - "# examples. Adjust the probability threshold based on the test results obtained\n", - "# after training is finished. You may also need to increase the Tensor arena\n", - "# model size if the model fails to load.\n", + "import shutil\n", + "from IPython.display import FileLink\n", "\n", - "from google.colab import files\n", + "# Define the source path and desired download location\n", + "source_path = \"trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite\"\n", + "destination_path = \"./stream_state_internal_quant.tflite\"\n", "\n", - "files.download(f\"trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite\")" + "# Copy the file to the current working directory\n", + "shutil.copy(source_path, destination_path)\n", + "\n", + "# Generate a link to download the file\n", + "print(\"Download your file:\")\n", + "FileLink(destination_path)" ] - } - ], + }, "metadata": { "accelerator": "GPU", "colab": {