from argparse import Action, ArgumentDefaultsHelpFormatter, ArgumentParser
from dataclasses import fields
from typing import List


class CliArgumentParser(ArgumentParser):
    """ Argument Parser to define and parse command-line args for training.

    Args:
        training_args: dict or list of dict which defines different
            parameters for training.
    """

    def __init__(self, training_args=None, **kwargs):
        if 'formatter_class' not in kwargs:
            kwargs['formatter_class'] = ArgumentDefaultsHelpFormatter
        super().__init__(**kwargs)
        self.training_args = training_args
        self.define_args()

    def get_manual_args(self, args):
        return [arg[2:] for arg in args if arg.startswith('--')]

    def _parse_known_args(self,
                          args: List = None,
                          namespace=None,
                          *args_extra,
                          **kwargs):
        self.model_id = namespace.model if namespace is not None else None
        if '--model' in args:
            self.model_id = args[args.index('--model') + 1]
        self.manual_args = self.get_manual_args(args)
        return super()._parse_known_args(args, namespace, *args_extra,
                                         **kwargs)

    def print_help(self, file=None):
        return super().print_help(file)

    def define_args(self):
        if self.training_args is not None:
            for f in fields(self.training_args):
                arg_name = f.name
                arg_attr = getattr(self.training_args, f.name)
                name = f'--{arg_name}'
                kwargs = dict(type=f.type, help=f.metadata['help'])
                kwargs['default'] = arg_attr

                if 'choices' in f.metadata:
                    kwargs['choices'] = f.metadata['choices']

                kwargs['action'] = SingleAction
                self.add_argument(name, **kwargs)


class DictAction(Action):
    """
    argparse action to split an argument into KEY=VALUE form
    on the first = and append to a dictionary. List options can
    be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
    brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
    list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
    """

    @staticmethod
    def parse_int_float_bool_str(val):
        try:
            return int(val)
        except ValueError:
            pass
        try:
            return float(val)
        except ValueError:
            pass
        if val.lower() in ['true', 'false']:
            return val.lower() == 'true'
        if val == 'None':
            return None
        return val

    @staticmethod
    def parse_iterable(val):
        """Parse iterable values in the string.
        All elements inside '()' or '[]' are treated as iterable values.
        Args:
            val (str): Value string.
        Returns:
            list | tuple: The expanded list or tuple from the string.
        Examples:
            >>> DictAction._parse_iterable('1,2,3')
            [1, 2, 3]
            >>> DictAction._parse_iterable('[a, b, c]')
            ['a', 'b', 'c']
            >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
            [(1, 2, 3), ['a', 'b'], 'c']
        """

        def find_next_comma(string):
            """Find the position of next comma in the string.
            If no ',' is found in the string, return the string length. All
            chars inside '()' and '[]' are treated as one element and thus ','
            inside these brackets are ignored.
            """
            assert (string.count('(') == string.count(')')) and (
                string.count('[')
                == string.count(']')), f'Imbalanced brackets exist in {string}'
            end = len(string)
            for idx, char in enumerate(string):
                pre = string[:idx]
                # The string before this ',' is balanced
                if ((char == ',') and (pre.count('(') == pre.count(')'))
                        and (pre.count('[') == pre.count(']'))):
                    end = idx
                    break
            return end

        # Strip ' and " characters and replace whitespace.
        val = val.strip('\'\"').replace(' ', '')
        is_tuple = False
        if val.startswith('(') and val.endswith(')'):
            is_tuple = True
            val = val[1:-1]
        elif val.startswith('[') and val.endswith(']'):
            val = val[1:-1]
        elif ',' not in val:
            # val is a single value
            return DictAction.parse_int_float_bool_str(val)

        values = []
        while len(val) > 0:
            comma_idx = find_next_comma(val)
            element = DictAction.parse_iterable(val[:comma_idx])
            values.append(element)
            val = val[comma_idx + 1:]
        if is_tuple:
            values = tuple(values)
        return values

    def __call__(self, parser, namespace, values, option_string):
        options = {}
        for kv in values:
            key, val = kv.split('=', maxsplit=1)
            options[key] = self.parse_iterable(val)
        setattr(namespace, self.dest, options)


class SingleAction(DictAction):
    """ Argparse action to convert value to tuple or list or nested structure of
    list and tuple, i.e 'V1,V2,V3', or with explicit brackets, i.e. '[V1,V2,V3]'.
    It also support nested brackets to build list/tuple values. e.g. '[(V1,V2),(V3,V4)]'
    """

    def __call__(self, parser, namespace, value, option_string):
        if isinstance(value, str):
            setattr(namespace, self.dest, self.parse_iterable(value))
        else:
            setattr(namespace, self.dest, value)
