please god work
This commit is contained in:
51
download_model.py
Normal file
51
download_model.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import hf_hub_download, list_repo_files, snapshot_download
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
REPO_ID = "WarriorMama777/GLaDOS_TTS"
|
||||
MODEL_SUBDIR = "Models/Style-Bert_VITS2/Portal_GLaDOS_v1"
|
||||
|
||||
|
||||
def download_model(output_dir: Path) -> Path:
|
||||
output_dir = output_dir.resolve()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
files = list_repo_files(REPO_ID)
|
||||
model_files = [f for f in files if f.startswith(MODEL_SUBDIR)]
|
||||
|
||||
if not model_files:
|
||||
raise ValueError(f"No files found in {REPO_ID}/{MODEL_SUBDIR}")
|
||||
|
||||
for file_path in model_files:
|
||||
_LOGGER.info("Downloading %s...", file_path)
|
||||
downloaded = hf_hub_download(
|
||||
repo_id=REPO_ID,
|
||||
filename=file_path,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
src = Path(downloaded)
|
||||
dst = output_dir / src.name
|
||||
if src != dst:
|
||||
_LOGGER.info("Copying %s -> %s", src.name, dst)
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
_LOGGER.info("Model downloaded to %s", output_dir)
|
||||
for f in sorted(output_dir.iterdir()):
|
||||
if f.is_file():
|
||||
_LOGGER.info(" %s (%d bytes)", f.name, f.stat().st_size)
|
||||
|
||||
return output_dir
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Download GLaDOS TTS model")
|
||||
parser.add_argument("--output-dir", type=Path, default="/data",
|
||||
help="Output directory for model files")
|
||||
args = parser.parse_args()
|
||||
download_model(args.output_dir)
|
||||
Reference in New Issue
Block a user