t5x.interactive_model package#

InteractiveModel class for use in T5X Colabs.

The InteractiveModel can be used to run training, inference, and evaluation on natural text inputs and targets.

class t5x.interactive_model.InferenceType(value)[source]#

An enumeration.

class t5x.interactive_model.InteractiveModel(batch_size, task_feature_lengths, output_dir, partitioner, model, dtype, restore_mode, checkpoint_path, input_shapes, input_types=None, init_random_seed=42, add_eos=True, eval_names=None)[source]#

Wrapper around T5X components to enable interactive train/infer/eval.

evaluate(examples, metric_fns)[source]#

Evaluation function.

Parameters:
  • examples – examples that should be transformed into a tf.data.Dataset. The examples can either take the form of a string (ex: a single input for inference), or a dictionary mapping “input”/”target” to a string containing that element.

  • metric_fns – list(callable), an optional list of metric functions with a signature that matches one of three possible forms: - (targets, scores) - Note that scores refers to the score the model assigned the target sequence, given the input. - (targets, predictions) - (targets, predictions, aux_values) - Note that aux_values refers to a dictionary of auxiliary values that the model assigned to each sequence.

Returns:

Mapping of metrics names to metrics values.

evaluate_with_preprocessors(examples, preprocessors, metric_fns, postprocessor=None)[source]#

Evaluation function.

Parameters:
  • examples – examples that should be transformed into a tf.data.Dataset. The examples must take the form of a dictionary mapping “input”/”target” to a string containing that element.

  • preprocessors – list(callable), an optional list of functions that receive a tf.data.Dataset and return a tf.data.Dataset. These will be executed sequentially and the final dataset must include features matching self._features.

  • metric_fns – list(callable), an optional list of metric functions with a signature that matches one of three possible forms: - (targets, scores) - Note that scores refers to the score the model assigned the target sequence, given the input. - (targets, predictions) - (targets, predictions, aux_values) - Note that aux_values refers to a dictionary of auxiliary values that the model assigned to each sequence.

  • postprocessor – callable, an optional function that receives decoded model outputs and converts them to a form that is ready for evaluation using the metric functions in metric_fns.

Returns:

Mapping of metrics names to metrics values.

infer_with_preprocessors(mode, examples, preprocessors, **inference_kwargs)[source]#

Infer function.

Parameters:
  • mode – Either ‘score’ to compute the log likelihood of given targets, or ‘predict_with_aux’ to score and decode targets.

  • examples – examples that should be transformed into a tf.data.Dataset. The examples can either take the form of a string (ex: a single input for inference), or a dictionary mapping “input”/”target” to a string containing that element.

  • preprocessors – list(callable), an optional list of functions that receive a tf.data.Dataset and return a tf.data.Dataset. These will be executed sequentially and the final dataset must include features matching self._features.

  • **inference_kwargs – additional keyword arguments to pass to the inference function (e.g., model.predict_batch_with_aux or score_batch).

Returns:

Returns a tuple of predictions/scores and any auxiliary values.

predict_with_aux(examples)[source]#

Predict with auxiliary values method.

save_checkpoint()[source]#

Saves model checkpoint.

score(examples)[source]#

Score method.

train_loop(num_steps, eval_period=1, train_batches=None, predict_batches=None, score_batches=None, eval_batches=None, metrics_fns=None)[source]#

Runs training, inference, and evaluation for num_steps.

It should be noted that there are many different possible variants of the train_loop function that a user might want to use. The primary goal of the train_loop function is not to cover all the potential training loop variants that a user may want; rather, the goal is to demonstrate how the user could stack the InteractiveModel train, predict, score, and evaluate methods.

Parameters:
  • num_steps – the number of steps to run for training, inference, and evaluation.

  • eval_period – specifies how many steps to take between inference/evaluation.

  • train_batches – an optional list of batches that we should run training on. If no batches are provided, then training will be skipped. If a single batch is provided, we will repeat training on this batch for num_steps.

  • predict_batches – an optional list of batches that we should get predictions for. If no batches are provided, then predicting will be skipped. If a single batch is provided, we will repeatedly get predictions on this batch for num_steps.

  • score_batches – an optional list of batches that we should score. If no batches are provided, then scoring will be skipped. If a single batch is provided, we will repeatedly score this batch for num_steps.

  • eval_batches – an optional list of batches that we should run eval on. If no batches are provided, then evaluation will be skipped. If a single batch is provided, we will repeatedly evaluate this batch for num_steps.

  • metrics_fns – list(callable), an optional list of metric functions with a signature that matches one of three possible forms: - (targets, scores) - Note that scores refers to the score the model assigned the target sequence, given the input. - (targets, predictions) - (targets, predictions, aux_values) - Note that aux_values refers to a dictionary of auxiliary values that the model assigned to each sequence.

Returns:

Predictions, scores, and metrics for the final step of the training loop.

train_step(examples)[source]#

Train function.

Parameters:

examples – examples that should be transformed into a tf.data.Dataset. The examples can either take the form of a string (ex: a single input for inference), or a dictionary mapping “input”/”target” to a string containing that element. At least self._batch_size examples must be provided.

Raises:

ValueError – the user provided less than batch_size examples, or self._train_state was set to a sequence of TrainStates, when it should have been a single TrainState.

train_step_from_batch_iterator(iterator)[source]#

Runs one training step from a batch iterator.

train_step_with_preprocessors(examples, preprocessors)[source]#

Train function.

Parameters:
  • examples – examples that should be transformed into a tf.data.Dataset. The examples can either take the form of a string (ex: a single input for inference), or a dictionary mapping “input”/”target” to a string containing that element. At least self._batch_size examples must be provided.

  • preprocessors – list(callable), an optional list of functions that receive a tf.data.Dataset and return a tf.data.Dataset. These will be executed sequentially and the final dataset must include features matching self._features.

Raises:

ValueError – the user provided less than batch_size examples, or self._train_state was set to a sequence of TrainStates, when it should have been a single TrainState.

class t5x.interactive_model.T5XScriptType(value)[source]#

An enumeration.

t5x.interactive_model.get_batches_from_seqio(task_or_mixture_name, split, batch_size, num_batches, get_pretokenized_examples=True, sequence_length=None, **get_dataset_kwargs)[source]#

Returns a batch of examples from a provided SeqIO task.

Parameters:
  • task_or_mixture_name – the SeqIO task/mixture to read data from.

  • split – the split of the SeqIO task/mixture to read data from.

  • batch_size – how many examples should be in each batch.

  • num_batches – the total number of batches to return.

  • get_pretokenized_examples – a bool, where True indicates that we should return the natural text (pre-tokenization) inputs and targets. Default to True in order to make the examples easy to debug/inspect.

  • sequence_length – dictionary mapping feature key to maximum length (int) for that feature. Used by SeqIO to get the dataset.

  • **get_dataset_kwargs – any additional arguments that should be passed to the SeqIO get_dataset() call.

Returns:

A sequence of batches, where each batch is a sequence of examples. Each

example is a dictionary mapping ‘input’ and ‘target’ to the corresponding values for a single example.

t5x.interactive_model.get_dataset_from_natural_text_examples(examples, preprocessors, task_feature_lengths, features)[source]#

Returns a tf.data.Dataset from a list of examples.

Parameters:
  • examples – a single batch of examples that should be transformed into a tf.data.Dataset. The examples can either take the form of a string (ex: a single input for inference), or a dictionary mapping “input”/”target” to a string containing that element.

  • preprocessors – an optional list of functions that receive a tf.data.Dataset and return a tf.data.Dataset. These will be executed sequentially and the final dataset must include features matching self._features.

  • task_feature_lengths – dictionary mapping feature key to maximum length (int) for that feature. If feature is longer than this length after preprocessing, the feature will be truncated. May be set to None to avoid truncation.

  • features – dictionary defining what features should be present in all examples.

Returns:

A tf.data.Dataset.

t5x.interactive_model.get_gin_config_from_interactive_model(interactive_model, script_type, task_name, partitioner_config_str, model_config_str, train_steps=1, imports_str='')[source]#

Converts an InteractiveModel instance into a Gin config string.

This function will be used to graduate people to the T5X/SeqIO-based train/infer/eval scripts.

Parameters:
  • interactive_model – an instance of the InteractiveModel.

  • script_type – which T5X script the Gin config should function with.

  • task_name – the name of the SeqIO task to be used.

  • partitioner_config_str – a string that defines the Partitioner object in the Gin config.

  • model_config_str – a string that defines the Model object in the Gin config.

  • train_steps – the number of steps to train for, only used if FINETUNING or PRETRAINING is selected as the script type.

  • imports_str – if the model_config_str or partitioner_config_str relies on some other files to be imported, these import statements can be included in the final Gin file by adding them to this string.

Returns:

A string that contains the full Gin file to be used for train/infer/eval.py.

t5x.interactive_model.get_seqio_task_from_examples(task_name, interactive_model, examples, preprocessors, metric_fns=None, add_to_registry=True)[source]#

Registers and returns a SeqIO task from the provided inputs.

This function will be used to graduate people to the T5X/SeqIO-based train/infer/eval scripts.

Parameters:
  • task_name – the name of the SeqIO task to be created and registered.

  • interactive_model – an instance of the InteractiveModel.

  • examples – a single batch of examples that should be transformed into a tf.data.Dataset. The examples can either take the form of a string (ex: a single input for inference), or a dictionary mapping “input”/”target” to a string containing that element.

  • preprocessors – an optional list of functions that receive a tf.data.Dataset and return a tf.data.Dataset. These will be executed sequentially and the final dataset must include features matching self._features.

  • metric_fns – list(callable), an optional list of metric functions with a signature that matches one of three possible forms: - (targets, scores) - Note that scores refers to the score the model assigned the target sequence, given the input. - (targets, predictions) - (targets, predictions, aux_values) - Note that aux_values refers to a dictionary of auxiliary values that the model assigned to each sequence.

  • add_to_registry – if True, will register the new task.

Returns:

A SeqIO task.