Source code for otx.algo.utils.support_otx_v1
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Utility functions to guarantee the OTX1.x models."""
from __future__ import annotations
[docs]
class OTXv1Helper:
"""Helper class to support the backward compatibility of OTX v1."""
[docs]
@staticmethod
def load_common_ckpt(state_dict: dict, add_prefix: str = "") -> dict:
"""Load the OTX1.x model checkpoints that don't need special handling."""
state_dict = state_dict["model"]["state_dict"]
for key in list(state_dict.keys()):
val = state_dict.pop(key)
state_dict[add_prefix + key] = val
return state_dict
[docs]
@staticmethod
def load_cls_effnet_b0_ckpt(state_dict: dict, label_type: str, add_prefix: str = "") -> dict:
"""Load the OTX1.x efficientnet b0 classification checkpoints."""
state_dict = state_dict["model"]["state_dict"]
for key in list(state_dict.keys()):
val = state_dict.pop(key)
if key.startswith("features."):
new_key = "backbone." + key if "activ" not in key else key
elif key.startswith("output."):
new_key = key.replace("output", "head")
if label_type != "hlabel":
new_key = new_key.replace("asl", "fc")
val = val.t()
state_dict[add_prefix + new_key] = val
return state_dict
[docs]
@staticmethod
def load_cls_effnet_v2_ckpt(state_dict: dict, label_type: str, add_prefix: str = "") -> dict:
"""Load the OTX1.x efficientnet v2 classification checkpoints."""
state_dict = state_dict["model"]["state_dict"]
for key in list(state_dict.keys()):
val = state_dict.pop(key)
if key.startswith("model.classifier."):
new_key = key.replace("model.classifier", "head.fc")
if label_type != "hlabel":
val = val.t()
elif key.startswith("model"):
new_key = "backbone." + key
state_dict[add_prefix + new_key] = val
return state_dict
[docs]
@staticmethod
def load_cls_mobilenet_v3_ckpt(state_dict: dict, label_type: str, add_prefix: str = "") -> dict:
"""Load the OTX1.x mobilenet v3 classification checkpoints."""
state_dict = state_dict["model"]["state_dict"]
for key in list(state_dict.keys()):
val = state_dict.pop(key)
if key.startswith("classifier."):
if "4" in key:
new_key = "head." + key.replace("4", "3")
if label_type == "multilabel":
val = val.t()
else:
new_key = "head." + key
elif key.startswith("act"):
new_key = "head." + key
elif not key.startswith("backbone."):
new_key = "backbone." + key
state_dict[add_prefix + new_key] = val
return state_dict
[docs]
@staticmethod
def load_cls_deit_ckpt(state_dict: dict, add_prefix: str = "") -> dict:
"""Load the OTX1.x deit-tiny classification checkpoints."""
return OTXv1Helper.load_common_ckpt(state_dict, add_prefix)
[docs]
@staticmethod
def load_det_ckpt(state_dict: dict, add_prefix: str = "") -> dict:
"""Load the OTX1.x detection model checkpoints."""
state_dict = state_dict["model"]["state_dict"]
for key in list(state_dict.keys()):
val = state_dict.pop(key)
if not key.startswith("ema_"):
state_dict[add_prefix + key] = val
return state_dict
[docs]
@staticmethod
def load_ssd_ckpt(state_dict: dict, add_prefix: str = "") -> dict:
"""Load OTX1.x SSD model checkpoints."""
state_dict["model"]["state_dict"]["anchors"] = state_dict.pop("anchors", None)
return OTXv1Helper.load_det_ckpt(state_dict, add_prefix)
[docs]
@staticmethod
def load_iseg_ckpt(state_dict: dict, add_prefix: str = "") -> dict:
"""Load the instance segmentation model checkpoints."""
return OTXv1Helper.load_det_ckpt(state_dict, add_prefix)
[docs]
@staticmethod
def load_seg_segnext_ckpt(state_dict: dict, add_prefix: str = "") -> dict:
"""Load the OTX1.x segnext segmentation checkpoints."""
state_dict = state_dict["model"]["state_dict"]
for key in list(state_dict.keys()):
val = state_dict.pop(key)
if "ham.bases" not in key:
state_dict[add_prefix + key] = val
return state_dict
[docs]
@staticmethod
def load_seg_lite_hrnet_ckpt(state_dict: dict, add_prefix: str = "") -> dict:
"""Load the OTX1.x lite hrnet segmentation checkpoints."""
state_dict = state_dict["model"]["state_dict"]
for key in list(state_dict.keys()):
val = state_dict.pop(key)
state_dict[add_prefix + key] = val
return state_dict
[docs]
@staticmethod
def load_action_ckpt(state_dict: dict, add_prefix: str = "") -> dict:
"""Load the OTX1.x action cls/det model checkpoints."""
return OTXv1Helper.load_common_ckpt(state_dict, add_prefix)