# Copyright (c) Alibaba, Inc. and its affiliates.
import collections.abc
import math
import os.path as osp
from itertools import repeat

import numpy as np
import torch
from mmcls.datasets.base_dataset import BaseDataset


def get_trained_checkpoints_name(work_path):
    import os
    file_list = os.listdir(work_path)
    last = 0
    model_name = None
    # find the best model
    if model_name is None:
        for f_name in file_list:
            if 'best_' in f_name and f_name.endswith('.pth'):
                best_epoch = f_name.replace('.pth', '').split('_')[-1]
                if best_epoch.isdigit():
                    last = int(best_epoch)
                    model_name = f_name
                    return model_name
    # or find the latest model
    if model_name is None:
        for f_name in file_list:
            if 'epoch_' in f_name and f_name.endswith('.pth'):
                epoch_num = f_name.replace('epoch_', '').replace('.pth', '')
                if not epoch_num.isdigit():
                    continue
                ind = int(epoch_num)
                if ind > last:
                    last = ind
                    model_name = f_name
    return model_name


def preprocess_transform(cfgs):
    if cfgs is None:
        return None
    for i, cfg in enumerate(cfgs):
        if cfg.type == 'Resize':
            if isinstance(cfg.size, list):
                cfgs[i].size = tuple(cfg.size)
    return cfgs


def get_ms_dataset_root(ms_dataset):
    if ms_dataset is None or len(ms_dataset) < 1:
        return None
    try:
        data_root = ms_dataset[0]['image:FILE'].split('extracted')[0]
        path_post = ms_dataset[0]['image:FILE'].split('extracted')[1].split(
            '/')
        extracted_data_root = osp.join(data_root, 'extracted', path_post[1],
                                       path_post[2])
        return extracted_data_root
    except Exception as e:
        raise ValueError(f'Dataset Error: {e}')
    return None


def get_classes(classes=None):
    import mmcv
    if isinstance(classes, str):
        # take it as a file path
        class_names = mmcv.list_from_file(classes)
    elif isinstance(classes, (tuple, list)):
        class_names = classes
    else:
        raise ValueError(f'Unsupported type {type(classes)} of classes.')

    return class_names


class MmDataset(BaseDataset):

    def __init__(self,
                 ms_dataset,
                 pipeline,
                 classes=None,
                 test_mode=False,
                 data_prefix=''):
        self.ms_dataset = ms_dataset
        if len(self.ms_dataset) < 1:
            raise ValueError('Dataset Error: dataset is empty')
        super(MmDataset, self).__init__(
            data_prefix=data_prefix,
            pipeline=pipeline,
            classes=classes,
            test_mode=test_mode)

    def load_annotations(self):
        if self.CLASSES is None:
            raise ValueError(
                f'Dataset Error: Not found classesname.txt: {self.CLASSES}')

        data_infos = []
        for data_info in self.ms_dataset:
            filename = data_info['image:FILE']
            gt_label = data_info['category']
            info = {'img_prefix': self.data_prefix}
            info['img_info'] = {'filename': filename}
            info['gt_label'] = np.array(gt_label, dtype=np.int64)
            data_infos.append(info)

        return data_infos


def _trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn(
            'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
            'The distribution of values may be incorrect.',
            stacklevel=2)

    # Values are generated by using a truncated uniform distribution and
    # then using the inverse CDF for the normal distribution.
    # Get upper and lower cdf values
    v = norm_cdf((a - mean) / std)
    u = norm_cdf((b - mean) / std)

    # Uniformly fill tensor with values from [v, u], then translate to
    # [2v-1, 2u-1].
    tensor.uniform_(2 * v - 1, 2 * u - 1)

    # Use inverse cdf transform for normal distribution to get truncated
    # standard normal
    tensor.erfinv_()

    # Transform to proper mean, std
    tensor.mul_(std * math.sqrt(2.))
    tensor.add_(mean)

    # Clamp to ensure it's in the proper range
    tensor.clamp_(min=a, max=b)
    return tensor


def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    with torch.no_grad():
        return _trunc_normal_(tensor, mean, std, a, b)


# From PyTorch internals
def _ntuple(n):

    def parse(x):
        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
            return x
        return tuple(repeat(x, n))

    return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
