diff --git a/trax/supervised/trainer_lib.py b/trax/supervised/trainer_lib.py index b52d99524..bfc3d8ffc 100644 --- a/trax/supervised/trainer_lib.py +++ b/trax/supervised/trainer_lib.py @@ -82,7 +82,9 @@ def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, checkpoints_at=None, should_save_checkpoints=True, should_write_summaries=True, - metrics=None, checkpoint_highest=None, checkpoint_lowest=None): + metrics=None, checkpoint_highest=None, + checkpoint_lowest=None, + init_checkpoint=None): self._is_chief, _, self._n_devices, rng = ( training.init_host_and_devices(n_devices, random_seed)) @@ -105,6 +107,10 @@ def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') + # Should work for fine-tuning of T5. + if init_checkpoint: + model_train.init_from_file(init_checkpoint, weights_only=True) + model_predict_eval.init_from_file(init_checkpoint, weights_only=True) self._model_with_loss = tl.Serial(model_train, loss_fn) # Setup state. @@ -523,7 +529,8 @@ def train(output_dir, checkpoint_lowest=None, use_loop=True, loss_chunk_size=0, - use_memory_efficient_trainer=False): + use_memory_efficient_trainer=False, + init_checkpoint=None): """Train the model on the inputs. Args: @@ -554,7 +561,8 @@ def train(output_dir, checkpoint_lowest: save the checkpoint lowest at this metric. use_loop: whether to use training.Loop instead of Trainer. loss_chunk_size: int, if > 0 chunk loss into these sizes to save memory. - use_memory_efficient_trainer: whether to use memory-efficient trainer. + use_memory_efficient_trainer: whether to use memory-efficient trainer.. + init_checkpoint: a checkpoint for fine tuning. Returns: trax.TrainerState or training.Loop if use_loop is True @@ -594,10 +602,17 @@ def train(output_dir, permanent_checkpoint_at = None if permanent_checkpoints_at is not None: permanent_checkpoint_at = (lambda step: step in permanent_checkpoints_at) + + # Setup the model. + model_train = model(mode='train') + model_predict_eval = model(mode='eval') + if init_checkpoint: + model_train.init_from_file(init_checkpoint, weights_only=True) + model_predict_eval.init_from_file(init_checkpoint, weights_only=True) loop = training.Loop( - model(mode='train'), + model_train, [train_task], - eval_model=model(mode='eval'), + eval_model=model_predict_eval, eval_tasks=[eval_task], output_dir=output_dir, checkpoint_at=checkpoint_at, @@ -624,7 +639,8 @@ def train(output_dir, checkpoints_at=checkpoints_at, metrics=metrics, checkpoint_lowest=checkpoint_lowest, - checkpoint_highest=checkpoint_highest) + checkpoint_highest=checkpoint_highest, + init_checkpoint=init_checkpoint) epoch_steps = [steps] # Only training if eval_frequency is 0 or None if eval_frequency and eval_steps > 0: