mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-13 04:20:19 -06:00
Merge pull request #21 from TaterTotterson/recorder-ui
Recorder UI update
This commit is contained in:
135
.bashrc
Normal file
135
.bashrc
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
# ~/.bashrc: executed by bash(1) for non-login shells.
|
||||||
|
# see /usr/share/doc/bash/examples/startup-files (in the package bash-doc)
|
||||||
|
# for examples
|
||||||
|
|
||||||
|
# If not running interactively, don't do anything
|
||||||
|
[ -z "$PS1" ] && return
|
||||||
|
|
||||||
|
# don't put duplicate lines in the history. See bash(1) for more options
|
||||||
|
# ... or force ignoredups and ignorespace
|
||||||
|
HISTCONTROL=ignoredups:ignorespace
|
||||||
|
|
||||||
|
# append to the history file, don't overwrite it
|
||||||
|
shopt -s histappend
|
||||||
|
|
||||||
|
# for setting history length see HISTSIZE and HISTFILESIZE in bash(1)
|
||||||
|
HISTSIZE=1000
|
||||||
|
HISTFILESIZE=2000
|
||||||
|
|
||||||
|
# check the window size after each command and, if necessary,
|
||||||
|
# update the values of LINES and COLUMNS.
|
||||||
|
shopt -s checkwinsize
|
||||||
|
|
||||||
|
# make less more friendly for non-text input files, see lesspipe(1)
|
||||||
|
[ -x /usr/bin/lesspipe ] && eval "$(SHELL=/bin/sh lesspipe)"
|
||||||
|
|
||||||
|
# set variable identifying the chroot you work in (used in the prompt below)
|
||||||
|
if [ -z "$debian_chroot" ] && [ -r /etc/debian_chroot ]; then
|
||||||
|
debian_chroot=$(cat /etc/debian_chroot)
|
||||||
|
fi
|
||||||
|
|
||||||
|
# set a fancy prompt (non-color, unless we know we "want" color)
|
||||||
|
case "$TERM" in
|
||||||
|
xterm-color) color_prompt=yes;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
# uncomment for a colored prompt, if the terminal has the capability; turned
|
||||||
|
# off by default to not distract the user: the focus in a terminal window
|
||||||
|
# should be on the output of commands, not on the prompt
|
||||||
|
#force_color_prompt=yes
|
||||||
|
|
||||||
|
if [ -n "$force_color_prompt" ]; then
|
||||||
|
if [ -x /usr/bin/tput ] && tput setaf 1 >&/dev/null; then
|
||||||
|
# We have color support; assume it's compliant with Ecma-48
|
||||||
|
# (ISO/IEC-6429). (Lack of such support is extremely rare, and such
|
||||||
|
# a case would tend to support setf rather than setaf.)
|
||||||
|
color_prompt=yes
|
||||||
|
else
|
||||||
|
color_prompt=
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$color_prompt" = yes ]; then
|
||||||
|
PS1='${debian_chroot:+($debian_chroot)}\[\033[01;32m\]\u@\h\[\033[00m\]:\[\033[01;34m\]\w\[\033[00m\]\$ '
|
||||||
|
else
|
||||||
|
PS1='${debian_chroot:+($debian_chroot)}\u@\h:\w\$ '
|
||||||
|
fi
|
||||||
|
unset color_prompt force_color_prompt
|
||||||
|
|
||||||
|
# If this is an xterm set the title to user@host:dir
|
||||||
|
case "$TERM" in
|
||||||
|
xterm*|rxvt*)
|
||||||
|
PS1="\[\e]0;${debian_chroot:+($debian_chroot)}\u@\h: \w\a\]$PS1"
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
# enable color support of ls and also add handy aliases
|
||||||
|
if [ -x /usr/bin/dircolors ]; then
|
||||||
|
test -r ~/.dircolors && eval "$(dircolors -b ~/.dircolors)" || eval "$(dircolors -b)"
|
||||||
|
alias ls='ls --color=auto'
|
||||||
|
#alias dir='dir --color=auto'
|
||||||
|
#alias vdir='vdir --color=auto'
|
||||||
|
|
||||||
|
alias grep='grep --color=auto'
|
||||||
|
alias fgrep='fgrep --color=auto'
|
||||||
|
alias egrep='egrep --color=auto'
|
||||||
|
fi
|
||||||
|
|
||||||
|
# some more ls aliases
|
||||||
|
alias ll='ls -alF'
|
||||||
|
alias la='ls -A'
|
||||||
|
alias l='ls -CF'
|
||||||
|
|
||||||
|
# Alias definitions.
|
||||||
|
# You may want to put all your additions into a separate file like
|
||||||
|
# ~/.bash_aliases, instead of adding them here directly.
|
||||||
|
# See /usr/share/doc/bash-doc/examples in the bash-doc package.
|
||||||
|
|
||||||
|
if [ -f ~/.bash_aliases ]; then
|
||||||
|
. ~/.bash_aliases
|
||||||
|
fi
|
||||||
|
|
||||||
|
# enable programmable completion features (you don't need to enable
|
||||||
|
# this, if it's already enabled in /etc/bash.bashrc and /etc/profile
|
||||||
|
# sources /etc/bash.bashrc).
|
||||||
|
#if [ -f /etc/bash_completion ] && ! shopt -oq posix; then
|
||||||
|
# . /etc/bash_completion
|
||||||
|
#fi
|
||||||
|
|
||||||
|
if [ -f /data/.bashrc ]; then
|
||||||
|
. /data/.bashrc
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ! mountpoint -q /data ; then
|
||||||
|
cat <<-EOF >&2
|
||||||
|
=======================================================
|
||||||
|
WARNING: The /data directory is NOT mounted.
|
||||||
|
Running the training process without /data mounted
|
||||||
|
could add over 140Gb of python packages and training
|
||||||
|
files to this container's storage which is probably
|
||||||
|
NOT what you want.
|
||||||
|
|
||||||
|
You should remove this container and re-create it with
|
||||||
|
a 'docker run' option like '-v <host_work_dir>:/data'
|
||||||
|
making sure the host directory is on a device that has
|
||||||
|
enough free space.
|
||||||
|
=======================================================
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -d /data/.venv ]; then
|
||||||
|
. /data/.venv/bin/activate
|
||||||
|
else
|
||||||
|
cat <<-EOF >&2
|
||||||
|
=======================================================
|
||||||
|
WARNING: A python virtual environment wasn't found
|
||||||
|
at /data/.venv. You'll need to run 'setup_python_venv'
|
||||||
|
before you'll be able to use this container for
|
||||||
|
training.
|
||||||
|
=======================================================
|
||||||
|
EOF
|
||||||
|
|
||||||
|
fi
|
||||||
|
alias venv='[ -d /data/.venv ] && source /data/.venv/bin/activate || echo "/data/.venv does not exist yet"'
|
||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,2 +1,2 @@
|
|||||||
personal_samples/*
|
personal_samples/*
|
||||||
|
.DS_Store
|
||||||
201
LICENSE
201
LICENSE
@@ -1,201 +0,0 @@
|
|||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [yyyy] [name of copyright owner]
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
160
README.md
160
README.md
@@ -1,20 +1,14 @@
|
|||||||
<div align="center">
|
# microWakeWord Nvidia Trainer & Recorder
|
||||||
<img src="https://raw.githubusercontent.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker/refs/heads/main/mmw.png" alt="MicroWakeWord Trainer Logo" width="100" />
|
|
||||||
<h1>microWakeWord Trainer Docker</h1>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
# 🥔 MicroWakeWord Trainer – Tater Approved
|
Train **microWakeWord** detection models using a simple **web-based recorder + trainer UI**, packaged in a Docker container.
|
||||||
|
|
||||||
**✅ Tater Totterson tested & working on an NVIDIA RTX 3070 Laptop GPU (8 GB VRAM).**
|
No Jupyter notebooks required. No manual cell execution. Just record your voice (optional) and train.
|
||||||
Easily train microWakeWord detection models with this pre-built Docker image and JupyterLab notebook.
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 🚀 Quick Start
|
## 🚀 Quick Start
|
||||||
|
|
||||||
Follow these steps to get up and running:
|
### 1️⃣ Pull the Docker Image
|
||||||
|
|
||||||
### 1️⃣ Pull the Pre-Built Docker Image
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker pull ghcr.io/tatertotterson/microwakeword:latest
|
docker pull ghcr.io/tatertotterson/microwakeword:latest
|
||||||
@@ -22,102 +16,118 @@ docker pull ghcr.io/tatertotterson/microwakeword:latest
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### 2️⃣ Run the Docker Container
|
### 2️⃣ Run the Container
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --rm -it \
|
docker run --rm -it \
|
||||||
--gpus all \
|
--gpus all \
|
||||||
-p 8888:8888 \
|
-p 8888:8888 \
|
||||||
-v $(pwd):/data \
|
-v $(pwd):/data \
|
||||||
ghcr.io/tatertotterson/microwakeword:latest
|
ghcr.io/tatertotterson/microwakeword:latest
|
||||||
```
|
```
|
||||||
|
|
||||||
**What these flags do:**
|
**What these flags do:**
|
||||||
- `--gpus all` → Enables GPU acceleration
|
- `--gpus all` → Enables GPU acceleration
|
||||||
- `-p 8888:8888` → Exposes JupyterLab on port 8888
|
- `-p 8888:8888` → Exposes the Recorder + Trainer WebUI
|
||||||
- `-v $(pwd):/data` → Saves your work in the current folder
|
- `-v $(pwd):/data` → Persists all models, datasets, and cache
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### 3️⃣ Open JupyterLab
|
### 3️⃣ Open the Recorder WebUI
|
||||||
|
|
||||||
Visit [http://localhost:8888](http://localhost:8888) in your browser — the notebook UI will open.
|
Open your browser and go to:
|
||||||
|
|
||||||
|
👉 **http://localhost:8888**
|
||||||
|
|
||||||
|
You’ll see the **microWakeWord Recorder & Trainer UI**.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### 4️⃣ Set Your Wake Word
|
## 🎤 Recording Voice Samples (Optional)
|
||||||
|
|
||||||
At the **top of the notebook**, find this line:
|
Personal voice recordings are **optional**.
|
||||||
|
|
||||||
```bash
|
- You may **record your own voice** for better accuracy
|
||||||
TARGET_WORD = "hey_tater" # Change this to your desired wake word
|
- Or simply **click “Train” without recording anything**
|
||||||
|
|
||||||
|
If no recordings are present, training will proceed using **synthetic TTS samples only**.
|
||||||
|
|
||||||
|
### Remote systems (important)
|
||||||
|
If you are running this on a **remote PC / server**, browser-based recording will not work unless:
|
||||||
|
- You use a **reverse proxy** (HTTPS + mic permissions), **or**
|
||||||
|
- You access the UI via **localhost** on the same machine
|
||||||
|
|
||||||
|
Training itself works fine remotely — only recording requires local microphone access.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🧠 Training Behavior (Important Notes)
|
||||||
|
|
||||||
|
### ⏬ First training run
|
||||||
|
The **first time you click Train**, the system will download **large training datasets** (background noise, speech corpora, etc.).
|
||||||
|
|
||||||
|
- This can take **several minutes**
|
||||||
|
- This happens **only once**
|
||||||
|
- Data is cached inside `/data`
|
||||||
|
|
||||||
|
You **will NOT need to download these again** unless you delete `/data`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 🔁 Re-training is safe and incremental
|
||||||
|
|
||||||
|
- You can train **multiple wake words** back-to-back
|
||||||
|
- You do **NOT** need to clear any folders between runs
|
||||||
|
- Old models are preserved in timestamped output directories
|
||||||
|
- All required cleanup and reuse logic is handled automatically
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 📦 Output Files
|
||||||
|
|
||||||
|
When training completes, you’ll get:
|
||||||
|
- `<wake_word>.tflite` – quantized streaming model
|
||||||
|
- `<wake_word>.json` – ESPHome-compatible metadata
|
||||||
|
|
||||||
|
Both are saved under:
|
||||||
|
|
||||||
|
```text
|
||||||
|
/data/output/
|
||||||
```
|
```
|
||||||
|
|
||||||
Change `"hey_tater"` to your desired wake word (phonetic spellings often work best).
|
Each run is placed in its own timestamped folder.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### 5️⃣ Run the Notebook
|
## 🎤 Optional: Personal Voice Samples (Advanced)
|
||||||
|
|
||||||
Run all cells in the notebook. This process will:
|
If you record personal samples:
|
||||||
- Generate wake word samples
|
- They are automatically augmented
|
||||||
- Train a detection model
|
- They are **up-weighted during training**
|
||||||
- Output a quantized `.tflite` model ready for on-device use
|
- This significantly improves real-world accuracy
|
||||||
|
|
||||||
|
No configuration required — detection is automatic.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### 6️⃣ Retrieve the Trained Model & JSON
|
## 🔄 Resetting Everything (Optional)
|
||||||
|
|
||||||
When training finishes, download links for both the `.tflite` model and its `.json` manifest will be displayed in the last cell.
|
If you want a **completely clean slate**:
|
||||||
|
|
||||||
---
|
Delete the /data folder
|
||||||
|
|
||||||
## 🔄 Resetting to a Clean State
|
Then restart the container.
|
||||||
|
|
||||||
If you need to start fresh:
|
⚠️ This will:
|
||||||
|
- Remove cached datasets
|
||||||
1. Delete the `data` folder that was mapped to your Docker container.
|
- Require re-downloading training data
|
||||||
2. Restart the container using the steps above.
|
- Delete trained models
|
||||||
3. A fresh copy of the notebook will be placed into the `data` directory.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 🎤 Optional: Personal Voice Samples
|
|
||||||
|
|
||||||
In addition to synthetic TTS samples, the trainer can optionally use your own real voice recordings to significantly improve accuracy for your voice and environment.
|
|
||||||
|
|
||||||
### How it works
|
|
||||||
- If a folder named personal_samples/ exists and contains .wav files, the trainer will:
|
|
||||||
- Automatically extract features from those recordings
|
|
||||||
- Include them during training alongside the synthetic TTS data
|
|
||||||
- Up-weight your personal samples during training for better real-world performance
|
|
||||||
|
|
||||||
No extra flags or configuration are required — it is detected automatically.
|
|
||||||
|
|
||||||
### How to use it
|
|
||||||
1. Create a folder in the repo root:
|
|
||||||
mkdir personal_samples
|
|
||||||
|
|
||||||
2. Record yourself saying the wake word naturally and save the files as .wav:
|
|
||||||
personal_samples/
|
|
||||||
hey_tater_01.wav
|
|
||||||
hey_tater_02.wav
|
|
||||||
hey_tater_03.wav
|
|
||||||
...
|
|
||||||
|
|
||||||
3. Run the training script as normal:
|
|
||||||
|
|
||||||
If personal samples are found, you’ll see a message during training indicating they are being included.
|
|
||||||
|
|
||||||
### Recording tips
|
|
||||||
- 10–30 recordings is usually enough to see a noticeable improvement
|
|
||||||
- Vary distance, volume, and tone slightly
|
|
||||||
- Record in the same environment where the wake word will be used (room noise matters)
|
|
||||||
- Use 16-bit WAV files if possible (most recorders do this by default)
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 🙌 Credits
|
## 🙌 Credits
|
||||||
|
|
||||||
This project builds upon the excellent work of [kahrendt/microWakeWord](https://github.com/kahrendt/microWakeWord).
|
Built on top of the excellent
|
||||||
Huge thanks to the original authors for their contributions to the open-source community!
|
**https://github.com/kahrendt/microWakeWord**
|
||||||
|
|
||||||
|
Huge thanks to the original authors ❤️
|
||||||
53
cli/cudainfo
Executable file
53
cli/cudainfo
Executable file
@@ -0,0 +1,53 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import sys, glob
|
||||||
|
|
||||||
|
devices = glob.glob("/dev/nvidia[0-9]")
|
||||||
|
if len(devices) == 0:
|
||||||
|
print("CUDA not available or no CUDA-capable GPU found.")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
cc_cores_per_SM_dict = {
|
||||||
|
(2,0) : 32,
|
||||||
|
(2,1) : 48,
|
||||||
|
(3,0) : 192,
|
||||||
|
(3,5) : 192,
|
||||||
|
(3,7) : 192,
|
||||||
|
(5,0) : 128,
|
||||||
|
(5,2) : 128,
|
||||||
|
(6,0) : 64,
|
||||||
|
(6,1) : 128,
|
||||||
|
(7,0) : 64,
|
||||||
|
(7,5) : 64,
|
||||||
|
(8,0) : 64,
|
||||||
|
(8,6) : 128,
|
||||||
|
(8,9) : 128,
|
||||||
|
(9,0) : 128,
|
||||||
|
(10,0) : 128,
|
||||||
|
(12,0) : 128
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
from numba import cuda
|
||||||
|
device = cuda.get_current_device()
|
||||||
|
ctx = cuda.current_context()
|
||||||
|
meminfo = ctx.get_memory_info()
|
||||||
|
compute_capability = device.compute_capability
|
||||||
|
sms = getattr(device, 'MULTIPROCESSOR_COUNT')
|
||||||
|
cores_per_sm = cc_cores_per_SM_dict.get(compute_capability)
|
||||||
|
if not cores_per_sm:
|
||||||
|
cores_per_sm = "unknown"
|
||||||
|
total_cores = "unknown"
|
||||||
|
else:
|
||||||
|
total_cores = cores_per_sm * sms
|
||||||
|
|
||||||
|
print(f" GPU Name: {device.name if type(device.name) is str else device.name.decode()}")
|
||||||
|
print(f" Compute Capability: {'.'.join(list(map(str, compute_capability))):>7}")
|
||||||
|
print(f"Streaming Multiprocessors: {sms:>7}")
|
||||||
|
print(f" CUDA Cores per SM: {cores_per_sm:>7}")
|
||||||
|
print(f" Total CUDA Cores: {total_cores:>7}")
|
||||||
|
print(f" Total Memory: {meminfo.total / 1024 / 1024:>7.0f} mb")
|
||||||
|
print(f" Free Memory: {meminfo.free / 1024 / 1024:>7.0f} mb")
|
||||||
|
except Exception as e:
|
||||||
|
print("CUDA not available or no CUDA-capable GPU found.")
|
||||||
199
cli/setup_audioset
Executable file
199
cli/setup_audioset
Executable file
@@ -0,0 +1,199 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
PROGPATH=$(realpath "$0")
|
||||||
|
PROGDIR=$(dirname "${PROGPATH}")
|
||||||
|
|
||||||
|
source "${PROGDIR}/shell.functions"
|
||||||
|
|
||||||
|
if [ "${HELP}" == "true" ] ; then
|
||||||
|
cat <<EOF >&2
|
||||||
|
Usage: $0 [ --cleanup-archives ] [ --cleanup-input-files ] [ --data-dir=<data_dir> ]
|
||||||
|
|
||||||
|
--cleanup-archives : Automatically clean up any downloaded archvies after
|
||||||
|
extraction.
|
||||||
|
--cleanup-intermediate-files
|
||||||
|
: Automatically clean up the intermediate files after they've
|
||||||
|
: converted to 16k.
|
||||||
|
<data_dir> : Path to the data directory.
|
||||||
|
: Default: ${DATA_DIR}
|
||||||
|
|
||||||
|
EOF
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
mkdir -p "${DATA_DIR}/training_datasets/downloads" || :
|
||||||
|
cd "${DATA_DIR}/training_datasets"
|
||||||
|
|
||||||
|
echo "***** Checking audioset *****"
|
||||||
|
|
||||||
|
AUDIO_URL="https://huggingface.co/datasets/agkphysics/AudioSet/resolve"
|
||||||
|
AUDIO_DIR="./audioset"
|
||||||
|
mkdir -p "${AUDIO_DIR}"
|
||||||
|
AUDIO16K_DIR="./audioset_16k"
|
||||||
|
mkdir -p "${AUDIO16K_DIR}"
|
||||||
|
AUDIO_FILECOUNT="./downloads/audioset_filecount"
|
||||||
|
AUDIO_IN_GLOB="*.flac"
|
||||||
|
|
||||||
|
declare -A filecounts
|
||||||
|
for i in {0..9} ; do
|
||||||
|
fname="bal_train0${i}.tar"
|
||||||
|
filecounts[${fname}]=0
|
||||||
|
done
|
||||||
|
|
||||||
|
get_filecounts filecounts "${AUDIO_FILECOUNT}"
|
||||||
|
|
||||||
|
|
||||||
|
REV_CANDIDATES=(
|
||||||
|
"6762f044d1c88619c7f2006486036192128fb07e"
|
||||||
|
"0049167e89f259a010c3f070fe3666d9e5242836"
|
||||||
|
"ceb9eaaa7844c9ad7351e659c84a572e376ad06d"
|
||||||
|
"main"
|
||||||
|
)
|
||||||
|
|
||||||
|
TAR_PATTERNS=(
|
||||||
|
"data/bal_train0"
|
||||||
|
"data/bal_train/bal_train0"
|
||||||
|
)
|
||||||
|
|
||||||
|
find_rev() {
|
||||||
|
for rev in "${REV_CANDIDATES[@]}" ; do
|
||||||
|
for pattern in "${TAR_PATTERNS[@]}" ; do
|
||||||
|
url="https://huggingface.co/datasets/agkphysics/AudioSet/resolve/${rev}/${pattern}0.tar"
|
||||||
|
curl -I -L --fail -s "${url}" > /dev/null && echo "${rev},${pattern}"
|
||||||
|
done
|
||||||
|
done
|
||||||
|
echo ""
|
||||||
|
}
|
||||||
|
|
||||||
|
converter() {
|
||||||
|
# shellcheck source=/dev/null
|
||||||
|
source "${DATA_DIR}/.venv/bin/activate"
|
||||||
|
|
||||||
|
python - "${AUDIO_DIR}" "${AUDIO16K_DIR}" <<-EOF
|
||||||
|
import os, sys
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import scipy.io.wavfile
|
||||||
|
import librosa
|
||||||
|
|
||||||
|
def write_wav(dst: Path, data: np.ndarray, sr: int):
|
||||||
|
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
x = np.clip(data, -1.0, 1.0)
|
||||||
|
scipy.io.wavfile.write(dst, sr, (x * 32767).astype(np.int16))
|
||||||
|
|
||||||
|
audioset_dir = Path(sys.argv[1])
|
||||||
|
audioset_out = Path(sys.argv[2])
|
||||||
|
|
||||||
|
flacs = list(audioset_dir.rglob("*.flac"))
|
||||||
|
total = len(flacs)
|
||||||
|
print(f" FLAC files: {total}")
|
||||||
|
print(" Converting AudioSet → 16k mono WAV")
|
||||||
|
print(" Sit tight — this step can take a while.")
|
||||||
|
print("")
|
||||||
|
|
||||||
|
audioset_bad = []
|
||||||
|
ok = 0
|
||||||
|
skipped = 0
|
||||||
|
|
||||||
|
START = datetime.now(timezone.utc).replace(microsecond=0)
|
||||||
|
|
||||||
|
# Heartbeat interval (prints every N files)
|
||||||
|
HEARTBEAT_EVERY = 500
|
||||||
|
|
||||||
|
for idx, p in enumerate(flacs, start=1):
|
||||||
|
try:
|
||||||
|
outfile = audioset_out / (p.stem + ".wav")
|
||||||
|
if outfile.exists():
|
||||||
|
skipped += 1
|
||||||
|
else:
|
||||||
|
y, _ = librosa.load(p, sr=16000, mono=True)
|
||||||
|
if y.size == 0:
|
||||||
|
raise ValueError("empty audio")
|
||||||
|
write_wav(outfile, y, 16000)
|
||||||
|
ok += 1
|
||||||
|
except Exception as e:
|
||||||
|
audioset_bad.append(f"{p}:{e}")
|
||||||
|
|
||||||
|
if idx == 1 or (idx % HEARTBEAT_EVERY) == 0 or idx == total:
|
||||||
|
print(f" Progress: {idx}/{total} (ok={ok}, skipped={skipped}, failed={len(audioset_bad)})")
|
||||||
|
|
||||||
|
if audioset_bad:
|
||||||
|
(audioset_out / "audioset_corrupted_files.log").write_text("\n".join(audioset_bad))
|
||||||
|
|
||||||
|
END = datetime.now(timezone.utc).replace(microsecond=0)
|
||||||
|
elapsed = END - START
|
||||||
|
print("")
|
||||||
|
print(f" AudioSet complete ({ok} ok, {skipped} skipped, {len(audioset_bad)} failed) Elapsed: {elapsed}")
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
expected_filecount=$(get_total_filecount filecounts)
|
||||||
|
actual_filecount=$(find "${AUDIO16K_DIR}" -name "*.wav" 2>/dev/null | wc -l) || :
|
||||||
|
write_filecount=false
|
||||||
|
|
||||||
|
# Option B behavior: if we already have output WAVs, don't re-download/re-extract/re-convert
|
||||||
|
if [ "${actual_filecount}" -ne 0 ] ; then
|
||||||
|
echo " Existing ${AUDIO16K_DIR} present (${actual_filecount} wav); skipping extract/convert"
|
||||||
|
else
|
||||||
|
dl=$(find_rev)
|
||||||
|
[ -n "$dl" ] || { echo " Could not locate an AudioSet revision with FLAC tarballs still present on HF." ; exit 1 ; }
|
||||||
|
rev=${dl%%,*}
|
||||||
|
pattern=${dl##*,}
|
||||||
|
|
||||||
|
echo " Checking 10 tarballs"
|
||||||
|
for i in {0..9} ; do
|
||||||
|
fname="downloads/bal_train0${i}.tar"
|
||||||
|
if [ ! -f "${fname}" ] ; then
|
||||||
|
echo " Downloading bal_train0${i}.tar"
|
||||||
|
url="${AUDIO_URL}/${rev}/${pattern}${i}.tar"
|
||||||
|
curl -L -s --fail "${url}" -o "${fname}" || { echo "Could not fetch ${fname} at rev ${rev}; continuing." ; continue ; }
|
||||||
|
fi
|
||||||
|
|
||||||
|
tarball_filecount=$(tar -tvf "${fname}" | wc -l )
|
||||||
|
filecounts["bal_train0${i}.tar"]=${tarball_filecount}
|
||||||
|
write_filecount=true
|
||||||
|
|
||||||
|
echo " Untarring bal_train0${i}.tar"
|
||||||
|
tar -xf "${fname}" -C "${AUDIO_DIR}"
|
||||||
|
if "${CLEANUP_ARCHIVES}" && [ -f "${fname}" ] ; then
|
||||||
|
echo " Cleaning up bal_train0${i}.tar"
|
||||||
|
rm -rf "${fname}"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
rm -rf "${AUDIO16K_DIR}/audioset_corrupted_files.log" || :
|
||||||
|
converter
|
||||||
|
|
||||||
|
# Recompute counts and warn (but do not fail)
|
||||||
|
expected_filecount=$(get_total_filecount filecounts)
|
||||||
|
actual_filecount=$(find "${AUDIO16K_DIR}" -name "*.wav" 2>/dev/null | wc -l) || :
|
||||||
|
if [ "${actual_filecount}" -ne "${expected_filecount}" ] ; then
|
||||||
|
echo " Converted file count(${actual_filecount}) != expected file count(${expected_filecount})" >&2
|
||||||
|
echo " WARNING: mismatch is expected if some AudioSet files are corrupted; continuing." >&2
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ${write_filecount} ; then
|
||||||
|
write_filecounts filecounts "${AUDIO_FILECOUNT}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if "${CLEANUP_ARCHIVES}" ; then
|
||||||
|
for i in {0..9} ; do
|
||||||
|
fname="downloads/bal_train0${i}.tar"
|
||||||
|
if [ -f "${fname}" ] ; then
|
||||||
|
echo " Cleaning up bal_train0${i}.tar"
|
||||||
|
rm -rf "${fname}"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if "${CLEANUP_INTERMEDIATE_FILES}" && [ -d "${AUDIO_DIR}" ] ; then
|
||||||
|
echo " Cleaning up ${AUDIO_DIR}"
|
||||||
|
rm -rf "${AUDIO_DIR}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo " Audioset complete"
|
||||||
|
exit 0
|
||||||
131
cli/setup_fma
Executable file
131
cli/setup_fma
Executable file
@@ -0,0 +1,131 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
PROGPATH=$(realpath "$0")
|
||||||
|
PROGDIR=$(dirname "${PROGPATH}")
|
||||||
|
|
||||||
|
source "${PROGDIR}/shell.functions"
|
||||||
|
|
||||||
|
if [ "${HELP}" == "true" ] ; then
|
||||||
|
cat <<EOF >&2
|
||||||
|
Usage: $0 [ --cleanup-archives ] [ --cleanup-input-files ] [ --data-dir=<data_dir> ]
|
||||||
|
|
||||||
|
--cleanup-archives : Automatically clean up any downloaded archvies after
|
||||||
|
extraction.
|
||||||
|
--cleanup-intermediate-files
|
||||||
|
: Automatically clean up the intermediate files after they've
|
||||||
|
: converted to 16k.
|
||||||
|
<data_dir> : Path to the data directory.
|
||||||
|
: Default: ${DATA_DIR}
|
||||||
|
|
||||||
|
EOF
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
mkdir -p "${DATA_DIR}/training_datasets/downloads" || :
|
||||||
|
cd "${DATA_DIR}/training_datasets"
|
||||||
|
|
||||||
|
echo "***** Checking FMA *****"
|
||||||
|
|
||||||
|
AUDIO_URL="https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/fma_xs.zip"
|
||||||
|
AUDIO_ZIPFILE="fma_xs.zip"
|
||||||
|
AUDIO_ZIP="./downloads/${AUDIO_ZIPFILE}"
|
||||||
|
AUDIO_DIR="fma"
|
||||||
|
mkdir -p "${AUDIO_DIR}" || :
|
||||||
|
AUDIO16K_DIR="fma_16k"
|
||||||
|
mkdir -p "${AUDIO16K_DIR}" || :
|
||||||
|
AUDIO_FILECOUNT="./downloads/fma_filecount"
|
||||||
|
AUDIO_IN_GLOB="*.mp3"
|
||||||
|
|
||||||
|
declare -A filecounts=( [${AUDIO_ZIPFILE}]=0 )
|
||||||
|
get_filecounts filecounts "${AUDIO_FILECOUNT}"
|
||||||
|
|
||||||
|
converter() {
|
||||||
|
source ${DATA_DIR}/.venv/bin/activate
|
||||||
|
python - "${AUDIO_DIR}" "${AUDIO16K_DIR}" <<-EOF
|
||||||
|
import os, sys, subprocess, scipy.io.wavfile, numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
import soundfile as sf
|
||||||
|
import librosa
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
def write_wav(dst: Path, data: np.ndarray, sr: int):
|
||||||
|
x = np.clip(data, -1.0, 1.0)
|
||||||
|
scipy.io.wavfile.write(dst, sr, (x * 32767).astype(np.int16))
|
||||||
|
|
||||||
|
fma_dir = Path(sys.argv[1])
|
||||||
|
fma_out = Path(sys.argv[2])
|
||||||
|
|
||||||
|
# convert MP3 → 16k mono WAV
|
||||||
|
mp3s = list(fma_dir.rglob("*.mp3"))
|
||||||
|
print(f" MP3 files: {len(mp3s)}")
|
||||||
|
fma_bad = []
|
||||||
|
ok = 0
|
||||||
|
for p in tqdm(mp3s, desc=" FMA→WAV (resample 16k mono)"):
|
||||||
|
try:
|
||||||
|
outfile = Path(fma_out / (p.stem + ".wav"))
|
||||||
|
if outfile.exists():
|
||||||
|
continue
|
||||||
|
y, _ = librosa.load(p, sr=16000, mono=True)
|
||||||
|
if y.size == 0:
|
||||||
|
raise ValueError("empty audio")
|
||||||
|
write_wav(outfile, y, 16000)
|
||||||
|
ok += 1
|
||||||
|
except Exception as e:
|
||||||
|
fma_bad.append(f"{p}:{e}")
|
||||||
|
|
||||||
|
if fma_bad:
|
||||||
|
(fma_out / "fma_corrupted_files.log").write_text("\n".join(fma_bad))
|
||||||
|
print(f" FMA complete ({ok} ok, {len(fma_bad)} failed)")
|
||||||
|
EOF
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
expected_filecount=${filecounts[${AUDIO_ZIPFILE}]}
|
||||||
|
actual_filecount=$(find ${AUDIO16K_DIR} -name '*.wav' 2>/dev/null | wc -l) || :
|
||||||
|
write_filecount=false
|
||||||
|
|
||||||
|
if [ "${actual_filecount}" -ne 0 ] && [ "${actual_filecount}" -eq "${expected_filecount}" ] ; then
|
||||||
|
echo " Existing FMA valid"
|
||||||
|
else
|
||||||
|
actual_filecount=$(find "${AUDIO_DIR}" -name "${AUDIO_IN_GLOB}" 2>/dev/null | wc -l) || :
|
||||||
|
if [ "${actual_filecount}" -eq 0 ] || [ "${actual_filecount}" -ne "${expected_filecount}" ] ; then
|
||||||
|
if [ ! -f "${AUDIO_ZIP}" ] ; then
|
||||||
|
echo " Downloading ${AUDIO_ZIPFILE}"
|
||||||
|
curl -sfL "${AUDIO_URL}" -o "${AUDIO_ZIP}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
rm -rf "${AUDIO_DIR}" || :
|
||||||
|
mkdir "${AUDIO_DIR}"
|
||||||
|
echo " Unzipping ${AUDIO_ZIPFILE}"
|
||||||
|
unzip -q -d "${AUDIO_DIR}" "${AUDIO_ZIP}"
|
||||||
|
fi
|
||||||
|
if "${CLEANUP_ARCHIVES}" && [ -f "${AUDIO_ZIP}" ] ; then
|
||||||
|
echo " Cleaning up ${AUDIO_ZIPFILE}"
|
||||||
|
rm -rf "${AUDIO_ZIP}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
converter
|
||||||
|
|
||||||
|
actual_filecount=$(find "${AUDIO16K_DIR}" -name "*.wav" 2>/dev/null | wc -l) || :
|
||||||
|
filecounts[${AUDIO_ZIPFILE}]="${actual_filecount}"
|
||||||
|
write_filecount=true
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ${write_filecount} ; then
|
||||||
|
write_filecounts filecounts "${AUDIO_FILECOUNT}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if "${CLEANUP_ARCHIVES}" && [ -f "${AUDIO_ZIP}" ] ; then
|
||||||
|
echo " Cleaning up ${AUDIO_ZIPFILE}"
|
||||||
|
rm -rf "${AUDIO_ZIP}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if "${CLEANUP_INTERMEDIATE_FILES}" && [ -d "${AUDIO_DIR}" ]; then
|
||||||
|
echo " Cleaning up ${AUDIO_DIR}"
|
||||||
|
rm -rf "${AUDIO_DIR}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo " FMA complete"
|
||||||
|
exit 0
|
||||||
|
|
||||||
124
cli/setup_mit_audio
Executable file
124
cli/setup_mit_audio
Executable file
@@ -0,0 +1,124 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
PROGPATH=$(realpath "$0")
|
||||||
|
PROGDIR=$(dirname "${PROGPATH}")
|
||||||
|
|
||||||
|
source "${PROGDIR}/shell.functions"
|
||||||
|
|
||||||
|
if [ "${HELP}" == "true" ] ; then
|
||||||
|
cat <<EOF >&2
|
||||||
|
Usage: $0 [ --cleanup-archives ] [ --cleanup-input-files ] [ --data-dir=<data_dir> ]
|
||||||
|
|
||||||
|
--cleanup-archives : Automatically clean up any downloaded archvies after
|
||||||
|
extraction.
|
||||||
|
--cleanup-intermediate-files
|
||||||
|
: Automatically clean up the intermediate files after they've
|
||||||
|
: converted to 16k.
|
||||||
|
<data_dir> : Path to the data directory.
|
||||||
|
: Default: ${DATA_DIR}
|
||||||
|
|
||||||
|
EOF
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
mkdir -p "${DATA_DIR}/training_datasets/downloads" || :
|
||||||
|
cd "${DATA_DIR}/training_datasets"
|
||||||
|
|
||||||
|
AUDIO_URL="https://mcdermottlab.mit.edu/Reverb/IRMAudio/Audio.zip"
|
||||||
|
AUDIO_ZIPFILE="MIT_RIR_Audio.zip"
|
||||||
|
AUDIO_ZIP="./downloads/${AUDIO_ZIPFILE}"
|
||||||
|
AUDIO_DIR="./mit_rirs"
|
||||||
|
mkdir -p "${AUDIO_DIR}" || :
|
||||||
|
AUDIO16K_DIR="./mit_rirs_16k"
|
||||||
|
mkdir -p "${AUDIO16K_DIR}" || :
|
||||||
|
AUDIO_FILECOUNT="./downloads/mit_rir_filecount"
|
||||||
|
AUDIO_IN_GLOB="*.wav"
|
||||||
|
|
||||||
|
declare -A filecounts=( [${AUDIO_ZIPFILE}]=0 )
|
||||||
|
get_filecounts filecounts "${AUDIO_FILECOUNT}"
|
||||||
|
|
||||||
|
echo "===== Checking MIT_RIR ====="
|
||||||
|
|
||||||
|
converter() {
|
||||||
|
source ${DATA_DIR}/.venv/bin/activate
|
||||||
|
python - "${AUDIO_DIR}" "${AUDIO16K_DIR}" <<-EOF
|
||||||
|
import os, sys, subprocess, scipy.io.wavfile, numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
import soundfile as sf
|
||||||
|
import librosa
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
def write_wav(dst: Path, data: np.ndarray, sr: int):
|
||||||
|
x = np.clip(data, -1.0, 1.0)
|
||||||
|
scipy.io.wavfile.write(dst, sr, (x * 32767).astype(np.int16))
|
||||||
|
|
||||||
|
rir_in = Path(sys.argv[1])
|
||||||
|
rir_out = Path(sys.argv[2])
|
||||||
|
|
||||||
|
waves = list(rir_in.rglob("*.wav"))
|
||||||
|
try:
|
||||||
|
print(" MIT RIR normalizing to 16k…")
|
||||||
|
# Normalize to 16k mono
|
||||||
|
for p in tqdm(waves, desc=" MIT_RIR (resample 16k mono)"):
|
||||||
|
outfile = Path(rir_out / p.name)
|
||||||
|
if outfile.exists():
|
||||||
|
continue
|
||||||
|
a, sr = sf.read(p, always_2d=False)
|
||||||
|
if a.ndim > 1:
|
||||||
|
a = a[:, 0]
|
||||||
|
if sr != 16000:
|
||||||
|
a, _ = librosa.load(p, sr=16000, mono=True)
|
||||||
|
write_wav(outfile, a, 16000)
|
||||||
|
print(" MIT RIR normalization complete")
|
||||||
|
except Exception as e2:
|
||||||
|
print(f" MIT RIR fallback failed: {e2}")
|
||||||
|
raise
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
expected_filecount=${filecounts[${AUDIO_ZIPFILE}]}
|
||||||
|
actual_filecount=$(find "${AUDIO16K_DIR}" -name '*.wav' 2>/dev/null | wc -l) || :
|
||||||
|
write_filecount=false
|
||||||
|
|
||||||
|
if [ "${actual_filecount}" -ne 0 ] && [ "${actual_filecount}" -eq "${expected_filecount}" ] ; then
|
||||||
|
echo " Existing ${AUDIO16K_DIR} valid"
|
||||||
|
else
|
||||||
|
actual_filecount=$(find "${AUDIO_DIR}" -name "${AUDIO_IN_GLOB}" 2>/dev/null | wc -l) || :
|
||||||
|
if [ "${actual_filecount}" -eq 0 ] || [ "${actual_filecount}" -ne "${expected_filecount}" ] ; then
|
||||||
|
if [ ! -f "${AUDIO_ZIP}" ] ; then
|
||||||
|
echo " Downloading ${AUDIO_ZIPFILE}"
|
||||||
|
curl -sfL "${AUDIO_URL}" -o "${AUDIO_ZIP}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
rm -rf "${AUDIO_DIR}" || :
|
||||||
|
echo " Unzipping ${AUDIO_ZIPFILE}"
|
||||||
|
unzip -u -q -d "${AUDIO_DIR}" "${AUDIO_ZIP}"
|
||||||
|
fi
|
||||||
|
if "${CLEANUP_ARCHIVES}" && [ -f "${AUDIO_ZIP}" ] ; then
|
||||||
|
echo " Cleaning up ${AUDIO_ZIPFILE}"
|
||||||
|
rm -rf "${AUDIO_ZIP}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
converter
|
||||||
|
actual_filecount=$(find "${AUDIO16K_DIR}" -name "*.wav" 2>/dev/null | wc -l) || :
|
||||||
|
filecounts[${AUDIO_ZIPFILE}]="${actual_filecount}"
|
||||||
|
write_filecount=true
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ${write_filecount} ; then
|
||||||
|
write_filecounts filecounts "${AUDIO_FILECOUNT}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if "${CLEANUP_ARCHIVES}" && [ -f "${AUDIO_ZIP}" ] ; then
|
||||||
|
echo " Cleaning up ${AUDIO_ZIPFILE}"
|
||||||
|
rm -rf "${AUDIO_ZIP}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if "${CLEANUP_INTERMEDIATE_FILES}" && [ -d "${AUDIO_DIR}" ]; then
|
||||||
|
echo " Cleaning up ${AUDIO_DIR}"
|
||||||
|
rm -rf "${AUDIO_DIR}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo " MIT_RIR complete"
|
||||||
|
exit 0
|
||||||
85
cli/setup_negative_datasets
Executable file
85
cli/setup_negative_datasets
Executable file
@@ -0,0 +1,85 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
PROGPATH=$(realpath "$0")
|
||||||
|
PROGDIR=$(dirname "${PROGPATH}")
|
||||||
|
|
||||||
|
source "${PROGDIR}/shell.functions"
|
||||||
|
|
||||||
|
if [ "${HELP}" == "true" ] ; then
|
||||||
|
cat <<EOF >&2
|
||||||
|
Usage: $0 [ --cleanup-archives ] [ --data-dir=<data_dir> ]
|
||||||
|
|
||||||
|
--cleanup-archives : Automatically clean up any downloaded archvies after
|
||||||
|
extraction.
|
||||||
|
<data_dir> : Path to the data directory.
|
||||||
|
: Default: ${DATA_DIR}
|
||||||
|
|
||||||
|
EOF
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
mkdir -p "${DATA_DIR}/training_datasets/downloads" || :
|
||||||
|
cd "${DATA_DIR}/training_datasets"
|
||||||
|
|
||||||
|
mkdir -p ./negative_datasets || :
|
||||||
|
|
||||||
|
NEGATIVE_DATASET_URL="https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main"
|
||||||
|
declare -a NEGATIVE_DATASETS=( dinner_party dinner_party_eval no_speech speech )
|
||||||
|
AUDIO_FILECOUNT="./downloads/negative_filecount"
|
||||||
|
|
||||||
|
declare -A filecounts=( [dinner_party.zip]=0 [dinner_party_eval.zip]=0 [no_speech.zip]=0 [speech.zip]=0 )
|
||||||
|
get_filecounts filecounts "${AUDIO_FILECOUNT}"
|
||||||
|
|
||||||
|
echo "===== Checking negative datasets: ${NEGATIVE_DATASETS[*]} ====="
|
||||||
|
write_filecount=false
|
||||||
|
|
||||||
|
for ds in "${NEGATIVE_DATASETS[@]}" ; do
|
||||||
|
AUDIO_ZIPFILE="${ds}.zip"
|
||||||
|
AUDIO_ZIP="./downloads/${AUDIO_ZIPFILE}"
|
||||||
|
AUDIO_DIR="./negative_datasets/${ds}"
|
||||||
|
mkdir -p "${AUDIO_DIR}" || :
|
||||||
|
|
||||||
|
expected_filecount=${filecounts[${AUDIO_ZIPFILE}]}
|
||||||
|
actual_filecount=$(find "${AUDIO_DIR}" -name '*.ninja' 2>/dev/null | wc -l) || :
|
||||||
|
|
||||||
|
if [ "${actual_filecount}" -ne 0 ] && [ "${actual_filecount}" -eq "${expected_filecount}" ] ; then
|
||||||
|
echo " Existing ${ds} valid"
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f "${AUDIO_ZIP}" ] ; then
|
||||||
|
echo " Downloading ${AUDIO_ZIPFILE}"
|
||||||
|
curl -sfL "${NEGATIVE_DATASET_URL}/${ds}.zip" -o "${AUDIO_ZIP}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
rm -rf "${AUDIO_DIR}" || :
|
||||||
|
echo " Unzipping ${AUDIO_ZIPFILE}"
|
||||||
|
unzip -q -d "./negative_datasets" "${AUDIO_ZIP}"
|
||||||
|
actual_filecount=$(find "${AUDIO_DIR}" -name '*.ninja' 2>/dev/null | wc -l) || :
|
||||||
|
filecounts[${AUDIO_ZIPFILE}]="${actual_filecount}"
|
||||||
|
write_filecount=true
|
||||||
|
|
||||||
|
if "${CLEANUP_ARCHIVES}" && [ -f "${AUDIO_ZIP}" ] ; then
|
||||||
|
echo " Cleaning up ${AUDIO_ZIPFILE}"
|
||||||
|
rm -rf "${AUDIO_ZIP}"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
if ${write_filecount} ; then
|
||||||
|
write_filecounts filecounts "${AUDIO_FILECOUNT}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if "${CLEANUP_ARCHIVES}" ; then
|
||||||
|
for ds in "${NEGATIVE_DATASETS[@]}" ; do
|
||||||
|
AUDIO_ZIPFILE="${ds}.zip"
|
||||||
|
AUDIO_ZIP="./downloads/${AUDIO_ZIPFILE}"
|
||||||
|
if [ -f "${AUDIO_ZIP}" ] ; then
|
||||||
|
echo " Cleaning up ${AUDIO_ZIPFILE}"
|
||||||
|
rm -rf "${AUDIO_ZIP}"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo " Negative datasets complete"
|
||||||
|
|
||||||
188
cli/setup_python_venv
Executable file
188
cli/setup_python_venv
Executable file
@@ -0,0 +1,188 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
PROGDIR="$(dirname "$(realpath "$0")")"
|
||||||
|
ROOTDIR="$(dirname "${PROGDIR}")"
|
||||||
|
|
||||||
|
KNOWN_ARGS=( data-dir python gpu no-gpu )
|
||||||
|
source "${PROGDIR}/shell.functions"
|
||||||
|
|
||||||
|
if [ ${#UNKNOWN_ARGS[@]} -gt 0 ] ; then
|
||||||
|
echo "Unknown argument(s): ${UNKNOWN_ARGS[*]}" >&2
|
||||||
|
HELP=true
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "${HELP}" == "true" ] ; then
|
||||||
|
cat <<EOF >&2
|
||||||
|
Usage: setup_python_venv [ --gpu | --no-gpu ] [ --verbose ]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--gpu: Install the GPU-capable versions of packages if available. This
|
||||||
|
is the default if the script detects that a GPU is available.
|
||||||
|
|
||||||
|
--no-gpu: Install the non-GPU-capable versions of packages even if
|
||||||
|
GPU-capable packages are available. This is the default if the script
|
||||||
|
detects that a GPU is NOT available.
|
||||||
|
|
||||||
|
--verbose: Print the detailed "pip install" output.
|
||||||
|
|
||||||
|
EOF
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
[ -n "${DATA_DIR}" ] && DATA_DIR="$(realpath "${DATA_DIR}")"
|
||||||
|
[ -d "${DATA_DIR}" ] || {
|
||||||
|
echo "Data directory '${DATA_DIR}' doesn't exist." >&2
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
cd "${DATA_DIR}"
|
||||||
|
|
||||||
|
[ -z "${GPU}" ] && {
|
||||||
|
GPU=false
|
||||||
|
[ -c /dev/nvidiactl ] && {
|
||||||
|
GPU=true
|
||||||
|
echo " Nvidia GPU detected"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
"${GPU}" || export CUDA_VISIBLE_DEVICES=-1
|
||||||
|
|
||||||
|
VENV="${DATA_DIR}/.venv"
|
||||||
|
[ -n "${VIRTUAL_ENV}" ] && deactivate
|
||||||
|
|
||||||
|
if [ -n "${PYTHON}" ] ; then
|
||||||
|
PYTHONS=( "${PYTHON}" )
|
||||||
|
unset PYTHON
|
||||||
|
else
|
||||||
|
# Add 3.11 as a common middle-ground (especially outside Ubuntu 24.04)
|
||||||
|
PYTHONS=( python3.12 python3.11 python3.10 )
|
||||||
|
fi
|
||||||
|
|
||||||
|
for p in "${PYTHONS[@]}" ; do
|
||||||
|
"${p}" --version &>/dev/null && { PYTHON="${p}" ; break ; }
|
||||||
|
done
|
||||||
|
|
||||||
|
[ -n "${PYTHON}" ] || {
|
||||||
|
echo "A python 3.12/3.11/3.10 interpreter wasn't found. You'll need to install one before proceeding." >&2
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ -d "${VENV}" ] ; then
|
||||||
|
if [ -f "${DATA_DIR}/.mww-data-dir" ] ; then
|
||||||
|
source "${VENV}/bin/activate" || {
|
||||||
|
echo "Unable to activate existing virtualenv '${VENV}'. You should delete it and try again." >&2
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
else
|
||||||
|
rm -rf "${VENV}"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "===== Setting up Python environment ${VENV} ====="
|
||||||
|
|
||||||
|
if [ -z "$VIRTUAL_ENV" ] ; then
|
||||||
|
echo " ===== Creating new virtualenv at '${VENV}' ====="
|
||||||
|
else
|
||||||
|
echo " ===== Updating virtualenv at '${VENV}' ====="
|
||||||
|
fi
|
||||||
|
|
||||||
|
${PYTHON} -m venv --upgrade-deps "${VENV}"
|
||||||
|
source "${VENV}/bin/activate"
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Symlink CLI scripts into .venv/bin
|
||||||
|
declare -a progfiles=( $(find "${PROGDIR}" -mindepth 1 -maxdepth 1 -executable -type f) )
|
||||||
|
progfiles+=( "${PROGDIR}/shell.functions" )
|
||||||
|
|
||||||
|
# Also symlink the top-level entrypoint if present
|
||||||
|
[ -x "${ROOTDIR}/train_wake_word" ] && progfiles+=( "${ROOTDIR}/train_wake_word" )
|
||||||
|
|
||||||
|
for f in "${progfiles[@]}" ; do
|
||||||
|
ln -sfr "${f}" ".venv/bin/$(basename "${f}")"
|
||||||
|
done
|
||||||
|
|
||||||
|
#
|
||||||
|
# Pip doesn't process packages from requirements.txt in order but order is
|
||||||
|
# important because tensorflow, torch, onnxruntime and micro-wake-word all
|
||||||
|
# depend on CUDA packages at various versions. They need to be installed in
|
||||||
|
# this specific order or they may not be able to use the GPU.
|
||||||
|
#
|
||||||
|
export PIP_PROGRESS_BAR=off
|
||||||
|
export PIP_NO_COLOR=1
|
||||||
|
export PIP_QUIET=0
|
||||||
|
|
||||||
|
pip_install() {
|
||||||
|
if $VERBOSE ; then
|
||||||
|
pip install "$@" || return 1
|
||||||
|
else
|
||||||
|
{ pip install "$@" || return 1 ; } | stdbuf -i0 -o0 tr -d '[:print:]' | stdbuf -i0 -o0 tr '\n' '.'
|
||||||
|
fi
|
||||||
|
echo
|
||||||
|
}
|
||||||
|
|
||||||
|
START_TS=$EPOCHSECONDS
|
||||||
|
|
||||||
|
echo " ===== Installing common requirements ====="
|
||||||
|
# requirements.txt lives in repo root now
|
||||||
|
pip_install -r "${ROOTDIR}/requirements.txt"
|
||||||
|
|
||||||
|
${GPU} && tfgpu='[and-cuda]' || tfgpu=""
|
||||||
|
echo " ===== Installing Tensorflow${tfgpu} ====="
|
||||||
|
pip_install ai_edge_litert "tensorflow${tfgpu}==2.20.0" "tensorboard==2.20.0" \
|
||||||
|
"tensorboard-data-server==0.7.2"
|
||||||
|
|
||||||
|
${GPU} && torchgpu='--index-url https://download.pytorch.org/whl/cu129' || torchgpu=""
|
||||||
|
echo " ===== Installing torch and torchaudio ${torchgpu:+[cuda]} ====="
|
||||||
|
pip_install "torch==2.9.1" "torchaudio==2.9.1" ${torchgpu}
|
||||||
|
|
||||||
|
echo " ===== Checking microwakeword ====="
|
||||||
|
MWW="${DATA_DIR}/tools/microWakeWord"
|
||||||
|
if [ ! -d "${MWW}" ] || [ -n "$(git -C "${MWW}" status --porcelain)" ] ; then
|
||||||
|
rm -rf "${MWW}" || :
|
||||||
|
echo " Cloning micro-wake-word to ${DATA_DIR}/tools"
|
||||||
|
git clone https://github.com/TaterTotterson/micro-wake-word "${MWW}" &>/dev/null
|
||||||
|
fi
|
||||||
|
echo " Installing microwakeword"
|
||||||
|
pip_install -e "${MWW}"
|
||||||
|
|
||||||
|
echo " ===== Checking piper-sample-generator ====="
|
||||||
|
PSG="${DATA_DIR}/tools/piper-sample-generator"
|
||||||
|
if [ ! -d "${PSG}" ] || [ -n "$(git -C "${PSG}" status --porcelain)" ] ; then
|
||||||
|
rm -rf "${PSG}" || :
|
||||||
|
echo " Cloning piper-sample-generator to ${DATA_DIR}/tools"
|
||||||
|
git clone https://github.com/rhasspy/piper-sample-generator "${PSG}" &>/dev/null
|
||||||
|
fi
|
||||||
|
echo " Installing piper-sample-generator"
|
||||||
|
pip_install -e "${PSG}"
|
||||||
|
git -C tools/piper-sample-generator clean -fd &>/dev/null
|
||||||
|
|
||||||
|
MODELS_DIR="${PSG}/models"
|
||||||
|
MODEL_NAME="en_US-libritts_r-medium.pt"
|
||||||
|
MODEL_FILE="${MODELS_DIR}/${MODEL_NAME}"
|
||||||
|
MODEL_URL="https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/${MODEL_NAME}"
|
||||||
|
if [ ! -f "${MODEL_FILE}" ] ; then
|
||||||
|
echo " Downloading ${MODEL_NAME} for piper-sample-generator"
|
||||||
|
curl -sfL "${MODEL_URL}" -o "${MODEL_FILE}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f "${MODEL_FILE}.json" ] ; then
|
||||||
|
echo " Downloading ${MODEL_NAME}.json for piper-sample-generator"
|
||||||
|
curl -sfL "${MODEL_URL}.json" -o "${MODEL_FILE}.json"
|
||||||
|
fi
|
||||||
|
|
||||||
|
${GPU} && onnxgpu='-gpu[cuda]' || onnxgpu=""
|
||||||
|
echo " ===== Installing onnxruntime${onnxgpu} ====="
|
||||||
|
pip_install "onnxruntime${onnxgpu}>=1.16.0"
|
||||||
|
|
||||||
|
echo " ===== Installing keras ====="
|
||||||
|
# keras 3.13 has "issues" so we need to back down to 3.12.
|
||||||
|
pip_install "keras==3.12.0"
|
||||||
|
|
||||||
|
"${PROGDIR}/test_python" --data-dir="${DATA_DIR}"
|
||||||
|
|
||||||
|
touch .mww-data-dir
|
||||||
|
END_TS=$EPOCHSECONDS
|
||||||
|
|
||||||
|
echo "Run 'source ${VENV}/bin/activate' to activate the new virtualenv in the current shell."
|
||||||
|
|
||||||
|
print_elapsed_time "${START_TS}" "${END_TS}" "Python package installation complete"
|
||||||
65
cli/setup_training_datasets
Executable file
65
cli/setup_training_datasets
Executable file
@@ -0,0 +1,65 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
PROGPATH="$(realpath "$0")"
|
||||||
|
PROGDIR="$(dirname "${PROGPATH}")"
|
||||||
|
ROOTDIR="$(dirname "${PROGDIR}")" # repo root (train_wake_word, requirements.txt, etc.)
|
||||||
|
|
||||||
|
KNOWN_ARGS=( data-dir cleanup-archives cleanup-intermediate-files )
|
||||||
|
source "${PROGDIR}/shell.functions"
|
||||||
|
|
||||||
|
if [ ${#UNKNOWN_ARGS[@]} -gt 0 ] ; then
|
||||||
|
echo "Unknown argument(s): ${UNKNOWN_ARGS[*]}" >&2
|
||||||
|
HELP=true
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "${HELP}" == "true" ] ; then
|
||||||
|
cat <<EOF >&2
|
||||||
|
Usage: setup_training_datasets [ --cleanup-archives ] [ --cleanup-intermediate-files ]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--cleanup-archives: Automatically delete the tarballs or zipfiles after
|
||||||
|
they've been extracted.
|
||||||
|
|
||||||
|
--cleanup-intermediate-files: Automatically delete the intermediate files
|
||||||
|
after they've been converted.
|
||||||
|
|
||||||
|
EOF
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Normalize + validate DATA_DIR (shell.functions typically sets a default,
|
||||||
|
# but this makes the script standalone-safe)
|
||||||
|
[ -n "${DATA_DIR:-}" ] && DATA_DIR="$(realpath "${DATA_DIR}")"
|
||||||
|
[ -d "${DATA_DIR}" ] || {
|
||||||
|
echo "Data directory '${DATA_DIR}' doesn't exist." >&2
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
cd "${DATA_DIR}"
|
||||||
|
|
||||||
|
START_TS=$EPOCHSECONDS
|
||||||
|
echo -e "\n===== Setting up Training Datasets =====\n"
|
||||||
|
|
||||||
|
"${PROGDIR}/setup_negative_datasets" \
|
||||||
|
--cleanup-archives="${CLEANUP_ARCHIVES}" \
|
||||||
|
--cleanup-intermediate-files="${CLEANUP_INTERMEDIATE_FILES}" \
|
||||||
|
--data-dir="${DATA_DIR}"
|
||||||
|
|
||||||
|
"${PROGDIR}/setup_mit_audio" \
|
||||||
|
--cleanup-archives="${CLEANUP_ARCHIVES}" \
|
||||||
|
--cleanup-intermediate-files="${CLEANUP_INTERMEDIATE_FILES}" \
|
||||||
|
--data-dir="${DATA_DIR}"
|
||||||
|
|
||||||
|
"${PROGDIR}/setup_audioset" \
|
||||||
|
--cleanup-archives="${CLEANUP_ARCHIVES}" \
|
||||||
|
--cleanup-intermediate-files="${CLEANUP_INTERMEDIATE_FILES}" \
|
||||||
|
--data-dir="${DATA_DIR}"
|
||||||
|
|
||||||
|
"${PROGDIR}/setup_fma" \
|
||||||
|
--cleanup-archives="${CLEANUP_ARCHIVES}" \
|
||||||
|
--cleanup-intermediate-files="${CLEANUP_INTERMEDIATE_FILES}" \
|
||||||
|
--data-dir="${DATA_DIR}"
|
||||||
|
|
||||||
|
END_TS=$EPOCHSECONDS
|
||||||
|
print_elapsed_time "${START_TS}" "${END_TS}" "Training dataset setup"
|
||||||
150
cli/shell.functions
Normal file
150
cli/shell.functions
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
|
||||||
|
if [ "$0" == "${BASH_SOURCE[0]}" ] ; then
|
||||||
|
echo "${BASH_SOURCE[0]} is meant to be 'sourced' not run directly" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -v DATA_DIR ] ; then
|
||||||
|
[ -f .mww-data-dir ] && DATA_DIR="${PWD}" || DATA_DIR="/data"
|
||||||
|
fi
|
||||||
|
|
||||||
|
DEFAULT_SAMPLES=50000
|
||||||
|
DEFAULT_BATCH_SIZE=100
|
||||||
|
DEFAULT_TRAINING_STEPS=40000
|
||||||
|
|
||||||
|
[ -f "${DATA_DIR}/.defaults.env" ] && source "${DATA_DIR}/.defaults.env" || :
|
||||||
|
|
||||||
|
: "${SAMPLES:=${DEFAULT_SAMPLES}}"
|
||||||
|
: "${BATCH_SIZE:=${DEFAULT_BATCH_SIZE}}"
|
||||||
|
: "${TRAINING_STEPS:=${DEFAULT_TRAINING_STEPS}}"
|
||||||
|
: "${CLEANUP_WORK_DIR:=false}"
|
||||||
|
: "${CLEANUP_ARCHIVES:=false}"
|
||||||
|
: "${CLEANUP_INTERMEDIATE_FILES:=false}"
|
||||||
|
: "${QUIET:=false}"
|
||||||
|
: "${VERBOSE:=false}"
|
||||||
|
|
||||||
|
HELP=false
|
||||||
|
|
||||||
|
if [ -v KNOWN_ARGS ] ; then
|
||||||
|
KNOWN_ARGS+=( help verbose quiet h v q )
|
||||||
|
fi
|
||||||
|
declare -gi OPTION_COUNT=0
|
||||||
|
declare -ga POSITIONAL_ARGS=()
|
||||||
|
declare -ga EXTRA_ARGS=()
|
||||||
|
declare -ga UNKNOWN_ARGS=()
|
||||||
|
declare -i __stop_parsing=0
|
||||||
|
for a in "$@"; do
|
||||||
|
if [ "$a" == "--" ] ; then
|
||||||
|
__stop_parsing=1
|
||||||
|
shift
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
if [ $__stop_parsing == 1 ] ; then
|
||||||
|
EXTRA_ARGS+=( "$a" )
|
||||||
|
shift
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -v KNOWN_ARGS ] && [[ "${a}" =~ ^--?([^=]+)=?.* ]] ; then
|
||||||
|
_arg=${BASH_REMATCH[1]}
|
||||||
|
known=false
|
||||||
|
for _k in "${KNOWN_ARGS[@]}" ; do
|
||||||
|
[ "${_arg}" == "${_k}" ] && { known=true ; break ; } || :
|
||||||
|
done
|
||||||
|
$known || UNKNOWN_ARGS+=( "${a}" )
|
||||||
|
fi
|
||||||
|
OPTION_COUNT+=1
|
||||||
|
case "$a" in
|
||||||
|
-h | --help)
|
||||||
|
HELP=true
|
||||||
|
break
|
||||||
|
;;
|
||||||
|
-q | --quiet)
|
||||||
|
QUIET=true
|
||||||
|
break
|
||||||
|
;;
|
||||||
|
-v | --verbose)
|
||||||
|
VERBOSE=true
|
||||||
|
break
|
||||||
|
;;
|
||||||
|
--*=*)
|
||||||
|
[[ $a =~ --([^=]+)=(.*) ]]
|
||||||
|
l=${BASH_REMATCH[1]//-/_}
|
||||||
|
declare -n var="${l^^}"
|
||||||
|
var="${BASH_REMATCH[2]}"
|
||||||
|
;;
|
||||||
|
--no-*)
|
||||||
|
[[ $a =~ --no-(.+) ]]
|
||||||
|
l=${BASH_REMATCH[1]//-/_}
|
||||||
|
declare -n var="${l^^}"
|
||||||
|
var=false
|
||||||
|
;;
|
||||||
|
--*)
|
||||||
|
[[ $a =~ --(.+) ]]
|
||||||
|
l=${BASH_REMATCH[1]//-/_}
|
||||||
|
declare -n var="${l^^}"
|
||||||
|
var=true
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
POSITIONAL_ARGS+=( "$a" )
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
|
||||||
|
print_elapsed_time() {
|
||||||
|
print_seps=True
|
||||||
|
if [ "$1" == "--no-separators" ] ; then
|
||||||
|
shift
|
||||||
|
print_seps=False
|
||||||
|
fi
|
||||||
|
local START_TS=${1:?"Usage: $0 <start_timestamp> <end_timestamp>"}
|
||||||
|
local END_TS=${2:?"Usage: $0 <start_timestamp> <end_timestamp>"}
|
||||||
|
message="${3}"
|
||||||
|
python <<EOF
|
||||||
|
from datetime import datetime
|
||||||
|
st=datetime.fromtimestamp(int($START_TS))
|
||||||
|
et=datetime.fromtimestamp(int($END_TS))
|
||||||
|
msg=f"${message} Elapsed time: {et-st!s}"
|
||||||
|
if ${print_seps}:
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
print(f"{msg:>80s}")
|
||||||
|
if ${print_seps}:
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
justify_text() {
|
||||||
|
msg="${1:?Need a string}"
|
||||||
|
len="${2:?Need a length}"
|
||||||
|
printf "%*s\n" $(( (${#msg}+len)/2)) "${msg}"
|
||||||
|
}
|
||||||
|
|
||||||
|
get_filecounts() {
|
||||||
|
declare -ln fca=${1}
|
||||||
|
local af=${2}
|
||||||
|
if [ -f "${af}" ] ; then
|
||||||
|
mapfile -t fc < <(cat "${af}")
|
||||||
|
for ds in "${fc[@]}" ; do
|
||||||
|
[[ "${ds}" =~ ^([^:]+):([0-9-]+)$ ]] && fca[${BASH_REMATCH[1]}]=${BASH_REMATCH[2]} || :
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
get_total_filecount() {
|
||||||
|
declare -ln fca=${1}
|
||||||
|
declare -li total=0
|
||||||
|
for ds in "${fca[@]}" ; do
|
||||||
|
total+=${ds}
|
||||||
|
done
|
||||||
|
echo $total
|
||||||
|
}
|
||||||
|
|
||||||
|
write_filecounts() {
|
||||||
|
declare -ln fca=${1}
|
||||||
|
local af=${2}
|
||||||
|
rm -rf "${af}" || :
|
||||||
|
for ds in "${!fca[@]}" ; do
|
||||||
|
echo "${ds}:${fca[${ds}]}" >> "${af}"
|
||||||
|
done
|
||||||
|
}
|
||||||
18
cli/system_summary
Executable file
18
cli/system_summary
Executable file
@@ -0,0 +1,18 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
PROGPATH=$(realpath "$0")
|
||||||
|
PROGDIR=$(dirname "${PROGPATH}")
|
||||||
|
|
||||||
|
CUDA_INFO=$("${PROGDIR}/cudainfo")
|
||||||
|
CUDA_CORES=$(sed -n -r -e "s/\s*Total\s+CUDA\s+Cores:\s+([0-9]+)$/\1/gp" <<<${CUDA_INFO})
|
||||||
|
GPU_NAME="$(sed -n -r -e 's/\s*GPU\s+Name:\s+(.+)$/\1/gp' <<<${CUDA_INFO})"
|
||||||
|
GPU_MEMORY="$(sed -n -r -e 's/\s*Total\s+Memory:\s*([0-9.]+).*/\1/gp' <<<${CUDA_INFO})"
|
||||||
|
CPU_NAME="$(sed -n -r -e 's/model\s+name\s*:\s*(.+)$/\1/gp' /proc/cpuinfo | head -1)"
|
||||||
|
CPU_CORES="$(nproc)"
|
||||||
|
SYS_MEMORY="$(free -m | sed -n -r -e 's/Mem:\s+([0-9.]+)\s+.*/\1/gp')"
|
||||||
|
|
||||||
|
printf "CPU: %s (%d cores) Memory: %s mb\n" "${CPU_NAME}" "${CPU_CORES}" "${SYS_MEMORY}"
|
||||||
|
if [ -z "${GPU_NAME}" ] ; then
|
||||||
|
printf "GPU: N/A\n"
|
||||||
|
else
|
||||||
|
printf "GPU: %s (%d cores) Memory: %s mb\n" "${GPU_NAME}" "${CUDA_CORES}" "${GPU_MEMORY}"
|
||||||
|
fi
|
||||||
129
cli/test_python
Executable file
129
cli/test_python
Executable file
@@ -0,0 +1,129 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
PROGPATH=$(realpath "$0")
|
||||||
|
PROGDIR=$(dirname "${PROGPATH}")
|
||||||
|
TRAINING_STEPS=40000
|
||||||
|
DATA_DIR=/data
|
||||||
|
source "${PROGDIR}/shell.functions"
|
||||||
|
|
||||||
|
source "${DATA_DIR}/.venv/bin/activate"
|
||||||
|
|
||||||
|
export TF_CPP_MIN_LOG_LEVEL=9
|
||||||
|
export GLOG_minloglevel=2
|
||||||
|
export GRPC_VERBOSITY="ERROR"
|
||||||
|
|
||||||
|
echo -e "\n===== Testing Python Environment =====\n"
|
||||||
|
|
||||||
|
echo -e "\n===== Testing Cuda =====\n"
|
||||||
|
"${PROGDIR}/cudainfo"
|
||||||
|
|
||||||
|
python - 2>/dev/null <<EOF
|
||||||
|
import os, sys
|
||||||
|
|
||||||
|
print("\n===== Testing Tensorflow =====\n")
|
||||||
|
try:
|
||||||
|
from ai_edge_litert.interpreter import Interpreter
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
try:
|
||||||
|
with tf.device("/GPU:0"):
|
||||||
|
a = tf.random.normal([10000, 10000])
|
||||||
|
b = tf.random.normal([10000, 10000])
|
||||||
|
c = tf.matmul(a, b)
|
||||||
|
if c.device.find("GPU") >= 0:
|
||||||
|
result = "Available - " + c.device
|
||||||
|
else:
|
||||||
|
result = "Not available"
|
||||||
|
except:
|
||||||
|
result = "Not available"
|
||||||
|
|
||||||
|
print("GPU:", result)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with tf.device("/CPU:0"):
|
||||||
|
a = tf.random.normal([10000, 10000])
|
||||||
|
b = tf.random.normal([10000, 10000])
|
||||||
|
c = tf.matmul(a, b)
|
||||||
|
result = "Available - " + c.device
|
||||||
|
except:
|
||||||
|
result = "Not available"
|
||||||
|
|
||||||
|
print("CPU:", result)
|
||||||
|
except:
|
||||||
|
print("Tensorflow not available")
|
||||||
|
EOF
|
||||||
|
|
||||||
|
|
||||||
|
python - 2>/dev/null <<EOF
|
||||||
|
import os, sys
|
||||||
|
print("\n===== Testing Torch =====\n")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print(f"GPU: Available - {torch.cuda.get_device_name(0)}")
|
||||||
|
else:
|
||||||
|
print("GPU:", "Not available")
|
||||||
|
print("CPU:", "Available")
|
||||||
|
except:
|
||||||
|
print("Torch not available")
|
||||||
|
EOF
|
||||||
|
|
||||||
|
python - 2>/dev/null <<EOF
|
||||||
|
import os, sys
|
||||||
|
print("\n===== Testing onnxruntime =====\n")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
providers = ort.get_available_providers()
|
||||||
|
if 'CUDAExecutionProvider' in providers:
|
||||||
|
print("GPU:", "Available")
|
||||||
|
else:
|
||||||
|
print("GPU:", "Not available")
|
||||||
|
|
||||||
|
if 'CPUExecutionProvider' in providers:
|
||||||
|
print("CPU:", "Available")
|
||||||
|
else:
|
||||||
|
print("CPU:", "Not available")
|
||||||
|
|
||||||
|
if 'TensorrtExecutionProvider' in providers:
|
||||||
|
print("TensorRT:", "Available")
|
||||||
|
else:
|
||||||
|
print("TensorRT:", "Not available")
|
||||||
|
except:
|
||||||
|
print("onnxruntime not available")
|
||||||
|
EOF
|
||||||
|
|
||||||
|
python - 2>/dev/null <<EOF
|
||||||
|
import os, sys
|
||||||
|
|
||||||
|
print("\n===== Testing micro-wake-word =====\n")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
import librosa
|
||||||
|
from mmap_ninja.ragged import RaggedMmap
|
||||||
|
from microwakeword.audio.augmentation import Augmentation
|
||||||
|
from microwakeword.audio.clips import Clips
|
||||||
|
from microwakeword.audio.spectrograms import SpectrogramGeneration
|
||||||
|
from microwakeword.audio.audio_utils import save_clip
|
||||||
|
|
||||||
|
print("micro-wake-word available")
|
||||||
|
except:
|
||||||
|
print("micro-wake-word not available")
|
||||||
|
|
||||||
|
print("")
|
||||||
|
EOF
|
||||||
|
|
||||||
|
echo -e "===== Testing piper-sample-generator =====\n"
|
||||||
|
|
||||||
|
./tools/piper-sample-generator/generate_samples.py --help &>/dev/null && {
|
||||||
|
echo "piper-sample-generator available"
|
||||||
|
} || {
|
||||||
|
echo "piper-sample-generator not available"
|
||||||
|
}
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo -e "\n===== Python Environment Testing Complete =====\n"
|
||||||
248
cli/wake_word_sample_augmenter
Normal file
248
cli/wake_word_sample_augmenter
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import sys, os, gc, glob, random
|
||||||
|
import types
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from argparse import ArgumentParser as ArgParser, ArgumentError
|
||||||
|
|
||||||
|
default_data_dir = os.getcwd() if os.path.exists(".mww-data-dir") else "/data"
|
||||||
|
|
||||||
|
parser = ArgParser(exit_on_error=False)
|
||||||
|
parser.add_argument("--data-dir", type=str, help=f"Data directory. Default: {default_data_dir}", required=False, default=default_data_dir)
|
||||||
|
|
||||||
|
# Wake word (TTS/generated) inputs/outputs
|
||||||
|
parser.add_argument("--input-dir", type=str, help="Wake word input dir. Default: <data-dir>/work/wake_word_samples", required=False)
|
||||||
|
parser.add_argument("--output-dir", type=str, help="Wake word output dir. Default: <input-dir>_augmented", required=False)
|
||||||
|
|
||||||
|
# Personal inputs/outputs (NEW)
|
||||||
|
parser.add_argument("--personal-dir", type=str, help="Personal WAV dir. Default: <data-dir>/personal_samples", required=False)
|
||||||
|
parser.add_argument("--personal-output-dir", type=str, help="Personal features output dir. Default: <data-dir>/work/personal_augmented_features", required=False)
|
||||||
|
|
||||||
|
# Dataset dirs
|
||||||
|
parser.add_argument("--mit-rirs-16k-dir", type=str, help="MIT RIR input directory. Default: <data-dir>/training_datasets/mit_rirs_16k", required=False)
|
||||||
|
parser.add_argument("--fma-16k-dir", type=str, help="FMA input directory. Default: <data-dir>/training_datasets/fma_16k", required=False)
|
||||||
|
parser.add_argument("--audioset-16k-dir", type=str, help="Audioset input directory. Default: <data-dir>/training_datasets/audioset_16k", required=False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
args = parser.parse_args()
|
||||||
|
except ArgumentError:
|
||||||
|
parser.print_help()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
args.data_dir = os.path.realpath(args.data_dir)
|
||||||
|
work_dir = os.path.join(args.data_dir, "work")
|
||||||
|
|
||||||
|
# Wake word defaults
|
||||||
|
if not args.input_dir:
|
||||||
|
args.input_dir = os.path.join(work_dir, "wake_word_samples")
|
||||||
|
else:
|
||||||
|
args.input_dir = os.path.realpath(args.input_dir)
|
||||||
|
|
||||||
|
if not args.output_dir:
|
||||||
|
args.output_dir = args.input_dir + "_augmented"
|
||||||
|
else:
|
||||||
|
args.output_dir = os.path.realpath(args.output_dir)
|
||||||
|
|
||||||
|
# Personal defaults (NEW)
|
||||||
|
if not args.personal_dir:
|
||||||
|
args.personal_dir = os.path.join(args.data_dir, "personal_samples")
|
||||||
|
else:
|
||||||
|
args.personal_dir = os.path.realpath(args.personal_dir)
|
||||||
|
|
||||||
|
if not args.personal_output_dir:
|
||||||
|
args.personal_output_dir = os.path.join(work_dir, "personal_augmented_features")
|
||||||
|
else:
|
||||||
|
args.personal_output_dir = os.path.realpath(args.personal_output_dir)
|
||||||
|
|
||||||
|
# Dataset defaults
|
||||||
|
if not args.mit_rirs_16k_dir:
|
||||||
|
args.mit_rirs_16k_dir = os.path.join(args.data_dir, "training_datasets", "mit_rirs_16k")
|
||||||
|
else:
|
||||||
|
args.mit_rirs_16k_dir = os.path.realpath(args.mit_rirs_16k_dir)
|
||||||
|
|
||||||
|
if not args.fma_16k_dir:
|
||||||
|
args.fma_16k_dir = os.path.join(args.data_dir, "training_datasets", "fma_16k")
|
||||||
|
else:
|
||||||
|
args.fma_16k_dir = os.path.realpath(args.fma_16k_dir)
|
||||||
|
|
||||||
|
if not args.audioset_16k_dir:
|
||||||
|
args.audioset_16k_dir = os.path.join(args.data_dir, "training_datasets", "audioset_16k")
|
||||||
|
else:
|
||||||
|
args.audioset_16k_dir = os.path.realpath(args.audioset_16k_dir)
|
||||||
|
|
||||||
|
def validate_directories(paths):
|
||||||
|
for path in paths:
|
||||||
|
if not os.path.exists(path):
|
||||||
|
print(f"Error: Directory {path} does not exist. Please ensure preprocessing is complete.")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
required = [work_dir, args.input_dir, args.mit_rirs_16k_dir, args.fma_16k_dir, args.audioset_16k_dir]
|
||||||
|
if not validate_directories(required):
|
||||||
|
parser.print_help()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# -------------------- TF + libs --------------------
|
||||||
|
print(" Initializing libraries")
|
||||||
|
|
||||||
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
|
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
|
||||||
|
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"
|
||||||
|
os.environ["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0"
|
||||||
|
os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
|
||||||
|
os.environ["TF_CUDNN_WORKSPACE_LIMIT_IN_MB"] = "512"
|
||||||
|
os.environ["GLOG_minloglevel"] = "9"
|
||||||
|
os.environ["GRPC_VERBOSITY"] = "ERROR"
|
||||||
|
|
||||||
|
print(" Loading Tensorflow")
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
print(" GPU memory config")
|
||||||
|
for g in tf.config.list_physical_devices("GPU"):
|
||||||
|
try:
|
||||||
|
tf.config.experimental.set_memory_growth(g, True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
print(f" GPUs: {tf.config.list_physical_devices('GPU')}")
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import librosa
|
||||||
|
from mmap_ninja.ragged import RaggedMmap
|
||||||
|
from microwakeword.audio.augmentation import Augmentation
|
||||||
|
from microwakeword.audio.clips import Clips
|
||||||
|
from microwakeword.audio.spectrograms import SpectrogramGeneration
|
||||||
|
|
||||||
|
START_TIME = datetime.now(timezone.utc).replace(microsecond=0)
|
||||||
|
|
||||||
|
impulse_paths = [args.mit_rirs_16k_dir]
|
||||||
|
background_paths = [args.fma_16k_dir, args.audioset_16k_dir]
|
||||||
|
|
||||||
|
augmenter = Augmentation(
|
||||||
|
augmentation_duration_s=3.2,
|
||||||
|
augmentation_probabilities={
|
||||||
|
"SevenBandParametricEQ": 0.1,
|
||||||
|
"TanhDistortion": 0.05,
|
||||||
|
"PitchShift": 0.15,
|
||||||
|
"BandStopFilter": 0.1,
|
||||||
|
"AddColorNoise": 0.1,
|
||||||
|
"AddBackgroundNoise": 0.7,
|
||||||
|
"Gain": 0.8,
|
||||||
|
"RIR": 0.7,
|
||||||
|
},
|
||||||
|
impulse_paths=impulse_paths,
|
||||||
|
background_paths=background_paths,
|
||||||
|
background_min_snr_db=5,
|
||||||
|
background_max_snr_db=10,
|
||||||
|
min_jitter_s=0.2,
|
||||||
|
max_jitter_s=0.3,
|
||||||
|
)
|
||||||
|
|
||||||
|
split_cfg = {
|
||||||
|
"training": {"name": "train", "repetition": 2, "slide_frames": 10},
|
||||||
|
"validation": {"name": "validation", "repetition": 1, "slide_frames": 10},
|
||||||
|
"testing": {"name": "test", "repetition": 1, "slide_frames": 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
def bind_wav_generator(clips_obj: Clips, wav_dir: str):
|
||||||
|
"""
|
||||||
|
Patch clips.audio_generator so we load WAVs directly (deterministic 80/10/10 split, seed=10).
|
||||||
|
Matches the notebook behavior you posted.
|
||||||
|
"""
|
||||||
|
def audio_generator_from_wavs(self, split="train", repeat=1):
|
||||||
|
files = sorted(glob.glob(os.path.join(wav_dir, "*.wav")))
|
||||||
|
if not files:
|
||||||
|
return
|
||||||
|
|
||||||
|
rng = random.Random(10)
|
||||||
|
files_shuf = files[:]
|
||||||
|
rng.shuffle(files_shuf)
|
||||||
|
|
||||||
|
n = len(files_shuf)
|
||||||
|
n_val = max(1, int(0.10 * n))
|
||||||
|
n_test = max(1, int(0.10 * n))
|
||||||
|
n_train = max(0, n - n_val - n_test)
|
||||||
|
|
||||||
|
splits = {
|
||||||
|
"train": files_shuf[:n_train],
|
||||||
|
"validation": files_shuf[n_train:n_train + n_val],
|
||||||
|
"test": files_shuf[n_train + n_val:],
|
||||||
|
}
|
||||||
|
file_list = splits.get(split, [])
|
||||||
|
if not file_list:
|
||||||
|
return
|
||||||
|
|
||||||
|
for _ in range(max(1, int(repeat))):
|
||||||
|
for p in file_list:
|
||||||
|
y, _sr = librosa.load(p, sr=16000, mono=True)
|
||||||
|
yield y.astype(np.float32, copy=False)
|
||||||
|
|
||||||
|
clips_obj.audio_generator = types.MethodType(audio_generator_from_wavs, clips_obj)
|
||||||
|
|
||||||
|
def generate_feature_set(input_wav_dir: str, out_root_dir: str, label: str):
|
||||||
|
files = glob.glob(os.path.join(input_wav_dir, "*.wav"))
|
||||||
|
if not files:
|
||||||
|
print(f"ℹ️ No WAVs found for {label} in: {input_wav_dir} (skipping)")
|
||||||
|
return False
|
||||||
|
|
||||||
|
max_samples = len(files)
|
||||||
|
print(f"\n===== Augmenting {max_samples} wake word samples ({label}) =====")
|
||||||
|
|
||||||
|
clips = Clips(
|
||||||
|
input_directory=input_wav_dir,
|
||||||
|
file_pattern="*.wav",
|
||||||
|
max_clip_duration_s=5,
|
||||||
|
remove_silence=True,
|
||||||
|
random_split_seed=10,
|
||||||
|
split_count=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
bind_wav_generator(clips, input_wav_dir)
|
||||||
|
|
||||||
|
out_root = Path(out_root_dir)
|
||||||
|
out_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
for split, cfg in split_cfg.items():
|
||||||
|
out_dir = out_root / split
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f" Augmenting {split} ({label})")
|
||||||
|
print(" Sit tight this can take awhile ...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
spectros = SpectrogramGeneration(
|
||||||
|
clips=clips,
|
||||||
|
augmenter=augmenter,
|
||||||
|
slide_frames=cfg["slide_frames"],
|
||||||
|
step_ms=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
gen = spectros.spectrogram_generator(
|
||||||
|
split=cfg["name"],
|
||||||
|
repeat=cfg["repetition"],
|
||||||
|
)
|
||||||
|
|
||||||
|
RaggedMmap.from_generator(
|
||||||
|
out_dir=str(out_dir / "wakeword_mmap"),
|
||||||
|
sample_generator=gen,
|
||||||
|
batch_size=100,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" {split} augmentation complete ({label})")
|
||||||
|
|
||||||
|
print(f"\n✅ Features ready: {out_root_dir}/*/wakeword_mmap\n")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Wake word generated/TTS features (existing behavior)
|
||||||
|
generate_feature_set(args.input_dir, args.output_dir, "generated")
|
||||||
|
|
||||||
|
# Personal features (NEW)
|
||||||
|
generate_feature_set(args.personal_dir, args.personal_output_dir, "personal")
|
||||||
|
|
||||||
|
END_TIME = datetime.now(timezone.utc).replace(microsecond=0)
|
||||||
|
et = END_TIME - START_TIME
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"{'Augmentation completed.':>50s} Elapsed time: {et!s}")
|
||||||
|
print(f"{'=' * 80}\n")
|
||||||
112
cli/wake_word_sample_generator
Executable file
112
cli/wake_word_sample_generator
Executable file
@@ -0,0 +1,112 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
PROGPATH=$(realpath "$0")
|
||||||
|
PROGDIR=$(dirname "${PROGPATH}")
|
||||||
|
|
||||||
|
KNOWN_ARGS=( samples batch-size data-dir )
|
||||||
|
source "${PROGDIR}/shell.functions"
|
||||||
|
WAKE_WORD="${POSITIONAL_ARGS[0]}"
|
||||||
|
|
||||||
|
if [ ${#UNKNOWN_ARGS[@]} -gt 0 ] ; then
|
||||||
|
echo "Unknown argument(s): ${UNKNOWN_ARGS[*]}" >&2
|
||||||
|
HELP=true
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "${HELP}" == "true" ] || [ -z "${WAKE_WORD}" ] ; then
|
||||||
|
cat <<EOF >&2
|
||||||
|
Usage: $0 [ --samples=<samples> ] [ --batch-size=<batch_size> ] <wake_word>
|
||||||
|
|
||||||
|
--samples: The number of samples to generate for the wake word.
|
||||||
|
Default: ${DEFAULT_SAMPLES}
|
||||||
|
|
||||||
|
--batch-size: How many samples should be generated at a time. The more
|
||||||
|
samples, the more memory is needed.
|
||||||
|
Default: ${DEFAULT_BATCH_SIZE}
|
||||||
|
|
||||||
|
<wake_word> The word to generate samples for.
|
||||||
|
Required.
|
||||||
|
|
||||||
|
EOF
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# shellcheck source=/dev/null
|
||||||
|
source "${DATA_DIR}/.venv/bin/activate"
|
||||||
|
|
||||||
|
WORK_DIR="${DATA_DIR}/work"
|
||||||
|
mkdir -p "${WORK_DIR}" || :
|
||||||
|
cd "${WORK_DIR}"
|
||||||
|
|
||||||
|
PSG="${DATA_DIR}/tools/piper-sample-generator"
|
||||||
|
MODELS_DIR="${PSG}/models"
|
||||||
|
MODEL_NAME=en_US-libritts_r-medium.pt
|
||||||
|
MODEL_FILE="${MODELS_DIR}/${MODEL_NAME}"
|
||||||
|
SAMPLES_DIR="${WORK_DIR}/wake_word_samples"
|
||||||
|
|
||||||
|
mkdir -p "${SAMPLES_DIR}" || :
|
||||||
|
|
||||||
|
REGENERATE=false
|
||||||
|
|
||||||
|
if [ "${SAMPLES}" -eq 1 ] ; then
|
||||||
|
echo "===== Generating ${SAMPLES} sample of '${WAKE_WORD}' ====="
|
||||||
|
wake_word_filename="${WAKE_WORD//[ \`~\!\$&*\(\)\{\}\[\]\|\;\'\"<>.?\/]/_}"
|
||||||
|
|
||||||
|
mkdir -p "${WORK_DIR}/test_sample" || :
|
||||||
|
"${PSG}/generate_samples.py" "${WAKE_WORD}" \
|
||||||
|
--model "${MODEL_FILE}" \
|
||||||
|
--max-samples ${SAMPLES} \
|
||||||
|
--batch-size ${BATCH_SIZE} \
|
||||||
|
--output-dir "${WORK_DIR}/test_sample" \
|
||||||
|
--max-speakers 100 2>&1 | sed -r -e "s/(DEBUG|INFO):__main__:/ /g"
|
||||||
|
mv "${WORK_DIR}/test_sample/0.wav" "${WORK_DIR}/test_sample/${wake_word_filename}.wav"
|
||||||
|
echo "Sample available at ${WORK_DIR}/test_sample/${wake_word_filename}.wav"
|
||||||
|
echo "Play it from your host."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
grep -q "${WAKE_WORD}:${SAMPLES}:${MODEL_NAME}" "${WORK_DIR}/last_wake_word" &>/dev/null || REGENERATE=true
|
||||||
|
|
||||||
|
# Double check that the number of existing samples matches SAMPLES"
|
||||||
|
existing_samples=$(find "${SAMPLES_DIR}" -name '*.wav' | wc -l)
|
||||||
|
[ "${existing_samples}" -eq "${SAMPLES}" ] || REGENERATE=true
|
||||||
|
|
||||||
|
START_TS=$EPOCHSECONDS
|
||||||
|
|
||||||
|
if ! ${REGENERATE} ; then
|
||||||
|
echo "Sample generation not required"
|
||||||
|
echo
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo -e "\n===== Generating ${SAMPLES} wake word samples in batches of ${BATCH_SIZE} ====="
|
||||||
|
export TF_CPP_MIN_LOG_LEVEL=9
|
||||||
|
export TF_FORCE_GPU_ALLOW_GROWTH=true
|
||||||
|
export TF_GPU_ALLOCATOR=cuda_malloc_async
|
||||||
|
export TF_XLA_FLAGS="--tf_xla_auto_jit=0"
|
||||||
|
export NVIDIA_TF32_OVERRIDE=1
|
||||||
|
export TF_CUDNN_WORKSPACE_LIMIT_IN_MB=512
|
||||||
|
export GLOG_minloglevel=2
|
||||||
|
export GRPC_VERBOSITY=ERROR
|
||||||
|
|
||||||
|
echo " Generating samples"
|
||||||
|
rm -rf "${SAMPLES_DIR}" || :
|
||||||
|
mkdir -p "${SAMPLES_DIR}" || :
|
||||||
|
"${PSG}/generate_samples.py" "${WAKE_WORD}" \
|
||||||
|
--model "${MODEL_FILE}" \
|
||||||
|
--max-samples ${SAMPLES} \
|
||||||
|
--batch-size ${BATCH_SIZE} \
|
||||||
|
--output-dir "${SAMPLES_DIR}" 2>&1 | sed -r -e "s/(DEBUG|INFO):__main__:/ /g"
|
||||||
|
|
||||||
|
generated_files=$(find "${SAMPLES_DIR}" -name '*.wav' | wc -l)
|
||||||
|
if [ "${generated_files}" -ne "${SAMPLES}" ] ; then
|
||||||
|
echo "ERROR: only generated ${generated_files} files" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
END_TS=$(date +%s.%N)
|
||||||
|
echo "${WAKE_WORD}:${SAMPLES}:${MODEL_NAME}" > "${WORK_DIR}/last_wake_word"
|
||||||
|
echo
|
||||||
|
END_TS=$EPOCHSECONDS
|
||||||
|
print_elapsed_time "${START_TS}" "${END_TS}" "Generated ${SAMPLES} wake word samples."
|
||||||
|
|
||||||
|
exit 0
|
||||||
323
cli/wake_word_sample_trainer
Normal file
323
cli/wake_word_sample_trainer
Normal file
@@ -0,0 +1,323 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
PROGPATH=$(realpath "$0")
|
||||||
|
PROGDIR=$(dirname "${PROGPATH}")
|
||||||
|
|
||||||
|
KNOWN_ARGS=( training-steps samples data-dir )
|
||||||
|
source "${PROGDIR}/shell.functions"
|
||||||
|
WAKE_WORD="${POSITIONAL_ARGS[0]}"
|
||||||
|
|
||||||
|
if [ ${#UNKNOWN_ARGS[@]} -gt 0 ] ; then
|
||||||
|
echo "Unknown argument(s): ${UNKNOWN_ARGS[*]}" >&2
|
||||||
|
HELP=true
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "${HELP}" == "true" ] || [ -z "${WAKE_WORD}" ] ; then
|
||||||
|
cat <<EOF >&2
|
||||||
|
Usage: $0 [ --samples=<samples> ] [ --training-steps=<steps> ]
|
||||||
|
<wake_word> [ <wake_word_title> ]
|
||||||
|
|
||||||
|
$0 -h/--help
|
||||||
|
|
||||||
|
--samples: The number of samples to generate for the wake word.
|
||||||
|
Used only to generate output file names.
|
||||||
|
|
||||||
|
--training-steps: Number of training steps.
|
||||||
|
Default: ${DEFAULT_TRAINING_STEPS}
|
||||||
|
|
||||||
|
<wake_word>: The word to train spelled phonetically.
|
||||||
|
Required.
|
||||||
|
|
||||||
|
<wake_word_title>: A pretty name to save to the json metadata file.
|
||||||
|
Default: The wake word with individual words capitalized.
|
||||||
|
|
||||||
|
EOF
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
WORK_DIR="${DATA_DIR}/work"
|
||||||
|
TRAINING_DS="${DATA_DIR}/training_datasets"
|
||||||
|
|
||||||
|
[ ${#POSITIONAL_ARGS} -eq 2 ] && WAKE_WORD_TITLE="${POSITIONAL_ARGS[1]}"
|
||||||
|
|
||||||
|
if [ ! -v WAKE_WORD_TITLE ] ; then
|
||||||
|
declare -a WWNA=( ${WAKE_WORD//[^a-zA-Z0-9]/ } )
|
||||||
|
WAKE_WORD_TITLE="${WWNA[*]^}"
|
||||||
|
elif [ -z "$WAKE_WORD_TITLE" ] ; then
|
||||||
|
WAKE_WORD_TITLE="$WAKE_WORD"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# shellcheck source=/dev/null
|
||||||
|
source "${DATA_DIR}/.venv/bin/activate"
|
||||||
|
|
||||||
|
check_directories() {
|
||||||
|
for d in "$@" ; do
|
||||||
|
[ -d "$d" ] || { echo "ERROR: Directory $d not found" >&2 ; exit 1 ; }
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
check_directories ${WORK_DIR}/wake_word_samples_augmented \
|
||||||
|
${TRAINING_DS}/negative_datasets/{speech,dinner_party,no_speech,dinner_party_eval}
|
||||||
|
|
||||||
|
# Personal features are optional, but if present they MUST have /training
|
||||||
|
PERSONAL_FEATURES_DIR="${WORK_DIR}/personal_augmented_features"
|
||||||
|
HAS_PERSONAL="false"
|
||||||
|
if [ -d "${PERSONAL_FEATURES_DIR}/training" ] ; then
|
||||||
|
HAS_PERSONAL="true"
|
||||||
|
echo "✅ Found personal features: ${PERSONAL_FEATURES_DIR}/training (will weight sampling_weight=3.0)"
|
||||||
|
else
|
||||||
|
echo "ℹ️ No personal features found at ${PERSONAL_FEATURES_DIR}/training (continuing without personal weighting)"
|
||||||
|
fi
|
||||||
|
|
||||||
|
cd "${WORK_DIR}"
|
||||||
|
|
||||||
|
echo "===== Starting ${TRAINING_STEPS} training steps ====="
|
||||||
|
START_TS=$EPOCHSECONDS
|
||||||
|
|
||||||
|
mkdir -p "${WORK_DIR}/trained_models" || :
|
||||||
|
|
||||||
|
# We write a YAML with a marker, then splice personal feature block in if it exists.
|
||||||
|
YAML_PATH="${WORK_DIR}/trained_models/training_parameters.yaml"
|
||||||
|
|
||||||
|
cat <<'EOF' > "${YAML_PATH}"
|
||||||
|
batch_size: 16
|
||||||
|
clip_duration_ms: 1500
|
||||||
|
eval_step_interval: 500
|
||||||
|
features:
|
||||||
|
- features_dir: __WAKEWORD_FEATURES__
|
||||||
|
penalty_weight: 1.0
|
||||||
|
sampling_weight: 2.0
|
||||||
|
truncation_strategy: truncate_start
|
||||||
|
truth: true
|
||||||
|
type: mmap
|
||||||
|
__PERSONAL_FEATURE_MARKER__
|
||||||
|
- features_dir: __NEG_SPEECH__
|
||||||
|
penalty_weight: 1.0
|
||||||
|
sampling_weight: 12.0
|
||||||
|
truncation_strategy: random
|
||||||
|
truth: false
|
||||||
|
type: mmap
|
||||||
|
- features_dir: __NEG_DINNER__
|
||||||
|
penalty_weight: 1.0
|
||||||
|
sampling_weight: 12.0
|
||||||
|
truncation_strategy: random
|
||||||
|
truth: false
|
||||||
|
type: mmap
|
||||||
|
- features_dir: __NEG_NOSPEECH__
|
||||||
|
penalty_weight: 1.0
|
||||||
|
sampling_weight: 5.0
|
||||||
|
truncation_strategy: random
|
||||||
|
truth: false
|
||||||
|
type: mmap
|
||||||
|
- features_dir: __NEG_DINNER_EVAL__
|
||||||
|
penalty_weight: 1.0
|
||||||
|
sampling_weight: 0.0
|
||||||
|
truncation_strategy: split
|
||||||
|
truth: false
|
||||||
|
type: mmap
|
||||||
|
freq_mask_count:
|
||||||
|
- 0
|
||||||
|
freq_mask_max_size:
|
||||||
|
- 0
|
||||||
|
learning_rates:
|
||||||
|
- 0.001
|
||||||
|
maximization_metric: average_viable_recall
|
||||||
|
minimization_metric: null
|
||||||
|
negative_class_weight:
|
||||||
|
- 20
|
||||||
|
positive_class_weight:
|
||||||
|
- 1
|
||||||
|
target_minimization: 0.9
|
||||||
|
time_mask_count:
|
||||||
|
- 0
|
||||||
|
time_mask_max_size:
|
||||||
|
- 0
|
||||||
|
train_dir: __TRAIN_DIR__
|
||||||
|
training_steps:
|
||||||
|
- __TRAINING_STEPS__
|
||||||
|
window_step_ms: 10
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# Replace placeholders (portable)
|
||||||
|
sed -i \
|
||||||
|
-e "s|__WAKEWORD_FEATURES__|${WORK_DIR}/wake_word_samples_augmented|g" \
|
||||||
|
-e "s|__NEG_SPEECH__|${TRAINING_DS}/negative_datasets/speech|g" \
|
||||||
|
-e "s|__NEG_DINNER__|${TRAINING_DS}/negative_datasets/dinner_party|g" \
|
||||||
|
-e "s|__NEG_NOSPEECH__|${TRAINING_DS}/negative_datasets/no_speech|g" \
|
||||||
|
-e "s|__NEG_DINNER_EVAL__|${TRAINING_DS}/negative_datasets/dinner_party_eval|g" \
|
||||||
|
-e "s|__TRAIN_DIR__|${WORK_DIR}/trained_models/wakeword|g" \
|
||||||
|
-e "s|__TRAINING_STEPS__|${TRAINING_STEPS}|g" \
|
||||||
|
"${YAML_PATH}"
|
||||||
|
|
||||||
|
# Insert/remove personal block
|
||||||
|
if [ "${HAS_PERSONAL}" = "true" ]; then
|
||||||
|
# Insert directly after the wakeword feature block (matches notebook: insert(1, ...))
|
||||||
|
personal_block="$(cat <<EOF
|
||||||
|
- features_dir: ${PERSONAL_FEATURES_DIR}
|
||||||
|
penalty_weight: 1.0
|
||||||
|
sampling_weight: 3.0
|
||||||
|
truncation_strategy: truncate_start
|
||||||
|
truth: true
|
||||||
|
type: mmap
|
||||||
|
EOF
|
||||||
|
)"
|
||||||
|
|
||||||
|
perl -0777 -i -pe "s#__PERSONAL_FEATURE_MARKER__#${personal_block}#g" "${YAML_PATH}"
|
||||||
|
else
|
||||||
|
# Remove marker line entirely
|
||||||
|
sed -i -e "/__PERSONAL_FEATURE_MARKER__/d" "${YAML_PATH}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo " Wrote training_parameters.yaml"
|
||||||
|
rm -rf "${WORK_DIR}/trained_models/wakeword"
|
||||||
|
|
||||||
|
wake_word_filename="$(
|
||||||
|
echo "${WAKE_WORD}" \
|
||||||
|
| tr '[:upper:]' '[:lower:]' \
|
||||||
|
| sed -E 's/[^a-z0-9]+/_/g; s/^_+//; s/_+$//'
|
||||||
|
)"
|
||||||
|
[ -n "${wake_word_filename}" ] || wake_word_filename="wakeword"
|
||||||
|
|
||||||
|
OUTPUT_DIR="${DATA_DIR}/output/$(date +'%Y-%m-%d-%H-%M-%S')-${wake_word_filename}-${SAMPLES}-${TRAINING_STEPS}"
|
||||||
|
mkdir -p "${OUTPUT_DIR}/logs" || :
|
||||||
|
TRAIN_LOG="${OUTPUT_DIR}/logs/training.log"
|
||||||
|
|
||||||
|
TRAIN_ARGS=(
|
||||||
|
-m microwakeword.model_train_eval
|
||||||
|
--training_config "${WORK_DIR}/trained_models/training_parameters.yaml"
|
||||||
|
--train 1
|
||||||
|
--restore_checkpoint 1
|
||||||
|
--test_tf_nonstreaming 0
|
||||||
|
--test_tflite_nonstreaming 0
|
||||||
|
--test_tflite_nonstreaming_quantized 0
|
||||||
|
--test_tflite_streaming 0
|
||||||
|
--test_tflite_streaming_quantized 1
|
||||||
|
--use_weights best_weights
|
||||||
|
mixednet
|
||||||
|
--pointwise_filters "64,64,64,64"
|
||||||
|
--repeat_in_block "1,1,1,1"
|
||||||
|
--mixconv_kernel_sizes "[5], [7,11], [9,15], [23]"
|
||||||
|
--residual_connection "0,0,0,0"
|
||||||
|
--first_conv_filters 32
|
||||||
|
--first_conv_kernel_size 5
|
||||||
|
--stride 2
|
||||||
|
)
|
||||||
|
|
||||||
|
GPU_FALLBACK_MARKERS=(
|
||||||
|
"resourceexhaustederror"
|
||||||
|
"resource exhausted"
|
||||||
|
"oom"
|
||||||
|
"out of memory"
|
||||||
|
"cuda_error_out_of_memory"
|
||||||
|
"failed to allocate"
|
||||||
|
"cudnn"
|
||||||
|
"cublas"
|
||||||
|
"internalerror: cuda"
|
||||||
|
"failed call to cuinit"
|
||||||
|
"dst tensor is not initialized"
|
||||||
|
"failed copying input tensor"
|
||||||
|
"_eagerconst"
|
||||||
|
)
|
||||||
|
|
||||||
|
run_attempt() {
|
||||||
|
local label="$1"
|
||||||
|
shift
|
||||||
|
echo
|
||||||
|
echo "================================================================================"
|
||||||
|
echo "===== ${label} ====="
|
||||||
|
echo "================================================================================"
|
||||||
|
echo "→ ${PYTHON_BIN:-python} ${TRAIN_ARGS[*]}"
|
||||||
|
echo
|
||||||
|
|
||||||
|
"${PYTHON_BIN:-python}" "${TRAIN_ARGS[@]}" 2>&1 \
|
||||||
|
| tr '\r' '\n' \
|
||||||
|
| stdbuf -i0 -o0 sed -r -e "/^Validation Batch/d" \
|
||||||
|
| tee "${TRAIN_LOG}" \
|
||||||
|
| sed -r -e "/^Validation Batch/d" -e "s/^INFO:absl:/ /g"
|
||||||
|
|
||||||
|
return ${PIPESTATUS[0]}
|
||||||
|
}
|
||||||
|
|
||||||
|
export TF_CPP_MIN_LOG_LEVEL="${TF_CPP_MIN_LOG_LEVEL:-2}"
|
||||||
|
export TF_XLA_FLAGS="${TF_XLA_FLAGS:---tf_xla_auto_jit=0}"
|
||||||
|
export NVIDIA_TF32_OVERRIDE="${NVIDIA_TF32_OVERRIDE:-1}"
|
||||||
|
export TF_FORCE_GPU_ALLOW_GROWTH="${TF_FORCE_GPU_ALLOW_GROWTH:-true}"
|
||||||
|
export TF_GPU_ALLOCATOR="${TF_GPU_ALLOCATOR:-cuda_malloc_async}"
|
||||||
|
|
||||||
|
if run_attempt "Attempt 1/2: GPU training (allow_growth + cuda_malloc_async)" ; then
|
||||||
|
echo "✅ Training complete (GPU path)."
|
||||||
|
else
|
||||||
|
echo "⚠️ GPU attempt failed. Checking whether this looks like a GPU/OOM/runtime failure…"
|
||||||
|
|
||||||
|
log_lc="$(tr '[:upper:]' '[:lower:]' < "${TRAIN_LOG}" || true)"
|
||||||
|
looks_like_gpu_fail="false"
|
||||||
|
for m in "${GPU_FALLBACK_MARKERS[@]}"; do
|
||||||
|
if echo "${log_lc}" | grep -qF "${m}"; then
|
||||||
|
looks_like_gpu_fail="true"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ "${looks_like_gpu_fail}" = "true" ]; then
|
||||||
|
echo "↪️ Detected GPU/OOM/runtime failure markers. Falling back to CPU."
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES=""
|
||||||
|
unset TF_GPU_ALLOCATOR
|
||||||
|
if run_attempt "Attempt 2/2: CPU fallback (CUDA_VISIBLE_DEVICES='')" ; then
|
||||||
|
echo "✅ Training complete (CPU fallback)."
|
||||||
|
else
|
||||||
|
echo "❌ Training failed on BOTH GPU and CPU. See: ${TRAIN_LOG}" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "❌ Training failed (does not look GPU/OOM/runtime). See: ${TRAIN_LOG}" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
source_path="${WORK_DIR}/trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite"
|
||||||
|
|
||||||
|
if [ ! -f "${source_path}" ] ; then
|
||||||
|
echo "Output model not found! Training didn't complete successfully. See ${TRAIN_LOG}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
cp "${WORK_DIR}/trained_models/wakeword/model_summary.txt" "${OUTPUT_DIR}/logs/" || :
|
||||||
|
cp -a "${WORK_DIR}/trained_models/wakeword/logs/train" "${OUTPUT_DIR}/logs/" || :
|
||||||
|
cp -a "${WORK_DIR}/trained_models/wakeword/logs/validation" "${OUTPUT_DIR}/logs/" || :
|
||||||
|
|
||||||
|
echo -e "\n Training complete!"
|
||||||
|
echo " Full log: ${TRAIN_LOG}"
|
||||||
|
|
||||||
|
tflite_filename="${wake_word_filename}.tflite"
|
||||||
|
tflite_path="${OUTPUT_DIR}/${tflite_filename}"
|
||||||
|
cp "${source_path}" "${tflite_path}"
|
||||||
|
|
||||||
|
json_path="${OUTPUT_DIR}/${wake_word_filename}.json"
|
||||||
|
cat <<-EOF > "${json_path}"
|
||||||
|
{
|
||||||
|
"type": "micro",
|
||||||
|
"wake_word": "${WAKE_WORD_TITLE}",
|
||||||
|
"author": "Tater Totterson",
|
||||||
|
"website": "https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git",
|
||||||
|
"model": "${tflite_filename}",
|
||||||
|
"trained_languages": ["en"],
|
||||||
|
"version": 2,
|
||||||
|
"micro": {
|
||||||
|
"probability_cutoff": 0.97,
|
||||||
|
"sliding_window_size": 5,
|
||||||
|
"feature_step_size": 10,
|
||||||
|
"tensor_arena_size": 30000,
|
||||||
|
"minimum_esphome_version": "2024.7.0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EOF
|
||||||
|
|
||||||
|
echo "Name: ${WAKE_WORD_TITLE}"
|
||||||
|
echo "Model: ${tflite_path}"
|
||||||
|
echo "Metadata: ${json_path}"
|
||||||
|
echo
|
||||||
|
END_TS=$EPOCHSECONDS
|
||||||
|
print_elapsed_time "${START_TS}" "${END_TS}" "Training completed."
|
||||||
|
echo
|
||||||
79
dockerfile
79
dockerfile
@@ -1,59 +1,40 @@
|
|||||||
# Standard Ubuntu base image. CUDA base images not needed.
|
# Base
|
||||||
FROM ubuntu:22.04
|
FROM ubuntu:24.04
|
||||||
|
|
||||||
ENV DEBIAN_FRONTEND=noninteractive \
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
PYTHONUNBUFFERED=1 \
|
|
||||||
PIP_NO_CACHE_DIR=1 \
|
|
||||||
PIP_ROOT_USER_ACTION=ignore \
|
|
||||||
HF_HUB_DISABLE_SYMLINKS_WARNING=1 \
|
|
||||||
XLA_FLAGS="--xla_gpu_cuda_data_dir=/usr/local/cuda" \
|
|
||||||
PATH="/usr/local/cuda/bin:${PATH}" \
|
|
||||||
LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
|
|
||||||
|
|
||||||
# System deps (+dev headers for building C/C++ extensions)
|
# System deps
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
python3.10 python3.10-venv python3.10-distutils python3.10-dev python3-pip \
|
python3.12 python3.12-venv python3.12-dev python3-pip python-is-python3 \
|
||||||
git wget curl unzip ca-certificates git-lfs \
|
git wget curl unzip ca-certificates nano less \
|
||||||
build-essential g++ cmake \
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
libsndfile1 libsndfile1-dev libffi-dev \
|
&& mkdir -p /data
|
||||||
ffmpeg \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Use python3.10 everywhere
|
# Recorder port
|
||||||
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 \
|
EXPOSE 8789
|
||||||
&& update-alternatives --install /usr/bin/pip pip /usr/bin/pip3 1
|
|
||||||
|
|
||||||
# ---- No cuDNN repo meddling needed if using TF 2.17.x ----
|
# Script root
|
||||||
|
WORKDIR /root/mww-scripts
|
||||||
|
|
||||||
# Python deps
|
# Bash environment
|
||||||
# Order is important. onnxruntime, tensorflow and torch have
|
COPY --chown=root:root --chmod=0755 .bashrc /root/
|
||||||
# to be installed in the order below or their cuda dependencies
|
|
||||||
# will conflict.
|
|
||||||
COPY requirements.txt /tmp/requirements.txt
|
|
||||||
RUN pip install --upgrade pip \
|
|
||||||
&& pip install "numpy==1.26.4" "cython>=0.29.36" \
|
|
||||||
&& pip install -r /tmp/requirements.txt \
|
|
||||||
&& pip install "onnxruntime-gpu[cuda]>=1.16.0" \
|
|
||||||
&& pip install "tensorflow[and-cuda]==2.18.0" \
|
|
||||||
"tensorboard==2.18.0" \
|
|
||||||
"tensorboard-data-server==0.7.2" \
|
|
||||||
"tensorflow-io-gcs-filesystem==0.37.1" \
|
|
||||||
&& pip install \
|
|
||||||
torch==2.7.1 \
|
|
||||||
torchaudio==2.7.1 \
|
|
||||||
--index-url https://download.pytorch.org/whl/cu128
|
|
||||||
|
|
||||||
# Workspace + notebook fallback
|
# Root-level entrypoints
|
||||||
RUN mkdir -p /data
|
COPY --chown=root:root --chmod=0755 \
|
||||||
WORKDIR /data
|
train_wake_word \
|
||||||
COPY microWakeWord_training_notebook.ipynb /root/
|
run_recorder.sh \
|
||||||
|
recorder_server.py \
|
||||||
|
requirements.txt \
|
||||||
|
/root/mww-scripts/
|
||||||
|
|
||||||
# Startup script (copies default notebook if missing)
|
# CLI folder
|
||||||
COPY startup.sh /usr/local/bin/startup.sh
|
COPY --chown=root:root cli/ /root/mww-scripts/cli/
|
||||||
RUN chmod +x /usr/local/bin/startup.sh
|
|
||||||
|
|
||||||
EXPOSE 8888
|
# Make all CLI scripts executable (avoids "Permission denied")
|
||||||
|
RUN chmod -R a+x /root/mww-scripts/cli
|
||||||
|
|
||||||
CMD ["/bin/bash", "-lc", "/usr/local/bin/startup.sh && \
|
# Static UI for recorder
|
||||||
exec jupyter lab --ip=0.0.0.0 --port=8888 --no-browser --allow-root \
|
COPY --chown=root:root --chmod=0644 static/index.html /root/mww-scripts/static/index.html
|
||||||
--ServerApp.token='' --ServerApp.password='' --ServerApp.root_dir=/data"]
|
|
||||||
|
# recorder server
|
||||||
|
CMD ["/bin/bash", "-lc", "/root/mww-scripts/run_recorder.sh"]
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
660
recorder_server.py
Normal file
660
recorder_server.py
Normal file
@@ -0,0 +1,660 @@
|
|||||||
|
# recorder_server.py
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import threading
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, List, Optional, Tuple
|
||||||
|
|
||||||
|
from fastapi import FastAPI, UploadFile, File, Form
|
||||||
|
from fastapi.responses import HTMLResponse, JSONResponse
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
|
ROOT_DIR = Path(__file__).resolve().parent
|
||||||
|
|
||||||
|
# In Docker CLI world, DATA_DIR should be /data
|
||||||
|
DATA_DIR = Path(os.environ.get("DATA_DIR", "/data")).resolve()
|
||||||
|
|
||||||
|
# UI files live next to this script by default
|
||||||
|
STATIC_DIR = Path(os.environ.get("STATIC_DIR", str(ROOT_DIR / "static"))).resolve()
|
||||||
|
|
||||||
|
# Personal samples MUST land in /data/personal_samples for your CLI pipeline
|
||||||
|
PERSONAL_DIR = Path(os.environ.get("PERSONAL_DIR", str(DATA_DIR / "personal_samples"))).resolve()
|
||||||
|
|
||||||
|
# CLI folder inside repo
|
||||||
|
CLI_DIR = Path(os.environ.get("CLI_DIR", str(ROOT_DIR / "cli"))).resolve()
|
||||||
|
|
||||||
|
DATASET_CLEANUP_ARCHIVES = os.environ.get("REC_DATASET_CLEANUP_ARCHIVES", "false").lower() in ("1", "true", "yes", "y")
|
||||||
|
DATASET_CLEANUP_INTERMEDIATE = os.environ.get("REC_DATASET_CLEANUP_INTERMEDIATE_FILES", "false").lower() in ("1", "true", "yes", "y")
|
||||||
|
|
||||||
|
TRAIN_CMD = os.environ.get(
|
||||||
|
"TRAIN_CMD",
|
||||||
|
f"source '{DATA_DIR}/.venv/bin/activate' && train_wake_word --data-dir '{DATA_DIR}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
TAKES_PER_SPEAKER_DEFAULT = int(os.environ.get("REC_TAKES_PER_SPEAKER", "10"))
|
||||||
|
SPEAKERS_TOTAL_DEFAULT = int(os.environ.get("REC_SPEAKERS_TOTAL", "1"))
|
||||||
|
|
||||||
|
# Tail lines shown to UI
|
||||||
|
TRAIN_LOG_TAIL_LINES = int(os.environ.get("REC_TRAIN_LOG_TAIL_LINES", "400"))
|
||||||
|
# Safety cap for reads (bytes) to avoid giant file reads
|
||||||
|
TRAIN_LOG_MAX_BYTES = int(os.environ.get("REC_TRAIN_LOG_MAX_BYTES", str(512 * 1024))) # 512KB
|
||||||
|
|
||||||
|
app = FastAPI(title="microWakeWord Personal Recorder")
|
||||||
|
|
||||||
|
STATIC_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
||||||
|
|
||||||
|
|
||||||
|
def safe_name(raw: str) -> str:
|
||||||
|
s = (raw or "").strip().lower()
|
||||||
|
s = re.sub(r"\s+", "_", s)
|
||||||
|
s = re.sub(r"[^a-z0-9_]+", "", s)
|
||||||
|
s = re.sub(r"^_+|_+$", "", s)
|
||||||
|
return s or "wakeword"
|
||||||
|
|
||||||
|
|
||||||
|
STATE: Dict[str, Any] = {
|
||||||
|
"raw_phrase": None,
|
||||||
|
"safe_word": None,
|
||||||
|
|
||||||
|
"speakers_total": SPEAKERS_TOTAL_DEFAULT,
|
||||||
|
"takes_per_speaker": TAKES_PER_SPEAKER_DEFAULT,
|
||||||
|
|
||||||
|
"takes_received": 0,
|
||||||
|
"takes": [],
|
||||||
|
|
||||||
|
"training": {
|
||||||
|
"running": False,
|
||||||
|
"exit_code": None,
|
||||||
|
"log_lines": [], # legacy in-memory tail (kept, but not relied on)
|
||||||
|
"log_path": None, # path to recorder_training.log
|
||||||
|
"safe_word": None,
|
||||||
|
|
||||||
|
# prevent UI duplication when UI appends:
|
||||||
|
"last_sent_tail": [], # last tail snapshot (list of lines)
|
||||||
|
"last_log_size": 0, # detect truncation
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
STATE_LOCK = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_personal_samples_dir():
|
||||||
|
PERSONAL_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
for p in PERSONAL_DIR.glob("*.wav"):
|
||||||
|
try:
|
||||||
|
p.unlink()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _clear_training_log():
|
||||||
|
"""
|
||||||
|
Truncate recorder_training.log for a fresh session.
|
||||||
|
"""
|
||||||
|
log_path = DATA_DIR / "recorder_training.log"
|
||||||
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
with open(log_path, "w", encoding="utf-8") as lf:
|
||||||
|
lf.write("================================================================================\n")
|
||||||
|
lf.write("===== New recorder session started =====\n")
|
||||||
|
lf.write("================================================================================\n")
|
||||||
|
lf.flush()
|
||||||
|
|
||||||
|
with STATE_LOCK:
|
||||||
|
STATE["training"]["log_path"] = str(log_path)
|
||||||
|
STATE["training"]["log_lines"] = []
|
||||||
|
STATE["training"]["last_sent_tail"] = []
|
||||||
|
STATE["training"]["last_log_size"] = 0
|
||||||
|
|
||||||
|
|
||||||
|
def _append_train_log(line: str):
|
||||||
|
line = (line or "").rstrip("\n")
|
||||||
|
with STATE_LOCK:
|
||||||
|
buf: List[str] = STATE["training"]["log_lines"]
|
||||||
|
buf.append(line)
|
||||||
|
if len(buf) > 250:
|
||||||
|
del buf[: (len(buf) - 250)]
|
||||||
|
|
||||||
|
|
||||||
|
def _title_from_phrase(raw_phrase: str) -> str:
|
||||||
|
s = re.sub(r"[^a-zA-Z0-9 ]+", " ", raw_phrase or "").strip()
|
||||||
|
s = re.sub(r"\s+", " ", s)
|
||||||
|
return s.title() if s else ""
|
||||||
|
|
||||||
|
|
||||||
|
def _run_streamed(
|
||||||
|
cmd: List[str],
|
||||||
|
cwd: Path,
|
||||||
|
log_path: Path,
|
||||||
|
header: Optional[str] = None,
|
||||||
|
env: Optional[Dict[str, str]] = None,
|
||||||
|
) -> int:
|
||||||
|
if header:
|
||||||
|
_append_train_log(header)
|
||||||
|
|
||||||
|
_append_train_log("→ " + " ".join(cmd))
|
||||||
|
|
||||||
|
with open(log_path, "a", encoding="utf-8") as lf:
|
||||||
|
lf.write("\n" + ("=" * 80) + "\n")
|
||||||
|
if header:
|
||||||
|
lf.write(header + "\n")
|
||||||
|
lf.write("→ " + " ".join(cmd) + "\n")
|
||||||
|
lf.flush()
|
||||||
|
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
cwd=str(cwd),
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
text=True,
|
||||||
|
bufsize=1,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert proc.stdout is not None
|
||||||
|
for line in proc.stdout:
|
||||||
|
lf.write(line)
|
||||||
|
lf.flush()
|
||||||
|
_append_train_log(line)
|
||||||
|
|
||||||
|
return proc.wait()
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_training_venv(log_path: Path) -> None:
|
||||||
|
activate = DATA_DIR / ".venv" / "bin" / "activate"
|
||||||
|
if activate.exists():
|
||||||
|
_append_train_log("✅ Training venv found (skipping setup_python_venv)")
|
||||||
|
return
|
||||||
|
|
||||||
|
setup = CLI_DIR / "setup_python_venv"
|
||||||
|
if not setup.exists():
|
||||||
|
raise RuntimeError(f"Missing setup_python_venv at: {setup}")
|
||||||
|
|
||||||
|
rc = _run_streamed(
|
||||||
|
["bash", "-lc", f"cd '{DATA_DIR}' && '{setup}' --data-dir='{DATA_DIR}'"],
|
||||||
|
cwd=DATA_DIR,
|
||||||
|
log_path=log_path,
|
||||||
|
header="===== Ensuring Python venv (/data/.venv) =====",
|
||||||
|
)
|
||||||
|
|
||||||
|
if rc != 0:
|
||||||
|
raise RuntimeError(f"setup_python_venv failed (exit_code={rc})")
|
||||||
|
|
||||||
|
if not activate.exists():
|
||||||
|
raise RuntimeError(f"setup_python_venv finished, but {activate} is still missing")
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_training_datasets(log_path: Path) -> None:
|
||||||
|
setup = CLI_DIR / "setup_training_datasets"
|
||||||
|
if not setup.exists():
|
||||||
|
raise RuntimeError(f"Missing setup_training_datasets at: {setup}")
|
||||||
|
|
||||||
|
cleanup_arch = "true" if DATASET_CLEANUP_ARCHIVES else "false"
|
||||||
|
cleanup_inter = "true" if DATASET_CLEANUP_INTERMEDIATE else "false"
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"bash",
|
||||||
|
"-lc",
|
||||||
|
(
|
||||||
|
f"cd '{DATA_DIR}' && "
|
||||||
|
f"'{setup}' "
|
||||||
|
f"--cleanup-archives='{cleanup_arch}' "
|
||||||
|
f"--cleanup-intermediate-files='{cleanup_inter}' "
|
||||||
|
f"--data-dir='{DATA_DIR}'"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
rc = _run_streamed(
|
||||||
|
cmd,
|
||||||
|
cwd=DATA_DIR,
|
||||||
|
log_path=log_path,
|
||||||
|
header="===== Ensuring training datasets (setup_training_datasets) =====",
|
||||||
|
)
|
||||||
|
|
||||||
|
if rc != 0:
|
||||||
|
raise RuntimeError(f"setup_training_datasets failed (exit_code={rc})")
|
||||||
|
|
||||||
|
|
||||||
|
def _read_tail_lines(log_path: Path, max_lines: int) -> List[str]:
|
||||||
|
"""
|
||||||
|
Read the last N lines, bounded by TRAIN_LOG_MAX_BYTES.
|
||||||
|
Returns list of lines (no trailing newlines).
|
||||||
|
"""
|
||||||
|
if not log_path.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
size = log_path.stat().st_size
|
||||||
|
start = max(0, size - TRAIN_LOG_MAX_BYTES)
|
||||||
|
with open(log_path, "rb") as f:
|
||||||
|
f.seek(start)
|
||||||
|
data = f.read()
|
||||||
|
text = data.decode("utf-8", errors="replace")
|
||||||
|
lines = text.splitlines()
|
||||||
|
if len(lines) <= max_lines:
|
||||||
|
return lines
|
||||||
|
return lines[-max_lines:]
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_new_lines(prev_tail: List[str], new_tail: List[str]) -> List[str]:
|
||||||
|
"""
|
||||||
|
Given previous and current tail snapshots, return only the newly-added lines.
|
||||||
|
Works even if the tail window shifts.
|
||||||
|
"""
|
||||||
|
if not prev_tail:
|
||||||
|
return new_tail
|
||||||
|
|
||||||
|
# Try to find the largest suffix of prev_tail that matches a prefix of new_tail
|
||||||
|
max_k = min(len(prev_tail), len(new_tail))
|
||||||
|
for k in range(max_k, 0, -1):
|
||||||
|
if prev_tail[-k:] == new_tail[:k]:
|
||||||
|
return new_tail[k:]
|
||||||
|
|
||||||
|
# If no overlap, just return full new_tail (probably truncation or big jump)
|
||||||
|
return new_tail
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------- output artifact normalization --------------------
|
||||||
|
|
||||||
|
def _find_latest_output_pair(output_dir: Path) -> Tuple[Optional[Path], Optional[Path]]:
|
||||||
|
"""
|
||||||
|
Find the most recently modified .tflite and its matching .json (same basename)
|
||||||
|
in output_dir. Falls back to newest .json if an exact match doesn't exist.
|
||||||
|
Returns (tflite_path, json_path) or (None, None).
|
||||||
|
"""
|
||||||
|
if not output_dir.exists():
|
||||||
|
return (None, None)
|
||||||
|
|
||||||
|
tflites = sorted(output_dir.glob("*.tflite"), key=lambda p: p.stat().st_mtime, reverse=True)
|
||||||
|
if not tflites:
|
||||||
|
return (None, None)
|
||||||
|
|
||||||
|
tfl = tflites[0]
|
||||||
|
js = tfl.with_suffix(".json")
|
||||||
|
if js.exists():
|
||||||
|
return (tfl, js)
|
||||||
|
|
||||||
|
jsons = sorted(output_dir.glob("*.json"), key=lambda p: p.stat().st_mtime, reverse=True)
|
||||||
|
return (tfl, jsons[0] if jsons else None)
|
||||||
|
|
||||||
|
|
||||||
|
def _deep_replace_strings(obj: Any, old: str, new: str) -> Any:
|
||||||
|
"""
|
||||||
|
Recursively replace occurrences of old in any string values with new.
|
||||||
|
"""
|
||||||
|
if isinstance(obj, str):
|
||||||
|
return obj.replace(old, new)
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [_deep_replace_strings(x, old, new) for x in obj]
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: _deep_replace_strings(v, old, new) for k, v in obj.items()}
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_output_artifacts(safe_word: str, log_path: Path) -> None:
|
||||||
|
"""
|
||||||
|
Rename output artifacts to <safe_word>.tflite / <safe_word>.json
|
||||||
|
and patch the JSON so it references the renamed tflite.
|
||||||
|
|
||||||
|
Handles weird trainer names like ____r_.tflite by normalizing post-run.
|
||||||
|
"""
|
||||||
|
out_dir = DATA_DIR / "output"
|
||||||
|
tfl, js = _find_latest_output_pair(out_dir)
|
||||||
|
|
||||||
|
if not tfl:
|
||||||
|
_append_train_log(f"⚠️ No .tflite found in {out_dir}")
|
||||||
|
return
|
||||||
|
|
||||||
|
new_tfl = out_dir / f"{safe_word}.tflite"
|
||||||
|
new_js = out_dir / f"{safe_word}.json"
|
||||||
|
old_tfl_name = tfl.name
|
||||||
|
|
||||||
|
# Already normalized
|
||||||
|
if tfl.name == new_tfl.name and (js and js.name == new_js.name):
|
||||||
|
_append_train_log(f"✅ Output names already normalized: {new_tfl.name}")
|
||||||
|
return
|
||||||
|
|
||||||
|
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
|
||||||
|
def backup_if_exists(p: Path, suffix: str) -> None:
|
||||||
|
if p.exists():
|
||||||
|
bk = out_dir / f"{safe_word}.{ts}.bak{suffix}"
|
||||||
|
shutil.move(str(p), str(bk))
|
||||||
|
_append_train_log(f"↪️ Backed up existing {p.name} → {bk.name}")
|
||||||
|
|
||||||
|
# Avoid clobbering existing target files (back them up)
|
||||||
|
if new_tfl.exists() and new_tfl.resolve() != tfl.resolve():
|
||||||
|
backup_if_exists(new_tfl, ".tflite")
|
||||||
|
if new_js.exists() and (not js or new_js.resolve() != js.resolve()):
|
||||||
|
backup_if_exists(new_js, ".json")
|
||||||
|
|
||||||
|
# Rename tflite
|
||||||
|
if tfl.resolve() != new_tfl.resolve():
|
||||||
|
new_tfl.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.move(str(tfl), str(new_tfl))
|
||||||
|
_append_train_log(f"✅ Renamed model: {old_tfl_name} → {new_tfl.name}")
|
||||||
|
|
||||||
|
# Rename + patch json if present
|
||||||
|
if js and js.exists():
|
||||||
|
# Read JSON before move (safer if we want the old name)
|
||||||
|
try:
|
||||||
|
data = json.loads(js.read_text(encoding="utf-8"))
|
||||||
|
except Exception:
|
||||||
|
data = None
|
||||||
|
|
||||||
|
if js.resolve() != new_js.resolve():
|
||||||
|
shutil.move(str(js), str(new_js))
|
||||||
|
_append_train_log(f"✅ Renamed metadata: {js.name} → {new_js.name}")
|
||||||
|
|
||||||
|
if data is not None:
|
||||||
|
patched = _deep_replace_strings(data, old_tfl_name, new_tfl.name)
|
||||||
|
|
||||||
|
# Patch common keys if present
|
||||||
|
for key in ("model", "model_file", "model_filename", "tflite", "tflite_file", "tflite_filename"):
|
||||||
|
if isinstance(patched, dict) and key in patched and isinstance(patched[key], str):
|
||||||
|
patched[key] = new_tfl.name
|
||||||
|
|
||||||
|
new_js.write_text(json.dumps(patched, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
|
||||||
|
_append_train_log(f"✅ Patched JSON to reference: {new_tfl.name}")
|
||||||
|
else:
|
||||||
|
_append_train_log("⚠️ No .json found to patch (model renamed only)")
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------- training worker --------------------
|
||||||
|
|
||||||
|
def _run_training_background(safe_word: str, allow_no_personal: bool):
|
||||||
|
with STATE_LOCK:
|
||||||
|
raw_phrase = STATE.get("raw_phrase") or ""
|
||||||
|
|
||||||
|
wake_word_title = _title_from_phrase(raw_phrase)
|
||||||
|
|
||||||
|
with STATE_LOCK:
|
||||||
|
STATE["training"]["running"] = True
|
||||||
|
STATE["training"]["exit_code"] = None
|
||||||
|
STATE["training"]["log_lines"] = []
|
||||||
|
STATE["training"]["safe_word"] = safe_word
|
||||||
|
STATE["training"]["last_sent_tail"] = []
|
||||||
|
STATE["training"]["last_log_size"] = 0
|
||||||
|
log_path = Path(str(DATA_DIR / "recorder_training.log"))
|
||||||
|
STATE["training"]["log_path"] = str(log_path)
|
||||||
|
|
||||||
|
_append_train_log("================================================================================")
|
||||||
|
_append_train_log("===== Recorder Training Run =====")
|
||||||
|
_append_train_log("================================================================================")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(log_path, "a", encoding="utf-8") as lf:
|
||||||
|
lf.write("\n" + ("=" * 80) + "\n")
|
||||||
|
lf.write("===== Recorder Training Run =====\n")
|
||||||
|
lf.write(("=" * 80) + "\n")
|
||||||
|
lf.flush()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
_ensure_training_venv(log_path)
|
||||||
|
_ensure_training_datasets(log_path)
|
||||||
|
|
||||||
|
if wake_word_title:
|
||||||
|
cmd_str = f"{TRAIN_CMD} '{safe_word}' '{wake_word_title}'"
|
||||||
|
else:
|
||||||
|
cmd_str = f"{TRAIN_CMD} '{safe_word}'"
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["MWW_ALLOW_NO_PERSONAL"] = "true" if allow_no_personal else "false"
|
||||||
|
|
||||||
|
_append_train_log("===== Training (train_wake_word) =====")
|
||||||
|
_append_train_log(f"→ Running: {cmd_str}")
|
||||||
|
|
||||||
|
with open(log_path, "a", encoding="utf-8") as lf:
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
["bash", "-lc", cmd_str],
|
||||||
|
cwd=str(DATA_DIR),
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
text=True,
|
||||||
|
bufsize=1,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
assert proc.stdout is not None
|
||||||
|
for line in proc.stdout:
|
||||||
|
lf.write(line)
|
||||||
|
lf.flush()
|
||||||
|
_append_train_log(line)
|
||||||
|
|
||||||
|
rc = proc.wait()
|
||||||
|
|
||||||
|
_append_train_log(f"✓ Training finished (exit_code={rc})")
|
||||||
|
with STATE_LOCK:
|
||||||
|
STATE["training"]["exit_code"] = rc
|
||||||
|
|
||||||
|
# Normalize output artifact names on success
|
||||||
|
if rc == 0:
|
||||||
|
_normalize_output_artifacts(safe_word, log_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
_append_train_log(f"✗ Training crashed: {e!r}")
|
||||||
|
with STATE_LOCK:
|
||||||
|
STATE["training"]["exit_code"] = 999
|
||||||
|
|
||||||
|
finally:
|
||||||
|
with STATE_LOCK:
|
||||||
|
STATE["training"]["running"] = False
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/", response_class=HTMLResponse)
|
||||||
|
def index():
|
||||||
|
html_path = STATIC_DIR / "index.html"
|
||||||
|
if not html_path.exists():
|
||||||
|
return HTMLResponse(
|
||||||
|
"<h3>Missing UI</h3><p>Create <code>static/index.html</code>.</p>",
|
||||||
|
status_code=500,
|
||||||
|
)
|
||||||
|
return HTMLResponse(html_path.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/start_session")
|
||||||
|
def start_session(payload: Dict[str, Any]):
|
||||||
|
raw = (payload.get("phrase") or "").strip()
|
||||||
|
if not raw:
|
||||||
|
return JSONResponse({"ok": False, "error": "phrase is required"}, status_code=400)
|
||||||
|
|
||||||
|
safe = safe_name(raw)
|
||||||
|
|
||||||
|
speakers_total = int(payload.get("speakers_total") or SPEAKERS_TOTAL_DEFAULT)
|
||||||
|
takes_per_speaker = int(payload.get("takes_per_speaker") or TAKES_PER_SPEAKER_DEFAULT)
|
||||||
|
|
||||||
|
speakers_total = max(1, min(10, speakers_total))
|
||||||
|
takes_per_speaker = max(1, min(50, takes_per_speaker))
|
||||||
|
|
||||||
|
with STATE_LOCK:
|
||||||
|
STATE["raw_phrase"] = raw
|
||||||
|
STATE["safe_word"] = safe
|
||||||
|
STATE["speakers_total"] = speakers_total
|
||||||
|
STATE["takes_per_speaker"] = takes_per_speaker
|
||||||
|
STATE["takes_received"] = 0
|
||||||
|
STATE["takes"] = []
|
||||||
|
|
||||||
|
_reset_personal_samples_dir()
|
||||||
|
|
||||||
|
# Always wipe log on start_session (even if same wakeword)
|
||||||
|
_clear_training_log()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"ok": True,
|
||||||
|
"raw_phrase": raw,
|
||||||
|
"safe_word": safe,
|
||||||
|
"speakers_total": speakers_total,
|
||||||
|
"takes_per_speaker": takes_per_speaker,
|
||||||
|
"takes_total": speakers_total * takes_per_speaker,
|
||||||
|
"personal_dir": str(PERSONAL_DIR),
|
||||||
|
"data_dir": str(DATA_DIR),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/session")
|
||||||
|
def get_session():
|
||||||
|
with STATE_LOCK:
|
||||||
|
return {
|
||||||
|
"ok": True,
|
||||||
|
"raw_phrase": STATE["raw_phrase"],
|
||||||
|
"safe_word": STATE["safe_word"],
|
||||||
|
"speakers_total": STATE["speakers_total"],
|
||||||
|
"takes_per_speaker": STATE["takes_per_speaker"],
|
||||||
|
"takes_received": STATE["takes_received"],
|
||||||
|
"takes": list(STATE["takes"]),
|
||||||
|
"training": dict(STATE["training"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/upload_take")
|
||||||
|
async def upload_take(
|
||||||
|
speaker_index: int = Form(...),
|
||||||
|
take_index: int = Form(...),
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
):
|
||||||
|
with STATE_LOCK:
|
||||||
|
safe_word = STATE["safe_word"]
|
||||||
|
speakers_total = int(STATE["speakers_total"])
|
||||||
|
takes_per_speaker = int(STATE["takes_per_speaker"])
|
||||||
|
|
||||||
|
if not safe_word:
|
||||||
|
return JSONResponse({"ok": False, "error": "No active session. Call /api/start_session first."}, status_code=400)
|
||||||
|
|
||||||
|
if speaker_index < 1 or speaker_index > speakers_total:
|
||||||
|
return JSONResponse({"ok": False, "error": f"speaker_index must be 1..{speakers_total}"}, status_code=400)
|
||||||
|
|
||||||
|
if take_index < 1 or take_index > takes_per_speaker:
|
||||||
|
return JSONResponse({"ok": False, "error": f"take_index must be 1..{takes_per_speaker}"}, status_code=400)
|
||||||
|
|
||||||
|
PERSONAL_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
out_name = f"speaker{speaker_index:02d}_take{take_index:02d}.wav"
|
||||||
|
out_path = PERSONAL_DIR / out_name
|
||||||
|
|
||||||
|
data = await file.read()
|
||||||
|
if not data or len(data) < 44:
|
||||||
|
return JSONResponse({"ok": False, "error": "Empty/invalid file"}, status_code=400)
|
||||||
|
|
||||||
|
out_path.write_bytes(data)
|
||||||
|
|
||||||
|
with STATE_LOCK:
|
||||||
|
if out_name not in STATE["takes"]:
|
||||||
|
STATE["takes"].append(out_name)
|
||||||
|
STATE["takes_received"] = len(STATE["takes"])
|
||||||
|
|
||||||
|
return {"ok": True, "saved_as": out_name, "takes_received": STATE["takes_received"]}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/train")
|
||||||
|
def train_now(payload: Dict[str, Any] = None):
|
||||||
|
payload = payload or {}
|
||||||
|
allow_no_personal = bool(payload.get("allow_no_personal", False))
|
||||||
|
|
||||||
|
with STATE_LOCK:
|
||||||
|
safe_word = STATE["safe_word"]
|
||||||
|
takes_received = int(STATE["takes_received"])
|
||||||
|
speakers_total = int(STATE["speakers_total"])
|
||||||
|
takes_per_speaker = int(STATE["takes_per_speaker"])
|
||||||
|
training_running = bool(STATE["training"]["running"])
|
||||||
|
|
||||||
|
takes_total = speakers_total * takes_per_speaker
|
||||||
|
|
||||||
|
if training_running:
|
||||||
|
return JSONResponse({"ok": False, "error": "Training already running"}, status_code=400)
|
||||||
|
|
||||||
|
if not safe_word:
|
||||||
|
return JSONResponse({"ok": False, "error": "No active session"}, status_code=400)
|
||||||
|
|
||||||
|
min_required = max(1, min(3, takes_total))
|
||||||
|
|
||||||
|
if takes_received == 0 and not allow_no_personal:
|
||||||
|
return JSONResponse(
|
||||||
|
{
|
||||||
|
"ok": False,
|
||||||
|
"error": f"No personal voice samples recorded (0/{takes_total}).",
|
||||||
|
"code": "NO_PERSONAL_SAMPLES",
|
||||||
|
"message": "You can train without personal voices, or record samples first.",
|
||||||
|
"takes_total": takes_total,
|
||||||
|
},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
if 0 < takes_received < min_required:
|
||||||
|
return JSONResponse(
|
||||||
|
{
|
||||||
|
"ok": False,
|
||||||
|
"error": f"Not enough takes yet ({takes_received}/{takes_total}).",
|
||||||
|
"code": "NOT_ENOUGH_TAKES",
|
||||||
|
"min_required": min_required,
|
||||||
|
"takes_total": takes_total,
|
||||||
|
},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
t = threading.Thread(target=_run_training_background, args=(safe_word, allow_no_personal), daemon=True)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"ok": True,
|
||||||
|
"started": True,
|
||||||
|
"safe_word": safe_word,
|
||||||
|
"personal_samples_used": takes_received >= min_required,
|
||||||
|
"allow_no_personal": allow_no_personal,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/train_status")
|
||||||
|
def train_status():
|
||||||
|
"""
|
||||||
|
Return only NEW lines since last poll (prevents UI duplication spam even if UI appends).
|
||||||
|
"""
|
||||||
|
with STATE_LOCK:
|
||||||
|
tr = dict(STATE["training"])
|
||||||
|
log_path_str = tr.get("log_path")
|
||||||
|
prev_tail = list(STATE["training"].get("last_sent_tail") or [])
|
||||||
|
prev_size = int(STATE["training"].get("last_log_size") or 0)
|
||||||
|
|
||||||
|
new_lines: List[str] = []
|
||||||
|
full_tail: List[str] = []
|
||||||
|
size_now = 0
|
||||||
|
|
||||||
|
if log_path_str:
|
||||||
|
p = Path(log_path_str)
|
||||||
|
if p.exists():
|
||||||
|
try:
|
||||||
|
size_now = int(p.stat().st_size)
|
||||||
|
except Exception:
|
||||||
|
size_now = 0
|
||||||
|
|
||||||
|
# If file was truncated/cleared, reset history
|
||||||
|
if size_now < prev_size:
|
||||||
|
prev_tail = []
|
||||||
|
|
||||||
|
full_tail = _read_tail_lines(p, TRAIN_LOG_TAIL_LINES)
|
||||||
|
new_lines = _compute_new_lines(prev_tail, full_tail)
|
||||||
|
|
||||||
|
# Save snapshot for next poll
|
||||||
|
with STATE_LOCK:
|
||||||
|
STATE["training"]["last_sent_tail"] = full_tail
|
||||||
|
STATE["training"]["last_log_size"] = size_now
|
||||||
|
|
||||||
|
tr["log_text"] = "\n".join(new_lines) # ONLY new lines
|
||||||
|
tr["log_tail_preview"] = "\n".join(full_tail) # optional: handy for debugging
|
||||||
|
return {"ok": True, "training": tr}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/reset_recordings")
|
||||||
|
def reset_recordings():
|
||||||
|
_reset_personal_samples_dir()
|
||||||
|
with STATE_LOCK:
|
||||||
|
STATE["takes_received"] = 0
|
||||||
|
STATE["takes"] = []
|
||||||
|
return {"ok": True}
|
||||||
@@ -1,28 +1,10 @@
|
|||||||
# --- Core training (Microwakeword) ---
|
# --- Packages needed by our scripts ---
|
||||||
|
|
||||||
numpy==1.26.4
|
numpy==1.26.4
|
||||||
scipy==1.12.0
|
scipy==1.12.0
|
||||||
librosa==0.10.2.post1
|
librosa==0.10.2.post1
|
||||||
soundfile==0.12.1
|
soundfile==0.12.1
|
||||||
soxr==0.5.0.post1
|
|
||||||
audiomentations==0.38.0
|
|
||||||
webrtcvad==2.0.10
|
|
||||||
tqdm==4.67.1
|
tqdm==4.67.1
|
||||||
scikit-learn==1.6.0
|
scikit-learn==1.6.0
|
||||||
numba==0.60.0
|
numba==0.63.1
|
||||||
joblib==1.4.2
|
PyYAML==6.0.3
|
||||||
pandas==2.2.3
|
|
||||||
pymicro_features @ git+https://github.com/puddly/pymicro-features@e1d3f88183e12bb8af2df9e399ea157af7393762
|
|
||||||
audio-metadata @ git+https://github.com/whatsnowplaying/audio-metadata@d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f
|
|
||||||
bitstruct==8.19.0
|
|
||||||
|
|
||||||
# --- Piper sample generation ---
|
|
||||||
piper-tts>=1.2.0
|
|
||||||
piper-phonemize-cross==1.2.1
|
|
||||||
|
|
||||||
# --- Notebook / tooling ---
|
|
||||||
ipykernel==6.29.5
|
|
||||||
jupyterlab==4.3.4
|
|
||||||
ipywidgets==8.1.5
|
|
||||||
matplotlib-inline==0.1.7
|
|
||||||
rich==13.9.4
|
|
||||||
|
|||||||
64
run_recorder.sh
Normal file
64
run_recorder.sh
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
ROOTDIR="$(dirname "$(realpath "$0")")"
|
||||||
|
|
||||||
|
# Training convention
|
||||||
|
DATA_DIR="${DATA_DIR:-/data}"
|
||||||
|
HOST="${REC_HOST:-0.0.0.0}"
|
||||||
|
PORT="${REC_PORT:-8888}"
|
||||||
|
|
||||||
|
# Keep recorder deps separate from training venv
|
||||||
|
VENV_DIR="${DATA_DIR}/.recorder-venv"
|
||||||
|
PY="${VENV_DIR}/bin/python"
|
||||||
|
PIP="${PY} -m pip"
|
||||||
|
PIN_FILE="${VENV_DIR}/.pinned_installed"
|
||||||
|
|
||||||
|
FASTAPI_VERSION="${REC_FASTAPI_VERSION:-0.115.6}"
|
||||||
|
UVICORN_VERSION="${REC_UVICORN_VERSION:-0.30.6}"
|
||||||
|
PY_MULTIPART_VERSION="${REC_PY_MULTIPART_VERSION:-0.0.9}"
|
||||||
|
|
||||||
|
echo "microWakeWord Recorder (Docker)"
|
||||||
|
echo "-> ROOTDIR: ${ROOTDIR}"
|
||||||
|
echo "-> DATA_DIR: ${DATA_DIR}"
|
||||||
|
echo "-> URL: http://localhost:${PORT}/"
|
||||||
|
|
||||||
|
mkdir -p "${DATA_DIR}"
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Recorder venv (separate)
|
||||||
|
# -----------------------------
|
||||||
|
if [[ ! -x "${PY}" ]]; then
|
||||||
|
echo "Creating recorder venv: ${VENV_DIR}"
|
||||||
|
python3 -m venv "${VENV_DIR}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# shellcheck disable=SC1091
|
||||||
|
source "${VENV_DIR}/bin/activate"
|
||||||
|
|
||||||
|
if [[ ! -f "${PIN_FILE}" ]]; then
|
||||||
|
echo "Installing pinned recorder deps"
|
||||||
|
${PIP} install -U pip setuptools wheel
|
||||||
|
${PIP} install \
|
||||||
|
"fastapi==${FASTAPI_VERSION}" \
|
||||||
|
"uvicorn[standard]==${UVICORN_VERSION}" \
|
||||||
|
"python-multipart==${PY_MULTIPART_VERSION}"
|
||||||
|
touch "${PIN_FILE}"
|
||||||
|
else
|
||||||
|
echo "Reusing existing recorder venv (no upgrades)"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Recorder server env
|
||||||
|
# -----------------------------
|
||||||
|
export DATA_DIR="${DATA_DIR}"
|
||||||
|
export STATIC_DIR="${ROOTDIR}/static"
|
||||||
|
export PERSONAL_DIR="${DATA_DIR}/personal_samples"
|
||||||
|
|
||||||
|
# IMPORTANT: leave training venv creation to /api/train inside recorder_server.py
|
||||||
|
# but still set TRAIN_CMD so the server knows how to invoke training once ready
|
||||||
|
export TRAIN_CMD="source '${DATA_DIR}/.venv/bin/activate' && train_wake_word --data-dir='${DATA_DIR}'"
|
||||||
|
|
||||||
|
echo "Launching uvicorn on ${HOST}:${PORT}"
|
||||||
|
cd "${ROOTDIR}"
|
||||||
|
exec "${VENV_DIR}/bin/uvicorn" recorder_server:app --host "${HOST}" --port "${PORT}"
|
||||||
23
startup.sh
23
startup.sh
@@ -1,23 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
: "${NB_UID:=0}"
|
|
||||||
: "${NB_GID:=0}"
|
|
||||||
umask 002
|
|
||||||
|
|
||||||
NOTEBOOK_SRC="/root/microWakeWord_training_notebook.ipynb"
|
|
||||||
NOTEBOOK_DST="/data/microWakeWord_training_notebook.ipynb"
|
|
||||||
|
|
||||||
mkdir -p /data /data/generated_samples /data/personal_samples
|
|
||||||
|
|
||||||
if [[ ! -f "$NOTEBOOK_DST" ]]; then
|
|
||||||
echo "No training notebook found in /data; copying default…"
|
|
||||||
cp -n "$NOTEBOOK_SRC" "$NOTEBOOK_DST"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Try to align ownership for convenience (ignore errors if not permitted)
|
|
||||||
if [[ "$NB_UID" != "0" || "$NB_GID" != "0" ]]; then
|
|
||||||
chown -R "$NB_UID:$NB_GID" /data || true
|
|
||||||
fi
|
|
||||||
|
|
||||||
exec "$@"
|
|
||||||
811
static/index.html
Normal file
811
static/index.html
Normal file
@@ -0,0 +1,811 @@
|
|||||||
|
<!doctype html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||||
|
<title>microWakeWord Recorder</title>
|
||||||
|
<style>
|
||||||
|
:root{
|
||||||
|
--bg: #070709;
|
||||||
|
--panel: rgba(18, 18, 22, 0.78);
|
||||||
|
--panel2: rgba(24, 24, 30, 0.86);
|
||||||
|
--text: #e9e9ee;
|
||||||
|
--muted: #a2a2ad;
|
||||||
|
--line: rgba(255,255,255,0.10);
|
||||||
|
--orange: #ff8a2a;
|
||||||
|
--orange2:#ffb066;
|
||||||
|
--ok:#38d39f;
|
||||||
|
--warn:#ffb020;
|
||||||
|
--err:#ff4a4a;
|
||||||
|
--shadow: 0 18px 50px rgba(0,0,0,0.45);
|
||||||
|
--radius: 16px;
|
||||||
|
}
|
||||||
|
|
||||||
|
html, body { height: 100%; }
|
||||||
|
body {
|
||||||
|
margin: 0;
|
||||||
|
color: var(--text);
|
||||||
|
font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, sans-serif;
|
||||||
|
background:
|
||||||
|
radial-gradient(900px 500px at 12% 6%, rgba(255, 138, 42, 0.12), transparent 55%),
|
||||||
|
radial-gradient(700px 420px at 80% 14%, rgba(255, 176, 102, 0.09), transparent 60%),
|
||||||
|
radial-gradient(800px 600px at 50% 100%, rgba(255, 138, 42, 0.06), transparent 55%),
|
||||||
|
linear-gradient(180deg, #050506 0%, #09090d 100%);
|
||||||
|
}
|
||||||
|
|
||||||
|
.wrap { max-width: 940px; margin: 0 auto; padding: 26px 18px 42px; }
|
||||||
|
|
||||||
|
h2 { margin: 0 0 8px; font-size: 22px; letter-spacing: 0.2px; }
|
||||||
|
p { margin: 0 0 14px; color: var(--muted); line-height: 1.45; }
|
||||||
|
|
||||||
|
.topbar {
|
||||||
|
display:flex; align-items:center; justify-content:space-between;
|
||||||
|
gap: 12px; margin-bottom: 14px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.brand { display:flex; align-items:center; gap:10px; }
|
||||||
|
.logo {
|
||||||
|
width: 38px; height: 38px; border-radius: 12px;
|
||||||
|
background:
|
||||||
|
radial-gradient(circle at 30% 30%, rgba(255,176,102,0.55), rgba(255,138,42,0.25) 45%, rgba(0,0,0,0) 72%),
|
||||||
|
linear-gradient(180deg, rgba(255,138,42,0.22), rgba(255,138,42,0.06));
|
||||||
|
border: 1px solid rgba(255,138,42,0.30);
|
||||||
|
box-shadow: 0 10px 28px rgba(255,138,42,0.08);
|
||||||
|
}
|
||||||
|
|
||||||
|
.row { display: flex; gap: 12px; flex-wrap: wrap; align-items: center; }
|
||||||
|
|
||||||
|
.card {
|
||||||
|
border: 1px solid var(--line);
|
||||||
|
background: linear-gradient(180deg, var(--panel), var(--panel2));
|
||||||
|
border-radius: var(--radius);
|
||||||
|
padding: 16px;
|
||||||
|
margin-top: 14px;
|
||||||
|
box-shadow: var(--shadow);
|
||||||
|
backdrop-filter: blur(8px);
|
||||||
|
}
|
||||||
|
|
||||||
|
.muted { color: var(--muted); }
|
||||||
|
|
||||||
|
input[type="text"], input[type="number"]{
|
||||||
|
padding: 11px 12px;
|
||||||
|
font-size: 15px;
|
||||||
|
border-radius: 12px;
|
||||||
|
border: 1px solid rgba(255,255,255,0.12);
|
||||||
|
background: rgba(0,0,0,0.35);
|
||||||
|
color: var(--text);
|
||||||
|
outline: none;
|
||||||
|
}
|
||||||
|
input[type="text"] { width: 420px; max-width: 100%; }
|
||||||
|
input[type="number"] { width: 120px; }
|
||||||
|
input::placeholder { color: rgba(233,233,238,0.35); }
|
||||||
|
|
||||||
|
button {
|
||||||
|
padding: 10px 14px;
|
||||||
|
font-size: 13px;
|
||||||
|
cursor: pointer;
|
||||||
|
border-radius: 12px;
|
||||||
|
border: 1px solid rgba(255,255,255,0.14);
|
||||||
|
background: rgba(255,255,255,0.06);
|
||||||
|
color: var(--text);
|
||||||
|
transition: transform 0.04s ease, border-color .15s ease, background .15s ease;
|
||||||
|
}
|
||||||
|
button:hover { border-color: rgba(255,138,42,0.35); background: rgba(255,255,255,0.08); }
|
||||||
|
button:active { transform: translateY(1px); }
|
||||||
|
button:disabled { opacity: 0.45; cursor: not-allowed; }
|
||||||
|
|
||||||
|
.primary {
|
||||||
|
border-color: rgba(255,138,42,0.40);
|
||||||
|
background: linear-gradient(180deg, rgba(255,138,42,0.24), rgba(255,138,42,0.12));
|
||||||
|
}
|
||||||
|
.primary:hover { border-color: rgba(255,138,42,0.65); }
|
||||||
|
|
||||||
|
.pill {
|
||||||
|
display:inline-block;
|
||||||
|
padding: 4px 10px;
|
||||||
|
border-radius: 999px;
|
||||||
|
background: rgba(255,255,255,0.07);
|
||||||
|
border: 1px solid rgba(255,255,255,0.10);
|
||||||
|
color: var(--muted);
|
||||||
|
font-size: 12px;
|
||||||
|
}
|
||||||
|
.pill.ok { color: var(--ok); border-color: rgba(56,211,159,0.25); background: rgba(56,211,159,0.08); }
|
||||||
|
.pill.warn { color: var(--warn); border-color: rgba(255,176,32,0.25); background: rgba(255,176,32,0.08); }
|
||||||
|
.pill.err { color: var(--err); border-color: rgba(255,74,74,0.25); background: rgba(255,74,74,0.08); }
|
||||||
|
|
||||||
|
details { margin-top: 10px; }
|
||||||
|
summary { cursor: pointer; color: var(--orange2); }
|
||||||
|
summary:hover { color: var(--orange); }
|
||||||
|
|
||||||
|
label { display:flex; gap:10px; align-items:center; }
|
||||||
|
input[type="range"] { width: 240px; }
|
||||||
|
|
||||||
|
.meter {
|
||||||
|
height: 10px;
|
||||||
|
background: rgba(255,255,255,0.08);
|
||||||
|
border-radius: 999px;
|
||||||
|
overflow: hidden;
|
||||||
|
width: 280px;
|
||||||
|
border: 1px solid rgba(255,255,255,0.10);
|
||||||
|
}
|
||||||
|
.meter > div {
|
||||||
|
height: 10px;
|
||||||
|
width: 0%;
|
||||||
|
background: linear-gradient(90deg, rgba(255,138,42,0.55), rgba(255,176,102,0.85));
|
||||||
|
}
|
||||||
|
|
||||||
|
pre {
|
||||||
|
background: rgba(0,0,0,0.55);
|
||||||
|
color: #e6e6ea;
|
||||||
|
padding: 12px;
|
||||||
|
border-radius: 14px;
|
||||||
|
overflow: auto;
|
||||||
|
max-height: 300px;
|
||||||
|
border: 1px solid rgba(255,255,255,0.10);
|
||||||
|
white-space: pre-wrap;
|
||||||
|
word-break: break-word;
|
||||||
|
}
|
||||||
|
|
||||||
|
.big { font-size: 16px; }
|
||||||
|
|
||||||
|
.divider {
|
||||||
|
height: 1px;
|
||||||
|
width: 100%;
|
||||||
|
background: rgba(255,255,255,0.10);
|
||||||
|
margin: 12px 0;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body>
|
||||||
|
<div class="wrap">
|
||||||
|
<div class="topbar">
|
||||||
|
<div class="brand">
|
||||||
|
<div class="logo"></div>
|
||||||
|
<div>
|
||||||
|
<h2>🎙️ microWakeWord Personal Recorder</h2>
|
||||||
|
<p class="muted">Enter a wake word, test TTS pronunciation, then record takes. Recording starts when you speak and stops after silence.</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="card">
|
||||||
|
<div class="row">
|
||||||
|
<input id="phrase" type="text" placeholder='e.g. "tater totterson"' />
|
||||||
|
<button id="startSessionBtn" class="primary">Start session</button>
|
||||||
|
<button id="ttsBtn" disabled>🔊 Test TTS</button>
|
||||||
|
<span id="sessionPill" class="pill">No session</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="row" style="margin-top:10px;">
|
||||||
|
<label class="muted">Speakers
|
||||||
|
<input id="speakersTotal" type="number" min="1" max="10" value="1" />
|
||||||
|
</label>
|
||||||
|
<label class="muted">Takes / speaker
|
||||||
|
<input id="takesPerSpeaker" type="number" min="1" max="50" value="10" />
|
||||||
|
</label>
|
||||||
|
<span id="speakerPill" class="pill">Speaker: -</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Advanced (if it’s too sensitive / not sensitive enough)</summary>
|
||||||
|
<div style="margin-top:10px;">
|
||||||
|
<label>
|
||||||
|
Start sensitivity
|
||||||
|
<input id="startThresh" type="range" min="0.005" max="0.08" step="0.001" value="0.02" />
|
||||||
|
<span id="startThreshVal" class="muted"></span>
|
||||||
|
</label>
|
||||||
|
<label>
|
||||||
|
Silence stop (ms)
|
||||||
|
<input id="silenceMs" type="range" min="300" max="2000" step="50" value="900" />
|
||||||
|
<span id="silenceMsVal" class="muted"></span>
|
||||||
|
</label>
|
||||||
|
<label>
|
||||||
|
Min take length (ms)
|
||||||
|
<input id="minTakeMs" type="range" min="300" max="2000" step="50" value="650" />
|
||||||
|
<span id="minTakeMsVal" class="muted"></span>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="card">
|
||||||
|
<div class="row">
|
||||||
|
<button id="beginBtn" disabled class="primary">🎬 Begin recording</button>
|
||||||
|
<button id="resetBtn" disabled>🧹 Reset recordings</button>
|
||||||
|
<button id="trainBtn" disabled>🧠 Start training</button>
|
||||||
|
<span id="status" class="pill">Idle</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div style="margin-top:12px;" class="row">
|
||||||
|
<div class="meter"><div id="meterFill"></div></div>
|
||||||
|
<span class="muted" id="meterText">Mic level</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="divider"></div>
|
||||||
|
|
||||||
|
<p class="big">
|
||||||
|
Speaker: <b id="speakerNum">-</b> / <b id="speakerTotal">-</b>
|
||||||
|
<span id="speakerState" class="pill">Waiting</span>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<p class="big">
|
||||||
|
Take: <b id="takeNum">0</b> / <b id="takeTotal">10</b>
|
||||||
|
<span id="takeState" class="pill">Not recording</span>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<div id="takesList" class="muted"></div>
|
||||||
|
|
||||||
|
<h4 style="margin-top: 18px; margin-bottom: 10px;">Training log</h4>
|
||||||
|
<pre id="trainLog">(no training started)</pre>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
const $ = (id) => document.getElementById(id);
|
||||||
|
|
||||||
|
function setPill(el, text, cls) {
|
||||||
|
el.className = "pill " + (cls || "");
|
||||||
|
el.textContent = text;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function api(path, opts) {
|
||||||
|
opts = opts || {};
|
||||||
|
// Always try to avoid cache for polling endpoints
|
||||||
|
if (!opts.cache) opts.cache = "no-store";
|
||||||
|
|
||||||
|
const res = await fetch(path, opts);
|
||||||
|
const ct = res.headers.get("content-type") || "";
|
||||||
|
const data = ct.includes("application/json") ? await res.json() : await res.text();
|
||||||
|
if (!res.ok) {
|
||||||
|
const err = (typeof data === "string") ? { error: data } : (data || {});
|
||||||
|
const msg = err.error || err.message || JSON.stringify(err);
|
||||||
|
const e = new Error(msg);
|
||||||
|
e.details = err;
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------------------- log auto-scroll (sticky to bottom) --------------------
|
||||||
|
function isNearBottom(el, px = 40) {
|
||||||
|
return (el.scrollHeight - el.scrollTop - el.clientHeight) <= px;
|
||||||
|
}
|
||||||
|
|
||||||
|
function setLogTextAutoScroll(el, text) {
|
||||||
|
const stick = isNearBottom(el);
|
||||||
|
el.textContent = text || "";
|
||||||
|
if (stick) el.scrollTop = el.scrollHeight;
|
||||||
|
}
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
let session = null;
|
||||||
|
let isRunning = false;
|
||||||
|
|
||||||
|
let stream = null;
|
||||||
|
let audioCtx = null;
|
||||||
|
let analyser = null;
|
||||||
|
let source = null;
|
||||||
|
|
||||||
|
let capturing = false;
|
||||||
|
let startedAt = 0;
|
||||||
|
let silenceStart = null;
|
||||||
|
let floatChunks = [];
|
||||||
|
let frameSize = 2048;
|
||||||
|
|
||||||
|
let currentSpeaker = 1;
|
||||||
|
let speakersTotal = 1;
|
||||||
|
|
||||||
|
let currentTake = 0;
|
||||||
|
let takesPerSpeaker = 10;
|
||||||
|
|
||||||
|
// --- training poll (append mode; scrollback works) ---
|
||||||
|
let trainingPollRunning = false;
|
||||||
|
let trainingPollAbort = false;
|
||||||
|
|
||||||
|
let logBuffer = ""; // full text we’ve shown in the browser
|
||||||
|
let lastChunk = ""; // last chunk we received (for de-dupe)
|
||||||
|
let seenAnyOutput = false;
|
||||||
|
|
||||||
|
function appendLogAutoScroll(el, chunk) {
|
||||||
|
if (!chunk) return;
|
||||||
|
const stick = isNearBottom(el);
|
||||||
|
el.textContent += chunk;
|
||||||
|
if (stick) el.scrollTop = el.scrollHeight;
|
||||||
|
}
|
||||||
|
|
||||||
|
function startThreshold() { return parseFloat($("startThresh").value); }
|
||||||
|
function silenceStopMs() { return parseInt($("silenceMs").value, 10); }
|
||||||
|
function minTakeMs() { return parseInt($("minTakeMs").value, 10); }
|
||||||
|
|
||||||
|
function updateAdvancedLabels() {
|
||||||
|
$("startThreshVal").textContent = startThreshold().toFixed(3);
|
||||||
|
$("silenceMsVal").textContent = silenceStopMs() + "ms";
|
||||||
|
$("minTakeMsVal").textContent = minTakeMs() + "ms";
|
||||||
|
}
|
||||||
|
["startThresh","silenceMs","minTakeMs"].forEach(id => $(id).addEventListener("input", updateAdvancedLabels));
|
||||||
|
updateAdvancedLabels();
|
||||||
|
|
||||||
|
function refreshUI() {
|
||||||
|
$("speakerNum").textContent = String(currentSpeaker);
|
||||||
|
$("speakerTotal").textContent = String(speakersTotal);
|
||||||
|
$("takeNum").textContent = String(currentTake);
|
||||||
|
$("takeTotal").textContent = String(takesPerSpeaker);
|
||||||
|
setPill($("speakerPill"), `Speaker ${currentSpeaker}/${speakersTotal}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------------------- mic lifecycle --------------------
|
||||||
|
async function ensureMic() {
|
||||||
|
if (!navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) {
|
||||||
|
throw new Error("Microphone not available here. Use https:// (or http://localhost) to record.");
|
||||||
|
}
|
||||||
|
if (stream) return;
|
||||||
|
stream = await navigator.mediaDevices.getUserMedia({ audio: true, video: false });
|
||||||
|
audioCtx = new (window.AudioContext || window.webkitAudioContext)();
|
||||||
|
analyser = audioCtx.createAnalyser();
|
||||||
|
analyser.fftSize = 2048;
|
||||||
|
source = audioCtx.createMediaStreamSource(stream);
|
||||||
|
source.connect(analyser);
|
||||||
|
requestAnimationFrame(meterLoop);
|
||||||
|
}
|
||||||
|
|
||||||
|
async function stopMicNow() {
|
||||||
|
isRunning = false;
|
||||||
|
capturing = false;
|
||||||
|
|
||||||
|
const proc = window.__mw_proc;
|
||||||
|
if (proc) {
|
||||||
|
try { proc.disconnect(); } catch {}
|
||||||
|
try { source && source.disconnect(proc); } catch {}
|
||||||
|
window.__mw_proc = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (stream) {
|
||||||
|
try { stream.getTracks().forEach(t => t.stop()); } catch {}
|
||||||
|
stream = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (audioCtx) {
|
||||||
|
try { await audioCtx.close(); } catch {}
|
||||||
|
audioCtx = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
analyser = null;
|
||||||
|
source = null;
|
||||||
|
|
||||||
|
$("meterFill").style.width = "0%";
|
||||||
|
$("meterText").textContent = "Mic stopped";
|
||||||
|
}
|
||||||
|
|
||||||
|
function meterLoop() {
|
||||||
|
if (!analyser) {
|
||||||
|
requestAnimationFrame(meterLoop);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = new Uint8Array(analyser.fftSize);
|
||||||
|
analyser.getByteTimeDomainData(data);
|
||||||
|
|
||||||
|
let sumSq = 0;
|
||||||
|
for (let i=0;i<data.length;i++){
|
||||||
|
const v = (data[i] - 128) / 128;
|
||||||
|
sumSq += v*v;
|
||||||
|
}
|
||||||
|
const rms = Math.sqrt(sumSq / data.length);
|
||||||
|
const pct = Math.min(100, Math.max(0, rms * 600));
|
||||||
|
$("meterFill").style.width = pct + "%";
|
||||||
|
$("meterText").textContent = `Mic level (rms=${rms.toFixed(3)})`;
|
||||||
|
|
||||||
|
if (isRunning) recorderTick(rms);
|
||||||
|
|
||||||
|
requestAnimationFrame(meterLoop);
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------------------- recording state machine --------------------
|
||||||
|
function recorderTick(rms) {
|
||||||
|
const now = performance.now();
|
||||||
|
|
||||||
|
if (!capturing) {
|
||||||
|
if (rms >= startThreshold()) startCapture();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rms < startThreshold() * 0.65) {
|
||||||
|
if (silenceStart === null) silenceStart = now;
|
||||||
|
const silentFor = now - silenceStart;
|
||||||
|
if (silentFor >= silenceStopMs()) {
|
||||||
|
const dur = now - startedAt;
|
||||||
|
if (dur >= minTakeMs()) stopCaptureAndUpload();
|
||||||
|
else silenceStart = now;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
silenceStart = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function startCapture() {
|
||||||
|
capturing = true;
|
||||||
|
startedAt = performance.now();
|
||||||
|
silenceStart = null;
|
||||||
|
floatChunks = [];
|
||||||
|
|
||||||
|
setPill($("takeState"), "Recording…", "warn");
|
||||||
|
|
||||||
|
const proc = audioCtx.createScriptProcessor(frameSize, 1, 1);
|
||||||
|
source.connect(proc);
|
||||||
|
proc.connect(audioCtx.destination);
|
||||||
|
|
||||||
|
proc.onaudioprocess = (ev) => {
|
||||||
|
if (!capturing) return;
|
||||||
|
const chan = ev.inputBuffer.getChannelData(0);
|
||||||
|
floatChunks.push(new Float32Array(chan));
|
||||||
|
};
|
||||||
|
|
||||||
|
window.__mw_proc = proc;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function stopCaptureAndUpload() {
|
||||||
|
capturing = false;
|
||||||
|
setPill($("takeState"), "Processing…");
|
||||||
|
|
||||||
|
const proc = window.__mw_proc;
|
||||||
|
if (proc) {
|
||||||
|
try { proc.disconnect(); } catch {}
|
||||||
|
try { source.disconnect(proc); } catch {}
|
||||||
|
window.__mw_proc = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
currentTake += 1;
|
||||||
|
refreshUI();
|
||||||
|
|
||||||
|
let totalLen = 0;
|
||||||
|
for (const c of floatChunks) totalLen += c.length;
|
||||||
|
const merged = new Float32Array(totalLen);
|
||||||
|
let off = 0;
|
||||||
|
for (const c of floatChunks) { merged.set(c, off); off += c.length; }
|
||||||
|
|
||||||
|
const wavBlob = await floatToWav16kMono(merged, audioCtx.sampleRate);
|
||||||
|
|
||||||
|
try {
|
||||||
|
setPill($("status"), `Uploading speaker ${currentSpeaker} take ${currentTake}…`, "warn");
|
||||||
|
|
||||||
|
const fd = new FormData();
|
||||||
|
fd.append("speaker_index", String(currentSpeaker));
|
||||||
|
fd.append("take_index", String(currentTake));
|
||||||
|
fd.append("file", wavBlob, `take_${String(currentTake).padStart(2,"0")}.wav`);
|
||||||
|
|
||||||
|
await api("/api/upload_take", { method:"POST", body: fd });
|
||||||
|
|
||||||
|
$("takesList").textContent = `Saved ${currentTake}/${takesPerSpeaker} takes for speaker ${currentSpeaker}/${speakersTotal}`;
|
||||||
|
setPill($("status"), `Saved speaker ${currentSpeaker} take ${currentTake}/${takesPerSpeaker}`, "ok");
|
||||||
|
|
||||||
|
if (currentTake >= takesPerSpeaker) {
|
||||||
|
if (currentSpeaker >= speakersTotal) {
|
||||||
|
setPill($("takeState"), "Done", "ok");
|
||||||
|
setPill($("speakerState"), "All speakers done ✅", "ok");
|
||||||
|
setPill($("status"), "All takes recorded ✅", "ok");
|
||||||
|
|
||||||
|
await stopMicNow();
|
||||||
|
await autoStartTraining();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
currentSpeaker += 1;
|
||||||
|
currentTake = 0;
|
||||||
|
refreshUI();
|
||||||
|
|
||||||
|
setPill($("speakerState"), `Speaker ${currentSpeaker - 1} complete ✅`, "ok");
|
||||||
|
setPill($("takeState"), "Paused", "warn");
|
||||||
|
setPill($("status"), `Ready for speaker ${currentSpeaker}. Click Begin recording.`, "warn");
|
||||||
|
|
||||||
|
isRunning = false;
|
||||||
|
$("beginBtn").disabled = false;
|
||||||
|
|
||||||
|
await stopMicNow();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
setPill($("speakerState"), `Speaker ${currentSpeaker}/${speakersTotal}`);
|
||||||
|
setPill($("takeState"), "Listening…", "ok");
|
||||||
|
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
setPill($("status"), "Upload failed", "err");
|
||||||
|
setPill($("takeState"), "Error", "err");
|
||||||
|
isRunning = false;
|
||||||
|
$("beginBtn").disabled = false;
|
||||||
|
alert("Upload failed: " + e.message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------------------- WAV encoding helpers --------------------
|
||||||
|
async function floatToWav16kMono(float32, srcRate) {
|
||||||
|
const buf = audioCtx.createBuffer(1, float32.length, srcRate);
|
||||||
|
buf.copyToChannel(float32, 0);
|
||||||
|
|
||||||
|
const targetRate = 16000;
|
||||||
|
const targetLen = Math.max(1, Math.round(float32.length * targetRate / srcRate));
|
||||||
|
const offline = new OfflineAudioContext(1, targetLen, targetRate);
|
||||||
|
|
||||||
|
const src = offline.createBufferSource();
|
||||||
|
src.buffer = buf;
|
||||||
|
src.connect(offline.destination);
|
||||||
|
src.start(0);
|
||||||
|
|
||||||
|
const rendered = await offline.startRendering();
|
||||||
|
const data = rendered.getChannelData(0);
|
||||||
|
|
||||||
|
const wav = encodeWavPCM16(data, targetRate);
|
||||||
|
return new Blob([wav], { type: "audio/wav" });
|
||||||
|
}
|
||||||
|
|
||||||
|
function encodeWavPCM16(float32, sampleRate) {
|
||||||
|
const numSamples = float32.length;
|
||||||
|
const buffer = new ArrayBuffer(44 + numSamples * 2);
|
||||||
|
const view = new DataView(buffer);
|
||||||
|
|
||||||
|
function writeString(offset, str) {
|
||||||
|
for (let i=0;i<str.length;i++) view.setUint8(offset+i, str.charCodeAt(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
writeString(0, "RIFF");
|
||||||
|
view.setUint32(4, 36 + numSamples * 2, true);
|
||||||
|
writeString(8, "WAVE");
|
||||||
|
|
||||||
|
writeString(12, "fmt ");
|
||||||
|
view.setUint32(16, 16, true);
|
||||||
|
view.setUint16(20, 1, true);
|
||||||
|
view.setUint16(22, 1, true);
|
||||||
|
view.setUint32(24, sampleRate, true);
|
||||||
|
view.setUint32(28, sampleRate * 2, true);
|
||||||
|
view.setUint16(32, 2, true);
|
||||||
|
view.setUint16(34, 16, true);
|
||||||
|
|
||||||
|
writeString(36, "data");
|
||||||
|
view.setUint32(40, numSamples * 2, true);
|
||||||
|
|
||||||
|
let offset = 44;
|
||||||
|
for (let i=0;i<numSamples;i++) {
|
||||||
|
let s = Math.max(-1, Math.min(1, float32[i]));
|
||||||
|
const v = s < 0 ? s * 0x8000 : s * 0x7fff;
|
||||||
|
view.setInt16(offset, v, true);
|
||||||
|
offset += 2;
|
||||||
|
}
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------------------- training (manual + auto) --------------------
|
||||||
|
async function startTrainingWithPrompt(auto=false) {
|
||||||
|
const sess = await api("/api/session", { method: "GET" });
|
||||||
|
const takesReceived = sess.takes_received || 0;
|
||||||
|
const total = (sess.speakers_total || 1) * (sess.takes_per_speaker || 10);
|
||||||
|
|
||||||
|
let allowNoPersonal = false;
|
||||||
|
|
||||||
|
if (takesReceived === 0) {
|
||||||
|
const ok = confirm(
|
||||||
|
`No personal voice samples recorded (0/${total}).\n\nTrain anyway WITHOUT personal voices?`
|
||||||
|
);
|
||||||
|
if (!ok) return;
|
||||||
|
allowNoPersonal = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// lock UI immediately
|
||||||
|
$("trainBtn").disabled = true;
|
||||||
|
$("beginBtn").disabled = true;
|
||||||
|
$("resetBtn").disabled = true;
|
||||||
|
|
||||||
|
setPill($("status"), auto ? "Auto-starting training…" : "Preparing training environment…", "warn");
|
||||||
|
|
||||||
|
// Reset log state for a fresh run
|
||||||
|
trainingPollAbort = false;
|
||||||
|
logBuffer = "";
|
||||||
|
lastChunk = "";
|
||||||
|
seenAnyOutput = false;
|
||||||
|
|
||||||
|
const logEl = $("trainLog");
|
||||||
|
logEl.textContent = "(preparing…)\n";
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Kick off training first
|
||||||
|
await api("/api/train", {
|
||||||
|
method: "POST",
|
||||||
|
headers: { "Content-Type": "application/json" },
|
||||||
|
body: JSON.stringify({ allow_no_personal: allowNoPersonal })
|
||||||
|
});
|
||||||
|
|
||||||
|
// Only start polling AFTER training was successfully kicked off
|
||||||
|
if (!trainingPollRunning) {
|
||||||
|
trainingPollRunning = true;
|
||||||
|
pollTrainingTail();
|
||||||
|
}
|
||||||
|
|
||||||
|
setPill($("status"), "Training running…", "warn");
|
||||||
|
} catch (e) {
|
||||||
|
$("trainBtn").disabled = false;
|
||||||
|
$("resetBtn").disabled = false;
|
||||||
|
$("beginBtn").disabled = false;
|
||||||
|
trainingPollAbort = true;
|
||||||
|
trainingPollRunning = false;
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function autoStartTraining() {
|
||||||
|
try {
|
||||||
|
await startTrainingWithPrompt(true);
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
setPill($("status"), "Auto-train failed", "err");
|
||||||
|
alert("Auto-start training failed: " + e.message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
$("trainBtn").addEventListener("click", async () => {
|
||||||
|
try {
|
||||||
|
await startTrainingWithPrompt(false);
|
||||||
|
} catch (e) {
|
||||||
|
alert("Train failed: " + e.message);
|
||||||
|
setPill($("status"), "Train failed", "err");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
async function pollTrainingTail() {
|
||||||
|
const logEl = $("trainLog");
|
||||||
|
|
||||||
|
for (;;) {
|
||||||
|
if (trainingPollAbort) {
|
||||||
|
trainingPollRunning = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const st = await api(`/api/train_status?ts=${Date.now()}`, { method:"GET", cache:"no-store" });
|
||||||
|
const tr = st.training || {};
|
||||||
|
|
||||||
|
// NOTE: this assumes /api/train_status returns NEW output chunks (not full tail snapshots)
|
||||||
|
const chunkRaw = tr.log_text || "";
|
||||||
|
const chunk = chunkRaw; // keep exact newlines from server
|
||||||
|
|
||||||
|
if (chunk) {
|
||||||
|
// wipe placeholder once
|
||||||
|
if (!seenAnyOutput) {
|
||||||
|
logEl.textContent = "";
|
||||||
|
logBuffer = "";
|
||||||
|
lastChunk = "";
|
||||||
|
seenAnyOutput = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// simple de-dupe: if server repeats the same chunk, skip it
|
||||||
|
if (chunk !== lastChunk) {
|
||||||
|
lastChunk = chunk;
|
||||||
|
logBuffer += chunk;
|
||||||
|
appendLogAutoScroll(logEl, chunk);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// before first output, show waiting message but do NOT overwrite later scrollback
|
||||||
|
if (!seenAnyOutput) {
|
||||||
|
if (!logEl.textContent || logEl.textContent.includes("(no training") || logEl.textContent.startsWith("(preparing…")) {
|
||||||
|
logEl.textContent = "Waiting for training output…\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const exitCodeIsSet = (tr.exit_code !== null && tr.exit_code !== undefined);
|
||||||
|
|
||||||
|
if (!tr.running && exitCodeIsSet) {
|
||||||
|
$("trainBtn").disabled = false;
|
||||||
|
$("resetBtn").disabled = false;
|
||||||
|
$("beginBtn").disabled = false;
|
||||||
|
|
||||||
|
if (tr.exit_code === 0) setPill($("status"), "Training finished ✅", "ok");
|
||||||
|
else setPill($("status"), `Training ended (exit=${tr.exit_code})`, "err");
|
||||||
|
|
||||||
|
trainingPollRunning = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
// ignore transient polling errors
|
||||||
|
}
|
||||||
|
|
||||||
|
await new Promise(r => setTimeout(r, 1000));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------------------- session + UI wiring --------------------
|
||||||
|
$("ttsBtn").addEventListener("click", () => {
|
||||||
|
const phrase = ($("phrase").value || "").trim();
|
||||||
|
if (!phrase) return;
|
||||||
|
const u = new SpeechSynthesisUtterance(phrase);
|
||||||
|
speechSynthesis.cancel();
|
||||||
|
speechSynthesis.speak(u);
|
||||||
|
});
|
||||||
|
|
||||||
|
$("startSessionBtn").addEventListener("click", async () => {
|
||||||
|
const phrase = ($("phrase").value || "").trim();
|
||||||
|
if (!phrase) { alert("Enter a wake word phrase first."); return; }
|
||||||
|
|
||||||
|
speakersTotal = parseInt($("speakersTotal").value || "1", 10);
|
||||||
|
takesPerSpeaker = parseInt($("takesPerSpeaker").value || "10", 10);
|
||||||
|
|
||||||
|
try {
|
||||||
|
setPill($("sessionPill"), "Starting…", "warn");
|
||||||
|
const data = await api("/api/start_session", {
|
||||||
|
method: "POST",
|
||||||
|
headers: {"Content-Type":"application/json"},
|
||||||
|
body: JSON.stringify({ phrase, speakers_total: speakersTotal, takes_per_speaker: takesPerSpeaker })
|
||||||
|
});
|
||||||
|
|
||||||
|
session = data;
|
||||||
|
|
||||||
|
currentSpeaker = 1;
|
||||||
|
currentTake = 0;
|
||||||
|
|
||||||
|
$("takesList").textContent = "";
|
||||||
|
$("trainLog").textContent = "(no training started)";
|
||||||
|
|
||||||
|
// Stop any previous poll loop cleanly
|
||||||
|
trainingPollAbort = true;
|
||||||
|
trainingPollRunning = false;
|
||||||
|
logBuffer = "";
|
||||||
|
lastChunk = "";
|
||||||
|
seenAnyOutput = false;
|
||||||
|
|
||||||
|
refreshUI();
|
||||||
|
|
||||||
|
await stopMicNow();
|
||||||
|
|
||||||
|
setPill($("sessionPill"), `Session: ${data.safe_word}`, "ok");
|
||||||
|
$("beginBtn").disabled = false;
|
||||||
|
$("resetBtn").disabled = false;
|
||||||
|
$("trainBtn").disabled = false;
|
||||||
|
$("ttsBtn").disabled = false;
|
||||||
|
|
||||||
|
setPill($("status"), "Ready", "ok");
|
||||||
|
setPill($("speakerState"), "Waiting");
|
||||||
|
setPill($("takeState"), "Not recording");
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
setPill($("sessionPill"), "Session failed", "err");
|
||||||
|
alert("Start session failed: " + e.message);
|
||||||
|
} finally {
|
||||||
|
// allow a new poll loop to start later
|
||||||
|
trainingPollAbort = false;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
$("resetBtn").addEventListener("click", async () => {
|
||||||
|
try {
|
||||||
|
await api("/api/reset_recordings", {method:"POST"});
|
||||||
|
currentSpeaker = 1;
|
||||||
|
currentTake = 0;
|
||||||
|
$("takesList").textContent = "";
|
||||||
|
refreshUI();
|
||||||
|
setPill($("status"), "Recordings reset", "ok");
|
||||||
|
} catch (e) {
|
||||||
|
alert("Reset failed: " + e.message);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
$("beginBtn").addEventListener("click", async () => {
|
||||||
|
if (!session) { alert("Start a session first."); return; }
|
||||||
|
try {
|
||||||
|
await ensureMic();
|
||||||
|
} catch (e) {
|
||||||
|
alert("Mic permission failed: " + e.message);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
$("takesList").textContent = "";
|
||||||
|
refreshUI();
|
||||||
|
|
||||||
|
isRunning = true;
|
||||||
|
$("beginBtn").disabled = true;
|
||||||
|
|
||||||
|
setPill($("speakerState"), `Speaker ${currentSpeaker}/${speakersTotal}`);
|
||||||
|
setPill($("status"), "Listening… say the wake word now", "ok");
|
||||||
|
setPill($("takeState"), "Listening…", "ok");
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
130
train_wake_word
Normal file
130
train_wake_word
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
PROGPATH=$(realpath "$0")
|
||||||
|
PROGDIR=$(dirname "${PROGPATH}")
|
||||||
|
CLIDIR="${PROGDIR}/cli"
|
||||||
|
|
||||||
|
KNOWN_ARGS=( samples batch-size training-steps data-dir cleanup-work-dir )
|
||||||
|
source "${CLIDIR}/shell.functions"
|
||||||
|
WAKE_WORD=${POSITIONAL_ARGS[0]}
|
||||||
|
|
||||||
|
if [ ${#UNKNOWN_ARGS[@]} -gt 0 ] ; then
|
||||||
|
echo "Unknown argument(s): ${UNKNOWN_ARGS[*]}" >&2
|
||||||
|
HELP=true
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "${HELP}" == "true" ] || [ -z "${WAKE_WORD}" ] ; then
|
||||||
|
cat <<EOF >&2
|
||||||
|
Usage: train_wake_word [ --samples=<samples> ] [ --batch-size=<batch_size> ]
|
||||||
|
[ --training-steps=<steps> ] [ --cleanup-work-dir ]
|
||||||
|
<wake_word> [ <wake_word_title> ]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--samples: The number of samples to generate for the wake word.
|
||||||
|
Default: ${DEFAULT_SAMPLES}
|
||||||
|
|
||||||
|
--batch-size: How many samples should be generated at a time. The more
|
||||||
|
samples per batch, the more memory is needed.
|
||||||
|
Default: ${DEFAULT_BATCH_SIZE}
|
||||||
|
|
||||||
|
--training-steps: Number of training steps. More training steps means better
|
||||||
|
detection and false positive rates but also more time to train.
|
||||||
|
Default: ${DEFAULT_TRAINING_STEPS}
|
||||||
|
|
||||||
|
--cleanup-work-dir: Delete the /data/work directory after successful training.
|
||||||
|
Default: false
|
||||||
|
|
||||||
|
<wake_word> The word to train spelled phonetically.
|
||||||
|
Required.
|
||||||
|
|
||||||
|
<wake_word_title> An optional pretty name to save to the json metadata file.
|
||||||
|
Default: The wake word with individual words capitalized
|
||||||
|
and punctuation removed.
|
||||||
|
|
||||||
|
EOF
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# shellcheck source=/dev/null
|
||||||
|
source "${DATA_DIR}/.venv/bin/activate"
|
||||||
|
|
||||||
|
cd "${DATA_DIR}"
|
||||||
|
mkdir -p "${DATA_DIR}/work" || :
|
||||||
|
|
||||||
|
[ ${#POSITIONAL_ARGS} -eq 2 ] && WAKE_WORD_TITLE="${POSITIONAL_ARGS[1]}" || :
|
||||||
|
|
||||||
|
if [ ! -v WAKE_WORD_TITLE ] ; then
|
||||||
|
declare -a WWNA=( ${WAKE_WORD//[^a-zA-Z0-9]/ } )
|
||||||
|
WAKE_WORD_TITLE="${WWNA[*]^}"
|
||||||
|
elif [ -z "$WAKE_WORD_TITLE" ] ; then
|
||||||
|
WAKE_WORD_TITLE="$WAKE_WORD"
|
||||||
|
fi
|
||||||
|
|
||||||
|
printf "%-80s\n" "=" | tr ' ' "="
|
||||||
|
echo "===== Running '${WAKE_WORD}(${WAKE_WORD_TITLE})' generation, augmentation and training ====="
|
||||||
|
"${CLIDIR}/cudainfo"
|
||||||
|
echo
|
||||||
|
START_TS=$EPOCHSECONDS
|
||||||
|
|
||||||
|
export TF_CPP_MIN_LOG_LEVEL=9
|
||||||
|
export TF_FORCE_GPU_ALLOW_GROWTH=true
|
||||||
|
export TF_GPU_ALLOCATOR=cuda_malloc_async
|
||||||
|
export TF_XLA_FLAGS="--tf_xla_auto_jit=0"
|
||||||
|
export NVIDIA_TF32_OVERRIDE=1
|
||||||
|
export TF_CUDNN_WORKSPACE_LIMIT_IN_MB=512
|
||||||
|
export GLOG_minloglevel=2
|
||||||
|
export GRPC_VERBOSITY=ERROR
|
||||||
|
|
||||||
|
"${CLIDIR}/wake_word_sample_generator" \
|
||||||
|
--samples=${SAMPLES} \
|
||||||
|
--batch-size=${BATCH_SIZE} \
|
||||||
|
--data-dir="${DATA_DIR}" "${WAKE_WORD}"
|
||||||
|
|
||||||
|
POST_GEN_TS=$EPOCHSECONDS
|
||||||
|
|
||||||
|
AUGMENT=false
|
||||||
|
GENERATED_DIR="${DATA_DIR}/work/wake_word_samples"
|
||||||
|
AUGMENTED_DIR="${DATA_DIR}/work/wake_word_samples_augmented"
|
||||||
|
|
||||||
|
[ -d "${AUGMENTED_DIR}" ] || AUGMENT=true
|
||||||
|
[ "${GENERATED_DIR}/0.wav" -nt "${AUGMENTED_DIR}/testing/wakeword_mmap/data.ninja" ] && AUGMENT=true || :
|
||||||
|
|
||||||
|
if ${AUGMENT} ; then
|
||||||
|
rm -rf "${AUGMENTED_DIR}" || :
|
||||||
|
mkdir -p "${AUGMENTED_DIR}" || :
|
||||||
|
python -u "${CLIDIR}/wake_word_sample_augmenter" --data-dir="${DATA_DIR}" || { rm -rf "${AUGMENTED_DIR}" ; exit 1 ; }
|
||||||
|
else
|
||||||
|
echo "Augmentation not required"
|
||||||
|
echo
|
||||||
|
fi
|
||||||
|
|
||||||
|
POST_AUGMENT_TS=$EPOCHSECONDS
|
||||||
|
|
||||||
|
"${CLIDIR}/wake_word_sample_trainer" \
|
||||||
|
--samples=${SAMPLES} \
|
||||||
|
--training-steps=${TRAINING_STEPS} \
|
||||||
|
--data-dir="${DATA_DIR}" \
|
||||||
|
"${WAKE_WORD}" "${WAKE_WORD_TITLE}"
|
||||||
|
|
||||||
|
if ${CLEANUP_WORK_DIR} ; then
|
||||||
|
rm -rf \
|
||||||
|
"${DATA_DIR}/work/trained_models" \
|
||||||
|
"${DATA_DIR}/work/wake_word_samples" \
|
||||||
|
"${DATA_DIR}/work/wake_word_samples_augmented" \
|
||||||
|
"${DATA_DIR}/work/personal_augmented_features" \
|
||||||
|
"${DATA_DIR}/work/last_wake_word" || :
|
||||||
|
fi
|
||||||
|
|
||||||
|
END_TS=$EPOCHSECONDS
|
||||||
|
|
||||||
|
python -c $'print(f"{\'=\' * 80}")'
|
||||||
|
printf "%44s\n\n" "Training Summary"
|
||||||
|
"${CLIDIR}/system_summary"
|
||||||
|
echo
|
||||||
|
print_elapsed_time --no-separators "${START_TS}" "${POST_GEN_TS}" "Generate ${SAMPLES} samples, ${BATCH_SIZE}/batch"
|
||||||
|
print_elapsed_time --no-separators "${POST_GEN_TS}" "${POST_AUGMENT_TS}" "Augment ${SAMPLES} samples"
|
||||||
|
print_elapsed_time --no-separators "${POST_AUGMENT_TS}" "${END_TS}" "${TRAINING_STEPS} training steps"
|
||||||
|
python -c $'msg="="*54 ; print(f"{msg:>80s}")'
|
||||||
|
print_elapsed_time --no-separators "${START_TS}" "${END_TS}" "Total"
|
||||||
|
python -c $'print(f"{\'=\' * 80}")'
|
||||||
Reference in New Issue
Block a user