Source code for otx.algo.action_classification.recognizers.recognizer

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OpenMMLab. All rights reserved.

"""Custom 3D recognizer for OTX."""
from __future__ import annotations

from typing import Any

import torch

from otx.algo.action_classification.utils.data_sample import ActionDataSample
from otx.algo.modules.base_module import BaseModule

[docs] class BaseRecognizer(BaseModule): """Custom 3d recognizer class for OTX. This is for patching forward function during export procedure. """ def __init__( self, backbone: torch.Module, cls_head: torch.Module, neck: torch.Module | None = None, test_cfg: dict | None = None, ) -> None: super().__init__() self.backbone = backbone self.cls_head = cls_head if neck is not None: self.neck = neck self.test_cfg = test_cfg @property def with_neck(self) -> bool: """bool: whether the recognizer has a neck.""" return hasattr(self, "neck") and self.neck is not None @property def with_cls_head(self) -> bool: """bool: whether the recognizer has a cls_head.""" return hasattr(self, "cls_head") and self.cls_head is not None
[docs] def extract_feat( self, inputs: torch.Tensor, stage: str = "neck", data_samples: list[ActionDataSample] | None = None, test_mode: bool = False, ) -> tuple: """Extract features of different stages. Args: inputs (torch.Tensor): The input data. stage (str): Which stage to output the feature. Defaults to ``'neck'``. data_samples (list[:obj:`ActionDataSample`], optional): Action data samples, which are only needed in training. Defaults to None. test_mode (bool): Whether in test mode. Defaults to False. Returns: torch.Tensor: The extracted features. dict: A dict recording the kwargs for downstream pipeline. These keys are usually included: ``loss_aux``. """ # Record the kwargs required by `loss` and `predict` loss_predict_kwargs = {} num_segs = inputs.shape[1] # [N, num_crops, C, T, H, W] -> # [N * num_crops, C, T, H, W] # `num_crops` is calculated by: # 1) `twice_sample` in `SampleFrames` # 2) `num_sample_positions` in `DenseSampleFrames` # 3) `ThreeCrop/TenCrop` in `test_pipeline` # 4) `num_clips` in `SampleFrames` or its subclass if `clip_len != 1` inputs = inputs.view((-1,) + inputs.shape[2:]) # Check settings of test if test_mode: if self.test_cfg is not None: loss_predict_kwargs["fcn_test"] = self.test_cfg.get("fcn_test", False) if self.test_cfg is not None and self.test_cfg.get("max_testing_views", False): max_testing_views = self.test_cfg.get("max_testing_views") if not isinstance(max_testing_views, int): msg = "max_testing_views should be 'int'" raise TypeError(msg) total_views = inputs.shape[0] if num_segs != total_views: msg = "max_testing_views is only compatible with batch_size == 1" raise ValueError(msg) view_ptr = 0 feats = [] while view_ptr < total_views: batch_imgs = inputs[view_ptr : view_ptr + max_testing_views] feat = self.backbone(batch_imgs) if self.with_neck: feat, _ = self.neck(feat) feats.append(feat) view_ptr += max_testing_views def recursively_cat( feats: torch.Tensor | list[Any] | tuple[Any, ...], ) -> tuple[torch.Tensor, ...]: # recursively traverse feats until it's a tensor, # then concat out_feats: list[torch.Tensor] = [] for e_idx, elem in enumerate(feats[0]): batch_elem = [feat[e_idx] for feat in feats] if not isinstance(elem, torch.Tensor): batch_elem = recursively_cat(batch_elem) # type: ignore[assignment] else: batch_elem = out_feats.append(batch_elem) return tuple(out_feats) x = recursively_cat(feats) if isinstance(feats[0], tuple) else else: x = self.backbone(inputs) if self.with_neck: x, _ = self.neck(x) return x, loss_predict_kwargs # Return features extracted through backbone x = self.backbone(inputs) if stage == "backbone": return x, loss_predict_kwargs loss_aux = {} if self.with_neck: x, loss_aux = self.neck(x, data_samples=data_samples) # Return features extracted through neck loss_predict_kwargs["loss_aux"] = loss_aux if stage == "neck": return x, loss_predict_kwargs # Return raw logits through head. x = self.cls_head(x, **loss_predict_kwargs) return x, loss_predict_kwargs
[docs] def forward( self, inputs: torch.Tensor, data_samples: list[ActionDataSample] | None = None, mode: str = "tensor", **kwargs, ) -> dict[str, torch.Tensor] | list[ActionDataSample] | tuple[torch.Tensor] | torch.Tensor: """The unified entry for a forward process in both training and test. The method should accept three modes: - ``tensor``: Forward the whole network and return tensor or tuple of tensor without any post-processing, same as a common nn.Module. - ``predict``: Forward and return the predictions, which are fully processed to a list of :obj:`ActionDataSample`. - ``loss``: Forward and return a dict of losses according to the given inputs and data samples. Note that this method doesn't handle neither back propagation nor optimizer updating, which are done in the :meth:`train_step`. Args: inputs (torch.Tensor): The input tensor with shape (N, C, ...) in general. data_samples (List[``ActionDataSample], optional): The annotation data of every samples. Defaults to None. mode (str): Return what kind of value. Defaults to ``tensor``. Returns: The return type depends on ``mode``. - If ``mode="tensor"``, return a tensor or a tuple of tensor. - If ``mode="predict"``, return a list of ``ActionDataSample``. - If ``mode="loss"``, return a dict of tensor. """ if mode == "predict": return self.predict(inputs, data_samples, **kwargs) if mode == "loss": return self.loss(inputs, data_samples, **kwargs) if mode == "tensor": return self._forward(inputs, **kwargs) msg = f"Invalid mode '{mode}'. Only supports loss, predict and tensor mode" raise RuntimeError(msg)
[docs] def loss(self, inputs: torch.Tensor, data_samples: list[ActionDataSample] | None, **kwargs) -> dict: """Calculate losses from a batch of inputs and data samples. Args: inputs (torch.Tensor): Raw Inputs of the recognizer. These should usually be mean centered and std scaled. data_samples (List[``ActionDataSample``]): The batch data samples. It usually includes information such as ``gt_label``. Returns: dict: A dictionary of loss components. """ feats, loss_kwargs = self.extract_feat(inputs, data_samples=data_samples) # loss_aux will be a empty dict if `self.with_neck` is False. loss_aux = loss_kwargs.get("loss_aux", {}) loss_cls = self.cls_head.loss(feats, data_samples, **loss_kwargs) return self._merge_dict(loss_cls, loss_aux)
[docs] def predict( self, inputs: torch.Tensor, data_samples: list[ActionDataSample] | None, **kwargs, ) -> list[ActionDataSample]: """Predict results from a batch of inputs and data samples with postprocessing. Args: inputs (torch.Tensor): Raw Inputs of the recognizer. These should usually be mean centered and std scaled. data_samples (List[``ActionDataSample``]): The batch data samples. It usually includes information such as ``gt_label``. Returns: List[``ActionDataSample``]: Return the recognition results. The returns value is ``ActionDataSample``, which usually contains ``pred_scores``. And the ``pred_scores`` usually contains following keys. - item (torch.Tensor): Classification scores, has a shape (num_classes, ) """ feats, predict_kwargs = self.extract_feat(inputs, test_mode=True) return self.cls_head.predict(feats, data_samples, **predict_kwargs)
def _forward(self, inputs: torch.Tensor, stage: str = "backbone", **kwargs) -> torch.Tensor: """Network forward process for export procedure. Args: inputs (torch.Tensor): Raw Inputs of the recognizer. stage (str): Which stage to output the features. """ feats, predict_kwargs = self.extract_feat(inputs, test_mode=True) cls_scores = self.cls_head(feats, **predict_kwargs) num_segs = cls_scores.shape[0] // inputs.shape[1] return self.cls_head.average_clip(cls_scores, num_segs=num_segs) @staticmethod def _merge_dict(*args) -> dict: """Merge all dictionaries into one dictionary. If pytorch version >= 1.8, ``merge_dict`` will be wrapped by ``torch.fx.wrap``, which will make ``torch.fx.symbolic_trace`` skip trace ``merge_dict``. Note: If a function needs to be traced by ``torch.fx.symbolic_trace``, but inevitably needs to use ``update`` method of ``dict``(``update`` is not traceable). It should use ``merge_dict`` to replace ``xxx.update``. Args: *args: dictionary needs to be merged. Returns: dict: Merged dict from args """ output = {} for item in args: if not isinstance(item, dict): msg = f"all arguments of merge_dict should be a dict, but got {type(item)}" raise TypeError(msg) output.update(item) return output