Coverage for core / sharding.py: 95%
20 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 21:16 +0200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 21:16 +0200
1from __future__ import annotations
3from typing import TypeVar
5import jax
6import numpy as np
7from jax.sharding import Mesh, NamedSharding
8from jax.sharding import PartitionSpec as P
10_T = TypeVar("_T")
13def make_mesh(n_devices: int = 0) -> Mesh:
14 """Create a 1-D data-parallel mesh.
16 Parameters
17 ----------
18 n_devices:
19 Number of devices to use. 0 (default) means all available local devices.
20 """
21 devices = jax.local_devices()
22 if 0 < n_devices < len(devices):
23 devices = devices[:n_devices]
24 return Mesh(np.array(devices), axis_names=("data",))
27def replicate(pytree: _T, mesh: Mesh) -> _T:
28 """Copy *pytree* to every device in *mesh* (no sharding on any axis)."""
29 sharding = NamedSharding(mesh, P())
30 return jax.device_put(pytree, sharding)
33def shard_batch(pytree: _T, mesh: Mesh) -> _T:
34 """Shard *pytree* along its leading (batch) axis across all devices in *mesh*."""
35 sharding = NamedSharding(mesh, P("data"))
36 return jax.device_put(pytree, sharding)
39def num_devices(mesh: Mesh) -> int:
40 """Return the total number of devices in *mesh*."""
41 return mesh.size