otx.algorithms.common.adapters.torch.utils#

Utils for modules using torch.

Functions

model_from_timm(model)

Check a model comes from timm module.

convert_sync_batchnorm(model)

Convert BatchNorm layers to SyncBatchNorm layers.

sync_batchnorm_2_batchnorm(module[, dim])

Syncs the BatchNorm layers in a model to use regular BatchNorm layers.

Classes

BsSearchAlgo(train_func, train_func_kwargs, ...)

Algorithm class to find optimal batch size.

class otx.algorithms.common.adapters.torch.utils.BsSearchAlgo(train_func: Callable, train_func_kwargs: Dict[str, Any], default_bs: int, max_bs: int)[source]#

Bases: object

Algorithm class to find optimal batch size.

Parameters:
  • train_func (Callable[[int], None]) – Training function with single arugment to set batch size.

  • train_func_kwargs (Dict[str, Any]) – Keyword arguments for train_func.

  • default_bs (int) – Default batch size. It should be bigger than 0.

  • max_bs (int) – Maximum batch size. It should be bigger than 0.

auto_decrease_batch_size() int[source]#

Decrease batch size if default batch size isn’t fit to current device.

Returns:

Proper batch size possibly decreased as default value isn’t fit

Return type:

int

find_big_enough_batch_size(drop_last: bool = False) int[source]#

Find a big enough batch size.

This function finds a big enough batch size by training with various batch sizes. It estimate a batch size using equation is estimated using training history. The reason why using the word “big enough” is that it tries to find not maxmium but big enough value which uses memory between lower and upper bound.

Parameters:

drop_last (bool) – Whether to drop the last incomplete batch.

Raises:

RuntimeError – If training with batch size 2 can’t be run, raise an error.

Returns:

Big enough batch size.

Return type:

int

otx.algorithms.common.adapters.torch.utils.convert_sync_batchnorm(model: Module)[source]#

Convert BatchNorm layers to SyncBatchNorm layers.

Parameters:

model (Module) – model containing batchnorm layers.

otx.algorithms.common.adapters.torch.utils.model_from_timm(model: Module) bool[source]#

Check a model comes from timm module.

Parameters:

model (Module) – model to check it comes from timm module.

Returns:

whether model comes from timm or not.

Return type:

bool

otx.algorithms.common.adapters.torch.utils.sync_batchnorm_2_batchnorm(module, dim=2)[source]#

Syncs the BatchNorm layers in a model to use regular BatchNorm layers.