#!/usr/bin/env python

import sys, os, gc, glob, random
import types
from datetime import datetime, timezone
from pathlib import Path
from argparse import ArgumentParser as ArgParser, ArgumentError

default_data_dir = os.getcwd() if os.path.exists(".mww-data-dir") else "/data"

parser = ArgParser(exit_on_error=False)
parser.add_argument("--data-dir", type=str, help=f"Data directory. Default: {default_data_dir}", required=False, default=default_data_dir)

# Wake word (TTS/generated) inputs/outputs
parser.add_argument("--input-dir", type=str, help="Wake word input dir. Default: <data-dir>/work/wake_word_samples", required=False)
parser.add_argument("--output-dir", type=str, help="Wake word output dir. Default: <input-dir>_augmented", required=False)

# Personal inputs/outputs (NEW)
parser.add_argument("--personal-dir", type=str, help="Personal WAV dir. Default: <data-dir>/personal_samples", required=False)
parser.add_argument("--personal-output-dir", type=str, help="Personal features output dir. Default: <data-dir>/work/personal_augmented_features", required=False)
parser.add_argument("--negative-dir", type=str, help="Reviewed negative WAV dir. Default: <data-dir>/negative_samples", required=False)
parser.add_argument("--negative-output-dir", type=str, help="Reviewed negative features output dir. Default: <data-dir>/work/reviewed_negative_features", required=False)

# Dataset dirs
parser.add_argument("--mit-rirs-16k-dir", type=str, help="MIT RIR input directory. Default: <data-dir>/training_datasets/mit_rirs_16k", required=False)
parser.add_argument("--fma-16k-dir", type=str, help="FMA input directory. Default: <data-dir>/training_datasets/fma_16k", required=False)
parser.add_argument("--audioset-16k-dir", type=str, help="Audioset input directory. Default: <data-dir>/training_datasets/audioset_16k", required=False)
parser.add_argument("--wham-16k-dir", type=str, help="WHAM input directory. Default: <data-dir>/training_datasets/wham_16k", required=False)
parser.add_argument("--chime-16k-dir", type=str, help="CHiME input directory. Default: <data-dir>/training_datasets/chime_16k", required=False)

try:
    args = parser.parse_args()
except ArgumentError:
    parser.print_help()
    sys.exit(1)

args.data_dir = os.path.realpath(args.data_dir)
work_dir = os.path.join(args.data_dir, "work")

# Wake word defaults
if not args.input_dir:
    args.input_dir = os.path.join(work_dir, "wake_word_samples")
else:
    args.input_dir = os.path.realpath(args.input_dir)

if not args.output_dir:
    args.output_dir = args.input_dir + "_augmented"
else:
    args.output_dir = os.path.realpath(args.output_dir)

# Personal defaults (NEW)
if not args.personal_dir:
    args.personal_dir = os.path.join(args.data_dir, "personal_samples")
else:
    args.personal_dir = os.path.realpath(args.personal_dir)

if not args.personal_output_dir:
    args.personal_output_dir = os.path.join(work_dir, "personal_augmented_features")
else:
    args.personal_output_dir = os.path.realpath(args.personal_output_dir)

# Reviewed negative defaults
if not args.negative_dir:
    args.negative_dir = os.path.join(args.data_dir, "negative_samples")
else:
    args.negative_dir = os.path.realpath(args.negative_dir)

if not args.negative_output_dir:
    args.negative_output_dir = os.path.join(work_dir, "reviewed_negative_features")
else:
    args.negative_output_dir = os.path.realpath(args.negative_output_dir)

# Dataset defaults
if not args.mit_rirs_16k_dir:
    args.mit_rirs_16k_dir = os.path.join(args.data_dir, "training_datasets", "mit_rirs_16k")
else:
    args.mit_rirs_16k_dir = os.path.realpath(args.mit_rirs_16k_dir)

if not args.fma_16k_dir:
    args.fma_16k_dir = os.path.join(args.data_dir, "training_datasets", "fma_16k")
else:
    args.fma_16k_dir = os.path.realpath(args.fma_16k_dir)

if not args.audioset_16k_dir:
    args.audioset_16k_dir = os.path.join(args.data_dir, "training_datasets", "audioset_16k")
else:
    args.audioset_16k_dir = os.path.realpath(args.audioset_16k_dir)

if not args.wham_16k_dir:
    args.wham_16k_dir = os.path.join(args.data_dir, "training_datasets", "wham_16k")
else:
    args.wham_16k_dir = os.path.realpath(args.wham_16k_dir)

if not args.chime_16k_dir:
    args.chime_16k_dir = os.path.join(args.data_dir, "training_datasets", "chime_16k")
else:
    args.chime_16k_dir = os.path.realpath(args.chime_16k_dir)

def validate_directories(paths):
    for path in paths:
        if not os.path.exists(path):
            print(f"Error: Directory {path} does not exist. Please ensure preprocessing is complete.")
            return False
    return True

required = [
    work_dir,
    args.input_dir,
    args.mit_rirs_16k_dir,
    args.wham_16k_dir,
    args.chime_16k_dir,
    args.fma_16k_dir,
    args.audioset_16k_dir,
]
if not validate_directories(required):
    parser.print_help()
    sys.exit(1)

# -------------------- TF + libs --------------------
print("   Initializing libraries")

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"
os.environ["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0"
os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
os.environ["TF_CUDNN_WORKSPACE_LIMIT_IN_MB"] = "512"
os.environ["GLOG_minloglevel"] = "9"
os.environ["GRPC_VERBOSITY"] = "ERROR"

print("   Loading Tensorflow")
import tensorflow as tf

print("   GPU memory config")
for g in tf.config.list_physical_devices("GPU"):
    try:
        tf.config.experimental.set_memory_growth(g, True)
    except Exception:
        pass
print(f"   GPUs: {tf.config.list_physical_devices('GPU')}")
gc.collect()

import numpy as np
import librosa
from mmap_ninja.ragged import RaggedMmap
from microwakeword.audio.augmentation import Augmentation
from microwakeword.audio.clips import Clips
from microwakeword.audio.spectrograms import SpectrogramGeneration

START_TIME = datetime.now(timezone.utc).replace(microsecond=0)

impulse_paths = [args.mit_rirs_16k_dir]
background_paths = [
    args.wham_16k_dir,
    args.chime_16k_dir,
    args.fma_16k_dir,
    args.audioset_16k_dir,
]

augmenter = Augmentation(
    augmentation_duration_s=3.2,
    augmentation_probabilities={
        "SevenBandParametricEQ": 0.1,
        "TanhDistortion": 0.05,
        "PitchShift": 0.15,
        "BandStopFilter": 0.1,
        "AddColorNoise": 0.1,
        "AddBackgroundNoise": 0.7,
        "Gain": 0.8,
        "RIR": 0.7,
    },
    impulse_paths=impulse_paths,
    background_paths=background_paths,
    background_min_snr_db=5,
    background_max_snr_db=10,
    min_jitter_s=0.2,
    max_jitter_s=0.3,
)

split_cfg = {
    "training":   {"name": "train",      "repetition": 2, "slide_frames": 10},
    "validation": {"name": "validation", "repetition": 1, "slide_frames": 10},
    "testing":    {"name": "test",       "repetition": 1, "slide_frames": 1},
}

def bind_wav_generator(clips_obj: Clips, wav_dir: str):
    """
    Patch clips.audio_generator so we load WAVs directly (deterministic 80/10/10 split, seed=10).
    Matches the notebook behavior you posted.
    """
    def audio_generator_from_wavs(self, split="train", repeat=1):
        files = sorted(glob.glob(os.path.join(wav_dir, "*.wav")))
        if not files:
            return

        rng = random.Random(10)
        files_shuf = files[:]
        rng.shuffle(files_shuf)

        n = len(files_shuf)
        n_val = max(1, int(0.10 * n))
        n_test = max(1, int(0.10 * n))
        n_train = max(0, n - n_val - n_test)

        splits = {
            "train":      files_shuf[:n_train],
            "validation": files_shuf[n_train:n_train + n_val],
            "test":       files_shuf[n_train + n_val:],
        }
        file_list = splits.get(split, [])
        if not file_list:
            return

        for _ in range(max(1, int(repeat))):
            for p in file_list:
                y, _sr = librosa.load(p, sr=16000, mono=True)
                yield y.astype(np.float32, copy=False)

    clips_obj.audio_generator = types.MethodType(audio_generator_from_wavs, clips_obj)

def generate_feature_set(input_wav_dir: str, out_root_dir: str, label: str, *, remove_silence: bool = True):
    files = glob.glob(os.path.join(input_wav_dir, "*.wav"))
    if not files:
        print(f"ℹ️  No WAVs found for {label} in: {input_wav_dir} (skipping)")
        return False

    max_samples = len(files)
    print(f"\n===== Augmenting {max_samples} wake word samples ({label}) =====")

    clips = Clips(
        input_directory=input_wav_dir,
        file_pattern="*.wav",
        max_clip_duration_s=5,
        remove_silence=remove_silence,
        random_split_seed=10,
        split_count=0.1,
    )

    bind_wav_generator(clips, input_wav_dir)

    out_root = Path(out_root_dir)
    out_root.mkdir(parents=True, exist_ok=True)

    for split, cfg in split_cfg.items():
        out_dir = out_root / split
        out_dir.mkdir(parents=True, exist_ok=True)

        print(f"   Augmenting {split} ({label})")
        print("   Sit tight this can take awhile ...")
        print()

        spectros = SpectrogramGeneration(
            clips=clips,
            augmenter=augmenter,
            slide_frames=cfg["slide_frames"],
            step_ms=10,
        )

        gen = spectros.spectrogram_generator(
            split=cfg["name"],
            repeat=cfg["repetition"],
        )

        RaggedMmap.from_generator(
            out_dir=str(out_dir / "wakeword_mmap"),
            sample_generator=gen,
            batch_size=100,
            verbose=False,
        )

        print(f"      {split} augmentation complete ({label})")

    print(f"\n✅ Features ready: {out_root_dir}/*/wakeword_mmap\n")
    return True

# Wake word generated/TTS features (existing behavior)
generate_feature_set(args.input_dir, args.output_dir, "generated")

# Personal features
generate_feature_set(args.personal_dir, args.personal_output_dir, "personal")

# Reviewed false-positive / hard-negative features
generate_feature_set(args.negative_dir, args.negative_output_dir, "reviewed negatives", remove_silence=False)

END_TIME = datetime.now(timezone.utc).replace(microsecond=0)
et = END_TIME - START_TIME
print(f"\n{'=' * 80}")
print(f"{'Augmentation completed.':>50s} Elapsed time: {et!s}")
print(f"{'=' * 80}\n")
