t5x.checkpoint_importer package#
T5 Checkpoint Importer.
- class t5x.checkpoint_importer.CheckpointTranslator[source]#
Utility class for defining mapping rules from one flatdict to another.
We assume a checkpoint is loaded as a dictionary with flattened keys of the form: ‘name0/name1/name2/…/nameN’
A rule is added with the ‘add’ decorator, which takes a regex matching rule and wraps a conversion function, feeding it (opts, key, val, **regex_groups) where opts is a dict containing apply-time keyword options for use by the conversion functions.
- class t5x.checkpoint_importer.LazyArray(shape, dtype, get_fn)[source]#
Lazily and asynchronously loads an array.
LazyArray behaves in the same way as a numpy or jax.numpy array while instantiating lazily. All properties, including shape, dtype, and nbytes are created when the LazyArray is created, but no data is materialized until get or get_async are called. Data is materialized using a specified get_fn.
This class can be used to implement lazy restoration in checkpointing APIs, where the data is only read from disk when explicitly needed by the user.
- class t5x.checkpoint_importer.LazyAwaitableArray(shape, dtype, get_fn)[source]#
Lazily and asynchronously loads an array when the get_fn is async.
Note
The synchronous load method .get requires the asyncio event loop and calling .run_until_complete. This is not supported when the event loop is already running (for example, from inside another async function).
Note
Currently, this class has a few helper methods for creating a LazyAwaitableArray when the input could be either an array, or a TensorStore spec. Most people use async code when dealing with TensorStore so the classmethods have been placed here. When someone eventually uses a blocking function to read from TensorStore they can be moved to the LazyArray base class.
- classmethod from_array(array, get_fn, dtype=None)[source]#
Create a LazyAwaitableArray based on an array or python number.
- class t5x.checkpoint_importer.LazyThreadPoolArray(shape, dtype, get_fn)[source]#
Lazily and asynchronously loads an array when the get_fn blocks.
- t5x.checkpoint_importer.attention_layers(opts, key, val, encdec, blocknum, attntype, qkvo, slot)[source]#
Process attention layers.
- t5x.checkpoint_importer.final_layernorms(opts, key, val, encdec, slot)[source]#
Process final layer norms.
- t5x.checkpoint_importer.layernorms(opts, key, val, encdec, blocknum, lyrnum, slot)[source]#
Process layer norms assuming that they are pre-layernorms.
- t5x.checkpoint_importer.load_tf_ckpt(path)[source]#
Load a TF checkpoint as a flat dictionary of numpy arrays.
- t5x.checkpoint_importer.mlpblock(opts, key, val, encdec, blocknum, io_name, io_num, slot)[source]#
Process MLP blocks.
- t5x.checkpoint_importer.rel_embeddings(opts, key, val, encdec, blocknum, slot)[source]#
Process relpos bias assuming that they are not shared across layers.
- t5x.checkpoint_importer.restore_from_t5_checkpoint(state_dict, path, lazy_parameters=False, strict=True, translator=None)[source]#
Load T5 checkpoint and update Adafactor optimizer and T5 model from it.
We require that the final translated checkpoint structure exactly matches that of the Flax Adafactor + Transformer data, up to shape agreement of the leaves.
- Parameters:
state_dict – Flax Adafactor Optimizer for T5 transformer encoder-decoder.
path – a path to checkpoint file or directory.
lazy_parameters – whether to leave the parameters as LazyArrays to preserve memory.
strict – If True requires that optimizer and t5_data mappings contain the same set of names (variables). If False, updating will succeed even if t5_data contains variables not in the optimizer. If the optimizer has variables not in t5_data, this function will still fail.
translator – The mapping rules for conversion. If None, then default T5 conversion rules will be used.
- Returns:
Adafactor optimizer updated with parameters and optimizer state from T5 checkpoint.