t5x.decoding package#

Fast decoding routines for inference from a trained model.

class t5x.decoding.BeamState(cur_index, live_logprobs, finished_scores, live_seqs, finished_seqs, finished_flags, cache, initial_index)[source]#

Holds beam search state data.

replace(**updates)#

“Returns a new object replacing the specified fields with new values.

class t5x.decoding.DecodingState(cur_index, sequences, cur_token, cache)[source]#

Holds decoding state data.

Used to communicate the current decoding state to tokens_to_logits methods. Note that we use a different class than SamplingLoopState or Beamstate to decouple the concerns of what data is useful for the loop vs. what the sampling method needs. Decodes for a given batch entry are flattened in a column-major way so that decodes from the same batch entry are grouped together.

cur_index#

[batch_size * num_decodes] array position of the sampling loop in the length dimension.

Type:

jax.Array

sequences#

[batch_size * num_decodes, max_decode_len] array of current sampled sequence prefixes.

Type:

jax.Array

cur_token#

[batch_size * num_decodes] single timestep slice containing current tokens.

Type:

jax.Array

cache#

any mapping of arrays, e.g. flax attention cache.

Type:

Mapping[str, jax.Array]

replace(**updates)#

“Returns a new object replacing the specified fields with new values.

class t5x.decoding.SamplingLoopState(step, cur_index, sequences, cache, cur_token, ended, rng, log_prob)[source]#

Holds sampling state data.

step#

Scalar decoding step count. Starts from zero.

Type:

jax.Array

cur_index#

[batch_size * num_decodes] array position of the sampling loop in the length dimension.

Type:

jax.Array

sequences#

[batch_size * num_decodes, max_decode_len] array of current sampled sequence prefixes.

Type:

jax.Array

cache#

any mapping of arrays, e.g. flax attention cache.

Type:

Mapping[str, jax.Array]

cur_token#

[batch_size * num_decodes] single timestep slice containing current tokens.

Type:

jax.Array

ended#

[batch_size * num_decodes] binary array marking completed sequences.

Type:

jax.Array

rng#

Jax PRNGKey

Type:

jax.Array

log_prob#

[batch_size * num_decodes] array of log probs for each sequence.

Type:

jax.Array

replace(**updates)#

“Returns a new object replacing the specified fields with new values.

t5x.decoding.add_beam_dim(x, beam_size, offset=0)[source]#

Creates new beam dimension in non-scalar array and tiles into it.

t5x.decoding.beam_init(batch_size, beam_size, max_decode_len, cache, offset=0, live_seqs=None, initial_index=None)[source]#

Initializes the beam search state data structure.

Beam search for transformer machine translation.

If inputs has non-zero entries, those values are not modified, i.e., the sampled values for those positions are discarded. This simulates the teacher forcing on the prefix positions.

NOTE: While using initial_index with prompts of variable lengths To comply with the max_decode_len length requirement, we might now return sequences that were live (i.e. EOS not decoded yet) when they exceeded their length allowance along with sequences that finished (i.e. EOS was decoded). Furthermore there might be sequences that finished decoding after their max_decode_len was finished, but would appear truncated in the output at max_decode_len.

TODO(afrozm): Solve this, if needed, by having a third class of sequences apart from live and finished called “truncated”, then after beam search completes, we will order them as finished > truncated > live, rather than finished > live that happens right now.

Parameters:
  • inputs – array: [batch_size, length] int32 sequence of tokens.

  • cache – flax attention cache.

  • tokens_to_logits – fast autoregressive decoder function taking single token slices and cache and returning next-token logits and updated cache.

  • eos_id – int: id of end-of-sentence token for target vocabulary.

  • num_decodes – number of decoded sequences to be returned. This is equivalent to the number of beams used in the beam search.

  • alpha – float: scaling factor for brevity penalty.

  • max_decode_len – int: an optional maximum length of decoded sequence. If None, it uses inputs.shape[1] as max_decode_len.

  • min_log_prob – the beam search will stop if there is no live beam entry with higher raw score (ignoring brevity penalty) than this.

  • decode_rng – Unused decoder RNG seed.

  • cache_offset – axis offset for cache, arising from scanned layers.

  • initial_index – Optional[jnp.ndarray], the index from which to start decoding autoregressively if set. If unset, then we teacher-force the prefix, but autoregressively (so it will be slow). When set, this also assumes that the cache is appropriately populated. Since inputs are padded on the left with BOS = 0, these are also the lengths of the prompts.

Returns:

[batch_size, beam_size, max_decode_len] top-scoring sequences [batch_size, beam_size] beam-search scores.

Return type:

Tuple of

t5x.decoding.brevity_penalty(alpha, length)[source]#

Brevity penalty function for beam search penalizing short sequences.

Parameters:
  • alpha – float: brevity-penalty scaling parameter.

  • length – int: length of considered sequence.

Returns:

Brevity penalty score as jax scalar.

t5x.decoding.cache_gather_beams(nested, beam_indices, batch_size, old_beam_size, new_beam_size, one_hot=True, offset=0)[source]#

Gathers the cache beam slices indexed by beam_indices into new beam array.

Parameters:
  • nested – cache pytree.

  • beam_indices – array of beam_indices

  • batch_size – size of batch.

  • old_beam_size – size of _old_ beam dimension.

  • new_beam_size – size of _new_ beam dimension.

  • one_hot – whether to perform gathers by one-hot contraction or directly.

  • offset – cache axis offset from scanned layers.

Returns:

New pytree with new beam arrays. [batch_size, old_beam_size, …] –> [batch_size, new_beam_size, …]

t5x.decoding.cache_map(fn, cache, apply_to_index=False)[source]#

Maps function over that caches, even multiple caches in various layers.

Parameters:
  • fn – The function to apply.

  • cache – The cache to apply it to.

  • apply_to_index – Whether to apply the function to the cache index.

Returns:

The result of applying fn to the cache.

t5x.decoding.flat_batch_beam_expand(x, beam_size, offset=0)[source]#

Expands the each batch item by beam_size in batch_dimension.

t5x.decoding.flatten_beam_dim(x, offset=0)[source]#

Flattens the first two dimensions of a non-scalar array.

t5x.decoding.gather_beams(nested, beam_indices, batch_size, old_beam_size, new_beam_size, one_hot=True)[source]#

Gathers the beam slices indexed by beam_indices into new beam array.

Parameters:
  • nested – pytree of arrays or scalars (the latter ignored).

  • beam_indices – array of beam_indices

  • batch_size – size of batch.

  • old_beam_size – size of _old_ beam dimension.

  • new_beam_size – size of _new_ beam dimension.

  • one_hot – whether to perform gathers by one-hot contraction or directly.

Returns:

New pytree with new beam arrays. [batch_size, old_beam_size, …] –> [batch_size, new_beam_size, …]

t5x.decoding.gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_size)[source]#

Gathers the top-k beam slices given by score_or_log_prob array.

Parameters:
  • nested – pytree of arrays or scalars (the latter ignored).

  • score_or_log_prob – [batch_size, old_beam_size] array of values to sort by for top-k selection of beam slices.

  • batch_size – int: size of batch.

  • new_beam_size – int: size of _new_ top-k selected beam dimension

Returns:

New pytree with new beam arrays containing top k new_beam_size slices. [batch_size, old_beam_size, …] –> [batch_size, new_beam_size, …]

t5x.decoding.temperature_sample(inputs, cache, tokens_to_logits, eos_id, decode_rng=None, num_decodes=1, temperature=1.0, topk=1, topp=0.0, cache_offset=0, initial_index=None, max_decode_steps=None, max_decode_steps_hard_limit=None, rescale_log_probs=True, state_callback_fn=None, logit_callback_fn=None)[source]#

Temperature sampling for language model generation.

The temperature sampling is performed num_decodes times in a vectorized manner by expanding the batch dimension. This is similar to how beam search expands the batch dimension to process each batch element with multiple beams.

This function dynamically updates the inputs array by sampling from the model logits, which is provided by tokens_to_logits callable. The input sequences are expanded at the end, populated and sliced by dropping the first position.

If inputs has non-zero entries, those values are not modified, i.e., the sampled values for those positions are discarded. This simulates the teacher forcing on the prefix positions.

There are a few important observations related to this function.

  1. The inputs is assumed to be a non-packed sequence.

  2. If initial_index=None, then inputs`[:, 0] is ignored. We will use 0 as a BOS token to start the generation. This inherently assumes that `inputs is already shifted to the right by one position. If initial_index=an_array, the token values at `inputs`[:, initial_index] are used as the token to start the generation.

  3. The loop index, i, is a vector of shape [batch_size]. When beginning generation from scratch, each value will always have the same value. When beginning with a partially filled cache, the loop index of different elements can differ, via providing a value for initial_index.

  1. Unless all batch elements generated the eos_id before reaching the end, we always make max_decode_len = inputs.shape[1] number of calls to tokens_to_logits when decoding from scratch and max_decode_len - jnp.minimum(initial_index) number of calls when starting from a partially filled cache.

  2. Let output be the output sequences, i.e.,`sequences`[:, 1:]. Then output`[:, j] are the tokens generated when the while loop counter `i = j. Therefore, we generate the last token when i = max_decode_len - 1 and exit the while loop as all i`s are incremented to `max_decode_len.

  3. Once eos_id = 1 is generated, the subsequent predictions are all replaced by padding token 0.

  4. When using a partially filled cache, different batch elements can have different lengths. This means an input that has a longer input will have fewer steps until its i value reaches max_decode_len than an input with a shorter input. We keep these longer examples alive, doing busy work continually overwriting a new garbage token at the end of the sequence until shorter examples finish.

  5. When using a partially filled cache, providing a value for initial_index, the attention cache index should be a vector of [batch_size].

We show three examples to illustrate how this function works. In addition to input and output of the function, we also show two intermediate values: expanded_prompt_inputs and final_sequences. Also for simplicity, the examples are limited to num_decodes = 1 usage and the num_decodes dimension is omitted.

``` Example 1:

inputs = [0, 5, 6, 1, 0]

expanded_prompt_inputs = [0, 5, 6, 1, 0, 0]
final_sequences = [0, 5, 6, 1, a, b] # before slicing.

output = [5, 6, 1, a, b]

where a is prediction while taking 1 as input and b is prediction while taking a as input.

Example 2 (early stopping):
inputs = [[0, 5, 1, 0, 0, 0, 0],

[0, 8, 0, 0, 0, 0, 0]

expanded_prompt_inputs = [[0, 5, 1, 0, 0, 0, 0, 0],

[0, 8, 0, 0, 0, 0, 0, 0]

final_sequences = [[0, 5, 1, a, b, c=1, 0, 0],

[0, 8, d, e, f=1, g=0, 0, 0]]

output = [[5, 1, a, b, c=1, 0, 0],

[8, d, e, f=1, g=0, 0, 0]]

In this example, there are two sequences. Let’s look at sequence 0. The first generated token is a, which is in turn used to generate b. Finally, c = 1 is generated with the input b. Then the loop terminates early because 1 is the eos_id.

Now consider sequence 1. The when f = 1 was generated, it is considered done. Since sequence 0 is not done at this point, the next prediction, i.e., g is zerod out. This continues until the end.

Example 3 (prefilled cache):
inputs = [[0, 5, 2, 6, 1, 0],

[0, 8, 1, 0, 0, 0]]

expanded_prompt_inputs = [[0, 5, 2, 6, 1, 0, 0, 0],

[0, 8, 1, 0, 0, 0, 0, 0]]

max_decode_length = 6

i = [4, 2]
input_tokens = [[1],

[1]]

output_tokens = [[a],

[b]]

expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, 0, 0],

[0, 8, 1, b, 0, 0, 0, 0]]

i = [5, 3]
input_tokens = [[a],

[b]]

output_tokens = [[c],

[d]]

expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, c, 0],

[0, 8, 1, b, d, 0, 0, 0]]

i = [6, 4]
input_tokens = [[c],

[d]]

output_tokens = [[y],

[e]]

expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, c, y],

[0, 8, 1, b, d, e, 0, 0]]

i = [6, 5]
input_tokens = [[z],

[e]]

output_tokens = [[z],

[f]]

expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, c, z],

[0, 8, 1, b, d, e, f, 0]]

i = [6, 6]
exit
outputs = [[5, 2, 6, 1, a, c],

[8, 1, b, d, e, f]]

In this example, there are two sequences with different input lengths. Thus the two caches had been filled to different positions. As we decode, the first sequence hits the max decode length before the second. In order to avoid prematurely ending decoding for the second sequence, the first sequence continually overwrites the final token.

Example 4 (prefilled cache and max decode steps):
inputs = [[0, 2, 0, 0, 0, 0, 0, 0],

[0, 3, 4, 0, 0, 0, 0, 0]]

expanded_prompt_inputs = [[0, 2, 0, 0, 0, 0, 0, 0, 0, 0]

[0, 3, 4, 0, 0, 0, 0, 0, 0, 0]]

initial_indices = [1, 2] max_decode_step = 2

Then max_decode_len = [3, 4]. i = [1, 2]

input_tokens = [[2],

[4]]

output_tokens = [[a],

[b]]

expanded_prompt_inputs = [[0, 2, a, 0, 0, 0, 0, 0, 0, 0]

[0, 3, 4, b, 0, 0, 0, 0, 0, 0]]

i = [2, 3]]
input_tokens = [[a],

[b]]

output_tokens = [[c],

[d]]

expanded_prompt_inputs = [[0, 2, a, c, 0, 0, 0, 0, 0, 0]

[0, 3, 4, b, d, 0, 0, 0, 0, 0]]

This is the last while loop iteration with i == max_decode_len - 1.
outputs = [[2, a, c, 0, 0, 0, 0, 0]

[3, 4, b, d, 0, 0, 0, 0]]

```

Parameters:
  • inputs – array: [batch_size, max_decode_len] int32 sequence of tokens.

  • cache – flax attention cache.

  • tokens_to_logits – fast autoregressive decoder function taking single token slices and cache and returning next-token logits and updated cache.

  • eos_id – int: end-of-sentence token for target vocabulary.

  • decode_rng – JAX PRNGKey.

  • num_decodes – number of decoded sequences to be returned.

  • temperature – float: sampling temperature factor. As it approaches zero this becomes equivalent to greedy sampling. You may also provide an array of floats of size batch_size to use different temperature values for each batch item.

  • topk – integer: if nonzero only use the top-k logits to sample next token, if zero don’t use any cutoff and sample from full logits over vocabulary.

  • topp – float: if nonzero only use the smallest number of logits whose cumulative sum of probs adds up to (at least) topp. Will raise ValueError if it’s nonzero when topk is nonzero.

  • cache_offset – axis offset for cache, arising from scanned layers.

  • initial_index – Optional[array]: [batch_size] int32 a vector of loop indexes to start decoding at.

  • max_decode_steps – int: an optional maximum number of decoding steps. If None, it will decode until the full input shape inputs.shape[1] is filled. max_decode_steps begins counting after the prompt, so it will decode at most len(prompt) + max_decode_steps tokens.

  • max_decode_steps_hard_limit – int: an optional fixed hard limit on max_decode_steps. If this is set (not None and > 0), and max_decode_steps is also set, then max_decode_steps will be clipped to this limit. The value max_decode_steps can be an ndarray, but max_decode_steps_hard_limit must be a Python integer or None.

  • rescale_log_probs – bool: whether to apply temperature, topp, and topk rescaling to the log probs which are returned. If True, the log_probs will include these transformations (for example, with topk=1, all log_probs will be identically 0.0). If False, the log_probs will not be affected, and topk/topp/temperature will not affect sequence probabilities.

  • state_callback_fn – Function that modifies the sampling loop state before each step. This can be used to manipulate any part of the state either on the accelerator or on the host using host callback. The function should take a SamplingLoopState as argument, and it returns the updated state. See decoding_test.py for an example usage.

  • logit_callback_fn – Function that modifies the logits before each temperature sampling step. The function should take arguments (logits, state) and it should return the modified logits. See decoding_test.py for an example usage.

Returns:

A tuple (decodes, log_prob) where decodes is sampled sequences with shape [batch_size, num_decodes, max_decode_len] sorted by log_prob, which is log probability of each of the sampled sequences.

t5x.decoding.top_k_two_stage(x, k)[source]#

Wrapper around lax.top_k with low-batch optimization.

Parameters:
  • x – tensor with shape f32[batch, num_samples].

  • k – integer indicating how many top values to return.

Returns:

Largest k values and indices with shape (f32[batch, k], s32[batch, k]).

t5x.decoding.unflatten_beam_dim(x, batch_size, beam_size, offset=0)[source]#

Unflattens the first, flat batch*beam dimension of a non-scalar array.