Decoding#
This page outlines the decoding functions that T5X provides out-of-the-box and
how custom decoding functions can be used for a Transformer model, i.e., an
instance of
BaseTransformerModel
.
Here we refer to decoding as a process of generating a sequence of items from a
fixed alphabet (e.g., generating token ids from the vocabulary).
There are two major ways to configure the decoding routine. The first method is
to define a decode function that follows the DecodeFnCallable
signature. This
is more restrictive as it enforces the call signature but users don’t need to
modify the model code.
The second method is to subclass a model class and override
predict_batch_with_aux
method. While this provides more flexibility, it
requires rewriting the method.
Option 1: defining a decoding function#
If a desired decoding process can follow DecodeFnCallable
, it can be
registered as a private attribute of a
BaseTransformerModel
by passing it as a decode_fn
argument to its constructor.
Decoding function call signature#
DecodeFnCallable
has the following call signature
It takes in inputs
, which is an int32 array with a shape [batch_size, max_decode_len]
. This is an input tokens to the decoder. For the standard
encoder-decoder models like T5, this is initialized as zeros with a desired
decoding length. The decoding function will populate the array with the sampled
token ids and return.
For a decoder-only architectures such as a Prefix Language Model, inputs
can
be a concatenated sequence of “inputs” and “targets” tokens ids.
tokens_to_logits
is a callable that takes in a batch of token ids and the
current autoregressive cache, performs the forward pass and returns the
resulting logits resulting and an updated cache. Note that for incremental
decoding, this function operates with a single token, i.e., the length dimension
is assumed to be 1.
DecodeFnCallable
is designed to be as general as possible. This results in
some of the arguments being somewhat generic for a specialized decoding
algorithm. For example, num_decodes
refers to the number of decoded samples to
be returned. In the case of beam search, num_decodes
corresponds to what is
commonly known as beam_size
, with returned sequences sorted by the beam
scores. For temperature sampling, we perform num_decodes
independent
sampling procedures with different random seeds and sort them by the log
probability of the generated sequences.
For custom decoding functions, there might be additional arguments. To support
these, we provide **kwargs
.
Another usage of **kwargs
is calling decoding_fn
multiple times without
recompiling the model. This pattern is used in
Prediction Service.
For a compiled model, different values of alpha
can be passed e.g.,
decoder_params = {"alpha": 0.7}
where decoder_params
is the argument to
predict_batch_with_aux
. It is unpacked and passed to beam_search
function.
Note that the Prediction Service uses
predict_batch_with_aux
,
which is one of the two public methods. This method is useful if auxiliary
outputs (e.g., scores of the predictions) are to be returned. The other method
is
predict_batch
,
which simply returns the predictions.
Beam search#
The following lines can be added to a gin file in order to use beam search as a decoding function for an encoder-decoder model.
models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4
models.EncoderDecoderModel.decode_fn = @decoding.beam_search
decode.beam_search.alpha = 0.6
Note that we skip the gin boilerplate code such as gin dynamic registration. Please refer to T5X Gin Primer for more details.
The beam search behavior is controlled by the arguments passed to beam_search
.
We provide details for a few of them below.
num_decodes
#
If num_decodes
are configured with gin.register
, it is overridden by the
value explicitly passed by the caller e.g.,
models.EncoderDecoderModel.predict_batch_with_aux
. This is because the
information about num_decodes
is needed to prepare the encoder inputs and
outputs expanded by num_decodes
times in the batch dimension.
We recommend that num_decodes
be specified only in
models.EncoderDecoderModel.predict_batch_with_aux
.
alpha
#
This is the brevity penalty introduced in Wu et al. 2016 to penalize short sequences.
max_decode_len
#
For evaluation, we typically don’t want to truncate the examples by a specified
sequence length. Therefore, we dynamically obtain the length information from
the batch of examples. The default behavior of seqio.Evaluator
is to use the
maximum length of a task but, this can be overridden.
Since the length information is provided dynamically, we don’t set
max_decode_len
in gin. Instead we pass the relevant inputs
array to
beam_search
whose length is the dynamically determined maximum length.
If max_decode_len
is explicitly specified via gin, this will override the
implicitly determined length information unless it is passed by
predict_batch_with_aux
.
Temperature sampling#
Temperature sampling
can be used for multiple decoding strategies. The following lines configures
temperature sampling as a decode_fn
.
models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 1
models.EncoderDecoderModel.decode_fn = @decoding.temperature_sample
decoding.temperature_sample:
temperature = 0.5
topk = 20
Similar specification can be used for other model types by replacing
models.EncoderDecoderModel
with the relevant model class, e.g.
models.PrefixLanguageModel
.
The sampling behavior is controlled by the arguments passed to
temperature_sample
. We provide details for a few of them below.
temperature
#
A probabilistic model outputs a probability distribution over a pre-defined alphabet. For example, a language model outputs logits, which are unnormalized probability values for each item in the vocabulary. We use a language model as a running example. A sampling process involves sampling from the predicted distribution one item at a time conditioned on the previously generated items until a given number of items are generated or a sentinel token that represents the end of sequence is generated.
Temperature modifies the unnormalized probability distribution at each step. For each item $\(i\)$ in the vocabulary, its probability predicted by the model is given by
where $\(T\)\( is the temperature and \)\(x_i\)\( is the logits value corresponding to item \)\(i\)\(. As \)\(T \to 0\)$, the distribution puts all probability mass to the item with the highest probability. In other words, the sampling process becomes a greedy search.
In the other extreme, as $\(T \to \infty\)$, the predicted distribution becomes uniform.
topk
#
By specifying strictly positive integer value for topk
, the sampling process
in each step is limited to the k
items with highest probabilities. topk
also
uses temperature
to modify the logits corresponding to the top k
items.
topp
#
By specifying non-zero positive float value for topp
, the sampling process is
limited to a subset of the vocabulary $\(V^{(p)} \subset V\)$, which is defined by
the smallest set such that
where $\(p_i\)\( is the conditional distribution at each time step for item \)\(i\)$. This is called “Nucleus sampling”, which was introduced by Holtzman et al. ICLR 2020.
IMPORTANT: Only one of topk
or topp
can be used.
Option 2: subclassing a model class#
If DecodeFnCallable
is not flexible enough for your custom decoding function,
you can subclass the model class and override predict_batch_with_aux
method.
While the model class can be any instance of
BaseTransformerModel
,
we recommend that you subclass the existing models such as
EncoderDecoderModel
and only override predict_batch_with_aux
method.
predict_batch_with_aux
method also has a required call signature, but it is
significantly more flexible. It should return a tuple of predicted sequence
array and auxiliary outputs such as score.