Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 50 additions & 39 deletions dlclibrary/dlcmodelzoo/modelzoo_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import json
import os
import tarfile
import shutil
import tempfile
from pathlib import Path

from huggingface_hub import hf_hub_download
Expand Down Expand Up @@ -111,34 +113,50 @@ def get_available_models(dataset: str) -> list[str]:
return list(_load_pytorch_dataset_models(dataset)["pose_models"].keys())



def _handle_downloaded_file(
file_path: str, target_dir: str, rename_mapping: dict | None = None
):
"""Handle the downloaded file from HuggingFace"""
"""Handle the downloaded file from HuggingFace cache and place the final artifact in target_dir."""
file_name = os.path.basename(file_path)

try:
with tarfile.open(file_path, mode="r:gz") as tar:
for member in tar:
if not member.isdir():
fname = Path(member.name).name
tar.makefile(member, os.path.join(target_dir, fname))
except tarfile.ReadError: # The model is a .pt file
# Be permissive about compression type
with tarfile.open(file_path, mode="r:*") as tar:
extracted_any = False
for member in tar.getmembers():
# Only extract regular files
if not member.isfile():
continue

fname = Path(member.name).name
if not fname:
continue

src = tar.extractfile(member)
if src is None:
continue

extracted_path = os.path.join(target_dir, fname)
with src, open(extracted_path, "wb") as dst:
shutil.copyfileobj(src, dst)

extracted_any = True

# If it opened as a tar but contained nothing useful, fail loudly
if not extracted_any:
raise tarfile.ReadError(f"No regular files extracted from archive: {file_path}")

except tarfile.ReadError:
# Not an archive -> treat as a direct model file (.pt/.pth/etc.)
if rename_mapping is not None:
file_name = rename_mapping.get(file_name, file_name)
if os.path.islink(file_path):
file_path_ = os.readlink(file_path)
if not os.path.isabs(file_path_):
file_path_ = os.path.abspath(
os.path.join(os.path.dirname(file_path), file_path_)
)
file_path = file_path_
os.rename(file_path, os.path.join(target_dir, file_name))
shutil.copy2(file_path, os.path.join(target_dir, file_name))


def download_huggingface_model(
model_name: str,
target_dir: str = ".",
remove_hf_folder: bool = True,
rename_mapping: str | dict | None = None,
):
"""
Expand All @@ -151,10 +169,6 @@ def download_huggingface_model(
target_dir (str, optional):
Target directory where the model weights will be stored.
Defaults to the current directory.
remove_hf_folder (bool, optional):
Whether to remove the directory structure created by HuggingFace
after downloading and decompressing the data into DeepLabCut format.
Defaults to True.
rename_mapping (dict | str | None, optional):
- If a dictionary, it should map the original Hugging Face filenames
to new filenames (e.g. {"snapshot-12345.tar.gz": "mymodel.tar.gz"}).
Expand All @@ -164,7 +178,7 @@ def download_huggingface_model(

Examples:
>>> # Download without renaming, keep original filename
download_huggingface_model("superanimal_bird_resnet_50", remove_hf_folder=False)
download_huggingface_model("superanimal_bird_resnet_50")

>>> # Download and rename by specifying the new name directly
download_huggingface_model(
Expand All @@ -188,25 +202,22 @@ def download_huggingface_model(

if not os.path.isabs(target_dir):
target_dir = os.path.abspath(target_dir)
os.makedirs(target_dir, exist_ok=True)

for url in urls:
url = url.split("/")
repo_id, targzfn = url[0] + "/" + url[1], str(url[-1])

hf_hub_download(repo_id, targzfn, cache_dir=str(target_dir))

# Create a new subfolder as indicated below, unzipping from there and deleting this folder
hf_folder = f"models--{url[0]}--{url[1]}"
path_ = os.path.join(target_dir, hf_folder, "snapshots")
commit = os.listdir(path_)[0]
file_name = os.path.join(path_, commit, targzfn)

if isinstance(rename_mapping, str):
rename_mapping = {targzfn: rename_mapping}
with tempfile.TemporaryDirectory(prefix="dlc_hf_") as hf_cache_dir:
for url in urls:
url = url.split("/")
repo_id, targzfn = url[0] + "/" + url[1], str(url[-1])

_handle_downloaded_file(file_name, target_dir, rename_mapping)
downloaded = hf_hub_download(
repo_id=repo_id,
filename=targzfn,
cache_dir=hf_cache_dir,
)

if remove_hf_folder:
import shutil
if isinstance(rename_mapping, str):
mapping = {targzfn: rename_mapping}
else:
mapping = rename_mapping

shutil.rmtree(os.path.join(target_dir, hf_folder))
_handle_downloaded_file(downloaded, target_dir, mapping)
154 changes: 140 additions & 14 deletions tests/test_modeldownload.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,125 @@
#
# Licensed under GNU Lesser General Public License v3.0
#
import dlclibrary
from __future__ import annotations

import io
import os
import tarfile
from pathlib import Path

import pytest

import dlclibrary
import dlclibrary.dlcmodelzoo.modelzoo_download as modelzoo_download
from dlclibrary.dlcmodelzoo.modelzoo_download import MODELOPTIONS


def test_download_huggingface_model(tmp_path_factory, model="full_cat"):
folder = tmp_path_factory.mktemp("temp")
def _fake_model_names():
"""
Return a deterministic fake URL for each model.
Alternate between tar.gz and .pt to test both branches.
"""
mapping = {}
for i, model in enumerate(MODELOPTIONS):
ext = ".tar.gz" if i % 2 == 0 else ".pt"
mapping[model] = f"fakeorg/{model}-repo/{model}{ext}"
return mapping


def _write_fake_tar_gz(path: Path):
"""
Create a fake tar.gz archive with the files the downloader expects
for archive-based DLC models.
"""
path.parent.mkdir(parents=True, exist_ok=True)

with tarfile.open(path, mode="w:gz") as tar:
files = {
"pose_cfg.yaml": b"all_joints: [0, 1]\n",
"snapshot-103000.index": b"fake index",
"snapshot-103000.data-00000-of-00001": b"fake weights",
"snapshot-103000.meta": b"fake meta",
}

for name, content in files.items():
info = tarfile.TarInfo(name=name)
info.size = len(content)
tar.addfile(info, io.BytesIO(content))


def _write_fake_pt(path: Path):
"""
Create a fake .pt / .pth weight file.
"""
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(b"fake pytorch weights")


@pytest.fixture
def mock_modelzoo(monkeypatch):
"""
Patch both:
- model name resolution
- hf_hub_download network call

so all downloads are local and deterministic.
"""
fake_names = _fake_model_names()

monkeypatch.setattr(modelzoo_download, "_load_model_names", lambda: fake_names)

def fake_hf_hub_download(repo_id, filename, cache_dir):
cache_dir = Path(cache_dir)
hf_folder = cache_dir / f"models--{repo_id.replace('/', '--')}"
snapshot_dir = hf_folder / "snapshots" / "fakecommit123"
returned_file = snapshot_dir / filename

if filename.endswith(".tar.gz"):
_write_fake_tar_gz(returned_file)
elif filename.endswith(".pt") or filename.endswith(".pth"):
_write_fake_pt(returned_file)
else:
raise AssertionError(f"Unexpected mocked filename: {filename}")

return str(returned_file)

monkeypatch.setattr(modelzoo_download, "hf_hub_download", fake_hf_hub_download)

return fake_names


def _assert_download_success(folder: Path, model: str):
"""
Shared assertion helper for download_huggingface_model.
"""
dlclibrary.download_huggingface_model(model, str(folder))

try: # These are not created for .pt models
assert os.path.exists(folder / "pose_cfg.yaml")
assert any(f.startswith("snapshot-") for f in os.listdir(folder))
except AssertionError:
assert any(f.endswith(".pth") for f in os.listdir(folder))
files = {p.name for p in folder.iterdir()}

# Archive-based DLC model
if "pose_cfg.yaml" in files:
assert "pose_cfg.yaml" in files
assert any(name.startswith("snapshot-") for name in files)

# Direct PyTorch model
else:
assert any(name.endswith((".pt", ".pth")) for name in files)

# Verify that the Hugging Face cache folder was removed
assert not any(name.startswith("models--") for name in files)

# Verify that the Hugging Face folder was removed
assert not any(f.startswith("models--") for f in os.listdir(folder))

def test_download_huggingface_model_tar_or_pt(tmp_path, mock_modelzoo):
folder = tmp_path / "download_one"
folder.mkdir()

def test_download_huggingface_wrong_model():
# "full_cat" may map to tar.gz or .pt depending on ordering;
# this assertion helper supports both branches.
_assert_download_success(folder, "full_cat")


def test_download_huggingface_wrong_model(mock_modelzoo):
with pytest.raises(ValueError):
dlclibrary.download_huggingface_model("wrong_model_name")

Expand All @@ -40,6 +138,34 @@ def test_parse_superanimal_models():


@pytest.mark.parametrize("model", MODELOPTIONS)
def test_download_all_models(tmp_path_factory, model):
print("Downloading ...", model)
test_download_huggingface_model(tmp_path_factory, model)
def test_download_all_models(tmp_path, mock_modelzoo, model):
folder = tmp_path / model
folder.mkdir()
_assert_download_success(folder, model)


def test_download_with_rename_mapping_for_pt(tmp_path, mock_modelzoo):
"""
Explicitly test rename_mapping for a .pt model.
"""
# Pick one of the mocked .pt models
pt_model = None
for i, model in enumerate(MODELOPTIONS):
if i % 2 == 1:
pt_model = model
break

assert pt_model is not None, "Expected at least one mocked .pt model"

folder = tmp_path / "rename_pt"
folder.mkdir()

dlclibrary.download_huggingface_model(
pt_model,
str(folder),
rename_mapping="renamed_weights.pt",
)

files = {p.name for p in folder.iterdir()}
assert "renamed_weights.pt" in files
assert not any(name.startswith("models--") for name in files)
Loading