-
Notifications
You must be signed in to change notification settings - Fork 657
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
TakeshiMusgrave
committed
Oct 24, 2019
1 parent
59fa79a
commit e95c290
Showing
38 changed files
with
2,100 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
__pycache__/ | ||
*.py[cod] | ||
.nfs* | ||
build/ | ||
dist/ | ||
pytorch_metric_learning.egg-info/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2019 Kevin Musgrave | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from .angular_loss import AngularLoss, AngularNPairsLoss | ||
from .contrastive_loss import ContrastiveLoss | ||
from .generic_pair_loss import GenericPairLoss | ||
from .lifted_structure_loss import GeneralizedLiftedStructureLoss | ||
from .margin_loss import MarginLoss | ||
from .multi_similarity_loss import MultiSimilarityLoss | ||
from .nca_loss import NCALoss | ||
from .n_pairs_loss import NPairsLoss | ||
from .proxy_losses import ProxyNCALoss | ||
from .triplet_margin_loss import TripletMarginLoss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
#! /usr/bin/env python3 | ||
|
||
from . import base_metric_loss_function as bmlf | ||
import numpy as np | ||
import torch | ||
from ..utils import loss_and_miner_utils as lmu | ||
|
||
|
||
class AngularNPairsLoss(bmlf.BaseMetricLossFunction): | ||
""" | ||
Implementation of https://arxiv.org/abs/1708.01682 | ||
Args: | ||
alpha: The angle (as described in the paper), specified in degrees. | ||
""" | ||
def __init__(self, alpha, **kwargs): | ||
self.alpha = torch.tensor(np.radians(alpha)) | ||
self.maybe_modify_loss = lambda x: x | ||
self.num_anchors = 0 | ||
self.avg_embedding_norm = 0 | ||
self.record_these = ["num_anchors", "avg_embedding_norm"] | ||
super().__init__(**kwargs) | ||
|
||
def compute_loss(self, embeddings, labels, indices_tuple): | ||
self.avg_embedding_norm = torch.mean(torch.norm(embeddings, p=2, dim=1)) | ||
anchor_idx, positive_idx = lmu.convert_to_pos_pairs_with_unique_labels(indices_tuple, labels) | ||
self.num_anchors = len(anchor_idx) | ||
if self.num_anchors == 0: | ||
return 0 | ||
|
||
anchors, positives = embeddings[anchor_idx], embeddings[positive_idx] | ||
points_per_anchor = (self.num_anchors - 1) * 2 | ||
|
||
alpha = self.maybe_mask_param(self.alpha, labels[anchor_idx]) | ||
sq_tan_alpha = torch.tan(alpha) ** 2 | ||
|
||
xa_xp = torch.sum(anchors * positives, dim=1, keepdim=True) | ||
term2_multiplier = 2 * (1 + sq_tan_alpha) | ||
term2 = term2_multiplier * xa_xp | ||
|
||
a_p_summed = anchors + positives | ||
inside_exp = [] | ||
mask = torch.ones(self.num_anchors).to(embeddings.device) - torch.eye(self.num_anchors).to(embeddings.device) | ||
term1_multiplier = 4 * sq_tan_alpha | ||
|
||
for negatives in [anchors, positives]: | ||
term1 = term1_multiplier * torch.matmul(a_p_summed, torch.t(negatives)) | ||
inside_exp.append(term1 - term2.repeat(1, self.num_anchors)) | ||
inside_exp[-1] = inside_exp[-1] * mask | ||
|
||
inside_exp_final = torch.zeros((self.num_anchors, points_per_anchor + 1)).to(embeddings.device) | ||
|
||
for i in range(self.num_anchors): | ||
indices = np.concatenate((np.arange(0, i), np.arange(i + 1, inside_exp[0].size(1)))) | ||
inside_exp_final[i, : points_per_anchor // 2] = inside_exp[0][i, indices] | ||
inside_exp_final[:, points_per_anchor // 2 :] = inside_exp[1] | ||
inside_exp_final = self.maybe_modify_loss(inside_exp_final) | ||
|
||
return torch.mean(torch.logsumexp(inside_exp_final, dim=1)) | ||
|
||
def create_learnable_parameter(self, init_value): | ||
return super().create_learnable_parameter(init_value, unsqueeze=True) | ||
|
||
|
||
class AngularLoss(AngularNPairsLoss): | ||
def compute_loss(self, embeddings, labels, indices_tuple): | ||
self.avg_embedding_norm = torch.mean(torch.norm(embeddings, p=2, dim=1)) | ||
anchor_idx, positive_idx, negative_idx = lmu.convert_to_triplets(indices_tuple, labels) | ||
self.num_anchors = len(anchor_idx) | ||
if self.num_anchors == 0: | ||
return 0 | ||
anchors, positives, negatives = embeddings[anchor_idx], embeddings[positive_idx], embeddings[negative_idx] | ||
alpha = self.maybe_mask_param(self.alpha, labels[anchor_idx]) | ||
sq_tan_alpha = torch.tan(alpha) ** 2 | ||
term1 = 4 * sq_tan_alpha * torch.sum((anchors + positives) * negatives, dim=1, keepdim=True) | ||
term2 = 2 * (1 + sq_tan_alpha) * torch.sum(anchors * positives, dim=1, keepdim=True) | ||
final_form = torch.cat([term1 - term2, torch.zeros(term1.size(0), 1).to(embeddings.device)], dim=1) | ||
final_form = self.maybe_modify_loss(final_form) | ||
return torch.mean(torch.logsumexp(final_form, dim=1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
#! /usr/bin/env python3 | ||
|
||
import torch | ||
|
||
|
||
class BaseMetricLossFunction(torch.nn.Module): | ||
""" | ||
All loss functions extend this class | ||
Args: | ||
normalize_embeddings: type boolean. If True then normalize embeddins | ||
to have norm = 1 before computing the loss | ||
num_class_per_param: type int. The number of classes for each parameter. | ||
If your parameters don't have a separate value for each class, | ||
then leave this at None | ||
learnable_param_names: type list of strings. Each element is the name of | ||
attributes that should be converted to nn.Parameter | ||
""" | ||
def __init__( | ||
self, | ||
normalize_embeddings=True, | ||
num_class_per_param=None, | ||
learnable_param_names=None | ||
): | ||
super().__init__() | ||
self.normalize_embeddings = normalize_embeddings | ||
self.num_class_per_param = num_class_per_param | ||
self.learnable_param_names = learnable_param_names | ||
self.initialize_learnable_parameters() | ||
|
||
def compute_loss(self, embeddings, labels, indices_tuple=None): | ||
""" | ||
This has to be implemented and is what actually computes the loss. | ||
""" | ||
raise NotImplementedError | ||
|
||
def forward(self, embeddings, labels, indices_tuple=None): | ||
""" | ||
Args: | ||
embeddings: tensor of size (batch_size, embedding_size) | ||
labels: tensor of size (batch_size) | ||
indices_tuple: tuple of size 3 for triplets (anchors, positives, negatives) | ||
or size 4 for pairs (anchor1, postives, anchor2, negatives) | ||
Can also be left as None | ||
Returns: the loss (float) | ||
""" | ||
labels = labels.to(embeddings.device) | ||
if self.normalize_embeddings: | ||
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) | ||
loss = self.compute_loss(embeddings, labels, indices_tuple) | ||
return loss | ||
|
||
def initialize_learnable_parameters(self): | ||
""" | ||
To learn hyperparams, create an attribute called learnable_param_names. | ||
This should be a list of strings which are the names of the | ||
hyperparameters to be learned | ||
""" | ||
if self.learnable_param_names is not None: | ||
for k in self.learnable_param_names: | ||
v = getattr(self, k) | ||
setattr(self, k, self.create_learnable_parameter(v)) | ||
|
||
def create_learnable_parameter(self, init_value, unsqueeze=False): | ||
""" | ||
Returns nn.Parameter with an initial value of init_value | ||
and size of num_labels | ||
""" | ||
vec_len = self.num_class_per_param if self.num_class_per_param else 1 | ||
if unsqueeze: | ||
vec_len = (vec_len, 1) | ||
p = torch.nn.Parameter(torch.ones(vec_len) * init_value) | ||
return p | ||
|
||
def maybe_mask_param(self, param, labels): | ||
""" | ||
This returns the hyperparameters corresponding to class labels (if applicable). | ||
If there is a hyperparameter for each class, then when computing the loss, | ||
the class hyperparameter has to be matched to the corresponding embedding. | ||
""" | ||
if self.num_class_per_param: | ||
return param[labels] | ||
return param |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
#! /usr/bin/env python3 | ||
|
||
import torch | ||
|
||
from . import generic_pair_loss as gpl | ||
|
||
|
||
class ContrastiveLoss(gpl.GenericPairLoss): | ||
""" | ||
Contrastive loss using either distance or similarity. | ||
Args: | ||
pos_margin: The distance (or similarity) over (under) which positive pairs will contribute to the loss. | ||
neg_margin: The distance (or similarity) under (over) which negative pairs will contribute to the loss. | ||
use_similarity: If True, will use dot product between vectors instead of euclidean distance | ||
power: Each pair's loss will be raised to this power. | ||
avg_non_zero_only: Only pairs that contribute non-zero loss will be used in the final loss. | ||
""" | ||
def __init__( | ||
self, | ||
pos_margin=0, | ||
neg_margin=1, | ||
use_similarity=False, | ||
power=1, | ||
avg_non_zero_only=True, | ||
**kwargs | ||
): | ||
self.pos_margin = pos_margin | ||
self.neg_margin = neg_margin | ||
self.avg_non_zero_only = avg_non_zero_only | ||
self.num_non_zero_pos_pairs = 0 | ||
self.num_non_zero_neg_pairs = 0 | ||
self.record_these = ["num_non_zero_pos_pairs", "num_non_zero_neg_pairs"] | ||
self.power = power | ||
super().__init__(use_similarity=use_similarity, iterate_through_loss=False, **kwargs) | ||
|
||
def pair_based_loss( | ||
self, | ||
pos_pair_dist, | ||
neg_pair_dist, | ||
pos_pair_anchor_labels, | ||
neg_pair_anchor_labels, | ||
): | ||
pos_loss, neg_loss = 0, 0 | ||
self.num_non_zero_pos_pairs, self.num_non_zero_neg_pairs = 0, 0 | ||
if len(pos_pair_dist) > 0: | ||
pos_loss, self.num_non_zero_pos_pairs = self.mask_margin_and_calculate_loss( | ||
pos_pair_dist, pos_pair_anchor_labels, "pos" | ||
) | ||
if len(neg_pair_dist) > 0: | ||
neg_loss, self.num_non_zero_neg_pairs = self.mask_margin_and_calculate_loss( | ||
neg_pair_dist, neg_pair_anchor_labels, "neg" | ||
) | ||
return pos_loss + neg_loss | ||
|
||
def mask_margin_and_calculate_loss(self, pair_dists, labels, pos_or_neg): | ||
loss_calc_func = self.pos_calc if pos_or_neg == "pos" else self.neg_calc | ||
input_margin = self.pos_margin if pos_or_neg == "pos" else self.neg_margin | ||
margin = self.maybe_mask_param(input_margin, labels) | ||
per_pair_loss = loss_calc_func(pair_dists, margin) ** self.power | ||
num_non_zero_pairs = (per_pair_loss > 0).nonzero().size(0) | ||
if self.avg_non_zero_only: | ||
loss = torch.sum(per_pair_loss) / (num_non_zero_pairs + 1e-16) | ||
else: | ||
loss = torch.mean(per_pair_loss) | ||
return loss, num_non_zero_pairs | ||
|
||
def pos_calc(self, pos_pair_dist, margin): | ||
return ( | ||
torch.nn.functional.relu(margin - pos_pair_dist) | ||
if self.use_similarity | ||
else torch.nn.functional.relu(pos_pair_dist - margin) | ||
) | ||
|
||
def neg_calc(self, neg_pair_dist, margin): | ||
return ( | ||
torch.nn.functional.relu(neg_pair_dist - margin) | ||
if self.use_similarity | ||
else torch.nn.functional.relu(margin - neg_pair_dist) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
#! /usr/bin/env python3 | ||
|
||
|
||
import torch | ||
from ..utils import loss_and_miner_utils as lmu | ||
from . import base_metric_loss_function as bmlf | ||
|
||
|
||
class GenericPairLoss(bmlf.BaseMetricLossFunction): | ||
""" | ||
The function pair_based_loss has to be implemented by the child class. | ||
By default, this class extracts every positive and negative pair within a | ||
batch (based on labels) and passes the pairs to the loss function. | ||
The pairs can be passed to the loss function all at once (self.loss_once) | ||
or pairs can be passed iteratively (self.loss_loop) by going through each | ||
sample in a batch, and selecting just the positive and negative pairs | ||
containing that sample. | ||
Args: | ||
use_similarity: set to True if the loss function uses pairwise similarity | ||
(dot product of each embedding pair). Otherwise, | ||
euclidean distance will be used | ||
iterate_through_loss: set to True to use self.loss_loop and False otherwise | ||
squared_distances: if True, then the euclidean distance will be squared. | ||
""" | ||
|
||
def __init__( | ||
self, use_similarity, iterate_through_loss, squared_distances=False, **kwargs | ||
): | ||
self.use_similarity = use_similarity | ||
self.squared_distances = squared_distances | ||
self.loss_method = self.loss_loop if iterate_through_loss else self.loss_once | ||
super().__init__(**kwargs) | ||
|
||
def compute_loss(self, embeddings, labels, indices_tuple): | ||
mat = ( | ||
lmu.sim_mat(embeddings) | ||
if self.use_similarity | ||
else lmu.dist_mat(embeddings, squared=self.squared_distances) | ||
) | ||
indices_tuple = lmu.convert_to_pairs(indices_tuple, labels) | ||
return self.loss_method(mat, labels, indices_tuple) | ||
|
||
def pair_based_loss( | ||
self, pos_pairs, neg_pairs, pos_pair_anchor_labels, neg_pair_anchor_labels | ||
): | ||
raise NotImplementedError | ||
|
||
def loss_loop(self, mat, labels, indices_tuple): | ||
loss = torch.tensor(0.0).to(mat.device) | ||
n = 0 | ||
(a1_indices, p_indices, a2_indices, n_indices) = indices_tuple | ||
for i in range(mat.size(0)): | ||
pos_pair, neg_pair = [], [] | ||
if len(a1_indices) > 0: | ||
p_idx = a1_indices == i | ||
pos_pair = mat[a1_indices[p_idx], p_indices[p_idx]] | ||
if len(a2_indices) > 0: | ||
n_idx = a2_indices == i | ||
neg_pair = mat[a2_indices[n_idx], n_indices[n_idx]] | ||
loss += self.pair_based_loss( | ||
pos_pair, neg_pair, labels[a1_indices[p_idx]], labels[a2_indices[n_idx]] | ||
) | ||
n += 1 | ||
return loss / (n if n > 0 else 1) | ||
|
||
def loss_once(self, mat, labels, indices_tuple): | ||
(a1_indices, p_indices, a2_indices, n_indices) = indices_tuple | ||
pos_pair, neg_pair = [], [] | ||
if len(a1_indices) > 0: | ||
pos_pair = mat[a1_indices, p_indices] | ||
if len(a2_indices) > 0: | ||
neg_pair = mat[a2_indices, n_indices] | ||
return self.pair_based_loss( | ||
pos_pair, neg_pair, labels[a1_indices], labels[a2_indices] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#! /usr/bin/env python3 | ||
|
||
import torch | ||
|
||
from . import generic_pair_loss as gpl | ||
|
||
|
||
class GeneralizedLiftedStructureLoss(gpl.GenericPairLoss): | ||
# The 'generalized' lifted structure loss shown on page 4 | ||
# of the "in defense of triplet loss" paper | ||
# https://arxiv.org/pdf/1703.07737.pdf | ||
def __init__(self, neg_margin, **kwargs): | ||
self.neg_margin = neg_margin | ||
super().__init__(use_similarity=False, iterate_through_loss=True, **kwargs) | ||
|
||
def pair_based_loss(self, pos_pairs, neg_pairs, pos_pair_anchor_labels, neg_pair_anchor_labels): | ||
neg_margin = self.maybe_mask_param(self.neg_margin, neg_pair_anchor_labels) | ||
per_anchor = torch.logsumexp(pos_pairs, dim=0) + torch.logsumexp(neg_margin - neg_pairs, dim=0) | ||
hinged = torch.relu(per_anchor) | ||
return torch.mean(hinged) |
Oops, something went wrong.