t5x.utils package#

General utility functions for t5x.

class t5x.utils.CheckpointConfig(save=None, restore=None)[source]#

Configuration for checkpointing of model and dataset.

class t5x.utils.DatasetConfig(mixture_or_task_name, task_feature_lengths, split, batch_size, shuffle, seed, use_cached=False, pack=False, use_custom_packing_ops=False, module=None, use_memory_cache=True, trim_output_features=True)[source]#

Configuration for loading a dataset from a SeqIO Task or Mixture.

class t5x.utils.EvaluatorConstructor(*args, **kwargs)[source]#

A function that returns an Evaluator.

This protocol represents the actual callsite for the seqio.Evaluator c’tor in this file. It allows users to bind additional args with partial() and pass that partial into the fn without causing type check issues.

class t5x.utils.GetDatasetCallable(*args, **kwargs)[source]#

Interface for a function returning a dataset (iterator).

class t5x.utils.GetEvalDatasetCallable(*args, **kwargs)[source]#

Interface for a function returning a dataset (iterator).

class t5x.utils.InferFnCallable(*args, **kwargs)[source]#
class t5x.utils.InferStepWithRngCallable(*args, **kwargs)[source]#
class t5x.utils.InferStepWithoutRngCallable(*args, **kwargs)[source]#
class t5x.utils.InitFnCallable(*args, **kwargs)[source]#

A callable that initializes model variables.

class t5x.utils.LearningRateCallable(*args, **kwargs)[source]#
class t5x.utils.LegacyCheckpointManager(*, save_cfg, restore_cfg, train_state_shape, partitioner, ds_iter=None, model_dir=None)[source]#

Implementation of CheckpointManager interface for T5X.

Uses underlying LegacyCheckpointer to handle save/restore for Dataset and TrainState.

restore(paths, restore_cfg=None, fallback_state=None)[source]#

Performs restore operation using restore_checkpointer.

Determines whether the indicated path is a Tensorflow checkpoint.

Parameters:
  • paths – A sequence of paths to restore from.

  • restore_cfg – RestoreCheckpointConfig specifying restoration information.

  • fallback_state – a state dict of an optimizer to fall back to for loading params that do not exist in the checkpoint (after applying all state_transformation_fns), but do exist in Checkpointer.optimizer. The union of fallback_state and state loaded from the checkpoint must match Checkpointer.optimizer.

Returns:

The restored TrainState if only one TrainState can be restored from the given paths, otherwise a sequence of TrainStates. May return None.

save(train_state, state_transformation_fns=())[source]#

Performs save operation.

Parameters:
  • train_state – a TrainState PyTree to save.

  • state_transformation_fns – Transformations to apply, in order, to the state before writing.

class t5x.utils.LegacyCheckpointer(*, save_checkpointer=None, restore_checkpointer, strict=False)[source]#

Implementation of Checkpointer interface for T5X.

Relies on underlying save_checkpointer and restore_checkpointer, which are t5x.checkpoints.Checkpointer objects.

restore(path, item=None, state_transformation_fns=(), fallback_state=None, lazy_parameters=False)[source]#

Performs restore operation using restore_checkpointer.

Determines whether the indicated path is a Tensorflow checkpoint.

Parameters:
  • path – the string path to restore from.

  • item – a TrainState PyTree to restore. Unused.

  • state_transformation_fns – Transformations to apply, in order, to the state before writing.

  • fallback_state – a state dict of an optimizer to fall back to for loading params that do not exist in the checkpoint (after applying all state_transformation_fns), but do exist in Checkpointer.optimizer. The union of fallback_state and state loaded from the checkpoint must match Checkpointer.optimizer.

  • lazy_parameters – whether to load the parameters as LazyArrays to preserve memory.

Returns:

The restored train state.

save(path, item, force=False, state_transformation_fns=(), *, concurrent_gb=128)[source]#

Performs save operation using save_checkpointer.

Parameters:
  • path – path to save item to.

  • item – a TrainState PyTree to save.

  • force – unused.

  • state_transformation_fns – Transformations to apply, in order, to the state before writing.

  • concurrent_gb – the approximate number of gigabytes of partitionable parameters to process in parallel. Useful to preserve RAM.

class t5x.utils.RestoreCheckpointConfig(path, mode='latest', assignment_map=None, strict=True, fallback_to_scratch=False, dtype=None, restore_dataset=False, checkpointer_cls=<class 't5x.checkpoints.Checkpointer'>, state_transformation_fns=(), checkpoint_manager_cls=<class 't5x.checkpoints.OrbaxCheckpointManagerInterface'>)[source]#

Configuration for restoring model from checkpoint.

checkpoint_manager_cls#

alias of OrbaxCheckpointManagerInterface

checkpointer_cls#

alias of Checkpointer

class t5x.utils.SaveCheckpointConfig(dtype='float32', period=None, checkpoint_steps=None, keep=None, keep_dataset_checkpoints=None, save_dataset=False, checkpointer_cls=<class 't5x.checkpoints.Checkpointer'>, state_transformation_fns=<factory>, checkpoint_manager_cls=<class 't5x.checkpoints.OrbaxCheckpointManagerInterface'>)[source]#

Configuration for saving model checkpoints.

checkpoint_manager_cls#

alias of OrbaxCheckpointManagerInterface

checkpointer_cls#

alias of Checkpointer

class t5x.utils.ShardedDatasetIterator(iterator, partitioner, global_shapes)[source]#

A wrapper iterator that returns sharded arrays.

property element_spec#

Returns the spec elements.

reset()[source]#

Resets the iterator back to the beginning.

restore(filename)[source]#

Restores the iterator from a file (if available).

This should only handle this iterator - not iterators in other processes.

Parameters:

filename – Name of the checkpoint.

save(filename)[source]#

Saves the state of the iterator to a file.

This should only handle this iterator - not iterators in other processes.

Parameters:

filename – Name of the checkpoint.

class t5x.utils.TrainStateInitializer(optimizer_def, init_fn, input_shapes, partitioner, input_types=None)[source]#

Helper for initializing partitioned TrainState from checkpoints or scratch.

Common use cases:

  • To restore from a single checkpoint, use from_checkpoint.

  • To iterate over multiple checkpoints without recompiling the model, use from_checkpoints.

  • To initialize from scratch, use from_scratch.

  • To restore from a checkpoint with a fallback to initializing from scratch, use from_checkpoint_or_scratch.

global_train_state_shape#

a TrainState containing the global (unpartitioned) shape (in jax.ShapeDtypeStruct) of each parameter instead of its value.

train_state_axes#

a TrainState object containing a PartitionSpec (or None) for each parameter, in place of the parameter itself.

from_checkpoint(ckpt_cfgs, *, ds_iter=None, init_rng=None)[source]#

Restores (at most) 1 checkpoint using from_checkpoints, or dies.

from_checkpoint_or_scratch(ckpt_cfgs, *, init_rng, ds_iter=None)[source]#

Initializes from checkpoint, if found, or from scratch.

from_checkpoints(restore_cfgs, ds_iter=None, init_rng=None)[source]#

Yields 0 or more restored partitioned Optimizers, and maybe datasets.

The manner in which parameters are initialized depends on restore_cfgs and restore_cfgs is iterated over and the first config that matches one or more existing checkpoints is used to generate restored optimizers from the checkpoint(s). Any remaining configs are ignored.

Parameters:
  • restore_cfgs – ordered sequence of configurations specifying checkpoint(s) to restore from. The first config to match a checkpoint will be used.

  • ds_iter – a tf.data.Iterator for the input data, or None. If provided, the referenced iterator’s state may be silently restored (depending on the config’s restore_dataset value) along with the optimizer.

  • init_rng – for initializing parameters from scratch when they are not available in the checkpoint and fallback_to_scratch is True

Yields:

TrainState with initialized optimizer, with parameters copied to devices. Path to restored checkpoint.

from_scratch(init_rng)[source]#

Initializes the partitioned Optimizer from scratch.

t5x.utils.create_checkpoint_manager_and_restore(train_state_initializer, partitioner, restore_checkpoint_cfg, restore_path, fallback_init_rng, save_checkpoint_cfg=None, model_dir=None, ds_iter=None, use_orbax=False)[source]#

Creates a CheckpointManager and restores TrainState if available.

t5x.utils.create_learning_rate_scheduler(factors='constant * linear_warmup * rsqrt_decay', base_learning_rate=0.5, warmup_steps=1000, decay_factor=0.5, steps_per_decay=20000, steps_per_cycle=100000, step_offset=0, min_learning_rate=1e-08)[source]#

Creates learning rate schedule.

Interprets factors in the factors string which can consist of: * constant: interpreted as the constant value, * linear_warmup: interpreted as linear warmup until warmup_steps, * linear_decay: linear decay from warmup_steps with decay_factor slope. Note

this option implies ‘constant * linear_warmup’, and should not be used in in conjunction with constant or linear_warmup factors.

  • rsqrt_decay: divide by square root of max(step, warmup_steps)

  • rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1)

  • decay_every: Every k steps decay the learning rate by decay_factor.

  • cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter.

Parameters:
  • factors – string, factors separated by ‘*’ that defines the schedule.

  • base_learning_rate – float, the starting constant for the lr schedule.

  • warmup_steps – int, how many steps to warm up for in the warmup schedule.

  • decay_factor – float, the amount to decay the learning rate by.

  • steps_per_decay – int, how often to decay the learning rate.

  • steps_per_cycle – int, steps per cycle when using cosine decay.

  • step_offset – int, an offset that the step parameters to this function are relative to.

  • min_learning_rate – float, minimum learning rate to output. Useful for cases when a decay function is (mis)configured to decay to non-positive values.

Returns:

float -> {‘learning_rate’: float}, the step-dependent lr.

Return type:

a function learning_rate(step)

t5x.utils.create_orbax_checkpoint_manager(*, save_cfg=None, restore_cfg=None, train_state, partitioner, ds_iter=None, model_dir=None)[source]#

Creates Orbax CheckpointManager.

t5x.utils.find_first_checkpoint_step(checkpoint_steps_index, checkpoint_steps, first_step, host_step)[source]#

Finds the first valid step in checkpoint_step list parameter to save a checkpoint at.

Parameters:
  • checkpoint_steps_index – Current index in checkpoint_steps list while training.

  • checkpoint_steps – List of checkpoint_stems passed in as parameter in checkpoint_cfg.save.

  • first_step – First step in epoch while training.

  • host_step – Host step of training.

Returns:

Integer containing first valid checkpoint step index to start off epoch training on.

t5x.utils.find_next_checkpoint_step(checkpoint_steps_index, inner_num_steps, is_checkpoint_step, host_step, checkpoint_steps, epoch_end_step, checkpoint_period, first_step)[source]#

Finds next valid checkpoint step in checkpoint_steps list parameter to stop scalar training and save a checkpoint at.

Checkpoint step is considered valid if it is less than the epoch end step, not at a concurrent checkpoint_period step, greater than the epoch first step.

Parameters:
  • checkpoint_steps_index – Current index in checkpoint_steps list while training.

  • inner_num_steps – Number of scalar steps to iterate through in training.

  • is_checkpoint_step – Dictates whether the current subset of scalar steps contains a valid checkpoint_step to save.

  • host_step – Host step of training.

  • checkpoint_steps – List of checkpoint_stems passed in as parameter in checkpoint_cfg.save.

  • epoch_end_step – Last training step in epoch.

  • checkpoint_period – Period value passed in as parameter in checkpoint_cfg.save.

  • first_step – First step in epoch while training.

Returns:

Tuple containing (possibly) halted inner_num_steps value and is_checkpoint_step if checkpoint step value was found.

t5x.utils.flatten_dict_string_keys(x)[source]#

Flattens a nested dictionary to have string keys and ‘/’ separators.

t5x.utils.get_dataset(cfg, shard_id, num_shards, feature_converter_cls, num_epochs=None, continue_from_last_checkpoint=False)[source]#

Returns a dataset from SeqIO based on a DatasetConfig.

t5x.utils.get_dataset_inner(cfg, shard_info, feature_converter_cls, seed=None, num_epochs=None)[source]#

Internal fn to load a dataset from SeqIO based on a DatasetConfig.

t5x.utils.get_fallback_state(restore_cfg, init_fn, init_rng)[source]#

Returns the fallback_state that can be used in restore().

t5x.utils.get_first_valid_restore_config_and_paths(restore_cfgs)[source]#

Returns first valid restore_cfg and the paths to restore.

Parameters:

restore_cfgs – a sequence of RestoreCheckpointConfig objects, which should be filtered to determine the first valid object.

Returns:

Tuple of valid RestoreCheckpointConfig and a sequence of paths. If the first config encountered has mode ‘specific’, it is immediately returned, along with its specified paths. If the mode is ‘all’ or ‘latest’, checks to ensure that there are valid checkpoints at each of the provided paths and filters the returned paths accordingly.

t5x.utils.get_infer_fn(infer_step, batch_size, train_state_axes, partitioner, keep_aux_as_numpy=False)[source]#

Get prediction function for the SeqIO evaluator.

The returned prediction function should take in an enumerated dataset, make predictions and return in an enumerated form with the original indices and examples zipped together. This ensures that the predictions are compared to the targets in a correct order even if the dataset is sharded across multiple hosts and gathered in a nondeterministic way.

jax.process_index == 0 is used as a “main host”, i.e., it gathers all inference results and returns.

Shape notation:

Per replica set num replicas: R Per replica set batch size: B Number of replica sets: H Length: L

Some transformations have shape transformation annotation, e.g., [B, L] -> [R, B/R, L].

Parameters:
  • infer_step – a callable that executes one prediction step. Should not yet be partitioned or pmapped.

  • batch_size – the number of examples in the global infer batch.

  • train_state_axes – Partitioning info for the train state object.

  • partitioner – partitioner to use.

  • keep_aux_as_numpy – bool. whether to leave aux values as numpy arrays; can be used to save space when saving bfloat16s

Returns:

a callable which takes in the enumerated infer dataset and an

optimizer and runs the prediction.

Return type:

predict_fn

t5x.utils.get_training_eval_datasets(cfg, shard_id, num_shards, eval_steps, feature_converter_cls, deterministic=False, model_dir=None, start_step=0)[source]#

Returns a mapping from eval task name to its dataset.

t5x.utils.get_vocabulary(cfg)[source]#

Returns seqio.Vocabulary objects associated with the Mixture/Task.

Parameters:

cfg – the DatasetConfig specifying which mixture or task to get the vocabularies for.

Returns:

A tuple of seqio.Vocabulary for inputs and targets.

Raises:

ValueError – if inputs and targets are not both present and vocabularies are different.

t5x.utils.get_zeros_batch_like_dataset(dataset, batch_size=None)[source]#

Get zeros batch like the dataset spec.

t5x.utils.import_module(module)[source]#

Imports the given module at runtime.

t5x.utils.log_model_info(log_file, full_train_state, partitioner)[source]#

Log the variable shapes information and optionally write it to a file.

t5x.utils.multihost_assert_equal(input_tree, fail_message='')[source]#

Verifies that all the hosts have the same tree of values.

t5x.utils.override_params_axes_names(model_variables, params_axes_names_override=())[source]#

Applies parameter axis names overrides to axes variables.

Parameters:
  • model_variables – the original model variables containing the ‘params_axes’ collection.

  • params_axes_names_override – a priority-ordered mapping from regex patterns (fully matching parameter names) to tuples containing string logical axis names to replace model-derived names.

Returns:

an updated set of model variables with the overrides applied to the ‘params_axes’ collection.

t5x.utils.prepare_train_iter(train_iter, *, partitioner, checkpoint_cfg, data_layout)[source]#

Prepares the training input iterator.

t5x.utils.restore(checkpoint_manager, paths, restore_cfg, fallback_state=None)[source]#

Performs restore operation using restore_checkpointer.

Determines whether the indicated path is a Tensorflow checkpoint.

Parameters:
  • checkpoint_manager – OrbaxCheckpointManagerInterface

  • paths – A sequence of paths to restore from.

  • restore_cfg – RestoreCheckpointConfig specifying restoration information.

  • fallback_state – a state dict of an optimizer to fall back to for loading params that do not exist in the checkpoint (after applying all state_transformation_fns), but do exist in Checkpointer.optimizer. The union of fallback_state and state loaded from the checkpoint must match Checkpointer.optimizer.

Returns:

The restored TrainState if only one TrainState can be restored from the given paths, otherwise a sequence of TrainStates.

t5x.utils.round_vocab_size_to_multiple(vocabulary, divisor=128)[source]#

Round up vocabulary size for improved TPU performance.

t5x.utils.sync_global_devices(name)[source]#

Creates a barrier with given name across all hosts/devices.

t5x.utils.verify_matching_vocabs(cfg, model)[source]#

Verify whether the task vocab matches the model vocab.

The seqio Task and the Model both define their vocabularies separately, but these vocabularies must match or else the training/inference results will not be sensible. This functions validates that they do match, under the assumption that this is a standard Encoder-only, Decoder-only, or Encoder-decoder model.

Parameters:
  • cfg – The DatasetConfig of the training/inference task.

  • model – A BaseTransformerModel model with input_vocabulary and output_vocabulary attributes.

Raises:

ValueError – If the task vocabulary does not match the model vocabulary.