# Copyright (c) Alibaba, Inc. and its affiliates.
from argparse import ArgumentParser, _SubParsersAction

from modelscope.cli.base import CLICommand
from modelscope.hub.api import HubApi
from modelscope.hub.constants import (Licenses, ModelVisibility, Visibility,
                                      VisibilityMap)
from modelscope.hub.utils.aigc import AigcModel
from modelscope.utils.constant import REPO_TYPE_MODEL, REPO_TYPE_SUPPORT
from modelscope.utils.logger import get_logger

logger = get_logger()


def subparser_func(args):
    """ Function which will be called for a specific sub parser.
    """
    return CreateCMD(args)


class CreateCMD(CLICommand):
    """
    Command for creating a new repository, supporting both model and dataset.
    """

    name = 'create'

    def __init__(self, args: _SubParsersAction):
        self.args = args

    @staticmethod
    def define_args(parsers: _SubParsersAction):

        parser: ArgumentParser = parsers.add_parser(CreateCMD.name)

        parser.add_argument(
            'repo_id',
            type=str,
            help='The ID of the repo to create (e.g. `username/repo-name`)')
        parser.add_argument(
            '--token',
            type=str,
            default=None,
            help=
            'A User Access Token generated from https://modelscope.cn/my/myaccesstoken to authenticate the user. '
            'If not provided, the CLI will use the local credentials if available.'
        )
        parser.add_argument(
            '--repo_type',
            choices=REPO_TYPE_SUPPORT,
            default=REPO_TYPE_MODEL,
            help=
            'Type of the repo to create (e.g. `dataset`, `model`). Default to `model`.',
        )
        parser.add_argument(
            '--visibility',
            choices=[
                Visibility.PUBLIC, Visibility.INTERNAL, Visibility.PRIVATE
            ],
            default=Visibility.PUBLIC,
            help='Visibility of the repo to create. Default to `public`.',
        )
        parser.add_argument(
            '--chinese_name',
            type=str,
            default=None,
            help='Optional, Chinese name of the repo. Default to `None`.',
        )
        parser.add_argument(
            '--license',
            type=str,
            choices=Licenses.to_list(),
            default=Licenses.APACHE_V2,
            help=
            'Optional, License of the repo. Default to `Apache License 2.0`.',
        )
        parser.add_argument(
            '--endpoint',
            type=str,
            default=None,
            help='Optional, The modelscope server address. Default to None.',
        )

        # AIGC specific arguments
        aigc_group = parser.add_argument_group(
            'AIGC Model Creation',
            'Arguments for creating an AIGC model. Use --aigc to enable.')
        aigc_group.add_argument(
            '--aigc', action='store_true', help='Enable AIGC model creation.')
        aigc_group.add_argument(
            '--from_json',
            type=str,
            help='Path to a JSON file containing AIGC model configuration. '
            'If used, all other parameters except --repo_id are ignored.')
        aigc_group.add_argument(
            '--model_path', type=str, help='Path to the model file or folder.')
        aigc_group.add_argument(
            '--aigc_type',
            type=str,
            help="AIGC type. Recommended: 'Checkpoint', 'LoRA', 'VAE'.")
        aigc_group.add_argument(
            '--base_model_type',
            type=str,
            help='Base model type, e.g., SD_XL.')
        aigc_group.add_argument(
            '--revision',
            type=str,
            default='v1.0',
            help="Model revision. Defaults to 'v1.0'.")
        aigc_group.add_argument(
            '--base_model_id',
            type=str,
            default='',
            help='Base model ID from ModelScope.')
        aigc_group.add_argument(
            '--description',
            type=str,
            default='This is an AIGC model.',
            help='Model description.')
        aigc_group.add_argument(
            '--path_in_repo',
            type=str,
            default='',
            help='Path in the repository to upload to.')
        aigc_group.add_argument(
            '--model_source',
            type=str,
            default='USER_UPLOAD',
            help=
            'Source of the AIGC model. `USER_UPLOAD`, `TRAINED_FROM_MODELSCOPE` or `TRAINED_FROM_ALIYUN_FC`.'
        )
        aigc_group.add_argument(
            '--base_model_sub_type',
            type=str,
            default='',
            help='Base model sub type, e.g., Qwen_Edit_2509')

        parser.set_defaults(func=subparser_func)

    def execute(self):
        if self.args.aigc:
            if self.args.repo_type != REPO_TYPE_MODEL:
                raise ValueError(
                    'AIGC models can only be created when repo_type is "model".'
                )
            self._create_aigc_model()
        else:
            self._create_regular_repo()

    def _create_regular_repo(self):
        # Check token and login
        # The cookies will be reused if the user has logged in before.
        api = HubApi(endpoint=self.args.endpoint)

        # Create repo
        api.create_repo(
            repo_id=self.args.repo_id,
            token=self.args.token,
            visibility=self.args.visibility,
            repo_type=self.args.repo_type,
            chinese_name=self.args.chinese_name,
            license=self.args.license,
            exist_ok=True,
            create_default_config=True,
            endpoint=self.args.endpoint,
        )

    def _create_aigc_model(self):
        """Execute the command."""
        api = HubApi(endpoint=self.args.endpoint)
        model_id = self.args.repo_id

        if self.args.from_json:
            # Create from JSON file
            logger.info('Creating AIGC model from JSON file: '
                        f'{self.args.from_json}')
            aigc_model = AigcModel.from_json_file(self.args.from_json)
        else:
            # Create from command line arguments
            logger.info('Creating AIGC model from command line arguments...')
            if not all([
                    self.args.model_path, self.args.aigc_type,
                    self.args.base_model_type
            ]):
                raise ValueError(
                    'Error: --model_path, --aigc_type, and '
                    '--base_model_type are required when not using '
                    '--from_json.')

            aigc_model = AigcModel(
                model_path=self.args.model_path,
                aigc_type=self.args.aigc_type,
                base_model_type=self.args.base_model_type,
                tag=self.args.revision,
                description=self.args.description,
                base_model_id=self.args.base_model_id,
                path_in_repo=self.args.path_in_repo,
                model_source=self.args.model_source,
                base_model_sub_type=self.args.base_model_sub_type,
            )

        # Convert visibility string to int for the API call
        reverse_visibility_map = {v: k for k, v in VisibilityMap.items()}
        visibility_idx: int = reverse_visibility_map.get(
            self.args.visibility, ModelVisibility.PUBLIC)

        try:
            model_url = api.create_model(
                model_id=model_id,
                token=self.args.token,
                visibility=visibility_idx,
                license=self.args.license,
                chinese_name=self.args.chinese_name,
                aigc_model=aigc_model)
            print(f'Successfully created AIGC model: {model_url}')
        except Exception as e:
            print(f'Error creating AIGC model: {e}')
