-
Notifications
You must be signed in to change notification settings - Fork 90
Open
Description
Hi,
I want to use Clay for my research and therefore try around with the model to see how the reconstructions look.
Since there is no dummy data for the reconstruction notebook I tried to copy the reconstruction code to the wall-to-wall example.
My "Run the Model" chapter therefore looks like this:
with torch.no_grad():
unmsk_patch, unmsk_idx, msk_idx, msk_matrix = model.model.encoder(datacube)
# The first embedding is the class token, which is the
# overall single embedding. We extract that for PCA below.
embeddings = unmsk_patch[:, 0, :].cpu().numpy()
#! My changes here:
# decode the embeddings
with torch.no_grad():
reconstructed,_ = model.model.decoder(
unmsk_patch,
unmsk_idx,
msk_idx,
msk_matrix,
datacube["time"],
datacube["latlon"],
datacube["gsd"],
datacube["waves"],
)
def denormalize_images(normalized_images, means, stds):
means = np.array(means)
stds = np.array(stds)
means = means.reshape(1, -1, 1, 1)
stds = stds.reshape(1, -1, 1, 1)
denormalized_images = normalized_images * stds + means
return denormalized_images
# batch_size, num_patches,pixel_values_per_patch
reconstructed = rearrange(
reconstructed,
"b (h w) (c p1 p2) -> b c (h p1) (w p2)",
c=pixels.shape[1],
h=32,
w=32,
p1=8,
p2=8,
)
# denormalize the images
denormalized_images = denormalize_images(reconstructed.cpu().numpy(), means=mean, stds=std)
denormalized_inputs = denormalize_images(pixels.cpu().numpy(), means=mean, stds=std)
fig,axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(denormalized_inputs[4,[2, 1, 0], :, :].transpose(1, 2, 0)/2000)
axs[0].set_title('Original Image')
axs[0].axis('off')
axs[1].imshow(denormalized_images[4][[2, 1, 0], :, :].transpose(1, 2, 0)/2000)
axs[1].set_title('Reconstructed Image')
axs[1].axis('off')
plt.show()However, my reconstructed images look very wrong even though the rough shapes seem to be detected. You can even make out every 8x8 patch:

And this is despite the analysis of the embeddings looks exactly like in the provided notebook in the documentation.
Have you seen similar problems like this or point me to some error or debugging steps? I feel like I doublechecked everything that I could think of.
Metadata
Metadata
Assignees
Labels
No labels