Source code for t5x.eval

# Copyright 2023 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint:disable=line-too-long
# pyformat: disable
r"""Runs training- and inference-evaluation on a T5X-compatible model.

"""

# pyformat: enable
# pylint:enable=line-too-long
import functools
import os
import re
from typing import Callable, Collection, Mapping, Optional, Sequence, Set, Tuple, Type

# pylint:disable=g-import-not-at-top
# TODO(adarob): Re-enable once users are notified and tests are updated.
os.environ['FLAX_LAZY_RNG'] = 'no'
from absl import logging
from clu import metric_writers
import jax
import seqio
from t5x import checkpoints
from t5x import gin_utils
from t5x import models
from t5x import partitioning
from t5x import train_state as train_state_lib
from t5x import trainer as trainer_lib
from t5x import utils
import tensorflow as tf
from tensorflow.io import gfile
from typing_extensions import Protocol
# pylint:enable=g-import-not-at-top

# Automatically search for gin files relative to the T5X package.
_DEFAULT_GIN_SEARCH_PATHS = [
    os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
]


[docs]class SummarizeConfigFn(Protocol): def __call__( self, model_dir: str, summary_writer: Optional[metric_writers.SummaryWriter], step: int, ) -> None: ...
[docs]class InferenceEvaluator: """Runs evaluation of the model against a given SeqIo task.""" def __init__( self, infer_eval_dataset_cfg: utils.DatasetConfig, inference_evaluator_cls: utils.EvaluatorConstructor, model: models.BaseModel, partitioner: partitioning.BasePartitioner, log_dir: Optional[str] = None, verify_matching_vocabs_fn: Optional[ Callable[[utils.DatasetConfig, models.BaseTransformerModel], None] ] = utils.verify_matching_vocabs, ): """Constructs inference evaluator. Args: infer_eval_dataset_cfg: Specification for the dataset to evaluate with using the inference metrics (e.g., uses sampled decoding). If None, inference eval is disabled. inference_evaluator_cls: seqio.Evaluator class to use for inference evaluation, potentially with bound configuration args. model: Model to be evaluated. partitioner: the partitioner to use. log_dir: Parent directory to log evaluation results. verify_matching_vocabs_fn: Function to validate whether the task vocabulary matches the model vocabulary. Should raise an exception on error. """ if verify_matching_vocabs_fn and isinstance( model, models.BaseTransformerModel ): verify_matching_vocabs_fn(infer_eval_dataset_cfg, model) self._model = model self._partitioner = partitioner self._infer_eval_dataset_cfg = infer_eval_dataset_cfg kwargs = {} if log_dir: kwargs['log_dir'] = os.path.join(log_dir, 'inference_eval') else: # Disable loggers if log dir is not provided. kwargs['logger_cls'] = () self._seqio_evaluator = inference_evaluator_cls( mixture_or_task_name=infer_eval_dataset_cfg.mixture_or_task_name, feature_converter=model.FEATURE_CONVERTER_CLS(pack=False), eval_split=infer_eval_dataset_cfg.split, use_cached=infer_eval_dataset_cfg.use_cached, seed=infer_eval_dataset_cfg.seed, sequence_length=infer_eval_dataset_cfg.task_feature_lengths, use_memory_cache=infer_eval_dataset_cfg.use_memory_cache, **kwargs, ) # Lazily initialized upon the first `evaluate` call. self._predict_fn = None self._predict_with_aux_fn = None self._score_fn = None @property def model_feature_shapes(self) -> Mapping[str, Tuple[int, ...]]: return self._seqio_evaluator.model_feature_shapes @property def eval_tasks(self) -> Sequence[seqio.Task]: return self._seqio_evaluator.eval_tasks def close(self): self._seqio_evaluator.close()
[docs] def evaluate( self, train_state: train_state_lib.TrainState, train_state_axes: train_state_lib.TrainState, ) -> seqio.evaluation.AllMetricsFuture: """Runs the prediction based inference eval. Args: train_state: Training state to run evaluation of. train_state_axes: partitioning info for the train state to be used. Returns: A dictionary of training eval metrics. """ if not self._predict_fn: self._predict_fn = utils.get_infer_fn( infer_step=self._model.predict_batch, batch_size=self._infer_eval_dataset_cfg.batch_size, train_state_axes=train_state_axes, partitioner=self._partitioner, ) self._predict_with_aux_fn = utils.get_infer_fn( infer_step=self._model.predict_batch_with_aux, batch_size=self._infer_eval_dataset_cfg.batch_size, train_state_axes=train_state_axes, partitioner=self._partitioner, ) self._score_fn = utils.get_infer_fn( infer_step=self._model.score_batch, batch_size=self._infer_eval_dataset_cfg.batch_size, train_state_axes=train_state_axes, partitioner=self._partitioner, ) all_metrics, _ = self._seqio_evaluator.evaluate( compute_metrics=jax.process_index() == 0, step=int(utils.get_local_data(train_state.step)), predict_fn=functools.partial( self._predict_fn, train_state=train_state, rng=jax.random.PRNGKey(0) ), score_fn=functools.partial(self._score_fn, train_state=train_state), predict_with_aux_fn=functools.partial( self._predict_with_aux_fn, train_state=train_state, rng=jax.random.PRNGKey(0), ), ) return all_metrics
def _sorted_ckpt_paths(ckpt_paths: Collection[str]) -> Sequence[str]: def _extract_ckpt_step(ckpt_path: str) -> int: # Steps may be prefixed with "checkpoint_", "model.ckpt-" or nothing. match = re.search(r'(checkpoint_|model.ckpt-)?(\d+)\/?$', ckpt_path) if match is None: raise ValueError(f'Invalid checkpoint path: {ckpt_path}') assert match is not None return int(match.group(2)) return sorted(ckpt_paths, key=_extract_ckpt_step) def _load_evaluated_ckpt_paths(eval_ckpt_path: str) -> Set[str]: if not gfile.exists(eval_ckpt_path): return set() with gfile.GFile(eval_ckpt_path, 'r') as f: return set(f.read().split())
[docs]def evaluate( *, model: models.BaseTransformerModel, dataset_cfg: utils.DatasetConfig, restore_checkpoint_cfg: utils.RestoreCheckpointConfig, partitioner: partitioning.BasePartitioner, output_dir: str, inference_evaluator_cls: Optional[ utils.EvaluatorConstructor ] = seqio.Evaluator, training_evaluator_cls: Optional[Type[trainer_lib.Trainer]] = None, summarize_config_fn: SummarizeConfigFn = gin_utils.summarize_gin_config, train_state_initializer_cls: Type[ utils.TrainStateInitializer ] = utils.TrainStateInitializer, train_eval_get_dataset_fn: utils.GetEvalDatasetCallable = utils.get_training_eval_datasets, fallback_init_rng: Optional[int] = None, use_orbax: bool = False, ): """Evaluation function. Args: model: The model object to use for inference. dataset_cfg: Specification for the dataset to infer based on. restore_checkpoint_cfg: Specification for the model parameter checkpoint to load. partitioner: Partitioner for the model parameters and data across devices. output_dir: Path to directory to write temporary files and final results. inference_evaluator_cls: seqio.Evaluator class to use for inference evaluation, potentially with bound configuration args. training_evaluator_cls: an optional Trainer class to use for training evaluation, potentially with bound configuration args. summarize_config_fn: A function that takes in the model directory, an optional SummaryWriter, and the step number, and writes a summary of the configuration. SummaryWriter will be None in most cases. train_state_initializer_cls: t5x.utils.TrainStateInitializer class for initializing partitioned TrainState from checkpoints or scratch. train_eval_get_dataset_fn: Optional callable use to get the train-eval datasets based on the DatasetConfig and shard information. If missing, it defaults to `utils.get_training_eval_datasets`. fallback_init_rng: A random seed used for parameter initialization during model re-loading when utils.RestoreCheckpointConfig.fallback_to_scratch is set to True. If None, parameter initialization is not allowed during model loading and having fallback_to_scratch enabled will result in an error. use_orbax: if True, uses Orbax for checkpointing. Experimental feature. """ jax.monitoring.record_event('/jax/t5x/evaluate/beacon') logging.info('Process ID: %d', jax.process_index()) if dataset_cfg.module: utils.import_module(dataset_cfg.module) batch_size = dataset_cfg.batch_size summarize_config_fn(model_dir=output_dir, summary_writer=None, step=0) evaluator = InferenceEvaluator( dataset_cfg, inference_evaluator_cls, model, partitioner, log_dir=output_dir, ) if not evaluator.eval_tasks: raise ValueError( f"'{dataset_cfg.mixture_or_task_name}' has no metrics for evaluation, " "or this mixture/task doesn't have provided split." ) # ---------------------------------------------------------------------------- # T5X model loading. # ---------------------------------------------------------------------------- # Initialize optimizer from the existing checkpoint. input_shapes = { k: (batch_size,) + s for k, s in evaluator.model_feature_shapes.items() } train_state_initializer = train_state_initializer_cls( optimizer_def=None, # Do not load optimizer state. init_fn=model.get_initial_variables, input_shapes=input_shapes, partitioner=partitioner, ) train_state_axes = train_state_initializer.train_state_axes # Log the variable shapes information and write to a file. log_file = os.path.join(output_dir, 'model-info.txt') utils.log_model_info( log_file, train_state_initializer.global_train_state_shape, partitioner ) if training_evaluator_cls: data_layout = partitioner.get_data_layout(dataset_cfg.batch_size) train_eval_datasets = train_eval_get_dataset_fn( # pytype:disable=missing-parameter dataset_cfg, data_layout.shard_id, data_layout.num_shards, feature_converter_cls=model.FEATURE_CONVERTER_CLS, ) train_evaluator = training_evaluator_cls( # pytype:disable=wrong-arg-types model=model, train_state=None, # Will replace later. partitioner=partitioner, train_state_axes=train_state_axes, eval_names=train_eval_datasets.keys(), summary_dir=output_dir, rng=jax.random.PRNGKey(0), # unused learning_rate_fn=None, # unused num_microbatches=None, # unused ) def _maybe_run_train_eval(train_state: train_state_lib.TrainState): if training_evaluator_cls: train_evaluator.train_state = train_state train_evaluator.eval( { task: ( ds.as_numpy_iterator() if isinstance(ds, tf.data.Dataset) else ds ) for task, ds in train_eval_datasets.items() } ) # Disable strictness since we are dropping the optimizer state. restore_checkpoint_cfg.strict = False # Skip checkpoints that have already been evaluated. eval_ckpt_path = os.path.join( output_dir, f'eval.{dataset_cfg.mixture_or_task_name}.ckpt' ) if restore_checkpoint_cfg.mode == 'all' and gfile.exists(eval_ckpt_path): logging.info('Found evaluation checkpoint: %s', eval_ckpt_path) ckpt_dirs = ( [restore_checkpoint_cfg.path] if isinstance(restore_checkpoint_cfg.path, str) else restore_checkpoint_cfg.path ) ckpt_paths = set() for ckpt_dir in ckpt_dirs: if not gfile.isdir(ckpt_dir): raise ValueError( f"Checkpoint path '{ckpt_dir}' must be a valid directory when " "using restore mode 'all'." ) ckpt_paths.update( checkpoints.get_checkpoint_dir(ckpt_dir, step) for step in checkpoints.all_steps(ckpt_dir) ) evaluated_ckpt_paths = _load_evaluated_ckpt_paths(eval_ckpt_path) logging.info( 'Skipping evaluated checkpoints:\n %s', '\n '.join(_sorted_ckpt_paths(ckpt_paths & evaluated_ckpt_paths)), ) restore_checkpoint_cfg.mode = 'specific' restore_checkpoint_cfg.path = _sorted_ckpt_paths( ckpt_paths - evaluated_ckpt_paths ) if fallback_init_rng is not None: fallback_init_rng = jax.random.PRNGKey(fallback_init_rng) restore_cfg, ckpt_paths = utils.get_first_valid_restore_config_and_paths( [restore_checkpoint_cfg] ) for ckpt_path in ckpt_paths: train_state, _ = utils.create_checkpoint_manager_and_restore( train_state_initializer, partitioner, restore_cfg, ckpt_path, fallback_init_rng, use_orbax=use_orbax, ) if train_state is None: raise ValueError('Failed to restore checkpoint.') # ---------------------------------------------------------------------------- # Main evaluation loop # ---------------------------------------------------------------------------- # Run final evaluation (with decoding) on the full eval dataset. host_step = int(utils.get_local_data(train_state.step)) _maybe_run_train_eval(train_state) all_metrics = evaluator.evaluate(train_state, train_state_axes) all_metrics.result() # Ensure metrics are finished being computed. # Wait until computations are done before continuing. utils.sync_global_devices(f'step_{host_step}:complete') if jax.process_index() == 0: # Read/write/replace rather than append to avoid filesystem issue. evaluated_ckpt_paths = _load_evaluated_ckpt_paths(eval_ckpt_path) evaluated_ckpt_paths.add(ckpt_path) with gfile.GFile(eval_ckpt_path, 'w') as f: f.write('\n'.join(_sorted_ckpt_paths(evaluated_ckpt_paths))) logging.info('Finished.')
if __name__ == '__main__': # pylint:disable=g-import-not-at-top from absl import app from absl import flags import fiddle as fdl import gin from t5x import config_utils FLAGS = flags.FLAGS flags.DEFINE_multi_string( 'gin_file', default=None, help=( 'Path to gin configuration file. Multiple paths may be passed and ' 'will be imported in the given order, with later configurations ' 'overriding earlier ones.' ), ) flags.DEFINE_multi_string( 'gin_bindings', default=[], help='Individual gin bindings.' ) flags.DEFINE_list( 'gin_search_paths', default=['.'], help=( 'Comma-separated list of gin config path prefixes to be prepended ' 'to suffixes given via `--gin_file`. If a file appears in. Only the ' 'first prefix that produces a valid path for each suffix will be ' 'used.' ), ) flags.DEFINE_string( 'tfds_data_dir', None, 'If set, this directory will be used to store datasets prepared by ' 'TensorFlow Datasets that are not available in the public TFDS GCS ' 'bucket. Note that this flag overrides the `tfds_data_dir` attribute of ' 'all `Task`s.', ) def main(argv: Sequence[str]): """Wrapper for pdb post mortems.""" _main(argv) def _main(argv: Sequence[str]): """True main function.""" if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') if FLAGS.tfds_data_dir: seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir) if config_utils.using_fdl(): config = config_utils.config_with_fiddle(evaluate) evaluate_using_fiddle = fdl.build(config) evaluate_using_fiddle() else: # Create gin-configurable version of `eval`. evaluate_using_gin = gin.configurable(evaluate) gin_utils.parse_gin_flags( # User-provided gin paths take precedence if relative paths conflict. FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, FLAGS.gin_file, FLAGS.gin_bindings, ) evaluate_using_gin() config_utils.run(main)