Source code for otx.algo.classification.backbones.torchvision

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Torchvison model's Backbone Class."""

from typing import Literal

import torch
from torch import nn
from torchvision.models import get_model, get_model_weights

TVModelType = Literal[
    "alexnet",
    "convnext_base",
    "convnext_large",
    "convnext_small",
    "convnext_tiny",
    "efficientnet_b0",
    "efficientnet_b1",
    "efficientnet_b2",
    "efficientnet_b3",
    "efficientnet_b4",
    "efficientnet_b5",
    "efficientnet_b6",
    "efficientnet_b7",
    "efficientnet_v2_l",
    "efficientnet_v2_m",
    "efficientnet_v2_s",
    "googlenet",
    "mobilenet_v3_large",
    "mobilenet_v3_small",
    "regnet_x_16gf",
    "regnet_x_1_6gf",
    "regnet_x_32gf",
    "regnet_x_3_2gf",
    "regnet_x_400mf",
    "regnet_x_800mf",
    "regnet_x_8gf",
    "regnet_y_128gf",
    "regnet_y_16gf",
    "regnet_y_1_6gf",
    "regnet_y_32gf",
    "regnet_y_3_2gf",
    "regnet_y_400mf",
    "regnet_y_800mf",
    "regnet_y_8gf",
    "resnet101",
    "resnet152",
    "resnet18",
    "resnet34",
    "resnet50",
    "resnext101_32x8d",
    "resnext101_64x4d",
    "resnext50_32x4d",
    "swin_b",
    "swin_s",
    "swin_t",
    "swin_v2_b",
    "swin_v2_s",
    "swin_v2_t",
    "vgg11",
    "vgg11_bn",
    "vgg13",
    "vgg13_bn",
    "vgg16",
    "vgg16_bn",
    "vgg19",
    "vgg19_bn",
    "wide_resnet101_2",
    "wide_resnet50_2",
]


def get_in_features(sequential: nn.Sequential) -> int:
    """Get the in_features value from the first layer of an nn.Sequential object."""
    for layer in sequential.children():
        if isinstance(layer, nn.Linear):
            return layer.in_features
        if isinstance(layer, nn.Conv2d):
            return layer.in_channels
        # Add more conditions if needed for other layer types
    msg = "No suitable layer found to extract in_features"
    raise ValueError(msg)


[docs] class TorchvisionBackbone(nn.Module): """TorchvisionBackbone is a class that represents a backbone model from the torchvision library.""" def __init__( self, backbone: TVModelType, pretrained: bool = False, **kwargs, ): super().__init__(**kwargs) tv_model_cfg = {"name": backbone} if pretrained: tv_model_cfg["weights"] = get_model_weights(backbone) net = get_model(**tv_model_cfg) self.features = net.features last_layer = list(net.children())[-1] self.in_features = get_in_features(last_layer)
[docs] def forward(self, *args) -> torch.Tensor: """Forward pass of the model.""" return self.features(*args)