# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse

import json


def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Unsupported value encountered.')


class HParams(dict):
    """ Hyper-parameters class

    Store hyper-parameters in training / infer / ... scripts.
    """

    def __getattr__(self, name):
        if name in self.keys():
            return self[name]
        for v in self.values():
            if isinstance(v, HParams):
                if name in v:
                    return v[name]
        raise AttributeError(f"'HParams' object has no attribute '{name}'")

    def __setattr__(self, name, value):
        self[name] = value

    def save(self, filename):
        with open(filename, 'w', encoding='utf-8') as fp:
            json.dump(self, fp, ensure_ascii=False, indent=4, sort_keys=False)

    def load(self, filename):
        with open(filename, 'r', encoding='utf-8') as fp:
            params_dict = json.load(fp)
        for k, v in params_dict.items():
            if isinstance(v, dict):
                self[k].update(HParams(v))
            else:
                self[k] = v


def parse_args(parser):
    """ Parse hyper-parameters from cmdline. """
    parsed = parser.parse_args()
    args = HParams()
    optional_args = parser._action_groups[1]
    for action in optional_args._group_actions[1:]:
        arg_name = action.dest
        args[arg_name] = getattr(parsed, arg_name)
    for group in parser._action_groups[2:]:
        group_args = HParams()
        for action in group._group_actions:
            arg_name = action.dest
            group_args[arg_name] = getattr(parsed, arg_name)
        if len(group_args) > 0:
            args[group.title] = group_args
    return args
