t5x.models package#

T5X Models.

This module uses layers.py to build a higher-level model structure and define methods for the loss computation as well as a train, prediction, and evaluation steps.

class t5x.models.BaseModel(optimizer_def)[source]#

Abstract base class for models.

Wraps a flax module to provide a basic interface for computing loss, evaluation metrics, prediction, and scoring.

Subclasses must implement the abstract methods. Any additional arguments added to these methods must have defaults or be bound at run time to fit the interface expected by the standard training, inference, and evaluation functions.

eval_fn(params, batch)[source]#

Computes loss and metrics during the evaluation.

Parameters:
  • params – model parameters.

  • batch – a batch of inputs.

Returns:

the loss computed for the given inputs and parameters. aux:

weight_sum: sum of the per-token weights applied to the loss. metrics: a mapping of metrics computed for this batch.

Return type:

loss

abstract get_initial_variables(rng, input_shapes, input_types=None)[source]#

Returns the initial variables of the model.

abstract loss_fn(params, batch, dropout_rng)[source]#

Computes loss and metrics.

Parameters:
  • params – model parameters.

  • batch – a batch of inputs.

  • dropout_rng – rng to use for dropout, or None for deterministic mode.

Returns:

the loss computed for the given inputs and parameters. aux:

weight_sum: sum of the per-token weights applied to the loss. metrics: a mapping of metrics computed for this batch.

Return type:

loss

predict_batch(params, batch, rng=None)[source]#

Predicts a batch of outputs from the model.

Parameters:
  • params – model parameters.

  • batch – a batch of inputs.

  • rng – an optional RNG to use during prediction (e.g., for decoding).

Returns:

The model predictions.

abstract predict_batch_with_aux(params, batch, rng=None)[source]#

Predict a batch from the modelwith auxiliary outputs.

Parameters:
  • params – model parameters.

  • batch – a batch of inputs.

  • rng – an optional RNG key to use during prediction (e.g., for decoding).

Returns:

the model predictions aux: auxiliary data

Return type:

predictions

abstract score_batch(params, batch, return_intermediates=False)[source]#

Computes scores for batch.

class t5x.models.BaseTransformerModel(module, input_vocabulary, output_vocabulary, optimizer_def, decode_fn=None, label_smoothing=0.0, z_loss=0.0, loss_normalizing_factor=None)[source]#

Abstract base class for Transformer models.

Subclasses must implement predict_batch_with_aux, score_batch, get_initial_variables from BaseModel as well as _compute_logits.

loss_fn(params, batch, dropout_rng)[source]#

Loss function used for training with a cross-entropy loss.

class t5x.models.DecodeFnCallable(*args, **kwargs)[source]#

Decoding function call signature.

class t5x.models.DecoderOnlyModel(module, vocabulary, optimizer_def, decode_fn=<function temperature_sample>, inputs_bidirectional_attention=False, feature_converter_cls=None, label_smoothing=0.0, z_loss=0.0, loss_normalizing_factor=None)[source]#

Model class for the decoder-only modules.

It accepts inputs made out of only ‘targets’ or both ‘inputs’ and ‘targets’. If both ‘inputs’ and ‘targets’ are present, the loss will be computed only on ‘targets’.

By default the self-attention is fully causal and a given position only attends to the time steps before and itself. If inputs_bidirectional_attention = True, the attention in the “inputs” region is bidirectional. This architecture was referred to as “Prefix LM” in Raffel et al. 2019 (https://arxiv.org/abs/1910.10683).

FEATURE_CONVERTER_CLS#

alias of DecoderFeatureConverter

get_initial_variables(rng, input_shapes, input_types=None)[source]#

Get the initial variables.

predict_batch_with_aux(params, batch, rng=None, *, return_all_decodes=False, num_decodes=1, decoder_params=None)[source]#

Predict with prefix.

decoder_params can be used to pass dynamic configurations to self.decode_fn. An example usage is to pass different random seed (i.e., jax.random.PRNGKey(seed) with different seed value). This can be done by setting decoder_params[‘decode_rng’] = jax.random.PRNGKey(seed).

Although this method is short, there are a few subtle points that. We use a running example to make these points clear.

``` Example

inputs = [9, 4, 6, 1] targets = [3, 9, 1]

seqio.DecoderFeatureConverter will generate these set of features

decoder_target_tokens = [9, 4, 6, 1, 3, 9, 1, 0, 0]

decoder_input_tokens = [0, 9, 4, 6, 1, 3, 9, 1, 0]

decoder_causal_attention = [1, 1, 1, 1, 1, 0, 0, 0, 0]

The output of this function is (a` through e are the sampled token ids):

sampled_sequences = [9, 4, 6, 1, a, b, c, d, e].

```

Given these set of features, we make a few important observation.

  1. When a decoder-only model is used for a supervised learning with “inputs” and “targets”, one way to handle this is to concatenate the “inputs” and “targets”. For training, we use teacher forcing for the entire concatenated sequence. For inference, on the other hand, we don’t have the targets. This requires that we use teacher forcing on the “inputs” portion while using the generated token as the input token for the next decoding step. For evaluation, we do have “targets” but we only want to use them for computing metrics, i.e., by comparing to the sequence generated by the model.

    This function is currently used for evaluation mode, but by ignoring “targets”, it can be extended for the inference mode.

  2. During evaluation mode, the targets portion is zeroed out and they are filled with the sampled token ids. The inputs portion is kept intact.

  3. Note that decoder_causal_attention has an additional 1 after the final “inputs” token. This is because the position where the last “inputs” token (in this case 1) is input and the output is the first “target” token (in this case 3) can be included in the non-causal attention region.

    This results in an alignment between decoder_input_tokens and decoder_causal_attention because the former is shifted to the right by one position. So we use decoder_causal_attention as a binary mask to zero out the target tokens in decoder_input_tokens.

Note

In order to use a custom self._decode_fn with this model it must support:

  1. Decoding from a partially decoded state by accepting a vector of initial_indices that specify where in the input to start decoding from.

  2. Using a vector as the loop counter to support different examples being a different number of steps into their decoding loop.

  3. Be able to handle one batch element reaching max_decode_length before the others without it causing the model to prematurely stop decoding.

Parameters:
  • params – model parameters.

  • batch – batch element with the model features specified in seqio.DecoderFeatureConverter.

  • rng – an optional RNG key to use during prediction, which is passed as ‘decode_rng’ to the decoding function.

  • return_all_decodes – if True, will return all batch_size * num_decodes samples from the model as an array of shape [batch_size, num_decodes, sequence_length]. In this case the num_decodes dimension is sorted in increasing order of log-probability. Otherwise returns only the most likely samples as an array of shape [batch_size, sequence_length].

  • num_decodes – number of decoded sequences to be returned.

  • decoder_params – additional (model-independent) parameters for the decoder.

Returns:

an array of shape [batch, max_decode_length].

Return type:

sampled_sequences

score_batch(params, batch, return_intermediates=False)[source]#

Compute log likelihood score on a batch.

class t5x.models.DecoderParams(return_all_decodes: bool = False, num_decodes: int = 1)[source]#
class t5x.models.EncoderDecoderModel(module, input_vocabulary, output_vocabulary, optimizer_def, decode_fn=<function beam_search>, feature_converter_cls=None, label_smoothing=0.0, z_loss=0.0, loss_normalizing_factor=None, default_decoder_params=None)[source]#

Wrapper class for the models.Transformer nn.module.

FEATURE_CONVERTER_CLS#

alias of EncDecFeatureConverter

get_initial_variables(rng, input_shapes, input_types=None)[source]#

Get the initial variables for an encoder-decoder model.

predict_batch_with_aux(params, batch, rng=None, decoder_params=None, return_all_decodes=None, num_decodes=None, prompt_with_targets=False)[source]#

Predict with fast decoding beam search on a batch.

Here we refer to “parameters” for values that can be compiled into the model dynamically, as opposed to static configuration settings that require a recompile. For example, the model weights and the decoder brevity-penalty are parameters and can be modified without requiring a recompile. The number of layers, the batch size and the decoder beam size are configuration options that require recompilation if changed.

This method can be used with a customizable decoding function as long as it follows the signature of DecodeFnCallable. In order to provide a unified interface for the decoding functions, we use a generic names. For example, a beam size is a concept unique to beam search. Conceptually, it corresponds to the number of sequences returned by the beam search. Therefore, the generic argument num_decodes corresponds to the beam size if self._decode_fn is a beam search. For temperature sampling, num_decodes corresponds to the number of independent sequences to be sampled. Typically num_decodes = 1 is used for temperature sampling.

If return_all_decodes = True, the return tuple contains the predictions with a shape [batch, num_decodes, max_decode_len] and the scores (i.e., log probability of the generated sequence) with a shape [batch, num_decodes]. The beam dimension is sorted in increasing order of log-probability.

If return_all_decodes = False, the return tuple contains the predictions with a shape [batch, max_decode_len] and the scores with a shape [batch].

decoder_params can be used to pass dynamic configurations to self.decode_fn. An example usage is to pass different random seed (i.e., jax.random.PRNGKey(seed) with different seed value). This can be done by setting decoder_params[‘decode_rng’] = jax.random.PRNGKey(seed).

If prompt_with_targets = True, then decoder_prompt_inputs is initialized from the batch’s decoder_input_tokens. The EOS is stripped to avoid decoding to stop after the prompt by matching to output_vocabulary.eos_id.

Parameters:
  • params – model parameters.

  • batch – a batch of inputs.

  • rng – an optional RNG key to use during prediction, which is passed as ‘decode_rng’ to the decoding function.

  • decoder_params – additional (model-independent) parameters for the decoder.

  • return_all_decodes – whether to return the entire beam or just the top-1.

  • num_decodes – the number of beams to use in beam search.

  • prompt_with_targets – Whether the force decode decoder_inputs.

Returns:

the batch of predictions, with the entire beam if requested an auxiliary dictionary of decoder scores

Return type:

A tuple containing

score_batch(params, batch, return_intermediates=False)[source]#

Compute log likelihood score on a batch.

class t5x.models.TokensIdsToLogitsCallable(*args, **kwargs)[source]#

Token ids to logits mapping call signature.

t5x.models.compute_base_metrics(logits, targets, mask, loss, z_loss=None, segment_ids=None)[source]#

Compute summary metrics.

Parameters:
  • logits – [batch, length, num_classes] float array.

  • targets – categorical targets [batch, length] int array of categories.

  • mask – None or array of shape [batch, length]. Note: must consist of boolean values (float-valued weights not supported).

  • loss – loss (float)

  • z_loss – z_loss (float)

  • segment_ids – Optional dictionary of feature and value is the segment ids used for packing, i.e. [batch, length] arrays.

Returns:

Dict of metrics.

t5x.models.compute_metrics(logits, targets, weights, loss, weight_sum, additional_metrics)[source]#

Compute summary metrics.

t5x.models.compute_weighted_accuracy(logits, targets, weights=None)[source]#

Compute weighted accuracy for log probs and targets.

Parameters:
  • logits – [batch, length, num_classes] float array.

  • targets – categorical targets [batch, length] int array of categories.

  • weights – None or array of shape [batch, length]

Returns:

Scalar accuracy.

t5x.models.count_packed_examples(segment_ids)[source]#

Return the number of packed examples.

After packing, each row of segment_ids contains the ids of packed examples. For some model inputs, some features could have some examples but not others. For example, two tasks in a multimodal setup could be: (1). text -> text, and (2). image -> text. Examples from (1) will be missing image input feature and examples from (2) will be missing text input feature.

To count the packed examples, we count the unique ids in segment_ids excluding 0s (because of padding). It can be implemented by counting the number of non-zero values in the first discrete difference along axis=1, plus the number of rows in segment_ids, and minus the number of padded examples.

Example

[[1, 1, 3, 3, 0, 0],

[2, 2, 2, 2, 2, 2], [2, 7, 7, 7, 7, 0]] has 5 packed examples.

Parameters:

segment_ids – [B, L] array.

Returns:

Scalar count.

t5x.models.remove_prefix(sequence, prefix_length)[source]#

Vectorized version of remove_prefix. Takes similar arguments as remove_prefix but with additional array axes over which remove_prefix is mapped.

Original documentation:

Remove the prefix portion and shift to the left by the prefix length.

The example below uses non-decorated function definition, i.e., arrays do not have batch dimension. jax.vmap internally inserts the batch dimension at axis=0. The shape annotations do not include the batch dimension either.

Example: `python sequence = [1, 2, 3, 4, 5, 6, 7, 0] prefix_length = 2 remove_prefix(sequence, prefix_length) = [3, 4, 5, 6, 7, 0, 0, 0] `

Note that this function assumes that the padding token has an id of 0.

Args:

sequence: [length] array. prefix_length: scalar, i.e., rank 0 array.

Returns:

[length] array with the prefix removed and the suffix shifted.