Coverage for dantinox / hub.py: 34%
29 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-02 10:39 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-02 10:39 +0200
1"""
2HuggingFace Hub integration for DantinoX.
4Push a trained checkpoint to the Hub and pull it back on any machine.
6Examples
7--------
8CLI:
9 dantinox push --run_dir runs/my_run --repo my-org/dantinox-dante
10 dantinox pull --repo my-org/dantinox-dante --local_dir runs/pulled
12Python API:
13 from dantinox.hub import push, pull
15 url = push("runs/my_run", "my-org/dantinox-dante", private=True)
16 run_dir = pull("my-org/dantinox-dante")
17 gen = Generator(run_dir)
18"""
20from __future__ import annotations
22import logging
23import os
25log = logging.getLogger(__name__)
27_UPLOAD_IGNORE = ["*.log", "__pycache__/", "*.pyc"]
30def resolve_checkpoint(
31 path_or_repo: str,
32 *,
33 token: str | None = None,
34 revision: str | None = None,
35) -> str:
36 """Return a local directory path for *path_or_repo*.
38 If *path_or_repo* is an existing local directory it is returned unchanged.
39 Otherwise it is treated as a HuggingFace Hub repo ID (e.g.
40 ``"my-org/dantinox-dante"``) and the checkpoint is downloaded via
41 :func:`pull` before returning the local cache path.
43 Parameters
44 ----------
45 path_or_repo:
46 Local run directory **or** HuggingFace Hub repo ID.
47 token:
48 HuggingFace access token for private repositories.
49 revision:
50 Branch, tag, or commit SHA to download.
52 Returns
53 -------
54 str
55 Absolute path to a local directory suitable for passing to
56 ``Generator()``, ``Transformer.from_pretrained()``, etc.
57 """
58 if os.path.isdir(path_or_repo):
59 return path_or_repo
60 return pull(path_or_repo, token=token, revision=revision)
63def push(
64 run_dir: str,
65 repo_id: str,
66 *,
67 private: bool = False,
68 token: str | None = None,
69 commit_message: str | None = None,
70) -> str:
71 """
72 Upload a run directory to a HuggingFace Hub model repository.
74 Creates the repository if it does not exist. Only the core checkpoint
75 files are uploaded (``config.yaml``, ``tokenizer.json``,
76 ``model_weights.msgpack``, ``best_model_weights.msgpack``,
77 ``model_summary.json``). Log files are excluded.
79 Parameters
80 ----------
81 run_dir : str
82 Local path to a DantinoX run directory.
83 repo_id : str
84 Hub repository in the form ``"owner/repo-name"``.
85 private : bool
86 Create the repository as private (default False).
87 token : str, optional
88 HuggingFace access token. Falls back to the ``HF_TOKEN``
89 environment variable or the cached login token.
90 commit_message : str, optional
91 Commit message for the upload (auto-generated if omitted).
93 Returns
94 -------
95 str
96 URL of the Hub repository after the upload.
98 Raises
99 ------
100 ImportError
101 If ``huggingface_hub`` is not installed.
102 """
103 try:
104 from huggingface_hub import HfApi
105 except ImportError as exc:
106 raise ImportError(
107 "huggingface_hub is required for Hub integration: "
108 "pip install huggingface-hub"
109 ) from exc
111 import os
112 msg = commit_message or f"Upload DantinoX checkpoint from {os.path.basename(run_dir)}"
114 api = HfApi(token=token)
115 api.create_repo(repo_id, repo_type="model", private=private, exist_ok=True)
117 url = api.upload_folder(
118 folder_path=run_dir,
119 repo_id=repo_id,
120 repo_type="model",
121 commit_message=msg,
122 ignore_patterns=_UPLOAD_IGNORE,
123 )
125 log.info("Pushed %s → %s", run_dir, url)
126 return str(url)
129def pull(
130 repo_id: str,
131 *,
132 local_dir: str | None = None,
133 token: str | None = None,
134 revision: str | None = None,
135) -> str:
136 """
137 Download a DantinoX checkpoint from HuggingFace Hub.
139 Parameters
140 ----------
141 repo_id : str
142 Hub repository in the form ``"owner/repo-name"``.
143 local_dir : str, optional
144 Where to store the downloaded files. Defaults to the HuggingFace
145 cache directory (``~/.cache/huggingface/hub/...``).
146 token : str, optional
147 HuggingFace access token for private repositories.
148 revision : str, optional
149 Git revision (branch, tag, or commit SHA) to download.
151 Returns
152 -------
153 str
154 Path to the local directory containing the checkpoint. Pass this
155 directly to ``Generator(run_dir)``.
157 Raises
158 ------
159 ImportError
160 If ``huggingface_hub`` is not installed.
161 """
162 try:
163 from huggingface_hub import snapshot_download
164 except ImportError as exc:
165 raise ImportError(
166 "huggingface_hub is required for Hub integration: "
167 "pip install huggingface-hub"
168 ) from exc
170 run_dir: str = snapshot_download(
171 repo_id=repo_id,
172 repo_type="model",
173 local_dir=local_dir,
174 token=token,
175 revision=revision,
176 )
178 log.info("Pulled %s → %s", repo_id, run_dir)
179 return run_dir