import torch
import torch.utils.data
import torchvision
import torch.nn.functional as F
import tqdm

from unet import UNet
from scheduler import DDIMScheduler


class MNISTDataset(torchvision.datasets.MNIST):
    def __init__(self, root, image_size):
        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize(image_size),
            torchvision.transforms.ToTensor(),
        ])

        super().__init__(root, train=True, download=True, transform=transform)

    def __getitem__(self, item):
        return super().__getitem__(item)[0]


if __name__ == '__main__':
    root_dir = "MNIST"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Number of channels in the image. $3$ for RGB.
    image_channels: int = 1
    epochs: int = 5
    image_size: int = 32
    # Prediction type: 'velocity' or 'epsilon'
    prediction_type: str = 'epsilon'
    # Scheduler eta in [0, 1]
    # 0: no noise (ddim), 1: full noise (ddpm)
    scheduler_eta: float = 1.0

    # Number of time steps $T$
    n_steps: int = 1_000
    # Batch size
    batch_size: int = 64
    # Learning rate
    learning_rate: float = 2e-5
    # Number of samples to generate
    n_samples: int = 16

    ## TODO:Create model, scheduler, dataloader and optimizer
    model = ...
    noise_scheduler = ...
    dataset = ...
    data_loader = ...
    optimizer = ...

    ## TODO: Training loop
    print("Training Start:")
    for epoch in range(epochs):
        print("Epoch {}".format(epoch))
        # Iterate through the dataset
        for data in tqdm.tqdm(data_loader):
            data = data.to(device)
            optimizer.zero_grad()
            bsz = data.shape[0] # batch size

            # step1: Get random t in [0, T) for each sample in the batch
            t = ...
            # step2: Get Gauss noise ~ N(0, I)
            noise = ...
            # step3: Sample x_t using add_noise
            
            # step4: predict noise or velocity for x_0 to x_t
            model_pred = ...

            # BONUS: velocity prediction
            if prediction_type == 'velocity':
                target = noise_scheduler.get_velocity(data, noise, t)
            elif prediction_type == 'epsilon':
                target = noise
            else:
                raise ValueError(f"prediction type given as {prediction_type} is invalid.")

            loss = F.mse_loss(target, model_pred)            
            loss.backward()
            optimizer.step()
        
        # step5: Save model and sample images to check the training progress
        torch.save(model.state_dict(), 'model_{}_epoch{}.pth'.format(prediction_type, epoch))
        noise_scheduler.sample(
            f'sample_eta_{scheduler_eta}_{prediction_type}_epoch_{epoch}.png',
            model=model,
            eta=scheduler_eta,
            n_samples=n_samples,
            image_channels=image_channels,
            image_size=image_size,
            inference_steps=n_steps,
            device=device,
        )
