Coverage for dantinox / hub.py: 34%

29 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-02 10:39 +0200

1""" 

2HuggingFace Hub integration for DantinoX. 

3 

4Push a trained checkpoint to the Hub and pull it back on any machine. 

5 

6Examples 

7-------- 

8CLI: 

9 dantinox push --run_dir runs/my_run --repo my-org/dantinox-dante 

10 dantinox pull --repo my-org/dantinox-dante --local_dir runs/pulled 

11 

12Python API: 

13 from dantinox.hub import push, pull 

14 

15 url = push("runs/my_run", "my-org/dantinox-dante", private=True) 

16 run_dir = pull("my-org/dantinox-dante") 

17 gen = Generator(run_dir) 

18""" 

19 

20from __future__ import annotations 

21 

22import logging 

23import os 

24 

25log = logging.getLogger(__name__) 

26 

27_UPLOAD_IGNORE = ["*.log", "__pycache__/", "*.pyc"] 

28 

29 

30def resolve_checkpoint( 

31 path_or_repo: str, 

32 *, 

33 token: str | None = None, 

34 revision: str | None = None, 

35) -> str: 

36 """Return a local directory path for *path_or_repo*. 

37 

38 If *path_or_repo* is an existing local directory it is returned unchanged. 

39 Otherwise it is treated as a HuggingFace Hub repo ID (e.g. 

40 ``"my-org/dantinox-dante"``) and the checkpoint is downloaded via 

41 :func:`pull` before returning the local cache path. 

42 

43 Parameters 

44 ---------- 

45 path_or_repo: 

46 Local run directory **or** HuggingFace Hub repo ID. 

47 token: 

48 HuggingFace access token for private repositories. 

49 revision: 

50 Branch, tag, or commit SHA to download. 

51 

52 Returns 

53 ------- 

54 str 

55 Absolute path to a local directory suitable for passing to 

56 ``Generator()``, ``Transformer.from_pretrained()``, etc. 

57 """ 

58 if os.path.isdir(path_or_repo): 

59 return path_or_repo 

60 return pull(path_or_repo, token=token, revision=revision) 

61 

62 

63def push( 

64 run_dir: str, 

65 repo_id: str, 

66 *, 

67 private: bool = False, 

68 token: str | None = None, 

69 commit_message: str | None = None, 

70) -> str: 

71 """ 

72 Upload a run directory to a HuggingFace Hub model repository. 

73 

74 Creates the repository if it does not exist. Only the core checkpoint 

75 files are uploaded (``config.yaml``, ``tokenizer.json``, 

76 ``model_weights.msgpack``, ``best_model_weights.msgpack``, 

77 ``model_summary.json``). Log files are excluded. 

78 

79 Parameters 

80 ---------- 

81 run_dir : str 

82 Local path to a DantinoX run directory. 

83 repo_id : str 

84 Hub repository in the form ``"owner/repo-name"``. 

85 private : bool 

86 Create the repository as private (default False). 

87 token : str, optional 

88 HuggingFace access token. Falls back to the ``HF_TOKEN`` 

89 environment variable or the cached login token. 

90 commit_message : str, optional 

91 Commit message for the upload (auto-generated if omitted). 

92 

93 Returns 

94 ------- 

95 str 

96 URL of the Hub repository after the upload. 

97 

98 Raises 

99 ------ 

100 ImportError 

101 If ``huggingface_hub`` is not installed. 

102 """ 

103 try: 

104 from huggingface_hub import HfApi 

105 except ImportError as exc: 

106 raise ImportError( 

107 "huggingface_hub is required for Hub integration: " 

108 "pip install huggingface-hub" 

109 ) from exc 

110 

111 import os 

112 msg = commit_message or f"Upload DantinoX checkpoint from {os.path.basename(run_dir)}" 

113 

114 api = HfApi(token=token) 

115 api.create_repo(repo_id, repo_type="model", private=private, exist_ok=True) 

116 

117 url = api.upload_folder( 

118 folder_path=run_dir, 

119 repo_id=repo_id, 

120 repo_type="model", 

121 commit_message=msg, 

122 ignore_patterns=_UPLOAD_IGNORE, 

123 ) 

124 

125 log.info("Pushed %s → %s", run_dir, url) 

126 return str(url) 

127 

128 

129def pull( 

130 repo_id: str, 

131 *, 

132 local_dir: str | None = None, 

133 token: str | None = None, 

134 revision: str | None = None, 

135) -> str: 

136 """ 

137 Download a DantinoX checkpoint from HuggingFace Hub. 

138 

139 Parameters 

140 ---------- 

141 repo_id : str 

142 Hub repository in the form ``"owner/repo-name"``. 

143 local_dir : str, optional 

144 Where to store the downloaded files. Defaults to the HuggingFace 

145 cache directory (``~/.cache/huggingface/hub/...``). 

146 token : str, optional 

147 HuggingFace access token for private repositories. 

148 revision : str, optional 

149 Git revision (branch, tag, or commit SHA) to download. 

150 

151 Returns 

152 ------- 

153 str 

154 Path to the local directory containing the checkpoint. Pass this 

155 directly to ``Generator(run_dir)``. 

156 

157 Raises 

158 ------ 

159 ImportError 

160 If ``huggingface_hub`` is not installed. 

161 """ 

162 try: 

163 from huggingface_hub import snapshot_download 

164 except ImportError as exc: 

165 raise ImportError( 

166 "huggingface_hub is required for Hub integration: " 

167 "pip install huggingface-hub" 

168 ) from exc 

169 

170 run_dir: str = snapshot_download( 

171 repo_id=repo_id, 

172 repo_type="model", 

173 local_dir=local_dir, 

174 token=token, 

175 revision=revision, 

176 ) 

177 

178 log.info("Pulled %s → %s", repo_id, run_dir) 

179 return run_dir