Fine Tuning a Model#


This page outlines the steps to fine-tune an existing pre-trained model with T5X on common downstream tasks defined with SeqIO. This is one of the simplest and most common use cases of T5X. If you’re new to T5X, this tutorial is the recommended starting point.


Fine-tuning a model with T5X consists of the following steps:

  1. Choose the pre-trained model to fine-tune.

  2. Choose the SeqIO Task/Mixture to fine-tune the model on.

  3. Write a Gin file that configures the pre-trained model, SeqIO Task/Mixture and other details of your fine-tuning run.

  4. Launch your experiment locally or on XManager.

  5. Monitor your experiment and parse metrics.

These steps are explained in detail in the following sections. An example run that fine-tunes a T5-small checkpoint on WMT14 English to German translation benchmark is also showcased.

Step 1: Choose a pre-trained model#

To use a pre-trained model, you need a Gin config file that defines the model params, and the model checkpoint to load from. For your convenience, TensorFlow checkpoints and Gin configs for common T5 pre-trained models have been made available for use in T5X. A list of all the available pre-trained models (with model checkpoints and Gin config files) are available in the Models documentation.

For the example run, you will use the T5 1.1 Small model. The Gin file for this model is located at /t5x/examples/t5/t5_1_1/small.gin, and the checkpoint is located at gs://t5-data/pretrained_models/t5x/t5_1_1_small.

Step 2: Choose a SeqIO Task/Mixture#

A SeqIO Task encapsulates the data source, the preprocessing logic to be performed on the data before querying the model, the postprocessing logic to be performed on model outputs, and the metrics to be computed given the postprocessed outputs and targets. A SeqIO Mixture denotes a collection of Tasks and enables fine-tuning a model on multiple Tasks simultaneously.

Standard Tasks#

Many common datasets and benchmarks, e.g. GLUE, SuperGLUE, WMT, SQUAD, CNN/Daily Mail, etc. have been implemented as SeqIO Tasks/Mixtures and can be used directly. These Tasks/Mixtures are defined in third_party/py/t5/data/ and third_party/py/t5/data/

For the example run, you will fine-tune the model on the WMT14 English to German translation benchmark, which has been implemented as the wmt_t2t_ende_v003 Task.

Custom Tasks#

It is also possible to define your own custom task. See the SeqIO documentation for how to do this. As a note, Tasks defined using the old T5 codebase may also be used by T5X. If using a custom Task, you will need to follow the instructions in the Advanced Topics section at the end of this tutorial to make sure the module containing your task is included.

When defining a custom task, you have the option to cache it on disk before fine-tuning. The instructions for this are here. Caching may improve performance for tasks with expensive pre-processing. By default, T5X expects tasks to be cached. To finetune on a task that has not been cached, set --gin.USE_CACHED_TASKS=False.

Step 3: Write a Gin Config#

After choosing the pre-trained model and SeqIO Task/Mixture for your run, the next step is to configure your run using Gin. If you’re not familiar with Gin, reading the T5X Gin Primer is recommended.

T5X provides a Gin file that configures the T5X trainer for fine-tuning (located at t5x/configs/runs/finetune.gin), and expects a few params from you. These params can be specified in a separate Gin file, or via commandline flags. Following are the required params:

  • INITIAL_CHECKPOINT_PATH: This is the path to the pre-trained checkpoint (from Step 1). For the example run, set this to 'gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000'.

  • TRAIN_STEPS: Number of fine-tuning steps. This includes the number of steps that the model was pre-trained for, so make sure to add the step number from the INITIAL_CHECKPOINT_PATH. For the example run, to fine-tune for 20_000 steps, set this to 1_020_000, since the initial checkpoint is the 1_000_000th step.

  • MIXTURE_OR_TASK_NAME: This is the SeqIO Task or Mixture name to run (from Step 2). For the example run, set this to 'wmt_t2t_ende_v003'.

  • TASK_FEATURE_LENGTHS: This is a dict mapping feature key to maximum int length for that feature. After preprocessing, features are truncated to the provided value. For the example run, set this to {'inputs': 256, 'targets': 256}.

  • MODEL_DIR: A path to write fine-tuned checkpoints to. When launching using XManager, this path is automatically set and can be accessed from the XManager Artifacts page. When running locally using Blaze, you can explicitly pass a directory using a flag. Launch commands are provided in the next step.

  • LOSS_NORMALIZING_FACTOR: When fine-tuning a model that was pre-trained using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be set to pretraining batch_size * pretrained target_token_length. For T5 and T5.1.1: 2048 * 114. For mT5: 1024 * 229. For ByT5: 1024 * 189.

In addition to the above params, you will need to include finetune.gin and the Gin file for the pre-trained model, which for the example run is t5_1_1/small.gin.

include 't5x/configs/runs/finetune.gin'
include 't5x/examples/t5/t5_1_1/small.gin'

You will also need to import the Python module(s) that register SeqIO Tasks and Mixtures used in your run. For the example run, we add import since it is where wmt_t2t_ende_v003 is registered.

Finally, your Gin file should look like this:

include 't5x/configs/runs/finetune.gin'
include 't5x/examples/t5/t5_1_1/small.gin'

# Register necessary SeqIO Tasks/Mixtures.

MIXTURE_OR_TASK_NAME = "wmt_t2t_ende_v003"
TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 256}
TRAIN_STEPS = 1_020_000  # 1000000 pre-trained steps + 20000 fine-tuning steps.
INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000"

See t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin for this example.

Step 4: Launch your experiment#

To launch your experiment locally (for debugging only; larger checkpoints may cause issues), run the following on commandline:

python -m t5x.train_unfragmented \
  --gin_file=t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \

Note that multiple comma-separated paths can be passed to the gin_search_paths flag, and these paths should contain all Gin files used or included in your experiment.

After fine-tuning has completed, you can parse metrics into CSV format using the following script:

MODEL_DIR= # from Step 4 if running locally, from XManager Artifacts otherwise
python -m t5.scripts.parse_tb \
  --summary_dir="$VAL_DIR" \
  --seqio_summaries \
  --out_file="$VAL_DIR/results.csv" \

Metric Explanations#

By default, t5x logs many metrics to TensorBoard, many of these seem similar but have important distinctions.

The first two graphs you will see are the accuracy and cross_ent_loss graphs. These are the token-level teacher-forced accuracy and cross entropy loss respectively. Each of these graphs can have multiple curves on them. The first curve is the train curve. This is calculated as a running sum than is then normalized over the whole training set. The second class of curves have the form training_eval/${task_name}. These curves are created by running a subset (controlled by the eval_steps parameter of the main train function) of the validation split of ${task_name} through the model and calculating these metrics using teacher-forcing. These graphs can commonly be used to find “failure to learn” cases and as a warning sign of overfitting, but these are often not the final metrics one would report on.

The second set of graphs are the ones under the collapsible eval section in TensorBoard. These graphs are created based on the metric_fns defined in the SeqIO task. The curves on these graphs have the form inference_eval/${task_name}. Values are calculated by running the whole validation split through the model in inference mode, commonly auto-regressive decoding or output scoring. Most likely these are the metrics that will be reported.

More information about the configuration of the datasets used for these different metrics can be found here.

In summary, the metric you actually care about most likely lives under the eval tab rather, than in the accuracy graph.

Next Steps#

Now that you have successfully fine-tuned a pre-trained model on WMT, here are some topics you might want to explore next:

We also touch upon a few advanced topics related to fine-tuning below that might be useful, especially when customizing your fine-tuning job.

Advanced Topics#

train, train_eval and infer_eval {.no-toc}#

A DatasetConfig object is used to configure loading SeqIO Tasks/Mixtures for training and eval. If you take a closer look at runs/finetune.gin, you will see that there are three DatasetConfig objects defined and passed to the train function: train_dataset_cfg, train_eval_dataset_cfg, infer_eval_dataset_cfg. Here’s a brief description of these configs:

  • train: This configures the Task/Mixture that the model will be fine-tuned on.

  • train_eval: This configures the Task/Mixture that is used to compute training metrics on the eval split, e.g. perplexity. These metrics are defined in the Model class and the eval fn is located here.

  • infer_eval: This configures the Task/Mixture that is used to compute metrics on inferred model outputs (e.g., comparing decoded model outputs and targets). These metrics are defined in the SeqIO Task/Mixture and the eval fn is located here

Using separate SeqIO Tasks/Mixtures for fine-tuning and eval {.no-toc}#

Commonly, the same SeqIO Task/Mixture is used for training and eval. It is set by the MIXTURE_OR_TASK_NAME macro in your fine-tune Gin file from Step 3 above, and is passed to train_dataset_cfg, train_eval_dataset_cfg, infer_eval_dataset_cfg. The train split is used for training and the validation split is used for evals. However, you can override these params in your fine-tune Gin config. For example, if you want to fine-tune on all GLUE tasks but evaluate only on GLUE STS benchmark, you can override the SeqIO Task/Mixture used for infer_eval in your fine-tune Gin file as follows:

include 'runs/finetune.gin'
include 'models/t5_small.gin'

MIXTURE_OR_TASK_NAME = 'glue_v002_proportional'
TASK_FEATURE_LENGTHS =  {'inputs': 512, 'targets': 84}
TRAIN_STEPS = 1_500_000  # includes 1_000_000 pretrain steps
INITIAL_CHECKPOINT_PATH = 'gs://t5-data/pretrained_models/t5x/t5_small/checkpoint_1000000'
infer_eval/utils.DatasetConfig.mixture_or_task_name = 'glue_stsb_v002'

Other params in finetune.gin can be overridden in the same way.

Defining a custom SeqIO Task/Mixture to fine-tune on {.no-toc}#

Refer to SeqIO documentation.