t5x.partitioning package#

Utilities for partitioning.

class t5x.partitioning.AxisNames(*names)[source]#

Tuple of strings specifying name for each axis.

We create a separate class for this so JAX’s pytree utilities can distinguish it from a tuple that should be treated as a pytree, instead treating it as a leaf.

class t5x.partitioning.BasePartitioner(num_partitions=None, model_parallel_submesh=None, params_on_devices=True, backend=None, ici_mesh_shape=None, dcn_mesh_shape=None)[source]#

Interface for partitioning computations across hardware devices.

abstract compile(partitioned_fn, *args)[source]#

Compiles and returns the partitioned function, or the original.

Parameters:
  • partitioned_fn – The partitioned function.

  • *args – Sample arguments to the partitioned function matching the input shapes that will be passed to the compiled function.

Returns:

The compiled function, or the original if this partitioner does not support compilation.

property data_mesh_size#

Data mesh size.

Data mesh size is defined as the number of global devices involved to carry out data parallel. Let’s say we have a global mesh: (‘replica’: 2, ‘data’: 4, ‘model’: 2), and axes ‘replica’ and ‘data’ are responsible for the data parallel, that means we have 2*4 = 8 devices involved - i.e., data mesh size is 8.

Returns:

the id of the shard for the axes being replicated among the devices used to shard the sharded_mesh_axes.

property data_shard_id#

Data shard id for the current host.

Returns:

Index of data shard that will be sent to the current local host.

property data_shards#

Number of data shards.

Let’s say we are dealing with 2 slices of df4x2 TPUs. In data pipeline we need prepare / send one data shard to each local host. This means, we need 4 shards since we have 4 local hosts. How to infer the number of hosts from mesh information? In this case, we have a global mesh: (‘replica’: 2, ‘data’: 8, ‘model’: 2). Each local host (i.e., df2x2) has this local mesh: (‘replica’: 1, ‘data’: 4, ‘model’: 2). By dividing global mesh with local mesh, we can get the count of hosts.

Returns:

Number of data shards. Each shard will be sent to one local host.

get_data_layout(batch_size=None, host_index=None)[source]#

Returns filled DataLayout based on the partitioned model layout.

Parameters:
  • batch_size – if set, indicates the requested batch size. The exception will be raised if this batch size is not compatible with the layout. If not set, the batch size is inferred from the layout.

  • host_index – indicates the host index to use for the calculations, if not set - use JAX-provided one. Should be in [0, num_hosts) interval and the order should match the order of corresponding CPU devices in jax.devices().

Returns:

Filled DataLayout structure.

get_local_chunk_info(global_shape, mesh_axes)[source]#

Returns the local chunk info for a given array shape and sharded axes.

get_logical_axes(train_state)[source]#

Returns a copy of TrainState with Optional[AxisNames] as leaves.

get_mesh_axes(train_state)[source]#

Returns a copy of TrainState with Optional[PartitionSpecs] as leaves.

move_params_to_devices(train_state, train_state_axes)[source]#

Moves the optimizer parameters to devices.

abstract partition(fn, in_axis_resources, out_axis_resources, static_argnums=(), donate_argnums=())[source]#

Partitions the computation using partitioner-specific implementation.

Parameters:
  • fn – the function to partition.

  • in_axis_resources

    Pytree of structure matching that of arguments to fn, with all actual arguments replaced by resource assignment specifications. It is also valid to specify a pytree prefix (e.g. one value in place of a whole subtree), in which case the leaves get broadcast to all values in that subtree. The valid resource assignment specifications are:

    None: in which case the value will be replicated on all devices PartitionSpec: a tuple of length at most equal to the rank of the

    partitioned value. Each element can be a None, a mesh axis or a tuple of mesh axes, and specifies the set of resources assigned to partition the value’s dimension matching its position in the spec.

  • out_axis_resources – Like in_axis_resources, but specifies resource assignment for function outputs.

  • static_argnums – an optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant) in the partitioned function.

  • donate_argnums – an optional int or collection of ints that specify which argument buffers are “donated” to the computation. It is safe to donate argument buffers if you no longer need them once the computation has finished.

Returns:

A partitioned version of the input function.

class t5x.partitioning.BasePjitPartitioner(num_partitions=None, model_parallel_submesh=None, params_on_devices=True, backend=None, ici_mesh_shape=None, dcn_mesh_shape=None)[source]#

Partitioner that uses T5X version of jax.pjit.

compile(partitioned_fn, *args)[source]#

Compiles and returns the partitioned function, or the original.

Parameters:
  • partitioned_fn – The partitioned function.

  • *args – Sample arguments to the partitioned function matching the input shapes that will be passed to the compiled function.

Returns:

The compiled function, or the original if this partitioner does not support compilation.

partition(fn, in_axis_resources, out_axis_resources, static_argnums=(), donate_argnums=())[source]#

Partitions the computation using partitioner-specific implementation.

Parameters:
  • fn – the function to partition.

  • in_axis_resources

    Pytree of structure matching that of arguments to fn, with all actual arguments replaced by resource assignment specifications. It is also valid to specify a pytree prefix (e.g. one value in place of a whole subtree), in which case the leaves get broadcast to all values in that subtree. The valid resource assignment specifications are:

    None: in which case the value will be replicated on all devices PartitionSpec: a tuple of length at most equal to the rank of the

    partitioned value. Each element can be a None, a mesh axis or a tuple of mesh axes, and specifies the set of resources assigned to partition the value’s dimension matching its position in the spec.

  • out_axis_resources – Like in_axis_resources, but specifies resource assignment for function outputs.

  • static_argnums – an optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant) in the partitioned function.

  • donate_argnums – an optional int or collection of ints that specify which argument buffers are “donated” to the computation. It is safe to donate argument buffers if you no longer need them once the computation has finished.

Returns:

A partitioned version of the input function.

class t5x.partitioning.DataLayout(batch_size, shard_id, num_shards, is_first_host_in_replica_set)[source]#

Represents data layout for the partitioned model.

class t5x.partitioning.LocalChunkInfo(slice: Tuple[slice, ...], replica_id: int)[source]#
class t5x.partitioning.LocalChunker(global_mesh)[source]#

Utility class to aid chunking of sharded arrays in multihost settings.

get_local_chunk_info(global_shape, mesh_axes)[source]#

Get the local chunk info for a given array shape and sharded axes.

Parameters:
  • global_shape – the global, unsharded shape of the array to chunk.

  • mesh_axes – a sequence of names (or None) of equal rank to global_shape that specifies which mesh dimensions the array is sharded along.

Returns:

LocalChunkInfo containing the logical slices of the array found on this host’s local devices, as well as the replica index for this chunk among chunks with the same slice. The latter is used to determine which host should write this chunk during checkpointing.

get_replica_id(sharded_mesh_axes)[source]#

Given mesh axes used for sharding, computes current host’s replica id.

To give an example, let’s say there are two axes globally: data, and model, the mesh axes for sharding is (‘data’, ), which means we are going to partition an array along ‘data’ axis and replicate it along ‘model’ axis. The replica_id is to show the index of the current local host along the ‘model’ axis.

Parameters:

sharded_mesh_axes – the mesh axes for sharding.

Returns:

the index of the current local host along the non-sharding axes (i.e., replicating axes).

get_shard_id(sharded_mesh_axes)[source]#

Given mesh axes used for sharding, computes current host’s shard id.

To give an example, let’s say there are two axes globally: replica, data, and model, the mesh axes for sharding is (‘replica’, ‘data’), which means we are going to partition an array along ‘replica’ and ‘data’ axes. The shard_id is to show the index of the current local host along the sharding axes (in this example, it’s ‘replica’ and ‘data’ axes).

More concretely, let’s say we have 4 local hosts, and we use ‘replica’ and ‘data’ axes for data parallel (2 hosts along the replica axis, and 2 host along the data axis). The host located in (‘replica’: 0, ‘data’: 0), we should assign data shard-0 to it. For host (‘replica’: 0, ‘data’: 1), we assign shard-1. For host (‘replica’: 1, ‘data’: 0), we assign shard-2. For host (‘replica’: 1, ‘data’: 1), we assign shard-3.

Note: the host location along ‘replica’ and ‘data’ axes, e.g., (‘replica’: 0, ‘data’: 0) is named chunk_id and stored in self._local_chunker.chunk_ids[axis].

Parameters:

sharded_mesh_axes – the mesh axes for sharding.

Returns:

the index of the current local host along the sharding axes.

class t5x.partitioning.PjitPartitioner(num_partitions=None, model_parallel_submesh=None, params_on_devices=True, backend=None, ici_mesh_shape=None, dcn_mesh_shape=None, logical_axis_rules=None)[source]#

Partitioner that uses named axes and jax.pjit.

get_logical_axes(train_state)[source]#

Returns a copy of TrainState with Optional[AxisNames] as leaves.

get_mesh_axes(train_state)[source]#

Returns a copy of TrainState with Optional[PartitionSpecs] as leaves.

property logical_axis_rules#

Returns the logical axis rules.

partition(fn, in_axis_resources, out_axis_resources, static_argnums=(), donate_argnums=())[source]#

Partitions the function using jax.pjit.

class t5x.partitioning.PjittedFnWithContext(pjitted_fn, partition_mesh, logical_axis_rules=())[source]#

Wraps pjitted function to apply the appropriate contexts.

t5x.partitioning.bounds_from_last_device(last_device)[source]#

Get the bound from the given last device.

t5x.partitioning.default_mesh(num_partitions, model_parallel_submesh=None, backend=None, ici_mesh_shape=None, dcn_mesh_shape=None)[source]#

Attempt to return a default mesh for simple cases.

Parameters:
  • num_partitions – number of partitions to use, will be ignored if model_parallel_submesh is provided.

  • model_parallel_submesh – 4-tuple that specifies the x,y,z,c submesh to use as the model-parallel device tile.

  • backend – get devices from the pinned backend, if specified. This is useful for explicitly specifying the devices other than relying on jax_platform_name.

  • ici_mesh_shape – Shape of the logical mesh used for SPMD parallelism in each slice. The meaning of each mesh axis is defined by mesh_axis_names, so these two params must be the same length. If dcn_mesh_shape is present, the overall mesh is the product of ici_mesh_shape and dcn_mesh_shape. For example, an ici_mesh_shape of [2, 3, 4] with mesh_axis_names [‘replica’, ‘data’, ‘model’] indicates 2-way replica parallelism, 3-way data parallelism, and 4-way model parallelism over 24 devices. None, the default, is equivalent to a sequence of ones and means that the model is placed on a single device.

  • dcn_mesh_shape – Shape of the logical mesh used for SPMD parallelism over multiple slices. The overall mesh is the product of ici_mesh_shape and dcn_mesh_shape, and the meaning of each mesh axis is defined by mesh_axis_names, so these three params must be the same length.

Returns:

xmap/pjit 2D Mesh with ‘data’, ‘model’ mesh axes if single-slice, otherwise 3D Mesh with ‘replica’, ‘data’, and ‘model’ mesh axes.

t5x.partitioning.get_coords(device)[source]#

Returns the coordinates of the given device.

t5x.partitioning.get_cpu_mesh()[source]#

Trivial mesh for CPU Testing.

t5x.partitioning.get_gpu_mesh(num_partitions)[source]#

Mesh for GPUs that preferentially places ‘model’ on NVLink.

t5x.partitioning.get_mesh(model_parallel_submesh, input_devices=(), input_local_devices=(), tile_by_host_if_needed=True, backend=None)[source]#

Construct an xmap/pjit Mesh for the given model-parallel submesh.

The resulting mesh has two resource axes: ‘model’, with the provided submesh shape, and ‘data’, which covers the rest of the mesh.

Parameters:
  • model_parallel_submesh

    a HardwareMesh spec, namely (x,y,z,core) on TPU for a single model-parallel replica’s “tile” in the physical device mesh. The first three elements (x, y, and z) should be factors of the pod slice; e.g., if you are using df_4x8, then x should be a factor of 4 (one of 1, 2, 4), y should be a factor of 8 (one of 1, 2, 4, 8), and z must be 1, because TPU v3 slices are only 2D. z can be >1 for TPU v4 (and maybe later TPUs) that allow 3D slices. core is the number of cores to use from each TPU node. As communication is usually fastest inside the same node, if you need a tile of more than 1 core, then you should first increase core: e.g., for TPU v3, (1,1,1,2) is better

    than (2,1,1,1). To pick a good spec, try a few possible values until you get high TPU utilization.

  • input_devices – the devices to use, will use jax.devices() if this is not set.

  • input_local_devices – the local devices to use, will use jax.local_devices() if this is not set.

  • tile_by_host_if_needed – JAX currently requires that the parts of any sharded array that are located on one host’s local devices form a single contiguous slice. A best effort will be made to achieve this without “tiling” the device assignment over hosts (which can reduce XLA collective performance). If this flag is True, then the device assignment will be tiled over hosts if necessary to satisfy this constraint and create a buildable mesh; if false, mesh construction will fail instead.

  • backend – get devices from the pinned backend, if specified. This is useful for explicitly specifying the devices other than relying on jax_platform_name.

Returns:

A xmap / pjit Mesh containing the virtual device mesh with data, model axes.

t5x.partitioning.global_mesh_defined()[source]#

Checks if global xmap/pjit mesh resource environment is defined.

t5x.partitioning.standard_logical_axis_rules(activation_partitioning_dims=1, parameter_partitioning_dims=1, additional_rules=None)[source]#

Default sharding rules for T5X model in terms of logical axis names.

Parameters:
  • activation_partitioning_dims – enables 2-D activation sharding when set to 2.

  • parameter_partitioning_dims – enables 2-D parameter sharding when set to 2.

  • additional_rules – additional rules (a sequence of tuples) that will be appended to the standard rules.

Returns:

Sequence of logical axis rules

t5x.partitioning.with_sharding_constraint(x, axis_resources)[source]#

Wrapper for lax.with_sharding_constraint, no-op on cpu or outside pjit.