# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Optional, Union

import torch

from lm_eval.models.huggingface import HFLM as eval_wrapper
from pytorch_tokenizers.hf_tokenizer import HuggingFaceTokenizer
from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer
from pytorch_tokenizers.tiktoken import TiktokenTokenizer as Tiktoken

from torch import nn


class EagerEvalWrapper(eval_wrapper):
    """
    A wrapper class based on GPTFast, providing integration with the lm-evaluation-harness library.
    """

    def __init__(
        self,
        model: nn.Module,
        tokenizer: Union[SentencePieceTokenizer, Tiktoken, HuggingFaceTokenizer],
        max_seq_length: Optional[int] = None,
        use_kv_cache: bool = False,
    ):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        super().__init__(device=device, pretrained="gpt2")
        self._model = model
        self._tokenizer = tokenizer
        self._device = torch.device(device)
        self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
        self._use_kv_cache = use_kv_cache

    @property
    def eot_token_id(self):
        """
        The stories model does not have an EOT token, so we use the EOS token instead.
        """
        if hasattr(self._tokenizer, "eot_id"):
            return self._tokenizer.eot_id
        return self._tokenizer.eos_id

    @property
    def prefix_token_id(self):
        return self.eot_token_id

    @property
    def max_length(self):
        return self._max_seq_length

    @property
    def max_gen_toks(self):
        return 50

    @property
    def batch_size(self):
        return 1

    @property
    def device(self):
        return self._device

    def tok_encode(self, string: str, **kwargs):  # pyre-ignore
        return self._tokenizer.encode(string, bos=False, eos=False)

    def tok_decode(self, tokens, **kwargs):
        return self._tokenizer.decode([tokens] if isinstance(tokens, int) else tokens)

    def _model_call(self, inps):
        if self._use_kv_cache:
            pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
            # Batch process the whole sequence.
            logits = self._model(
                inps[:, : self._max_seq_length], {"input_pos": pos_tensor}
            )
            return logits
        else:
            return self._model(inps)

    def _model_generate(self, context, max_length, eos_token_id):
        raise Exception("unimplemented")
