t5x.adafactor package#

Adafactor Optimizer.

Specialized Adafactor implementation for T5X with:
  • custom factorization specification rules.

  • support for stacked parameters from scanned layers and parameter fusions.

Why do we need custom factorization? In the Adafactor paper, scalar, vector and matrix parameters are considered. This is sufficiently general because higher dimensional parameters can be reshaped. In practice, there are situations where higher dimensional parameters are desirable. For example, consider the multi-headed attention. It has projection kernels. This is naturally represented as 3-dimensional array [d_model, num_head, head_dim]. Keeping the 3-dimensional structure can be beneficial for performance optimization, e.g., by giving compilers additional degree of freedom to do layout optimization.

The default heuristic behavior for the second-moment estimator can lead to an unexpected result because it assumes that the parameters are matrices (vectors and scalars are not factored). The dimensions are sorted and the smaller dimension is assigned to the row dim and the larger dim to the col dim (unless the two largest dims have an equal size and then the original ordering of the dimensions is used). Then v_row (i.e., the optimizer state for the row) is obtained by removing the col dim. In other words, rank(v_row) = rank(v) - 1. If the parameter is higher dimensional, v_row and v_col are higher dimensional. Therefore, the outer product of v_row and v_col do not necessarily corresponds to the row rank approximation that minimizes the generalized Kullback-Leibler divergence (the original Adafactor formulation).

This Adafactor implementation generalized the default behavior such that we obtain the correct second moment estimator even for higher dimensional parameters.

class t5x.adafactor.Adafactor(learning_rate=None, factored=True, multiply_by_parameter_scale=True, beta1=None, decay_rate=0.8, step_offset=0, clipping_threshold=1.0, weight_decay_rate=None, min_dim_size_to_factor=128, epsilon1=1e-30, epsilon2=0.001, dtype_momentum=<class 'jax.numpy.float32'>, factor_map=None, logical_factor_rules=None, weight_decay_rate_lr_exponent=None, global_norm_clip_threshold=None, max_parameter_scale=None, skip_nan_updates=False)[source]#

Adafactor optimizer.

Adafactor is described in https://arxiv.org/abs/1804.04235.

apply_gradient(hyper_params, params, state, grads)[source]#

Applies a gradient for a set of parameters.

Parameters:
  • hyper_params – a named tuple of hyper parameters.

  • params – the parameters that should be updated.

  • state – a named tuple containing the state of the optimizer

  • grads – the gradient tensors for the parameters.

Returns:

A tuple containing the new parameters and the new optimizer state.

derive_logical_axes(optimizer_state, param_logical_axes)[source]#

Derives optimizer logical partitioning from model logical partitions.

set_param_axes(param_logical_axes)[source]#

Sets Adafactor factorization map from logical axis names tree.

class t5x.adafactor.FactorDim(value)[source]#

An enumeration.

class t5x.adafactor.HParamMap(rules)[source]#

Maps parameter path names to hparams.

Names of parameters nested in a PyTree (e.g., an Optimizer) are formed by joining the names along the path to the parameter leaf with ‘/’.

class t5x.adafactor.HeuristicRule(value)[source]#

An enumeration.