# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.

import os

import torch

from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
from .backbone import SwinTransformer
from .deformable_transformer import DeformableTransformer
from .fpn_fusion import FPNFusionModule
from .head import Detector


@MODELS.register_module(Tasks.image_object_detection, module_name=Models.vidt)
class VidtModel(TorchModel):
    """
        The implementation of 'ViDT for joint-learning of object detection and instance segmentation'.
        This model is dynamically initialized with the following parts:
            - 'backbone': pre-trained backbone model with parameters.
            - 'head': detection and segentation head with fine-tuning.
    """

    def __init__(self, model_dir: str, **kwargs):
        """ Initialize a Vidt Model.
        Args:
          model_dir: model id or path, where model_dir/pytorch_model.pt contains:
                    - 'backbone_weights': parameters of backbone.
                    - 'head_weights': parameters of head.
        """
        super(VidtModel, self).__init__()

        model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
        model_dict = torch.load(
            model_path, map_location='cpu', weights_only=True)

        # build backbone
        backbone = SwinTransformer(
            pretrain_img_size=[224, 224],
            embed_dim=96,
            depths=[2, 2, 6, 2],
            num_heads=[3, 6, 12, 24],
            window_size=7,
            drop_path_rate=0.2)
        backbone.finetune_det(
            method='vidt', det_token_num=300, pos_dim=256, cross_indices=[3])
        self.backbone = backbone
        self.backbone.load_state_dict(
            model_dict['backbone_weights'], strict=True)

        # build head
        epff = FPNFusionModule(backbone.num_channels, fuse_dim=256)
        deform_transformers = DeformableTransformer(
            d_model=256,
            nhead=8,
            num_decoder_layers=6,
            dim_feedforward=1024,
            dropout=0.1,
            activation='relu',
            return_intermediate_dec=True,
            num_feature_levels=4,
            dec_n_points=4,
            token_label=False)
        head = Detector(
            backbone,
            deform_transformers,
            num_classes=2,
            num_queries=300,
            # two essential techniques used in ViDT
            aux_loss=True,
            with_box_refine=True,
            # an epff module for ViDT+
            epff=epff,
            # an UQR module for ViDT+
            with_vector=False,
            processor_dct=None,
            # two additional losses for VIDT+
            iou_aware=True,
            token_label=False,
            vector_hidden_dim=256,
            # distil
            distil=False)
        self.head = head
        self.head.load_state_dict(model_dict['head_weights'], strict=True)

    def forward(self, x, mask):
        """ Dynamic forward function of VidtModel.
        Args:
            x: input images (B, 3, H, W)
            mask: input padding masks (B, H, W)
        """
        features_0, features_1, features_2, features_3, det_tgt, det_pos = self.backbone(
            x, mask)
        out_pred_logits, out_pred_boxes = self.head(features_0, features_1,
                                                    features_2, features_3,
                                                    det_tgt, det_pos, mask)
        return out_pred_logits, out_pred_boxes
