Skip to content

Model problems when reconstructing in wall-to-wall tutorial #353

@erikscheurer

Description

@erikscheurer

Hi,
I want to use Clay for my research and therefore try around with the model to see how the reconstructions look.
Since there is no dummy data for the reconstruction notebook I tried to copy the reconstruction code to the wall-to-wall example.

My "Run the Model" chapter therefore looks like this:

with torch.no_grad():
    unmsk_patch, unmsk_idx, msk_idx, msk_matrix = model.model.encoder(datacube)

# The first embedding is the class token, which is the
# overall single embedding. We extract that for PCA below.
embeddings = unmsk_patch[:, 0, :].cpu().numpy()

#! My changes here:
# decode the embeddings
with torch.no_grad():
    reconstructed,_ = model.model.decoder(
        unmsk_patch,
        unmsk_idx,
        msk_idx,
        msk_matrix,
        datacube["time"],
        datacube["latlon"],
        datacube["gsd"],
        datacube["waves"],
    )

def denormalize_images(normalized_images, means, stds):
    means = np.array(means)
    stds = np.array(stds)
    means = means.reshape(1, -1, 1, 1)
    stds = stds.reshape(1, -1, 1, 1)
    denormalized_images = normalized_images * stds + means

    return denormalized_images

# batch_size, num_patches,pixel_values_per_patch
reconstructed = rearrange(
    reconstructed,
    "b (h w) (c p1 p2) -> b c (h p1) (w p2)",
    c=pixels.shape[1],
    h=32,
    w=32,
    p1=8,
    p2=8,
)
# denormalize the images
denormalized_images = denormalize_images(reconstructed.cpu().numpy(), means=mean, stds=std)
denormalized_inputs = denormalize_images(pixels.cpu().numpy(), means=mean, stds=std)

fig,axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(denormalized_inputs[4,[2, 1, 0], :, :].transpose(1, 2, 0)/2000)
axs[0].set_title('Original Image')
axs[0].axis('off')
axs[1].imshow(denormalized_images[4][[2, 1, 0], :, :].transpose(1, 2, 0)/2000)
axs[1].set_title('Reconstructed Image')
axs[1].axis('off')
plt.show()

However, my reconstructed images look very wrong even though the rough shapes seem to be detected. You can even make out every 8x8 patch:
Image

And this is despite the analysis of the embeddings looks exactly like in the provided notebook in the documentation.
Have you seen similar problems like this or point me to some error or debugging steps? I feel like I doublechecked everything that I could think of.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions