diff --git a/dlclibrary/dlcmodelzoo/modelzoo_download.py b/dlclibrary/dlcmodelzoo/modelzoo_download.py index 12dff74..49162a5 100644 --- a/dlclibrary/dlcmodelzoo/modelzoo_download.py +++ b/dlclibrary/dlcmodelzoo/modelzoo_download.py @@ -13,6 +13,8 @@ import json import os import tarfile +import shutil +import tempfile from pathlib import Path from huggingface_hub import hf_hub_download @@ -111,34 +113,50 @@ def get_available_models(dataset: str) -> list[str]: return list(_load_pytorch_dataset_models(dataset)["pose_models"].keys()) + def _handle_downloaded_file( file_path: str, target_dir: str, rename_mapping: dict | None = None ): - """Handle the downloaded file from HuggingFace""" + """Handle the downloaded file from HuggingFace cache and place the final artifact in target_dir.""" file_name = os.path.basename(file_path) + try: - with tarfile.open(file_path, mode="r:gz") as tar: - for member in tar: - if not member.isdir(): - fname = Path(member.name).name - tar.makefile(member, os.path.join(target_dir, fname)) - except tarfile.ReadError: # The model is a .pt file + # Be permissive about compression type + with tarfile.open(file_path, mode="r:*") as tar: + extracted_any = False + for member in tar.getmembers(): + # Only extract regular files + if not member.isfile(): + continue + + fname = Path(member.name).name + if not fname: + continue + + src = tar.extractfile(member) + if src is None: + continue + + extracted_path = os.path.join(target_dir, fname) + with src, open(extracted_path, "wb") as dst: + shutil.copyfileobj(src, dst) + + extracted_any = True + + # If it opened as a tar but contained nothing useful, fail loudly + if not extracted_any: + raise tarfile.ReadError(f"No regular files extracted from archive: {file_path}") + + except tarfile.ReadError: + # Not an archive -> treat as a direct model file (.pt/.pth/etc.) if rename_mapping is not None: file_name = rename_mapping.get(file_name, file_name) - if os.path.islink(file_path): - file_path_ = os.readlink(file_path) - if not os.path.isabs(file_path_): - file_path_ = os.path.abspath( - os.path.join(os.path.dirname(file_path), file_path_) - ) - file_path = file_path_ - os.rename(file_path, os.path.join(target_dir, file_name)) + shutil.copy2(file_path, os.path.join(target_dir, file_name)) def download_huggingface_model( model_name: str, target_dir: str = ".", - remove_hf_folder: bool = True, rename_mapping: str | dict | None = None, ): """ @@ -151,10 +169,6 @@ def download_huggingface_model( target_dir (str, optional): Target directory where the model weights will be stored. Defaults to the current directory. - remove_hf_folder (bool, optional): - Whether to remove the directory structure created by HuggingFace - after downloading and decompressing the data into DeepLabCut format. - Defaults to True. rename_mapping (dict | str | None, optional): - If a dictionary, it should map the original Hugging Face filenames to new filenames (e.g. {"snapshot-12345.tar.gz": "mymodel.tar.gz"}). @@ -164,7 +178,7 @@ def download_huggingface_model( Examples: >>> # Download without renaming, keep original filename - download_huggingface_model("superanimal_bird_resnet_50", remove_hf_folder=False) + download_huggingface_model("superanimal_bird_resnet_50") >>> # Download and rename by specifying the new name directly download_huggingface_model( @@ -188,25 +202,22 @@ def download_huggingface_model( if not os.path.isabs(target_dir): target_dir = os.path.abspath(target_dir) + os.makedirs(target_dir, exist_ok=True) - for url in urls: - url = url.split("/") - repo_id, targzfn = url[0] + "/" + url[1], str(url[-1]) - - hf_hub_download(repo_id, targzfn, cache_dir=str(target_dir)) - - # Create a new subfolder as indicated below, unzipping from there and deleting this folder - hf_folder = f"models--{url[0]}--{url[1]}" - path_ = os.path.join(target_dir, hf_folder, "snapshots") - commit = os.listdir(path_)[0] - file_name = os.path.join(path_, commit, targzfn) - - if isinstance(rename_mapping, str): - rename_mapping = {targzfn: rename_mapping} + with tempfile.TemporaryDirectory(prefix="dlc_hf_") as hf_cache_dir: + for url in urls: + url = url.split("/") + repo_id, targzfn = url[0] + "/" + url[1], str(url[-1]) - _handle_downloaded_file(file_name, target_dir, rename_mapping) + downloaded = hf_hub_download( + repo_id=repo_id, + filename=targzfn, + cache_dir=hf_cache_dir, + ) - if remove_hf_folder: - import shutil + if isinstance(rename_mapping, str): + mapping = {targzfn: rename_mapping} + else: + mapping = rename_mapping - shutil.rmtree(os.path.join(target_dir, hf_folder)) + _handle_downloaded_file(downloaded, target_dir, mapping) diff --git a/tests/test_modeldownload.py b/tests/test_modeldownload.py index 6532f0e..2e0d7bc 100644 --- a/tests/test_modeldownload.py +++ b/tests/test_modeldownload.py @@ -8,27 +8,125 @@ # # Licensed under GNU Lesser General Public License v3.0 # -import dlclibrary +from __future__ import annotations + +import io import os +import tarfile +from pathlib import Path + import pytest + +import dlclibrary +import dlclibrary.dlcmodelzoo.modelzoo_download as modelzoo_download from dlclibrary.dlcmodelzoo.modelzoo_download import MODELOPTIONS -def test_download_huggingface_model(tmp_path_factory, model="full_cat"): - folder = tmp_path_factory.mktemp("temp") +def _fake_model_names(): + """ + Return a deterministic fake URL for each model. + Alternate between tar.gz and .pt to test both branches. + """ + mapping = {} + for i, model in enumerate(MODELOPTIONS): + ext = ".tar.gz" if i % 2 == 0 else ".pt" + mapping[model] = f"fakeorg/{model}-repo/{model}{ext}" + return mapping + + +def _write_fake_tar_gz(path: Path): + """ + Create a fake tar.gz archive with the files the downloader expects + for archive-based DLC models. + """ + path.parent.mkdir(parents=True, exist_ok=True) + + with tarfile.open(path, mode="w:gz") as tar: + files = { + "pose_cfg.yaml": b"all_joints: [0, 1]\n", + "snapshot-103000.index": b"fake index", + "snapshot-103000.data-00000-of-00001": b"fake weights", + "snapshot-103000.meta": b"fake meta", + } + + for name, content in files.items(): + info = tarfile.TarInfo(name=name) + info.size = len(content) + tar.addfile(info, io.BytesIO(content)) + + +def _write_fake_pt(path: Path): + """ + Create a fake .pt / .pth weight file. + """ + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(b"fake pytorch weights") + + +@pytest.fixture +def mock_modelzoo(monkeypatch): + """ + Patch both: + - model name resolution + - hf_hub_download network call + + so all downloads are local and deterministic. + """ + fake_names = _fake_model_names() + + monkeypatch.setattr(modelzoo_download, "_load_model_names", lambda: fake_names) + + def fake_hf_hub_download(repo_id, filename, cache_dir): + cache_dir = Path(cache_dir) + hf_folder = cache_dir / f"models--{repo_id.replace('/', '--')}" + snapshot_dir = hf_folder / "snapshots" / "fakecommit123" + returned_file = snapshot_dir / filename + + if filename.endswith(".tar.gz"): + _write_fake_tar_gz(returned_file) + elif filename.endswith(".pt") or filename.endswith(".pth"): + _write_fake_pt(returned_file) + else: + raise AssertionError(f"Unexpected mocked filename: {filename}") + + return str(returned_file) + + monkeypatch.setattr(modelzoo_download, "hf_hub_download", fake_hf_hub_download) + + return fake_names + + +def _assert_download_success(folder: Path, model: str): + """ + Shared assertion helper for download_huggingface_model. + """ dlclibrary.download_huggingface_model(model, str(folder)) - try: # These are not created for .pt models - assert os.path.exists(folder / "pose_cfg.yaml") - assert any(f.startswith("snapshot-") for f in os.listdir(folder)) - except AssertionError: - assert any(f.endswith(".pth") for f in os.listdir(folder)) + files = {p.name for p in folder.iterdir()} + + # Archive-based DLC model + if "pose_cfg.yaml" in files: + assert "pose_cfg.yaml" in files + assert any(name.startswith("snapshot-") for name in files) + + # Direct PyTorch model + else: + assert any(name.endswith((".pt", ".pth")) for name in files) + + # Verify that the Hugging Face cache folder was removed + assert not any(name.startswith("models--") for name in files) - # Verify that the Hugging Face folder was removed - assert not any(f.startswith("models--") for f in os.listdir(folder)) +def test_download_huggingface_model_tar_or_pt(tmp_path, mock_modelzoo): + folder = tmp_path / "download_one" + folder.mkdir() -def test_download_huggingface_wrong_model(): + # "full_cat" may map to tar.gz or .pt depending on ordering; + # this assertion helper supports both branches. + _assert_download_success(folder, "full_cat") + + +def test_download_huggingface_wrong_model(mock_modelzoo): with pytest.raises(ValueError): dlclibrary.download_huggingface_model("wrong_model_name") @@ -40,6 +138,34 @@ def test_parse_superanimal_models(): @pytest.mark.parametrize("model", MODELOPTIONS) -def test_download_all_models(tmp_path_factory, model): - print("Downloading ...", model) - test_download_huggingface_model(tmp_path_factory, model) +def test_download_all_models(tmp_path, mock_modelzoo, model): + folder = tmp_path / model + folder.mkdir() + _assert_download_success(folder, model) + + +def test_download_with_rename_mapping_for_pt(tmp_path, mock_modelzoo): + """ + Explicitly test rename_mapping for a .pt model. + """ + # Pick one of the mocked .pt models + pt_model = None + for i, model in enumerate(MODELOPTIONS): + if i % 2 == 1: + pt_model = model + break + + assert pt_model is not None, "Expected at least one mocked .pt model" + + folder = tmp_path / "rename_pt" + folder.mkdir() + + dlclibrary.download_huggingface_model( + pt_model, + str(folder), + rename_mapping="renamed_weights.pt", + ) + + files = {p.name for p in folder.iterdir()} + assert "renamed_weights.pt" in files + assert not any(name.startswith("models--") for name in files)