ModelPruning¶
- class lightning.pytorch.callbacks.ModelPruning(pruning_fn, parameters_to_prune=(), parameter_names=None, use_global_unstructured=True, amount=0.5, apply_pruning=True, make_pruning_permanent=True, use_lottery_ticket_hypothesis=True, resample_parameters=False, pruning_dim=None, pruning_norm=None, verbose=0, prune_on_train_epoch_end=True)[source]¶
- Bases: - Callback- Model pruning Callback, using PyTorch’s prune utilities. This callback is responsible of pruning networks parameters during training. - To learn more about pruning with PyTorch, please take a look at this tutorial. - Warning - This is an experimental feature. - parameters_to_prune = [(model.mlp_1, "weight"), (model.mlp_2, "weight")] trainer = Trainer( callbacks=[ ModelPruning( pruning_fn="l1_unstructured", parameters_to_prune=parameters_to_prune, amount=0.01, use_global_unstructured=True, ) ] ) - When - parameters_to_pruneis- None,- parameters_to_prunewill contain all parameters from the model. The user can override- filter_parameters_to_pruneto filter any- nn.Moduleto be pruned.- Parameters:
- pruning_fn¶ ( - Union[- Callable,- str]) – Function from torch.nn.utils.prune module or your own PyTorch- BasePruningMethodsubclass. Can also be string e.g. “l1_unstructured”. See pytorch docs for more details.
- parameters_to_prune¶ ( - Sequence[- tuple[- Module,- str]]) – List of tuples- (nn.Module, "parameter_name_string").
- parameter_names¶ ( - Optional[- list[- str]]) – List of parameter names to be pruned from the nn.Module. Can either be- "weight"or- "bias".
- use_global_unstructured¶ ( - bool) – Whether to apply pruning globally on the model. If- parameters_to_pruneis provided, global unstructured will be restricted on them.
- amount¶ ( - Union[- int,- float,- Callable[[- int],- Union[- int,- float]]]) –- Quantity of parameters to prune: - float. Between 0.0 and 1.0. Represents the fraction of parameters to prune.
- int. Represents the absolute number of parameters to prune.
- Callable. For dynamic values. Will be called every epoch. Should return a value.
 
- apply_pruning¶ ( - Union[- bool,- Callable[[- int],- bool]]) –- Whether to apply pruning. - bool. Always apply it or not.
- Callable[[epoch], bool]. For dynamic values. Will be called every epoch.
 
- make_pruning_permanent¶ ( - bool) – Whether to remove all reparameterization pre-hooks and apply masks when training ends or the model is saved.
- use_lottery_ticket_hypothesis¶ ( - Union[- bool,- Callable[[- int],- bool]]) –- See The lottery ticket hypothesis: - bool. Whether to apply it or not.
- Callable[[epoch], bool]. For dynamic values. Will be called every epoch.
 
- resample_parameters¶ ( - bool) – Used with- use_lottery_ticket_hypothesis. If True, the model parameters will be resampled, otherwise, the exact original parameters will be used.
- pruning_dim¶ ( - Optional[- int]) – If you are using a structured pruning method you need to specify the dimension.
- pruning_norm¶ ( - Optional[- int]) – If you are using- ln_structuredyou need to specify the norm.
- verbose¶ ( - int) – Verbosity level. 0 to disable, 1 to log overall sparsity, 2 to log per-layer sparsity
- prune_on_train_epoch_end¶ ( - bool) – whether to apply pruning at the end of the training epoch. If this is- False, then the check runs at the end of the validation epoch.
 
- Raises:
- MisconfigurationException – If - parameter_namesis neither- "weight"nor- "bias", if the provided- pruning_fnis not supported, if- pruning_dimis not provided when- "unstructured", if- pruning_normis not provided when- "ln_structured", if- pruning_fnis neither- strnor- torch.nn.utils.prune.BasePruningMethod, or if- amountis none of- int,- floatand- Callable.
 - apply_lottery_ticket_hypothesis()[source]¶
- Lottery ticket hypothesis algorithm (see page 2 of the paper): :rtype: - None- Randomly initialize a neural network \(f(x; \theta_0)\) (where \(\theta_0 \sim \mathcal{D}_\theta\)). 
- Train the network for \(j\) iterations, arriving at parameters \(\theta_j\). 
- Prune \(p\%\) of the parameters in \(\theta_j\), creating a mask \(m\). 
- Reset the remaining parameters to their values in \(\theta_0\), creating the winning ticket \(f(x; m \odot \theta_0)\). 
 - This function implements the step 4. - The - resample_parametersargument can be used to reset the parameters with a new \(\theta_z \sim \mathcal{D}_\theta\)
 - filter_parameters_to_prune(parameters_to_prune=())[source]¶
- This function can be overridden to control which module to prune. 
 - make_pruning_permanent(module)[source]¶
- Removes pruning buffers from any pruned modules. - Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/utils/prune.py#L1118-L1122 - Return type:
 
 - on_save_checkpoint(trainer, pl_module, checkpoint)[source]¶
- Called when saving a checkpoint to give you a chance to store anything else you might want to save. 
 - on_train_epoch_end(trainer, pl_module)[source]¶
- Called when the train epoch ends. - To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the - lightning.pytorch.core.LightningModuleand access them in this hook:- class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear() - Return type: