import torch
import torch.nn.functional as F
import torch.utils.data
from torch import nn
import torch.utils.data
import torchvision
import tqdm

from unet import UNet


class DDIMScheduler:
    ## TODO: constructor
    def __init__(
        self,
        n_steps: int,
        device: torch.device,
    ):
        """
        Args:
            n_steps (int): maximum number of steps (T)
            device (torch.device): device to place constants on
        """
        super().__init__()

        # [\beta_1, \dots, \beta_T], size [1000]
        self.beta = ...

        # \alpha_t = 1 - \beta_t, size [1000]
        self.alpha = ...
        # \bar{\alpha_t} = \prod_{s=1}^t \alpha_s, size [1000]
        self.alpha_bar = ...
        # T
        self.n_steps = n_steps

    ## TODO: add_noise
    def add_noise(
        self,
        x0: torch.FloatTensor,
        t: torch.IntTensor,
        noise: torch.FloatTensor,
    ):
        """Add noise to the original image $x_0$, with timestep t

        Args:
            x0 (torch.Tensor): the original real image, size [batch_size, image_channels, image_size, image_size]
            t (torch.Tensor): timestep, size [batch_size]
            noise (torch.Tensor): noise to be added to the image, size [batch_size, image_channels, image_size, image_size]

        Returns:
            xt (torch.Tensor): noisy image, size [batch_size, image_channels, image_size, image_size]
        """
        pass

    ## BONUS: get velocity for velcoity prediction
    def get_velocity(
        self,
        sample: torch.FloatTensor,
        noise: torch.FloatTensor,
        t: torch.IntTensor
    ):
        pass

    ## TODO: sample from noise
    def sample(
        self,
        save_path: str,
        model: UNet,
        eta: float,
        n_samples: int = 16,
        image_channels: int = 1,
        image_size: int = 32,
        inference_steps: int = 1_000,
        device: torch.device = torch.device('cuda'),
    ):
        """
            Sample images
        """

        with torch.no_grad():
            # step 1: Create random noise x_0 ~ N(0, I)
            x = ...

            print("Sampling...")

            # step2: Remove noise for inference_steps
            t_stepsize = self.n_steps // inference_steps
            for t_ in tqdm.trange(0, inference_steps):
                # get t after t_ steps with size of t_stepsize
                t = ...
                # get x using denoise
                x = self.denoise(
                    ...
                )
            # step3: Save the image
            torchvision.utils.save_image(x.clamp(0., 1.), save_path, nrow=4)

    ## TODO: get variance
    def _get_variance(
        self,
        alpha_bar_t: torch.FloatTensor,
        alpha_bar_t_prev: torch.FloatTensor,
    ):
        """Compute the variance \tilde{\beta}_t"""
        pass

    ## TODO: 1 step denoise x_t at timestep t
    def denoise(
        self, 
        model: UNet,
        xt: torch.FloatTensor,
        t: torch.IntTensor,
        inference_steps: int = 1_000,
        eta: float = 1.0,
    ):
        """Denoise the noisy image x_t at timestep t, predict x_{t-1}

        Args:
            model (UNet): the model to predict the noise or velocity
            xt (torch.FloatTensor): the noisy image x_t, size [batch_size, image_channels, image_size, image_size]
            t (torch.IntTensor): timestep, size [batch_size]
            inference_steps (int): number of inference steps
            eta (float): noise level in [0, 1]

        Returns:
            xt_1 (torch.Tensor): denoised image x_{t-1}, size [batch_size, image_channels, image_size, image_size]
        """

        # step 1: Compute alphas, betas

        # step 2: Compute "predicted x_0" and predicted noise (epsilon)
        model_pred = model(xt, t)
        if model.prediction_type == "epsilon":
            pred_x0 = ...
            pred_epsilon = ...
        elif model.prediction_type == "velocity":
            pass
        else:
            raise ValueError(
                f"prediction_type given as {model.prediction_type} must be `epsilon` or `v_prediction`"
            )

        # step 3: Compute "direction pointing to x_t"
        sigma_t_2 = ...
        pred_sample_direction = ...

        # step 4: Compute x_t
        prev_sample = ...
        if eta > 0:
            prev_sample += ...

        return prev_sample