52 lines
1.6 KiB
Python
52 lines
1.6 KiB
Python
import argparse
|
|
import logging
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
from huggingface_hub import hf_hub_download, list_repo_files, snapshot_download
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
REPO_ID = "WarriorMama777/GLaDOS_TTS"
|
|
MODEL_SUBDIR = "Models/Style-Bert_VITS2/Portal_GLaDOS_v1"
|
|
|
|
|
|
def download_model(output_dir: Path) -> Path:
|
|
output_dir = output_dir.resolve()
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
files = list_repo_files(REPO_ID)
|
|
model_files = [f for f in files if f.startswith(MODEL_SUBDIR)]
|
|
|
|
if not model_files:
|
|
raise ValueError(f"No files found in {REPO_ID}/{MODEL_SUBDIR}")
|
|
|
|
for file_path in model_files:
|
|
_LOGGER.info("Downloading %s...", file_path)
|
|
downloaded = hf_hub_download(
|
|
repo_id=REPO_ID,
|
|
filename=file_path,
|
|
local_dir_use_symlinks=False,
|
|
)
|
|
src = Path(downloaded)
|
|
dst = output_dir / src.name
|
|
if src != dst:
|
|
_LOGGER.info("Copying %s -> %s", src.name, dst)
|
|
shutil.copy2(src, dst)
|
|
|
|
_LOGGER.info("Model downloaded to %s", output_dir)
|
|
for f in sorted(output_dir.iterdir()):
|
|
if f.is_file():
|
|
_LOGGER.info(" %s (%d bytes)", f.name, f.stat().st_size)
|
|
|
|
return output_dir
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Download GLaDOS TTS model")
|
|
parser.add_argument("--output-dir", type=Path, default="/data",
|
|
help="Output directory for model files")
|
|
args = parser.parse_args()
|
|
download_model(args.output_dir)
|