Unable to reconstruct back the images using DDPM model

88 views Asked by At

So I have trained a DDPM(diffusion) model and had the checkpoints. now I loaded the checkpoint and to check the performance of the model I have fed images on my test set to the model. The intuition here is that it will add noise to the images in the test set and denoise it back. once it is done I wanted to do use input image and the reconstructed image after removing noise and apply mse between them, but as you can see in the image below using the code below I am feeding one image and getting different image after denosing. can someone point out where I am doing wrong ?

image

# Select one image from your test dataset
with torch.no_grad():
    x_real, _ = next(iter(test_loader))  # Get the first batch of images
    x_real = x_real.cuda()
    x = x_real[0].unsqueeze(0)  # Select the first image from the batch

    z = torch.randn_like(x)
    t = torch.randint(0, len(scheduler.betas), size=(len(x),)).cuda()
    x_noisy = scheduler.add_noise(x, z, t)

    # Determine the steps at which to capture images
    capture_steps = np.linspace(0, len(scheduler.betas)-1, 4, dtype=int)

    # Initialize a list to store the images
    images = []

    # Append the original image to the images list
    images.append(x.cpu().squeeze().numpy())  # Convert the tensor to numpy for visualization

    # Loop to denoise the image
    for step in tqdm(range(len(scheduler.betas) - 1, -1, -1)):
        if step in capture_steps:
            # Convert the current image to numpy and store it
            image_np = x_noisy.cpu().squeeze().numpy()
            images.append(image_np)

        t_tensor = torch.ones((len(x_noisy),)).cuda() * step
        z_pred = model(x_noisy, t_tensor)['sample']
        x_noisy = scheduler.step(z_pred, step, x_noisy)['prev_sample']

# Compute the reconstruction error
reconstruction_error = F.mse_loss(x_noisy, x)
print(f"Reconstruction Error: {reconstruction_error.item()}")

# Visualize the images in a grid
fig, axs = plt.subplots(1, len(images), figsize=(15, 5))
titles = ['Original Image'] + [f'Step: {step}' for step in capture_steps]
for idx, image in enumerate(images):
    axs[idx].imshow(image, cmap='gray')
    axs[idx].set_title(titles[idx])

# Hide the axes
for ax in axs:
    ax.axis('off')

plt.show()
0

There are 0 answers