t5x.trainer package#

Trainer and MetricsManager classes for use in train loop.

To create a custom trainer, subclass BaseTrainer and implement _partitioned_train_step and _partitioned_eval_step methods, possibly by re-using the utility functions provided in this module.

class t5x.trainer.ActionMode(value)[source]#

Defines when to run a action.

For example, TRAIN means to run an action after a TRAIN loop is done.

class t5x.trainer.ArrayMapFuture(*args, **kwargs)[source]#
class t5x.trainer.BaseAction[source]#

Base Action class for override. The action itself does nothing.

abstract run(train_state, metrics_by_task)[source]#

Runs an action for the given train_state and metrics.

Parameters:
  • train_state – The current train_state in the training loop.

  • metrics_by_task – A map of metrics that is grouped by each task.

Returns:

A bool indicating whether training should be halted.

class t5x.trainer.BaseTrainer(model, train_state, partitioner, eval_names, summary_dir, train_state_axes, rng)[source]#

Abstract base trainer class.

Internally this uses MetricsManagers that start threads. You should use the trainer as a context manager, or call close() directly in order to wait for these threads to finish after training is done.

close()[source]#

Stops all train metric managers threads.

compile_eval(batches)[source]#

Pre-compiles eval step (if not yet compiled).

Not required.

Pre-compiles the evaluation step for each evaluation dataset, reusing cached compilations where possible. In other words, if multiple evaluation datasets have equivalent shapes/dtypes for the batch and initial metrics, recompilation will be avoided.

If not called before eval, compilation will occur automatically on the first step and JAX’s “jit cache” will be used to avoid recompilation for future steps.

Parameters:

batches – a mapping from evaluation dataset name to a sample batch. The batch may contain dummy values, but the shapes and dtypes must be correct.

compile_train(batch)[source]#

Pre-compiles train step (if not yet compiled).

Not required.

If not called before train, compilation will occur automatically on the first step and JAX’s “jit cache” will be used to avoid recompilation for future steps.

Parameters:

batch – A sample batch that may contain dummy values, but with correct shapes and dtypes.

eval(batch_iters)[source]#

Runs evaluation loop over the iterator and writes summary.

train(batch_iter, num_steps, start_step=None)[source]#

Runs the train loop for the given number of steps.

class t5x.trainer.BaseTrainerConstructor(*args, **kwargs)[source]#

A function that returns a BaseTrainer.

class t5x.trainer.EarlyStoppingAction(metric, mode, patience=3, atol=0.0, rtol=0.0)[source]#

Terminates training when the specified metric is not improving.

Checks whether the monitored metrics are decreasing after every train or eval, or both. If the loss is no longer decreasing for patience times, terminate the training process.

run(train_state, metrics_by_task)[source]#

Runs an action for the given train_state and metrics.

Parameters:
  • train_state – The current train_state in the training loop.

  • metrics_by_task – A map of metrics that is grouped by each task.

Returns:

A bool indicating whether training should be halted.

class t5x.trainer.LearningRateCallable(*args, **kwargs)[source]#
class t5x.trainer.MetricValueMapFuture(*args, **kwargs)[source]#
class t5x.trainer.MetricsManager(name, summary_dir=None)[source]#

Manages a set of distributed metrics and their logging.

Logging is disabled on all but host 0.

Logs to:
  • TensorBoard

  • ABSL

You should call close() to wait for threads started by this class to finish.

start_duration_timer(block_on=())[source]#

Starts the duration timer.

property summary_writer#

Returns the MetricWriter used by this class.

write_metrics_summary(metrics, step, num_steps)[source]#

Writes summary based on accumulated metrics in a background thread.

Duration is automatically computed as the interval between completion of metrics fetching. This closely approximates the duration of num_steps, as the steps must be computes sequentually, and it is more accurate than computing the time since the call to the step function since its actual execution occurs asynchronously on the TPU/GPU device.

Parameters:
  • metrics – acculumated metric values.

  • step – the current train step.

  • num_steps – the number of steps the metrics are accumulated across.

Returns:

A mapping of name -> scalar value of the written summary. Only return the

real scalar value on host 0. For other hosts, return None.

write_scalar(key, val, step)[source]#

Writes scalar value to metric writers in a threadsafe manner.

write_scalars(step, scalars)[source]#

Writes scalar value to metric writers in a threadsafe manner.

class t5x.trainer.PartitionedEvalCallable(*args, **kwargs)[source]#

Protocol for a partitioned eval step.

class t5x.trainer.PartitionedTrainCallable(*args, **kwargs)[source]#

Protocol for a partitioned train step.

exception t5x.trainer.PreemptionError[source]#

Training has been interrupted and needs an emergency checkpoint.

class t5x.trainer.SummarizeMetricsCallable(*args, **kwargs)[source]#

PyType template for a metrics summary function.

class t5x.trainer.TerminateOnNanAction(task, metric='loss')[source]#

Terminates training when NaN loss is detected.

Checks whether the loss metric for the given task is NaN or Inf and terminates training if so.

run(train_state, metrics_by_task)[source]#

Runs an action for the given train_state and metrics.

Parameters:
  • train_state – The current train_state in the training loop.

  • metrics_by_task – A map of metrics that is grouped by each task.

Returns:

A bool indicating whether training should be halted.

class t5x.trainer.TimeFuture(*args, **kwargs)[source]#
class t5x.trainer.Trainer(model, train_state, partitioner, eval_names, summary_dir, train_state_axes, rng, learning_rate_fn, num_microbatches, weight_metrics_computer=None)[source]#

Training loop with optional microbatches.

class t5x.trainer.WeightMetricsComputer[source]#

Decides which weight metrics to compute during training.

compute_metrics(gradients, old_train_state, new_train_state)[source]#

Compute some metrics about weights after having updating these weights.

Parameters:
  • gradients – The gradients of all weights.

  • old_train_state – The training state before applying the gradients.

  • new_train_state – The training state after applying the gradients.

Returns:

A dictionary of Metrics, where the keys are either metric names, or of the form metric_name/parameter_name, depending on whether or not they are global to the model, or specific to each model parameter.

t5x.trainer.accumulate_grads_microbatched(model, train_state, batch, dropout_rng, num_microbatches, data_partition_spec=PartitionSpec('data'))[source]#

Implements optional microbatched gradient accumulation.

Parameters:
  • model – the instantiation of BaseModel to train.

  • train_state – A train state with model parameters and optimizer state.

  • batch – input batch consisting of either - simply-padded batched features ‘encoder_input_tokens’, ‘decoder_input_tokens’ ‘decoder_target_tokens’ ‘decoder_loss_weights’- packed, batched features with additional “(encoder|decoder)_segment_id”, “(encoder|decoder)_position”

  • dropout_rng – jax PRNGKey for dropout.

  • num_microbatches – the number of microbatches to use, or None for direct training.

  • data_partition_spec – the PartitionSpec to use for partitioning annotations on the batch.

Returns:

Accumulated gradients and incremental metrics.

t5x.trainer.apply_grads(train_state, grad_accum, metrics, learning_rate, weight_metrics_computer, other_state_variables=None)[source]#

Applies gradients to the optimizer.

Parameters:
  • train_state – A train state that contains model and optimizer params.

  • grad_accum – results of accumulate_grads call.

  • metrics – incremental metrics from accumulate_grads call.

  • learning_rate – the learning rate to use for this step.

  • weight_metrics_computer – A WeightMetricsComputer instance, or None, to decide what metrics, if any, to log about weights and weight updates during training.

  • other_state_variables – other variables to update the state with.

Returns:

The updated train state, metrics.

t5x.trainer.eval_step(model, train_state, batch)[source]#

Default evaluation step.

t5x.trainer.train_with_lr(train_state, batch, learning_rate, dropout_rng, model, num_microbatches, weight_metrics_computer=None, data_partition_spec=PartitionSpec('data'))[source]#

Main training function with LR schedule.