t5x.losses package#
Loss functions.
- class t5x.losses.SpecialLossNormalizingFactor(value)[source]#
Specially calculated loss_normalizing_factors, that are not a constant.
- NUM_REAL_TARGET_TOKENS#
Whether to divide the loss by the number of real (non-padding) tokens in the current target batch. If ‘decoder_loss_weights’ are specified, it will be the sum of the weights. Otherwise it will be the number of non-zero ‘decoder_target_tokens’.
- NUM_TOTAL_TARGET_TOKENS#
Whether to divide the loss by the total number of target tokens, i.e., batch_size * target_seq_length (including padding).
- AVERAGE_PER_SEQUENCE#
This will first compute the per-sequence loss (averaged over the number of real target tokens in the sequence), and then compute the average of that over the sequences. This can be preferable to NUM_REAL_TARGET_TOKENS for finetuning, because it will weigh all examples equally, regardless of sequence length (which can be especially important for multi-task finetuning).
- t5x.losses.compute_weighted_cross_entropy(logits, targets, weights=None, label_smoothing=0.0, z_loss=0.0, loss_normalizing_factor=None)[source]#
Compute weighted cross entropy and entropy for log probs and targets.
- Parameters:
logits – [batch, length, num_classes] float array.
targets – categorical targets [batch, length] int array.
weights – None or array of shape [batch, length].
label_smoothing – label smoothing constant, used to determine the on and off values.
z_loss – coefficient for auxiliary z-loss loss term.
loss_normalizing_factor – Constant to divide loss by. If not specified, loss will not be normalized. Intended for backward compatibility with T5-MTF training. Should not normally be used.
- Returns:
Tuple of scalar loss, z_loss, and weight sum.
- t5x.losses.convert_special_loss_normalizing_factor_to_enum(x)[source]#
Converts stringified version of LNF to an enum.
This is useful because gin dynamic registration does not (currently) have support for enum.
- Parameters:
x – stringified version of SpecialLossNormalizingFactor enum.
- Returns:
SpecialLossNormalizingFactor enum instance.
- t5x.losses.get_loss_normalizing_factor_and_weights(loss_normalizing_factor, batch)[source]#
Get the float loss_normalizing_factor and loss weights.
If loss_normalizing_factor is float or None, this will simply return the input loss_normalizing_factor and batch.
If loss_normalizing_factor is a SpecialLossNormalizingFactor, it will return a float loss_normalizing_factor and loss weights corresponding to the special LNF. See SpecialLossNormalizingFactor for more details.
- Parameters:
loss_normalizing_factor – The input LNF, which may be a float, None, or SpecialLossNormalizingFactor (or a stringified SLNF).
batch – Input data batch.
- Returns:
- Tuple of (output_loss_normalizing_factor, loss_weights).
’output_loss_normalizing_factor’ is a scalar float (Python float or jnp float). ‘loss_weights’ is the per token loss weight JNP array.