t5x.train_state package#
Train state for passing around objects during training.
- class t5x.train_state.FlaxOptimTrainState(_optimizer, params_axes=None, flax_mutables=FrozenDict({}), flax_mutables_axes=None)[source]#
Simple train state for holding parameters, step, optimizer state.
- replace(**updates)#
“Returns a new object replacing the specified fields with new values.
- class t5x.train_state.InferenceState(step, params, params_axes=None, flax_mutables=FrozenDict({}), flax_mutables_axes=None)[source]#
State compatible with FlaxOptimTrainState without optimizer state.
- property param_states#
The optimizer states of the parameters as a PyTree.
- replace(**updates)#
“Returns a new object replacing the specified fields with new values.
- class t5x.train_state.TrainState(*args, **kwargs)[source]#
TrainState interface.
- apply_gradient(grads, learning_rate, flax_mutables=FrozenDict({}))[source]#
Applies gradient, increments step, and returns an updated TrainState.
- property flax_mutables#
Flax mutable collection.
- property param_states#
The optimizer states of the parameters as a PyTree.
- property params#
The parameters of the model as a PyTree matching the Flax module.
- property step#
The current training step as an integer scalar.