# The implementation here is modified based on timm,
# originally Apache 2.0 License and publicly available at
# https://github.com/naver-ai/vidt/blob/vidt-plus/methods/swin_w_ram.py

import math
import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_


class Mlp(nn.Module):
    """ Multilayer perceptron."""

    def __init__(self,
                 in_features,
                 hidden_features=None,
                 out_features=None,
                 act_layer=nn.GELU,
                 drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


def masked_sin_pos_encoding(x,
                            mask,
                            num_pos_feats,
                            temperature=10000,
                            scale=2 * math.pi):
    """ Masked Sinusoidal Positional Encoding

    Args:
        x: [PATCH] tokens
        mask: the padding mask for [PATCH] tokens
        num_pos_feats: the size of channel dimension
        temperature: the temperature value
        scale: the normalization scale

    Returns:
        pos: Sinusoidal positional encodings
    """

    num_pos_feats = num_pos_feats // 2
    not_mask = ~mask

    y_embed = not_mask.cumsum(1, dtype=torch.float32)
    x_embed = not_mask.cumsum(2, dtype=torch.float32)

    eps = 1e-6
    y_embed = y_embed / (y_embed[:, -1:, :] + eps) * scale
    x_embed = x_embed / (x_embed[:, :, -1:] + eps) * scale

    dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=x.device)
    dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats)

    pos_x = x_embed[:, :, :, None] / dim_t
    pos_y = y_embed[:, :, :, None] / dim_t

    pos_x = torch.stack(
        (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
        dim=4).flatten(3)
    pos_y = torch.stack(
        (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
        dim=4).flatten(3)
    pos = torch.cat((pos_y, pos_x), dim=3)

    return pos


def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size,
               C)
    windows = x.permute(0, 1, 3, 2, 4,
                        5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size,
                     window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


class ReconfiguredAttentionModule(nn.Module):
    """ Window based multi-head self attention (W-MSA) module with relative position bias -> extended with RAM.
    It supports both of shifted and non-shifted window.

    !!!!!!!!!!! IMPORTANT !!!!!!!!!!!
    The original attention module in Swin is replaced with the reconfigured attention module in Section 3.
    All the Args are shared, so only the forward function is modified.
    See https://arxiv.org/pdf/2110.03921.pdf
    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self,
                 dim,
                 window_size,
                 num_heads,
                 qkv_bias=True,
                 qk_scale=None,
                 attn_drop=0.,
                 proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim**-0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
                        num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :,
                                         None] - coords_flatten[:,
                                                                None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(
            1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :,
                        0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer('relative_position_index',
                             relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self,
                x,
                det,
                mask=None,
                cross_attn=False,
                cross_attn_mask=None):
        """ Forward function.
        RAM module receives [Patch] and [DET] tokens and returns their calibrated ones

        Args:
            x: [PATCH] tokens
            det: [DET] tokens
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None -> mask for shifted window attention

            "additional inputs for RAM"
            cross_attn: whether to use cross-attention [det x patch] (for selective cross-attention)
            cross_attn_mask: mask for cross-attention

        Returns:
            patch_x: the calibrated [PATCH] tokens
            det_x: the calibrated [DET] tokens
        """

        assert self.window_size[0] == self.window_size[1]
        window_size = self.window_size[0]
        local_map_size = window_size * window_size

        # projection before window partitioning
        if not cross_attn:
            B, H, W, C = x.shape
            N = H * W
            x = x.view(B, N, C)
            x = torch.cat([x, det], dim=1)
            full_qkv = self.qkv(x)
            patch_qkv, det_qkv = full_qkv[:, :N, :], full_qkv[:, N:, :]
        else:
            B, H, W, C = x[0].shape
            N = H * W
            _, ori_H, ori_W, _ = x[1].shape
            ori_N = ori_H * ori_W

            shifted_x = x[0].view(B, N, C)
            cross_x = x[1].view(B, ori_N, C)
            x = torch.cat([shifted_x, cross_x, det], dim=1)
            full_qkv = self.qkv(x)
            patch_qkv, cross_patch_qkv, det_qkv = \
                full_qkv[:, :N, :], full_qkv[:, N:N + ori_N, :], full_qkv[:, N + ori_N:, :]
        patch_qkv = patch_qkv.view(B, H, W, -1)

        # window partitioning for [PATCH] tokens
        patch_qkv = window_partition(
            patch_qkv, window_size)  # nW*B, window_size, window_size, C
        B_ = patch_qkv.shape[0]
        patch_qkv = patch_qkv.reshape(B_, window_size * window_size, 3,
                                      self.num_heads, C // self.num_heads)
        _patch_qkv = patch_qkv.permute(2, 0, 3, 1, 4)
        patch_q, patch_k, patch_v = _patch_qkv[0], _patch_qkv[1], _patch_qkv[2]

        # [PATCH x PATCH] self-attention using window partitions
        patch_q = patch_q * self.scale
        patch_attn = (patch_q @ patch_k.transpose(-2, -1))
        # add relative pos bias for [patch x patch] self-attention
        relative_position_bias = self.relative_position_bias_table[
            self.relative_position_index.view(-1)].view(
                self.window_size[0] * self.window_size[1],
                self.window_size[0] * self.window_size[1],
                -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(
            2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        patch_attn = patch_attn + relative_position_bias.unsqueeze(0)

        # if shifted window is used, it needs to apply the mask
        if mask is not None:
            nW = mask.shape[0]
            tmp0 = patch_attn.view(B_ // nW, nW, self.num_heads,
                                   local_map_size, local_map_size)
            tmp1 = mask.unsqueeze(1).unsqueeze(0)
            patch_attn = tmp0 + tmp1
            patch_attn = patch_attn.view(-1, self.num_heads, local_map_size,
                                         local_map_size)

        patch_attn = self.softmax(patch_attn)
        patch_attn = self.attn_drop(patch_attn)
        patch_x = (patch_attn @ patch_v).transpose(1, 2).reshape(
            B_, window_size, window_size, C)

        # extract qkv for [DET] tokens
        det_qkv = det_qkv.view(B, -1, 3, self.num_heads, C // self.num_heads)
        det_qkv = det_qkv.permute(2, 0, 3, 1, 4)
        det_q, det_k, det_v = det_qkv[0], det_qkv[1], det_qkv[2]

        # if cross-attention is activated
        if cross_attn:

            # reconstruct the spatial form of [PATCH] tokens for global [DET x PATCH] attention
            cross_patch_qkv = cross_patch_qkv.view(B, ori_H, ori_W, 3,
                                                   self.num_heads,
                                                   C // self.num_heads)
            patch_kv = cross_patch_qkv[:, :, :,
                                       1:, :, :].permute(3, 0, 4, 1, 2,
                                                         5).contiguous()
            patch_kv = patch_kv.view(2, B, self.num_heads, ori_H * ori_W, -1)

            # extract "key and value" of [PATCH] tokens for cross-attention
            cross_patch_k, cross_patch_v = patch_kv[0], patch_kv[1]

            # bind key and value of [PATCH] and [DET] tokens for [DET X [PATCH, DET]] attention
            det_k, det_v = torch.cat([cross_patch_k, det_k],
                                     dim=2), torch.cat([cross_patch_v, det_v],
                                                       dim=2)

        # [DET x DET] self-attention or binded [DET x [PATCH, DET]] attention
        det_q = det_q * self.scale
        det_attn = (det_q @ det_k.transpose(-2, -1))
        # apply cross-attention mask if available
        if cross_attn_mask is not None:
            det_attn = det_attn + cross_attn_mask
        det_attn = self.softmax(det_attn)
        det_attn = self.attn_drop(det_attn)
        det_x = (det_attn @ det_v).transpose(1, 2).reshape(B, -1, C)

        # reverse window for [PATCH] tokens <- the output of [PATCH x PATCH] self attention
        patch_x = window_reverse(patch_x, window_size, H, W)

        # projection for outputs from multi-head
        x = torch.cat([patch_x.view(B, H * W, C), det_x], dim=1)
        x = self.proj(x)
        x = self.proj_drop(x)

        # decompose after FFN into [PATCH] and [DET] tokens
        patch_x = x[:, :H * W, :].view(B, H, W, C)
        det_x = x[:, H * W:, :]

        return patch_x, det_x


class SwinTransformerBlock(nn.Module):
    """ Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self,
                 dim,
                 num_heads,
                 window_size=7,
                 shift_size=0,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'

        self.norm1 = norm_layer(dim)
        self.attn = ReconfiguredAttentionModule(
            dim,
            window_size=to_2tuple(self.window_size),
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop)

        self.drop_path = DropPath(
            drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop)

        self.H = None
        self.W = None

    def forward(self, x, mask_matrix, pos, cross_attn, cross_attn_mask):
        """ Forward function.

        Args:
            x: Input feature, tensor size (B, H*W + DET, C). i.e., binded [PATCH, DET] tokens
            H, W: Spatial resolution of the input feature.
            mask_matrix: Attention mask for cyclic shift.

            "additional inputs'
            pos: (patch_pos, det_pos)
            cross_attn: whether to use cross attn [det x [det + patch]]
            cross_attn_mask: attention mask for cross-attention

        Returns:
            x: calibrated & binded [PATCH, DET] tokens
        """

        B, L, C = x.shape
        H, W = self.H, self.W

        assert L == H * W + self.det_token_num, 'input feature has wrong size'

        shortcut = x
        x = self.norm1(x)
        x, det = x[:, :H * W, :], x[:, H * W:, :]
        x = x.view(B, H, W, C)
        orig_x = x

        # pad feature maps to multiples of window size
        pad_l = pad_t = 0
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, Hp, Wp, _ = x.shape

        # projection for det positional encodings: make the channel size suitable for the current layer
        patch_pos, det_pos = pos
        det_pos = self.det_pos_linear(det_pos)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(
                x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
            attn_mask = mask_matrix
        else:
            shifted_x = x
            attn_mask = None

        # prepare cross-attn and add positional encodings
        if cross_attn:
            # patch token (for cross-attention) + Sinusoidal pos encoding
            cross_patch = orig_x + patch_pos
            # det token + learnable pos encoding
            det = det + det_pos
            shifted_x = (shifted_x, cross_patch)
        else:
            # it cross_attn is deactivated, only [PATCH] and [DET] self-attention are performed
            det = det + det_pos
            shifted_x = shifted_x

        # W-MSA/SW-MSA
        shifted_x, det = self.attn(
            shifted_x,
            mask=attn_mask,
            # additional args
            det=det,
            cross_attn=cross_attn,
            cross_attn_mask=cross_attn_mask)

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(
                shifted_x,
                shifts=(self.shift_size, self.shift_size),
                dims=(1, 2))
        else:
            x = shifted_x

        if pad_r > 0 or pad_b > 0:
            x = x[:, :H, :W, :].contiguous()

        x = x.view(B, H * W, C)
        x = torch.cat([x, det], dim=1)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


class PatchMerging(nn.Module):
    """ Patch Merging Layer

    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, norm_layer=nn.LayerNorm, expand=True):
        super().__init__()
        self.dim = dim

        # if expand is True, the channel size will be expanded, otherwise, return 256 size of channel
        expand_dim = 2 * dim if expand else 256
        self.reduction = nn.Linear(4 * dim, expand_dim, bias=False)
        self.norm = norm_layer(4 * dim)

        # added for detection token [please ignore, not used for training]
        # not implemented yet.
        self.expansion = nn.Linear(dim, expand_dim, bias=False)
        self.norm2 = norm_layer(dim)

    def forward(self, x, H, W):
        """ Forward function.

        Args:
            x: Input feature, tensor size (B, H*W, C), i.e., binded [PATCH, DET] tokens
            H, W: Spatial resolution of the input feature.

        Returns:
            x: merged [PATCH, DET] tokens;
            only [PATCH] tokens are reduced in spatial dim, while [DET] tokens is fix-scale
        """

        B, L, C = x.shape
        assert L == H * W + self.det_token_num, 'input feature has wrong size'

        x, det = x[:, :H * W, :], x[:, H * W:, :]
        x = x.view(B, H, W, C)

        # padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        # simply repeating for DET tokens
        det = det.repeat(1, 1, 4)

        x = torch.cat([x, det], dim=1)
        x = self.norm(x)
        x = self.reduction(x)

        return x


class BasicLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of feature channels
        depth (int): Depths of this stage.
        num_heads (int): Number of attention head.
        window_size (int): Local window size. Default: 7.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self,
                 dim,
                 depth,
                 num_heads,
                 window_size=7,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 norm_layer=nn.LayerNorm,
                 downsample=None,
                 last=False,
                 use_checkpoint=False):

        super().__init__()
        self.window_size = window_size
        self.shift_size = window_size // 2
        self.depth = depth
        self.dim = dim
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(
                dim=dim,
                num_heads=num_heads,
                window_size=window_size,
                shift_size=0 if (i % 2 == 0) else window_size // 2,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path[i]
                if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer) for i in range(depth)
        ])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(
                dim=dim, norm_layer=norm_layer, expand=(not last))
        else:
            self.downsample = None

    def forward(self, x, H, W, det_pos, input_mask, cross_attn=False):
        """ Forward function.

        Args:
            x: Input feature, tensor size (B, H*W, C).
            H, W: Spatial resolution of the input feature.
            det_pos: pos encoding for det token
            input_mask: padding mask for inputs
            cross_attn: whether to use cross attn [det x [det + patch]]
        """

        B = x.shape[0]

        # calculate attention mask for SW-MSA
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size,
                          -self.shift_size), slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size,
                          -self.shift_size), slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        # mask for cyclic shift
        mask_windows = window_partition(img_mask, self.window_size)
        mask_windows = mask_windows.view(-1,
                                         self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0,
                                          float(-100.0)).masked_fill(
                                              attn_mask == 0, float(0.0))

        # compute sinusoidal pos encoding and cross-attn mask here to avoid redundant computation
        if cross_attn:

            _H, _W = input_mask.shape[1:]
            if not (_H == H and _W == W):
                input_mask = F.interpolate(
                    input_mask[None].float(), size=(H, W)).to(torch.bool)[0]

            # sinusoidal pos encoding for [PATCH] tokens used in cross-attention
            patch_pos = masked_sin_pos_encoding(x, input_mask, self.dim)

            # attention padding mask due to the zero padding in inputs
            # the zero (padded) area is masked by 1.0 in 'input_mask'
            cross_attn_mask = input_mask.float()
            cross_attn_mask = cross_attn_mask.masked_fill(cross_attn_mask != 0.0, float(-100.0)). \
                masked_fill(cross_attn_mask == 0.0, float(0.0))

            # pad for detection token (this padding is required to process the binded [PATCH, DET] attention
            cross_attn_mask = cross_attn_mask.view(
                B, H * W).unsqueeze(1).unsqueeze(2)
            cross_attn_mask = F.pad(
                cross_attn_mask, (0, self.det_token_num), value=0)

        else:
            patch_pos = None
            cross_attn_mask = None

        # zip pos encodings
        pos = (patch_pos, det_pos)

        for n_blk, blk in enumerate(self.blocks):
            blk.H, blk.W = H, W

            # for selective cross-attention
            if cross_attn:
                _cross_attn = True
                _cross_attn_mask = cross_attn_mask
                _pos = pos  # i.e., (patch_pos, det_pos)
            else:
                _cross_attn = False
                _cross_attn_mask = None
                _pos = (None, det_pos)

            if self.use_checkpoint:
                x = checkpoint.checkpoint(
                    blk,
                    x,
                    attn_mask,
                    # additional inputs
                    pos=_pos,
                    cross_attn=_cross_attn,
                    cross_attn_mask=_cross_attn_mask)
            else:
                x = blk(
                    x,
                    attn_mask,
                    # additional inputs
                    pos=_pos,
                    cross_attn=_cross_attn,
                    cross_attn_mask=_cross_attn_mask)

        # reduce the number of patch tokens, but maintaining a fixed-scale det tokens
        # meanwhile, the channel dim increases by a factor of 2
        if self.downsample is not None:
            x_down = self.downsample(x, H, W)
            Wh, Ww = (H + 1) // 2, (W + 1) // 2
            return x, H, W, x_down, Wh, Ww
        else:
            return x, H, W, x, H, W


class PatchEmbed(nn.Module):
    """ Image to Patch Embedding

    Args:
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self,
                 patch_size=4,
                 in_chans=3,
                 embed_dim=96,
                 norm_layer=None):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.patch_size = patch_size

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        """Forward function."""

        # padding
        _, _, H, W = x.size()
        if W % self.patch_size[1] != 0:
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
        if H % self.patch_size[0] != 0:
            x = F.pad(x,
                      (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))

        x = self.proj(x)  # B C Wh Ww
        if self.norm is not None:
            Wh, Ww = x.size(2), x.size(3)
            x = x.flatten(2).transpose(1, 2)
            x = self.norm(x)
            x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)

        return x


class SwinTransformer(nn.Module):
    """ Swin Transformer backbone.
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030

    Args:
        pretrain_img_size (int): Input image size for training the pretrained model,
            used in absolute position embedding. Default 224.
        patch_size (int | tuple(int)): Patch size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        depths (tuple[int]): Depths of each Swin Transformer stage.
        num_heads (tuple[int]): Number of attention head of each stage.
        window_size (int): Window size. Default: 7.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
        drop_rate (float): Dropout rate.
        attn_drop_rate (float): Attention dropout rate. Default: 0.
        drop_path_rate (float): Stochastic depth rate. Default: 0.2.
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
        patch_norm (bool): If True, add normalization after patch embedding. Default: True.
        out_indices (Sequence[int]): Output from which stages.
        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
            -1 means not freezing any args.
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(
            self,
            pretrain_img_size=224,
            patch_size=4,
            in_chans=3,
            embed_dim=96,
            depths=[2, 2, 6, 2],
            num_heads=[3, 6, 12, 24],
            window_size=7,
            mlp_ratio=4.,
            qkv_bias=True,
            qk_scale=None,
            drop_rate=0.,
            attn_drop_rate=0.,
            drop_path_rate=0.2,
            norm_layer=nn.LayerNorm,
            ape=False,
            patch_norm=True,
            out_indices=[1, 2,
                         3],  # not used in the current version, please ignore.
            frozen_stages=-1,
            use_checkpoint=False):
        super().__init__()

        self.pretrain_img_size = pretrain_img_size
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.out_indices = out_indices
        self.frozen_stages = frozen_stages

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)

        # absolute position embedding
        if self.ape:
            pretrain_img_size = to_2tuple(pretrain_img_size)
            patch_size = to_2tuple(patch_size)
            patches_resolution = [
                pretrain_img_size[0] // patch_size[0],
                pretrain_img_size[1] // patch_size[1]
            ]

            self.absolute_pos_embed = nn.Parameter(
                torch.zeros(1, embed_dim, patches_resolution[0],
                            patches_resolution[1]))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
        ]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(
                dim=int(embed_dim * 2**i_layer),
                depth=depths[i_layer],
                num_heads=num_heads[i_layer],
                window_size=window_size,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                norm_layer=norm_layer,
                # modified by ViDT
                downsample=PatchMerging if
                (i_layer < self.num_layers) else None,
                last=None if (i_layer < self.num_layers - 1) else True,
                #
                use_checkpoint=use_checkpoint)
            self.layers.append(layer)

        num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
        self.num_features = num_features

        # add a norm layer for each output
        # Not used in the current version -> please ignore. this error will be fixed later
        # we leave this lines to load the pre-trained model ...
        for i_layer in out_indices:
            layer = norm_layer(num_features[i_layer])
            layer_name = f'norm{i_layer}'
            self.add_module(layer_name, layer)

        self._freeze_stages()

    def _freeze_stages(self):
        if self.frozen_stages >= 0:
            self.patch_embed.eval()
            for param in self.patch_embed.parameters():
                param.requires_grad = False

        if self.frozen_stages >= 1 and self.ape:
            self.absolute_pos_embed.requires_grad = False

        if self.frozen_stages >= 2:
            self.pos_drop.eval()
            for i in range(0, self.frozen_stages - 1):
                m = self.layers[i]
                m.eval()
                for param in m.parameters():
                    param.requires_grad = False

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'det_pos_embed', 'det_token'}

    def finetune_det(self,
                     method,
                     det_token_num=100,
                     pos_dim=256,
                     cross_indices=[3]):
        """ A function to add necessary (leanable) variables to Swin Transformer for object detection

            Args:
                method: vidt or vidt_wo_neck
                det_token_num: the number of object to detect, i.e., number of object queries
                pos_dim: the channel dimension of positional encodings for [DET] and [PATCH] tokens
                cross_indices: the indices where to use the [DET X PATCH] cross-attention
                    there are four possible stages in [0, 1, 2, 3]. 3 indicates Stage 4 in the ViDT paper.
        """

        # which method?
        self.method = method

        # how many object we detect?
        self.det_token_num = det_token_num
        self.det_token = nn.Parameter(
            torch.zeros(1, det_token_num, self.num_features[0]))
        self.det_token = trunc_normal_(self.det_token, std=.02)

        # dim size of pos encoding
        self.pos_dim = pos_dim

        # learnable positional encoding for detection tokens
        det_pos_embed = torch.zeros(1, det_token_num, pos_dim)
        det_pos_embed = trunc_normal_(det_pos_embed, std=.02)
        self.det_pos_embed = torch.nn.Parameter(det_pos_embed)

        # info for detection
        self.num_channels = [
            self.num_features[i + 1]
            for i in range(len(self.num_features) - 1)
        ]
        if method == 'vidt':
            self.num_channels.append(
                self.pos_dim)  # default: 256 (same to the default pos_dim)
        self.cross_indices = cross_indices
        # divisor to reduce the spatial size of the mask
        self.mask_divisor = 2**(len(self.layers) - len(self.cross_indices))

        # projection matrix for det pos encoding in each Swin layer (there are 4 blocks)
        for layer in self.layers:
            layer.det_token_num = det_token_num
            if layer.downsample is not None:
                layer.downsample.det_token_num = det_token_num
            for block in layer.blocks:
                block.det_token_num = det_token_num
                block.det_pos_linear = nn.Linear(pos_dim, block.dim)

        # neck-free model do not require downsampling at the last stage.
        if method == 'vidt_wo_neck':
            self.layers[-1].downsample = None

    def forward(self, x, mask):
        """ Forward function.

            Args:
                x: input rgb images
                mask: input padding masks [0: rgb values, 1: padded values]

            Returns:
                patch_outs: multi-scale [PATCH] tokens (four scales are used)
                    these tokens are the first input of the neck decoder
                det_tgt: final [DET] tokens obtained at the last stage
                    this tokens are the second input of the neck decoder
                det_pos: the learnable pos encoding for [DET] tokens.
                    these encodings are used to generate reference points in deformable attention
        """

        # original input shape
        B, _, _ = x.shape[0], x.shape[2], x.shape[3]

        # patch embedding
        x = self.patch_embed(x)

        Wh, Ww = x.size(2), x.size(3)
        x = x.flatten(2).transpose(1, 2)
        x = self.pos_drop(x)

        # expand det_token for all examples in the batch
        det_token = self.det_token.expand(B, -1, -1)

        # det pos encoding -> will be projected in each block
        det_pos = self.det_pos_embed

        # prepare a mask for cross attention
        mask = F.interpolate(
            mask[None].float(),
            size=(Wh // self.mask_divisor,
                  Ww // self.mask_divisor)).to(torch.bool)[0]

        patch_outs = []
        for stage in range(self.num_layers):
            layer = self.layers[stage]

            # whether to use cross-attention
            cross_attn = True if stage in self.cross_indices else False

            # concat input
            x = torch.cat([x, det_token], dim=1)

            # inference
            x_out, H, W, x, Wh, Ww = layer(
                x,
                Wh,
                Ww,
                # additional input for VIDT
                input_mask=mask,
                det_pos=det_pos,
                cross_attn=cross_attn)

            x, det_token = x[:, :-self.det_token_num, :], x[:, -self.
                                                            det_token_num:, :]

            # Aggregate intermediate outputs
            if stage > 0:
                patch_out = x_out[:, :-self.det_token_num, :].view(
                    B, H, W, -1).permute(0, 3, 1, 2)
                patch_outs.append(patch_out)

        # patch token reduced from last stage output
        patch_outs.append(x.view(B, Wh, Ww, -1).permute(0, 3, 1, 2))

        # det token
        det_tgt = x_out[:, -self.det_token_num:, :].permute(0, 2, 1)

        # det token pos encoding
        det_pos = det_pos.permute(0, 2, 1)

        features_0, features_1, features_2, features_3 = patch_outs
        return features_0, features_1, features_2, features_3, det_tgt, det_pos

    def train(self, mode=True):
        """Convert the model into training mode while keep layers freezed."""
        super(SwinTransformer, self).train(mode)
        self._freeze_stages()

    # not working in the current version
    def flops(self):
        flops = 0
        flops += self.patch_embed.flops()
        for i, layer in enumerate(self.layers):
            flops += layer.flops()
        flops += self.num_features * self.patches_resolution[
            0] * self.patches_resolution[1] // (2**self.num_layers)
        flops += self.num_features * self.num_classes
        return flops
