# Copyright (c) Alibaba, Inc. and its affiliates.

import math

import torch

from .dpm_solver import (DPM_Solver, NoiseScheduleVP, model_wrapper,
                         model_wrapper_guided_diffusion)
from .ops.losses import discretized_gaussian_log_likelihood, kl_divergence

__all__ = ['GaussianDiffusion', 'beta_schedule', 'GaussianDiffusion_style']


def _i(tensor, t, x):
    r"""Index tensor using t and format the output according to x.
    """
    shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
    if tensor.device != x.device:
        tensor = tensor.to(x.device)
    return tensor[t].view(shape).to(x)


def beta_schedule(schedule,
                  num_timesteps=1000,
                  init_beta=None,
                  last_beta=None):
    '''
    This code defines a function beta_schedule that generates a sequence of beta
    values based on the given input parameters.
    These beta values can be used in video diffusion processes. The function has the following parameters:
        schedule(str): Determines the type of beta schedule to be generated.
            It can be 'linear', 'linear_sd', 'quadratic', or 'cosine'.
        num_timesteps(int, optional): The number of timesteps for the generated beta schedule. Default is 1000.
        init_beta(float, optional): The initial beta value.
            If not provided, a default value is used based on the chosen schedule.
        last_beta(float, optional): The final beta value.
            If not provided, a default value is used based on the chosen schedule.
    The function returns a PyTorch tensor containing the generated beta values.
    The beta schedule is determined by the schedule parameter:
        1.Linear: Generates a linear sequence of beta values betweeninit_betaandlast_beta.
        2.Linear_sd: Generates a linear sequence of beta values between the square root of
            init_beta and the square root oflast_beta, and then squares the result.
        3.Quadratic: Similar to the 'linear_sd' schedule, but with different default values forinit_betaandlast_beta.
        4.Cosine: Generates a sequence of beta values based on a cosine function,
            ensuring the values are between 0 and 0.999.
    If an unsupported schedule is provided, a ValueError is raised with a message indicating the issue.
    '''
    if schedule == 'linear':
        scale = 1000.0 / num_timesteps
        init_beta = init_beta or scale * 0.0001
        last_beta = last_beta or scale * 0.02
        return torch.linspace(
            init_beta, last_beta, num_timesteps, dtype=torch.float64)
    elif schedule == 'linear_sd':
        return torch.linspace(
            init_beta**0.5, last_beta**0.5, num_timesteps,
            dtype=torch.float64)**2
    elif schedule == 'quadratic':
        init_beta = init_beta or 0.0015
        last_beta = last_beta or 0.0195
        return torch.linspace(
            init_beta**0.5, last_beta**0.5, num_timesteps,
            dtype=torch.float64)**2
    elif schedule == 'cosine':
        betas = []
        for step in range(num_timesteps):
            t1 = step / num_timesteps
            t2 = (step + 1) / num_timesteps
            fn = lambda u: math.cos(  # noqa
                (u + 0.008) / 1.008 * math.pi / 2)**2  # noqa
            betas.append(min(1.0 - fn(t2) / fn(t1), 0.999))
        return torch.tensor(betas, dtype=torch.float64)
    else:
        raise ValueError(f'Unsupported schedule: {schedule}')


def load_stable_diffusion_pretrained(state_dict, temporal_attention):
    import collections
    sd_new = collections.OrderedDict()
    keys = list(state_dict.keys())

    for k in keys:
        if k.find('diffusion_model') >= 0:
            k_new = k.split('diffusion_model.')[-1]
            if k_new in [
                    'input_blocks.3.0.op.weight', 'input_blocks.3.0.op.bias',
                    'input_blocks.6.0.op.weight', 'input_blocks.6.0.op.bias',
                    'input_blocks.9.0.op.weight', 'input_blocks.9.0.op.bias'
            ]:
                k_new = k_new.replace('0.op', 'op')
            if temporal_attention:
                if k_new.find('middle_block.2') >= 0:
                    k_new = k_new.replace('middle_block.2', 'middle_block.3')
                if k_new.find('output_blocks.5.2') >= 0:
                    k_new = k_new.replace('output_blocks.5.2',
                                          'output_blocks.5.3')
                if k_new.find('output_blocks.8.2') >= 0:
                    k_new = k_new.replace('output_blocks.8.2',
                                          'output_blocks.8.3')
            sd_new[k_new] = state_dict[k]

    return sd_new


class AddGaussianNoise(object):

    def __init__(self, mean=0., std=0.1):
        self.std = std
        self.mean = mean

    def __call__(self, img):
        assert isinstance(img, torch.Tensor)
        dtype = img.dtype
        if not img.is_floating_point():
            img = img.to(torch.float32)
        out = img + self.std * torch.randn_like(img) + self.mean
        if out.dtype != dtype:
            out = out.to(dtype)
        return out

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(
            self.mean, self.std)


class GaussianDiffusion(object):

    def __init__(self,
                 betas,
                 mean_type='eps',
                 var_type='learned_range',
                 loss_type='mse',
                 epsilon=1e-12,
                 rescale_timesteps=False):
        # check input
        if not isinstance(betas, torch.DoubleTensor):
            betas = torch.tensor(betas, dtype=torch.float64)
        assert min(betas) > 0 and max(betas) <= 1
        assert mean_type in ['x0', 'x_{t-1}', 'eps']
        assert var_type in [
            'learned', 'learned_range', 'fixed_large', 'fixed_small'
        ]
        assert loss_type in [
            'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1',
            'charbonnier'
        ]
        self.betas = betas
        self.num_timesteps = len(betas)
        self.mean_type = mean_type  # eps
        self.var_type = var_type  # 'fixed_small'
        self.loss_type = loss_type  # mse
        self.epsilon = epsilon  # 1e-12
        self.rescale_timesteps = rescale_timesteps  # False

        # alphas
        alphas = 1 - self.betas
        self.alphas_cumprod = torch.cumprod(alphas, dim=0)
        self.alphas_cumprod_prev = torch.cat(
            [alphas.new_ones([1]), self.alphas_cumprod[:-1]])
        self.alphas_cumprod_next = torch.cat(
            [self.alphas_cumprod[1:],
             alphas.new_zeros([1])])

        # q(x_t | x_{t-1})
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0
                                                        - self.alphas_cumprod)
        self.log_one_minus_alphas_cumprod = torch.log(1.0
                                                      - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod
                                                      - 1)

        # q(x_{t-1} | x_t, x_0)
        self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (
            1.0 - self.alphas_cumprod)
        self.posterior_log_variance_clipped = torch.log(
            self.posterior_variance.clamp(1e-20))
        self.posterior_mean_coef1 = betas * torch.sqrt(
            self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        self.posterior_mean_coef2 = (
            1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (
                1.0 - self.alphas_cumprod)

    def q_sample(self, x0, t, noise=None):
        r"""Sample from q(x_t | x_0).
        """
        noise = torch.randn_like(x0) if noise is None else noise
        return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + \
               _i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise  # noqa

    def q_mean_variance(self, x0, t):
        r"""Distribution of q(x_t | x_0).
        """
        mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
        var = _i(1.0 - self.alphas_cumprod, t, x0)
        log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
        return mu, var, log_var

    def q_posterior_mean_variance(self, x0, xt, t):
        r"""Distribution of q(x_{t-1} | x_t, x_0).
        """
        mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
            self.posterior_mean_coef2, t, xt) * xt
        var = _i(self.posterior_variance, t, xt)
        log_var = _i(self.posterior_log_variance_clipped, t, xt)
        return mu, var, log_var

    @torch.no_grad()
    def p_sample(self,
                 xt,
                 t,
                 model,
                 model_kwargs={},
                 clamp=None,
                 percentile=None,
                 condition_fn=None,
                 guide_scale=None):
        r"""Sample from p(x_{t-1} | x_t).
            - condition_fn: for classifier-based guidance (guided-diffusion).
            - guide_scale: for classifier-free guidance (glide/dalle-2).
        """
        # predict distribution of p(x_{t-1} | x_t)
        mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
                                                    clamp, percentile,
                                                    guide_scale)

        # random sample (with optional conditional function)
        noise = torch.randn_like(xt)
        mask = t.ne(0).float().view(
            -1,
            *((1, ) *  # noqa
              (xt.ndim - 1)))
        if condition_fn is not None:
            grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
            mu = mu.float() + var * grad.float()
        xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise
        return xt_1, x0

    @torch.no_grad()
    def p_sample_loop(self,
                      noise,
                      model,
                      model_kwargs={},
                      clamp=None,
                      percentile=None,
                      condition_fn=None,
                      guide_scale=None):
        r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1).
        """
        # prepare input
        b = noise.size(0)
        xt = noise

        # diffusion process
        for step in torch.arange(self.num_timesteps).flip(0):
            t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
            xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp,
                                  percentile, condition_fn, guide_scale)
        return xt

    def p_mean_variance(self,
                        xt,
                        t,
                        model,
                        model_kwargs={},
                        clamp=None,
                        percentile=None,
                        guide_scale=None):
        r"""Distribution of p(x_{t-1} | x_t).
        """
        # predict distribution
        if guide_scale is None:
            out = model(xt, self._scale_timesteps(t), **model_kwargs)
        else:
            # classifier-free guidance
            # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
            assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
            y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0])
            u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1])
            dim = y_out.size(1) if self.var_type.startswith(
                'fixed') else y_out.size(1) // 2
            out = torch.cat(
                [
                    u_out[:, :dim] + guide_scale *  # noqa
                    (y_out[:, :dim] - u_out[:, :dim]),
                    y_out[:, dim:]
                ],
                dim=1)  # noqa

        # compute variance
        if self.var_type == 'learned':
            out, log_var = out.chunk(2, dim=1)
            var = torch.exp(log_var)
        elif self.var_type == 'learned_range':
            out, fraction = out.chunk(2, dim=1)
            min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
            max_log_var = _i(torch.log(self.betas), t, xt)
            fraction = (fraction + 1) / 2.0
            log_var = fraction * max_log_var + (1 - fraction) * min_log_var
            var = torch.exp(log_var)
        elif self.var_type == 'fixed_large':
            var = _i(
                torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t,
                xt)
            log_var = torch.log(var)
        elif self.var_type == 'fixed_small':
            var = _i(self.posterior_variance, t, xt)
            log_var = _i(self.posterior_log_variance_clipped, t, xt)

        # compute mean and x0
        if self.mean_type == 'x_{t-1}':
            mu = out  # x_{t-1}
            x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - \
                 _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, xt) * xt  # noqa
        elif self.mean_type == 'x0':
            x0 = out
            mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
        elif self.mean_type == 'eps':
            x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
                 _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out  # noqa
            mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)

        # restrict the range of x0
        if percentile is not None:
            assert percentile > 0 and percentile <= 1  # e.g., 0.995
            s = torch.quantile(
                x0.flatten(1).abs(), percentile,
                dim=1).clamp_(1.0).view(-1, 1, 1, 1)
            x0 = torch.min(s, torch.max(-s, x0)) / s
        elif clamp is not None:
            x0 = x0.clamp(-clamp, clamp)
        return mu, var, log_var, x0

    @torch.no_grad()
    def ddim_sample(self,
                    xt,
                    t,
                    model,
                    model_kwargs={},
                    clamp=None,
                    percentile=None,
                    condition_fn=None,
                    guide_scale=None,
                    ddim_timesteps=20,
                    eta=0.0):
        r"""Sample from p(x_{t-1} | x_t) using DDIM.
            - condition_fn: for classifier-based guidance (guided-diffusion).
            - guide_scale: for classifier-free guidance (glide/dalle-2).
        """
        stride = self.num_timesteps // ddim_timesteps

        # predict distribution of p(x_{t-1} | x_t)
        _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
                                           percentile, guide_scale)
        if condition_fn is not None:
            # x0 -> eps
            alpha = _i(self.alphas_cumprod, t, xt)
            eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
                  _i(self.sqrt_recipm1_alphas_cumprod, t, xt)  # noqa
            eps = eps - (1 - alpha).sqrt() * condition_fn(
                xt, self._scale_timesteps(t), **model_kwargs)

            # eps -> x0
            x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
                 _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps  # noqa

        # derive variables
        eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
              _i(self.sqrt_recipm1_alphas_cumprod, t, xt)  # noqa
        alphas = _i(self.alphas_cumprod, t, xt)
        alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
        sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) *  # noqa
                                  (1 - alphas / alphas_prev))

        # random sample
        noise = torch.randn_like(xt)
        direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
        mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
        xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise
        return xt_1, x0

    @torch.no_grad()
    def ddim_sample_loop(self,
                         noise,
                         model,
                         model_kwargs={},
                         clamp=None,
                         percentile=None,
                         condition_fn=None,
                         guide_scale=None,
                         ddim_timesteps=20,
                         eta=0.0):
        # prepare input
        b = noise.size(0)
        xt = noise

        # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps)
        steps = (1 + torch.arange(0, self.num_timesteps,
                                  self.num_timesteps // ddim_timesteps)).clamp(
                                      0, self.num_timesteps - 1).flip(0)
        # import ipdb; ipdb.set_trace()
        for step in steps:
            t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
            xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp,
                                     percentile, condition_fn, guide_scale,
                                     ddim_timesteps, eta)
        return xt

    @torch.no_grad()
    def ddim_reverse_sample(self,
                            xt,
                            t,
                            model,
                            model_kwargs={},
                            clamp=None,
                            percentile=None,
                            guide_scale=None,
                            ddim_timesteps=20):
        r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic).
        """
        stride = self.num_timesteps // ddim_timesteps

        # predict distribution of p(x_{t-1} | x_t)
        _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
                                           percentile, guide_scale)

        # derive variables
        eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
              _i(self.sqrt_recipm1_alphas_cumprod, t, xt)  # noqa
        alphas_next = _i(
            torch.cat(
                [self.alphas_cumprod,
                 self.alphas_cumprod.new_zeros([1])]),
            (t + stride).clamp(0, self.num_timesteps), xt)

        # reverse sample
        mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps
        return mu, x0

    @torch.no_grad()
    def ddim_reverse_sample_loop(self,
                                 x0,
                                 model,
                                 model_kwargs={},
                                 clamp=None,
                                 percentile=None,
                                 guide_scale=None,
                                 ddim_timesteps=20):
        # prepare input
        b = x0.size(0)
        xt = x0

        # reconstruction steps
        steps = torch.arange(0, self.num_timesteps,
                             self.num_timesteps // ddim_timesteps)
        for step in steps:
            t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
            xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp,
                                             percentile, guide_scale,
                                             ddim_timesteps)
        return xt

    @torch.no_grad()
    def plms_sample(self,
                    xt,
                    t,
                    model,
                    model_kwargs={},
                    clamp=None,
                    percentile=None,
                    condition_fn=None,
                    guide_scale=None,
                    plms_timesteps=20):
        r"""Sample from p(x_{t-1} | x_t) using PLMS.
            - condition_fn: for classifier-based guidance (guided-diffusion).
            - guide_scale: for classifier-free guidance (glide/dalle-2).
        """
        stride = self.num_timesteps // plms_timesteps

        # function for compute eps
        def compute_eps(xt, t):
            # predict distribution of p(x_{t-1} | x_t)
            _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
                                               clamp, percentile, guide_scale)

            # condition
            if condition_fn is not None:
                # x0 -> eps
                alpha = _i(self.alphas_cumprod, t, xt)
                eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
                      _i(self.sqrt_recipm1_alphas_cumprod, t, xt)  # noqa
                eps = eps - (1 - alpha).sqrt() * condition_fn(
                    xt, self._scale_timesteps(t), **model_kwargs)

                # eps -> x0
                x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
                     _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps  # noqa

            # derive eps
            eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
                  _i(self.sqrt_recipm1_alphas_cumprod, t, xt)  # noqa
            return eps

        # function for compute x_0 and x_{t-1}
        def compute_x0(eps, t):
            # eps -> x0
            x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
                 _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps  # noqa

            # deterministic sample
            alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt)
            direction = torch.sqrt(1 - alphas_prev) * eps
            xt_1 = torch.sqrt(alphas_prev) * x0 + direction
            return xt_1, x0

        # PLMS sample
        eps = compute_eps(xt, t)
        if len(eps_cache) == 0:
            # 2nd order pseudo improved Euler
            xt_1, x0 = compute_x0(eps, t)
            eps_next = compute_eps(xt_1, (t - stride).clamp(0))
            eps_prime = (eps + eps_next) / 2.0
        elif len(eps_cache) == 1:
            # 2nd order pseudo linear multistep (Adams-Bashforth)
            eps_prime = (3 * eps - eps_cache[-1]) / 2.0
        elif len(eps_cache) == 2:
            # 3rd order pseudo linear multistep (Adams-Bashforth)
            eps_prime = (23 * eps - 16 * eps_cache[-1]
                         + 5 * eps_cache[-2]) / 12.0
        elif len(eps_cache) >= 3:
            # 4nd order pseudo linear multistep (Adams-Bashforth)
            eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2]
                         - 9 * eps_cache[-3]) / 24.0
        xt_1, x0 = compute_x0(eps_prime, t)
        return xt_1, x0, eps

    @torch.no_grad()
    def plms_sample_loop(self,
                         noise,
                         model,
                         model_kwargs={},
                         clamp=None,
                         percentile=None,
                         condition_fn=None,
                         guide_scale=None,
                         plms_timesteps=20):
        # prepare input
        b = noise.size(0)
        xt = noise

        # diffusion process
        steps = (1 + torch.arange(0, self.num_timesteps,
                                  self.num_timesteps // plms_timesteps)).clamp(
                                      0, self.num_timesteps - 1).flip(0)
        eps_cache = []
        for step in steps:
            # PLMS sampling step
            t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
            xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp,
                                          percentile, condition_fn,
                                          guide_scale, plms_timesteps,
                                          eps_cache)

            # update eps cache
            eps_cache.append(eps)
            if len(eps_cache) >= 4:
                eps_cache.pop(0)
        return xt

    def loss(self,
             x0,
             t,
             model,
             model_kwargs={},
             noise=None,
             weight=None,
             use_div_loss=False):
        noise = torch.randn_like(
            x0) if noise is None else noise  # [80, 4, 8, 32, 32]
        xt = self.q_sample(x0, t, noise=noise)

        # compute loss
        if self.loss_type in ['kl', 'rescaled_kl']:
            loss, _ = self.variational_lower_bound(x0, xt, t, model,
                                                   model_kwargs)
            if self.loss_type == 'rescaled_kl':
                loss = loss * self.num_timesteps
        elif self.loss_type in ['mse', 'rescaled_mse', 'l1',
                                'rescaled_l1']:  # self.loss_type: mse
            out = model(xt, self._scale_timesteps(t), **model_kwargs)

            # VLB for variation
            loss_vlb = 0.0
            if self.var_type in ['learned', 'learned_range'
                                 ]:  # self.var_type: 'fixed_small'
                out, var = out.chunk(2, dim=1)
                frozen = torch.cat([
                    out.detach(), var
                ], dim=1)  # learn var without affecting the prediction of mean
                loss_vlb, _ = self.variational_lower_bound(
                    x0, xt, t, model=lambda *args, **kwargs: frozen)
                if self.loss_type.startswith('rescaled_'):
                    loss_vlb = loss_vlb * self.num_timesteps / 1000.0

            # MSE/L1 for x0/eps
            target = {
                'eps': noise,
                'x0': x0,
                'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]
            }[self.mean_type]
            loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2
                                      ).abs().flatten(1).mean(dim=1)
            if weight is not None:
                loss = loss * weight

            # div loss
            if use_div_loss and self.mean_type == 'eps' and x0.shape[2] > 1:

                # derive  x0
                x0_ = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
                    _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out

                # ncfhw, std on f
                div_loss = 0.001 / (
                    x0_.std(dim=2).flatten(1).mean(dim=1) + 1e-4)
                loss = loss + div_loss

            # total loss
            loss = loss + loss_vlb
        elif self.loss_type in ['charbonnier']:
            out = model(xt, self._scale_timesteps(t), **model_kwargs)

            # VLB for variation
            loss_vlb = 0.0
            if self.var_type in ['learned', 'learned_range']:
                out, var = out.chunk(2, dim=1)
                frozen = torch.cat([out.detach(), var], dim=1)
                loss_vlb, _ = self.variational_lower_bound(
                    x0, xt, t, model=lambda *args, **kwargs: frozen)
                if self.loss_type.startswith('rescaled_'):
                    loss_vlb = loss_vlb * self.num_timesteps / 1000.0

            # MSE/L1 for x0/eps
            target = {
                'eps': noise,
                'x0': x0,
                'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]
            }[self.mean_type]
            loss = torch.sqrt((out - target)**2 + self.epsilon)
            if weight is not None:
                loss = loss * weight
            loss = loss.flatten(1).mean(dim=1)

            # total loss
            loss = loss + loss_vlb
        return loss

    def variational_lower_bound(self,
                                x0,
                                xt,
                                t,
                                model,
                                model_kwargs={},
                                clamp=None,
                                percentile=None):
        # compute groundtruth and predicted distributions
        mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t)
        mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
                                                    clamp, percentile)

        # compute KL loss
        kl = kl_divergence(mu1, log_var1, mu2, log_var2)
        kl = kl.flatten(1).mean(dim=1) / math.log(2.0)

        # compute discretized NLL loss (for p(x0 | x1) only)
        nll = -discretized_gaussian_log_likelihood(
            x0, mean=mu2, log_scale=0.5 * log_var2)
        nll = nll.flatten(1).mean(dim=1) / math.log(2.0)

        # NLL for p(x0 | x1) and KL otherwise
        vlb = torch.where(t == 0, nll, kl)
        return vlb, x0

    @torch.no_grad()
    def variational_lower_bound_loop(self,
                                     x0,
                                     model,
                                     model_kwargs={},
                                     clamp=None,
                                     percentile=None):
        r"""Compute the entire variational lower bound, measured in bits-per-dim.
        """
        # prepare input and output
        b = x0.size(0)
        metrics = {'vlb': [], 'mse': [], 'x0_mse': []}

        # loop
        for step in torch.arange(self.num_timesteps).flip(0):
            # compute VLB
            t = torch.full((b, ), step, dtype=torch.long, device=x0.device)
            noise = torch.randn_like(x0)
            xt = self.q_sample(x0, t, noise)
            vlb, pred_x0 = self.variational_lower_bound(
                x0, xt, t, model, model_kwargs, clamp, percentile)

            # predict eps from x0
            eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
                  _i(self.sqrt_recipm1_alphas_cumprod, t, xt)  # noqa

            # collect metrics
            metrics['vlb'].append(vlb)
            metrics['x0_mse'].append(
                (pred_x0 - x0).square().flatten(1).mean(dim=1))
            metrics['mse'].append(
                (eps - noise).square().flatten(1).mean(dim=1))
        metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()}

        # compute the prior KL term for VLB, measured in bits-per-dim
        mu, _, log_var = self.q_mean_variance(x0, t)
        kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu),
                                 torch.zeros_like(log_var))
        kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0)

        # update metrics
        metrics['prior_bits_per_dim'] = kl_prior
        metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior
        return metrics

    def _scale_timesteps(self, t):
        if self.rescale_timesteps:  # noqa
            return t.float() * 1000.0 / self.num_timesteps
        return t


class GaussianDiffusion_style(object):

    def __init__(self,
                 betas,
                 mean_type='eps',
                 var_type='fixed_small',
                 loss_type='mse',
                 rescale_timesteps=False):
        # check input
        if not isinstance(betas, torch.DoubleTensor):
            betas = torch.tensor(betas, dtype=torch.float64)
        assert min(betas) > 0 and max(betas) <= 1
        assert mean_type in ['x0', 'x_{t-1}', 'eps']
        assert var_type in [
            'learned', 'learned_range', 'fixed_large', 'fixed_small'
        ]
        assert loss_type in [
            'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1'
        ]
        self.betas = betas
        self.num_timesteps = len(betas)
        self.mean_type = mean_type
        self.var_type = var_type
        self.loss_type = loss_type
        self.rescale_timesteps = rescale_timesteps

        # alphas
        alphas = 1 - self.betas
        self.alphas_cumprod = torch.cumprod(alphas, dim=0)
        self.alphas_cumprod_prev = torch.cat(
            [alphas.new_ones([1]), self.alphas_cumprod[:-1]])
        self.alphas_cumprod_next = torch.cat(
            [self.alphas_cumprod[1:],
             alphas.new_zeros([1])])

        # q(x_t | x_{t-1})
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0
                                                        - self.alphas_cumprod)
        self.log_one_minus_alphas_cumprod = torch.log(1.0
                                                      - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod
                                                      - 1)

        # q(x_{t-1} | x_t, x_0)
        self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (
            1.0 - self.alphas_cumprod)
        self.posterior_log_variance_clipped = torch.log(
            self.posterior_variance.clamp(1e-20))
        self.posterior_mean_coef1 = betas * torch.sqrt(
            self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        self.posterior_mean_coef2 = (
            1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (
                1.0 - self.alphas_cumprod)

    def q_sample(self, x0, t, noise=None):
        r"""Sample from q(x_t | x_0).
        """
        noise = torch.randn_like(x0) if noise is None else noise
        xt = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + \
             _i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise  # noqa
        return xt.type_as(x0)

    def q_mean_variance(self, x0, t):
        r"""Distribution of q(x_t | x_0).
        """
        mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0
        var = _i(1.0 - self.alphas_cumprod, t, x0)
        log_var = _i(self.log_one_minus_alphas_cumprod, t, x0)
        return mu, var, log_var

    def q_posterior_mean_variance(self, x0, xt, t):
        r"""Distribution of q(x_{t-1} | x_t, x_0).
        """
        mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(
            self.posterior_mean_coef2, t, xt) * xt
        var = _i(self.posterior_variance, t, xt)
        log_var = _i(self.posterior_log_variance_clipped, t, xt)
        return mu, var, log_var

    @torch.no_grad()
    def p_sample(self,
                 xt,
                 t,
                 model,
                 model_kwargs={},
                 clamp=None,
                 percentile=None,
                 condition_fn=None,
                 guide_scale=None):
        r"""Sample from p(x_{t-1} | x_t).
            - condition_fn: for classifier-based guidance (guided-diffusion).
            - guide_scale: for classifier-free guidance (glide/dalle-2).
        """
        dtype = xt.dtype

        # predict distribution of p(x_{t-1} | x_t)
        mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
                                                    clamp, percentile,
                                                    guide_scale)

        # random sample (with optional conditional function)
        noise = torch.randn_like(xt)
        t_mask = t.ne(0).float().view(
            -1,
            *((1, ) *  # noqa
              (xt.ndim - 1)))
        if condition_fn is not None:
            grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs)
            mu = mu.float() + var * grad.float()
        xt_1 = mu + t_mask * torch.exp(0.5 * log_var) * noise
        return xt_1.type(dtype), x0.type(dtype)

    @torch.no_grad()
    def p_sample_loop(self,
                      noise,
                      model,
                      model_kwargs={},
                      clamp=None,
                      percentile=None,
                      condition_fn=None,
                      guide_scale=None):
        r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1).
        """
        # prepare input
        b = noise.size(0)
        xt = noise

        # diffusion process
        for step in torch.arange(self.num_timesteps).flip(0):
            t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
            xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp,
                                  percentile, condition_fn, guide_scale)
        return xt

    def p_mean_variance(self,
                        xt,
                        t,
                        model,
                        model_kwargs={},
                        clamp=None,
                        percentile=None,
                        guide_scale=None):
        r"""Distribution of p(x_{t-1} | x_t).
        """
        # predict distribution
        if guide_scale is None:
            out = model(xt, t=self._scale_timesteps(t), **model_kwargs)
        else:
            # classifier-free guidance
            # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs)
            assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
            y_out = model(xt, t=self._scale_timesteps(t), **model_kwargs[0])
            if guide_scale != 1.0:
                u_out = model(
                    xt, t=self._scale_timesteps(t), **model_kwargs[1])
                dim = y_out.size(1) if self.var_type.startswith(
                    'fixed') else y_out.size(1) // 2
                out = torch.cat(
                    [
                        u_out[:, :dim] + guide_scale *  # noqa
                        (y_out[:, :dim] - u_out[:, :dim]),
                        y_out[:, dim:]
                    ],
                    dim=1)  # noqa
            else:
                out = y_out

        # compute variance
        if self.var_type == 'learned':
            out, log_var = out.chunk(2, dim=1)
            var = torch.exp(log_var)
        elif self.var_type == 'learned_range':
            out, fraction = out.chunk(2, dim=1)
            min_log_var = _i(self.posterior_log_variance_clipped, t, xt)
            max_log_var = _i(torch.log(self.betas), t, xt)
            fraction = (fraction + 1) / 2.0
            log_var = fraction * max_log_var + (1 - fraction) * min_log_var
            var = torch.exp(log_var)
        elif self.var_type == 'fixed_large':
            var = _i(
                torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t,
                xt)
            log_var = torch.log(var)
        elif self.var_type == 'fixed_small':
            var = _i(self.posterior_variance, t, xt)
            log_var = _i(self.posterior_log_variance_clipped, t, xt)

        # compute mean and x0
        if self.mean_type == 'x_{t-1}':
            mu = out  # x_{t-1}
            x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - \
                 _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, xt) * xt  # noqa
        elif self.mean_type == 'x0':
            x0 = out
        elif self.mean_type == 'eps':
            x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
                 _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out  # noqa

        # restrict the range of x0
        if percentile is not None:
            assert percentile > 0 and percentile <= 1  # e.g., 0.995
            s = torch.quantile(
                x0.flatten(1).abs(), percentile,
                dim=1).clamp_(1.0).view(-1, 1, 1, 1, 1)
            # s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1).clamp_(1.0).view(-1, 1, 1, 1) # old
            x0 = torch.min(s, torch.max(-s, x0)) / s
        elif clamp is not None:
            x0 = x0.clamp(-clamp, clamp)

        # recompute mu using the restricted x0
        mu, _, _ = self.q_posterior_mean_variance(x0, xt, t)
        return mu, var, log_var, x0

    @torch.no_grad()
    def ddim_sample(self,
                    xt,
                    t,
                    t_prev,
                    model,
                    model_kwargs={},
                    clamp=None,
                    percentile=None,
                    condition_fn=None,
                    guide_scale=None,
                    ddim_timesteps=20,
                    eta=0.0):
        r"""Sample from p(x_{t-1} | x_t) using DDIM.
            - condition_fn: for classifier-based guidance (guided-diffusion).
            - guide_scale: for classifier-free guidance (glide/dalle-2).
        """
        dtype = xt.dtype

        # predict distribution of p(x_{t-1} | x_t)
        _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
                                           percentile, guide_scale)
        if condition_fn is not None:
            # x0 -> eps
            alpha = _i(self.alphas_cumprod, t, xt)
            eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
                  _i(self.sqrt_recipm1_alphas_cumprod, t, xt)  # noqa
            eps = eps - (1 - alpha).sqrt() * condition_fn(
                xt, self._scale_timesteps(t), **model_kwargs)

            # eps -> x0
            x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
                 _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps  # noqa

        # derive variables
        eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
              _i(self.sqrt_recipm1_alphas_cumprod, t, xt)  # noqa
        alphas = _i(self.alphas_cumprod, t, xt)
        alphas_prev = _i(self.alphas_cumprod, t_prev, xt)
        sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) *  # noqa
                                  (1 - alphas / alphas_prev))

        # random sample
        noise = torch.randn_like(xt)
        direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps
        t_mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1)))
        xt_1 = torch.sqrt(
            alphas_prev) * x0 + direction + t_mask * sigmas * noise
        return xt_1.type(dtype), x0.type(dtype)

    @torch.no_grad()
    def ddim_sample_loop(self,
                         noise,
                         model,
                         model_kwargs={},
                         clamp=None,
                         percentile=None,
                         condition_fn=None,
                         guide_scale=None,
                         ddim_timesteps=20,
                         eta=0.0):
        # prepare input
        b = noise.size(0)
        xt = noise

        # diffusion process
        steps = (1 + torch.arange(0, self.num_timesteps,
                                  self.num_timesteps // ddim_timesteps)).clamp(
                                      0, self.num_timesteps - 1).flip(0)
        for i, step in enumerate(steps):
            t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
            t_prev = torch.full((b, ),
                                steps[i + 1] if i < len(steps) - 1 else 0,
                                dtype=torch.long,
                                device=xt.device)
            xt, _ = self.ddim_sample(xt, t, t_prev, model, model_kwargs, clamp,
                                     percentile, condition_fn, guide_scale,
                                     ddim_timesteps, eta)
        return xt

    @torch.no_grad()
    def ddim_reverse_sample(self,
                            xt,
                            t,
                            t_next,
                            model,
                            model_kwargs={},
                            clamp=None,
                            percentile=None,
                            guide_scale=None,
                            ddim_timesteps=20):
        r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic).
        """
        dtype = xt.dtype

        # predict distribution of p(x_{t-1} | x_t)
        _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp,
                                           percentile, guide_scale)

        # derive variables
        eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
              _i(self.sqrt_recipm1_alphas_cumprod, t, xt)  # noqa
        alphas_next = _i(
            torch.cat(
                [self.alphas_cumprod,
                 self.alphas_cumprod.new_zeros([1])]), t_next, xt)

        # reverse sample
        mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps
        return mu.type(dtype), x0.type(dtype)

    @torch.no_grad()
    def ddim_reverse_sample_loop(self,
                                 x0,
                                 model,
                                 model_kwargs={},
                                 clamp=None,
                                 percentile=None,
                                 guide_scale=None,
                                 ddim_timesteps=20):
        # prepare input
        b = x0.size(0)
        xt = x0

        # reconstruction steps
        steps = (1 + torch.arange(0, self.num_timesteps,
                                  self.num_timesteps // ddim_timesteps)).clamp(
                                      0, self.num_timesteps - 1)
        for i, step in enumerate(steps):
            t = torch.full((b, ),
                           steps[i - 1] if i > 0 else 0,
                           dtype=torch.long,
                           device=xt.device)
            t_next = torch.full((b, ),
                                step,
                                dtype=torch.long,
                                device=xt.device)
            xt, _ = self.ddim_reverse_sample(xt, t, t_next, model,
                                             model_kwargs, clamp, percentile,
                                             guide_scale, ddim_timesteps)
        return xt

    @torch.no_grad()
    def plms_sample(self,
                    xt,
                    t,
                    t_prev,
                    model,
                    model_kwargs={},
                    clamp=None,
                    percentile=None,
                    condition_fn=None,
                    guide_scale=None,
                    plms_timesteps=20):
        r"""Sample from p(x_{t-1} | x_t) using PLMS.
            - condition_fn: for classifier-based guidance (guided-diffusion).
            - guide_scale: for classifier-free guidance (glide/dalle-2).
        """

        # function for compute eps
        def compute_eps(xt, t):
            dtype = xt.dtype

            # predict distribution of p(x_{t-1} | x_t)
            _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
                                               clamp, percentile, guide_scale)

            # condition
            if condition_fn is not None:
                # x0 -> eps
                alpha = _i(self.alphas_cumprod, t, xt)
                eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
                      _i(self.sqrt_recipm1_alphas_cumprod, t, xt)  # noqa
                eps = eps - (1 - alpha).sqrt() * condition_fn(
                    xt, self._scale_timesteps(t), **model_kwargs)

                # eps -> x0
                x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
                     _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps  # noqa

            # derive eps
            eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
                  _i(self.sqrt_recipm1_alphas_cumprod, t, xt)  # noqa
            return eps.type(dtype)

        # function for compute x_0 and x_{t-1}
        def compute_x0(eps, t):
            dtype = eps.dtype

            # eps -> x0
            x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \
                 _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps  # noqa

            # deterministic sample
            alphas_prev = _i(self.alphas_cumprod, t_prev, xt)
            direction = torch.sqrt(1 - alphas_prev) * eps
            xt_1 = torch.sqrt(alphas_prev) * x0 + direction
            return xt_1.type(dtype), x0.type(dtype)

        # PLMS sample
        eps = compute_eps(xt, t)
        if len(eps_cache) == 0:
            # 2nd order pseudo improved Euler
            xt_1, x0 = compute_x0(eps, t)
            eps_next = compute_eps(xt_1, t_prev)
            eps_prime = (eps + eps_next) / 2.0
        elif len(eps_cache) == 1:
            # 2nd order pseudo linear multistep (Adams-Bashforth)
            eps_prime = (3 * eps - eps_cache[-1]) / 2.0
        elif len(eps_cache) == 2:
            # 3rd order pseudo linear multistep (Adams-Bashforth)
            eps_prime = (23 * eps - 16 * eps_cache[-1]
                         + 5 * eps_cache[-2]) / 12.0
        elif len(eps_cache) >= 3:
            # 4nd order pseudo linear multistep (Adams-Bashforth)
            eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2]
                         - 9 * eps_cache[-3]) / 24.0
        xt_1, x0 = compute_x0(eps_prime, t)
        return xt_1, x0, eps

    @torch.no_grad()
    def plms_sample_loop(self,
                         noise,
                         model,
                         model_kwargs={},
                         clamp=None,
                         percentile=None,
                         condition_fn=None,
                         guide_scale=None,
                         plms_timesteps=20):
        # prepare input
        b = noise.size(0)
        xt = noise

        # diffusion process
        steps = (1 + torch.arange(0, self.num_timesteps,
                                  self.num_timesteps // plms_timesteps)).clamp(
                                      0, self.num_timesteps - 1).flip(0)
        eps_cache = []
        for i, step in enumerate(steps):
            # PLMS sampling step
            t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
            t_prev = torch.full((b, ),
                                steps[i + 1] if i < len(steps) - 1 else 0,
                                dtype=torch.long,
                                device=xt.device)
            xt, _, eps = self.plms_sample(xt, t, t_prev, model, model_kwargs,
                                          clamp, percentile, condition_fn,
                                          guide_scale, plms_timesteps,
                                          eps_cache)

            # update eps cache
            eps_cache.append(eps)
            if len(eps_cache) >= 4:
                eps_cache.pop(0)
        return xt

    @torch.no_grad()
    def dpm_solver_sample_loop(self,
                               noise,
                               model,
                               model_kwargs={},
                               order=2,
                               skip_type='logSNR',
                               method='multistep',
                               clamp=None,
                               percentile=None,
                               condition_fn=None,
                               guide_scale=None,
                               dpm_solver_timesteps=20,
                               algorithm_type='dpmsolver++',
                               t_start=None,
                               t_end=None,
                               lower_order_final=True,
                               denoise_to_zero=False,
                               solver_type='dpmsolver'):
        r"""Sample using DPM-Solver-based method.
            - condition_fn: for classifier-based guidance (guided-diffusion).
            - guide_scale: for classifier-free guidance (glide/dalle-2).

            Please check all the parameters in `dpm_solver.sample` before using.
        """
        assert self.mean_type in ('eps', 'x0')
        assert percentile in (None, 0.995)
        assert clamp is None or percentile is None
        noise_schedule = NoiseScheduleVP(
            schedule='discrete', betas=self.betas.float())
        model_fn = model_wrapper_guided_diffusion(
            model=model,
            noise_schedule=noise_schedule,
            var_type=self.var_type,
            mean_type=self.mean_type,
            model_kwargs=model_kwargs,
            rescale_timesteps=self.rescale_timesteps,
            num_timesteps=self.num_timesteps,
            guide_scale=guide_scale,
            condition_fn=condition_fn)
        dpm_solver = DPM_Solver(
            model_fn=model_fn,
            noise_schedule=noise_schedule,
            algorithm_type=algorithm_type,
            percentile=percentile,
            clamp=clamp)
        xt = dpm_solver.sample(
            noise,
            steps=dpm_solver_timesteps,
            order=order,
            skip_type=skip_type,
            method=method,
            solver_type=solver_type,
            t_start=t_start,
            t_end=t_end,
            lower_order_final=lower_order_final,
            denoise_to_zero=denoise_to_zero)
        return xt

    @torch.no_grad()
    def inpaint_p_sample(self,
                         xt,
                         t,
                         y,
                         mask,
                         model,
                         model_kwargs={},
                         clamp=None,
                         percentile=None,
                         guide_scale=None):
        r"""DDPM sampling step for inpainting.
        """
        dtype = xt.dtype

        # predict distribution of p(x_{t-1} | x_t), conditioned on y and mask
        xt = self.q_sample(y, t) * mask + xt * (1 - mask)
        mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
                                                    clamp, percentile,
                                                    guide_scale)

        # random sample
        t_mask = t.ne(0).float().view(
            -1,
            *((1, ) *  # noqa
              (xt.ndim - 1)))
        xt_1 = mu + t_mask * torch.exp(0.5 * log_var) * torch.randn_like(xt)
        return xt_1.type(dtype), x0.type(dtype)

    @torch.no_grad()
    def inpaint_p_sample_loop(self,
                              noise,
                              y,
                              mask,
                              model,
                              model_kwargs={},
                              clamp=None,
                              percentile=None,
                              guide_scale=None):
        r"""DDPM sampling loop for inpainting.
        """
        # prepare input
        b = noise.size(0)
        xt = noise

        # diffusion process
        for step in torch.arange(self.num_timesteps).flip(0):
            t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
            xt, _ = self.inpaint_p_sample(xt, t, y, mask, model, model_kwargs,
                                          clamp, percentile, guide_scale)
        return xt

    @torch.no_grad()
    def inpaint_mcg_p_sample(self,
                             xt,
                             t,
                             y,
                             mask,
                             model,
                             model_kwargs={},
                             clamp=None,
                             percentile=None,
                             guide_scale=None,
                             mcg_scale=1.0):
        r"""DDPM sampling step for inpainting, with Manifold Constrained Gradient (MCG) correction.
        """
        dtype = xt.dtype

        # predict distribution of p(x_{t-1} | x_t), conditioned on y and mask
        with torch.enable_grad():
            xt.requires_grad_(True)
            mu, var, log_var, x0 = self.p_mean_variance(
                xt, t, model, model_kwargs, clamp, percentile, guide_scale)
            loss = (y * mask - x0 * mask).square().mean()
            grad = torch.autograd.grad(loss, xt)[0]

        # random sample
        t_mask = t.ne(0).float().view(
            -1,
            *((1, ) *  # noqa
              (xt.ndim - 1)))
        xt_1 = mu + t_mask * torch.exp(0.5 * log_var) * torch.randn_like(xt)
        xt_1 = xt_1 - mcg_scale * grad

        # merge foreground and background
        xt_1 = self.q_sample(y, t) * mask + xt_1 * (1 - mask)
        return xt_1.type(dtype), x0.type(dtype)

    @torch.no_grad()
    def inpaint_mcg_p_sample_loop(self,
                                  noise,
                                  y,
                                  mask,
                                  model,
                                  model_kwargs={},
                                  clamp=None,
                                  percentile=None,
                                  guide_scale=None,
                                  mcg_scale=1.0):
        r"""DDPM sampling loop for inpainting, with Manifold Constrained Gradient (MCG) correction.
        """
        # prepare input
        b = noise.size(0)
        xt = noise

        # diffusion process
        for step in torch.arange(self.num_timesteps).flip(0):
            t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
            xt, _ = self.inpaint_mcg_p_sample(xt, t, y, mask, model,
                                              model_kwargs, clamp, percentile,
                                              guide_scale, mcg_scale)
        return xt

    def loss(self,
             x0,
             t,
             model,
             model_kwargs={},
             noise=None,
             input_x0=None,
             reduction='mean'):
        assert reduction in ['mean', 'none']
        noise = torch.randn_like(x0) if noise is None else noise
        input_x0 = x0 if input_x0 is None else input_x0
        xt = self.q_sample(input_x0, t, noise=noise)

        # compute loss
        if self.loss_type in ['kl', 'rescaled_kl']:
            loss, _ = self.variational_lower_bound(x0, xt, t, model,
                                                   model_kwargs)
            if self.loss_type == 'rescaled_kl':
                loss = loss * self.num_timesteps
        elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']:
            out = model(xt, t=self._scale_timesteps(t), **model_kwargs)

            # VLB for variation
            loss_vlb = 0.0
            if self.var_type in ['learned', 'learned_range']:
                out, var = out.chunk(2, dim=1)
                frozen = torch.cat([
                    out.detach(), var
                ], dim=1)  # learn var without affecting the prediction of mean
                loss_vlb, _ = self.variational_lower_bound(
                    x0,
                    xt,
                    t,
                    model=lambda *args, **kwargs: frozen,
                    reduction=reduction)
                if self.loss_type.startswith('rescaled_'):
                    loss_vlb = loss_vlb * self.num_timesteps / 1000.0

            # MSE/L1 for x0/eps
            target = {
                'eps': noise,
                'x0': x0,
                'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]
            }[self.mean_type]
            loss = (
                out
                - target).pow(1 if self.loss_type.endswith('l1') else 2).abs()
            if reduction == 'mean':
                loss = loss.flatten(1).mean(dim=1)

            # total loss
            loss = loss + loss_vlb
        return loss

    def variational_lower_bound(self,
                                x0,
                                xt,
                                t,
                                model,
                                model_kwargs={},
                                clamp=None,
                                percentile=None,
                                reduction='mean'):
        assert reduction in ['mean', 'none']

        # compute groundtruth and predicted distributions
        mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t)
        mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs,
                                                    clamp, percentile)

        # compute KL loss
        kl = kl_divergence(mu1, log_var1, mu2, log_var2) / math.log(2.0)
        if reduction == 'mean':
            kl = kl.flatten(1).mean(dim=1)

        # compute discretized NLL loss (for p(x0 | x1) only)
        nll = -discretized_gaussian_log_likelihood(
            x0, mean=mu2, log_scale=0.5 * log_var2) / math.log(2.0)
        if reduction == 'mean':
            nll = nll.flatten(1).mean(dim=1)

        # NLL for p(x0 | x1) and KL otherwise
        t = t.view(-1, *(1, ) * (nll.ndim - 1))
        vlb = torch.where(t == 0, nll, kl)
        return vlb, x0

    @torch.no_grad()
    def variational_lower_bound_loop(self,
                                     x0,
                                     model,
                                     model_kwargs={},
                                     clamp=None,
                                     percentile=None):
        r"""Compute the entire variational lower bound, measured in bits-per-dim.
        """
        # prepare input and output
        b = x0.size(0)
        metrics = {'vlb': [], 'mse': [], 'x0_mse': []}

        # loop
        for step in torch.arange(self.num_timesteps).flip(0):
            # compute VLB
            t = torch.full((b, ), step, dtype=torch.long, device=x0.device)
            noise = torch.randn_like(x0)
            xt = self.q_sample(x0, t, noise)
            vlb, pred_x0 = self.variational_lower_bound(
                x0, xt, t, model, model_kwargs, clamp, percentile)

            # predict eps from x0
            eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \
                  _i(self.sqrt_recipm1_alphas_cumprod, t, xt)  # noqa

            # collect metrics
            metrics['vlb'].append(vlb)
            metrics['x0_mse'].append(
                (pred_x0 - x0).square().flatten(1).mean(dim=1))
            metrics['mse'].append(
                (eps - noise).square().flatten(1).mean(dim=1))
        metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()}

        # compute the prior KL term for VLB, measured in bits-per-dim
        mu, _, log_var = self.q_mean_variance(x0, t)
        kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu),
                                 torch.zeros_like(log_var))
        kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0)

        # update metrics
        metrics['prior_bits_per_dim'] = kl_prior
        metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior
        return metrics

    def _scale_timesteps(self, t):
        if self.rescale_timesteps:
            return t.float() * 1000.0 / self.num_timesteps
        return t
