"""
Checkpoint reader functionality for machine learning models.

This module provides classes for reading checkpoints from storage, including
determining checkpoint layout and configuring the reader.
"""

import logging
import os
from itertools import zip_longest
from pathlib import Path
from typing import Any, Optional

import torch
from torch._subclasses.fake_tensor import FakeTensorMode

from .types import RankInfo, STATE_DICT


logger = logging.getLogger(__name__)


class CheckpointReader:
    """
    Handles reading state dictionaries from storage.

    This class is responsible for reading model state dictionaries from storage according
    to the specified checkpoint layout. It supports synchronization barriers to ensure
    all ranks in a distributed setting complete their checkpoint operations.
    """

    def __init__(
        self,
        rank_info: RankInfo,
    ):
        """
        Initialize a CheckpointReader.

        Args:
            rank_info: Information about the current rank in a distributed setting.
        """

        self._rank_info = rank_info

    def read(
        self,
        path: str,
        state_dict: Optional[STATE_DICT] = None,
        *,
        map_location: Any = None,
        **kwargs: dict[str, Any],
    ) -> tuple[STATE_DICT, list[str]]:
        """
        Reads a state dictionary from storage.

        Args:
            path (str): The path from which to read the checkpoint.
            map_location (Any): Device mapping function or device name for relocating tensors.
            **kwargs: Additional keyword arguments passed to torch.load.

        Returns:
            STATE_DICT: The loaded state dictionary.
            list[str]: List of missing keys.
        """
        logger.debug(
            "Reading checkpoint from %s for rank %s",
            path,
            self._rank_info.global_rank,
        )

        dir_path = Path(path)
        file_path = dir_path / f"checkpoint_{self._rank_info.global_rank}.pt"

        # Check if the file exists
        if not os.path.exists(file_path):
            logger.error("Checkpoint file not found at %s", file_path)
            raise FileNotFoundError(f"Checkpoint file not found at {file_path}")

        if state_dict is None:
            result: tuple[STATE_DICT, list[str]] = (
                torch.load(file_path, map_location=map_location),
                [],
            )
        else:
            result = self._partial_read(
                file_path, state_dict, map_location=map_location, **kwargs
            )
        logger.debug("Successfully read checkpoint file from %s", file_path)
        return result

    def _partial_read(
        self,
        file_path: Path,
        state_dict: STATE_DICT,
        *,
        map_location: Any = None,
        **kwargs: dict[str, Any],
    ) -> tuple[STATE_DICT, list[str]]:
        """
        Reads only the keys present in state_dict from the checkpoint file.

        This method optimizes checkpoint loading by only loading the tensors that
        are actually needed, based on the keys present in the input state_dict.
        This can significantly reduce memory usage and loading time for large checkpoints
        when only a subset of the model needs to be loaded.

        Args:
            file_path (str): The path to the checkpoint file.
            state_dict (STATE_DICT): The state dictionary containing keys to load.
            map_location (Any): Device mapping function or device name for relocating tensors.
            **kwargs: Additional keyword arguments passed to torch.load.

        Returns:
            tuple[STATE_DICT, list[str]]: The updated state dictionary with loaded values and a list of missing keys.
        """

        with FakeTensorMode():
            metadata_dict = torch.load(file_path, map_location=map_location)

        missing_keys = []

        with open(file_path, "rb") as file:
            # Helper function to load tensor data from file
            def load_tensor(
                target: Optional[torch.Tensor], source: torch.Tensor, full_key: str
            ) -> torch.Tensor:
                if target is not None and (
                    target.size() != source.size() or target.dtype != source.dtype
                ):
                    raise RuntimeError(
                        f"Target tensor size={target.size()} dtype={target.dtype} does not match "
                        f"source tensor size={source.size()} dtype={source.dtype} for key {full_key}"
                    )

                tensor_offset = source.untyped_storage()._checkpoint_offset

                assert tensor_offset is not None, (
                    "checkpoint_offset for tensor in torch serialized file is not set. This could"
                    "happen if the checkpoint was saved with a older version of Pytorch."
                    "Please make sure that the checkpoint was saved with Pytorch 2.7 or later."
                )

                tensor_len = source.nelement() * source.element_size()
                file.seek(
                    tensor_offset + source.element_size() * int(source.storage_offset())
                )
                if target is None:
                    target = torch.empty(
                        source.size(), dtype=source.dtype, device=source.device
                    )

                buffer = file.read(tensor_len)
                cpu_tensor = torch.frombuffer(buffer, dtype=source.dtype)
                tensor = cpu_tensor.view(source.size())
                target.copy_(tensor)
                return target

            # Helper function to recursively process nested structures
            def process_value(
                target_value: Any, source_value: Any, key_path: str
            ) -> Any:
                source_type = type(source_value)
                if source_type is torch._subclasses.fake_tensor.FakeTensor:
                    source_type = torch.Tensor
                if target_value is not None and not isinstance(
                    target_value, source_type
                ):
                    raise RuntimeError(
                        f"Target value {key_path} is set to {type(target_value)}, but source value is {type(source_value)}"
                    )
                if isinstance(source_value, torch.Tensor):
                    return load_tensor(target_value, source_value, key_path)
                elif isinstance(source_value, dict):
                    if target_value is None:
                        # create a new map with all the keys present in source_value
                        target_value = dict.fromkeys(source_value.keys())

                    for key in list(target_value.keys()):
                        current_path = f"{key_path}.{key}" if key_path else key
                        if key in source_value:
                            target_value[key] = process_value(
                                target_value[key], source_value[key], current_path
                            )
                        else:
                            missing_keys.append(current_path)

                    return target_value
                elif isinstance(source_value, list):
                    if target_value is None:
                        target_value = [None] * len(source_value)
                    result = []
                    for i, (target_item, source_item) in enumerate(
                        zip_longest(target_value, source_value, fillvalue=None)
                    ):
                        current_path = f"{key_path}[{i}]" if key_path else f"[{i}]"
                        result.append(
                            process_value(target_item, source_item, current_path)
                        )
                    return result
                else:
                    return source_value

            # Start recursive processing from the root of the state dictionary
            updated_state_dict = process_value(state_dict, metadata_dict, "")

        if missing_keys:
            if len(missing_keys) > 10:
                logger.warning(
                    "Missing %s keys from checkpoint: %s... (and %s more)",
                    len(missing_keys),
                    missing_keys[:10],
                    len(missing_keys) - 10,
                )
            else:
                logger.warning(
                    "Missing %s keys from checkpoint: %s",
                    len(missing_keys),
                    missing_keys,
                )

        return updated_state_dict, missing_keys
