# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from argparse import ArgumentParser, _SubParsersAction

from modelscope.cli.base import CLICommand
from modelscope.hub.api import HubApi
from modelscope.hub.utils.utils import convert_patterns, get_endpoint
from modelscope.utils.constant import REPO_TYPE_MODEL, REPO_TYPE_SUPPORT


def subparser_func(args):
    """ Function which will be called for a specific sub parser.
    """
    return UploadCMD(args)


class UploadCMD(CLICommand):

    name = 'upload'

    def __init__(self, args: _SubParsersAction):
        self.args = args

    @staticmethod
    def define_args(parsers: _SubParsersAction):

        parser: ArgumentParser = parsers.add_parser(UploadCMD.name)

        parser.add_argument(
            'repo_id',
            type=str,
            help='The ID of the repo to upload to (e.g. `username/repo-name`)')
        parser.add_argument(
            'local_path',
            type=str,
            nargs='?',
            default=None,
            help='Optional, '
            'Local path to the file or folder to upload. Defaults to current directory.'
        )
        parser.add_argument(
            'path_in_repo',
            type=str,
            nargs='?',
            default=None,
            help='Optional, '
            'Path of the file or folder in the repo. Defaults to the relative path of the file or folder.'
        )
        parser.add_argument(
            '--repo-type',
            choices=REPO_TYPE_SUPPORT,
            default=REPO_TYPE_MODEL,
            help=
            'Type of the repo to upload to (e.g. `dataset`, `model`). Defaults to be `model`.',
        )
        parser.add_argument(
            '--include',
            nargs='*',
            type=str,
            help='Glob patterns to match files to upload.')
        parser.add_argument(
            '--exclude',
            nargs='*',
            type=str,
            help='Glob patterns to exclude from files to upload.')
        parser.add_argument(
            '--commit-message',
            type=str,
            default=None,
            help='The message of commit. Default to be `None`.')
        parser.add_argument(
            '--commit-description',
            type=str,
            default=None,
            help=
            'The description of the generated commit. Default to be `None`.')
        parser.add_argument(
            '--token',
            type=str,
            default=None,
            help=
            'A User Access Token generated from https://modelscope.cn/my/myaccesstoken'
        )
        parser.add_argument(
            '--max-workers',
            type=int,
            default=min(8,
                        os.cpu_count() + 4),
            help='The number of workers to use for uploading files.')
        parser.add_argument(
            '--endpoint',
            type=str,
            default=get_endpoint(),
            help='Endpoint for ModelScope service.')

        parser.set_defaults(func=subparser_func)

    def execute(self):

        assert self.args.repo_id, '`repo_id` is required'
        assert self.args.repo_id.count(
            '/') == 1, 'repo_id should be in format of username/repo-name'
        repo_name: str = self.args.repo_id.split('/')[-1]
        self.repo_id = self.args.repo_id

        # Check path_in_repo
        if self.args.local_path is None and os.path.isfile(repo_name):
            # Case 1: modelscope upload owner_name/test_repo
            self.local_path = repo_name
            self.path_in_repo = repo_name
        elif self.args.local_path is None and os.path.isdir(repo_name):
            # Case 2: modelscope upload owner_name/test_repo  (run command line in the `repo_name` dir)
            # => upload all files in current directory to remote root path
            self.local_path = repo_name
            self.path_in_repo = '.'
        elif self.args.local_path is None:
            # Case 3: user provided only a repo_id that does not match a local file or folder
            # => the user must explicitly provide a local_path => raise exception
            raise ValueError(
                f"'{repo_name}' is not a local file or folder. Please set `local_path` explicitly."
            )
        elif self.args.path_in_repo is None and os.path.isfile(
                self.args.local_path):
            # Case 4: modelscope upload owner_name/test_repo /path/to/your_file.csv
            # => upload it to remote root path with same name
            self.local_path = self.args.local_path
            self.path_in_repo = os.path.basename(self.args.local_path)
        elif self.args.path_in_repo is None:
            # Case 5: modelscope upload owner_name/test_repo /path/to/your_folder
            # => upload all files in current directory to remote root path
            self.local_path = self.args.local_path
            self.path_in_repo = ''
        else:
            # Finally, if both paths are explicit
            self.local_path = self.args.local_path
            self.path_in_repo = self.args.path_in_repo

        api = HubApi(endpoint=self.args.endpoint)

        if os.path.isfile(self.local_path):
            api.upload_file(
                path_or_fileobj=self.local_path,
                path_in_repo=self.path_in_repo,
                repo_id=self.repo_id,
                repo_type=self.args.repo_type,
                commit_message=self.args.commit_message,
                commit_description=self.args.commit_description,
                token=self.args.token,
            )
        elif os.path.isdir(self.local_path):
            api.upload_folder(
                repo_id=self.repo_id,
                folder_path=self.local_path,
                path_in_repo=self.path_in_repo,
                commit_message=self.args.commit_message,
                commit_description=self.args.commit_description,
                repo_type=self.args.repo_type,
                allow_patterns=convert_patterns(self.args.include),
                ignore_patterns=convert_patterns(self.args.exclude),
                max_workers=self.args.max_workers,
                token=self.args.token,
            )
        else:
            raise ValueError(f'{self.local_path} is not a valid local path')

        print(f'Finished uploading to {self.repo_id}')
