t5x.infer binary#

This script runs inference on a T5X-compatible model.

class t5x.infer.FailFastThreadPoolExecutor(*args, **kwargs)[source]#

Wrapper for ThreadPoolExecutor that crashes main thread on exceptions.

NOTE: this class should be used only from the main thread.


Raises any exceptions from complete futures on the main thread.

shutdown(*args, wait=False, **kwargs)[source]#

Clean-up the resources associated with the Executor.

It is safe to call this method several times. Otherwise, no other methods can be called after this one.

  • wait – If True then shutdown will not return until all running futures have finished executing and the resources used by the executor have been reclaimed.

  • cancel_futures – If True then shutdown will cancel all pending futures. Futures that are completed or running will not be cancelled.

submit(*args, **kwargs)[source]#

Submit function to threadpool, capturing the returned future.

class t5x.infer.SummarizeConfigFn(*args, **kwargs)[source]#
t5x.infer.create_task_from_tfexample_file(paths, file_type, inputs_key, targets_key, features, task_id=None)[source]#

Registers ad-hoc Task for file-based dataset of TFExamples.

  • paths – Input file paths; all files should have type file_type and contain binary-serialized TFExample protos.

  • file_type – Input file type; e.g., ‘tfrecord’, ‘recordio’, ‘sstable’. For keyed formats like ‘sstable’, we ignore the keys and use only the values.

  • inputs_key – Name of TFExample feature containing the input text for T5X. The value of this feature should be a UTF8-encoded string.

  • targets_key – Optional name of a TFExample feature containing the target text (relevant only in scoring mode). The value of this feature should be a UTF8-encoded string.

  • features – Should have entries for keys ‘inputs’ and (if targets_key is not None) ‘targets’, mapping to seqio.Feature objects that specify attributes like vocabulary, add_eos, etc. These attributes are used for preprocessing and featurizing the input text.

  • task_id – Task name identifier. By default, it is set to a unique and deterministic hash id. Overrideable via this argument.


Name of the newly-registered Task. This Task has a split named ‘infer’ that contains the preprocessed and featurized input dataset.

t5x.infer.infer(*, mode, model, dataset_cfg, restore_checkpoint_cfg, partitioner, output_dir, checkpoint_period, shard_id=0, num_shards=1, merge_chunked_results=True, write_fn=<function write_inferences_to_file>, checkpoint_ds_iter=True, train_state_initializer_cls=<class 't5x.utils.TrainStateInitializer'>, fallback_init_rng=None, merge_fn=<function merge_chunks_to_file>, summarize_config_fn=<function summarize_gin_config>, verify_matching_vocabs_fn=<function verify_matching_vocabs>, output_vocab_feature_name='targets', file_extension='jsonl', keep_aux_as_numpy=False, use_orbax=False)[source]#

Infer function.

  • mode – Either ‘predict’ to decode targets, ‘score’ to compute the log likelihood of given targets, or ‘predict_with_aux’ for both.

  • model – The model object to use for inference.

  • dataset_cfg – Specification for the dataset to infer based on.

  • restore_checkpoint_cfg – Specification for the model parameter checkpoint to load.

  • partitioner – Partitioner for model parameters and data across devices.

  • output_dir – Path to directory to write temporary files and final results.

  • checkpoint_period – The intermediate results and dataset iterator will be checkpointed on each multiple of this number of batches to enable continuation after a failure.

  • shard_id – Index of dataset shard for this instance to use if splitting the work across multiple jobs.

  • num_shards – Total number of dataset shards to split dataset across.

  • merge_chunked_results – Whether to merge results of all chunks into a single json file.

  • write_fn – Callable function used to serialized and write inferences out to files.

  • checkpoint_ds_iter – if True, will checkpoint the dataset iterator every checkpoint_period to enable faster restore. This must be disabled for certain datasets, for example since stateful iterators (e.g. from seqio.FunctionTask) cannot be checkpointed.

  • train_state_initializer_cls – t5x.utils.TrainStateInitializer class for initializing partitioned TrainState from checkpoints or scratch.

  • fallback_init_rng – A random seed used for parameter initialization during model re-loading when utils.RestoreCheckpointConfig.fallback_to_scratch is set to True. If None, parameter initialization is not allowed during model loading and having fallback_to_scratch enabled will result in an error.

  • merge_fn – Callable function used to merge inferences from multiple files.

  • summarize_config_fn – A function that takes in the model directory, an optional SummaryWriter, and the step number, and writes a summary of the configuration. SummaryWriter will be None in most cases.

  • verify_matching_vocabs_fn – Function to validate whether the task vocabulary matches the model vocabulary. Should raise an exception on error.

  • output_vocab_feature_name – The name of the feature corresponding to the output vocabulary.

  • file_extension – str. file extension used for file names

  • keep_aux_as_numpy – bool. whether to leave aux values as numpy arrays; can be used to save space when saving bfloat16s

  • use_orbax – if True, uses Orbax for checkpointing. Experimental feature.

t5x.infer.merge_chunks_to_file(output_dir, output_fname, tmp_dir, step)[source]#

Merge the predictions from different chunks into a unified file.

t5x.infer.update_measurement_series(series_name, step, value)[source]#

Not implemented externally.

t5x.infer.write_inferences_to_file(path, inferences, task_ds, mode, vocabulary=None, json_encoder_cls=<class 'seqio.loggers.TensorAndNumpyEncoder'>, include_all_inputs=False, input_fields_to_include=None, output_ids=False)[source]#

Write model predictions, along with pretokenized inputs, to JSONL file.

  • path – File path to write to.

  • inferences – A tuple containing (predictions, aux_values). If mode is ‘predict’ then the predictions will be token IDs. If it’s ‘score’ then it’ll be a collection of scores. aux_values will be an empty dictionary unless mode is ‘predict_with_aux’, in which case it’ll contain the model’s auxiliary outputs.

  • task_ds – Original task dataset. Features from task with suffix _pretokenized are added to the outputs.

  • mode – Prediction mode, either ‘predict’, ‘score’ or ‘predict_with_aux’.

  • vocabulary – Task output vocabulary. Only used in predict mode in order to decode predicted outputs into string.

  • json_encoder_cls – a JSON encoder class used to customize JSON serialization via json.dumps.

  • include_all_inputs – if True, will include all model inputs in the output JSONL file (including raw tokens) in addition to the pretokenized inputs.

  • input_fields_to_include – List of input fields to include in the output JSONL file. This list should be None if include_all_inputs is set to True.

  • output_ids – if True, will output the token ID sequence for the output, in addition to the decoded text.