t5x.train_state package#

Train state for passing around objects during training.

class t5x.train_state.FlaxOptimTrainState(_optimizer, params_axes=None, flax_mutables=FrozenDict({}), flax_mutables_axes=None)[source]#

Simple train state for holding parameters, step, optimizer state.

replace(**updates)#

“Returns a new object replacing the specified fields with new values.

class t5x.train_state.InferenceState(step, params, params_axes=None, flax_mutables=FrozenDict({}), flax_mutables_axes=None)[source]#

State compatible with FlaxOptimTrainState without optimizer state.

property param_states#

The optimizer states of the parameters as a PyTree.

replace(**updates)#

“Returns a new object replacing the specified fields with new values.

class t5x.train_state.TrainState(*args, **kwargs)[source]#

TrainState interface.

apply_gradient(grads, learning_rate, flax_mutables=FrozenDict({}))[source]#

Applies gradient, increments step, and returns an updated TrainState.

as_logical_axes()[source]#

Replaces param and param-states with their logical axis names.

property flax_mutables#

Flax mutable collection.

property param_states#

The optimizer states of the parameters as a PyTree.

property params#

The parameters of the model as a PyTree matching the Flax module.

restore_state(state_dict)[source]#

Restores the object state from a state dict.

state_dict()[source]#

Returns a mutable representation of the state for checkpointing.

property step#

The current training step as an integer scalar.