t5x.optimizers package#

T5X Optimizer Support.

Tools for wrapping Optax optimizers and handling SPMD annotations for use with pjit.

Additional support for the legacy Adafactor implementation.

class t5x.optimizers.MultiOptimizer(traversals_and_optimizers)[source]#

Generalized Multioptimizer.

NB: Although this is provided for legacy support, it is still quite general and should work fine with wrapped optax optimizers. But do note that the more canonical way of mixing multiple optimizers inside optax uses optax.masked or optax.multi_transform instead.

A MultiOptimizer is subclass of OptimizerDef and useful for applying separate optimizer algorithms to various subsets of the model parameters.

The example below creates two optimizers using flax.traverse_util.ModelParamTraversal: one to optimize kernel parameters and to optimize bias parameters. Note each optimizer is created with a different learning rate:

kernels = traverse_util.ModelParamTraversal(
    lambda path, _: 'kernel' in path)
biases = traverse_util.ModelParamTraversal(lambda path, _: 'bias' in path)
kernel_opt = optimizers.adam(learning_rate=0.01)
bias_opt = optimizers.adam(learning_rate=0.1)
opt_def = MultiOptimizer((kernels, kernel_opt), (biases, bias_opt))
optimizer = opt_def.create(model)

In order to train only a subset of the parameters, you can simply use a single flax.traverse_util.ModelParamTraversal instance.

If you want to update the learning rates of both optimizers online with different learning rate schedules, you should update the learning rates when applying the gradient. In the following example, the second optimizer is not doing any optimization during the first 1000 steps:

hparams = optimizer.optimizer_def.hyper_params
new_optimizer = optimizer.apply_gradient(
    grads,
    hyper_params=[
      hparams[0].replace(learning_rate=0.2),
      hparams[1].replace(learning_rate=jnp.where(step < 1000, 0., lr)),
    ])
apply_gradient(hyper_params, params, state, grads)[source]#

Applies a gradient for a set of parameters.

derive_logical_axes(optimizer, param_logical_axes)[source]#

Derives optimizer logical partitioning from model logical partitions.

set_param_axes(param_logical_axes)[source]#

Derives factorization rules from model parameter logical axes.

update_hyper_params(**hyper_param_overrides)[source]#

Updates the hyper parameters with a set of overrides.

This method is called from Optimizer.apply_gradient() to create the hyper parameters for a specific optimization step. MultiOptimizer will apply the overrides for each sub optimizer.

Parameters:

**hyper_param_overrides – the hyper parameters updates will override the defaults specified in the OptimizerDef. Pass hyper_params=… to replace all hyper parameters.

Returns:

The new hyper parameters.

class t5x.optimizers.OptaxStatePartitionRules[source]#

Collection of rules to partition optax states.

These rules work for optimizers whose states are simply replications of params, e.g., Adam. Optimizers that aim to save memory by factoring states, e.g., Adafactor, SM3, are not supported currently.

classmethod derive_optax_logical_axes(optax_state, params_axes)[source]#

Derived logical axes for optax state.

classmethod derive_params_axes(optax_params, params_axes)[source]#

Derive axes for params inside optax state.

class t5x.optimizers.OptaxWrapper(optax_optimizer)[source]#

Wrapper to make optax optimizer compatible with T5X.

apply_gradient(hyper_params, params, state, grads)[source]#

Applies gradient.

Parameters:
  • hyper_params – Unused hyper parameters.

  • params – PyTree of the parameters.

  • state – A named tuple containing the state of the optimizer.

  • grads – PyTree of the gradients for the parameters.

Returns:

A tuple containing the new parameters and the new optimizer state.

derive_logical_axes(optimizer, param_logical_axes)[source]#

Derives optimizer state logical axes from params logical axes.

Parameters:
  • optimizeroptimizers.Optimizer instance.

  • param_logical_axes – A PyTree where each leaf is a t5x PartitionSpec.

Returns:

An optimizers.Optimizer instance, with all the leafs replaced by t5x PartitionSpec or None (no partition).

init_state(params)[source]#

Create initial state based on the params to optimize.

Parameters:

params – PyTree of parameters to optimize.

Returns:

Initial optimizer state.

restore_state(opt_target, opt_state, state_dict)[source]#

Override to restore empty dicts corresponding to optax.EmptyState.

Parameters:
  • opt_target – the optimizer target.

  • opt_state – the optimizer state.

  • state_dict – the state dict containing the desired new state of the optimizer.

Returns:

a tuple of the optimizer target and state with the restored values from the state dict.

state_dict(target, state)[source]#

Override state dict function.

We need to override this function because many optax transformations use optax.EmptyState, which produces empty dict in the state dict. This causes the T5 training loop to fail in multiple places. As a remedy, we will filter out the generated state dict so that there are no empty dict in the output.

The restore_state function is also overridden to reconstruct those empty dict.

Parameters:
  • target – Pytree of target variables.

  • state – Pytree of optimizer state.

Returns:

A nested state.

class t5x.optimizers.Optimizer(optimizer_def, state, target)[source]#

Legacy flax optimizer class.

Optimizer carries the target and optimizer state. The optimizer is updated using the method apply_gradient.

optimizer_def#

The optimizer definition.

Type:

t5x.optimizers.OptimizerDef

state#

The initial state of the optimizer.

Type:

Any

target#

The target to optimizer.

Type:

Any

apply_gradient(grads, **hyper_param_overrides)[source]#

Applies a pytree of gradients to the target.

Parameters:
  • grads – A pytree of gradients.

  • **hyper_param_overrides – the hyper parameters passed to apply_gradient will override the defaults specified in the OptimizerDef. Pass hyper_params=… to replace all hyper parameters.

Returns:

A new optimizer with the updated target and state.

replace(**updates)#

“Returns a new object replacing the specified fields with new values.

class t5x.optimizers.OptimizerDef(hyper_params)[source]#

Base class for an optimizer definition.

apply_gradient(hyper_params, params, state, grads)[source]#

Applies a gradient for a set of parameters.

create(target)[source]#

Creates a new optimizer for the given target.

Parameters:

target – the object to be optimized. This is typically a variable dict returned by flax.linen.Module.init(), but it can also be a container of variables dicts, e.g. (v1, v2) and (‘var1’: v1, ‘var2’: v2) are valid inputs as well.

Returns:

An instance of Optimizer.

restore_state(opt_target, opt_state, state_dict)[source]#

Restore the optimizer target and state from the state dict.

Parameters:
  • opt_target – the optimizer target.

  • opt_state – the optimizer state.

  • state_dict – the state dict containing the desired new state of the optimizer.

Returns:

a tuple of the optimizer target and state with the restored values from the state dict.

update_hyper_params(**hyper_param_overrides)[source]#

Updates the hyper parameters with a set of overrides.

Parameters:

**hyper_param_overrides – the hyper parameters updates will override the defaults specified in the OptimizerDef. Pass hyper_params=… to replace all hyper parameters.

Returns:

The new hyper parameters.

t5x.optimizers.OptimizerDefType#

alias of OptimizerDef

class t5x.optimizers.OptimizerState(step: jax.Array, param_states: Any)[source]#
replace(**updates)#

“Returns a new object replacing the specified fields with new values.

t5x.optimizers.OptimizerType#

alias of Optimizer

t5x.optimizers.wrap_optax_optimizer(optax_optimizer)[source]#

Converts optax optimizer constructor to a wrapped T5X-compatible optimizer.

Parameters:

optax_optimizer – an optax optimizer creation function that returns an optax GradientTransformation.

Returns:

A function that takes the same arguments as the original optax creation function but instead returns a wrapped OptimizerDef-compatible interface for using the optimizer with T5X.