diff --git a/.gitignore b/.gitignore index ef79da47..45431d45 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ dist/ *.egg-info/ site/ venv/ +**/.vscode .ipynb_checkpoints examples/notebooks/dataset examples/notebooks/CIFAR10_Dataset diff --git a/src/pytorch_metric_learning/losses/__init__.py b/src/pytorch_metric_learning/losses/__init__.py index 10e841f1..a0ba7407 100644 --- a/src/pytorch_metric_learning/losses/__init__.py +++ b/src/pytorch_metric_learning/losses/__init__.py @@ -8,6 +8,7 @@ from .cross_batch_memory import CrossBatchMemory from .fast_ap_loss import FastAPLoss from .generic_pair_loss import GenericPairLoss +from .histogram_loss import HistogramLoss from .instance_loss import InstanceLoss from .intra_pair_variance_loss import IntraPairVarianceLoss from .large_margin_softmax_loss import LargeMarginSoftmaxLoss diff --git a/src/pytorch_metric_learning/losses/histogram_loss.py b/src/pytorch_metric_learning/losses/histogram_loss.py new file mode 100644 index 00000000..ffd04cf2 --- /dev/null +++ b/src/pytorch_metric_learning/losses/histogram_loss.py @@ -0,0 +1,86 @@ +import torch + +from ..distances import CosineSimilarity +from ..utils import common_functions as c_f +from ..utils import loss_and_miner_utils as lmu +from .base_metric_loss_function import BaseMetricLossFunction + + +def filter_pairs(*tensors: torch.Tensor): + t = torch.stack(tensors) + t, _ = torch.sort(t, dim=0) + t = torch.unique(t, dim=1) + return t.tolist() + + +class HistogramLoss(BaseMetricLossFunction): + def __init__(self, n_bins: int = None, delta: float = None, **kwargs): + super().__init__(**kwargs) + assert ( + delta is None + and n_bins is not None + or delta is not None + and n_bins is None + or delta is not None + and n_bins is not None + ), "delta and n_bins cannot be both None" + + if delta is not None and n_bins is not None: + assert ( + delta == 2 / n_bins + ), f"delta and n_bins must satisfy the equation delta = 2/n_bins.\nPassed values are delta={delta} and n_bins={n_bins}" + + self.delta = delta if delta is not None else 2 / n_bins + self.add_to_recordable_attributes(name="num_bins", is_stat=True) + + def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): + c_f.labels_or_indices_tuple_required(labels, indices_tuple) + c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels) + indices_tuple = lmu.convert_to_triplets( + indices_tuple, labels, ref_labels, t_per_anchor="all" + ) + anchor_idx, positive_idx, negative_idx = indices_tuple + if len(anchor_idx) == 0: + return self.zero_losses() + mat = self.distance(embeddings, ref_emb) + + anchor_positive_idx = filter_pairs(anchor_idx, positive_idx) + anchor_negative_idx = filter_pairs(anchor_idx, negative_idx) + ap_dists = mat[anchor_positive_idx] + an_dists = mat[anchor_negative_idx] + + p_pos = self.compute_density(ap_dists) + phi = torch.cumsum(p_pos, dim=0) + + p_neg = self.compute_density(an_dists) + return { + "loss": { + "losses": torch.sum(p_neg * phi), + "indices": None, + "reduction_type": "already_reduced", + } + } + + def compute_density(self, distances): + size = distances.size(0) + r_star = torch.floor( + (distances.float() + 1) / self.delta + ) # Indices of the bins containing the values of the distances + r_star = c_f.to_device(r_star, tensor=distances, dtype=torch.long) + + delta_ijr_a = (distances + 1 - r_star * self.delta) / self.delta + delta_ijr_b = ((r_star + 1) * self.delta - 1 - distances) / self.delta + delta_ijr_a = c_f.to_dtype(delta_ijr_a, tensor=distances) + delta_ijr_b = c_f.to_dtype(delta_ijr_b, tensor=distances) + + density = torch.zeros(round(1 + 2 / self.delta)) + density = c_f.to_device(density, tensor=distances, dtype=distances.dtype) + + # For each node sum the contributions of the bins whose ending node is this one + density.scatter_add_(0, r_star + 1, delta_ijr_a) + # For each node sum the contributions of the bins whose starting node is this one + density.scatter_add_(0, r_star, delta_ijr_b) + return density / size + + def get_default_distance(self): + return CosineSimilarity() diff --git a/tests/losses/test_histogram_loss.py b/tests/losses/test_histogram_loss.py new file mode 100644 index 00000000..adf951d9 --- /dev/null +++ b/tests/losses/test_histogram_loss.py @@ -0,0 +1,164 @@ +import unittest + +import torch +from numpy.testing import assert_almost_equal + +from pytorch_metric_learning.losses import HistogramLoss + +from .. import TEST_DEVICE, TEST_DTYPES +from ..zzz_testing_utils.testing_utils import angle_to_coord + + +###################################### +#######ORIGINAL IMPLEMENTATION######## +###################################### +# DIRECTLY COPIED from https://github.com/valerystrizh/pytorch-histogram-loss/blob/master/losses.py. +# This code is copied from the official PyTorch implementation +# so that we can make sure our implementation returns the same result. +# Some minor changes were made to avoid errors during testing. +# Every change in the original code is reported and explained. +class OriginalImplementationHistogramLoss(torch.nn.Module): + def __init__(self, num_steps, cuda=True): + super(OriginalImplementationHistogramLoss, self).__init__() + self.step = 2 / (num_steps - 1) + self.eps = 1 / num_steps + self.cuda = cuda + self.t = torch.arange(-1, 1 + self.step, self.step).view(-1, 1) + self.tsize = self.t.size()[0] + if self.cuda: + self.t = self.t.cuda() + + def forward(self, features, classes): + def histogram(inds, size): + s_repeat_ = s_repeat.clone() + indsa = ( + (s_repeat_floor - (self.t - self.step) > -self.eps) + & (s_repeat_floor - (self.t - self.step) < self.eps) + & inds + ) + assert ( + indsa.nonzero().size()[0] == size + ), "Another number of bins should be used" + zeros = torch.zeros((1, indsa.size()[1])).byte() + if self.cuda: + zeros = zeros.cuda() + indsb = torch.cat((indsa, zeros))[1:, :].to( + dtype=torch.bool + ) # Added to avoid bug with masks of uint8 + s_repeat_[~(indsb | indsa)] = 0 + # indsa corresponds to the first condition of the second equation of the paper + self.t = self.t.to( + dtype=s_repeat_.dtype + ) # Added to avoid errors when using Half precision + s_repeat_[indsa] = (s_repeat_ - self.t + self.step)[indsa] / self.step + # indsb corresponds to the second condition of the second equation of the paper + s_repeat_[indsb] = (-s_repeat_ + self.t + self.step)[indsb] / self.step + + return s_repeat_.sum(1) / size + + classes_size = classes.size()[0] + classes_eq = ( + classes.repeat(classes_size, 1) + == classes.view(-1, 1).repeat(1, classes_size) + ).data + dists = torch.mm(features, features.transpose(0, 1)) + assert ( + (dists > 1 + self.eps).sum().item() + (dists < -1 - self.eps).sum().item() + ) == 0, "L2 normalization should be used" + s_inds = torch.triu(torch.ones(classes_eq.size()), 1).byte() + if self.cuda: + s_inds = s_inds.cuda() + classes_eq = classes_eq.to( + device=s_inds.device + ) # Added to avoid errors when using only cpu + pos_inds = classes_eq[s_inds].repeat(self.tsize, 1) + neg_inds = ~classes_eq[s_inds].repeat(self.tsize, 1) + pos_size = classes_eq[s_inds].sum().item() + neg_size = (~classes_eq[s_inds]).sum().item() + s = dists[s_inds].view(1, -1) + s_repeat = s.repeat(self.tsize, 1) + s_repeat_floor = (torch.floor(s_repeat.data / self.step) * self.step).float() + + histogram_pos = histogram(pos_inds, pos_size) + assert_almost_equal( + histogram_pos.sum().item(), + 1, + decimal=1, + err_msg="Not good positive histogram", + verbose=True, + ) + histogram_neg = histogram(neg_inds, neg_size) + assert_almost_equal( + histogram_neg.sum().item(), + 1, + decimal=1, + err_msg="Not good negative histogram", + verbose=True, + ) + histogram_pos_repeat = histogram_pos.view(-1, 1).repeat( + 1, histogram_pos.size()[0] + ) + histogram_pos_inds = torch.tril( + torch.ones(histogram_pos_repeat.size()), -1 + ).byte() + if self.cuda: + histogram_pos_inds = histogram_pos_inds.cuda() + histogram_pos_repeat[histogram_pos_inds] = 0 + histogram_pos_cdf = histogram_pos_repeat.sum(0) + loss = torch.sum(histogram_neg * histogram_pos_cdf) + + return loss + + +class TestHistogramLoss(unittest.TestCase): + def test_histogram_loss(self): + for dtype in TEST_DTYPES: + embeddings = torch.randn( + 5, + 32, + requires_grad=True, + dtype=dtype, + ).to( + TEST_DEVICE + ) # 2D embeddings + embeddings = torch.nn.functional.normalize(embeddings) + labels = torch.LongTensor([0, 0, 1, 1, 2]) + + num_steps = 5 if dtype == torch.float16 else 21 + num_bins = num_steps - 1 + loss_func = HistogramLoss(n_bins=num_bins) + + loss = loss_func(embeddings, labels) + + original_loss_func = OriginalImplementationHistogramLoss( + num_steps=num_steps + ) + correct_loss = original_loss_func(embeddings, labels) + + rtol = 1e-2 if dtype == torch.float16 else 1e-5 + self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol)) + + def test_with_no_valid_triplets(self): + loss_funcA = HistogramLoss(n_bins=4) + for dtype in TEST_DTYPES: + embedding_angles = [0, 20, 40, 60, 80] + embeddings = torch.tensor( + [angle_to_coord(a) for a in embedding_angles], + requires_grad=True, + dtype=dtype, + ).to( + TEST_DEVICE + ) # 2D embeddings + labels = torch.LongTensor([0, 1, 2, 3, 4]) + lossA = loss_funcA(embeddings, labels) + self.assertEqual(lossA, 0) + + def test_assertion_raises(self): + with self.assertRaises(AssertionError): + _ = HistogramLoss() + + with self.assertRaises(AssertionError): + _ = HistogramLoss(n_bins=1, delta=0.5) + + with self.assertRaises(AssertionError): + _ = HistogramLoss(n_bins=10, delta=0.4)