t5x.train binary#

Script to pretrain or finetune in JAX using a SeqIO pipeline.

t5x.train.run_actions(mode, actions, train_state, metrics_by_task)[source]#

Invokes all actions on the given mode on host 0, then broadcasts to all.

Parameters:
  • mode – The mode to run the actions. e.g., if mode is train, only actions configured to run with train mode will be invoked.

  • actions – A mapping of actions that runs after train, eval or infer_eval, to inspect the model and perform useful operations, e.g., early stopping.

  • train_state – The current train_state of the trainer.

  • metrics_by_task – A map of metrics keyed by task name.

Returns:

A bool indicating whether training should be halted.

Raises:

RuntimeError – When the metrics processed on host 0 is None.

t5x.train.train(*, model, train_dataset_cfg, train_eval_dataset_cfg, infer_eval_dataset_cfg, checkpoint_cfg, partitioner, trainer_cls, model_dir, total_steps, eval_steps, eval_period, stats_period=None, random_seed, use_hardware_rng=False, summarize_config_fn, inference_evaluator_cls=<class 'seqio.evaluation.Evaluator'>, get_dataset_fn=<function get_dataset>, concurrent_metrics=True, actions=None, train_eval_get_dataset_fn=<function get_training_eval_datasets>, run_eval_before_training=False, train_state_initializer_cls=<class 't5x.utils.TrainStateInitializer'>, use_orbax=False, verify_matching_vocabs_fn=<function verify_matching_vocabs>, gc_period=0)[source]#

Train function.

Parameters:
  • model – The model object to use for training.

  • train_dataset_cfg – Specification for the dataset to train with.

  • train_eval_dataset_cfg – Specification for the dataset to evaluate with using the train metrics and no inference (e.g., uses teacher forcing). If None, train eval is disabled.

  • infer_eval_dataset_cfg – Specification for the dataset to evaluate with using the inference metrics (e.g., uses sampled decoding). If None, inference eval is disabled.

  • checkpoint_cfg – Specification for saving and restoring model parameters and dataset state to/from checkpoints.

  • partitioner – Partitioner for model parameters and data across devices.

  • trainer_cls – An implementation of BaseTrainer.

  • model_dir – Path of directory to store checkpoints and metric summaries.

  • total_steps – The step number to stop training after. The number of actual steps trained in this run will be this number minus the starting step from the checkpoint. If this is set to the starting step from the checkpoint, the model will not be compiled for training and training will not be run. This can be used in conjunction with run_eval_before_training to only evaluate a model.

  • eval_steps – The number of batches to process for each train-eval loop.

  • eval_period – The number of train steps between each evaluation (both train-eval and infer-eval).

  • stats_period – The number of train steps between writing scalar stats. If None, defaults to eval_period.

  • random_seed – A random seed to use for dropout and initialization. If None, a fast, non-deterministic hardware-based RNG is used.

  • use_hardware_rng – Whether to force using the RngBitGenerator based hardware rng, which takes seeds and acts similarly to software PRNG in that it should be seed-deterministic. The new RngBitGenerator custom PRNG system should be reproducible for a given sharding, but the numbers will change for different shardings of the same model.

  • summarize_config_fn – A function that takes in the model directory, a SummaryWriter, and the step number, and writes a summary of the

  • inference_evaluator_cls – seqio.Evaluator class to use for inference evaluation, potentially with bound configuration args.

  • get_dataset_fn – The callable use to get the train and train-eval datasets based on the DatasetConfig and shard information.

  • concurrent_metrics – If True, allow metrics computation and logging to overlap with training. Will likely result in additional TPU memory usage.

  • actions – A mapping of actions that runs after train, eval or infer_eval, to inspect the model and perform useful operations, e.g., early stopping. The key must have a 1:1 mapping to ActionMode enum. For EVAL actions to actually work, this requires concurrent_metrics to be turned off, since chaining futures and mutating states concurrently might be error-prone.

  • train_eval_get_dataset_fn – Optional callable use to get the train-eval datasets based on the DatasetConfig and shard information. If missing, it defaults to utils.get_training_eval_datasets.

  • run_eval_before_training – If True, calculate training eval and inference eval metrics before training begins.

  • train_state_initializer_cls – t5x.utils.TrainStateInitializer class for initializing partitioned TrainState from checkpoints or scratch.

  • use_orbax – if True, uses Orbax for checkpointing. Experimental feature.

  • verify_matching_vocabs_fn – Function to validate whether the task vocabulary matches the model vocabulary, if the model is a BaseTransformerModel instance. Should raise an exception on error.

  • gc_period – The number of train steps between runs of the garbage collector. If 0, the garbage collector will run at the normal frequency.

Returns:

The tuple of (last_step, last_train_state).