t5x.metrics package#

T5X Metrics.

Defines Metric objects and collections used by T5X models. These objects use the CLU metrics library

class t5x.metrics.AveragePerStep(steps=1, total=None)[source]#

Represents per-step average (total divided by number of steps).

See also documentation of Step.

compute()[source]#

Computes final metrics from intermediate values.

classmethod from_model_output(values, steps=1, **_)[source]#

Initializes an AveragePerStep Metric from array (or singular) values.

Parameters:
  • values – array of values to sum (or a single value).

  • steps – number of steps, defaults to 1.

Returns:

AveragePerStep object.

merge(other)[source]#

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

replace(**updates)#

“Returns a new object replacing the specified fields with new values.

class t5x.metrics.Step(steps=1)[source]#

Abstract class representing a per-step or step-per metric.

Tracks number of steps. Must be set manually using replace_steps, since the use of microbatches may otherwise cause the computation to be incorrect.

See also documentation of Metric.

compute()[source]#

Computes final metrics from intermediate values.

replace(**updates)#

“Returns a new object replacing the specified fields with new values.

class t5x.metrics.StepsPerTime(duration=None, steps=1)[source]#

Represents a metric computed as number of steps per time.

See also documentation of Step.

compute()[source]#

Computes final metrics from intermediate values.

classmethod from_model_output(steps=1, **_)[source]#

Initializes an StepsPerTime Metric.

Parameters:

steps – number of steps, defaults to 1.

Returns:

StepsPerTime object.

merge(other)[source]#

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

replace(**updates)#

“Returns a new object replacing the specified fields with new values.

class t5x.metrics.Sum(total)[source]#

Computes the sum of a scalar or a batch of tensors.

See also documentation of Metric.

compute()[source]#

Computes final metrics from intermediate values.

classmethod from_model_output(values, **_)[source]#

Initializes a Sum Metric from array (or singular) values.

Parameters:

values – array of values to sum (or a single value).

Returns:

A Sum object.

merge(other)[source]#

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

replace(**updates)#

“Returns a new object replacing the specified fields with new values.

class t5x.metrics.Time(duration=None)[source]#

Computes the sum of a float-valued metric over a period of time.

Duration (the denominator) must be set manually. This is because JAX does not properly support time functions inside compiled functions. Calling time.time() inside a compiled function results in the stored time being the compilation time, not the run time.

See also documentation of Metric.

compute()[source]#

Computes final metrics from intermediate values.

merge(other)[source]#

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

replace(**updates)#

“Returns a new object replacing the specified fields with new values.

replace_duration(duration)[source]#

Replaces duration with the given value.

Should be used outside a compiled function to set the duration of the metric.

Parameters:

duration – metric duration

Returns:

A new Time object.

class t5x.metrics.TimeRate(duration=None, numerator=None)[source]#

Computes the sum of a float-valued metric over a period of time.

Duration (the denominator) must be set using replace_duration. This is because JAX does not properly support time functions inside compiled functions. Calling time.time() inside a compiled function results in the stored time being the compilation time, not the run time.

See also documentation of Time and Metric.

compute()[source]#

Computes final metrics from intermediate values.

classmethod from_model_output(numerator, **_)[source]#

Initializes a TimeRate Metric from a float value (the numerator).

Parameters:

numerator – a float (numerator of the metric)

Returns:

A TimeRate object.

merge(other)[source]#

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

replace(**updates)#

“Returns a new object replacing the specified fields with new values.

t5x.metrics.create_metrics_dict(float_metrics_dict)[source]#

Input: dict{str: float} | Output: dict{str: Metric}.

t5x.metrics.set_step_metrics_num_steps(metrics, num_steps)[source]#

Sets steps for Step objects in metrics pytree.

t5x.metrics.set_time_metrics_duration(metrics, duration)[source]#

Sets duration for TimeRate objects in metrics pytree.

t5x.metrics.shape_obj_to_defined_obj(obj)[source]#

Converts shapes in Metric to zero arrays.

obj should be a Metric object subclass where each member variable is a ShapeDtypeStruct (from jax.eval_shape). A new object of the same class where each member variable is an array of zeros with the same shape and type as the corresponding variable defined by ShapeDtypeStruct.

Parameters:

obj – a clu.metrics.Metric object where each member variable is a ShapeDtypeStruct (from jax.eval_shape)

Returns:

A Metric object with class variables initialized as zero arrays.