# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from argparse import ArgumentParser

from modelscope.cli.base import CLICommand
from modelscope.hub.api import HubApi
from modelscope.hub.constants import DEFAULT_MAX_WORKERS
from modelscope.hub.file_download import (dataset_file_download,
                                          model_file_download)
from modelscope.hub.snapshot_download import (dataset_snapshot_download,
                                              snapshot_download)
from modelscope.hub.utils.utils import convert_patterns
from modelscope.utils.constant import DEFAULT_DATASET_REVISION


def subparser_func(args):
    """ Function which will be called for a specific sub parser.
    """
    return DownloadCMD(args)


class DownloadCMD(CLICommand):
    name = 'download'

    def __init__(self, args):
        self.args = args

    @staticmethod
    def define_args(parsers: ArgumentParser):
        """ define args for download command.
        """
        parser: ArgumentParser = parsers.add_parser(DownloadCMD.name)
        group = parser.add_mutually_exclusive_group()
        group.add_argument(
            '--model',
            type=str,
            help='The id of the model to be downloaded. For download, '
            'the id of either a model or dataset must be provided.')
        group.add_argument(
            '--dataset',
            type=str,
            help='The id of the dataset to be downloaded. For download, '
            'the id of either a model or dataset must be provided.')
        parser.add_argument(
            'repo_id',
            type=str,
            nargs='?',
            default=None,
            help='Optional, '
            'ID of the repo to download, It can also be set by --model or --dataset.'
        )
        parser.add_argument(
            '--repo-type',
            choices=['model', 'dataset'],
            default='model',
            help="Type of repo to download from (defaults to 'model').",
        )
        parser.add_argument(
            '--token',
            type=str,
            default=None,
            help='Optional. Access token to download controlled entities.')
        parser.add_argument(
            '--revision',
            type=str,
            default=None,
            help='Revision of the entity (e.g., model).')
        parser.add_argument(
            '--cache_dir',
            type=str,
            default=None,
            help='Cache directory to save entity (e.g., model).')
        parser.add_argument(
            '--local_dir',
            type=str,
            default=None,
            help='File will be downloaded to local location specified by'
            'local_dir, in this case, cache_dir parameter will be ignored.')
        parser.add_argument(
            'files',
            type=str,
            default=None,
            nargs='*',
            help='Specify relative path to the repository file(s) to download.'
            "(e.g 'tokenizer.json', 'onnx/decoder_model.onnx').")
        parser.add_argument(
            '--include',
            nargs='*',
            default=None,
            type=str,
            help='Glob patterns to match files to download.'
            'Ignored if file is specified')
        parser.add_argument(
            '--exclude',
            nargs='*',
            type=str,
            default=None,
            help='Glob patterns to exclude from files to download.'
            'Ignored if file is specified')
        parser.add_argument(
            '--max-workers',
            type=int,
            default=DEFAULT_MAX_WORKERS,
            help='The maximum number of workers to download files.')

        parser.set_defaults(func=subparser_func)

    def execute(self):
        if self.args.model or self.args.dataset:
            # the position argument of files will be put to repo_id.
            if self.args.repo_id is not None:
                if self.args.files:
                    self.args.files.insert(0, self.args.repo_id)
                else:
                    self.args.files = [self.args.repo_id]
        else:
            if self.args.repo_id is not None:
                if self.args.repo_type == 'model':
                    self.args.model = self.args.repo_id
                elif self.args.repo_type == 'dataset':
                    self.args.dataset = self.args.repo_id
                else:
                    raise Exception('Not support repo-type: %s'
                                    % self.args.repo_type)
        if not self.args.model and not self.args.dataset:
            raise Exception('Model or dataset must be set.')
        cookies = None
        if self.args.token is not None:
            api = HubApi()
            cookies = api.get_cookies(access_token=self.args.token)
        if self.args.model:
            if len(self.args.files) == 1:  # download single file
                model_file_download(
                    self.args.model,
                    self.args.files[0],
                    cache_dir=self.args.cache_dir,
                    local_dir=self.args.local_dir,
                    revision=self.args.revision,
                    cookies=cookies)
            elif len(
                    self.args.files) > 1:  # download specified multiple files.
                snapshot_download(
                    self.args.model,
                    revision=self.args.revision,
                    cache_dir=self.args.cache_dir,
                    local_dir=self.args.local_dir,
                    allow_file_pattern=self.args.files,
                    max_workers=self.args.max_workers,
                    cookies=cookies)
            else:  # download repo
                snapshot_download(
                    self.args.model,
                    revision=self.args.revision,
                    cache_dir=self.args.cache_dir,
                    local_dir=self.args.local_dir,
                    allow_file_pattern=convert_patterns(self.args.include),
                    ignore_file_pattern=convert_patterns(self.args.exclude),
                    max_workers=self.args.max_workers,
                    cookies=cookies)
            print(f'\nSuccessfully Downloaded from model {self.args.model}.\n')
        elif self.args.dataset:
            dataset_revision: str = self.args.revision if self.args.revision else DEFAULT_DATASET_REVISION
            if len(self.args.files) == 1:  # download single file
                dataset_file_download(
                    self.args.dataset,
                    self.args.files[0],
                    cache_dir=self.args.cache_dir,
                    local_dir=self.args.local_dir,
                    revision=dataset_revision,
                    cookies=cookies)
            elif len(
                    self.args.files) > 1:  # download specified multiple files.
                dataset_snapshot_download(
                    self.args.dataset,
                    revision=dataset_revision,
                    cache_dir=self.args.cache_dir,
                    local_dir=self.args.local_dir,
                    allow_file_pattern=self.args.files,
                    max_workers=self.args.max_workers,
                    cookies=cookies)
            else:  # download repo
                dataset_snapshot_download(
                    self.args.dataset,
                    revision=dataset_revision,
                    cache_dir=self.args.cache_dir,
                    local_dir=self.args.local_dir,
                    allow_file_pattern=convert_patterns(self.args.include),
                    ignore_file_pattern=convert_patterns(self.args.exclude),
                    max_workers=self.args.max_workers,
                    cookies=cookies)
            print(
                f'\nSuccessfully Downloaded from dataset {self.args.dataset}.\n'
            )
        else:
            pass  # noop
