mirror of
https://github.com/TaterTotterson/microWakeWord-Trainer-Nvidia-Docker.git
synced 2026-06-12 20:10:19 -06:00
cli + web recorder ui
This commit is contained in:
@@ -67,42 +67,66 @@ find_rev() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
converter() {
|
converter() {
|
||||||
source ${DATA_DIR}/.venv/bin/activate
|
# shellcheck source=/dev/null
|
||||||
|
source "${DATA_DIR}/.venv/bin/activate"
|
||||||
|
|
||||||
python - "${AUDIO_DIR}" "${AUDIO16K_DIR}" <<-EOF
|
python - "${AUDIO_DIR}" "${AUDIO16K_DIR}" <<-EOF
|
||||||
import os, sys, subprocess, scipy.io.wavfile, numpy as np
|
import os, sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import soundfile as sf
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import scipy.io.wavfile
|
||||||
import librosa
|
import librosa
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
def write_wav(dst: Path, data: np.ndarray, sr: int):
|
def write_wav(dst: Path, data: np.ndarray, sr: int):
|
||||||
|
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||||
x = np.clip(data, -1.0, 1.0)
|
x = np.clip(data, -1.0, 1.0)
|
||||||
scipy.io.wavfile.write(dst, sr, (x * 32767).astype(np.int16))
|
scipy.io.wavfile.write(dst, sr, (x * 32767).astype(np.int16))
|
||||||
|
|
||||||
audioset_dir = Path(sys.argv[1])
|
audioset_dir = Path(sys.argv[1])
|
||||||
audioset_out = Path(sys.argv[2])
|
audioset_out = Path(sys.argv[2])
|
||||||
|
|
||||||
# convert FLAC → 16k mono WAV
|
|
||||||
flacs = list(audioset_dir.rglob("*.flac"))
|
flacs = list(audioset_dir.rglob("*.flac"))
|
||||||
print(f" FLAC files: {len(flacs)}")
|
total = len(flacs)
|
||||||
|
print(f" FLAC files: {total}")
|
||||||
|
print(" Converting AudioSet → 16k mono WAV")
|
||||||
|
print(" Sit tight — this step can take a while.")
|
||||||
|
print("")
|
||||||
|
|
||||||
audioset_bad = []
|
audioset_bad = []
|
||||||
ok = 0
|
ok = 0
|
||||||
for p in tqdm(flacs, desc=" AudioSet→WAV (resample 16k mono)"):
|
skipped = 0
|
||||||
|
|
||||||
|
START = datetime.now(timezone.utc).replace(microsecond=0)
|
||||||
|
|
||||||
|
# Heartbeat interval (prints every N files)
|
||||||
|
HEARTBEAT_EVERY = 500
|
||||||
|
|
||||||
|
for idx, p in enumerate(flacs, start=1):
|
||||||
try:
|
try:
|
||||||
outfile = Path(audioset_out / (p.stem + ".wav"))
|
outfile = audioset_out / (p.stem + ".wav")
|
||||||
if outfile.exists():
|
if outfile.exists():
|
||||||
continue
|
skipped += 1
|
||||||
y, _ = librosa.load(p, sr=16000, mono=True)
|
else:
|
||||||
if y.size == 0:
|
y, _ = librosa.load(p, sr=16000, mono=True)
|
||||||
raise ValueError("empty audio")
|
if y.size == 0:
|
||||||
write_wav(outfile, y, 16000)
|
raise ValueError("empty audio")
|
||||||
ok += 1
|
write_wav(outfile, y, 16000)
|
||||||
|
ok += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
audioset_bad.append(f"{p}:{e}")
|
audioset_bad.append(f"{p}:{e}")
|
||||||
|
|
||||||
|
if idx == 1 or (idx % HEARTBEAT_EVERY) == 0 or idx == total:
|
||||||
|
print(f" Progress: {idx}/{total} (ok={ok}, skipped={skipped}, failed={len(audioset_bad)})")
|
||||||
|
|
||||||
if audioset_bad:
|
if audioset_bad:
|
||||||
(audioset_out / "audioset_corrupted_files.log").write_text("\n".join(audioset_bad))
|
(audioset_out / "audioset_corrupted_files.log").write_text("\n".join(audioset_bad))
|
||||||
print(f" AudioSet complete ({ok} ok, {len(audioset_bad)} failed)")
|
|
||||||
|
END = datetime.now(timezone.utc).replace(microsecond=0)
|
||||||
|
elapsed = END - START
|
||||||
|
print("")
|
||||||
|
print(f" AudioSet complete ({ok} ok, {skipped} skipped, {len(audioset_bad)} failed) Elapsed: {elapsed}")
|
||||||
EOF
|
EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,13 +134,15 @@ expected_filecount=$(get_total_filecount filecounts)
|
|||||||
actual_filecount=$(find "${AUDIO16K_DIR}" -name "*.wav" 2>/dev/null | wc -l) || :
|
actual_filecount=$(find "${AUDIO16K_DIR}" -name "*.wav" 2>/dev/null | wc -l) || :
|
||||||
write_filecount=false
|
write_filecount=false
|
||||||
|
|
||||||
if [ "${actual_filecount}" -ne 0 ] && [ "${actual_filecount}" -eq "${expected_filecount}" ] ; then
|
# Option B behavior: if we already have output WAVs, don't re-download/re-extract/re-convert
|
||||||
echo " Existing Audioset valid"
|
if [ "${actual_filecount}" -ne 0 ] ; then
|
||||||
|
echo " Existing ${AUDIO16K_DIR} present (${actual_filecount} wav); skipping extract/convert"
|
||||||
else
|
else
|
||||||
dl=$(find_rev)
|
dl=$(find_rev)
|
||||||
[ -n "$dl" ] || { echo " Could not locate an AudioSet revision with FLAC tarballs still present on HF." ; exit 1 ; }
|
[ -n "$dl" ] || { echo " Could not locate an AudioSet revision with FLAC tarballs still present on HF." ; exit 1 ; }
|
||||||
rev=${dl%%,*}
|
rev=${dl%%,*}
|
||||||
pattern=${dl##*,}
|
pattern=${dl##*,}
|
||||||
|
|
||||||
echo " Checking 10 tarballs"
|
echo " Checking 10 tarballs"
|
||||||
for i in {0..9} ; do
|
for i in {0..9} ; do
|
||||||
fname="downloads/bal_train0${i}.tar"
|
fname="downloads/bal_train0${i}.tar"
|
||||||
@@ -137,17 +163,16 @@ else
|
|||||||
rm -rf "${fname}"
|
rm -rf "${fname}"
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
|
|
||||||
rm -rf "${AUDIO16K_DIR}/audioset_corrupted_files.log" || :
|
rm -rf "${AUDIO16K_DIR}/audioset_corrupted_files.log" || :
|
||||||
converter
|
converter
|
||||||
if [ -f "${AUDIO16K_DIR}/audioset_corrupted_files.log" ] ; then
|
|
||||||
failed=$(cat "${AUDIO16K_DIR}/audioset_corrupted_files.log" | wc -l)
|
# Recompute counts and warn (but do not fail)
|
||||||
filecounts[failed]=-${failed}
|
|
||||||
fi
|
|
||||||
expected_filecount=$(get_total_filecount filecounts)
|
expected_filecount=$(get_total_filecount filecounts)
|
||||||
actual_filecount=$(find ${AUDIO16K_DIR} -name "*.wav" 2>/dev/null | wc -l) || :
|
actual_filecount=$(find "${AUDIO16K_DIR}" -name "*.wav" 2>/dev/null | wc -l) || :
|
||||||
if [ "${actual_filecount}" -ne "${expected_filecount}" ] ; then
|
if [ "${actual_filecount}" -ne "${expected_filecount}" ] ; then
|
||||||
echo " Converted file count(${actual_filecount}) != expected file count(${expected_filecount})" >&2
|
echo " Converted file count(${actual_filecount}) != expected file count(${expected_filecount})" >&2
|
||||||
exit 1
|
echo " WARNING: mismatch is expected if some AudioSet files are corrupted; continuing." >&2
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@@ -171,5 +196,4 @@ if "${CLEANUP_INTERMEDIATE_FILES}" && [ -d "${AUDIO_DIR}" ] ; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
echo " Audioset complete"
|
echo " Audioset complete"
|
||||||
exit 0
|
exit 0
|
||||||
|
|
||||||
@@ -8,9 +8,9 @@ if [ ! -v DATA_DIR ] ; then
|
|||||||
[ -f .mww-data-dir ] && DATA_DIR="${PWD}" || DATA_DIR="/data"
|
[ -f .mww-data-dir ] && DATA_DIR="${PWD}" || DATA_DIR="/data"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
DEFAULT_SAMPLES=20000
|
DEFAULT_SAMPLES=50000
|
||||||
DEFAULT_BATCH_SIZE=100
|
DEFAULT_BATCH_SIZE=100
|
||||||
DEFAULT_TRAINING_STEPS=25000
|
DEFAULT_TRAINING_STEPS=40000
|
||||||
|
|
||||||
[ -f "${DATA_DIR}/.defaults.env" ] && source "${DATA_DIR}/.defaults.env" || :
|
[ -f "${DATA_DIR}/.defaults.env" ] && source "${DATA_DIR}/.defaults.env" || :
|
||||||
|
|
||||||
|
|||||||
@@ -71,17 +71,16 @@ if not files:
|
|||||||
max_samples = len(files)
|
max_samples = len(files)
|
||||||
|
|
||||||
print(f"\n===== Augmenting {max_samples} wake word samples =====")
|
print(f"\n===== Augmenting {max_samples} wake word samples =====")
|
||||||
|
|
||||||
print(" Initializing libraries")
|
print(" Initializing libraries")
|
||||||
|
|
||||||
os.environ["TF_CPP_MIN_LOG_LEVEL"]="3"
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"]="true"
|
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
|
||||||
os.environ["TF_GPU_ALLOCATOR"]="cuda_malloc_async"
|
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"
|
||||||
os.environ["TF_XLA_FLAGS"]="--tf_xla_auto_jit=0"
|
os.environ["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0"
|
||||||
os.environ["NVIDIA_TF32_OVERRIDE"]="1"
|
os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
|
||||||
os.environ["TF_CUDNN_WORKSPACE_LIMIT_IN_MB"]="512"
|
os.environ["TF_CUDNN_WORKSPACE_LIMIT_IN_MB"] = "512"
|
||||||
os.environ["GLOG_minloglevel"]="9"
|
os.environ["GLOG_minloglevel"] = "9"
|
||||||
os.environ["GRPC_VERBOSITY"]="ERROR"
|
os.environ["GRPC_VERBOSITY"] = "ERROR"
|
||||||
|
|
||||||
print(" Loading Tensorflow")
|
print(" Loading Tensorflow")
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -98,6 +97,7 @@ gc.collect()
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import librosa
|
import librosa
|
||||||
|
from tqdm import tqdm
|
||||||
from mmap_ninja.ragged import RaggedMmap
|
from mmap_ninja.ragged import RaggedMmap
|
||||||
from microwakeword.audio.augmentation import Augmentation
|
from microwakeword.audio.augmentation import Augmentation
|
||||||
from microwakeword.audio.clips import Clips
|
from microwakeword.audio.clips import Clips
|
||||||
@@ -108,7 +108,7 @@ START_TIME = datetime.now(timezone.utc).replace(microsecond=0)
|
|||||||
|
|
||||||
# Paths to augmented data
|
# Paths to augmented data
|
||||||
impulse_paths = [ args.mit_rirs_16k_dir ]
|
impulse_paths = [ args.mit_rirs_16k_dir ]
|
||||||
background_paths = [ args.fma_16k_dir, args.audioset_16k_dir]
|
background_paths = [ args.fma_16k_dir, args.audioset_16k_dir ]
|
||||||
|
|
||||||
clips = Clips(
|
clips = Clips(
|
||||||
input_directory=args.input_dir,
|
input_directory=args.input_dir,
|
||||||
@@ -139,8 +139,6 @@ augmenter = Augmentation(
|
|||||||
max_jitter_s=0.3,
|
max_jitter_s=0.3,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Augment samples and save the training, validation, and testing sets.
|
|
||||||
|
|
||||||
def audio_generator_from_wavs(self, split="train", repeat=1):
|
def audio_generator_from_wavs(self, split="train", repeat=1):
|
||||||
"""
|
"""
|
||||||
Yield 1-D float32 arrays loaded via librosa from input_dir/*.wav.
|
Yield 1-D float32 arrays loaded via librosa from input_dir/*.wav.
|
||||||
@@ -175,7 +173,7 @@ def audio_generator_from_wavs(self, split="train", repeat=1):
|
|||||||
# Bind the patched generator to your existing `clips` instance
|
# Bind the patched generator to your existing `clips` instance
|
||||||
clips.audio_generator = types.MethodType(audio_generator_from_wavs, clips)
|
clips.audio_generator = types.MethodType(audio_generator_from_wavs, clips)
|
||||||
|
|
||||||
# ---- Split config (same as before) ----
|
# ---- Split config ----
|
||||||
split_cfg = {
|
split_cfg = {
|
||||||
"training": {"name": "train", "repetition": 2, "slide_frames": 10},
|
"training": {"name": "train", "repetition": 2, "slide_frames": 10},
|
||||||
"validation": {"name": "validation", "repetition": 1, "slide_frames": 10},
|
"validation": {"name": "validation", "repetition": 1, "slide_frames": 10},
|
||||||
@@ -188,28 +186,34 @@ for split, cfg in split_cfg.items():
|
|||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
print(f" Augmenting {split}")
|
print(f" Augmenting {split}")
|
||||||
|
|
||||||
print(f" Generating spectrograms")
|
print(" Generating spectrograms")
|
||||||
spectros = SpectrogramGeneration(
|
spectros = SpectrogramGeneration(
|
||||||
clips=clips, # now backed by our WAV loader
|
clips=clips,
|
||||||
augmenter=augmenter, # your existing augmenter
|
augmenter=augmenter,
|
||||||
slide_frames=cfg["slide_frames"],
|
slide_frames=cfg["slide_frames"],
|
||||||
step_ms=10,
|
step_ms=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f" Generating files")
|
print(" Generating files")
|
||||||
|
print(" Sit tight — this step can take a while.")
|
||||||
|
|
||||||
|
gen = spectros.spectrogram_generator(
|
||||||
|
split=cfg["name"],
|
||||||
|
repeat=cfg["repetition"],
|
||||||
|
)
|
||||||
|
|
||||||
RaggedMmap.from_generator(
|
RaggedMmap.from_generator(
|
||||||
out_dir=str(out_dir / "wakeword_mmap"),
|
out_dir=str(out_dir / "wakeword_mmap"),
|
||||||
sample_generator=spectros.spectrogram_generator(
|
sample_generator=gen,
|
||||||
split=cfg["name"], repeat=cfg["repetition"]
|
|
||||||
),
|
|
||||||
batch_size=100,
|
batch_size=100,
|
||||||
verbose=False,
|
verbose=False, # keep mmap quiet
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f" {split} augmentation complete")
|
print(f" {split} augmentation complete")
|
||||||
|
|
||||||
END_TIME = datetime.now(timezone.utc).replace(microsecond=0)
|
END_TIME = datetime.now(timezone.utc).replace(microsecond=0)
|
||||||
et = END_TIME - START_TIME
|
et = END_TIME - START_TIME
|
||||||
print(f"\n{'=' * 80}")
|
print(f"\n{'=' * 80}")
|
||||||
msg=f"Augmented {max_samples} wake word samples."
|
msg = f"Augmented {max_samples} wake word samples."
|
||||||
print(f"{msg:>50s} Elapsed time: {et!s}")
|
print(f"{msg:>50s} Elapsed time: {et!s}")
|
||||||
print(f"{'=' * 80}\n")
|
print(f"{'=' * 80}\n")
|
||||||
@@ -129,88 +129,136 @@ EOF
|
|||||||
echo " Wrote training_parameters.yaml"
|
echo " Wrote training_parameters.yaml"
|
||||||
rm -rf "${WORK_DIR}/trained_models/wakeword"
|
rm -rf "${WORK_DIR}/trained_models/wakeword"
|
||||||
|
|
||||||
export TF_CPP_MIN_LOG_LEVEL=9
|
wake_word_filename="${WAKE_WORD//[ \`~\!\$&*$begin:math:text$$end:math:text$\{\}$begin:math:display$$end:math:display$\|\;\'\"<>.?\/]/_}"
|
||||||
export TF_FORCE_GPU_ALLOW_GROWTH=true
|
|
||||||
export TF_GPU_ALLOCATOR=cuda_malloc_async
|
|
||||||
export TF_XLA_FLAGS="--tf_xla_auto_jit=0"
|
|
||||||
export NVIDIA_TF32_OVERRIDE=1
|
|
||||||
export TF_CUDNN_WORKSPACE_LIMIT_IN_MB=512
|
|
||||||
export GLOG_minloglevel=9
|
|
||||||
export GRPC_VERBOSITY=ERROR
|
|
||||||
|
|
||||||
echo " Loading Tensorflow"
|
|
||||||
|
|
||||||
wake_word_filename="${WAKE_WORD//[ \`~\!\$&*\(\)\{\}\[\]\|\;\'\"<>.?\/]/_}"
|
|
||||||
OUTPUT_DIR="${DATA_DIR}/output/$(date +'%Y-%m-%d-%H-%M-%S')-${wake_word_filename}-${SAMPLES}-${TRAINING_STEPS}"
|
OUTPUT_DIR="${DATA_DIR}/output/$(date +'%Y-%m-%d-%H-%M-%S')-${wake_word_filename}-${SAMPLES}-${TRAINING_STEPS}"
|
||||||
mkdir -p "${OUTPUT_DIR}/logs" || :
|
mkdir -p "${OUTPUT_DIR}/logs" || :
|
||||||
|
|
||||||
python - \
|
TRAIN_LOG="${OUTPUT_DIR}/logs/training.log"
|
||||||
--training_config="${WORK_DIR}/trained_models/training_parameters.yaml" \
|
|
||||||
--train 1 \
|
|
||||||
--restore_checkpoint 1 \
|
|
||||||
--test_tf_nonstreaming 0 \
|
|
||||||
--test_tflite_nonstreaming 0 \
|
|
||||||
--test_tflite_nonstreaming_quantized 0 \
|
|
||||||
--test_tflite_streaming 0 \
|
|
||||||
--test_tflite_streaming_quantized 1 \
|
|
||||||
--use_weights "best_weights" \
|
|
||||||
mixednet \
|
|
||||||
--pointwise_filters "64,64,64,64" \
|
|
||||||
--repeat_in_block "1,1,1,1" \
|
|
||||||
--mixconv_kernel_sizes "[5], [7,11], [9,15], [23]" \
|
|
||||||
--residual_connection "0,0,0,0" \
|
|
||||||
--first_conv_filters 32 \
|
|
||||||
--first_conv_kernel_size 5 \
|
|
||||||
--stride 2 <<EOF 2>&1 | tr '\r' '\n' | stdbuf -i0 -o0 sed -r -e "/^Validation Batch/d" |\
|
|
||||||
tee "${OUTPUT_DIR}/logs/training.log" | sed -r -e '/^INFO:absl:/!d' \
|
|
||||||
-r -e "/None|Sharding|unsupported characters|AUC|fingerprint/d" \
|
|
||||||
-r -e 's/INFO:absl:/ /g' \
|
|
||||||
-r -e "s/, (recall =|estimated false|average viable recall)/,\n \1/g"
|
|
||||||
|
|
||||||
import sys, os, gc
|
# ------------------------------------------------------------------
|
||||||
import runpy
|
# Training args (same as before)
|
||||||
import yaml
|
# ------------------------------------------------------------------
|
||||||
print(" Loading Tensorflow")
|
TRAIN_ARGS=(
|
||||||
import tensorflow as tf
|
-m microwakeword.model_train_eval
|
||||||
|
--training_config "${WORK_DIR}/trained_models/training_parameters.yaml"
|
||||||
|
--train 1
|
||||||
|
--restore_checkpoint 1
|
||||||
|
--test_tf_nonstreaming 0
|
||||||
|
--test_tflite_nonstreaming 0
|
||||||
|
--test_tflite_nonstreaming_quantized 0
|
||||||
|
--test_tflite_streaming 0
|
||||||
|
--test_tflite_streaming_quantized 1
|
||||||
|
--use_weights best_weights
|
||||||
|
mixednet
|
||||||
|
--pointwise_filters "64,64,64,64"
|
||||||
|
--repeat_in_block "1,1,1,1"
|
||||||
|
--mixconv_kernel_sizes "[5], [7,11], [9,15], [23]"
|
||||||
|
--residual_connection "0,0,0,0"
|
||||||
|
--first_conv_filters 32
|
||||||
|
--first_conv_kernel_size 5
|
||||||
|
--stride 2
|
||||||
|
)
|
||||||
|
|
||||||
print(" GPU memory config")
|
# ------------------------------------------------------------------
|
||||||
# Per-device memory growth (belt + suspenders)
|
# GPU failure markers that should trigger CPU fallback
|
||||||
for g in tf.config.list_physical_devices("GPU"):
|
# (OOM + known GPU runtime/copy/init failures)
|
||||||
try:
|
# ------------------------------------------------------------------
|
||||||
tf.config.experimental.set_memory_growth(g, True)
|
GPU_FALLBACK_MARKERS=(
|
||||||
except Exception:
|
"resourceexhaustederror"
|
||||||
pass
|
"resource exhausted"
|
||||||
print(f"INFO:absl:GPUs: {tf.config.list_physical_devices('GPU')}")
|
"oom"
|
||||||
gc.collect()
|
"out of memory"
|
||||||
|
"cuda_error_out_of_memory"
|
||||||
|
"failed to allocate"
|
||||||
|
"cudnn"
|
||||||
|
"cublas"
|
||||||
|
"internalerror: cuda"
|
||||||
|
"failed call to cuinit"
|
||||||
|
"dst tensor is not initialized"
|
||||||
|
"failed copying input tensor"
|
||||||
|
"_eagerconst"
|
||||||
|
)
|
||||||
|
|
||||||
print()
|
run_attempt() {
|
||||||
try:
|
local label="$1"
|
||||||
runpy.run_module("microwakeword.model_train_eval", run_name="__main__", alter_sys=True)
|
shift
|
||||||
except Exception as e:
|
echo
|
||||||
print(e, file=sys.stderr)
|
echo "================================================================================"
|
||||||
sys.exit(1)
|
echo "===== ${label} ====="
|
||||||
EOF
|
echo "================================================================================"
|
||||||
|
echo "→ ${PYTHON_BIN:-python} ${TRAIN_ARGS[*]}"
|
||||||
|
echo
|
||||||
|
|
||||||
|
# stream everything except validation minibatch spam
|
||||||
|
"${PYTHON_BIN:-python}" "${TRAIN_ARGS[@]}" 2>&1 \
|
||||||
|
| tr '\r' '\n' \
|
||||||
|
| stdbuf -i0 -o0 sed -r -e "/^Validation Batch/d" \
|
||||||
|
| tee "${TRAIN_LOG}" \
|
||||||
|
| sed -r -e "/^Validation Batch/d" -e "s/^INFO:absl:/ /g"
|
||||||
|
|
||||||
|
return ${PIPESTATUS[0]}
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---- Common TF env (mirrors your notebook) ----
|
||||||
|
export TF_CPP_MIN_LOG_LEVEL="${TF_CPP_MIN_LOG_LEVEL:-2}"
|
||||||
|
export TF_XLA_FLAGS="${TF_XLA_FLAGS:---tf_xla_auto_jit=0}"
|
||||||
|
export NVIDIA_TF32_OVERRIDE="${NVIDIA_TF32_OVERRIDE:-1}"
|
||||||
|
export TF_FORCE_GPU_ALLOW_GROWTH="${TF_FORCE_GPU_ALLOW_GROWTH:-true}"
|
||||||
|
export TF_GPU_ALLOCATOR="${TF_GPU_ALLOCATOR:-cuda_malloc_async}"
|
||||||
|
|
||||||
|
# Attempt 1: GPU
|
||||||
|
if run_attempt "Attempt 1/2: GPU training (allow_growth + cuda_malloc_async)" ; then
|
||||||
|
echo "✅ Training complete (GPU path)."
|
||||||
|
else
|
||||||
|
echo "⚠️ GPU attempt failed. Checking whether this looks like a GPU/OOM/runtime failure…"
|
||||||
|
|
||||||
|
# Check log for GPU/OOM/runtime markers
|
||||||
|
log_lc="$(tr '[:upper:]' '[:lower:]' < "${TRAIN_LOG}" || true)"
|
||||||
|
looks_like_gpu_fail="false"
|
||||||
|
for m in "${GPU_FALLBACK_MARKERS[@]}"; do
|
||||||
|
if echo "${log_lc}" | grep -qF "${m}"; then
|
||||||
|
looks_like_gpu_fail="true"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ "${looks_like_gpu_fail}" = "true" ]; then
|
||||||
|
echo "↪️ Detected GPU/OOM/runtime failure markers. Falling back to CPU."
|
||||||
|
|
||||||
|
# Attempt 2: CPU (hide GPU completely)
|
||||||
|
export CUDA_VISIBLE_DEVICES=""
|
||||||
|
unset TF_GPU_ALLOCATOR
|
||||||
|
if run_attempt "Attempt 2/2: CPU fallback (CUDA_VISIBLE_DEVICES='')" ; then
|
||||||
|
echo "✅ Training complete (CPU fallback)."
|
||||||
|
else
|
||||||
|
echo "❌ Training failed on BOTH GPU and CPU. See: ${TRAIN_LOG}" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "❌ Training failed (does not look GPU/OOM/runtime). See: ${TRAIN_LOG}" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
source_path="${WORK_DIR}/trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite"
|
source_path="${WORK_DIR}/trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite"
|
||||||
|
|
||||||
if [ ! -f "${source_path}" ] ; then
|
if [ ! -f "${source_path}" ] ; then
|
||||||
echo "Output model not found! Training didn't complete successfully. See ${WORK_DIR}/training.log"
|
echo "Output model not found! Training didn't complete successfully. See ${TRAIN_LOG}"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
cp "${WORK_DIR}/trained_models/wakeword/model_summary.txt" "${OUTPUT_DIR}/logs/"
|
cp "${WORK_DIR}/trained_models/wakeword/model_summary.txt" "${OUTPUT_DIR}/logs/" || :
|
||||||
cp -a "${WORK_DIR}/trained_models/wakeword/logs/train" "${OUTPUT_DIR}/logs/"
|
cp -a "${WORK_DIR}/trained_models/wakeword/logs/train" "${OUTPUT_DIR}/logs/" || :
|
||||||
cp -a "${WORK_DIR}/trained_models/wakeword/logs/validation" "${OUTPUT_DIR}/logs/"
|
cp -a "${WORK_DIR}/trained_models/wakeword/logs/validation" "${OUTPUT_DIR}/logs/" || :
|
||||||
|
|
||||||
echo -e "\n Training complete!"
|
echo -e "\n Training complete!"
|
||||||
echo " Full log: ${OUTPUT_DIR}/logs/training.log"
|
echo " Full log: ${TRAIN_LOG}"
|
||||||
|
|
||||||
tflite_filename="${wake_word_filename}.tflite"
|
tflite_filename="${wake_word_filename}.tflite"
|
||||||
tflite_path="${OUTPUT_DIR}/${tflite_filename}"
|
tflite_path="${OUTPUT_DIR}/${tflite_filename}"
|
||||||
|
|
||||||
cp "${source_path}" "${tflite_path}"
|
cp "${source_path}" "${tflite_path}"
|
||||||
|
|
||||||
# --- Write JSON metadata file with matching model name ---
|
|
||||||
json_path="${OUTPUT_DIR}/${wake_word_filename}.json"
|
json_path="${OUTPUT_DIR}/${wake_word_filename}.json"
|
||||||
cat <<-EOF > "${json_path}"
|
cat <<-EOF > "${json_path}"
|
||||||
{
|
{
|
||||||
@@ -237,5 +285,4 @@ echo "Metadata: ${json_path}"
|
|||||||
echo
|
echo
|
||||||
END_TS=$EPOCHSECONDS
|
END_TS=$EPOCHSECONDS
|
||||||
print_elapsed_time "${START_TS}" "${END_TS}" "Training completed."
|
print_elapsed_time "${START_TS}" "${END_TS}" "Training completed."
|
||||||
echo
|
echo
|
||||||
|
|
||||||
@@ -27,9 +27,12 @@ COPY --chown=root:root --chmod=0755 \
|
|||||||
requirements.txt \
|
requirements.txt \
|
||||||
/root/mww-scripts/
|
/root/mww-scripts/
|
||||||
|
|
||||||
# CLI folder (THIS IS THE IMPORTANT CHANGE)
|
# CLI folder
|
||||||
COPY --chown=root:root cli/ /root/mww-scripts/cli/
|
COPY --chown=root:root cli/ /root/mww-scripts/cli/
|
||||||
|
|
||||||
|
# Make all CLI scripts executable (avoids "Permission denied")
|
||||||
|
RUN chmod -R a+x /root/mww-scripts/cli
|
||||||
|
|
||||||
# Static UI for recorder
|
# Static UI for recorder
|
||||||
COPY --chown=root:root --chmod=0644 static/index.html /root/mww-scripts/static/index.html
|
COPY --chown=root:root --chmod=0644 static/index.html /root/mww-scripts/static/index.html
|
||||||
|
|
||||||
|
|||||||
@@ -4,9 +4,9 @@ import re
|
|||||||
import subprocess
|
import subprocess
|
||||||
import threading
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, List, Optional, Tuple
|
from typing import Dict, Any, List, Optional
|
||||||
|
|
||||||
from fastapi import FastAPI, UploadFile, File, Form, Query
|
from fastapi import FastAPI, UploadFile, File, Form
|
||||||
from fastapi.responses import HTMLResponse, JSONResponse
|
from fastapi.responses import HTMLResponse, JSONResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
@@ -24,14 +24,9 @@ PERSONAL_DIR = Path(os.environ.get("PERSONAL_DIR", str(DATA_DIR / "personal_samp
|
|||||||
# CLI folder inside repo
|
# CLI folder inside repo
|
||||||
CLI_DIR = Path(os.environ.get("CLI_DIR", str(ROOT_DIR / "cli"))).resolve()
|
CLI_DIR = Path(os.environ.get("CLI_DIR", str(ROOT_DIR / "cli"))).resolve()
|
||||||
|
|
||||||
# If you want cleanup defaults for auto dataset setup, set these env vars:
|
|
||||||
# REC_DATASET_CLEANUP_ARCHIVES=true/false
|
|
||||||
# REC_DATASET_CLEANUP_INTERMEDIATE_FILES=true/false
|
|
||||||
DATASET_CLEANUP_ARCHIVES = os.environ.get("REC_DATASET_CLEANUP_ARCHIVES", "false").lower() in ("1", "true", "yes", "y")
|
DATASET_CLEANUP_ARCHIVES = os.environ.get("REC_DATASET_CLEANUP_ARCHIVES", "false").lower() in ("1", "true", "yes", "y")
|
||||||
DATASET_CLEANUP_INTERMEDIATE = os.environ.get("REC_DATASET_CLEANUP_INTERMEDIATE_FILES", "false").lower() in ("1", "true", "yes", "y")
|
DATASET_CLEANUP_INTERMEDIATE = os.environ.get("REC_DATASET_CLEANUP_INTERMEDIATE_FILES", "false").lower() in ("1", "true", "yes", "y")
|
||||||
|
|
||||||
# We want "Start training" to trigger your CLI entrypoint, using the existing venv
|
|
||||||
# (train_wake_word should be in /data/.venv/bin via setup_python_venv)
|
|
||||||
TRAIN_CMD = os.environ.get(
|
TRAIN_CMD = os.environ.get(
|
||||||
"TRAIN_CMD",
|
"TRAIN_CMD",
|
||||||
f"source '{DATA_DIR}/.venv/bin/activate' && train_wake_word --data-dir '{DATA_DIR}'"
|
f"source '{DATA_DIR}/.venv/bin/activate' && train_wake_word --data-dir '{DATA_DIR}'"
|
||||||
@@ -40,14 +35,13 @@ TRAIN_CMD = os.environ.get(
|
|||||||
TAKES_PER_SPEAKER_DEFAULT = int(os.environ.get("REC_TAKES_PER_SPEAKER", "10"))
|
TAKES_PER_SPEAKER_DEFAULT = int(os.environ.get("REC_TAKES_PER_SPEAKER", "10"))
|
||||||
SPEAKERS_TOTAL_DEFAULT = int(os.environ.get("REC_SPEAKERS_TOTAL", "1"))
|
SPEAKERS_TOTAL_DEFAULT = int(os.environ.get("REC_SPEAKERS_TOTAL", "1"))
|
||||||
|
|
||||||
# How many lines to show in WebUI (tail)
|
# Tail lines shown to UI
|
||||||
TRAIN_LOG_TAIL_LINES = int(os.environ.get("REC_TRAIN_LOG_TAIL_LINES", "400"))
|
TRAIN_LOG_TAIL_LINES = int(os.environ.get("REC_TRAIN_LOG_TAIL_LINES", "400"))
|
||||||
# If you prefer bytes-based tailing (fast), keep this non-zero.
|
# Safety cap for reads (bytes) to avoid giant file reads
|
||||||
TRAIN_LOG_MAX_BYTES = int(os.environ.get("REC_TRAIN_LOG_MAX_BYTES", str(512 * 1024))) # 512KB
|
TRAIN_LOG_MAX_BYTES = int(os.environ.get("REC_TRAIN_LOG_MAX_BYTES", str(512 * 1024))) # 512KB
|
||||||
|
|
||||||
app = FastAPI(title="microWakeWord Personal Recorder")
|
app = FastAPI(title="microWakeWord Personal Recorder")
|
||||||
|
|
||||||
# Serve static UI
|
|
||||||
STATIC_DIR.mkdir(parents=True, exist_ok=True)
|
STATIC_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
||||||
|
|
||||||
@@ -60,7 +54,6 @@ def safe_name(raw: str) -> str:
|
|||||||
return s or "wakeword"
|
return s or "wakeword"
|
||||||
|
|
||||||
|
|
||||||
# -------------------- In-memory session state --------------------
|
|
||||||
STATE: Dict[str, Any] = {
|
STATE: Dict[str, Any] = {
|
||||||
"raw_phrase": None,
|
"raw_phrase": None,
|
||||||
"safe_word": None,
|
"safe_word": None,
|
||||||
@@ -74,12 +67,13 @@ STATE: Dict[str, Any] = {
|
|||||||
"training": {
|
"training": {
|
||||||
"running": False,
|
"running": False,
|
||||||
"exit_code": None,
|
"exit_code": None,
|
||||||
"log_lines": [], # legacy in-memory tail (still maintained)
|
"log_lines": [], # legacy in-memory tail (kept, but not relied on)
|
||||||
"log_path": None, # path to recorder_training.log
|
"log_path": None, # path to recorder_training.log
|
||||||
"safe_word": None,
|
"safe_word": None,
|
||||||
|
|
||||||
# NEW: byte offset for efficient log tailing
|
# NEW: prevent UI duplication when UI appends:
|
||||||
"log_offset": 0,
|
"last_sent_tail": [], # last tail snapshot (list of lines)
|
||||||
|
"last_log_size": 0, # detect truncation
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,6 +89,26 @@ def _reset_personal_samples_dir():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _clear_training_log():
|
||||||
|
"""
|
||||||
|
Truncate recorder_training.log for a fresh session.
|
||||||
|
"""
|
||||||
|
log_path = DATA_DIR / "recorder_training.log"
|
||||||
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
with open(log_path, "w", encoding="utf-8") as lf:
|
||||||
|
lf.write("================================================================================\n")
|
||||||
|
lf.write("===== New recorder session started =====\n")
|
||||||
|
lf.write("================================================================================\n")
|
||||||
|
lf.flush()
|
||||||
|
|
||||||
|
with STATE_LOCK:
|
||||||
|
STATE["training"]["log_path"] = str(log_path)
|
||||||
|
STATE["training"]["log_lines"] = []
|
||||||
|
STATE["training"]["last_sent_tail"] = []
|
||||||
|
STATE["training"]["last_log_size"] = 0
|
||||||
|
|
||||||
|
|
||||||
def _append_train_log(line: str):
|
def _append_train_log(line: str):
|
||||||
line = (line or "").rstrip("\n")
|
line = (line or "").rstrip("\n")
|
||||||
with STATE_LOCK:
|
with STATE_LOCK:
|
||||||
@@ -105,7 +119,6 @@ def _append_train_log(line: str):
|
|||||||
|
|
||||||
|
|
||||||
def _title_from_phrase(raw_phrase: str) -> str:
|
def _title_from_phrase(raw_phrase: str) -> str:
|
||||||
# Keep it human-friendly for the optional <wake_word_title> argument
|
|
||||||
s = re.sub(r"[^a-zA-Z0-9 ]+", " ", raw_phrase or "").strip()
|
s = re.sub(r"[^a-zA-Z0-9 ]+", " ", raw_phrase or "").strip()
|
||||||
s = re.sub(r"\s+", " ", s)
|
s = re.sub(r"\s+", " ", s)
|
||||||
return s.title() if s else ""
|
return s.title() if s else ""
|
||||||
@@ -118,12 +131,6 @@ def _run_streamed(
|
|||||||
header: Optional[str] = None,
|
header: Optional[str] = None,
|
||||||
env: Optional[Dict[str, str]] = None,
|
env: Optional[Dict[str, str]] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
|
||||||
Run a command streaming stdout/stderr to both:
|
|
||||||
- recorder_training.log (disk)
|
|
||||||
- STATE["training"]["log_lines"] (UI) [best-effort]
|
|
||||||
Returns process exit code.
|
|
||||||
"""
|
|
||||||
if header:
|
if header:
|
||||||
_append_train_log(header)
|
_append_train_log(header)
|
||||||
|
|
||||||
@@ -156,9 +163,6 @@ def _run_streamed(
|
|||||||
|
|
||||||
|
|
||||||
def _ensure_training_venv(log_path: Path) -> None:
|
def _ensure_training_venv(log_path: Path) -> None:
|
||||||
"""
|
|
||||||
Ensure /data/.venv exists by running cli/setup_python_venv if needed.
|
|
||||||
"""
|
|
||||||
activate = DATA_DIR / ".venv" / "bin" / "activate"
|
activate = DATA_DIR / ".venv" / "bin" / "activate"
|
||||||
if activate.exists():
|
if activate.exists():
|
||||||
_append_train_log("✅ Training venv found (skipping setup_python_venv)")
|
_append_train_log("✅ Training venv found (skipping setup_python_venv)")
|
||||||
@@ -183,10 +187,6 @@ def _ensure_training_venv(log_path: Path) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _ensure_training_datasets(log_path: Path) -> None:
|
def _ensure_training_datasets(log_path: Path) -> None:
|
||||||
"""
|
|
||||||
Always run setup_training_datasets before training.
|
|
||||||
The underlying scripts should skip work when already done.
|
|
||||||
"""
|
|
||||||
setup = CLI_DIR / "setup_training_datasets"
|
setup = CLI_DIR / "setup_training_datasets"
|
||||||
if not setup.exists():
|
if not setup.exists():
|
||||||
raise RuntimeError(f"Missing setup_training_datasets at: {setup}")
|
raise RuntimeError(f"Missing setup_training_datasets at: {setup}")
|
||||||
@@ -217,67 +217,45 @@ def _ensure_training_datasets(log_path: Path) -> None:
|
|||||||
raise RuntimeError(f"setup_training_datasets failed (exit_code={rc})")
|
raise RuntimeError(f"setup_training_datasets failed (exit_code={rc})")
|
||||||
|
|
||||||
|
|
||||||
def _read_log_tail_by_bytes(log_path: Path, max_bytes: int) -> str:
|
def _read_tail_lines(log_path: Path, max_lines: int) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Read up to the last max_bytes from a file (UTF-8 best effort).
|
Read the last N lines, bounded by TRAIN_LOG_MAX_BYTES.
|
||||||
|
Returns list of lines (no trailing newlines).
|
||||||
"""
|
"""
|
||||||
if not log_path.exists():
|
if not log_path.exists():
|
||||||
return ""
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
size = log_path.stat().st_size
|
size = log_path.stat().st_size
|
||||||
start = max(0, size - max_bytes)
|
start = max(0, size - TRAIN_LOG_MAX_BYTES)
|
||||||
with open(log_path, "rb") as f:
|
with open(log_path, "rb") as f:
|
||||||
f.seek(start)
|
f.seek(start)
|
||||||
data = f.read()
|
data = f.read()
|
||||||
# If we started in the middle of a line, it's ok; UI will show partial.
|
|
||||||
return data.decode("utf-8", errors="replace")
|
|
||||||
except Exception:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def _read_log_tail_by_lines(log_path: Path, max_lines: int) -> str:
|
|
||||||
"""
|
|
||||||
Read last N lines of a file (simple, may be slower on huge files).
|
|
||||||
"""
|
|
||||||
if not log_path.exists():
|
|
||||||
return ""
|
|
||||||
try:
|
|
||||||
# Read by bytes limited first, then line-tail
|
|
||||||
raw = _read_log_tail_by_bytes(log_path, TRAIN_LOG_MAX_BYTES)
|
|
||||||
if not raw:
|
|
||||||
return ""
|
|
||||||
lines = raw.splitlines()
|
|
||||||
if len(lines) <= max_lines:
|
|
||||||
return "\n".join(lines)
|
|
||||||
return "\n".join(lines[-max_lines:])
|
|
||||||
except Exception:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def _read_log_since_offset(log_path: Path, offset: int, max_bytes: int = 256 * 1024) -> Tuple[str, int]:
|
|
||||||
"""
|
|
||||||
Read log file incrementally starting from `offset`.
|
|
||||||
Returns (new_text, new_offset). Caps bytes read per call.
|
|
||||||
"""
|
|
||||||
if not log_path.exists():
|
|
||||||
return ("", offset)
|
|
||||||
|
|
||||||
try:
|
|
||||||
size = log_path.stat().st_size
|
|
||||||
# If file rotated/truncated, reset offset
|
|
||||||
if offset > size:
|
|
||||||
offset = 0
|
|
||||||
|
|
||||||
with open(log_path, "rb") as f:
|
|
||||||
f.seek(offset)
|
|
||||||
data = f.read(max_bytes)
|
|
||||||
|
|
||||||
new_offset = offset + len(data)
|
|
||||||
text = data.decode("utf-8", errors="replace")
|
text = data.decode("utf-8", errors="replace")
|
||||||
return (text, new_offset)
|
lines = text.splitlines()
|
||||||
|
if len(lines) <= max_lines:
|
||||||
|
return lines
|
||||||
|
return lines[-max_lines:]
|
||||||
except Exception:
|
except Exception:
|
||||||
return ("", offset)
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_new_lines(prev_tail: List[str], new_tail: List[str]) -> List[str]:
|
||||||
|
"""
|
||||||
|
Given previous and current tail snapshots, return only the newly-added lines.
|
||||||
|
Works even if the tail window shifts.
|
||||||
|
"""
|
||||||
|
if not prev_tail:
|
||||||
|
return new_tail
|
||||||
|
|
||||||
|
# Try to find the largest suffix of prev_tail that matches a prefix of new_tail
|
||||||
|
max_k = min(len(prev_tail), len(new_tail))
|
||||||
|
for k in range(max_k, 0, -1):
|
||||||
|
if prev_tail[-k:] == new_tail[:k]:
|
||||||
|
return new_tail[k:]
|
||||||
|
|
||||||
|
# If no overlap, just return full new_tail (probably truncation or big jump)
|
||||||
|
return new_tail
|
||||||
|
|
||||||
|
|
||||||
def _run_training_background(safe_word: str, allow_no_personal: bool):
|
def _run_training_background(safe_word: str, allow_no_personal: bool):
|
||||||
@@ -291,16 +269,15 @@ def _run_training_background(safe_word: str, allow_no_personal: bool):
|
|||||||
STATE["training"]["exit_code"] = None
|
STATE["training"]["exit_code"] = None
|
||||||
STATE["training"]["log_lines"] = []
|
STATE["training"]["log_lines"] = []
|
||||||
STATE["training"]["safe_word"] = safe_word
|
STATE["training"]["safe_word"] = safe_word
|
||||||
|
STATE["training"]["last_sent_tail"] = []
|
||||||
|
STATE["training"]["last_log_size"] = 0
|
||||||
log_path = Path(str(DATA_DIR / "recorder_training.log"))
|
log_path = Path(str(DATA_DIR / "recorder_training.log"))
|
||||||
STATE["training"]["log_path"] = str(log_path)
|
STATE["training"]["log_path"] = str(log_path)
|
||||||
STATE["training"]["log_offset"] = 0
|
|
||||||
|
|
||||||
# fresh header at the start of a run
|
|
||||||
_append_train_log("================================================================================")
|
_append_train_log("================================================================================")
|
||||||
_append_train_log("===== Recorder Training Run =====")
|
_append_train_log("===== Recorder Training Run =====")
|
||||||
_append_train_log("================================================================================")
|
_append_train_log("================================================================================")
|
||||||
|
|
||||||
# Ensure the log exists and starts cleanly with a header separator for this run
|
|
||||||
try:
|
try:
|
||||||
with open(log_path, "a", encoding="utf-8") as lf:
|
with open(log_path, "a", encoding="utf-8") as lf:
|
||||||
lf.write("\n" + ("=" * 80) + "\n")
|
lf.write("\n" + ("=" * 80) + "\n")
|
||||||
@@ -311,13 +288,9 @@ def _run_training_background(safe_word: str, allow_no_personal: bool):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1) Ensure venv (auto-installs)
|
|
||||||
_ensure_training_venv(log_path)
|
_ensure_training_venv(log_path)
|
||||||
|
|
||||||
# 2) Ensure datasets (auto-installs / skips if already present)
|
|
||||||
_ensure_training_datasets(log_path)
|
_ensure_training_datasets(log_path)
|
||||||
|
|
||||||
# 3) Run training
|
|
||||||
if wake_word_title:
|
if wake_word_title:
|
||||||
cmd_str = f"{TRAIN_CMD} '{safe_word}' '{wake_word_title}'"
|
cmd_str = f"{TRAIN_CMD} '{safe_word}' '{wake_word_title}'"
|
||||||
else:
|
else:
|
||||||
@@ -361,7 +334,6 @@ def _run_training_background(safe_word: str, allow_no_personal: bool):
|
|||||||
STATE["training"]["running"] = False
|
STATE["training"]["running"] = False
|
||||||
|
|
||||||
|
|
||||||
# -------------------- Routes --------------------
|
|
||||||
@app.get("/", response_class=HTMLResponse)
|
@app.get("/", response_class=HTMLResponse)
|
||||||
def index():
|
def index():
|
||||||
html_path = STATIC_DIR / "index.html"
|
html_path = STATIC_DIR / "index.html"
|
||||||
@@ -394,10 +366,12 @@ def start_session(payload: Dict[str, Any]):
|
|||||||
STATE["takes_per_speaker"] = takes_per_speaker
|
STATE["takes_per_speaker"] = takes_per_speaker
|
||||||
STATE["takes_received"] = 0
|
STATE["takes_received"] = 0
|
||||||
STATE["takes"] = []
|
STATE["takes"] = []
|
||||||
# do not interrupt training if running
|
|
||||||
|
|
||||||
_reset_personal_samples_dir()
|
_reset_personal_samples_dir()
|
||||||
|
|
||||||
|
# Always wipe log on start_session (even if same wakeword)
|
||||||
|
_clear_training_log()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"ok": True,
|
"ok": True,
|
||||||
"raw_phrase": raw,
|
"raw_phrase": raw,
|
||||||
@@ -523,64 +497,42 @@ def train_now(payload: Dict[str, Any] = None):
|
|||||||
|
|
||||||
|
|
||||||
@app.get("/api/train_status")
|
@app.get("/api/train_status")
|
||||||
def train_status(
|
def train_status():
|
||||||
offset: int = Query(0, ge=0),
|
|
||||||
max_bytes: int = Query(65536, ge=1024, le=262144),
|
|
||||||
last_size: int = Query(0, ge=0),
|
|
||||||
last_mtime: float = Query(0.0, ge=0.0),
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Stream training output from the log file on disk.
|
Return only NEW lines since last poll (prevents UI duplication spam even if UI appends).
|
||||||
|
|
||||||
Robust to log overwrite/truncation:
|
|
||||||
- UI passes offset + last_size + last_mtime
|
|
||||||
- If file shrinks or mtime goes backwards/changes weirdly, reset offset to 0
|
|
||||||
"""
|
"""
|
||||||
with STATE_LOCK:
|
with STATE_LOCK:
|
||||||
tr = dict(STATE["training"])
|
tr = dict(STATE["training"])
|
||||||
log_path_str = tr.get("log_path")
|
log_path_str = tr.get("log_path")
|
||||||
|
prev_tail = list(STATE["training"].get("last_sent_tail") or [])
|
||||||
|
prev_size = int(STATE["training"].get("last_log_size") or 0)
|
||||||
|
|
||||||
log_text = ""
|
new_lines: List[str] = []
|
||||||
next_offset = offset
|
full_tail: List[str] = []
|
||||||
log_size = 0
|
size_now = 0
|
||||||
log_mtime = 0.0
|
|
||||||
|
|
||||||
if log_path_str:
|
if log_path_str:
|
||||||
p = Path(log_path_str)
|
p = Path(log_path_str)
|
||||||
if p.exists():
|
if p.exists():
|
||||||
try:
|
try:
|
||||||
st = p.stat()
|
size_now = int(p.stat().st_size)
|
||||||
log_size = int(st.st_size)
|
except Exception:
|
||||||
log_mtime = float(st.st_mtime)
|
size_now = 0
|
||||||
|
|
||||||
# Detect overwrite/truncate/reset:
|
# If file was truncated/cleared, reset history
|
||||||
# - file shrank
|
if size_now < prev_size:
|
||||||
# - file mtime moved "backwards" (rare) or changed while size reset
|
prev_tail = []
|
||||||
# If anything indicates a reset, restart from beginning.
|
|
||||||
if (log_size < last_size) or (last_mtime and log_mtime < last_mtime):
|
|
||||||
offset = 0
|
|
||||||
|
|
||||||
# Clamp offset to current file size
|
full_tail = _read_tail_lines(p, TRAIN_LOG_TAIL_LINES)
|
||||||
if offset > log_size:
|
new_lines = _compute_new_lines(prev_tail, full_tail)
|
||||||
offset = log_size
|
|
||||||
|
|
||||||
# Read incrementally from the file
|
# Save snapshot for next poll
|
||||||
with p.open("rb") as f:
|
with STATE_LOCK:
|
||||||
f.seek(offset)
|
STATE["training"]["last_sent_tail"] = full_tail
|
||||||
chunk = f.read(max_bytes)
|
STATE["training"]["last_log_size"] = size_now
|
||||||
|
|
||||||
log_text = chunk.decode("utf-8", errors="replace")
|
|
||||||
next_offset = offset + len(chunk)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
log_text = f"\n[log read error: {e!r}]\n"
|
|
||||||
next_offset = offset
|
|
||||||
|
|
||||||
tr["log_text"] = log_text
|
|
||||||
tr["next_offset"] = next_offset
|
|
||||||
tr["log_size"] = log_size
|
|
||||||
tr["log_mtime"] = log_mtime
|
|
||||||
|
|
||||||
|
tr["log_text"] = "\n".join(new_lines) # ONLY new lines
|
||||||
|
tr["log_tail_preview"] = "\n".join(full_tail) # optional: handy for debugging
|
||||||
return {"ok": True, "training": tr}
|
return {"ok": True, "training": tr}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -250,6 +250,10 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function api(path, opts) {
|
async function api(path, opts) {
|
||||||
|
opts = opts || {};
|
||||||
|
// Always try to avoid cache for polling endpoints
|
||||||
|
if (!opts.cache) opts.cache = "no-store";
|
||||||
|
|
||||||
const res = await fetch(path, opts);
|
const res = await fetch(path, opts);
|
||||||
const ct = res.headers.get("content-type") || "";
|
const ct = res.headers.get("content-type") || "";
|
||||||
const data = ct.includes("application/json") ? await res.json() : await res.text();
|
const data = ct.includes("application/json") ? await res.json() : await res.text();
|
||||||
@@ -268,10 +272,9 @@
|
|||||||
return (el.scrollHeight - el.scrollTop - el.clientHeight) <= px;
|
return (el.scrollHeight - el.scrollTop - el.clientHeight) <= px;
|
||||||
}
|
}
|
||||||
|
|
||||||
function appendLogChunkAutoScroll(el, chunk) {
|
function setLogTextAutoScroll(el, text) {
|
||||||
if (!chunk) return;
|
|
||||||
const stick = isNearBottom(el);
|
const stick = isNearBottom(el);
|
||||||
el.textContent += chunk;
|
el.textContent = text || "";
|
||||||
if (stick) el.scrollTop = el.scrollHeight;
|
if (stick) el.scrollTop = el.scrollHeight;
|
||||||
}
|
}
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
@@ -296,12 +299,21 @@
|
|||||||
let currentTake = 0;
|
let currentTake = 0;
|
||||||
let takesPerSpeaker = 10;
|
let takesPerSpeaker = 10;
|
||||||
|
|
||||||
// --- incremental log streaming state ---
|
// --- training poll (append mode; scrollback works) ---
|
||||||
// Polls /api/train_status?offset=<N> and appends training.log_text (reads /data/recorder_training.log)
|
|
||||||
let trainOffset = 0;
|
|
||||||
let trainingPollRunning = false;
|
let trainingPollRunning = false;
|
||||||
let trainingPollAbort = false;
|
let trainingPollAbort = false;
|
||||||
|
|
||||||
|
let logBuffer = ""; // full text we’ve shown in the browser
|
||||||
|
let lastChunk = ""; // last chunk we received (for de-dupe)
|
||||||
|
let seenAnyOutput = false;
|
||||||
|
|
||||||
|
function appendLogAutoScroll(el, chunk) {
|
||||||
|
if (!chunk) return;
|
||||||
|
const stick = isNearBottom(el);
|
||||||
|
el.textContent += chunk;
|
||||||
|
if (stick) el.scrollTop = el.scrollHeight;
|
||||||
|
}
|
||||||
|
|
||||||
function startThreshold() { return parseFloat($("startThresh").value); }
|
function startThreshold() { return parseFloat($("startThresh").value); }
|
||||||
function silenceStopMs() { return parseInt($("silenceMs").value, 10); }
|
function silenceStopMs() { return parseInt($("silenceMs").value, 10); }
|
||||||
function minTakeMs() { return parseInt($("minTakeMs").value, 10); }
|
function minTakeMs() { return parseInt($("minTakeMs").value, 10); }
|
||||||
@@ -585,9 +597,11 @@
|
|||||||
|
|
||||||
setPill($("status"), auto ? "Auto-starting training…" : "Preparing training environment…", "warn");
|
setPill($("status"), auto ? "Auto-starting training…" : "Preparing training environment…", "warn");
|
||||||
|
|
||||||
// reset streaming log state (we show recorder_training.log from the start of this run)
|
// Reset log state for a fresh run
|
||||||
trainOffset = 0;
|
|
||||||
trainingPollAbort = false;
|
trainingPollAbort = false;
|
||||||
|
logBuffer = "";
|
||||||
|
lastChunk = "";
|
||||||
|
seenAnyOutput = false;
|
||||||
|
|
||||||
const logEl = $("trainLog");
|
const logEl = $("trainLog");
|
||||||
logEl.textContent = "(preparing…)\n";
|
logEl.textContent = "(preparing…)\n";
|
||||||
@@ -603,7 +617,7 @@
|
|||||||
// Only start polling AFTER training was successfully kicked off
|
// Only start polling AFTER training was successfully kicked off
|
||||||
if (!trainingPollRunning) {
|
if (!trainingPollRunning) {
|
||||||
trainingPollRunning = true;
|
trainingPollRunning = true;
|
||||||
pollTrainingIncremental();
|
pollTrainingTail();
|
||||||
}
|
}
|
||||||
|
|
||||||
setPill($("status"), "Training running…", "warn");
|
setPill($("status"), "Training running…", "warn");
|
||||||
@@ -636,9 +650,7 @@
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// Polls /api/train_status?offset=<trainOffset>
|
async function pollTrainingTail() {
|
||||||
// Expects JSON: { ok: true, training: { running, exit_code, log_text, next_offset } }
|
|
||||||
async function pollTrainingIncremental() {
|
|
||||||
const logEl = $("trainLog");
|
const logEl = $("trainLog");
|
||||||
|
|
||||||
for (;;) {
|
for (;;) {
|
||||||
@@ -648,22 +660,37 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const st = await api(`/api/train_status?offset=${trainOffset}`, { method:"GET" });
|
const st = await api(`/api/train_status?ts=${Date.now()}`, { method:"GET", cache:"no-store" });
|
||||||
const tr = st.training || {};
|
const tr = st.training || {};
|
||||||
|
|
||||||
const chunk = tr.log_text || "";
|
// NOTE: this assumes /api/train_status returns NEW output chunks (not full tail snapshots)
|
||||||
const next = (typeof tr.next_offset === "number") ? tr.next_offset : trainOffset;
|
const chunkRaw = tr.log_text || "";
|
||||||
|
const chunk = chunkRaw; // keep exact newlines from server
|
||||||
|
|
||||||
// If we got real output, replace the "(preparing…)" placeholder
|
if (chunk) {
|
||||||
if (chunk && logEl.textContent.startsWith("(preparing…)")) {
|
// wipe placeholder once
|
||||||
logEl.textContent = "";
|
if (!seenAnyOutput) {
|
||||||
|
logEl.textContent = "";
|
||||||
|
logBuffer = "";
|
||||||
|
lastChunk = "";
|
||||||
|
seenAnyOutput = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// simple de-dupe: if server repeats the same chunk, skip it
|
||||||
|
if (chunk !== lastChunk) {
|
||||||
|
lastChunk = chunk;
|
||||||
|
logBuffer += chunk;
|
||||||
|
appendLogAutoScroll(logEl, chunk);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// before first output, show waiting message but do NOT overwrite later scrollback
|
||||||
|
if (!seenAnyOutput) {
|
||||||
|
if (!logEl.textContent || logEl.textContent.includes("(no training") || logEl.textContent.startsWith("(preparing…")) {
|
||||||
|
logEl.textContent = "Waiting for training output…\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (chunk) appendLogChunkAutoScroll(logEl, chunk);
|
|
||||||
|
|
||||||
trainOffset = next;
|
|
||||||
|
|
||||||
// Stop polling only when training has ended and exit_code is set
|
|
||||||
const exitCodeIsSet = (tr.exit_code !== null && tr.exit_code !== undefined);
|
const exitCodeIsSet = (tr.exit_code !== null && tr.exit_code !== undefined);
|
||||||
|
|
||||||
if (!tr.running && exitCodeIsSet) {
|
if (!tr.running && exitCodeIsSet) {
|
||||||
@@ -681,7 +708,7 @@
|
|||||||
// ignore transient polling errors
|
// ignore transient polling errors
|
||||||
}
|
}
|
||||||
|
|
||||||
await new Promise(r => setTimeout(r, 1500));
|
await new Promise(r => setTimeout(r, 1000));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -717,11 +744,12 @@
|
|||||||
$("takesList").textContent = "";
|
$("takesList").textContent = "";
|
||||||
$("trainLog").textContent = "(no training started)";
|
$("trainLog").textContent = "(no training started)";
|
||||||
|
|
||||||
trainOffset = 0;
|
// Stop any previous poll loop cleanly
|
||||||
|
|
||||||
// If a previous training poll loop is running, ask it to stop
|
|
||||||
trainingPollAbort = true;
|
trainingPollAbort = true;
|
||||||
trainingPollRunning = false;
|
trainingPollRunning = false;
|
||||||
|
logBuffer = "";
|
||||||
|
lastChunk = "";
|
||||||
|
seenAnyOutput = false;
|
||||||
|
|
||||||
refreshUI();
|
refreshUI();
|
||||||
|
|
||||||
@@ -741,6 +769,7 @@
|
|||||||
setPill($("sessionPill"), "Session failed", "err");
|
setPill($("sessionPill"), "Session failed", "err");
|
||||||
alert("Start session failed: " + e.message);
|
alert("Start session failed: " + e.message);
|
||||||
} finally {
|
} finally {
|
||||||
|
// allow a new poll loop to start later
|
||||||
trainingPollAbort = false;
|
trainingPollAbort = false;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ AUGMENTED_DIR="${DATA_DIR}/work/wake_word_samples_augmented"
|
|||||||
if ${AUGMENT} ; then
|
if ${AUGMENT} ; then
|
||||||
rm -rf "${AUGMENTED_DIR}" || :
|
rm -rf "${AUGMENTED_DIR}" || :
|
||||||
mkdir -p "${AUGMENTED_DIR}" || :
|
mkdir -p "${AUGMENTED_DIR}" || :
|
||||||
"${CLIDIR}/wake_word_sample_augmenter" --data-dir="${DATA_DIR}" || { rm -rf "${AUGMENTED_DIR}" ; exit 1 ; }
|
python -u "${CLIDIR}/wake_word_sample_augmenter" --data-dir="${DATA_DIR}" || { rm -rf "${AUGMENTED_DIR}" ; exit 1 ; }
|
||||||
else
|
else
|
||||||
echo "Augmentation not required"
|
echo "Augmentation not required"
|
||||||
echo
|
echo
|
||||||
|
|||||||
Reference in New Issue
Block a user