t5x.metrics package#
T5X Metrics.
Defines Metric objects and collections used by T5X models. These objects use the CLU metrics library
- class t5x.metrics.AveragePerStep(steps=1, total=None)[source]#
Represents per-step average (total divided by number of steps).
See also documentation of Step.
- classmethod from_model_output(values, steps=1, **_)[source]#
Initializes an AveragePerStep Metric from array (or singular) values.
- Parameters:
values – array of values to sum (or a single value).
steps – number of steps, defaults to 1.
- Returns:
AveragePerStep object.
- merge(other)[source]#
Returns Metric that is the accumulation of self and other.
- Parameters:
other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().
- Returns:
A new Metric that accumulates the value from both self and other.
- replace(**updates)#
“Returns a new object replacing the specified fields with new values.
- class t5x.metrics.Step(steps=1)[source]#
Abstract class representing a per-step or step-per metric.
Tracks number of steps. Must be set manually using replace_steps, since the use of microbatches may otherwise cause the computation to be incorrect.
See also documentation of Metric.
- replace(**updates)#
“Returns a new object replacing the specified fields with new values.
- class t5x.metrics.StepsPerTime(duration=None, steps=1)[source]#
Represents a metric computed as number of steps per time.
See also documentation of Step.
- classmethod from_model_output(steps=1, **_)[source]#
Initializes an StepsPerTime Metric.
- Parameters:
steps – number of steps, defaults to 1.
- Returns:
StepsPerTime object.
- merge(other)[source]#
Returns Metric that is the accumulation of self and other.
- Parameters:
other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().
- Returns:
A new Metric that accumulates the value from both self and other.
- replace(**updates)#
“Returns a new object replacing the specified fields with new values.
- class t5x.metrics.Sum(total)[source]#
Computes the sum of a scalar or a batch of tensors.
See also documentation of Metric.
- classmethod from_model_output(values, **_)[source]#
Initializes a Sum Metric from array (or singular) values.
- Parameters:
values – array of values to sum (or a single value).
- Returns:
A Sum object.
- merge(other)[source]#
Returns Metric that is the accumulation of self and other.
- Parameters:
other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().
- Returns:
A new Metric that accumulates the value from both self and other.
- replace(**updates)#
“Returns a new object replacing the specified fields with new values.
- class t5x.metrics.Time(duration=None)[source]#
Computes the sum of a float-valued metric over a period of time.
Duration (the denominator) must be set manually. This is because JAX does not properly support time functions inside compiled functions. Calling time.time() inside a compiled function results in the stored time being the compilation time, not the run time.
See also documentation of Metric.
- merge(other)[source]#
Returns Metric that is the accumulation of self and other.
- Parameters:
other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().
- Returns:
A new Metric that accumulates the value from both self and other.
- replace(**updates)#
“Returns a new object replacing the specified fields with new values.
- class t5x.metrics.TimeRate(duration=None, numerator=None)[source]#
Computes the sum of a float-valued metric over a period of time.
Duration (the denominator) must be set using replace_duration. This is because JAX does not properly support time functions inside compiled functions. Calling time.time() inside a compiled function results in the stored time being the compilation time, not the run time.
See also documentation of Time and Metric.
- classmethod from_model_output(numerator, **_)[source]#
Initializes a TimeRate Metric from a float value (the numerator).
- Parameters:
numerator – a float (numerator of the metric)
- Returns:
A TimeRate object.
- merge(other)[source]#
Returns Metric that is the accumulation of self and other.
- Parameters:
other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().
- Returns:
A new Metric that accumulates the value from both self and other.
- replace(**updates)#
“Returns a new object replacing the specified fields with new values.
- t5x.metrics.create_metrics_dict(float_metrics_dict)[source]#
Input: dict{str: float} | Output: dict{str: Metric}.
- t5x.metrics.set_step_metrics_num_steps(metrics, num_steps)[source]#
Sets steps for Step objects in metrics pytree.
- t5x.metrics.set_time_metrics_duration(metrics, duration)[source]#
Sets duration for TimeRate objects in metrics pytree.
- t5x.metrics.shape_obj_to_defined_obj(obj)[source]#
Converts shapes in Metric to zero arrays.
obj should be a Metric object subclass where each member variable is a ShapeDtypeStruct (from jax.eval_shape). A new object of the same class where each member variable is an array of zeros with the same shape and type as the corresponding variable defined by ShapeDtypeStruct.
- Parameters:
obj – a clu.metrics.Metric object where each member variable is a ShapeDtypeStruct (from jax.eval_shape)
- Returns:
A Metric object with class variables initialized as zero arrays.