ModelHooks¶
- class lightning.pytorch.core.hooks.ModelHooks[source]¶
- Bases: - object- Hooks to be used in LightningModule. - configure_model()[source]¶
- Hook to create modules in a strategy and precision aware context. - This is particularly useful for when using sharded strategies (FSDP and DeepSpeed), where we’d like to shard the model instantly to save memory and initialization time. For non-sharded strategies, you can choose to override this hook or to initialize your model under the - init_module()context manager.- This hook is called during each of fit/val/test/predict stages in the same process, so ensure that implementation of this hook is idempotent, i.e., after the first time the hook is called, subsequent calls to it should be a no-op. - Return type:
 
 - configure_sharded_model()[source]¶
- Deprecated. - Use - configure_model()instead.- Return type:
 
 - on_after_backward()[source]¶
- Called after - loss.backward()and before optimizers are stepped. :rtype:- None- Note - If using native AMP, the gradients will not be unscaled at this point. Use the - on_before_optimizer_stepif you need the unscaled gradients.
 - on_before_optimizer_step(optimizer)[source]¶
- Called before - optimizer.step().- If using gradient accumulation, the hook is called once the gradients have been accumulated. See: - accumulate_grad_batches.- If using AMP, the loss will be unscaled before calling this hook. See these docs for more information on the scaling of gradients. - If clipping gradients, the gradients will not have been clipped yet. - Example: - def on_before_optimizer_step(self, optimizer): # example to inspect gradient information in tensorboard if self.trainer.global_step % 25 == 0: # don't make the tf file huge for k, v in self.named_parameters(): self.logger.experiment.add_histogram( tag=k, values=v.grad, global_step=self.trainer.global_step ) 
 - on_before_zero_grad(optimizer)[source]¶
- Called after - training_step()and before- optimizer.zero_grad().- Called in the training loop after taking an optimizer step and before zeroing grads. Good place to inspect weight information with weights updated. - This is where it is called: - for optimizer in optimizers: out = training_step(...) model.on_before_zero_grad(optimizer) # < ---- called here optimizer.zero_grad() backward() 
 - on_fit_end()[source]¶
- Called at the very end of fit. - If on DDP it is called on every process - Return type:
 
 - on_fit_start()[source]¶
- Called at the very beginning of fit. - If on DDP it is called on every process - Return type:
 
 - on_predict_batch_end(outputs, batch, batch_idx, dataloader_idx=0)[source]¶
- Called in the predict loop after the batch. 
 - on_predict_batch_start(batch, batch_idx, dataloader_idx=0)[source]¶
- Called in the predict loop before anything happens for that batch. 
 - on_predict_model_eval()[source]¶
- Called when the predict loop starts. - The predict loop by default calls - .eval()on the LightningModule before it starts. Override this hook to change the behavior.- Return type:
 
 - on_test_batch_end(outputs, batch, batch_idx, dataloader_idx=0)[source]¶
- Called in the test loop after the batch. 
 - on_test_batch_start(batch, batch_idx, dataloader_idx=0)[source]¶
- Called in the test loop before anything happens for that batch. 
 - on_test_epoch_start()[source]¶
- Called in the test loop at the very beginning of the epoch. - Return type:
 
 - on_test_model_eval()[source]¶
- Called when the test loop starts. - The test loop by default calls - .eval()on the LightningModule before it starts. Override this hook to change the behavior. See also- on_test_model_train().- Return type:
 
 - on_test_model_train()[source]¶
- Called when the test loop ends. - The test loop by default restores the training mode of the LightningModule to what it was before starting testing. Override this hook to change the behavior. See also - on_test_model_eval().- Return type:
 
 - on_train_batch_end(outputs, batch, batch_idx)[source]¶
- Called in the training loop after the batch. - Parameters:
- Return type:
 - Note - The value - outputs["loss"]here will be the normalized value w.r.t- accumulate_grad_batchesof the loss returned from- training_step.
 - on_train_batch_start(batch, batch_idx)[source]¶
- Called in the training loop before anything happens for that batch. - If you return -1 here, you will skip training for the rest of the current epoch. 
 - on_train_end()[source]¶
- Called at the end of training before logger experiment is closed. - Return type:
 
 - on_train_epoch_end()[source]¶
- Called in the training loop at the very end of the epoch. - To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the - LightningModuleand access them in this hook:- class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss def on_train_epoch_end(self): # do something with all training_step outputs, for example: epoch_mean = torch.stack(self.training_step_outputs).mean() self.log("training_epoch_mean", epoch_mean) # free up the memory self.training_step_outputs.clear() - Return type:
 
 - on_train_epoch_start()[source]¶
- Called in the training loop at the very beginning of the epoch. - Return type:
 
 - on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx=0)[source]¶
- Called in the validation loop after the batch. 
 - on_validation_batch_start(batch, batch_idx, dataloader_idx=0)[source]¶
- Called in the validation loop before anything happens for that batch. 
 - on_validation_epoch_end()[source]¶
- Called in the validation loop at the very end of the epoch. - Return type:
 
 - on_validation_epoch_start()[source]¶
- Called in the validation loop at the very beginning of the epoch. - Return type:
 
 - on_validation_model_eval()[source]¶
- Called when the validation loop starts. - The validation loop by default calls - .eval()on the LightningModule before it starts. Override this hook to change the behavior. See also- on_validation_model_train().- Return type:
 
 - on_validation_model_train()[source]¶
- Called when the validation loop ends. - The validation loop by default restores the training mode of the LightningModule to what it was before starting validation. Override this hook to change the behavior. See also - on_validation_model_eval().- Return type: