t5x.state_utils package#

Utilities for processing optimizer states.

t5x.state_utils.apply_assignment_map(ckpt_optimizer_state, optimizer_state, assignment_map, require_all_rules_match=True, *, is_resuming=False)[source]#

Applies an assignment map to a checkpoint optimizer state.

In contrast to previous implementations, this has a switch whether to require that all rules match, and has somewhat-custom-but-sensible replacement rules:

  1. old keys that are matched are removed.

  2. old keys that don’t match are retained.

  3. if two new keys map to the same old key, they both get assigned to its value.

  4. if a new key isn’t mapped but is in the checkpoint, it is copied over.

  5. new keys with None-valued replacement patterns are removed.

Parameters:
  • ckpt_optimizer_state – Optimizer state in the checkpoint (usually, previous model).

  • optimizer_state – optimizer state in the current model.

  • assignment_map – List of tuples (matcher, replacement) where matcher is a regex, and replacement is a string replacement (possibly with regex-compatible group match codes) or None if the matching state should be dropped.

  • require_all_rules_match – Whether to require that all rules match.

  • is_resuming – Whether we are resuming a training run (True) or initializing a new one (False).

Returns:

New, remapped optimizer state.

t5x.state_utils.flatten_state_dict(state_dict, keep_empty_nodes=False)[source]#

Flatten a dictionary until an array or tensorstore is reached.

Parameters:
  • state_dict – Optimizer state as nested dictionary.

  • keep_empty_nodes – Whether to keep empty node, for example, empty param states from simple optimizers or non-touched parameter states in a multioptimizer.

Returns:

Flattened dictionary, though keeping tensor store state unflattened.

t5x.state_utils.get_name_tree(state_dict, keep_empty_nodes=False)[source]#

Returns new state_dict with leaves as joined path keys separated by “/”.

t5x.state_utils.intersect_state(state_dict, intersect_state_dict)[source]#

Drops non-matching entries from state_dict.

Parameters:
  • state_dict – nested dict of optimizer state

  • intersect_state_dict – nested dict of entries to keep

Returns:

nested dict like state_dict but with extra keys removed

t5x.state_utils.merge_state(state_dict, from_scratch_state, overwrite=False)[source]#

Inserts new entries into state_dict.

Parameters:
  • state_dict – nested dict of optimizer state

  • from_scratch_state – nested dict of entries to insert

  • overwrite – if True, values present in both state_dict and from_scratch_state will be present in the result with the value taken from from_scratch_state.

Returns:

a nested dict like state_dict but with extra entries from

from_scratch_state inserted

t5x.state_utils.tensorstore_leaf(_, value)[source]#

Detect if the node is a serialized tensorstore spec.

Parameters:
  • _ – The unused name of the current item.

  • value – The value of the possible leaf.

Returns:

True if the value represents a tensorstore spec, False otherwise.