import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from modelscope.metainfo import Models
from modelscope.models import MODELS
from modelscope.utils.constant import ModelFile, Tasks


def apply_offset(offset):
    sizes = list(offset.size()[2:])
    grid_list = torch.meshgrid(
        [torch.arange(size, device=offset.device) for size in sizes])
    grid_list = reversed(grid_list)
    # apply offset
    grid_list = [
        grid.float().unsqueeze(0) + offset[:, dim, ...]
        for dim, grid in enumerate(grid_list)
    ]
    # normalize
    grid_list = [
        grid / ((size - 1.0) / 2.0) - 1.0
        for grid, size in zip(grid_list, reversed(sizes))
    ]

    return torch.stack(grid_list, dim=-1)


# backbone
class ResBlock(nn.Module):

    def __init__(self, in_channels):
        super(ResBlock, self).__init__()
        self.block = nn.Sequential(
            nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels, in_channels, kernel_size=3,
                padding=1, bias=False), nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels, in_channels, kernel_size=3, padding=1,
                bias=False))

    def forward(self, x):
        return self.block(x) + x


class Downsample(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(Downsample, self).__init__()
        self.block = nn.Sequential(
            nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=3,
                stride=2,
                padding=1,
                bias=False))

    def forward(self, x):
        return self.block(x)


class FeatureEncoder(nn.Module):

    def __init__(self, in_channels, chns=[64, 128, 256, 256, 256]):
        # in_channels = 3 for images, and is larger (e.g., 17+1+1) for agnostic representation
        super(FeatureEncoder, self).__init__()
        self.encoders = []
        for i, out_chns in enumerate(chns):
            if i == 0:
                encoder = nn.Sequential(
                    Downsample(in_channels, out_chns), ResBlock(out_chns),
                    ResBlock(out_chns))
            else:
                encoder = nn.Sequential(
                    Downsample(chns[i - 1], out_chns), ResBlock(out_chns),
                    ResBlock(out_chns))

            self.encoders.append(encoder)

        self.encoders = nn.ModuleList(self.encoders)

    def forward(self, x):
        encoder_features = []
        for encoder in self.encoders:
            x = encoder(x)
            encoder_features.append(x)
        return encoder_features


class RefinePyramid(nn.Module):

    def __init__(self, chns=[64, 128, 256, 256, 256], fpn_dim=256):
        super(RefinePyramid, self).__init__()
        self.chns = chns

        # adaptive
        self.adaptive = []
        for in_chns in list(reversed(chns)):
            adaptive_layer = nn.Conv2d(in_chns, fpn_dim, kernel_size=1)
            self.adaptive.append(adaptive_layer)
        self.adaptive = nn.ModuleList(self.adaptive)
        # output conv
        self.smooth = []
        for i in range(len(chns)):
            smooth_layer = nn.Conv2d(
                fpn_dim, fpn_dim, kernel_size=3, padding=1)
            self.smooth.append(smooth_layer)
        self.smooth = nn.ModuleList(self.smooth)

    def forward(self, x):
        conv_ftr_list = x

        feature_list = []
        last_feature = None
        for i, conv_ftr in enumerate(list(reversed(conv_ftr_list))):
            # adaptive
            feature = self.adaptive[i](conv_ftr)
            # fuse
            if last_feature is not None:
                feature = feature + F.interpolate(
                    last_feature, scale_factor=2, mode='nearest')
            # smooth
            feature = self.smooth[i](feature)
            last_feature = feature
            feature_list.append(feature)

        return tuple(reversed(feature_list))


def DAWarp(feat, offsets, att_maps, sample_k, out_ch):
    att_maps = torch.repeat_interleave(att_maps, out_ch, 1)
    B, C, H, W = feat.size()
    multi_feat = torch.repeat_interleave(feat, sample_k, 0)
    multi_warp_feat = F.grid_sample(
        multi_feat,
        offsets.detach().permute(0, 2, 3, 1),
        mode='bilinear',
        padding_mode='border')
    multi_att_warp_feat = multi_warp_feat.reshape(B, -1, H, W) * att_maps
    att_warp_feat = sum(torch.split(multi_att_warp_feat, out_ch, 1))
    return att_warp_feat


class MFEBlock(nn.Module):

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=3,
                 num_filters=[128, 64, 32]):
        super(MFEBlock, self).__init__()
        layers = []
        for i in range(len(num_filters)):
            if i == 0:
                layers.append(
                    torch.nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=num_filters[i],
                        kernel_size=3,
                        stride=1,
                        padding=1))
            else:
                layers.append(
                    torch.nn.Conv2d(
                        in_channels=num_filters[i - 1],
                        out_channels=num_filters[i],
                        kernel_size=kernel_size,
                        stride=1,
                        padding=kernel_size // 2))
            layers.append(
                torch.nn.LeakyReLU(inplace=False, negative_slope=0.1))
        layers.append(
            torch.nn.Conv2d(
                in_channels=num_filters[-1],
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=1,
                padding=kernel_size // 2))
        self.layers = torch.nn.Sequential(*layers)

    def forward(self, input):
        return self.layers(input)


class DAFlowNet(nn.Module):

    def __init__(self, num_pyramid, fpn_dim=256, head_nums=1):
        super(DAFlowNet, self).__init__()
        self.Self_MFEs = []

        self.Cross_MFEs = []
        self.Refine_MFEs = []
        self.k = head_nums
        self.out_ch = fpn_dim
        for i in range(num_pyramid):
            # self-MFE for model img 2k:flow 1k:att_map
            Self_MFE_layer = MFEBlock(
                in_channels=2 * fpn_dim,
                out_channels=self.k * 3,
                kernel_size=7)
            # cross-MFE for cloth img
            Cross_MFE_layer = MFEBlock(
                in_channels=2 * fpn_dim, out_channels=self.k * 3)
            # refine-MFE for cloth and model imgs
            Refine_MFE_layer = MFEBlock(
                in_channels=2 * fpn_dim, out_channels=self.k * 6)
            self.Self_MFEs.append(Self_MFE_layer)
            self.Cross_MFEs.append(Cross_MFE_layer)
            self.Refine_MFEs.append(Refine_MFE_layer)

        self.Self_MFEs = nn.ModuleList(self.Self_MFEs)
        self.Cross_MFEs = nn.ModuleList(self.Cross_MFEs)
        self.Refine_MFEs = nn.ModuleList(self.Refine_MFEs)

        self.lights_decoder = torch.nn.Sequential(
            torch.nn.Conv2d(64, out_channels=32, kernel_size=1, stride=1),
            torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
            torch.nn.Conv2d(
                in_channels=32,
                out_channels=3,
                kernel_size=3,
                stride=1,
                padding=1))
        self.lights_encoder = torch.nn.Sequential(
            torch.nn.Conv2d(
                3, out_channels=32, kernel_size=3, stride=1, padding=1),
            torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
            torch.nn.Conv2d(
                in_channels=32, out_channels=64, kernel_size=1, stride=1))

    def forward(self,
                source_image,
                reference_image,
                source_feats,
                reference_feats,
                return_all=False,
                warp_feature=True,
                use_light_en_de=True):
        r"""
        Args:
            source_image: cloth rgb image for tryon
            reference_image: model rgb image for try on
            source_feats: cloth FPN features
            reference_feats: model and pose features
            return_all: bool return all intermediate try-on results in training phase
            warp_feature: use DAFlow for both features and images
            use_light_en_de: use shallow encoder and decoder to project the images from RGB to high dimensional space

        """

        # reference branch inputs model img using self-DAFlow
        last_multi_self_offsets = None
        # source branch inputs cloth img using cross-DAFlow
        last_multi_cross_offsets = None

        if return_all:
            results_all = []

        for i in range(len(source_feats)):

            feat_source = source_feats[len(source_feats) - 1 - i]
            feat_ref = reference_feats[len(reference_feats) - 1 - i]
            B, C, H, W = feat_source.size()

            # Pre-DAWarp for Pyramid feature
            if last_multi_cross_offsets is not None and warp_feature:
                att_source_feat = DAWarp(feat_source, last_multi_cross_offsets,
                                         cross_att_maps, self.k, self.out_ch)
                att_reference_feat = DAWarp(feat_ref, last_multi_self_offsets,
                                            self_att_maps, self.k, self.out_ch)
            else:
                att_source_feat = feat_source
                att_reference_feat = feat_ref
            # Cross-MFE
            input_feat = torch.cat([att_source_feat, feat_ref], 1)
            offsets_att = self.Cross_MFEs[i](input_feat)
            cross_att_maps = F.softmax(
                offsets_att[:, self.k * 2:, :, :], dim=1)
            offsets = apply_offset(offsets_att[:, :self.k * 2, :, :].reshape(
                -1, 2, H, W))
            if last_multi_cross_offsets is not None:
                offsets = F.grid_sample(
                    last_multi_cross_offsets,
                    offsets,
                    mode='bilinear',
                    padding_mode='border')
            else:
                offsets = offsets.permute(0, 3, 1, 2)
            last_multi_cross_offsets = offsets
            att_source_feat = DAWarp(feat_source, last_multi_cross_offsets,
                                     cross_att_maps, self.k, self.out_ch)

            # Self-MFE
            input_feat = torch.cat([att_source_feat, att_reference_feat], 1)
            offsets_att = self.Self_MFEs[i](input_feat)
            self_att_maps = F.softmax(offsets_att[:, self.k * 2:, :, :], dim=1)
            offsets = apply_offset(offsets_att[:, :self.k * 2, :, :].reshape(
                -1, 2, H, W))
            if last_multi_self_offsets is not None:
                offsets = F.grid_sample(
                    last_multi_self_offsets,
                    offsets,
                    mode='bilinear',
                    padding_mode='border')
            else:
                offsets = offsets.permute(0, 3, 1, 2)
            last_multi_self_offsets = offsets
            att_reference_feat = DAWarp(feat_ref, last_multi_self_offsets,
                                        self_att_maps, self.k, self.out_ch)

            # Refine-MFE
            input_feat = torch.cat([att_source_feat, att_reference_feat], 1)
            offsets_att = self.Refine_MFEs[i](input_feat)
            att_maps = F.softmax(offsets_att[:, self.k * 4:, :, :], dim=1)
            cross_offsets = apply_offset(
                offsets_att[:, :self.k * 2, :, :].reshape(-1, 2, H, W))
            self_offsets = apply_offset(
                offsets_att[:,
                            self.k * 2:self.k * 4, :, :].reshape(-1, 2, H, W))
            last_multi_cross_offsets = F.grid_sample(
                last_multi_cross_offsets,
                cross_offsets,
                mode='bilinear',
                padding_mode='border')
            last_multi_self_offsets = F.grid_sample(
                last_multi_self_offsets,
                self_offsets,
                mode='bilinear',
                padding_mode='border')

            # Upsampling
            last_multi_cross_offsets = F.interpolate(
                last_multi_cross_offsets, scale_factor=2, mode='bilinear')
            last_multi_self_offsets = F.interpolate(
                last_multi_self_offsets, scale_factor=2, mode='bilinear')
            self_att_maps = F.interpolate(
                att_maps[:, :self.k, :, :], scale_factor=2, mode='bilinear')
            cross_att_maps = F.interpolate(
                att_maps[:, self.k:, :, :], scale_factor=2, mode='bilinear')

            # Post-DAWarp for source and reference images
            if return_all:
                cur_source_image = F.interpolate(
                    source_image, (H * 2, W * 2), mode='bilinear')
                cur_reference_image = F.interpolate(
                    reference_image, (H * 2, W * 2), mode='bilinear')
                if use_light_en_de:
                    cur_source_image = self.lights_encoder(cur_source_image)
                    cur_reference_image = self.lights_encoder(
                        cur_reference_image)
                    # the feat dim in light encoder is 64
                    warp_att_source_image = DAWarp(cur_source_image,
                                                   last_multi_cross_offsets,
                                                   cross_att_maps, self.k, 64)
                    warp_att_reference_image = DAWarp(cur_reference_image,
                                                      last_multi_self_offsets,
                                                      self_att_maps, self.k,
                                                      64)
                    result_tryon = self.lights_decoder(
                        warp_att_source_image + warp_att_reference_image)
                else:
                    warp_att_source_image = DAWarp(cur_source_image,
                                                   last_multi_cross_offsets,
                                                   cross_att_maps, self.k, 3)
                    warp_att_reference_image = DAWarp(cur_reference_image,
                                                      last_multi_self_offsets,
                                                      self_att_maps, self.k, 3)
                    result_tryon = warp_att_source_image + warp_att_reference_image
                results_all.append(result_tryon)

        last_multi_self_offsets = F.interpolate(
            last_multi_self_offsets,
            reference_image.size()[2:],
            mode='bilinear')
        last_multi_cross_offsets = F.interpolate(
            last_multi_cross_offsets, source_image.size()[2:], mode='bilinear')
        self_att_maps = F.interpolate(
            self_att_maps, reference_image.size()[2:], mode='bilinear')
        cross_att_maps = F.interpolate(
            cross_att_maps, source_image.size()[2:], mode='bilinear')
        if use_light_en_de:
            source_image = self.lights_encoder(source_image)
            reference_image = self.lights_encoder(reference_image)
            warp_att_source_image = DAWarp(source_image,
                                           last_multi_cross_offsets,
                                           cross_att_maps, self.k, 64)
            warp_att_reference_image = DAWarp(reference_image,
                                              last_multi_self_offsets,
                                              self_att_maps, self.k, 64)
            result_tryon = self.lights_decoder(warp_att_source_image
                                               + warp_att_reference_image)
        else:
            warp_att_source_image = DAWarp(source_image,
                                           last_multi_cross_offsets,
                                           cross_att_maps, self.k, 3)
            warp_att_reference_image = DAWarp(reference_image,
                                              last_multi_self_offsets,
                                              self_att_maps, self.k, 3)
            result_tryon = warp_att_source_image + warp_att_reference_image

        if return_all:
            return result_tryon, return_all
        return result_tryon


class SDAFNet_Tryon(nn.Module):

    def __init__(self, ref_in_channel, source_in_channel=3, head_nums=6):
        super(SDAFNet_Tryon, self).__init__()
        num_filters = [64, 128, 256, 256, 256]
        self.source_features = FeatureEncoder(source_in_channel, num_filters)
        self.reference_features = FeatureEncoder(ref_in_channel, num_filters)
        self.source_FPN = RefinePyramid(num_filters)
        self.reference_FPN = RefinePyramid(num_filters)
        self.dafnet = DAFlowNet(len(num_filters), head_nums=head_nums)

    def forward(self,
                ref_input,
                source_image,
                ref_image,
                use_light_en_de=True,
                return_all=False,
                warp_feature=True):
        reference_feats = self.reference_FPN(
            self.reference_features(ref_input))
        source_feats = self.source_FPN(self.source_features(source_image))
        result = self.dafnet(
            source_image,
            ref_image,
            source_feats,
            reference_feats,
            use_light_en_de=use_light_en_de,
            return_all=return_all,
            warp_feature=warp_feature)
        return result
