t5x.checkpoint_importer package#

T5 Checkpoint Importer.

class t5x.checkpoint_importer.CheckpointTranslator[source]#

Utility class for defining mapping rules from one flatdict to another.

We assume a checkpoint is loaded as a dictionary with flattened keys of the form: ‘name0/name1/name2/…/nameN’

A rule is added with the ‘add’ decorator, which takes a regex matching rule and wraps a conversion function, feeding it (opts, key, val, **regex_groups) where opts is a dict containing apply-time keyword options for use by the conversion functions.

add(pattern)[source]#

Adds a new keyval conversion rule.

Parameters:

pattern – regex with capture groups for matching given sets of model variables. We terminate all regexes with ‘$’ to force complete matches.

Returns:

Translation function decorator for associating with the provided pattern.

apply(flatdict, **opts)[source]#

Applies rules to a flattened dictionary.

Parameters:
  • flatdict – flat-key dictionary of variables.

  • **opts – additional config options for translation rules supplied at application time.

Returns:

Checkpoint data with translated key/values in flat-key dict format.

class t5x.checkpoint_importer.LazyArray(shape, dtype, get_fn)[source]#

Lazily and asynchronously loads an array.

LazyArray behaves in the same way as a numpy or jax.numpy array while instantiating lazily. All properties, including shape, dtype, and nbytes are created when the LazyArray is created, but no data is materialized until get or get_async are called. Data is materialized using a specified get_fn.

This class can be used to implement lazy restoration in checkpointing APIs, where the data is only read from disk when explicitly needed by the user.

class t5x.checkpoint_importer.LazyAwaitableArray(shape, dtype, get_fn)[source]#

Lazily and asynchronously loads an array when the get_fn is async.

Note

The synchronous load method .get requires the asyncio event loop and calling .run_until_complete. This is not supported when the event loop is already running (for example, from inside another async function).

Note

Currently, this class has a few helper methods for creating a LazyAwaitableArray when the input could be either an array, or a TensorStore spec. Most people use async code when dealing with TensorStore so the classmethods have been placed here. When someone eventually uses a blocking function to read from TensorStore they can be moved to the LazyArray base class.

classmethod from_array(array, get_fn, dtype=None)[source]#

Create a LazyAwaitableArray based on an array or python number.

classmethod from_tensor_store_spec(ts_spec, get_fn, dtype=None)[source]#

Create a LazyAwaitableArray based on a tensorstore.Spec.

classmethod from_tensor_store_spec_or_array(maybe_ts_spec, get_fn, dtype=None)[source]#

Create a LazyAwaitableArray based on an array or a tensorstore.Spec.

class t5x.checkpoint_importer.LazyThreadPoolArray(shape, dtype, get_fn)[source]#

Lazily and asynchronously loads an array when the get_fn blocks.

t5x.checkpoint_importer.attention_layers(opts, key, val, encdec, blocknum, attntype, qkvo, slot)[source]#

Process attention layers.

t5x.checkpoint_importer.final_layernorms(opts, key, val, encdec, slot)[source]#

Process final layer norms.

t5x.checkpoint_importer.layernorms(opts, key, val, encdec, blocknum, lyrnum, slot)[source]#

Process layer norms assuming that they are pre-layernorms.

t5x.checkpoint_importer.load_tf_ckpt(path)[source]#

Load a TF checkpoint as a flat dictionary of numpy arrays.

t5x.checkpoint_importer.mlpblock(opts, key, val, encdec, blocknum, io_name, io_num, slot)[source]#

Process MLP blocks.

t5x.checkpoint_importer.rel_embeddings(opts, key, val, encdec, blocknum, slot)[source]#

Process relpos bias assuming that they are not shared across layers.

t5x.checkpoint_importer.restore_from_t5_checkpoint(state_dict, path, lazy_parameters=False, strict=True, translator=None)[source]#

Load T5 checkpoint and update Adafactor optimizer and T5 model from it.

We require that the final translated checkpoint structure exactly matches that of the Flax Adafactor + Transformer data, up to shape agreement of the leaves.

Parameters:
  • state_dict – Flax Adafactor Optimizer for T5 transformer encoder-decoder.

  • path – a path to checkpoint file or directory.

  • lazy_parameters – whether to leave the parameters as LazyArrays to preserve memory.

  • strict – If True requires that optimizer and t5_data mappings contain the same set of names (variables). If False, updating will succeed even if t5_data contains variables not in the optimizer. If the optimizer has variables not in t5_data, this function will still fail.

  • translator – The mapping rules for conversion. If None, then default T5 conversion rules will be used.

Returns:

Adafactor optimizer updated with parameters and optimizer state from T5 checkpoint.