Skip to content

Commit

Permalink
adding files
Browse files Browse the repository at this point in the history
  • Loading branch information
TakeshiMusgrave committed Oct 24, 2019
1 parent 59fa79a commit e95c290
Show file tree
Hide file tree
Showing 38 changed files with 2,100 additions and 0 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
__pycache__/
*.py[cod]
.nfs*
build/
dist/
pytorch_metric_learning.egg-info/
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2019 Kevin Musgrave

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
10 changes: 10 additions & 0 deletions losses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .angular_loss import AngularLoss, AngularNPairsLoss
from .contrastive_loss import ContrastiveLoss
from .generic_pair_loss import GenericPairLoss
from .lifted_structure_loss import GeneralizedLiftedStructureLoss
from .margin_loss import MarginLoss
from .multi_similarity_loss import MultiSimilarityLoss
from .nca_loss import NCALoss
from .n_pairs_loss import NPairsLoss
from .proxy_losses import ProxyNCALoss
from .triplet_margin_loss import TripletMarginLoss
78 changes: 78 additions & 0 deletions losses/angular_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#! /usr/bin/env python3

from . import base_metric_loss_function as bmlf
import numpy as np
import torch
from ..utils import loss_and_miner_utils as lmu


class AngularNPairsLoss(bmlf.BaseMetricLossFunction):
"""
Implementation of https://arxiv.org/abs/1708.01682
Args:
alpha: The angle (as described in the paper), specified in degrees.
"""
def __init__(self, alpha, **kwargs):
self.alpha = torch.tensor(np.radians(alpha))
self.maybe_modify_loss = lambda x: x
self.num_anchors = 0
self.avg_embedding_norm = 0
self.record_these = ["num_anchors", "avg_embedding_norm"]
super().__init__(**kwargs)

def compute_loss(self, embeddings, labels, indices_tuple):
self.avg_embedding_norm = torch.mean(torch.norm(embeddings, p=2, dim=1))
anchor_idx, positive_idx = lmu.convert_to_pos_pairs_with_unique_labels(indices_tuple, labels)
self.num_anchors = len(anchor_idx)
if self.num_anchors == 0:
return 0

anchors, positives = embeddings[anchor_idx], embeddings[positive_idx]
points_per_anchor = (self.num_anchors - 1) * 2

alpha = self.maybe_mask_param(self.alpha, labels[anchor_idx])
sq_tan_alpha = torch.tan(alpha) ** 2

xa_xp = torch.sum(anchors * positives, dim=1, keepdim=True)
term2_multiplier = 2 * (1 + sq_tan_alpha)
term2 = term2_multiplier * xa_xp

a_p_summed = anchors + positives
inside_exp = []
mask = torch.ones(self.num_anchors).to(embeddings.device) - torch.eye(self.num_anchors).to(embeddings.device)
term1_multiplier = 4 * sq_tan_alpha

for negatives in [anchors, positives]:
term1 = term1_multiplier * torch.matmul(a_p_summed, torch.t(negatives))
inside_exp.append(term1 - term2.repeat(1, self.num_anchors))
inside_exp[-1] = inside_exp[-1] * mask

inside_exp_final = torch.zeros((self.num_anchors, points_per_anchor + 1)).to(embeddings.device)

for i in range(self.num_anchors):
indices = np.concatenate((np.arange(0, i), np.arange(i + 1, inside_exp[0].size(1))))
inside_exp_final[i, : points_per_anchor // 2] = inside_exp[0][i, indices]
inside_exp_final[:, points_per_anchor // 2 :] = inside_exp[1]
inside_exp_final = self.maybe_modify_loss(inside_exp_final)

return torch.mean(torch.logsumexp(inside_exp_final, dim=1))

def create_learnable_parameter(self, init_value):
return super().create_learnable_parameter(init_value, unsqueeze=True)


class AngularLoss(AngularNPairsLoss):
def compute_loss(self, embeddings, labels, indices_tuple):
self.avg_embedding_norm = torch.mean(torch.norm(embeddings, p=2, dim=1))
anchor_idx, positive_idx, negative_idx = lmu.convert_to_triplets(indices_tuple, labels)
self.num_anchors = len(anchor_idx)
if self.num_anchors == 0:
return 0
anchors, positives, negatives = embeddings[anchor_idx], embeddings[positive_idx], embeddings[negative_idx]
alpha = self.maybe_mask_param(self.alpha, labels[anchor_idx])
sq_tan_alpha = torch.tan(alpha) ** 2
term1 = 4 * sq_tan_alpha * torch.sum((anchors + positives) * negatives, dim=1, keepdim=True)
term2 = 2 * (1 + sq_tan_alpha) * torch.sum(anchors * positives, dim=1, keepdim=True)
final_form = torch.cat([term1 - term2, torch.zeros(term1.size(0), 1).to(embeddings.device)], dim=1)
final_form = self.maybe_modify_loss(final_form)
return torch.mean(torch.logsumexp(final_form, dim=1))
82 changes: 82 additions & 0 deletions losses/base_metric_loss_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#! /usr/bin/env python3

import torch


class BaseMetricLossFunction(torch.nn.Module):
"""
All loss functions extend this class
Args:
normalize_embeddings: type boolean. If True then normalize embeddins
to have norm = 1 before computing the loss
num_class_per_param: type int. The number of classes for each parameter.
If your parameters don't have a separate value for each class,
then leave this at None
learnable_param_names: type list of strings. Each element is the name of
attributes that should be converted to nn.Parameter
"""
def __init__(
self,
normalize_embeddings=True,
num_class_per_param=None,
learnable_param_names=None
):
super().__init__()
self.normalize_embeddings = normalize_embeddings
self.num_class_per_param = num_class_per_param
self.learnable_param_names = learnable_param_names
self.initialize_learnable_parameters()

def compute_loss(self, embeddings, labels, indices_tuple=None):
"""
This has to be implemented and is what actually computes the loss.
"""
raise NotImplementedError

def forward(self, embeddings, labels, indices_tuple=None):
"""
Args:
embeddings: tensor of size (batch_size, embedding_size)
labels: tensor of size (batch_size)
indices_tuple: tuple of size 3 for triplets (anchors, positives, negatives)
or size 4 for pairs (anchor1, postives, anchor2, negatives)
Can also be left as None
Returns: the loss (float)
"""
labels = labels.to(embeddings.device)
if self.normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
loss = self.compute_loss(embeddings, labels, indices_tuple)
return loss

def initialize_learnable_parameters(self):
"""
To learn hyperparams, create an attribute called learnable_param_names.
This should be a list of strings which are the names of the
hyperparameters to be learned
"""
if self.learnable_param_names is not None:
for k in self.learnable_param_names:
v = getattr(self, k)
setattr(self, k, self.create_learnable_parameter(v))

def create_learnable_parameter(self, init_value, unsqueeze=False):
"""
Returns nn.Parameter with an initial value of init_value
and size of num_labels
"""
vec_len = self.num_class_per_param if self.num_class_per_param else 1
if unsqueeze:
vec_len = (vec_len, 1)
p = torch.nn.Parameter(torch.ones(vec_len) * init_value)
return p

def maybe_mask_param(self, param, labels):
"""
This returns the hyperparameters corresponding to class labels (if applicable).
If there is a hyperparameter for each class, then when computing the loss,
the class hyperparameter has to be matched to the corresponding embedding.
"""
if self.num_class_per_param:
return param[labels]
return param
79 changes: 79 additions & 0 deletions losses/contrastive_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#! /usr/bin/env python3

import torch

from . import generic_pair_loss as gpl


class ContrastiveLoss(gpl.GenericPairLoss):
"""
Contrastive loss using either distance or similarity.
Args:
pos_margin: The distance (or similarity) over (under) which positive pairs will contribute to the loss.
neg_margin: The distance (or similarity) under (over) which negative pairs will contribute to the loss.
use_similarity: If True, will use dot product between vectors instead of euclidean distance
power: Each pair's loss will be raised to this power.
avg_non_zero_only: Only pairs that contribute non-zero loss will be used in the final loss.
"""
def __init__(
self,
pos_margin=0,
neg_margin=1,
use_similarity=False,
power=1,
avg_non_zero_only=True,
**kwargs
):
self.pos_margin = pos_margin
self.neg_margin = neg_margin
self.avg_non_zero_only = avg_non_zero_only
self.num_non_zero_pos_pairs = 0
self.num_non_zero_neg_pairs = 0
self.record_these = ["num_non_zero_pos_pairs", "num_non_zero_neg_pairs"]
self.power = power
super().__init__(use_similarity=use_similarity, iterate_through_loss=False, **kwargs)

def pair_based_loss(
self,
pos_pair_dist,
neg_pair_dist,
pos_pair_anchor_labels,
neg_pair_anchor_labels,
):
pos_loss, neg_loss = 0, 0
self.num_non_zero_pos_pairs, self.num_non_zero_neg_pairs = 0, 0
if len(pos_pair_dist) > 0:
pos_loss, self.num_non_zero_pos_pairs = self.mask_margin_and_calculate_loss(
pos_pair_dist, pos_pair_anchor_labels, "pos"
)
if len(neg_pair_dist) > 0:
neg_loss, self.num_non_zero_neg_pairs = self.mask_margin_and_calculate_loss(
neg_pair_dist, neg_pair_anchor_labels, "neg"
)
return pos_loss + neg_loss

def mask_margin_and_calculate_loss(self, pair_dists, labels, pos_or_neg):
loss_calc_func = self.pos_calc if pos_or_neg == "pos" else self.neg_calc
input_margin = self.pos_margin if pos_or_neg == "pos" else self.neg_margin
margin = self.maybe_mask_param(input_margin, labels)
per_pair_loss = loss_calc_func(pair_dists, margin) ** self.power
num_non_zero_pairs = (per_pair_loss > 0).nonzero().size(0)
if self.avg_non_zero_only:
loss = torch.sum(per_pair_loss) / (num_non_zero_pairs + 1e-16)
else:
loss = torch.mean(per_pair_loss)
return loss, num_non_zero_pairs

def pos_calc(self, pos_pair_dist, margin):
return (
torch.nn.functional.relu(margin - pos_pair_dist)
if self.use_similarity
else torch.nn.functional.relu(pos_pair_dist - margin)
)

def neg_calc(self, neg_pair_dist, margin):
return (
torch.nn.functional.relu(neg_pair_dist - margin)
if self.use_similarity
else torch.nn.functional.relu(margin - neg_pair_dist)
)
75 changes: 75 additions & 0 deletions losses/generic_pair_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#! /usr/bin/env python3


import torch
from ..utils import loss_and_miner_utils as lmu
from . import base_metric_loss_function as bmlf


class GenericPairLoss(bmlf.BaseMetricLossFunction):
"""
The function pair_based_loss has to be implemented by the child class.
By default, this class extracts every positive and negative pair within a
batch (based on labels) and passes the pairs to the loss function.
The pairs can be passed to the loss function all at once (self.loss_once)
or pairs can be passed iteratively (self.loss_loop) by going through each
sample in a batch, and selecting just the positive and negative pairs
containing that sample.
Args:
use_similarity: set to True if the loss function uses pairwise similarity
(dot product of each embedding pair). Otherwise,
euclidean distance will be used
iterate_through_loss: set to True to use self.loss_loop and False otherwise
squared_distances: if True, then the euclidean distance will be squared.
"""

def __init__(
self, use_similarity, iterate_through_loss, squared_distances=False, **kwargs
):
self.use_similarity = use_similarity
self.squared_distances = squared_distances
self.loss_method = self.loss_loop if iterate_through_loss else self.loss_once
super().__init__(**kwargs)

def compute_loss(self, embeddings, labels, indices_tuple):
mat = (
lmu.sim_mat(embeddings)
if self.use_similarity
else lmu.dist_mat(embeddings, squared=self.squared_distances)
)
indices_tuple = lmu.convert_to_pairs(indices_tuple, labels)
return self.loss_method(mat, labels, indices_tuple)

def pair_based_loss(
self, pos_pairs, neg_pairs, pos_pair_anchor_labels, neg_pair_anchor_labels
):
raise NotImplementedError

def loss_loop(self, mat, labels, indices_tuple):
loss = torch.tensor(0.0).to(mat.device)
n = 0
(a1_indices, p_indices, a2_indices, n_indices) = indices_tuple
for i in range(mat.size(0)):
pos_pair, neg_pair = [], []
if len(a1_indices) > 0:
p_idx = a1_indices == i
pos_pair = mat[a1_indices[p_idx], p_indices[p_idx]]
if len(a2_indices) > 0:
n_idx = a2_indices == i
neg_pair = mat[a2_indices[n_idx], n_indices[n_idx]]
loss += self.pair_based_loss(
pos_pair, neg_pair, labels[a1_indices[p_idx]], labels[a2_indices[n_idx]]
)
n += 1
return loss / (n if n > 0 else 1)

def loss_once(self, mat, labels, indices_tuple):
(a1_indices, p_indices, a2_indices, n_indices) = indices_tuple
pos_pair, neg_pair = [], []
if len(a1_indices) > 0:
pos_pair = mat[a1_indices, p_indices]
if len(a2_indices) > 0:
neg_pair = mat[a2_indices, n_indices]
return self.pair_based_loss(
pos_pair, neg_pair, labels[a1_indices], labels[a2_indices]
)
20 changes: 20 additions & 0 deletions losses/lifted_structure_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#! /usr/bin/env python3

import torch

from . import generic_pair_loss as gpl


class GeneralizedLiftedStructureLoss(gpl.GenericPairLoss):
# The 'generalized' lifted structure loss shown on page 4
# of the "in defense of triplet loss" paper
# https://arxiv.org/pdf/1703.07737.pdf
def __init__(self, neg_margin, **kwargs):
self.neg_margin = neg_margin
super().__init__(use_similarity=False, iterate_through_loss=True, **kwargs)

def pair_based_loss(self, pos_pairs, neg_pairs, pos_pair_anchor_labels, neg_pair_anchor_labels):
neg_margin = self.maybe_mask_param(self.neg_margin, neg_pair_anchor_labels)
per_anchor = torch.logsumexp(pos_pairs, dim=0) + torch.logsumexp(neg_margin - neg_pairs, dim=0)
hinged = torch.relu(per_anchor)
return torch.mean(hinged)
Loading

0 comments on commit e95c290

Please sign in to comment.