Source code for otx.algo.classification.backbones.vision_transformer
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OpenMMLab. All rights reserved.
"""Copy from mmpretrain/models/backbones/vision_transformer.py."""
from __future__ import annotations
import math
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Literal
import torch
from timm.layers import (
LayerType,
Mlp,
PatchDropout,
PatchEmbed,
SwiGLUPacked,
get_act_layer,
get_norm_layer,
resample_abs_pos_embed,
resample_patch_embed,
trunc_normal_,
)
from timm.models._manipulate import adapt_input_conv
from timm.models.vision_transformer import Attention, Block
from torch import nn
from otx.algo.modules.base_module import BaseModule
if TYPE_CHECKING:
from pathlib import Path
import numpy as np
VIT_ARCH_TYPE = Literal[
"vit-t",
"vit-tiny",
"vit-s",
"vit-small",
"vit-b",
"vit-base",
"vit-l",
"vit-large",
"vit-h",
"vit-huge",
"dinov2-s",
"dinov2-small",
"dinov2-small-seg",
"dinov2-b",
"dinov2-base",
"dinov2-l",
"dinov2-large",
"dinov2-g",
"dinov2-giant",
]
[docs]
class VisionTransformer(BaseModule):
"""Implementation of Vision Transformer from Timm.
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
- https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
Args:
arch: Vision Transformer architecture.
img_size: Input image size.
patch_size: Patch size.
in_chans: Number of image input channels.
num_classes: Mumber of classes for classification head.
embed_dim: Transformer embedding dimension.
depth: Depth of transformer.
num_heads: Number of attention heads.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: Enable bias for qkv projections if True.
init_values: Layer-scale init values (layer-scale enabled if not None).
class_token: Use class token.
no_embed_class: Don't include position embeddings for class (or reg) tokens.
reg_tokens: Number of register tokens.
drop_rate: Head dropout rate.
pos_drop_rate: Position embedding dropout rate.
attn_drop_rate: Attention dropout rate.
drop_path_rate: Stochastic depth rate.
weight_init: Weight initialization scheme.
fix_init: Apply weight initialization fix (scaling w/ layer index).
embed_layer: Patch embedding layer.
norm_layer: Normalization layer.
act_layer: MLP activation layer.
block_fn: Transformer block layer.
interpolate_offset: work-around offset to apply when interpolating positional embeddings
lora: Enable LoRA training.
"""
arch_zoo: dict[str, dict] = { # noqa: RUF012
**dict.fromkeys(
["vit-t", "vit-tiny"],
{
"patch_size": 16,
"embed_dim": 192,
"depth": 12,
"num_heads": 3,
},
),
**dict.fromkeys(
["vit-s", "vit-small"],
{
"patch_size": 16,
"embed_dim": 384,
"depth": 12,
"num_heads": 6,
},
),
**dict.fromkeys(
["vit-b", "vit-base"],
{
"patch_size": 16,
"embed_dim": 768,
"depth": 12,
"num_heads": 12,
},
),
**dict.fromkeys(
["vit-l", "vit-large"],
{
"patch_size": 16,
"embed_dim": 1024,
"depth": 24,
"num_heads": 16,
},
),
**dict.fromkeys(
["vit-h", "vit-huge"],
{
# The same as the implementation in MAE
# <https://arxiv.org/abs/2111.06377>
"patch_size": 16,
"embed_dim": 1280,
"depth": 32,
"num_heads": 16,
},
),
**dict.fromkeys(
["dinov2-s", "dinov2-small"],
{
"patch_size": 14,
"embed_dim": 384,
"depth": 12,
"num_heads": 6,
"reg_tokens": 4,
"no_embed_class": True,
},
),
**dict.fromkeys(
["dinov2-small-seg"], # segmentation
{
"patch_size": 14,
"embed_dim": 384,
"depth": 12,
"num_heads": 6,
"reg_tokens": 0,
"no_embed_class": False,
"init_values": 1e-5,
},
),
**dict.fromkeys(
["dinov2-b", "dinov2-base"],
{
"patch_size": 14,
"embed_dim": 768,
"depth": 12,
"num_heads": 12,
"reg_tokens": 4,
"no_embed_class": True,
"init_values": 1e-5,
},
),
**dict.fromkeys(
["dinov2-l", "dinov2-large"],
{
"patch_size": 14,
"embed_dim": 1024,
"depth": 24,
"num_heads": 16,
"reg_tokens": 4,
"no_embed_class": True,
"init_values": 1e-5,
},
),
**dict.fromkeys(
["dinov2-g", "dinov2-giant"],
{
"patch_size": 14,
"embed_dim": 1536,
"depth": 40,
"num_heads": 24,
"reg_tokens": 4,
"no_embed_class": True,
"init_values": 1e-5,
"mlp_ratio": 2.66667 * 2,
"mlp_layer": SwiGLUPacked,
"act_layer": nn.SiLU,
},
),
}
def __init__( # noqa: PLR0913
self,
arch: VIT_ARCH_TYPE | str = "vit-base",
img_size: int | tuple[int, int] = 224,
patch_size: int | None = None,
in_chans: int = 3,
num_classes: int = 1000,
embed_dim: int | None = None,
depth: int | None = None,
num_heads: int | None = None,
mlp_ratio: float | None = None,
qkv_bias: bool = True,
qk_norm: bool = False,
init_values: float | None = None,
class_token: bool = True,
no_embed_class: bool | None = None,
reg_tokens: int | None = None,
pre_norm: bool = False,
dynamic_img_size: bool = False,
dynamic_img_pad: bool = False,
pos_drop_rate: float = 0.0,
patch_drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
embed_layer: Callable = PatchEmbed,
block_fn: nn.Module = Block,
mlp_layer: nn.Module | None = None,
act_layer: LayerType | None = None,
norm_layer: LayerType | None = None,
interpolate_offset: float = 0.1,
lora: bool = False,
) -> None:
super().__init__()
if isinstance(arch, str):
if arch not in set(self.arch_zoo):
msg = f"Arch {arch} is not in default archs {set(self.arch_zoo)}"
raise ValueError(msg)
arch_settings: dict[str, Any] = self.arch_zoo[arch]
self.img_size: int | tuple[int, int] = img_size
self.patch_size: int = patch_size or arch_settings.get("patch_size", 16)
self.embed_dim = embed_dim or arch_settings.get("embed_dim", 768)
depth = depth or arch_settings.get("depth", 12)
num_heads = num_heads or arch_settings.get("num_heads", 12)
no_embed_class = no_embed_class or arch_settings.get("no_embed_class", False)
reg_tokens = reg_tokens or arch_settings.get("reg_tokens", 0)
init_values = init_values or arch_settings.get("init_values", None)
mlp_layer = mlp_layer or arch_settings.get("mlp_layer", Mlp)
mlp_ratio = mlp_ratio or arch_settings.get("mlp_ratio", 4.0)
norm_layer = get_norm_layer(norm_layer) or arch_settings.get("norm_layer", partial(nn.LayerNorm, eps=1e-6))
act_layer = get_act_layer(act_layer) or arch_settings.get("act_layer", nn.GELU)
self.num_classes = num_classes
self.num_prefix_tokens = 1 if class_token else 0
self.num_prefix_tokens += reg_tokens
self.num_reg_tokens = reg_tokens
self.has_class_token = class_token
self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
self.dynamic_img_size = dynamic_img_size
self.grad_checkpointing = False
self.interpolate_offset = interpolate_offset
embed_args = {}
if dynamic_img_size:
# flatten deferred until after pos embed
embed_args.update({"strict_img_size": False, "output_fmt": "NHWC"})
self.patch_embed = embed_layer(
img_size=self.img_size,
patch_size=self.patch_size,
in_chans=in_chans,
embed_dim=self.embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
dynamic_img_pad=dynamic_img_pad,
**embed_args,
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) if class_token else None
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, self.embed_dim)) if reg_tokens else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(torch.zeros(1, embed_len, self.embed_dim))
self.pos_drop = nn.Dropout(p=pos_drop_rate)
if patch_drop_rate > 0:
self.patch_drop = PatchDropout(
patch_drop_rate,
num_prefix_tokens=self.num_prefix_tokens,
)
else:
self.patch_drop = nn.Identity()
self.norm_pre = norm_layer(self.embed_dim) if pre_norm else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(
*[
block_fn(
dim=self.embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
init_values=init_values,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
mlp_layer=mlp_layer,
)
for i in range(depth)
],
)
self.norm = norm_layer(self.embed_dim)
self.lora = lora
if self.lora:
lora_rank = 8
lora_alpha = 1.0
assign_lora = partial(AttentionWithLoRA, rank=lora_rank, alpha=lora_alpha)
for block in self.blocks:
block.attn.qkv = assign_lora(block.attn.qkv)
# Freeze all params
for param in self.parameters():
param.requires_grad = False
# Unfreeze LoRA layers
for block in self.blocks:
for param in block.attn.qkv.lora_q.parameters():
param.requires_grad = True
for param in block.attn.qkv.lora_v.parameters():
param.requires_grad = True
[docs]
def init_weights(self) -> None:
"""Initializes the weights of the VisionTransformer."""
super().init_weights()
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=0.02)
[docs]
@torch.jit.ignore()
def load_pretrained(self, checkpoint_path: Path, prefix: str = "") -> None:
"""Loads the pretrained weight to the VisionTransformer."""
checkpoint_ext = checkpoint_path.suffix
if checkpoint_ext == ".npz": # deit models
self._load_npz_weights(self, checkpoint_path, prefix)
elif checkpoint_ext == ".pth": # dinov2 models
def resize_positional_embeddings(pos_embed: torch.Tensor, new_shape: tuple[int, int]) -> torch.Tensor:
# Resize the embeddings using bilinear interpolation.
pos_embed = pos_embed.permute(0, 2, 1).reshape(1, -1, 37, 37) # 560 (img_size) / 14 (patch_size) = 37
pos_embed_resized = nn.functional.interpolate(
pos_embed,
size=(new_shape[0], new_shape[1]),
mode="bilinear",
)
return pos_embed_resized.reshape(1, -1, new_shape[0] * new_shape[1]).permute(0, 2, 1)
# convert dinov2 pretrained weights
state_dict = torch.load(checkpoint_path)
state_dict.pop("mask_token", None)
if "reg_token" in state_dict:
state_dict["reg_token"] = state_dict.pop("register_tokens")
state_dict["cls_token"] = state_dict.pop("cls_token") + state_dict["pos_embed"][:, 0]
img_size = (self.img_size, self.img_size) if isinstance(self.img_size, int) else self.img_size
patch_size = (self.patch_size, self.patch_size)
if state_dict["pos_embed"].shape != self.pos_embed.shape:
state_dict["pos_embed"] = resize_positional_embeddings(
state_dict.pop("pos_embed")[:, 1:],
(img_size[0] // patch_size[0], img_size[1] // patch_size[1]),
)
self.load_state_dict(state_dict, strict=False)
else:
msg = f"Unsupported `checkpoint_extension` {checkpoint_ext}, please choose from 'npz' or 'pth'."
raise ValueError(msg)
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
"""Implements positional embedding."""
if self.dynamic_img_size:
b, h, w, c = x.shape
pos_embed = resample_abs_pos_embed(
self.pos_embed,
(h, w),
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
)
x = x.view(b, -1, c)
else:
pos_embed = self.pos_embed
to_cat = []
if self.cls_token is not None:
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
if self.reg_token is not None:
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + pos_embed
if to_cat:
x = torch.cat(to_cat + [x], dim=1) # noqa: RUF005
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if to_cat:
x = torch.cat(to_cat + [x], dim=1) # noqa: RUF005
x = x + pos_embed
return self.pos_drop(x)
[docs]
def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor:
"""Interpolates the positional encoding to match the input dimensions.
Args:
x (torch.Tensor): Input tensor.
w (int): Width of the input image.
h (int): Height of the input image.
Returns:
torch.Tensor: Tensor with interpolated positional encoding.
"""
previous_dtype = x.dtype
npatch = x.shape[1]
n = self.pos_embed.shape[1]
if npatch == n and w == h:
return self.pos_embed
pos_embed = self.pos_embed.float()
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
m = int(math.sqrt(n)) # Recover the number of patches in each dimension
if m * m != n:
msg = f"Expected m * m to equal n, but got m={m}, n={n}"
raise ValueError(msg)
kwargs = {}
if self.interpolate_offset:
# fix float error by introducing small offset
sx = float(w0 + self.interpolate_offset) / m
sy = float(h0 + self.interpolate_offset) / m
kwargs["scale_factor"] = (sx, sy)
else:
# Simply specify an output size instead of a scale factor
kwargs["size"] = (w0, h0)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, m, m, dim).permute(0, 3, 1, 2),
mode="bicubic",
**kwargs,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
[docs]
def prepare_tokens_with_masks(self, x: torch.Tensor, masks: torch.Tensor | None = None) -> torch.Tensor:
"""Prepare tokens with optional masks.
Args:
x (torch.Tensor): Input tensor.
masks (torch.Tensor | None): Optional masks tensor.
Returns:
torch.Tensor: Tensor with prepared tokens.
"""
_, _, w, h = x.shape
x = self.patch_embed(x)
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.interpolate_pos_encoding(x, w, h)
if self.reg_token is not None:
x = torch.cat(
(
x[:, :1],
self.reg_token.expand(x.shape[0], -1, -1),
x[:, 1:],
),
dim=1,
)
return x
def _get_intermediate_layers_not_chunked(self, x: torch.Tensor, n: int = 1) -> list[torch.Tensor]:
"""Get intermediate layers without chunking.
Args:
x (torch.Tensor): Input tensor.
n (int): Number of last blocks to take. If it's a list, take the specified blocks.
Returns:
list[torch.Tensor]: List of intermediate layer outputs.
"""
x = self.prepare_tokens_with_masks(x)
# If n is an int, take the n last blocks. If it's a list, take them
output, total_block_len = [], len(self.blocks)
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in blocks_to_take:
output.append(x)
if len(output) != len(blocks_to_take):
msg = f"only {len(output)} / {len(blocks_to_take)} blocks found"
raise RuntimeError(msg)
return output
[docs]
def get_intermediate_layers(
self,
x: torch.Tensor,
n: int = 1, # Layers or n last layers to take
reshape: bool = False,
return_class_token: bool = False,
norm: bool = True,
) -> tuple:
"""Get intermediate layers of the VisionTransformer.
Args:
x (torch.Tensor): Input tensor.
n (int): Number of last blocks to take. If it's a list, take the specified blocks.
reshape (bool): Whether to reshape the output feature maps.
return_class_token (bool): Whether to return the class token.
norm (bool): Whether to apply normalization to the outputs.
Returns:
tuple: A tuple containing the intermediate layer outputs.
"""
outputs = self._get_intermediate_layers_not_chunked(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
class_tokens = [out[:, 0] for out in outputs]
outputs = [out[:, 1 + self.num_reg_tokens :] for out in outputs]
if reshape:
b, _, w, h = x.shape
outputs = [
out.reshape(b, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
for out in outputs
]
if return_class_token:
return tuple(zip(outputs, class_tokens))
return tuple(outputs)
[docs]
def forward(
self,
x: torch.Tensor,
out_type: Literal["raw", "cls_token", "featmap", "avg_featmap"] = "cls_token",
) -> tuple:
"""Forward pass of the VisionTransformer model."""
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
x = self.blocks(x)
x = self.norm(x)
if out_type == "raw":
return (x,)
if out_type == "cls_token":
return (x[:, 0],)
msg = f"Unsupported `out_type` {out_type}, please choose from {self.OUT_TYPES}"
raise ValueError(msg)
@torch.no_grad()
def _load_npz_weights( # noqa: C901
self,
model: VisionTransformer,
checkpoint_path: str,
prefix: str = "",
) -> None:
"""Load weights from .npz checkpoints for official Google Brain Flax implementation."""
import numpy as np
def _n2p(w: np.ndarray, t: bool = True, idx: int | None = None) -> torch.Tensor:
if idx is not None:
w = w[idx]
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
w = w.flatten()
if t:
if w.ndim == 4:
w = w.transpose([3, 2, 0, 1])
elif w.ndim == 3:
w = w.transpose([2, 0, 1])
elif w.ndim == 2:
w = w.transpose([1, 0])
return torch.from_numpy(w)
w = np.load(checkpoint_path)
interpolation = "bilinear"
antialias = False
big_vision = False
if not prefix:
if "opt/target/embedding/kernel" in w:
prefix = "opt/target/"
elif "params/embedding/kernel" in w:
prefix = "params/"
big_vision = True
elif "params/img/embedding/kernel" in w:
prefix = "params/img/"
big_vision = True
if hasattr(model.patch_embed, "backbone"):
# hybrid
backbone = model.patch_embed.backbone
stem_only = not hasattr(backbone, "stem")
stem = backbone if stem_only else backbone.stem
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f"{prefix}conv_root/kernel"])))
stem.norm.weight.copy_(_n2p(w[f"{prefix}gn_root/scale"]))
stem.norm.bias.copy_(_n2p(w[f"{prefix}gn_root/bias"]))
if not stem_only:
for i, stage in enumerate(backbone.stages):
for j, block in enumerate(stage.blocks):
bp = f"{prefix}block{i + 1}/unit{j + 1}/"
for r in range(3):
getattr(block, f"conv{r + 1}").weight.copy_(_n2p(w[f"{bp}conv{r + 1}/kernel"]))
getattr(block, f"norm{r + 1}").weight.copy_(_n2p(w[f"{bp}gn{r + 1}/scale"]))
getattr(block, f"norm{r + 1}").bias.copy_(_n2p(w[f"{bp}gn{r + 1}/bias"]))
if block.downsample is not None:
block.downsample.conv.weight.copy_(_n2p(w[f"{bp}conv_proj/kernel"]))
block.downsample.norm.weight.copy_(_n2p(w[f"{bp}gn_proj/scale"]))
block.downsample.norm.bias.copy_(_n2p(w[f"{bp}gn_proj/bias"]))
embed_conv_w = _n2p(w[f"{prefix}embedding/kernel"])
else:
embed_conv_w = adapt_input_conv(
model.patch_embed.proj.weight.shape[1],
_n2p(w[f"{prefix}embedding/kernel"]),
)
if embed_conv_w.shape[-2:] != model.patch_embed.proj.weight.shape[-2:]:
embed_conv_w = resample_patch_embed(
embed_conv_w,
model.patch_embed.proj.weight.shape[-2:],
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
model.patch_embed.proj.weight.copy_(embed_conv_w)
model.patch_embed.proj.bias.copy_(_n2p(w[f"{prefix}embedding/bias"]))
if model.cls_token is not None:
model.cls_token.copy_(_n2p(w[f"{prefix}cls"], t=False))
if big_vision:
pos_embed_w = _n2p(w[f"{prefix}pos_embedding"], t=False)
else:
pos_embed_w = _n2p(w[f"{prefix}Transformer/posembed_input/pos_embedding"], t=False)
if pos_embed_w.shape != model.pos_embed.shape:
num_prefix_tokens = 0 if getattr(model, "no_embed_class", False) else getattr(model, "num_prefix_tokens", 1)
pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w,
new_size=model.patch_embed.grid_size,
num_prefix_tokens=num_prefix_tokens,
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
model.pos_embed.copy_(pos_embed_w)
model.norm.weight.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/scale"]))
model.norm.bias.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/bias"]))
mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2)
idx: int | None = None
for i, block in enumerate(model.blocks.children()):
if f"{prefix}Transformer/encoderblock/LayerNorm_0/scale" in w:
block_prefix = f"{prefix}Transformer/encoderblock/"
idx = i
else:
embed_conv_w = adapt_input_conv(
model.patch_embed.proj.weight.shape[1],
_n2p(w[f"{prefix}embedding/kernel"]),
)
if embed_conv_w.shape[-2:] != model.patch_embed.proj.weight.shape[-2:]:
embed_conv_w = resample_patch_embed(
embed_conv_w,
model.patch_embed.proj.weight.shape[-2:],
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
model.patch_embed.proj.weight.copy_(embed_conv_w)
model.patch_embed.proj.bias.copy_(_n2p(w[f"{prefix}embedding/bias"]))
if model.cls_token is not None:
model.cls_token.copy_(_n2p(w[f"{prefix}cls"], t=False))
if big_vision:
pos_embed_w = _n2p(w[f"{prefix}pos_embedding"], t=False)
else:
pos_embed_w = _n2p(w[f"{prefix}Transformer/posembed_input/pos_embedding"], t=False)
if pos_embed_w.shape != model.pos_embed.shape:
num_prefix_tokens = (
0 if getattr(model, "no_embed_class", False) else getattr(model, "num_prefix_tokens", 1)
)
pos_embed_w = (
resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w,
new_size=model.patch_embed.grid_size,
num_prefix_tokens=num_prefix_tokens,
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
)
model.pos_embed.copy_(pos_embed_w)
model.norm.weight.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/scale"]))
model.norm.bias.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/bias"]))
mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2)
for i, block in enumerate(model.blocks.children()): # noqa: PLW2901
if f"{prefix}Transformer/encoderblock/LayerNorm_0/scale" in w:
block_prefix = f"{prefix}Transformer/encoderblock/"
idx = i
else:
block_prefix = f"{prefix}Transformer/encoderblock_{i}/"
idx = None
mha_prefix = block_prefix + f"MultiHeadDotProductAttention_{mha_sub}/"
block.norm1.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/scale"], idx=idx))
block.norm1.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/bias"], idx=idx))
if not self.lora:
block.attn.qkv.weight.copy_(
torch.cat(
[
_n2p(w[f"{mha_prefix}{n}/kernel"], t=False, idx=idx).flatten(1).T
for n in ("query", "key", "value")
],
),
)
block.attn.qkv.bias.copy_(
torch.cat(
[
_n2p(w[f"{mha_prefix}{n}/bias"], t=False, idx=idx).reshape(-1)
for n in ("query", "key", "value")
],
),
)
else:
block.attn.qkv.qkv.weight.copy_(
torch.cat(
[
_n2p(w[f"{mha_prefix}{n}/kernel"], t=False, idx=idx).flatten(1).T
for n in ("query", "key", "value")
],
),
)
block.attn.qkv.qkv.bias.copy_(
torch.cat(
[
_n2p(w[f"{mha_prefix}{n}/bias"], t=False, idx=idx).reshape(-1)
for n in ("query", "key", "value")
],
),
)
block.attn.proj.weight.copy_(_n2p(w[f"{mha_prefix}out/kernel"], idx=idx).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f"{mha_prefix}out/bias"], idx=idx))
block.norm2.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_{ln1_sub}/scale"], idx=idx))
block.norm2.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_{ln1_sub}/bias"], idx=idx))
for r in range(2):
getattr(block.mlp, f"fc{r + 1}").weight.copy_(
_n2p(w[f"{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel"], idx=idx),
)
getattr(block.mlp, f"fc{r + 1}").bias.copy_(
_n2p(w[f"{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias"], idx=idx),
)
class LoRALayer(torch.nn.Module):
"""LoRA layer implementation for computing A, B composition."""
def __init__(self, in_dim: int, out_dim: int, rank: int, alpha: float):
super().__init__()
std = torch.sqrt(torch.tensor(rank).float())
self.A = torch.nn.Parameter(torch.randn(in_dim, rank) / std)
self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
self.alpha = alpha
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the LoRA layer."""
return self.alpha * (x @ self.A @ self.B)
class AttentionWithLoRA(torch.nn.Module):
"""Add LoRA layer into QKV attention layer in VisionTransformer."""
def __init__(self, qkv: Attention, rank: int, alpha: float):
super().__init__()
self.qkv = qkv
self.dim = qkv.in_features
self.lora_q = LoRALayer(self.dim, self.dim, rank, alpha)
self.lora_v = LoRALayer(self.dim, self.dim, rank, alpha)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the AttentionWithLoRA."""
qkv = self.qkv(x)
qkv[:, :, : self.dim] += self.lora_q(x)
qkv[:, :, -self.dim :] += self.lora_v(x)
return qkv