Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tps/rand_tps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn
import itertools

from tps_stn_pytorch.tps_grid_gen import TPSGridGen
from tps.tps_grid_gen import TPSGridGen
from tps.grid_sample import grid_sample
from torch.autograd import Variable

Expand Down
71 changes: 71 additions & 0 deletions tps/tps_grid_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# encoding: utf-8

import torch
import itertools
import torch.nn as nn
from torch.autograd import Function, Variable

# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
def compute_partial_repr(input_points, control_points):
N = input_points.size(0)
M = control_points.size(0)
pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
# original implementation, very slow
# pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
pairwise_diff_square = pairwise_diff * pairwise_diff
pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1]
repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
# fix numerical error for 0 * log(0), substitute all nan with 0
mask = repr_matrix != repr_matrix
repr_matrix.masked_fill_(mask, 0)
return repr_matrix

class TPSGridGen(nn.Module):

def __init__(self, target_height, target_width, target_control_points):
super(TPSGridGen, self).__init__()
assert target_control_points.ndimension() == 2
assert target_control_points.size(1) == 2
N = target_control_points.size(0)
self.num_points = N
target_control_points = target_control_points.float()

# create padded kernel matrix
forward_kernel = torch.zeros(N + 3, N + 3)
target_control_partial_repr = compute_partial_repr(target_control_points, target_control_points)
forward_kernel[:N, :N].copy_(target_control_partial_repr)
forward_kernel[:N, -3].fill_(1)
forward_kernel[-3, :N].fill_(1)
forward_kernel[:N, -2:].copy_(target_control_points)
forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
# compute inverse matrix
inverse_kernel = torch.inverse(forward_kernel)

# create target cordinate matrix
HW = target_height * target_width
target_coordinate = list(itertools.product(range(target_height), range(target_width)))
target_coordinate = torch.Tensor(target_coordinate) # HW x 2
Y, X = target_coordinate.split(1, dim = 1)
Y = Y * 2 / (target_height - 1) - 1
X = X * 2 / (target_width - 1) - 1
target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y)
target_coordinate_partial_repr = compute_partial_repr(target_coordinate, target_control_points)
target_coordinate_repr = torch.cat([
target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate
], dim = 1)

# register precomputed matrices
self.register_buffer('inverse_kernel', inverse_kernel)
self.register_buffer('padding_matrix', torch.zeros(3, 2))
self.register_buffer('target_coordinate_repr', target_coordinate_repr)

def forward(self, source_control_points):
assert source_control_points.ndimension() == 3
assert source_control_points.size(1) == self.num_points
assert source_control_points.size(2) == 2
batch_size = source_control_points.size(0)

Y = torch.cat([source_control_points, Variable(self.padding_matrix.expand(batch_size, 3, 2))], 1)
mapping_matrix = torch.matmul(Variable(self.inverse_kernel), Y)
source_coordinate = torch.matmul(Variable(self.target_coordinate_repr), mapping_matrix)
return source_coordinate