t5x.binary_search package#

Binary search over float32 bits.

Includes fast algorithms top-k masking and top-p masking on probability distributions.

t5x.binary_search.float32_bsearch(batch_shape, predicate)[source]#

Binary search on finite float32 numbers.

For each element of the batch, this function searches for the largest finite non-NaN float32 for which the predicate is False.

Parameters:
  • batch_shape – Shape of the search that we’re batching over.

  • predicate – the query we’re searching for. This is required to be monotonic with respect to the floating point order, i.e. it must be False for all numbers <= a threshold, and then True for all numbers > the threshold. The threshold may be different for different elements of the batch.

Returns:

For each element of the batch, the largest float32 for which the predicate returns False. Shape: f32[batch_shape].

t5x.binary_search.int32_bsearch(batch_shape, predicate)[source]#

Batched binary search over int32 values.

For each element of the batch, search for the largest int32 (closest to positive infinity) for which the predicate is False. If the predicate is always True, returns the minimum int32 value.

Parameters:
  • batch_shape – Shape of the search that we’re batching over.

  • predicate – the query we’re searching for. For every batch element, this is required to be a monotonic function from int32 to bool. In other words, the predicate must return False for all numbers <= some threshold and then return True for all numbers > that threshold. The threshold may be different for different elements of the batch.

Returns:

For each element of the batch, the largest int32 for which the predicate returns False. Shape: batch_shape.

t5x.binary_search.topk_mask(x, k, replace_val)[source]#

Sets everything to replace_val, except the top k values per batch element.

Sharding considerations: this function does 32 reductions over the vocab_size axis of the input array. To avoid excessive latency from these reductions, you should ensure that the vocab_size axis is unsharded on input to this function. Prefer to shard the batch axes instead.

Scratchpad memory considerations: this function is most efficient if the entire input array can fit in a fast memory tier. To help ensure this, you may wish to split the batch axes into microbatches and the microbatches in a sequential loop.

Parameters:
  • x – Values before masking. [batch…, vocab_size]

  • k – Number of masked values to return. In presence of ties, more than k values might be returned.

  • replace_val – For the masked values of x, what to overwrite them with.

Returns:

masked version of x. [batch…, vocab_size]

t5x.binary_search.topp_mask(logits, p, replace_val)[source]#

Applies top-p masking to logits.

Masks logits down to the smallest set of choices, such that the total probability mass is >= p. Values in this set are left as they are. All other values are set with replace_val.

Sharding considerations: this function does 33 reductions over the vocab_size axis of the input array. To avoid excessive latency from these reductions, you should ensure that the vocab_size axis is unsharded on input to this function. Prefer to shard the batch axes instead.

Scratchpad memory considerations: this function is most efficient if the entire input array can fit in a fast memory tier. To help ensure this, you may wish to split the batch axes into microbatches and the microbatches in a sequential loop.

Parameters:
  • logits – Logits before masking. [batch…, vocab_size]

  • p – Minimum probability mass requested.

  • replace_val – For the masked values of logits, what to overwrite them with.

Returns:

masked version of x. [batch…, vocab_size]