Coverage for core / sharding.py: 95%

20 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 21:16 +0200

1from __future__ import annotations 

2 

3from typing import TypeVar 

4 

5import jax 

6import numpy as np 

7from jax.sharding import Mesh, NamedSharding 

8from jax.sharding import PartitionSpec as P 

9 

10_T = TypeVar("_T") 

11 

12 

13def make_mesh(n_devices: int = 0) -> Mesh: 

14 """Create a 1-D data-parallel mesh. 

15 

16 Parameters 

17 ---------- 

18 n_devices: 

19 Number of devices to use. 0 (default) means all available local devices. 

20 """ 

21 devices = jax.local_devices() 

22 if 0 < n_devices < len(devices): 

23 devices = devices[:n_devices] 

24 return Mesh(np.array(devices), axis_names=("data",)) 

25 

26 

27def replicate(pytree: _T, mesh: Mesh) -> _T: 

28 """Copy *pytree* to every device in *mesh* (no sharding on any axis).""" 

29 sharding = NamedSharding(mesh, P()) 

30 return jax.device_put(pytree, sharding) 

31 

32 

33def shard_batch(pytree: _T, mesh: Mesh) -> _T: 

34 """Shard *pytree* along its leading (batch) axis across all devices in *mesh*.""" 

35 sharding = NamedSharding(mesh, P("data")) 

36 return jax.device_put(pytree, sharding) 

37 

38 

39def num_devices(mesh: Mesh) -> int: 

40 """Return the total number of devices in *mesh*.""" 

41 return mesh.size