t5x.losses package#

Loss functions.

class t5x.losses.SpecialLossNormalizingFactor(value)[source]#

Specially calculated loss_normalizing_factors, that are not a constant.


Whether to divide the loss by the number of real (non-padding) tokens in the current target batch. If ‘decoder_loss_weights’ are specified, it will be the sum of the weights. Otherwise it will be the number of non-zero ‘decoder_target_tokens’.


Whether to divide the loss by the total number of target tokens, i.e., batch_size * target_seq_length (including padding).


This will first compute the per-sequence loss (averaged over the number of real target tokens in the sequence), and then compute the average of that over the sequences. This can be preferable to NUM_REAL_TARGET_TOKENS for finetuning, because it will weigh all examples equally, regardless of sequence length (which can be especially important for multi-task finetuning).

t5x.losses.compute_weighted_cross_entropy(logits, targets, weights=None, label_smoothing=0.0, z_loss=0.0, loss_normalizing_factor=None)[source]#

Compute weighted cross entropy and entropy for log probs and targets.

  • logits – [batch, length, num_classes] float array.

  • targets – categorical targets [batch, length] int array.

  • weights – None or array of shape [batch, length].

  • label_smoothing – label smoothing constant, used to determine the on and off values.

  • z_loss – coefficient for auxiliary z-loss loss term.

  • loss_normalizing_factor – Constant to divide loss by. If not specified, loss will not be normalized. Intended for backward compatibility with T5-MTF training. Should not normally be used.


Tuple of scalar loss, z_loss, and weight sum.


Converts stringified version of LNF to an enum.

This is useful because gin dynamic registration does not (currently) have support for enum.


x – stringified version of SpecialLossNormalizingFactor enum.


SpecialLossNormalizingFactor enum instance.

t5x.losses.get_loss_normalizing_factor_and_weights(loss_normalizing_factor, batch)[source]#

Get the float loss_normalizing_factor and loss weights.

If loss_normalizing_factor is float or None, this will simply return the input loss_normalizing_factor and batch.

If loss_normalizing_factor is a SpecialLossNormalizingFactor, it will return a float loss_normalizing_factor and loss weights corresponding to the special LNF. See SpecialLossNormalizingFactor for more details.

  • loss_normalizing_factor – The input LNF, which may be a float, None, or SpecialLossNormalizingFactor (or a stringified SLNF).

  • batch – Input data batch.


Tuple of (output_loss_normalizing_factor, loss_weights).

’output_loss_normalizing_factor’ is a scalar float (Python float or jnp float). ‘loss_weights’ is the per token loss weight JNP array.