BasePredictionWriter¶
- class lightning.pytorch.callbacks.BasePredictionWriter(write_interval='batch')[source]¶
- Bases: - Callback- Base class to implement how the predictions should be stored. - Example: - import torch from lightning.pytorch.callbacks import BasePredictionWriter class CustomWriter(BasePredictionWriter): def __init__(self, output_dir, write_interval): super().__init__(write_interval) self.output_dir = output_dir def write_on_batch_end( self, trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx ): torch.save(prediction, os.path.join(self.output_dir, dataloader_idx, f"{batch_idx}.pt")) def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): torch.save(predictions, os.path.join(self.output_dir, "predictions.pt")) pred_writer = CustomWriter(output_dir="pred_path", write_interval="epoch") trainer = Trainer(callbacks=[pred_writer]) model = BoringModel() trainer.predict(model, return_predictions=False) - Example: - # multi-device inference example import torch from lightning.pytorch.callbacks import BasePredictionWriter class CustomWriter(BasePredictionWriter): def __init__(self, output_dir, write_interval): super().__init__(write_interval) self.output_dir = output_dir def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): # this will create N (num processes) files in `output_dir` each containing # the predictions of it's respective rank torch.save(predictions, os.path.join(self.output_dir, f"predictions_{trainer.global_rank}.pt")) # optionally, you can also save `batch_indices` to get the information about the data index # from your prediction data torch.save(batch_indices, os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt")) # or you can set `write_interval="batch"` and override `write_on_batch_end` to save # predictions at batch level pred_writer = CustomWriter(output_dir="pred_path", write_interval="epoch") trainer = Trainer(accelerator="gpu", strategy="ddp", devices=8, callbacks=[pred_writer]) model = BoringModel() trainer.predict(model, return_predictions=False) - on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]¶
- Called when the predict batch ends. - Return type:
 
 - setup(trainer, pl_module, stage)[source]¶
- Called when fit, validate, test, predict, or tune begins. - Return type: