Source code for t5x.interactive_model

# Copyright 2023 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""InteractiveModel class for use in T5X Colabs.

The InteractiveModel can be used to run training, inference, and evaluation on
natural text inputs and targets.

"""

import abc
from collections.abc import Mapping, Sequence
import enum
import functools
import inspect
import itertools
import logging
import os
import re
from typing import Any, Callable, Iterator, Optional, Tuple, Union

import clu.data.dataset_iterator
import jax
from jax import random
from jax.experimental import multihost_utils
import numpy as np
import seqio
from t5x import checkpoints
from t5x import models
from t5x import partitioning
from t5x import trainer as trainer_lib
from t5x import utils
from t5x.infer import _extract_tokens_and_aux_values
from t5x.infer import _Inferences
import tensorflow as tf
import tensorflow_datasets as tfds

BatchesType = Union[
    Sequence[Mapping[str, str]], Sequence[Sequence[Mapping[str, str]]]
]


[docs]class InferenceType(enum.Enum): PREDICT_WITH_AUX = 1 SCORE = 2
[docs]class T5XScriptType(enum.Enum): FINETUNING = 1 INFERENCE = 2 EVALUATION = 3 PRETRAINING = 4
[docs]class InteractiveModel(abc.ABC): """Wrapper around T5X components to enable interactive train/infer/eval.""" def __init__( self, batch_size: int, task_feature_lengths: Mapping[str, int], output_dir: str, partitioner: partitioning.BasePartitioner, model: models.BaseTransformerModel, dtype: Optional[str], restore_mode: str, checkpoint_path: str, input_shapes: Mapping[str, utils.Array], input_types: Optional[Mapping[str, utils.DType]] = None, init_random_seed: int = 42, add_eos: bool = True, eval_names: Optional[Sequence[str]] = None, ): """Init function. Configures the output directory, RNGs, and partitioner given the provided arguments. Args: batch_size: number of examples per batch for training, inference, and evaluation. task_feature_lengths: dictionary mapping feature key to maximum length (int) for that feature. If feature is longer than this length after preprocessing, the feature will be truncated. May be set to None to avoid truncation. output_dir: Path to directory where we will write temporary files and final results. partitioner: the partitioner that defines how we divide and replicate machine learning model parameters, activations, and data across the accelerator devices (TPU/GPU). See https://github.com/google-research/t5x/blob/main/docs/usage.md/partitioning for details. model: the model object to use for training, inference, and evaluation. dtype: The dtype to restore ('float32' or 'bfloat16'), or None to load as saved. restore_mode: One of 'specific', 'latest', or 'all'. `specific` loads the checkpoint specified by `path`. `latest` loads the most recent checkpoint in the directory specified by `path`. `all` sequentially loads all of checkpoints in the directory `path`. checkpoint_path: Path(s) to checkpoint to restore from or directory (depending on `restore_mode`). input_shapes: a mapping from key to array shape for each feature in the global (unsharded) input batch. input_types: a mapping from key to array type for each feature in the global (unshared) input batch. If not provided, the type is assumed to be `jnp.float32`. init_random_seed: the random seed used to initialize all RNGs. add_eos: whether or not to add the EOS token to inputs/targets. eval_names: names of evaluation datasets, which must match the keys of the mapping passed to trainer's `eval` method. Raises: ValueError: the partitioner has an incorrect submesh, or the checkpoint restore function returned a sequence of TrainStates, when it should have returned a single TrainState. """ self._batch_size = batch_size self._task_feature_lengths = task_feature_lengths self._cached_infer_fns = {} # -------------------------------------------------------------------------- # Configure the output directory # -------------------------------------------------------------------------- self._output_dir = output_dir # Remove double-slashes in directory path to avoid inconsistencies. self._output_dir = re.sub(r"(?<!gs:)([\/]{2,})", "/", self._output_dir) if not os.path.exists(self._output_dir): os.mkdir(self._output_dir) # -------------------------------------------------------------------------- # Initialize RNGs # -------------------------------------------------------------------------- self._init_random_seed = init_random_seed random_seed = multihost_utils.broadcast_one_to_all( np.int32(self._init_random_seed) ) utils.set_hardware_rng_ops() rng = random.PRNGKey(random_seed) self._init_rng, self._trainer_rng = random.split(rng, 2) # -------------------------------------------------------------------------- # Initialize the partitioner. # -------------------------------------------------------------------------- if partitioner._model_parallel_submesh: num_partitions = np.prod(partitioner._model_parallel_submesh) else: num_partitions = partitioner._num_partitions if jax.device_count() % num_partitions != 0: raise ValueError( "The number of devices available must be a multiple of the number of", f" partitions. There are {jax.device_count()} devices available, but", f" the number of partitions is set to {num_partitions}. Please", " provide a different number of partitions.", ) self._partitioner = partitioner # -------------------------------------------------------------------------- # Create and save a checkpoint manager. # -------------------------------------------------------------------------- logging.info("Initializing model, optimizer, and step functions.") self._model = model self._feature_converter = self._model.FEATURE_CONVERTER_CLS(pack=False) self._input_shapes = input_shapes self._input_types = input_types # Save the model vocabulary as features. output_features = { "inputs": seqio.Feature( vocabulary=self._model.input_vocabulary, add_eos=add_eos ), "targets": seqio.Feature( vocabulary=self._model.output_vocabulary, add_eos=add_eos ), } self._features = dict(sorted(output_features.items())) # Define restore and save checkpoints. if checkpoint_path: self._restore_checkpoint_cfg = utils.RestoreCheckpointConfig( dtype=dtype, mode=restore_mode, path=checkpoint_path ) else: self._restore_checkpoint_cfg = None self._save_checkpoint_cfg = utils.SaveCheckpointConfig( dtype=dtype, keep=5, save_dataset=False, period=1000 ) self._train_state_initializer = utils.TrainStateInitializer( optimizer_def=self._model.optimizer_def, init_fn=self._model.get_initial_variables, input_shapes=self._input_shapes, input_types=self._input_types, partitioner=self._partitioner, ) # Initialize checkpoint manager. self._checkpoint_manager = utils.LegacyCheckpointManager( save_cfg=self._save_checkpoint_cfg, restore_cfg=self._restore_checkpoint_cfg, train_state_shape=( self._train_state_initializer.global_train_state_shape ), partitioner=self._partitioner, ds_iter=None, model_dir=self._output_dir, ) # -------------------------------------------------------------------------- # Restore a model from a checkpoint or from scratch. # -------------------------------------------------------------------------- def get_state(rng): return self._train_state_initializer.from_scratch(rng).state_dict() restore_cfgs = [] # 1. From a checkpoint specified by `self._restore_checkpoint_cfg.path`, if # set. if self._restore_checkpoint_cfg: restore_cfgs.append(self._restore_checkpoint_cfg) # 2. If no checkpoint provided, look for one in the model directory. if self._restore_checkpoint_cfg is not None: state_transforms_for_restore = [ functools.partial(fn, is_resuming=True) for fn in self._restore_checkpoint_cfg.state_transformation_fns ] else: state_transforms_for_restore = [] restore_cfgs.append( utils.RestoreCheckpointConfig( path=self._output_dir, mode="latest", dtype=self._save_checkpoint_cfg.dtype if self._save_checkpoint_cfg else "float32", checkpointer_cls=self._save_checkpoint_cfg.checkpointer_cls if self._save_checkpoint_cfg else checkpoints.Checkpointer, # Restore dataset state if it is being saved. restore_dataset=( self._save_checkpoint_cfg and self._save_checkpoint_cfg.save_dataset ), state_transformation_fns=state_transforms_for_restore, ) ) # Restore the model using a checkpoint. valid_restore_cfg, restore_paths = ( utils.get_first_valid_restore_config_and_paths(restore_cfgs) ) self._train_state = self._checkpoint_manager.restore( restore_paths, valid_restore_cfg, utils.get_fallback_state(valid_restore_cfg, get_state, self._init_rng), ) # 3. If no checkpoint to restore, init from scratch. if self._train_state is None: self._train_state = self._train_state_initializer.from_scratch( self._init_rng ) self._train_state_axes = self._train_state_initializer.train_state_axes # Log the variable shapes information and write to a file. log_file = os.path.join(self._output_dir, "model-info.txt") utils.log_model_info( log_file, self._train_state_initializer.global_train_state_shape, self._partitioner, ) # -------------------------------------------------------------------------- # Trainer # -------------------------------------------------------------------------- if isinstance(self._train_state, Sequence): raise ValueError( "Expected a single train state, but instead received a Sequence." ) self._trainer = trainer_lib.Trainer( model=self._model, train_state=self._train_state, partitioner=self._partitioner, eval_names=eval_names if eval_names else [], summary_dir=self._output_dir, train_state_axes=self._train_state_axes, rng=self._trainer_rng, learning_rate_fn=utils.create_learning_rate_scheduler(), num_microbatches=None, ) @property def trainer(self): return self._trainer @property def partitioner(self): return self._partitioner @property def model(self): return self._model @property def train_state(self): return self._train_state @property def train_state_axes(self): return self._train_state_axes @property def train_summary(self): return self._train_summary.result() @property def step(self): if isinstance(self._train_state, Sequence): raise ValueError( "Expected a single train state, but instead received a Sequence." ) return int(self._train_state.step)
[docs] def train_step(self, examples: Sequence[Union[str, dict[str, str]]]): """Train function. Args: examples: examples that should be transformed into a tf.data.Dataset. The examples can either take the form of a string (ex: a single input for inference), or a dictionary mapping "input"/"target" to a string containing that element. At least `self._batch_size` examples must be provided. Raises: ValueError: the user provided less than `batch_size` examples, or `self._train_state` was set to a sequence of TrainStates, when it should have been a single TrainState. """ # By default, only tokenize and append EOS. preprocessors = [ seqio.preprocessors.tokenize, seqio.preprocessors.append_eos, ] self.train_step_with_preprocessors( examples=examples, preprocessors=preprocessors )
[docs] def train_step_with_preprocessors( self, examples: Sequence[Union[str, dict[str, str]]], preprocessors: Sequence[Callable[..., tf.data.Dataset]], ): """Train function. Args: examples: examples that should be transformed into a tf.data.Dataset. The examples can either take the form of a string (ex: a single input for inference), or a dictionary mapping "input"/"target" to a string containing that element. At least `self._batch_size` examples must be provided. preprocessors: list(callable), an optional list of functions that receive a tf.data.Dataset and return a tf.data.Dataset. These will be executed sequentially and the final dataset must include features matching `self._features`. Raises: ValueError: the user provided less than `batch_size` examples, or `self._train_state` was set to a sequence of TrainStates, when it should have been a single TrainState. """ # -------------------------------------------------------------------------- # Initialize dataset and dataset iterator # -------------------------------------------------------------------------- if len(examples) < self._batch_size: raise ValueError( "At least one batch of data must be provided. Please decrease the " "batch_size or provide more examples." ) train_dataset = get_dataset_from_natural_text_examples( examples, preprocessors=preprocessors, task_feature_lengths=self._task_feature_lengths, features=self._features, ) train_dataset = self._feature_converter( train_dataset, task_feature_lengths=self._task_feature_lengths ) train_dataset = train_dataset.padded_batch( self._batch_size, drop_remainder=True ) train_iter = clu.data.dataset_iterator.TfDatasetIterator( train_dataset, checkpoint=True ) # -------------------------------------------------------------------------- # Take 1 train step. # -------------------------------------------------------------------------- # `stop_training` is requested, break out the main loop immediately. if self._trainer.stop_training: logging.info( "Stopping training early since `stop_training` is requested." ) return try: self.train_step_from_batch_iterator(train_iter) except trainer_lib.PreemptionError as e: logging.info("Saving emergency checkpoint.") self.save_checkpoint() logging.info("Saving emergency checkpoint done.") raise e # Save a checkpoint. logging.info("Saving checkpoint.") self.save_checkpoint()
[docs] def train_step_from_batch_iterator(self, iterator): """Runs one training step from a batch iterator.""" if isinstance(self._train_state, Sequence): raise ValueError( "Expected a single train state, but instead received a Sequence." ) first_step = int(utils.get_local_data(self._train_state.step)) self._train_summary = self._trainer.train( iterator, 1, start_step=first_step ) # Wait until computations are done before exiting utils.sync_global_devices("complete") self._train_state = self._trainer.train_state
[docs] def save_checkpoint(self): """Saves model checkpoint.""" self._checkpoint_manager.save( self._trainer.train_state, self._save_checkpoint_cfg.state_transformation_fns, )
[docs] def infer_with_preprocessors( self, mode: InferenceType, examples: Sequence[Union[str, dict[str, str]]], preprocessors: Sequence[Callable[..., tf.data.Dataset]], **inference_kwargs, ) -> _Inferences: """Infer function. Args: mode: Either 'score' to compute the log likelihood of given targets, or 'predict_with_aux' to score and decode targets. examples: examples that should be transformed into a tf.data.Dataset. The examples can either take the form of a string (ex: a single input for inference), or a dictionary mapping "input"/"target" to a string containing that element. preprocessors: list(callable), an optional list of functions that receive a tf.data.Dataset and return a tf.data.Dataset. These will be executed sequentially and the final dataset must include features matching `self._features`. **inference_kwargs: additional keyword arguments to pass to the inference function (e.g., `model.predict_batch_with_aux` or `score_batch`). Returns: Returns a tuple of predictions/scores and any auxiliary values. """ # -------------------------------------------------------------------------- # Parse Mode # -------------------------------------------------------------------------- if mode == InferenceType.PREDICT_WITH_AUX: infer_step = self._model.predict_batch_with_aux elif mode == InferenceType.SCORE: infer_step = self._model.score_batch else: raise ValueError( "Mode must be `predict_with_aux`, or `score`," f" but instead was {mode}." ) key_array = seqio.utils.flatten_dict(inference_kwargs) key_array["mode"] = mode infer_fn_key = tuple(key_array.items()) if infer_fn_key not in self._cached_infer_fns: self._cached_infer_fns[infer_fn_key] = utils.get_infer_fn( infer_step=functools.partial(infer_step, **inference_kwargs), batch_size=self._batch_size, train_state_axes=self._train_state_initializer.train_state_axes, partitioner=self._partitioner, ) infer_fn = functools.partial( self._cached_infer_fns[infer_fn_key], train_state=self._train_state, ) # -------------------------------------------------------------------------- # Construct a dataset and dataset iterator. # -------------------------------------------------------------------------- dataset = get_dataset_from_natural_text_examples( examples, preprocessors=preprocessors, task_feature_lengths=self._task_feature_lengths, features=self._features, ) model_dataset = self._feature_converter( dataset, task_feature_lengths=self._task_feature_lengths ) # Zip task and model features. infer_dataset = tf.data.Dataset.zip((dataset, model_dataset)) # Create batches and index them. infer_dataset = infer_dataset.padded_batch( self._batch_size, drop_remainder=False ).enumerate() infer_dataset_iter: Iterator[Tuple[int, Any]] = iter( infer_dataset.prefetch(tf.data.experimental.AUTOTUNE) ) # -------------------------------------------------------------------------- # Run inference # -------------------------------------------------------------------------- # Main Loop over "batches". all_inferences = [] all_aux_values = {} for chunk, chunk_batch in infer_dataset_iter: # Load the dataset for the next chunk. We can't use `infer_dataset_iter` # directly since `infer_fn` needs to know the exact size of each chunk, # which may be smaller for the final one. chunk_dataset = tf.data.Dataset.from_tensor_slices(chunk_batch) chunk_dataset.cache().prefetch(tf.data.experimental.AUTOTUNE) # Unzip chunk dataset in to pretokenized and model datasets. task_dataset = chunk_dataset.map( lambda p, m: p, num_parallel_calls=tf.data.experimental.AUTOTUNE ) model_dataset = chunk_dataset.map( lambda p, m: m, num_parallel_calls=tf.data.experimental.AUTOTUNE ) # Get a chunk-specific RNG key. chunk_rng = jax.random.fold_in(jax.random.PRNGKey(0), chunk) inferences = _extract_tokens_and_aux_values( infer_fn(model_dataset.enumerate(), rng=chunk_rng) ) predictions, aux_values = inferences accumulated_inferences = [] for idx, inputs in task_dataset.enumerate().as_numpy_iterator(): prediction = predictions[idx] # Decode predictions if applicable. if mode == InferenceType.PREDICT_WITH_AUX: prediction = ( self._features["targets"] .vocabulary.decode_tf(tf.constant(prediction)) .numpy() ) accumulated_inferences.append((inputs, prediction)) all_inferences += accumulated_inferences # Accumulate aux values over batches. if not all_aux_values: all_aux_values = aux_values else: for key, values in aux_values.items(): all_aux_values[key] += values return all_inferences, all_aux_values
[docs] def predict_with_aux( self, examples: Sequence[Union[str, dict[str, str]]] ) -> _Inferences: """Predict with auxiliary values method.""" # By default, only tokenize and append EOS. preprocessors = [ seqio.preprocessors.tokenize, seqio.preprocessors.append_eos, ] return self.infer_with_preprocessors( mode=InferenceType.PREDICT_WITH_AUX, examples=examples, preprocessors=preprocessors, )
[docs] def score( self, examples: Sequence[Union[str, dict[str, str]]] ) -> Sequence[Any]: """Score method.""" # By default, only tokenize and append EOS. preprocessors = [ seqio.preprocessors.tokenize, seqio.preprocessors.append_eos, ] # Ignore auxiliary values. scores, _ = self.infer_with_preprocessors( mode=InferenceType.SCORE, examples=examples, preprocessors=preprocessors ) return scores
def _compute_metrics( self, targets: Sequence[Any], predictions: Sequence[Any], aux_values: Sequence[Any], scores: Sequence[Any], predict_metric_fns: Sequence[seqio.dataset_providers.MetricFnCallable], predict_with_aux_metric_fns: Sequence[ seqio.dataset_providers.MetricFnCallable ], score_metric_fns: Sequence[seqio.dataset_providers.MetricFnCallable], ): """Computes the metrics specified in the metric_fns lists.""" # Only compute metrics once if jax.process_index() != 0: return {} def compute_metrics_fn(): task_metrics = [] if predict_metric_fns: task_metrics.extend( [ metric_fn(targets, predictions) for metric_fn in predict_metric_fns ] ) if predict_with_aux_metric_fns: task_metrics.extend( [ metric_fn(targets, predictions, aux_values) for metric_fn in predict_with_aux_metric_fns ] ) if score_metric_fns: is_tuple = isinstance(scores, tuple) if (not is_tuple and len(targets) != len(scores)) or ( is_tuple and len(targets) != len(scores[0]) ): raise ValueError( f"len(targets)({len(targets)}) != " f"len(output_scores)({len(scores)})" ) task_metrics.extend( [metric_fn(targets, scores) for metric_fn in score_metric_fns] ) all_metrics = {} for k, v in itertools.chain(*[m.items() for m in task_metrics]): if k in all_metrics: raise ValueError(f"Duplicate metric key '{k}' in Task.") all_metrics[k] = v return all_metrics if not tf.executing_eagerly(): def wrap_graph(fn): graph = tf.compat.v1.get_default_graph() def wrapped_fn(): with graph.as_default(): return fn() return wrapped_fn compute_metrics_fn = wrap_graph(compute_metrics_fn) all_metrics = compute_metrics_fn() # Wait until computations are done before continuing. utils.sync_global_devices("Completed.") return all_metrics
[docs] def evaluate( self, examples: Sequence[Union[str, dict[str, str]]], metric_fns: Sequence[seqio.dataset_providers.MetricFnCallable], ) -> Mapping[Any, Any]: """Evaluation function. Args: examples: examples that should be transformed into a tf.data.Dataset. The examples can either take the form of a string (ex: a single input for inference), or a dictionary mapping "input"/"target" to a string containing that element. metric_fns: list(callable), an optional list of metric functions with a signature that matches one of three possible forms: - (targets, scores) - Note that `scores` refers to the score the model assigned the target sequence, given the input. - (targets, predictions) - (targets, predictions, aux_values) - Note that `aux_values` refers to a dictionary of auxiliary values that the model assigned to each sequence. Returns: Mapping of metrics names to metrics values. """ # By default, only tokenize and append EOS. preprocessors = [ seqio.preprocessors.tokenize, seqio.preprocessors.append_eos, ] return self.evaluate_with_preprocessors( examples=examples, preprocessors=preprocessors, metric_fns=metric_fns, postprocessor=None, )
[docs] def evaluate_with_preprocessors( self, examples: Sequence[dict[str, str]], preprocessors: Sequence[Callable[..., tf.data.Dataset]], metric_fns: Sequence[seqio.dataset_providers.MetricFnCallable], postprocessor: Optional[Callable[..., Any]] = None, ) -> Mapping[Any, Any]: """Evaluation function. Args: examples: examples that should be transformed into a tf.data.Dataset. The examples must take the form of a dictionary mapping "input"/"target" to a string containing that element. preprocessors: list(callable), an optional list of functions that receive a tf.data.Dataset and return a tf.data.Dataset. These will be executed sequentially and the final dataset must include features matching `self._features`. metric_fns: list(callable), an optional list of metric functions with a signature that matches one of three possible forms: - (targets, scores) - Note that `scores` refers to the score the model assigned the target sequence, given the input. - (targets, predictions) - (targets, predictions, aux_values) - Note that `aux_values` refers to a dictionary of auxiliary values that the model assigned to each sequence. postprocessor: callable, an optional function that receives decoded model outputs and converts them to a form that is ready for evaluation using the metric functions in `metric_fns`. Returns: Mapping of metrics names to metrics values. """ # -------------------------------------------------------------------------- # Parse Metrics functions # -------------------------------------------------------------------------- predict_metric_fns = [] predict_with_aux_metric_fns = [] score_metric_fns = [] for metric_fn in metric_fns: pos_args = tuple( key for key, param in inspect.signature(metric_fn).parameters.items() if param.default == inspect.Parameter.empty ) if pos_args == ("targets", "scores"): score_metric_fns.append(metric_fn) elif pos_args == ("targets", "predictions"): predict_metric_fns.append(metric_fn) elif pos_args == ("targets", "predictions", "aux_values"): predict_with_aux_metric_fns.append(metric_fn) else: raise ValueError( "Metric functions must have positional arguments matching either " "('targets', 'scores'), ('targets', 'predictions') or " "('targets', 'predictions', 'aux_values'). " f"Got: {pos_args}" ) # ------------------------------------------------------------------------ # Get targets, predictions, and scores # ------------------------------------------------------------------------ dataset = get_dataset_from_natural_text_examples( examples, preprocessors=preprocessors, task_feature_lengths=self._task_feature_lengths, features=self._features, ) # Get targets. def postprocess_fn(decoded_model_output: Any, **postprocess_kwargs) -> Any: """Returns the model output after applying the postprocess function.""" if postprocessor: return postprocessor(decoded_model_output, **postprocess_kwargs) return decoded_model_output targets = [] for ex in tfds.as_numpy(dataset): targets.append( postprocess_fn( decoded_model_output=ex["targets_pretokenized"], example=ex, is_target=True, ) ) # Get predictions. predictions = [] if predict_with_aux_metric_fns or predict_metric_fns: predictions, aux_values = self.infer_with_preprocessors( mode=InferenceType.PREDICT_WITH_AUX, examples=examples, preprocessors=preprocessors, ) predictions = [ prediction.decode("utf-8") for example, prediction in predictions ] # Get scores. scores = [] if score_metric_fns: scores, _ = self.infer_with_preprocessors( mode=InferenceType.SCORE, examples=examples, preprocessors=preprocessors, ) scores = [score for example, score in scores] return self._compute_metrics( targets, predictions, aux_values, scores, # pytype: disable=wrong-arg-types # mapping-is-not-sequence predict_metric_fns, predict_with_aux_metric_fns, score_metric_fns, )
[docs] def train_loop( self, num_steps: int, eval_period: Optional[int] = 1, train_batches: Optional[BatchesType] = None, predict_batches: Optional[BatchesType] = None, score_batches: Optional[BatchesType] = None, eval_batches: Optional[BatchesType] = None, metrics_fns: Optional[ Sequence[seqio.dataset_providers.MetricFnCallable] ] = None, ): """Runs training, inference, and evaluation for `num_steps`. It should be noted that there are many different possible variants of the `train_loop` function that a user might want to use. The primary goal of the `train_loop` function is not to cover all the potential training loop variants that a user may want; rather, the goal is to demonstrate how the user could stack the `InteractiveModel` train, predict, score, and evaluate methods. Args: num_steps: the number of steps to run for training, inference, and evaluation. eval_period: specifies how many steps to take between inference/evaluation. train_batches: an optional list of batches that we should run training on. If no batches are provided, then training will be skipped. If a single batch is provided, we will repeat training on this batch for `num_steps`. predict_batches: an optional list of batches that we should get predictions for. If no batches are provided, then predicting will be skipped. If a single batch is provided, we will repeatedly get predictions on this batch for `num_steps`. score_batches: an optional list of batches that we should score. If no batches are provided, then scoring will be skipped. If a single batch is provided, we will repeatedly score this batch for `num_steps`. eval_batches: an optional list of batches that we should run eval on. If no batches are provided, then evaluation will be skipped. If a single batch is provided, we will repeatedly evaluate this batch for `num_steps`. metrics_fns: list(callable), an optional list of metric functions with a signature that matches one of three possible forms: - (targets, scores) - Note that `scores` refers to the score the model assigned the target sequence, given the input. - (targets, predictions) - (targets, predictions, aux_values) - Note that `aux_values` refers to a dictionary of auxiliary values that the model assigned to each sequence. Returns: Predictions, scores, and metrics for the final step of the training loop. """ # Ensure all batches are `num_steps` in length train_batches = _get_equal_length_batches(train_batches, num_steps) predictions = None scores = None metrics = None for step_num, train_batch in enumerate(train_batches): if train_batch: self.train_step(train_batch) # Run inference/evaluation every `eval_period` steps. if step_num % eval_period == 0: # Run on all batches for inference/evaluation. if predict_batches: for predict_batch in predict_batches: predictions, _ = self.predict_with_aux(predict_batch) # pytype: disable=wrong-arg-types # mapping-is-not-sequence if score_batches: for score_batch in score_batches: scores = self.score(score_batch) # pytype: disable=wrong-arg-types # mapping-is-not-sequence if eval_batches: for eval_batch in eval_batches: metrics = self.evaluate(eval_batch, metrics_fns) # pytype: disable=wrong-arg-types # mapping-is-not-sequence return predictions, scores, metrics
[docs]def get_dataset_from_natural_text_examples( examples: Sequence[Union[str, dict[str, str]]], preprocessors: Sequence[Callable[..., tf.data.Dataset]], task_feature_lengths: Mapping[str, int], features: Mapping[str, Any], ) -> tf.data.Dataset: """Returns a tf.data.Dataset from a list of examples. Args: examples: a single batch of examples that should be transformed into a tf.data.Dataset. The examples can either take the form of a string (ex: a single input for inference), or a dictionary mapping "input"/"target" to a string containing that element. preprocessors: an optional list of functions that receive a tf.data.Dataset and return a tf.data.Dataset. These will be executed sequentially and the final dataset must include features matching `self._features`. task_feature_lengths: dictionary mapping feature key to maximum length (int) for that feature. If feature is longer than this length after preprocessing, the feature will be truncated. May be set to None to avoid truncation. features: dictionary defining what features should be present in all examples. Returns: A tf.data.Dataset. """ # ------------------------------------------------------------------------ # Construct a `tf.data.Dataset` from the provided examples # ------------------------------------------------------------------------ merged_examples = {"inputs": [], "targets": []} for example in examples: # If the provided example is just a string, add an empty target string if isinstance(example, dict): example_dict = example else: example_dict = {"input": example, "target": ""} merged_examples["inputs"].append(example_dict["input"]) merged_examples["targets"].append(example_dict["target"]) dataset = tf.data.Dataset.from_tensor_slices(merged_examples) # Define `ShardInfo` that doesn't shard the data pipeline. shard_info = seqio.ShardInfo(0, 1) dataset = dataset.shard(shard_info.num_shards, shard_info.index) # ------------------------------------------------------------------------ # Preprocess data # ------------------------------------------------------------------------ for prep_fn in preprocessors: # prep_fn must not rely on variable length keyword args such as **kwargs. fn_args = set(inspect.signature(prep_fn).parameters.keys()) kwargs = {} if "sequence_length" in fn_args: kwargs["sequence_length"] = task_feature_lengths if "output_features" in fn_args: kwargs["output_features"] = features dataset = prep_fn(dataset, **kwargs) def _validate_preprocessing(dataset: tf.data.Dataset) -> tf.data.Dataset: """Validates preprocessed dataset, raising Exceptions if needed. Args: dataset: a tf.data.Dataset to validate. Returns: a validated tf.data.Dataset. Raises: ValueError: dataset has missing feature or the incorrect type/rank for a feature. """ actual_specs = dataset.element_spec for feat, feat_spec in features.items(): if feat not in actual_specs: if feat_spec.required: raise ValueError( "Task dataset is missing expected output feature after " f"preprocessing: {feat}" ) else: # It's ok that this feature does not exist. continue actual_spec = actual_specs[feat] if feat_spec.dtype != actual_spec.dtype: raise ValueError( f"Task dataset has incorrect type for feature '{feat}' after " f"preprocessing: Got {actual_spec.dtype.name}, expected " f"{feat_spec.dtype.name}" ) if feat_spec.rank != actual_spec.shape.rank: raise ValueError( f"Task dataset has incorrect rank for feature '{feat}' after " f"preprocessing: Got {actual_spec.shape.rank}, expected " f"{feat_spec.rank}" ) return dataset dataset = _validate_preprocessing(dataset) dataset = seqio.utils.trim_dataset(dataset, task_feature_lengths, features) return dataset.prefetch(tf.data.experimental.AUTOTUNE)
def _get_equal_length_batches( batches: BatchesType, length: int ) -> Sequence[Any]: """Produces a list of batches that is `length` batches long. Given a single batch, repeat the batch `length` times. Given a list of batches, either repeat the batches to get `length` total batches or take the first 'length' batches. Args: batches: either a single batch of examples, or a list of batches. length: the total number of batches that should be present in the final list. Returns: A list of batches. """ # Given a list of batches, return a list of batches that is `length` long, # either by repeating the batches or taking the first `length` batches if not batches: return [None] * length if isinstance(batches[0], Mapping): return [batches for i in range(length)] if len(batches) < length: batches = batches * (length // len(batches)) # If multiple batches are provided, only use the first `length` batches. logging.warning( "We will only use the first %s batches provided for training.", length ) return batches[:length]
[docs]def get_batches_from_seqio( task_or_mixture_name: str, split: str, batch_size: int, num_batches: int, get_pretokenized_examples: bool = True, sequence_length: Optional[Mapping[str, int]] = None, **get_dataset_kwargs, ) -> Sequence[Sequence[Mapping[str, str]]]: """Returns a batch of examples from a provided SeqIO task. Args: task_or_mixture_name: the SeqIO task/mixture to read data from. split: the split of the SeqIO task/mixture to read data from. batch_size: how many examples should be in each batch. num_batches: the total number of batches to return. get_pretokenized_examples: a bool, where True indicates that we should return the natural text (pre-tokenization) inputs and targets. Default to True in order to make the examples easy to debug/inspect. sequence_length: dictionary mapping feature key to maximum length (int) for that feature. Used by SeqIO to get the dataset. **get_dataset_kwargs: any additional arguments that should be passed to the SeqIO `get_dataset()` call. Returns: A sequence of batches, where each batch is a sequence of examples. Each example is a dictionary mapping 'input' and 'target' to the corresponding values for a single example. """ task_or_mixture = seqio.get_mixture_or_task(task_or_mixture_name) total_examples_requested = batch_size * num_batches dataset = task_or_mixture.get_dataset( sequence_length=sequence_length, split=split, **get_dataset_kwargs ) all_batches = [] current_batch = [] input_key = "inputs_pretokenized" if get_pretokenized_examples else "inputs" target_key = ( "targets_pretokenized" if get_pretokenized_examples else "targets" ) total_examples_seen = 0 # It should be noted that we could replace the following loop with tf.Dataset # operations (like # `list(dataset.batch(batch_size).take(num_batches).as_numpy_iterator())`), # but this would require us to pad batches first or represent the token IDs as # ragged tensors. These approaches are currently overkill for the # InteractiveModel, but may be investigated in the future. dataset = dataset.take(total_examples_requested) for idx, element in enumerate(dataset.as_numpy_iterator()): total_examples_seen += 1 if idx >= total_examples_requested: # Because we force `num_examples_requested` to be a multiple of # `batch_size`, this should enforce that the last batch always has the # same number of examples as all other batches. break example_input = element[input_key] example_target = element[target_key] if not get_pretokenized_examples: example_input = example_input.tolist() example_target = example_target.tolist() current_example = {"input": example_input, "target": example_target} current_batch.append(current_example) # If we've collected `batch_size` examples, save the current batch and start # a new batch. if len(current_batch) == batch_size: all_batches.append(current_batch) current_batch = [] if total_examples_seen < total_examples_requested: raise ValueError( "Not enough examples in Task/Mixture. User requested " f"{num_batches} batches of size {batch_size} for a total " f"of {total_examples_requested} examples. Only " f"{total_examples_seen} available in " "Task/Mixture." ) return all_batches
[docs]def get_seqio_task_from_examples( task_name: str, interactive_model: InteractiveModel, examples: Sequence[Union[str, dict[str, str]]], preprocessors: Sequence[Callable[..., tf.data.Dataset]], metric_fns: Optional[ Sequence[seqio.dataset_providers.MetricFnCallable] ] = None, add_to_registry: bool = True, ) -> Union[seqio.Task, seqio.Mixture]: """Registers and returns a SeqIO task from the provided inputs. This function will be used to graduate people to the T5X/SeqIO-based train/infer/eval scripts. Args: task_name: the name of the SeqIO task to be created and registered. interactive_model: an instance of the InteractiveModel. examples: a single batch of examples that should be transformed into a tf.data.Dataset. The examples can either take the form of a string (ex: a single input for inference), or a dictionary mapping "input"/"target" to a string containing that element. preprocessors: an optional list of functions that receive a tf.data.Dataset and return a tf.data.Dataset. These will be executed sequentially and the final dataset must include features matching `self._features`. metric_fns: list(callable), an optional list of metric functions with a signature that matches one of three possible forms: - (targets, scores) - Note that `scores` refers to the score the model assigned the target sequence, given the input. - (targets, predictions) - (targets, predictions, aux_values) - Note that `aux_values` refers to a dictionary of auxiliary values that the model assigned to each sequence. add_to_registry: if True, will register the new task. Returns: A SeqIO task. """ def dataset_fn(split, shuffle_files): del split, shuffle_files return get_dataset_from_natural_text_examples( examples, preprocessors=[], task_feature_lengths=interactive_model._task_feature_lengths, # pylint: disable=protected-access features={}, ) data_source = seqio.FunctionDataSource( dataset_fn=dataset_fn, splits=["train", "validation"] ) if add_to_registry: seqio.TaskRegistry.add( task_name, data_source, preprocessors=preprocessors, output_features=interactive_model._features, # pylint: disable=protected-access metric_fns=metric_fns, ) return seqio.get_mixture_or_task(task_name)
# pylint: disable=protected-access
[docs]def get_gin_config_from_interactive_model( interactive_model: InteractiveModel, script_type: T5XScriptType, task_name: str, partitioner_config_str: str, model_config_str: str, train_steps: int = 1, imports_str: str = "", ): """Converts an InteractiveModel instance into a Gin config string. This function will be used to graduate people to the T5X/SeqIO-based train/infer/eval scripts. Args: interactive_model: an instance of the InteractiveModel. script_type: which T5X script the Gin config should function with. task_name: the name of the SeqIO task to be used. partitioner_config_str: a string that defines the Partitioner object in the Gin config. model_config_str: a string that defines the Model object in the Gin config. train_steps: the number of steps to train for, only used if FINETUNING or PRETRAINING is selected as the script type. imports_str: if the `model_config_str` or `partitioner_config_str` relies on some other files to be imported, these import statements can be included in the final Gin file by adding them to this string. Returns: A string that contains the full Gin file to be used for train/infer/eval.py. """ restore_config_str = "" if interactive_model._restore_checkpoint_cfg: restore_config_str = f"""CHECKPOINT_PATH = '{interactive_model._restore_checkpoint_cfg.path}' utils.RestoreCheckpointConfig: path = %CHECKPOINT_PATH mode = '{interactive_model._restore_checkpoint_cfg.mode}' dtype = '{interactive_model._restore_checkpoint_cfg.dtype}'""" base_config_str = f""" {imports_str} MODEL_DIR = "{interactive_model._output_dir}" MIXTURE_OR_TASK_NAME = "{task_name}" TASK_FEATURE_LENGTHS = {interactive_model._task_feature_lengths} USE_CACHED_TASKS = False SHUFFLE_TRAIN_EXAMPLES = False BATCH_SIZE = {interactive_model._batch_size} {model_config_str} {partitioner_config_str} {restore_config_str}""" if script_type == T5XScriptType.INFERENCE: if not interactive_model._restore_checkpoint_cfg: raise ValueError("A checkpoint must be provided to run inference.") gin_config = f""" include 't5x/configs/runs/infer.gin' {base_config_str} INFER_OUTPUT_DIR = %MODEL_DIR utils.DatasetConfig: use_cached = %USE_CACHED_TASKS batch_size = %BATCH_SIZE shuffle = False seed = 0 pack = False """ elif ( script_type == T5XScriptType.FINETUNING or script_type == T5XScriptType.PRETRAINING or script_type == T5XScriptType.EVALUATION ): gin_config = f""" from __gin__ import dynamic_registration import __main__ as train_script from t5x import utils include 't5x/configs/runs/pretrain.gin' {base_config_str} utils.SaveCheckpointConfig: period = {interactive_model._save_checkpoint_cfg.period} dtype = '{interactive_model._save_checkpoint_cfg.dtype}' keep = {interactive_model._save_checkpoint_cfg.keep} save_dataset = {interactive_model._save_checkpoint_cfg.save_dataset} TRAIN_STEPS = {train_steps} SHUFFLE_TRAIN_EXAMPLES = False DROPOUT_RATE = 0.0 train/utils.DatasetConfig: pack = False train_eval/utils.DatasetConfig: pack = False """ if script_type == T5XScriptType.EVALUATION: gin_config += """ train_script.train: run_eval_before_training = True eval_period = 0 total_steps = 0 """ return gin_config
# pylint: enable=protected-access