nncf.torch.sparsity.rb.algo#

Classes#

RBSparsityController

Controller for the regularization-based (RB) sparsity algorithm in PT.

class nncf.torch.sparsity.rb.algo.RBSparsityController(target_model, sparsified_module_info, config)[source]#

Bases: nncf.torch.sparsity.base_algo.BaseSparsityAlgoController

Controller for the regularization-based (RB) sparsity algorithm in PT.

Parameters:
  • target_model (nncf.torch.nncf_network.NNCFNetwork) –

  • sparsified_module_info (List[nncf.torch.sparsity.base_algo.SparseModuleInfo]) –

  • config (nncf.NNCFConfig) –

property current_sparsity_level: float[source]#

Returns the current sparsity level of the underlying model.

Return type:

float

property compression_rate[source]#

Returns a float compression rate value ranging from 0 to 1 (e.g. the sparsity level, or the ratio of filters pruned).

set_sparsity_level(sparsity_level, target_sparsified_module_info=None)[source]#

Sets the sparsity level that should be applied to the model’s weights.

Parameters:
  • sparsity_level – Sparsity level that should be applied to the model’s weights.

  • target_sparsified_module_info (nncf.torch.sparsity.base_algo.SparseModuleInfo) –

compression_stage()[source]#

Returns the compression stage. Should be used on saving best checkpoints to distinguish between uncompressed, partially compressed, and fully compressed models.

Returns:

The compression stage of the target model.

Return type:

nncf.api.compression.CompressionStage

freeze()[source]#

Freezes all sparsity masks. Sparsity masks will not be trained after calling this method.

distributed()[source]#

Should be called when distributed training with multiple training processes is going to be used (i.e. after the model is wrapped with DistributedDataParallel). Any special preparations for the algorithm to properly support distributed training should be made inside this function.

statistics(quickly_collected_only=False)[source]#

Returns a Statistics class instance that contains compression algorithm statistics.

Parameters:

quickly_collected_only – Enables collection of the statistics that don’t take too much time to compute. Can be helpful for the case when need to keep track of statistics on each training batch/step/iteration.

Return type:

nncf.common.statistics.NNCFStatistics