Pretraining a model#


This page outlines the steps to pretrain a model with T5X on common tasks defined with SeqIO.


Pretraining a model with T5X consists of the following steps:

  1. Choose the model architecture.

  2. Choose the SeqIO Task/Mixture to for training.

  3. Write a Gin file that configures the model, SeqIO Task/Mixture and other details of your pretraining run.

  4. Launch your experiment locally or on XManager.

  5. Monitor your experiment.

These steps are explained in detail in the following sections. An example run that trains a T5 1.1 Small checkpoint from scratch on the C4 dataset using the span corruption pretraining objective is also showcased.

Step 1: Choose a model architecture#

To train a model, you need a Gin config file that defines the model params. For your convenience, Gin configs for common models have been made available for use in T5X. A list of all the available pre-trained models (with model checkpoints and Gin config files) are available in the Models documentation.

For the example run, you will use the T5 1.1 Small model. The Gin file for this model is located at /t5x/examples/t5/t5_1_1/1_1_small.gin.

Step 2: Choose a SeqIO Task/Mixture#

A SeqIO Task encapsulates the data source, the preprocessing logic to be performed on the data before querying the model, the postprocessing logic to be performed on model outputs, and the metrics to be computed given the postprocessed outputs and targets. A SeqIO Mixture denotes a collection of Tasks and enables pretraining a model on multiple Tasks simultaneously.

Many common datasets and benchmarks, e.g. GLUE, SuperGLUE, WMT, SQUAD, CNN/Daily Mail, etc. have been implemented as SeqIO Tasks/Mixtures and can be used directly. These Tasks/Mixtures are defined in third_party/py/t5/data/ and third_party/py/t5/data/

For the example run, you will train the model on c4_v220_span_corruption Task that implements the span corruption pretraining objective using the C4 dataset. This is the final pretraining Task used in the T5 paper.

TIP: Want to use a custom Task or Mixture? See section below called “Adding SeqIO Task/Mixture modules and Gin files”

Step 3: Write a Gin Config#

After choosing the model architecture and SeqIO Task/Mixture for your run, the next step is to configure your run using Gin. If you’re not familiar with Gin, reading the T5X Gin Primer is recommended.

T5X provides a Gin file that configures the T5X trainer for pretraining (located at runs/pretrain.gin), and expects a few params from you. These params can be specified in a separate Gin file, or via commandline flags. Following are the required params:

  • TRAIN_STEPS: Number of training steps. For the example run, set this to 100_000.

  • MIXTURE_OR_TASK_NAME: This is the SeqIO Task or Mixture name to run (from Step 2). For the example run, set this to 'c4_v220_span_corruption'.

  • TASK_FEATURE_LENGTHS: This is a dict mapping feature key to maximum int length for that feature. After preprocessing, features are truncated to the provided value. For the example run, set this to {"inputs": 512, "targets": 114}, following the original T5 pretraining setup.

  • MODEL_DIR: A path to write pretrained checkpoints to. When launching using XManager, this path is automatically set and can be accessed from the XManager Artifacts page. When running locally using Blaze, you can explicitly pass a directory using a flag. Launch commands are provided in the next step.

In addition to the above params, you will need to import pretrain.gin and the Gin file for the pretrained model, which for the example run is t5_1_1/small.gin.

include 't5x/configs/runs/pretrain.gin'
include 't5x/examples/t5/t5_1_1/small.gin'

Note that the include statements can use relative paths in this example for which You will pass an appropriate gin_search_paths flag to locate these files when launching your run. However, we recommend that you use absolute paths because it can be more difficult to locate the gin files speicified via relative paths without inspecting the launch command.

You will also need to import the Python module(s) that register SeqIO Tasks and Mixtures used in your run. For the example run, we add import since it is where ‘glue_v002_proportional’ is registered. Note that this module must also be included as a dependency in the T5X trainer binary. Most common Task/Mixture modules, such as this one, are already included. If your module is not included, see the Advanced Topics section at the end of this tutorial for instructions to add it.

Finally, your Gin file should look like this:

include 't5x/examples/t5/t5_1_1/small.gin'
include 't5x/configs/runs/pretrain.gin'

# Register necessary SeqIO Tasks/Mixtures.

MIXTURE_OR_TASK_NAME = "c4_v220_span_corruption"
TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 114}

See t5x/examples/t5/t5_1_1/examples/small_c4_pretrain.gin for this example.

Step 4: Launch your experiment#

To launch your experiment locally (for debugging only; larger checkpoints may cause issues), run the following on commandline:

python -m t5x.train_unfragmented \
  --gin_file=t5x/examples/t5/t5_1_1/c4_pretrain_small.gin \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \

Note that multiple comma-separated paths can be passed to the gin_search_paths flag, and these paths should contain all Gin files used or included in your experiment.

Next Steps#

Now that you have successfully pretrained a model, here are some topics you might want to explore next:

We also touch upon a few advanced topics related to pretraining below that might be useful, especially when customizing your pretraining job.

Advanced Topics#

train, train_eval {#train-eval .no-toc}#

A DatasetConfig object is used to configure loading SeqIO Tasks/Mixtures for training and eval. If you take a closer look at runs/pretrain.gin, you will see that there are two DatasetConfig objects defined and passed to the train function: train_dataset_cfg and train_eval_dataset_cfg. Here’s a brief description of these configs:

  • train: This configures the Task/Mixture that the model will be pretrained on.

  • train_eval: This configures the Task/Mixture that is used to compute training metrics on the eval split, e.g. perplexity. These metrics are defined in the Model class and the eval fn is located here.

Deterministic training {.no-toc}#

A training run may consist of various randomized operations, e.g. dataset shuffling, dropout, etc. However, it is often useful to have deterministic training, meaning that the random operations are reproducible and robust to preemption/restarts. To make your pretraining deterministic, in addition to the params configured in pretrain.gin, you need to add the following configs:

  • sets the dataset seed to a fixed value: train/utils.DatasetConfig.seed = 42.

  • sets the dropout seed to a fixed value: train_script.train.random_seed = 42.

  • enables dataset checkpointing: utils.SaveCheckpointConfig.save_dataset = True. This means that the dataset iterator is checkpointed periodically during training, and in case of preemptions, training resumes from the latest dataset checkpoint to ensure deterministic behavior. The checkpointing frequency is set using utils.SaveCheckpointConfig.period (1000 by default), meaning that the dataset is checkpointed after processing 1000 batches (batches, not examples; batch size can be overridden using train/DatasetConfig.batch_size and is set to 128 by default).

Defining a custom SeqIO Task/Mixture to pretrain on {.no-toc}#

Refer to SeqIO documentation.