import os

import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms as T
from tqdm import tqdm

from .ray_utils import *


def trans_t(t):
    return torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, t],
                         [0, 0, 0, 1]]).float()


def rot_phi(phi):
    return torch.Tensor([[1, 0, 0, 0], [0, np.cos(phi), -np.sin(phi), 0],
                         [0, np.sin(phi), np.cos(phi), 0], [0, 0, 0,
                                                            1]]).float()


def rot_theta(th):
    return torch.Tensor([[np.cos(th), 0, -np.sin(th), 0], [0, 1, 0, 0],
                         [np.sin(th), 0, np.cos(th), 0], [0, 0, 0,
                                                          1]]).float()


def pose_spherical(theta, phi, radius):
    c2w = trans_t(radius)
    c2w = rot_phi(phi / 180. * np.pi) @ c2w
    c2w = rot_theta(theta / 180. * np.pi) @ c2w
    c2w = torch.Tensor(
        np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]
                  ])) @ c2w
    return c2w


class NSVF(Dataset):
    """NSVF Generic Dataset."""

    def __init__(self,
                 datadir,
                 split='train',
                 downsample=1.0,
                 wh=[800, 800],
                 is_stack=False):
        self.root_dir = datadir
        self.split = split
        self.is_stack = is_stack
        self.downsample = downsample
        self.img_wh = (int(wh[0] / downsample), int(wh[1] / downsample))
        self.define_transforms()

        self.white_bg = True
        self.near_far = [0.5, 6.0]
        self.scene_bbox = torch.from_numpy(
            np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2, 3)
        self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0],
                                        [0, 0, -1, 0], [0, 0, 0, 1]])
        self.read_meta()
        self.define_proj_mat()

        self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
        self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)

    def bbox2corners(self):
        corners = self.scene_bbox.unsqueeze(0).repeat(4, 1, 1)
        for i in range(3):
            corners[i, [0, 1], i] = corners[i, [1, 0], i]
        return corners.view(-1, 3)

    def read_meta(self):
        with open(os.path.join(self.root_dir, 'intrinsics.txt')) as f:
            focal = float(f.readline().split()[0])
        self.intrinsics = np.array([[focal, 0, 400.0], [0, focal, 400.0],
                                    [0, 0, 1]])
        self.intrinsics[:2] *= (np.array(self.img_wh)
                                / np.array([800, 800])).reshape(2, 1)

        pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose')))
        img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb')))

        if self.split == 'train':
            pose_files = [x for x in pose_files if x.startswith('0_')]
            img_files = [x for x in img_files if x.startswith('0_')]
        elif self.split == 'val':
            pose_files = [x for x in pose_files if x.startswith('1_')]
            img_files = [x for x in img_files if x.startswith('1_')]
        elif self.split == 'test':
            test_pose_files = [x for x in pose_files if x.startswith('2_')]
            test_img_files = [x for x in img_files if x.startswith('2_')]
            if len(test_pose_files) == 0:
                test_pose_files = [x for x in pose_files if x.startswith('1_')]
                test_img_files = [x for x in img_files if x.startswith('1_')]
            pose_files = test_pose_files
            img_files = test_img_files

        # ray directions for all pixels, same for all images (same H, W, focal)
        self.directions = get_ray_directions(
            self.img_wh[1],
            self.img_wh[0], [self.intrinsics[0, 0], self.intrinsics[1, 1]],
            center=self.intrinsics[:2, 2])  # (h, w, 3)
        self.directions = self.directions / torch.norm(
            self.directions, dim=-1, keepdim=True)

        self.render_path = torch.stack([
            pose_spherical(angle, -30.0, 4.0)
            for angle in np.linspace(-180, 180, 40 + 1)[:-1]
        ], 0)

        self.poses = []
        self.all_rays = []
        self.all_rgbs = []

        assert len(img_files) == len(pose_files)
        for img_fname, pose_fname in tqdm(
                zip(img_files, pose_files),
                desc=f'Loading data {self.split} ({len(img_files)})'):
            image_path = os.path.join(self.root_dir, 'rgb', img_fname)
            img = Image.open(image_path)
            if self.downsample != 1.0:
                img = img.resize(self.img_wh, Image.LANCZOS)
            img = self.transform(img)  # (4, h, w)
            img = img.view(img.shape[0], -1).permute(1, 0)  # (h*w, 4) RGBA
            if img.shape[-1] == 4:
                img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]
                                                  )  # blend A to RGB
            self.all_rgbs += [img]

            c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname))
            c2w = torch.FloatTensor(c2w)
            self.poses.append(c2w)  # C2W
            rays_o, rays_d = get_rays(self.directions, c2w)  # both (h*w, 3)
            self.all_rays += [torch.cat([rays_o, rays_d], 1)]  # (h*w, 8)

        self.poses = torch.stack(self.poses)
        if 'train' == self.split:
            if self.is_stack:
                self.all_rays = torch.stack(self.all_rays,
                                            0).reshape(-1, *self.img_wh[::-1],
                                                       6)
                self.all_rgbs = torch.stack(self.all_rgbs,
                                            0).reshape(-1, *self.img_wh[::-1],
                                                       3)
            else:
                self.all_rays = torch.cat(self.all_rays, 0)
                self.all_rgbs = torch.cat(self.all_rgbs, 0)
        else:
            self.all_rays = torch.stack(self.all_rays, 0)
            self.all_rgbs = torch.stack(self.all_rgbs,
                                        0).reshape(-1, *self.img_wh[::-1], 3)

    def define_transforms(self):
        self.transform = T.ToTensor()

    def define_proj_mat(self):
        self.proj_mat = torch.from_numpy(
            self.intrinsics[:3, :3]).unsqueeze(0).float() @ torch.inverse(
                self.poses)[:, :3]

    def world2ndc(self, points):
        device = points.device
        return (points - self.center.to(device)) / self.radius.to(device)

    def __len__(self):
        if self.split == 'train':
            return len(self.all_rays)
        return len(self.all_rgbs)

    def __getitem__(self, idx):

        if self.split == 'train':  # use data in the buffers
            sample = {'rays': self.all_rays[idx], 'rgbs': self.all_rgbs[idx]}

        else:  # create data for each image separately

            img = self.all_rgbs[idx]
            rays = self.all_rays[idx]

            sample = {'rays': rays, 'rgbs': img}
        return sample
