Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

procrustes alignment #2691

Open
heth27 opened this issue Aug 13, 2024 · 6 comments · May be fixed by #2723
Open

procrustes alignment #2691

heth27 opened this issue Aug 13, 2024 · 6 comments · May be fixed by #2723
Labels
enhancement New feature or request New metric
Milestone

Comments

@heth27
Copy link

heth27 commented Aug 13, 2024

🚀 Feature

spatial procrustes alignment, a similarity test for two data sets

Motivation

Procrustes alignment is a staple when calculating metrics for 3d human pose estimation, but there seems to be no library that offers this function for pytorch, so I guess everyone just maintains their own version.

Pitch

There is a variant in scipy
https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.procrustes.html

Alternatives

Additional context

The implementation I'm using, don't know if it is any good.

def procrustes(pts1: torch.Tensor, pts2: torch.Tensor):
    assert pts1.shape == pts2.shape, f"{pts1.shape} != {pts2.shape}"
    assert pts1.shape[-1] == 3 and len(pts1.shape) == 2, f"{pts1.shape}"
    # estimate a sim3 transformation to align two point clouds
    # find M = argmin ||P1 - M @ P2||
    t1 = pts1.mean(dim=0)
    t2 = pts2.mean(dim=0)
    pts1 = pts1 - t1[None, :]
    pts2 = pts2 - t2[None, :]

    s1 = pts1.square().sum(dim=-1).mean().sqrt()
    s2 = pts2.square().sum(dim=-1).mean().sqrt()
    pts1 = pts1 / s1
    pts2 = pts2 / s2
    try:

        U, _, V = (pts1.T @ pts2).double().svd()
        U: torch.Tensor = U
        V: torch.Tensor = V
    except:
        print("Procustes failed: SVD did not converge!")
        s = s1 / s2
        return 1, torch.eye(3, device=pts1.device), torch.zeros_like(t1)
    # build rotation matrix
    R = (U @ V.T).float()
    if R.det() < 0:
        R[:, 2] *= -1
    s = s1 / s2
    t = t1 - s * t2 @ R.T

    # use as mat4: [sR, t] @ pts2
    # or as s * R @ pts2 + t

    # s, R, mean_1, mean_2 = procrustes(pts1, pts2)
    #
    # procrustes_aligned = torch.einsum("jd, od -> jo", coords3d_pred_rel_dataset_format[index_in_batch] - mean_2,
    #                                               s * R) + mean_1
    return s, R, t1, t2

example usage:

s, R, mean_1, mean_2 = procrustes(coords_3d_true,
                                              coords_3d_prediction)
procrustes_aligned = torch.einsum("jd, od -> jo", coords_3d_prediction - mean_2,
                                              s * R) + mean_1

The problem with this version is that it does not work on batches.

@heth27 heth27 added the enhancement New feature or request label Aug 13, 2024
Copy link

Hi! thanks for your contribution!, great first issue!

@Borda
Copy link
Member

Borda commented Aug 21, 2024

spatial procrustes alignment, a similarity test for two data sets

this sounds good, would you be interested in adding it to TM? creating draft PR and then we can help you finish it... 👼

@SkafteNicki
Copy link
Member

Hi @heth27, I took a stab at implementing a batched version of your implementation:

import torch

def procrustus_batch(data1, data2):
    if data1.shape != data2.shape:
        raise ValueError("data1 and data2 must have the same shape")
    if data1.ndim == 2:
        data1 = data1[None, :, :]
        data2 = data2[None, :, :]

    data1 -= data1.mean(dim=1, keepdim=True)
    data2 -= data2.mean(dim=1, keepdim=True)
    data1 /= torch.linalg.norm(data1, dim=[1,2], keepdim=True)
    data2 /= torch.linalg.norm(data2, dim=[1,2], keepdim=True)

    try:
        u, w, v = torch.linalg.svd(torch.matmul(data2.transpose(1, 2), data1).transpose(1,2), full_matrices=False)
    except:
        raise ValueError("SVD did not converge")
    rotation = torch.matmul(u, v)
    scale = w.sum(1, keepdim=True)
    data2 = scale[:,None] * torch.matmul(data2, rotation.transpose(1,2))
    disparity = (data1 - data2).square().sum(dim=[1,2])
    return disparity

coords_3d_true = torch.rand(2, 10, 3)
coords_3d_prediction = torch.rand(2, 10, 3)

p2 = procrustus_batch(coords_3d_true.clone(), coords_3d_prediction.clone())
print(p2)

from scipy.spatial import procrustes as procrustes_scipy
for i in range(2):
    mtx1, mtx2, disparity = procrustes_scipy(coords_3d_true[i].clone(), coords_3d_prediction[i].clone())
    print(disparity)

for random inputs it seems to work when comparing against scipy.
Are you interested in sending a PR or do you want me to take over?

@SkafteNicki SkafteNicki added this to the v1.5.0 milestone Aug 28, 2024
@heth27
Copy link
Author

heth27 commented Aug 28, 2024

Hi @SkafteNicki thank you, please feel free to create a PR. How do you feel about returning the rotation matrix, or the transformed coordinates as well? They are used for downstream calculation of procrustes-aligned mean per joint position error in a lot of human pose estimation tasks.

@SkafteNicki
Copy link
Member

@heth27 I would be fine with that. Maybe it makes sense to add an additional argument like return_all_stats or something similar to indicate if this additional information should be returned.

This metric does not fit under any of our current subdomains, do you have a recommendation for what new domain this metric fits under?

@heth27
Copy link
Author

heth27 commented Aug 30, 2024

Wikipedia suggests shape analysis. I plan on adding things like Mean-Per-Joint-Position-Error (MPJPE) and Percent-of-correctly-classified-keypoints (PCK) as well when I have more time. Those would also fit the domain.
There is also a new feature request for Hausdorff distance #1990.
I'm not sure if a more general domain (shape analysis) or the most common usage (human pose estimation in the case of procrustes) makes it easier to find.

@SkafteNicki SkafteNicki linked a pull request Sep 7, 2024 that will close this issue
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request New metric
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants