t5x.checkpoints package#

Utilities for reading and writing sharded checkpoints.

The checkpointing utilities here can be used in two ways. The first is to use the Checkpointer class. This requires having an optimizer and various partitioning utilities setup, but allows for reading and writing of partitioned parameters. It also allows different hosts to read different parameter partitions in a multi-host setup, which results in much faster reads. This is normally used during training where you have already created an optimizer based on a config.

The second way is to use the load_t5x_checkpoint function. This doesn’t require an optimizer to get given up front so it is useful for things like debugging and analysis of learned weights. However, this means that we cannot do partitioned reads so loading will be slower than that Checkpointer class.

class t5x.checkpoints.CheckpointManagerConstructor(*args, **kwargs)[source]#

A function that returns a checkpoints.CheckpointManager.

This type annotation allows users to partially bind args to the constructors of CheckpointManager subclasses without triggering type errors.

class t5x.checkpoints.Checkpointer(train_state, partitioner, checkpoints_dir, dataset_iterator=None, *, keep=None, save_dtype=<class 'numpy.float32'>, restore_dtype=None, keep_dataset_checkpoints=None)[source]#

Handles saving and restoring potentially-sharded T5X checkpoints.

Checkpoints are stored using a combination of msgpack (via flax.serialization) and TensorStore.

Parameters (and other objects) that are not partitioned are written to the msgpack binary directly (by host 0). Partitioned parameters are each written to their own TensorStore, with each host writing their portion to the same TensorStore in parallel. If a partition is written on multiple hosts, the partition is further sharded across these replicas to avoid additional overhead. In place of the parameter, a tensorstore.Spec is written to the msgpack (by host 0) as a reference to be used during restore. Note that the path of the array being written is relative. This makes the checkpoints portable. In other words, even if the checkpoint files are moved to a new directory, they can still be loaded. Because the path is relative, the checkpoint directory information has to be dynamically provided. This is done by _update_ts_path_from_relative_to_absolute.

For TensorStore driver using Google Cloud Storage (GCS) Key-Value Storage Layer, the GCS bucket information is necessary. When a checkpoint is written using the gcs driver, we don’t want to hardcode the bucket information in the resulting file in order to maintain the portability. Therefore, we use a dummy bucket name of “t5x-dummy-bucket”. When reading or writing the checkpoint, the bucket information is parsed from the checkpoint directory and the bucket information is dynamically updated.

checkpoints_dir#

a path to a directory to save checkpoints in and restore them from.

keep#

an optional maximum number of checkpoints to keep. If more than this number of checkpoints exist after a save, the oldest ones will be automatically deleted to save space.

restore_dtype#

optional dtype to cast targets to after restoring.

save_dtype#

dtype to cast targets to before saving.

keep_dataset_checkpoints#

an optional maximum number of data iterators to keep. If more than this number of data iterators exist after a save, the oldest ones will be automatically deleted to save space.

all_dataset_checkpoint_steps()[source]#

Returns list of available step numbers in ascending order.

all_steps()[source]#

Returns list of available step numbers in ascending order.

convert_from_tf_checkpoint(path_or_dir, *, state_transformation_fns=(), concurrent_gb=16, translator=None)[source]#

Convert from a TensorFlow-based T5 checkpoint.

latest_step()[source]#

Returns latest step number or None if no checkpoints exist.

restore(step=None, path=None, state_transformation_fns=(), fallback_state=None, lazy_parameters=False)[source]#

Restores the host-specific parameters in an Optimizer.

Either step or path can be specified, but not both. If neither are specified, restores from the latest checkpoint in the checkpoints directory.

Parameters:
  • step – the optional step number to restore from.

  • path – an optional absolute path to a checkpoint file to restore from.

  • state_transformation_fns – Transformations to apply, in order, to the state after reading.

  • fallback_state – a state dict of an optimizer to fall back to for loading params that do not exist in the checkpoint (after applying all state_transformation_fns), but do exist in Checkpointer.optimizer. The union of fallback_state and state loaded from the checkpoint must match Checkpointer.optimizer.

  • lazy_parameters – whether to load the parameters as LazyArrays to preserve memory.

Returns:

The restored train state.

Raises:
  • ValueError if both step and path are specified.

  • ValueError if checkpoint at path or step does not exist.

  • ValueError if step and path are not specified and no checkpoint is – found in the checkpoints directory.

restore_from_tf_checkpoint(path_or_dir, strict=True, translator=None)[source]#

Restore from a TensorFlow-based T5 checkpoint.

save(train_state, state_transformation_fns=(), *, concurrent_gb=128)[source]#

Saves a checkpoint for the given train state.

Parameters:
  • train_state – the train state to save. May contain a combination of LazyArray objects and arrays (e.g., np.ndarray, jax.DeviceArray)

  • state_transformation_fns – Transformations to apply, in order, to the state before writing.

  • concurrent_gb – the approximate number of gigabytes of partitionable parameters to process in parallel. Useful to preserve RAM.

class t5x.checkpoints.CheckpointerConstructor(*args, **kwargs)[source]#

A function that returns a checkpoints.Checkpointer.

This type annotation allows users to partially bind args to the constructors of Checkpointer subclasses without triggering type errors.

class t5x.checkpoints.DatasetCheckpointHandler(checkpoint_filename)[source]#

A CheckpointHandler implementation that handles tf.data.Iterator.

restore(directory, item=None)[source]#

Restores the given item.

Parameters:
  • directory – restore location directory.

  • item – a tf.data.Iterator to be restored. Not Optional

Returns:

a tf.data.Iterator restored from directory.

save(directory, item)[source]#

Saves the given item.

Parameters:
  • directory – save location directory.

  • item – a tf.data.Iterator to be saved.

structure(directory)[source]#

Unimplemented. See parent class.

class t5x.checkpoints.OrbaxCheckpointManagerInterface(directory, train_state, partitioner, dataset_iterator=None, save_dtype=None, restore_dtype=None, keep=None, period=1, checkpoint_steps=None, keep_dataset_checkpoints=None, force_keep_period=None, metric_name_to_monitor=None, metric_mode='max', keep_checkpoints_without_metrics=True)[source]#

Wrapper for ocp.CheckpointManager.

restore(step=None, path=None, fallback_state=None, state_transformation_fns=(), lazy_parameters=False)[source]#

Restores a TrainState from the given step or path.

Note: can only provide one of step or path.

Parameters:
  • step – the step number to restore from.

  • path – the full path to restore from.

  • fallback_state – a state dict of an optimizer to fall back to for loading params that do not exist in the checkpoint (after applying all state_transformation_fns), but do exist in Checkpointer.optimizer. The union of fallback_state and state loaded from the checkpoint must match Checkpointer.optimizer.

  • state_transformation_fns – Transformations to apply, in order, to the state after reading.

  • lazy_parameters – whether to load the parameters as LazyArrays to preserve memory.

Returns:

The restored train state.

restore_from_tf_checkpoint(path_or_dir, strict=True, translator=None)[source]#

Restore from a TensorFlow-based T5 checkpoint.

save(train_state, state_transformation_fns=(), force=True)[source]#

Saves a checkpoint for the given train state.

Parameters:
  • train_state – the train state to save. May contain a combination of LazyArray objects and arrays (e.g., np.ndarray, jax.DeviceArray)

  • state_transformation_fns – Transformations to apply, in order, to the state before writing.

  • force – Saves regardless of whether should_save is False. True by default because should_save logic is handled externally to this class in T5X. This is because of a feature that decouples actual step and step offset.

Returns:

Whether the save was performed or not.

class t5x.checkpoints.RestoreStateTransformationFn(*args, **kwargs)[source]#
class t5x.checkpoints.SaveBestCheckpointer(train_state, partitioner, checkpoints_dir, dataset_iterator=None, *, keep=None, save_dtype=<class 'numpy.float32'>, restore_dtype=None, metric_name_to_monitor='train/accuracy', metric_mode='max', keep_checkpoints_without_metrics=True, force_keep_period=None, keep_dataset_checkpoints=None)[source]#

A Checkpointer class that keeps checkpoints based on ‘best’ metrics.

This extends the standard Checkpointer to garbage collect checkpoints based on metric values, instead of step recency. It uses TensorBoard summary files to determine best values for a given user configured metric name. Events are read and parsed using TensorBoard’s event_processing packages.

The metric name must be of the form {run_name}/{tag_name}. For example, ‘train/accuracy’ or ‘inference_eval/glue_cola_v002/eval/accuracy’.

A few important features of this checkpointer:

  • Fallback behavior. It is not possible to verify whether metric names are valid during initialization, since some metrics may get written out after some time (e.g., during an evaluation). As such, when user provided metric names are not found, this checkpointer can be configured for two fall back strategies: (1) if keep_checkpoints_without_metrics is False, we use to the “most recent checkpoint” strategy from the standard checkpointer, (2) if keep_checkpoints_without_metrics is True, we keep all checkpoints until metrics become available (potentially indefinitely if summary files have been deleted or corrupted).

  • The number of checkpoints to keep is always increased by 1. Since its crucial to always keep the latest checkpoint (for recovery purposes) we always store the latest checkpoint plus keep number of best checkpoints.

  • It is assumed that TensorBoard summaries (event) files share a common root directory with checkpoint_dir, which is the directory passed to the the logdir crawler that searches for event files.

checkpoints_dir#

a path to a directory to save checkpoints in and restore them from.

keep#

an optional maximum number of checkpoints to keep. If more than this number of checkpoints exist after a save, the oldest ones will be automatically deleted to save space.

restore_dtype#

optional dtype to cast targets to after restoring.

save_dtype#

dtype to cast targets to before saving.

metric_name_to_monitor#

Name of metric to monitor. Must be in the format {run_name}/{tag_name} (e.g., ‘train/accuracy’, ‘inference_eval/glue_cola_v002/eval/accuracy’).

metric_mode#

Mode to use to compare metric values. One of ‘max’ or ‘min’.

keep_checkpoints_without_metrics#

Whether to always keep (or delete) checkpoints for which a metric value has not been found.

force_keep_period#

When removing checkpoints, skip those who step is divisible by force_keep_period (step % force_keep_period == 0).

keep_dataset_checkpoints#

an optional maximum number of data iterators to keep. If more than this number of data iterators exist after a save, the oldest ones will be automatically deleted to save space.

class t5x.checkpoints.SaveStateTransformationFn(*args, **kwargs)[source]#
t5x.checkpoints.all_dataset_checkpoint_steps(checkpoints_dir)[source]#

Returns available dataset checkpoint step numbers in ascending order.

t5x.checkpoints.all_steps(checkpoints_dir)[source]#

Returns list of available step numbers in ascending order.

t5x.checkpoints.fake_param_info(maybe_tspec)[source]#

Create _ParameterInfo that results in a full read.

t5x.checkpoints.find_checkpoint(path, step=None)[source]#

Find the checkpoint file based on paths and steps.

Parameters:
  • path – The location of the checkpoint. Can point to the model_dir, the checkpoint dir with a step, or the actual checkpoint file.

  • step – The step to load. Only used if you are pointing to the model_dir

Raises:

ValueError if the checkpoint file can't be found.

Returns:

The path to the checkpoint file.

t5x.checkpoints.get_checkpoint_dir(checkpoints_dir, step, step_format_fixed_length=None)[source]#

Returns path to a checkpoint dir given a parent directory and step.

t5x.checkpoints.get_local_data(x)[source]#

Get local buffer for input data.

t5x.checkpoints.get_step_from_checkpoint_dir(checkpoints_dir)[source]#

Returns a step number and the parent directory.

t5x.checkpoints.latest_step(checkpoints_dir)[source]#

Returns latest step number or None if no checkpoints exist.

t5x.checkpoints.load_t5x_checkpoint(path, step=None, state_transformation_fns=(), remap=True, restore_dtype=None, lazy_parameters=False)[source]#

Load a T5X checkpoint without pre-defining the optimizer.

Note

This only works for T5X checkpoints, not TF checkpoints.

Parameters:
  • path – The location of the checkpoint.

  • step – The checkpoint from which step should be loaded.

  • state_transformation_fns – Transformations to apply, in order, to the state after reading.

  • remap – Whether to rename the checkpoint variables to the newest version.

  • restore_dtype – optional dtype to cast targets to after restoring. If None, no parameter casting is performed.

  • lazy_parameters – whether to load the parameters as LazyArrays to preserve memory.

Returns:

A nested dictionary of weights and parameter states from the checkpoint.

t5x.checkpoints.populate_metrics_for_steps(checkpoints_dir, metric_name, steps)[source]#

Iterate through summary event files and return metrics for steps.