Decoding#

This page outlines the decoding functions that T5X provides out-of-the-box and how custom decoding functions can be used for a Transformer model, i.e., an instance of BaseTransformerModel. Here we refer to decoding as a process of generating a sequence of items from a fixed alphabet (e.g., generating token ids from the vocabulary).

There are two major ways to configure the decoding routine. The first method is to define a decode function that follows the DecodeFnCallable signature. This is more restrictive as it enforces the call signature but users don’t need to modify the model code.

The second method is to subclass a model class and override predict_batch_with_aux method. While this provides more flexibility, it requires rewriting the method.

Option 1: defining a decoding function#

If a desired decoding process can follow DecodeFnCallable, it can be registered as a private attribute of a BaseTransformerModel by passing it as a decode_fn argument to its constructor.

Decoding function call signature#

DecodeFnCallable has the following call signature

It takes in inputs, which is an int32 array with a shape [batch_size, max_decode_len]. This is an input tokens to the decoder. For the standard encoder-decoder models like T5, this is initialized as zeros with a desired decoding length. The decoding function will populate the array with the sampled token ids and return.

For a decoder-only architectures such as a Prefix Language Model, inputs can be a concatenated sequence of “inputs” and “targets” tokens ids.

tokens_to_logits is a callable that takes in a batch of token ids and the current autoregressive cache, performs the forward pass and returns the resulting logits resulting and an updated cache. Note that for incremental decoding, this function operates with a single token, i.e., the length dimension is assumed to be 1.

DecodeFnCallable is designed to be as general as possible. This results in some of the arguments being somewhat generic for a specialized decoding algorithm. For example, num_decodes refers to the number of decoded samples to be returned. In the case of beam search, num_decodes corresponds to what is commonly known as beam_size, with returned sequences sorted by the beam scores. For temperature sampling, we perform num_decodes independent sampling procedures with different random seeds and sort them by the log probability of the generated sequences.

For custom decoding functions, there might be additional arguments. To support these, we provide **kwargs.

Another usage of **kwargs is calling decoding_fn multiple times without recompiling the model. This pattern is used in Prediction Service. For a compiled model, different values of alpha can be passed e.g., decoder_params = {"alpha": 0.7} where decoder_params is the argument to predict_batch_with_aux. It is unpacked and passed to beam_search function. Note that the Prediction Service uses predict_batch_with_aux, which is one of the two public methods. This method is useful if auxiliary outputs (e.g., scores of the predictions) are to be returned. The other method is predict_batch, which simply returns the predictions.

Temperature sampling#

Temperature sampling can be used for multiple decoding strategies. The following lines configures temperature sampling as a decode_fn.

models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 1
models.EncoderDecoderModel.decode_fn = @decoding.temperature_sample
decoding.temperature_sample:
  temperature = 0.5
  topk = 20

Similar specification can be used for other model types by replacing models.EncoderDecoderModel with the relevant model class, e.g. models.PrefixLanguageModel.

The sampling behavior is controlled by the arguments passed to temperature_sample. We provide details for a few of them below.

temperature#

A probabilistic model outputs a probability distribution over a pre-defined alphabet. For example, a language model outputs logits, which are unnormalized probability values for each item in the vocabulary. We use a language model as a running example. A sampling process involves sampling from the predicted distribution one item at a time conditioned on the previously generated items until a given number of items are generated or a sentinel token that represents the end of sequence is generated.

Temperature modifies the unnormalized probability distribution at each step. For each item $\(i\)$ in the vocabulary, its probability predicted by the model is given by

\[p_i \propto \exp\left(\frac{x_i}{T} \right)\]

where $\(T\)\( is the temperature and \)\(x_i\)\( is the logits value corresponding to item \)\(i\)\(. As \)\(T \to 0\)$, the distribution puts all probability mass to the item with the highest probability. In other words, the sampling process becomes a greedy search.

In the other extreme, as $\(T \to \infty\)$, the predicted distribution becomes uniform.

topk#

By specifying strictly positive integer value for topk, the sampling process in each step is limited to the k items with highest probabilities. topk also uses temperature to modify the logits corresponding to the top k items.

topp#

By specifying non-zero positive float value for topp, the sampling process is limited to a subset of the vocabulary $\(V^{(p)} \subset V\)$, which is defined by the smallest set such that

\[\sum_{i \in V^{(p)}} p_i \ge p\]

where $\(p_i\)\( is the conditional distribution at each time step for item \)\(i\)$. This is called “Nucleus sampling”, which was introduced by Holtzman et al. ICLR 2020.

IMPORTANT: Only one of topk or topp can be used.

Option 2: subclassing a model class#

If DecodeFnCallable is not flexible enough for your custom decoding function, you can subclass the model class and override predict_batch_with_aux method. While the model class can be any instance of BaseTransformerModel, we recommend that you subclass the existing models such as EncoderDecoderModel and only override predict_batch_with_aux method.

predict_batch_with_aux method also has a required call signature, but it is significantly more flexible. It should return a tuple of predicted sequence array and auxiliary outputs such as score.