From e95c29056d12c351927158ff63d380e79ae83910 Mon Sep 17 00:00:00 2001 From: TakeshiMusgrave Date: Thu, 24 Oct 2019 16:29:51 -0400 Subject: [PATCH] adding files --- .gitignore | 6 + LICENSE | 21 ++ losses/__init__.py | 10 + losses/angular_loss.py | 78 ++++++ losses/base_metric_loss_function.py | 82 +++++++ losses/contrastive_loss.py | 79 +++++++ losses/generic_pair_loss.py | 75 ++++++ losses/lifted_structure_loss.py | 20 ++ losses/margin_loss.py | 46 ++++ losses/multi_similarity_loss.py | 37 +++ losses/n_pairs_loss.py | 35 +++ losses/nca_loss.py | 32 +++ losses/proxy_losses.py | 18 ++ losses/triplet_margin_loss.py | 68 ++++++ miners/__init__.py | 7 + miners/base_miner.py | 101 ++++++++ miners/distance_weighted_miner.py | 45 ++++ ...embeddings_already_packaged_as_triplets.py | 18 ++ miners/hdc_miner.py | 90 +++++++ miners/maximum_loss_miner.py | 29 +++ miners/multi_similarity_miner.py | 50 ++++ miners/pair_margin_miner.py | 49 ++++ samplers/__init__.py | 2 + samplers/fixed_set_of_triplets.py | 44 ++++ samplers/m_per_class_sampler.py | 45 ++++ setup.py | 22 ++ trainers/__init__.py | 4 + trainers/base_trainer.py | 222 ++++++++++++++++++ trainers/cascaded_embeddings.py | 29 +++ trainers/deep_adversarial_metric_learning.py | 141 +++++++++++ trainers/metric_loss_only.py | 21 ++ trainers/train_with_classifier.py | 33 +++ utils/__init__.py | 0 utils/common_functions.py | 221 +++++++++++++++++ utils/loss_and_miner_utils.py | 159 +++++++++++++ utils/loss_tracker.py | 33 +++ utils/misc_models.py | 23 ++ utils/record_keeper.py | 105 +++++++++ 38 files changed, 2100 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 losses/__init__.py create mode 100644 losses/angular_loss.py create mode 100644 losses/base_metric_loss_function.py create mode 100644 losses/contrastive_loss.py create mode 100644 losses/generic_pair_loss.py create mode 100644 losses/lifted_structure_loss.py create mode 100644 losses/margin_loss.py create mode 100644 losses/multi_similarity_loss.py create mode 100644 losses/n_pairs_loss.py create mode 100644 losses/nca_loss.py create mode 100644 losses/proxy_losses.py create mode 100644 losses/triplet_margin_loss.py create mode 100644 miners/__init__.py create mode 100644 miners/base_miner.py create mode 100644 miners/distance_weighted_miner.py create mode 100644 miners/embeddings_already_packaged_as_triplets.py create mode 100644 miners/hdc_miner.py create mode 100644 miners/maximum_loss_miner.py create mode 100644 miners/multi_similarity_miner.py create mode 100644 miners/pair_margin_miner.py create mode 100644 samplers/__init__.py create mode 100644 samplers/fixed_set_of_triplets.py create mode 100644 samplers/m_per_class_sampler.py create mode 100644 setup.py create mode 100644 trainers/__init__.py create mode 100644 trainers/base_trainer.py create mode 100644 trainers/cascaded_embeddings.py create mode 100644 trainers/deep_adversarial_metric_learning.py create mode 100644 trainers/metric_loss_only.py create mode 100644 trainers/train_with_classifier.py create mode 100644 utils/__init__.py create mode 100644 utils/common_functions.py create mode 100644 utils/loss_and_miner_utils.py create mode 100644 utils/loss_tracker.py create mode 100644 utils/misc_models.py create mode 100644 utils/record_keeper.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..39801679 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +__pycache__/ +*.py[cod] +.nfs* +build/ +dist/ +pytorch_metric_learning.egg-info/ \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..56d56cb3 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Kevin Musgrave + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/losses/__init__.py b/losses/__init__.py new file mode 100644 index 00000000..3af687bb --- /dev/null +++ b/losses/__init__.py @@ -0,0 +1,10 @@ +from .angular_loss import AngularLoss, AngularNPairsLoss +from .contrastive_loss import ContrastiveLoss +from .generic_pair_loss import GenericPairLoss +from .lifted_structure_loss import GeneralizedLiftedStructureLoss +from .margin_loss import MarginLoss +from .multi_similarity_loss import MultiSimilarityLoss +from .nca_loss import NCALoss +from .n_pairs_loss import NPairsLoss +from .proxy_losses import ProxyNCALoss +from .triplet_margin_loss import TripletMarginLoss \ No newline at end of file diff --git a/losses/angular_loss.py b/losses/angular_loss.py new file mode 100644 index 00000000..44d8afe9 --- /dev/null +++ b/losses/angular_loss.py @@ -0,0 +1,78 @@ +#! /usr/bin/env python3 + +from . import base_metric_loss_function as bmlf +import numpy as np +import torch +from ..utils import loss_and_miner_utils as lmu + + +class AngularNPairsLoss(bmlf.BaseMetricLossFunction): + """ + Implementation of https://arxiv.org/abs/1708.01682 + Args: + alpha: The angle (as described in the paper), specified in degrees. + """ + def __init__(self, alpha, **kwargs): + self.alpha = torch.tensor(np.radians(alpha)) + self.maybe_modify_loss = lambda x: x + self.num_anchors = 0 + self.avg_embedding_norm = 0 + self.record_these = ["num_anchors", "avg_embedding_norm"] + super().__init__(**kwargs) + + def compute_loss(self, embeddings, labels, indices_tuple): + self.avg_embedding_norm = torch.mean(torch.norm(embeddings, p=2, dim=1)) + anchor_idx, positive_idx = lmu.convert_to_pos_pairs_with_unique_labels(indices_tuple, labels) + self.num_anchors = len(anchor_idx) + if self.num_anchors == 0: + return 0 + + anchors, positives = embeddings[anchor_idx], embeddings[positive_idx] + points_per_anchor = (self.num_anchors - 1) * 2 + + alpha = self.maybe_mask_param(self.alpha, labels[anchor_idx]) + sq_tan_alpha = torch.tan(alpha) ** 2 + + xa_xp = torch.sum(anchors * positives, dim=1, keepdim=True) + term2_multiplier = 2 * (1 + sq_tan_alpha) + term2 = term2_multiplier * xa_xp + + a_p_summed = anchors + positives + inside_exp = [] + mask = torch.ones(self.num_anchors).to(embeddings.device) - torch.eye(self.num_anchors).to(embeddings.device) + term1_multiplier = 4 * sq_tan_alpha + + for negatives in [anchors, positives]: + term1 = term1_multiplier * torch.matmul(a_p_summed, torch.t(negatives)) + inside_exp.append(term1 - term2.repeat(1, self.num_anchors)) + inside_exp[-1] = inside_exp[-1] * mask + + inside_exp_final = torch.zeros((self.num_anchors, points_per_anchor + 1)).to(embeddings.device) + + for i in range(self.num_anchors): + indices = np.concatenate((np.arange(0, i), np.arange(i + 1, inside_exp[0].size(1)))) + inside_exp_final[i, : points_per_anchor // 2] = inside_exp[0][i, indices] + inside_exp_final[:, points_per_anchor // 2 :] = inside_exp[1] + inside_exp_final = self.maybe_modify_loss(inside_exp_final) + + return torch.mean(torch.logsumexp(inside_exp_final, dim=1)) + + def create_learnable_parameter(self, init_value): + return super().create_learnable_parameter(init_value, unsqueeze=True) + + +class AngularLoss(AngularNPairsLoss): + def compute_loss(self, embeddings, labels, indices_tuple): + self.avg_embedding_norm = torch.mean(torch.norm(embeddings, p=2, dim=1)) + anchor_idx, positive_idx, negative_idx = lmu.convert_to_triplets(indices_tuple, labels) + self.num_anchors = len(anchor_idx) + if self.num_anchors == 0: + return 0 + anchors, positives, negatives = embeddings[anchor_idx], embeddings[positive_idx], embeddings[negative_idx] + alpha = self.maybe_mask_param(self.alpha, labels[anchor_idx]) + sq_tan_alpha = torch.tan(alpha) ** 2 + term1 = 4 * sq_tan_alpha * torch.sum((anchors + positives) * negatives, dim=1, keepdim=True) + term2 = 2 * (1 + sq_tan_alpha) * torch.sum(anchors * positives, dim=1, keepdim=True) + final_form = torch.cat([term1 - term2, torch.zeros(term1.size(0), 1).to(embeddings.device)], dim=1) + final_form = self.maybe_modify_loss(final_form) + return torch.mean(torch.logsumexp(final_form, dim=1)) diff --git a/losses/base_metric_loss_function.py b/losses/base_metric_loss_function.py new file mode 100644 index 00000000..7efc90fe --- /dev/null +++ b/losses/base_metric_loss_function.py @@ -0,0 +1,82 @@ +#! /usr/bin/env python3 + +import torch + + +class BaseMetricLossFunction(torch.nn.Module): + """ + All loss functions extend this class + Args: + normalize_embeddings: type boolean. If True then normalize embeddins + to have norm = 1 before computing the loss + num_class_per_param: type int. The number of classes for each parameter. + If your parameters don't have a separate value for each class, + then leave this at None + learnable_param_names: type list of strings. Each element is the name of + attributes that should be converted to nn.Parameter + """ + def __init__( + self, + normalize_embeddings=True, + num_class_per_param=None, + learnable_param_names=None + ): + super().__init__() + self.normalize_embeddings = normalize_embeddings + self.num_class_per_param = num_class_per_param + self.learnable_param_names = learnable_param_names + self.initialize_learnable_parameters() + + def compute_loss(self, embeddings, labels, indices_tuple=None): + """ + This has to be implemented and is what actually computes the loss. + """ + raise NotImplementedError + + def forward(self, embeddings, labels, indices_tuple=None): + """ + Args: + embeddings: tensor of size (batch_size, embedding_size) + labels: tensor of size (batch_size) + indices_tuple: tuple of size 3 for triplets (anchors, positives, negatives) + or size 4 for pairs (anchor1, postives, anchor2, negatives) + Can also be left as None + Returns: the loss (float) + """ + labels = labels.to(embeddings.device) + if self.normalize_embeddings: + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + loss = self.compute_loss(embeddings, labels, indices_tuple) + return loss + + def initialize_learnable_parameters(self): + """ + To learn hyperparams, create an attribute called learnable_param_names. + This should be a list of strings which are the names of the + hyperparameters to be learned + """ + if self.learnable_param_names is not None: + for k in self.learnable_param_names: + v = getattr(self, k) + setattr(self, k, self.create_learnable_parameter(v)) + + def create_learnable_parameter(self, init_value, unsqueeze=False): + """ + Returns nn.Parameter with an initial value of init_value + and size of num_labels + """ + vec_len = self.num_class_per_param if self.num_class_per_param else 1 + if unsqueeze: + vec_len = (vec_len, 1) + p = torch.nn.Parameter(torch.ones(vec_len) * init_value) + return p + + def maybe_mask_param(self, param, labels): + """ + This returns the hyperparameters corresponding to class labels (if applicable). + If there is a hyperparameter for each class, then when computing the loss, + the class hyperparameter has to be matched to the corresponding embedding. + """ + if self.num_class_per_param: + return param[labels] + return param diff --git a/losses/contrastive_loss.py b/losses/contrastive_loss.py new file mode 100644 index 00000000..baa2bee0 --- /dev/null +++ b/losses/contrastive_loss.py @@ -0,0 +1,79 @@ +#! /usr/bin/env python3 + +import torch + +from . import generic_pair_loss as gpl + + +class ContrastiveLoss(gpl.GenericPairLoss): + """ + Contrastive loss using either distance or similarity. + Args: + pos_margin: The distance (or similarity) over (under) which positive pairs will contribute to the loss. + neg_margin: The distance (or similarity) under (over) which negative pairs will contribute to the loss. + use_similarity: If True, will use dot product between vectors instead of euclidean distance + power: Each pair's loss will be raised to this power. + avg_non_zero_only: Only pairs that contribute non-zero loss will be used in the final loss. + """ + def __init__( + self, + pos_margin=0, + neg_margin=1, + use_similarity=False, + power=1, + avg_non_zero_only=True, + **kwargs + ): + self.pos_margin = pos_margin + self.neg_margin = neg_margin + self.avg_non_zero_only = avg_non_zero_only + self.num_non_zero_pos_pairs = 0 + self.num_non_zero_neg_pairs = 0 + self.record_these = ["num_non_zero_pos_pairs", "num_non_zero_neg_pairs"] + self.power = power + super().__init__(use_similarity=use_similarity, iterate_through_loss=False, **kwargs) + + def pair_based_loss( + self, + pos_pair_dist, + neg_pair_dist, + pos_pair_anchor_labels, + neg_pair_anchor_labels, + ): + pos_loss, neg_loss = 0, 0 + self.num_non_zero_pos_pairs, self.num_non_zero_neg_pairs = 0, 0 + if len(pos_pair_dist) > 0: + pos_loss, self.num_non_zero_pos_pairs = self.mask_margin_and_calculate_loss( + pos_pair_dist, pos_pair_anchor_labels, "pos" + ) + if len(neg_pair_dist) > 0: + neg_loss, self.num_non_zero_neg_pairs = self.mask_margin_and_calculate_loss( + neg_pair_dist, neg_pair_anchor_labels, "neg" + ) + return pos_loss + neg_loss + + def mask_margin_and_calculate_loss(self, pair_dists, labels, pos_or_neg): + loss_calc_func = self.pos_calc if pos_or_neg == "pos" else self.neg_calc + input_margin = self.pos_margin if pos_or_neg == "pos" else self.neg_margin + margin = self.maybe_mask_param(input_margin, labels) + per_pair_loss = loss_calc_func(pair_dists, margin) ** self.power + num_non_zero_pairs = (per_pair_loss > 0).nonzero().size(0) + if self.avg_non_zero_only: + loss = torch.sum(per_pair_loss) / (num_non_zero_pairs + 1e-16) + else: + loss = torch.mean(per_pair_loss) + return loss, num_non_zero_pairs + + def pos_calc(self, pos_pair_dist, margin): + return ( + torch.nn.functional.relu(margin - pos_pair_dist) + if self.use_similarity + else torch.nn.functional.relu(pos_pair_dist - margin) + ) + + def neg_calc(self, neg_pair_dist, margin): + return ( + torch.nn.functional.relu(neg_pair_dist - margin) + if self.use_similarity + else torch.nn.functional.relu(margin - neg_pair_dist) + ) diff --git a/losses/generic_pair_loss.py b/losses/generic_pair_loss.py new file mode 100644 index 00000000..95f0ff1e --- /dev/null +++ b/losses/generic_pair_loss.py @@ -0,0 +1,75 @@ +#! /usr/bin/env python3 + + +import torch +from ..utils import loss_and_miner_utils as lmu +from . import base_metric_loss_function as bmlf + + +class GenericPairLoss(bmlf.BaseMetricLossFunction): + """ + The function pair_based_loss has to be implemented by the child class. + By default, this class extracts every positive and negative pair within a + batch (based on labels) and passes the pairs to the loss function. + The pairs can be passed to the loss function all at once (self.loss_once) + or pairs can be passed iteratively (self.loss_loop) by going through each + sample in a batch, and selecting just the positive and negative pairs + containing that sample. + Args: + use_similarity: set to True if the loss function uses pairwise similarity + (dot product of each embedding pair). Otherwise, + euclidean distance will be used + iterate_through_loss: set to True to use self.loss_loop and False otherwise + squared_distances: if True, then the euclidean distance will be squared. + """ + + def __init__( + self, use_similarity, iterate_through_loss, squared_distances=False, **kwargs + ): + self.use_similarity = use_similarity + self.squared_distances = squared_distances + self.loss_method = self.loss_loop if iterate_through_loss else self.loss_once + super().__init__(**kwargs) + + def compute_loss(self, embeddings, labels, indices_tuple): + mat = ( + lmu.sim_mat(embeddings) + if self.use_similarity + else lmu.dist_mat(embeddings, squared=self.squared_distances) + ) + indices_tuple = lmu.convert_to_pairs(indices_tuple, labels) + return self.loss_method(mat, labels, indices_tuple) + + def pair_based_loss( + self, pos_pairs, neg_pairs, pos_pair_anchor_labels, neg_pair_anchor_labels + ): + raise NotImplementedError + + def loss_loop(self, mat, labels, indices_tuple): + loss = torch.tensor(0.0).to(mat.device) + n = 0 + (a1_indices, p_indices, a2_indices, n_indices) = indices_tuple + for i in range(mat.size(0)): + pos_pair, neg_pair = [], [] + if len(a1_indices) > 0: + p_idx = a1_indices == i + pos_pair = mat[a1_indices[p_idx], p_indices[p_idx]] + if len(a2_indices) > 0: + n_idx = a2_indices == i + neg_pair = mat[a2_indices[n_idx], n_indices[n_idx]] + loss += self.pair_based_loss( + pos_pair, neg_pair, labels[a1_indices[p_idx]], labels[a2_indices[n_idx]] + ) + n += 1 + return loss / (n if n > 0 else 1) + + def loss_once(self, mat, labels, indices_tuple): + (a1_indices, p_indices, a2_indices, n_indices) = indices_tuple + pos_pair, neg_pair = [], [] + if len(a1_indices) > 0: + pos_pair = mat[a1_indices, p_indices] + if len(a2_indices) > 0: + neg_pair = mat[a2_indices, n_indices] + return self.pair_based_loss( + pos_pair, neg_pair, labels[a1_indices], labels[a2_indices] + ) diff --git a/losses/lifted_structure_loss.py b/losses/lifted_structure_loss.py new file mode 100644 index 00000000..594fb6b0 --- /dev/null +++ b/losses/lifted_structure_loss.py @@ -0,0 +1,20 @@ +#! /usr/bin/env python3 + +import torch + +from . import generic_pair_loss as gpl + + +class GeneralizedLiftedStructureLoss(gpl.GenericPairLoss): + # The 'generalized' lifted structure loss shown on page 4 + # of the "in defense of triplet loss" paper + # https://arxiv.org/pdf/1703.07737.pdf + def __init__(self, neg_margin, **kwargs): + self.neg_margin = neg_margin + super().__init__(use_similarity=False, iterate_through_loss=True, **kwargs) + + def pair_based_loss(self, pos_pairs, neg_pairs, pos_pair_anchor_labels, neg_pair_anchor_labels): + neg_margin = self.maybe_mask_param(self.neg_margin, neg_pair_anchor_labels) + per_anchor = torch.logsumexp(pos_pairs, dim=0) + torch.logsumexp(neg_margin - neg_pairs, dim=0) + hinged = torch.relu(per_anchor) + return torch.mean(hinged) diff --git a/losses/margin_loss.py b/losses/margin_loss.py new file mode 100644 index 00000000..e048dcf1 --- /dev/null +++ b/losses/margin_loss.py @@ -0,0 +1,46 @@ +#! /usr/bin/env python3 + +from . import base_metric_loss_function as bmlf +import torch +from ..utils import loss_and_miner_utils as lmu, common_functions as c_f + + +class MarginLoss(bmlf.BaseMetricLossFunction): + + def __init__(self, margin, nu, beta, **kwargs): + self.margin = margin + self.nu = nu + self.beta = beta + self.num_pos_pairs = 0 + self.num_neg_pairs = 0 + self.record_these = ["num_pos_pairs", "num_neg_pairs"] + super().__init__(**kwargs) + + def compute_loss(self, embeddings, labels, indices_tuple): + anchor_idx, positive_idx, negative_idx = lmu.convert_to_triplets(indices_tuple, labels) + if len(anchor_idx) == 0: + self.num_pos_pairs = 0 + self.num_neg_pairs = 0 + return 0 + anchors, positives, negatives = embeddings[anchor_idx], embeddings[positive_idx], embeddings[negative_idx] + beta = self.maybe_mask_param(self.beta, labels[anchor_idx]) + beta_reg_loss = self.compute_reg_loss(beta) + + d_ap = torch.nn.functional.pairwise_distance(positives, anchors, p=2) + d_an = torch.nn.functional.pairwise_distance(negatives, anchors, p=2) + + pos_loss = torch.nn.functional.relu(d_ap - beta + self.margin) + neg_loss = torch.nn.functional.relu(beta - d_an + self.margin) + + self.num_pos_pairs = (pos_loss > 0.0).nonzero().size(0) + self.num_neg_pairs = (neg_loss > 0.0).nonzero().size(0) + + pair_count = self.num_pos_pairs + self.num_neg_pairs + + return (torch.sum(pos_loss + neg_loss) + beta_reg_loss) / (pair_count + 1e-16) + + def compute_reg_loss(self, beta): + if self.nu > 0: + beta_mean = c_f.try_torch_operation(torch.mean, beta) + return beta_mean * self.nu + return 0 \ No newline at end of file diff --git a/losses/multi_similarity_loss.py b/losses/multi_similarity_loss.py new file mode 100644 index 00000000..140c711d --- /dev/null +++ b/losses/multi_similarity_loss.py @@ -0,0 +1,37 @@ +#! /usr/bin/env python3 + +import torch + +from . import generic_pair_loss as gpl +from ..utils import common_functions as c_f + +class MultiSimilarityLoss(gpl.GenericPairLoss): + """ + modified from https://github.com/MalongTech/research-ms-loss/ + Args: + alpha: The exponential weight for positive pairs + beta: The exponential weight for negative pairs + base: The shift in the exponent applied to both positive and negative pairs + """ + def __init__(self, alpha, beta, base=0.5, **kwargs): + self.alpha = alpha + self.beta = beta + self.base = base + super().__init__(use_similarity=True, iterate_through_loss=True, **kwargs) + + def pair_based_loss( + self, pos_pairs, neg_pairs, pos_pair_anchor_labels, neg_pair_anchor_labels + ): + pos_loss, neg_loss = 0, 0 + if len(pos_pairs) > 0: + alpha = self.maybe_mask_param(self.alpha, pos_pair_anchor_labels) + pos_loss = self.exp_loss(pos_pairs, -alpha, 1.0/alpha) + if len(neg_pairs) > 0: + beta = self.maybe_mask_param(self.beta, neg_pair_anchor_labels) + neg_loss = self.exp_loss(neg_pairs, beta, 1.0/beta) + return pos_loss + neg_loss + + def exp_loss(self, pair, exp_weight, scaler): + scaler = c_f.try_torch_operation(torch.mean, scaler) + inside_exp = exp_weight * (pair - self.base) + return scaler * torch.log(1 + torch.sum(torch.exp(inside_exp))) diff --git a/losses/n_pairs_loss.py b/losses/n_pairs_loss.py new file mode 100644 index 00000000..51f9e61a --- /dev/null +++ b/losses/n_pairs_loss.py @@ -0,0 +1,35 @@ +#! /usr/bin/env python3 + +from . import base_metric_loss_function as bmlf +import torch +from ..utils import loss_and_miner_utils as lmu + + +class NPairsLoss(bmlf.BaseMetricLossFunction): + """ + Implementation of https://arxiv.org/abs/1708.01682 + Args: + l2_reg_weight: The L2 regularizer weight (multiplier) + """ + def __init__(self, l2_reg_weight=0, **kwargs): + self.cross_entropy = torch.nn.CrossEntropyLoss() + self.l2_reg_weight = l2_reg_weight + self.num_pairs = 0 + self.avg_embedding_norm = 0 + self.record_these = ["num_pairs", "avg_embedding_norm"] + super().__init__(**kwargs) + + def compute_loss(self, embeddings, labels, indices_tuple): + self.avg_embedding_norm = torch.mean(torch.norm(embeddings, p=2, dim=1)) + anchor_idx, positive_idx = lmu.convert_to_pos_pairs_with_unique_labels(indices_tuple, labels) + self.num_pairs = len(anchor_idx) + if self.num_pairs == 0: + return 0 + anchors, positives = embeddings[anchor_idx], embeddings[positive_idx] + targets = torch.arange(self.num_pairs).to(embeddings.device) + sim_mat = torch.matmul(anchors, positives.t()) + s_loss = self.cross_entropy(sim_mat, targets) + if self.l2_reg_weight > 0: + l2_reg = torch.mean(torch.norm(embeddings, p=2, dim=1)) + return s_loss + l2_reg * self.l2_reg_weight + return s_loss diff --git a/losses/nca_loss.py b/losses/nca_loss.py new file mode 100644 index 00000000..befd4bbb --- /dev/null +++ b/losses/nca_loss.py @@ -0,0 +1,32 @@ +#! /usr/bin/env python3 + +from . import base_metric_loss_function as bmlf +from ..utils import loss_and_miner_utils as lmu +import torch + + +class NCALoss(bmlf.BaseMetricLossFunction): + # modified from https://github.com/microsoft/snca.pytorch/blob/master/lib/NCA.py + # https://www.cs.toronto.edu/~hinton/absps/nca.pdf + def compute_loss(self, embeddings, labels, *_): + return self.nca_computation(embeddings, embeddings, labels, labels) + + def nca_computation(self, query, reference, query_labels, reference_labels): + query_batch_size = len(query) + reference_batch_size = len(reference) + x = lmu.dist_mat(query, reference, squared=True) + exp = torch.exp(-x) + + if query is reference: + exp = exp - torch.diag(exp.diag()) + repeated_labels = query_labels.view(query_batch_size, 1).repeat(1, reference_batch_size) + same_labels = (repeated_labels == reference_labels).float() + + # sum over all positive neighbors of each anchor + p = torch.sum(exp * same_labels, dim=1) + # sum over all neighbors of each anchor (excluding the anchor) + Z = torch.sum(exp * (1 - same_labels), dim=1) + prob = p / Z + non_zero_prob = torch.masked_select(prob, prob != 0) + + return -torch.mean(torch.log(non_zero_prob)) \ No newline at end of file diff --git a/losses/proxy_losses.py b/losses/proxy_losses.py new file mode 100644 index 00000000..783f7be9 --- /dev/null +++ b/losses/proxy_losses.py @@ -0,0 +1,18 @@ +#! /usr/bin/env python3 + +from . import nca_loss as nca +import torch +from ..utils import loss_and_miner_utils as lmu + +class ProxyNCALoss(nca.NCALoss): + def __init__(self, num_labels, embedding_size, **kwargs): + self.proxies = torch.nn.Parameter(torch.randn(num_labels, embedding_size)) + self.proxy_labels = torch.arange(num_labels) + super().__init__(**kwargs) + + def compute_loss(self, embeddings, labels, *_): + if self.normalize_embeddings: + prox = torch.nn.functional.normalize(self.proxies, p=2, dim=1) + else: + prox = self.proxies + return self.nca_computation(embeddings, prox, labels, self.proxy_labels.to(labels.device)) \ No newline at end of file diff --git a/losses/triplet_margin_loss.py b/losses/triplet_margin_loss.py new file mode 100644 index 00000000..585a78c8 --- /dev/null +++ b/losses/triplet_margin_loss.py @@ -0,0 +1,68 @@ +#! /usr/bin/env python3 + +from . import base_metric_loss_function as bmlf +import torch +import torch.nn.functional as F +from ..utils import loss_and_miner_utils as lmu + + +class TripletMarginLoss(bmlf.BaseMetricLossFunction): + """ + Args: + margin: The desired difference between the anchor-positive distance and the + anchor-negative distance. + distance_norm: The norm used when calculating distance between embeddings + power: Each pair's loss will be raised to this power. + swap: Use the positive-negative distance instead of anchor-negative distance, + if it violates the margin more. + smooth_loss: Use the log-exp version of the triplet loss + avg_non_zero_only: Only pairs that contribute non-zero loss will be used in the final loss. + """ + def __init__( + self, + margin=0.05, + distance_norm=2, + power=1, + swap=False, + smooth_loss=False, + avg_non_zero_only=True, + **kwargs + ): + self.margin = margin + self.distance_norm = distance_norm + self.power = power + self.swap = swap + self.smooth_loss = smooth_loss + self.avg_non_zero_only = avg_non_zero_only + self.num_non_zero_triplets = 0 + self.record_these = ["num_non_zero_triplets"] + self.maybe_modify_loss = lambda x: x + super().__init__(**kwargs) + + def compute_loss(self, embeddings, labels, indices_tuple): + indices_tuple = lmu.convert_to_triplets(indices_tuple, labels) + anchor_idx, positive_idx, negative_idx = indices_tuple + if len(anchor_idx) == 0: + self.num_non_zero_triplets = 0 + return 0 + anchors, positives, negatives = embeddings[anchor_idx], embeddings[positive_idx], embeddings[negative_idx] + a_p_dist = F.pairwise_distance(anchors, positives, self.distance_norm) + a_n_dist = F.pairwise_distance(anchors, negatives, self.distance_norm) + if self.swap: + p_n_dist = F.pairwise_distance(positives, negatives, self.distance_norm) + a_n_dist = torch.min(a_n_dist, p_n_dist) + a_p_dist = a_p_dist ** self.power + a_n_dist = a_n_dist ** self.power + if self.smooth_loss: + inside_exp = a_p_dist - a_n_dist + inside_exp = self.maybe_modify_loss(inside_exp) + return torch.mean(torch.log(1 + torch.exp(inside_exp))) + else: + dist = a_p_dist - a_n_dist + margin = self.maybe_mask_param(self.margin, labels[anchor_idx]) + loss_modified = self.maybe_modify_loss(dist + margin) + relued = torch.nn.functional.relu(loss_modified) + self.num_non_zero_triplets = (relued > 0).nonzero().size(0) + if self.avg_non_zero_only: + return torch.sum(relued) / (self.num_non_zero_triplets + 1e-16) + return torch.mean(relued) diff --git a/miners/__init__.py b/miners/__init__.py new file mode 100644 index 00000000..62511997 --- /dev/null +++ b/miners/__init__.py @@ -0,0 +1,7 @@ +from .distance_weighted_miner import DistanceWeightedMiner +from .embeddings_already_packaged_as_triplets import EmbeddingsAlreadyPackagedAsTriplets +from .hdc_miner import HDCMiner +from .maximum_loss_miner import MaximumLossMiner +from .multi_similarity_miner import MultiSimilarityMiner +from .pair_margin_miner import PairMarginMiner + diff --git a/miners/base_miner.py b/miners/base_miner.py new file mode 100644 index 00000000..c391c727 --- /dev/null +++ b/miners/base_miner.py @@ -0,0 +1,101 @@ +#! /usr/bin/env python3 + +import torch + + +class BaseMiner(torch.nn.Module): + def __init__(self, normalize_embeddings=True): + super().__init__() + self.normalize_embeddings = normalize_embeddings + + def mine(self, embeddings, labels): + """ + Args: + embeddings: tensor of size (batch_size, embedding_size) + labels: tensor of size (batch_size) + Returns: a tuple where each element is an array of indices. + """ + raise NotImplementedError + + def output_assertion(self, output): + raise NotImplementedError + + def forward(self, embeddings, labels): + """ + Args: + embeddings: tensor of size (batch_size, embedding_size) + labels: tensor of size (batch_size) + Does any necessary preprocessing, then does mining, and then checks the + shape of the mining output before returning it + """ + labels = labels.to(embeddings.device) + with torch.no_grad(): + if self.normalize_embeddings: + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + mining_output = self.mine(embeddings, labels) + self.output_assertion(mining_output) + return mining_output + +class BasePostGradientMiner(BaseMiner): + """ + A post-gradient miner is used after gradients have already been computed. + In other words, the composition of the batch has already been decided, + and the miner will find pairs or triplets within the batch that should + be used to compute the loss. + Args: + normalize_embeddings: type boolean, if True then normalize embeddings + to have norm = 1 before mining + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.num_pos_pairs = 0 + self.num_neg_pairs = 0 + self.num_triplets = 0 + record_these = ["num_pos_pairs", "num_neg_pairs", "num_triplets"] + if hasattr(self, "record_these"): + self.record_these += record_these + else: + self.record_these = record_these + + + def output_assertion(self, output): + """ + Args: + output: the output of self.mine + This asserts that the mining function is outputting + properly formatted indices. The default is to require a tuple representing + a,p,n indices or a1,p,a2,n indices within a batch of embeddings. + For example, a tuple of (anchors, positives, negatives) will be + (torch.tensor, torch.tensor, torch.tensor) + """ + if len(output) == 3: + self.num_triplets = len(output[0]) + assert self.num_triplets == len(output[1]) == len(output[2]) + elif len(output) == 4: + self.num_pos_pairs = len(output[0]) + self.num_neg_pairs = len(output[2]) + assert self.num_pos_pairs == len(output[1]) + assert self.num_neg_pairs == len(output[3]) + else: + raise BaseException + + +class BasePreGradientMiner(BaseMiner): + """ + A pre-gradient miner is used before gradients have been computed. + The miner finds a subset of the sampled batch for which gradients will + then need to be computed. + Args: + output_batch_size: type int. The size of the subset that the miner + will output. + normalize_embeddings: type boolean, if True then normalize embeddings + to have norm = 1 before mining + """ + + def __init__(self, output_batch_size, **kwargs): + super().__init__(**kwargs) + self.output_batch_size = output_batch_size + + def output_assertion(self, output): + assert len(output) == self.output_batch_size \ No newline at end of file diff --git a/miners/distance_weighted_miner.py b/miners/distance_weighted_miner.py new file mode 100644 index 00000000..065e1305 --- /dev/null +++ b/miners/distance_weighted_miner.py @@ -0,0 +1,45 @@ +#! /usr/bin/env python3 + +from . import base_miner as b_m +import torch +from ..utils import loss_and_miner_utils as lmu + + +# adapted from +# https://github.com/chaoyuaw/incubator-mxnet/blob/master/example/gluon/ +# /embedding_learning/model.py +class DistanceWeightedMiner(b_m.BasePostGradientMiner): + def __init__(self, cutoff, nonzero_loss_cutoff, **kwargs): + super().__init__(**kwargs) + self.cutoff = cutoff + self.nonzero_loss_cutoff = nonzero_loss_cutoff + + def mine(self, embeddings, labels): + label_set = torch.unique(labels) + n, d = embeddings.size() + + dist_mat = lmu.dist_mat(embeddings) + dist_mat = dist_mat + torch.eye(dist_mat.size(0)).to(embeddings.device) + # so that we don't get log(0). We mask the diagonal out later anyway + # Cut off to avoid high variance. + dist_mat = torch.max(dist_mat, torch.tensor(self.cutoff).to(dist_mat.device)) + + # Subtract max(log(distance)) for stability. + # See the first equation from Section 4 of the paper + log_weights = (2.0 - float(d)) * torch.log(dist_mat) - ( + float(d - 3) / 2 + ) * torch.log(1.0 - 0.25 * (dist_mat ** 2.0)) + weights = torch.exp(log_weights - torch.max(log_weights)) + + # Sample only negative examples by setting weights of + # the same-class examples to 0. + mask = torch.ones(weights.size()).to(embeddings.device) + for i in label_set: + idx = (labels == i).nonzero() + mask[torch.meshgrid(idx.squeeze(1), idx.squeeze(1))] = 0 + + weights = weights * mask * ((dist_mat < self.nonzero_loss_cutoff).float()) + weights = weights / torch.sum(weights, dim=1, keepdim=True) + + np_weights = weights.cpu().numpy() + return lmu.get_random_triplet_indices(labels, weights=np_weights) diff --git a/miners/embeddings_already_packaged_as_triplets.py b/miners/embeddings_already_packaged_as_triplets.py new file mode 100644 index 00000000..9963342b --- /dev/null +++ b/miners/embeddings_already_packaged_as_triplets.py @@ -0,0 +1,18 @@ +#! /usr/bin/env python3 + +from . import base_miner as b_m +import torch + + +class EmbeddingsAlreadyPackagedAsTriplets(b_m.BasePostGradientMiner): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # If the embeddings are grouped by triplet, + # then use this miner to force the loss function to use the already-formed triplets + def mine(self, embeddings, labels): + batch_size = embeddings.size(0) + a = torch.arange(0, batch_size, 3) + p = torch.arange(1, batch_size, 3) + n = torch.arange(2, batch_size, 3) + return a, p, n diff --git a/miners/hdc_miner.py b/miners/hdc_miner.py new file mode 100644 index 00000000..f372f09d --- /dev/null +++ b/miners/hdc_miner.py @@ -0,0 +1,90 @@ +#! /usr/bin/env python3 +from . import base_miner as b_m +import torch +from ..utils import loss_and_miner_utils as lmu + + +# mining method used in Hard Aware Deeply Cascaded Embeddings +# https://arxiv.org/abs/1611.05720 +class HDCMiner(b_m.BasePostGradientMiner): + def __init__(self, filter_amounts, use_sim_mat=False, **kwargs): + super().__init__(**kwargs) + self.num_pos_pairs_per_round = torch.zeros(len(filter_amounts)) + self.num_neg_pairs_per_round = torch.zeros(len(filter_amounts)) + self.record_these = ["num_pos_pairs_per_round", "num_neg_pairs_per_round"] + self.filter_amounts = filter_amounts + self.use_sim_mat = use_sim_mat + self.i = 0 + self.reset_prev_idx() + + def mine(self, embeddings, labels): + if self.i == 0: + self.num_pos_pairs_per_round *= 0 + self.num_neg_pairs_per_round *= 0 + if self.use_sim_mat: + self.sim_mat = lmu.sim_mat(embeddings) + else: + self.dist_mat = lmu.dist_mat(embeddings, squared=False) + self.a1_idx, self.p_idx, self.a2_idx, self.n_idx = lmu.get_all_pairs_indices( + labels + ) + self.reset_prev_idx() + + self.maybe_set_to_prev() + curr_filter = self.filter_amounts[self.i] + if curr_filter != 1: + mat = self.sim_mat if self.use_sim_mat else self.dist_mat + pos_pair_ = mat[self.a1_idx, self.p_idx] + neg_pair_ = mat[self.a2_idx, self.n_idx] + + a1, p, a2, n = [], [], [], [] + + for name, v in {"pos": pos_pair_, "neg": neg_pair_}.items(): + num_pairs = len(v) + k = int(curr_filter * num_pairs) + largest = self.should_select_largest(name) + _, idx = torch.topk(v, k=k, largest=largest) + self.append_original_indices(name, idx, a1, p, a2, n) + + self.a1_idx = torch.cat(a1) + self.p_idx = torch.cat(p) + self.a2_idx = torch.cat(a2) + self.n_idx = torch.cat(n) + + self.num_pos_pairs_per_round[self.i] = len(self.a1_idx) + self.num_neg_pairs_per_round[self.i] = len(self.a2_idx) + self.set_prev_idx() + self.i = (self.i + 1) % len(self.filter_amounts) + return self.a1_idx, self.p_idx, self.a2_idx, self.n_idx + + def should_select_largest(self, name): + if self.use_sim_mat: + return False if name == "pos" else True + return True if name == "pos" else False + + def append_original_indices(self, name, idx, a1, p, a2, n): + if name == "pos": + a1.append(self.a1_idx[idx]) + p.append(self.p_idx[idx]) + else: + a2.append(self.a2_idx[idx]) + n.append(self.n_idx[idx]) + + def maybe_set_to_prev(self): + if self.prev_a1 is not None: + self.a1_idx = self.prev_a1 + self.p_idx = self.prev_p + self.a2_idx = self.prev_a2 + self.n_idx = self.prev_n + + def reset_prev_idx(self): + self.prev_a1 = None + self.prev_p = None + self.prev_a2 = None + self.prev_n = None + + def set_prev_idx(self, reset=False): + self.prev_a1 = self.a1_idx.clone() + self.prev_p = self.p_idx.clone() + self.prev_a2 = self.a2_idx.clone() + self.prev_n = self.n_idx.clone() diff --git a/miners/maximum_loss_miner.py b/miners/maximum_loss_miner.py new file mode 100644 index 00000000..bdbe4ed2 --- /dev/null +++ b/miners/maximum_loss_miner.py @@ -0,0 +1,29 @@ +#! /usr/bin/env python3 + + +from . import base_miner as b_m +from ..utils import loss_and_miner_utils as lmu +import numpy as np +import torch + +class MaximumLossMiner(b_m.BasePreGradientMiner): + def __init__(self, loss_function, mining_function=None, num_trials=5, **kwargs): + super().__init__(**kwargs) + self.loss_function = loss_function + self.mining_function = mining_function + self.num_trials = num_trials + + def mine(self, embeddings, labels): + losses = [] + rand_subset_idx = torch.randint(0, len(embeddings), size=(self.num_trials, self.output_batch_size)) + for i in range(self.num_trials): + curr_embeddings, curr_labels = embeddings[rand_subset_idx[i]], labels[rand_subset_idx[i]] + indices_tuple = self.inner_miner(curr_embeddings, curr_labels) + losses.append(self.loss_function(curr_embeddings, curr_labels, indices_tuple)) + max_loss_idx = np.argmax(losses) + return rand_subset_idx[max_loss_idx] + + def inner_miner(self, embeddings, labels): + if self.mining_function: + return self.mining_function(embeddings, labels) + return None \ No newline at end of file diff --git a/miners/multi_similarity_miner.py b/miners/multi_similarity_miner.py new file mode 100644 index 00000000..864a4ee2 --- /dev/null +++ b/miners/multi_similarity_miner.py @@ -0,0 +1,50 @@ +#! /usr/bin/env python3 + +from . import base_miner as b_m +from ..utils import loss_and_miner_utils as lmu +import torch + + +class MultiSimilarityMiner(b_m.BasePostGradientMiner): + def __init__(self, epsilon, **kwargs): + super().__init__(**kwargs) + self.epsilon = epsilon + + def mine(self, embeddings, labels): + self.n = embeddings.size(0) + self.index_list = torch.arange(self.n).to(embeddings.device) + self.sim_mat = lmu.sim_mat(embeddings) + return self.compute_indices(labels) + + def compute_indices(self, labels): + empty_tensor = torch.tensor([]).long().to(labels.device) + a1_idx, p_idx, a2_idx, n_idx = [empty_tensor], [empty_tensor], [empty_tensor], [empty_tensor] + for i in range(self.n): + pos_indices = ( + ((labels == labels[i]) & (self.index_list != i)).nonzero().flatten() + ) + neg_indices = (labels != labels[i]).nonzero().flatten() + + if pos_indices.size(0) == 0 or neg_indices.size(0) == 0: + continue + + pos_sorted, pos_sorted_idx = torch.sort(self.sim_mat[i][pos_indices]) + neg_sorted, neg_sorted_idx = torch.sort(self.sim_mat[i][neg_indices]) + neg_sorted_filtered_idx = ( + (neg_sorted + self.epsilon > pos_sorted[0]).nonzero().flatten() + ) + pos_sorted_filtered_idx = ( + (pos_sorted - self.epsilon < neg_sorted[-1]).nonzero().flatten() + ) + + pos_indices = pos_indices[pos_sorted_idx][pos_sorted_filtered_idx] + neg_indices = neg_indices[neg_sorted_idx][neg_sorted_filtered_idx] + + if len(pos_indices) > 0: + a1_idx.append(torch.ones_like(pos_indices) * i) + p_idx.append(pos_indices) + if len(neg_indices) > 0: + a2_idx.append(torch.ones_like(neg_indices) * i) + n_idx.append(neg_indices) + + return [torch.cat(idx) for idx in [a1_idx, p_idx, a2_idx, n_idx]] diff --git a/miners/pair_margin_miner.py b/miners/pair_margin_miner.py new file mode 100644 index 00000000..8a823ed9 --- /dev/null +++ b/miners/pair_margin_miner.py @@ -0,0 +1,49 @@ +#! /usr/bin/env python3 + +from . import base_miner as b_m +from ..utils import loss_and_miner_utils as lmu +import torch + + +class PairMarginMiner(b_m.BasePostGradientMiner): + """ + Returns positive pairs that have distance greater than a margin and negative + pairs that have distance less than a margin + """ + + def __init__( + self, pos_margin, neg_margin, use_similarity, squared_mat=False, **kwargs + ): + super().__init__(**kwargs) + self.pos_margin = pos_margin + self.neg_margin = neg_margin + self.use_similarity = use_similarity + self.squared_mat = squared_mat + self.pos_pair_dist = 0 + self.neg_pair_dist = 0 + self.record_these += ["pos_pair_dist", "neg_pair_dist"] + + def mine(self, embeddings, labels): + mat = ( + lmu.sim_mat(embeddings) + if self.use_similarity + else lmu.dist_mat(embeddings, squared=self.squared_mat) + ) + a1, p, a2, n = lmu.get_all_pairs_indices(labels) + pos_pair = mat[a1, p] + neg_pair = mat[a2, n] + self.pos_pair_dist = torch.mean(pos_pair).item() + self.neg_pair_dist = torch.mean(neg_pair).item() + pos_mask_condition = self.pos_filter(pos_pair, self.pos_margin) + neg_mask_condition = self.neg_filter(neg_pair, self.neg_margin) + a1 = torch.masked_select(a1, pos_mask_condition) + p = torch.masked_select(p, pos_mask_condition) + a2 = torch.masked_select(a2, neg_mask_condition) + n = torch.masked_select(n, neg_mask_condition) + return a1, p, a2, n + + def pos_filter(self, pos_pair, margin): + return pos_pair < margin if self.use_similarity else pos_pair > margin + + def neg_filter(self, neg_pair, margin): + return neg_pair > margin if self.use_similarity else neg_pair < margin diff --git a/samplers/__init__.py b/samplers/__init__.py new file mode 100644 index 00000000..b34306af --- /dev/null +++ b/samplers/__init__.py @@ -0,0 +1,2 @@ +from .m_per_class_sampler import MPerClassSampler +from .fixed_set_of_triplets import FixedSetOfTriplets \ No newline at end of file diff --git a/samplers/fixed_set_of_triplets.py b/samplers/fixed_set_of_triplets.py new file mode 100644 index 00000000..970a135c --- /dev/null +++ b/samplers/fixed_set_of_triplets.py @@ -0,0 +1,44 @@ +from torch.utils.data.sampler import Sampler +from ..utils import common_functions as c_f +import numpy as np + +class FixedSetOfTriplets(Sampler): + """ + Upon initialization, this will create num_triplets triplets based on + the labels provided in labels_to_indices. This is for experimental purposes, + to see how algorithms perform when the only ground truth is a set of + triplets, rather than having explicit labels. + """ + + def __init__(self, labels_to_indices, num_triplets, hierarchy_level=0): + self.create_fixed_set_of_triplets( + labels_to_indices[hierarchy_level], num_triplets + ) + + def __len__(self): + return self.fixed_set_of_triplets.shape[0] * 3 + + def __iter__(self): + np.random.shuffle(self.fixed_set_of_triplets) + flattened = self.fixed_set_of_triplets.flatten().tolist() + return iter(flattened) + + def create_fixed_set_of_triplets(self, labels_to_indices, num_triplets): + """ + This creates self.fixed_set_of_triplets, which is a numpy array of size + (num_triplets, 3). Each row is a triplet of indices: (a, p, n), where + a=anchor, p=positive, and n=negative. Each triplet is created by first + randomly sampling two classes, then randomly sampling an anchor, positive, + and negative. + """ + num_triplets = int(num_triplets) + assert num_triplets > 0 + self.fixed_set_of_triplets = np.ones((num_triplets, 3), dtype=np.int) * -1 + label_list = list(labels_to_indices.keys()) + for i in range(num_triplets): + anchor_label, negative_label = random.sample(label_list, 2) + anchor_list = labels_to_indices[anchor_label] + negative_list = labels_to_indices[negative_label] + anchor, positive = c_f.safe_random_choice(anchor_list, size=2) + negative = np.random.choice(negative_list, replace=False) + self.fixed_set_of_triplets[i, :] = np.array([anchor, positive, negative]) \ No newline at end of file diff --git a/samplers/m_per_class_sampler.py b/samplers/m_per_class_sampler.py new file mode 100644 index 00000000..2011ccd7 --- /dev/null +++ b/samplers/m_per_class_sampler.py @@ -0,0 +1,45 @@ +from torch.utils.data.sampler import Sampler +from ..utils import common_functions as c_f +import numpy as np + +# modified from +# https://raw.githubusercontent.com/bnulihaixia/Deep_metric/master/utils/sampler.py +class MPerClassSampler(Sampler): + """ + At every iteration, this will return m samples per class. For example, + if dataloader's batchsize is 100, and m = 5, then 20 classes with 5 samples + each will be returned + Args: + labels_to_indices: a dictionary mapping dataset labels to lists of + indices that have that label + m: the number of samples per class to fetch at every iteration. If a + class has less than m samples, then there will be duplicates + in the returned batch + hierarchy_level: which level of labels will be used to form each batch. + The default is 0, because most use-cases will have + 1 label per datapoint. But for example, iNat has 7 + labels per datapoint, in which case hierarchy_level could + be set to a number between 0 and 6. + """ + + def __init__(self, labels_to_indices, m, hierarchy_level=0): + self.m_per_class = int(m) + self.labels_to_indices = labels_to_indices + self.set_hierarchy_level(hierarchy_level) + + def __len__(self): + return len(self.labels) * self.m_per_class + + def __iter__(self): + ret = [] + for _ in range(1000): + np.random.shuffle(self.labels) + for label in self.labels: + t = self.curr_labels_to_indices[label] + t = c_f.safe_random_choice(t, size=self.m_per_class) + ret.extend(t) + return iter(ret) + + def set_hierarchy_level(self, hierarchy_level): + self.curr_labels_to_indices = self.labels_to_indices[hierarchy_level] + self.labels = list(self.curr_labels_to_indices.keys()) \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..3c4ef166 --- /dev/null +++ b/setup.py @@ -0,0 +1,22 @@ +import setuptools + +with open("README.md", "r") as fh: + long_description = fh.read() + +setuptools.setup( + name="pytorch_metric_learning", + version="0.9.0", + author="Kevin Musgrave", + author_email="tkm45@cornell.com", + description="A flexible and extensible metric learning library, written in PyTorch.", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/TakeshiMusgrave/pytorch_metric_learning", + packages=setuptools.find_packages(), + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + python_requires='>=3.7', +) \ No newline at end of file diff --git a/trainers/__init__.py b/trainers/__init__.py new file mode 100644 index 00000000..48fe5fe2 --- /dev/null +++ b/trainers/__init__.py @@ -0,0 +1,4 @@ +from .metric_loss_only import MetricLossOnly +from .train_with_classifier import TrainWithClassifier +from .cascaded_embeddings import CascadedEmbeddings +from .deep_adversarial_metric_learning import DeepAdversarialMetricLearning \ No newline at end of file diff --git a/trainers/base_trainer.py b/trainers/base_trainer.py new file mode 100644 index 00000000..ffe343f9 --- /dev/null +++ b/trainers/base_trainer.py @@ -0,0 +1,222 @@ +#! /usr/bin/env python3 + +import torch +from ..utils import common_functions as c_f, loss_tracker as l_t +import tqdm + +class BaseTrainer: + def __init__( + self, + models, + optimizers, + batch_size, + loss_funcs, + mining_funcs, + num_epochs, + iterations_per_epoch, + dataset, + data_device=None, + loss_weights=None, + label_mapper=None, + sampler=None, + collate_fn=None, + record_keeper=None, + lr_schedulers=None, + gradient_clippers=None, + freeze_trunk_batchnorm=False, + label_hierarchy_level=0, + dataloader_num_workers=32, + post_processor=None, + start_epoch=1, + ): + self.models = models + self.optimizers = optimizers + self.batch_size = batch_size + self.loss_funcs = loss_funcs + self.mining_funcs = mining_funcs + self.label_mapper = label_mapper + self.num_epochs = num_epochs + self.iterations_per_epoch = iterations_per_epoch + self.dataset = dataset + self.data_device = data_device + self.sampler = sampler + self.collate_fn = collate_fn + self.record_keeper = record_keeper + self.lr_schedulers = lr_schedulers + self.gradient_clippers = gradient_clippers + self.freeze_trunk_batchnorm = freeze_trunk_batchnorm + self.label_hierarchy_level = label_hierarchy_level + self.dataloader_num_workers = dataloader_num_workers + self.post_processor = post_processor + self.epoch = start_epoch + self.loss_weights = loss_weights + self.custom_setup() + self.initialize_data_device() + self.initialize_label_mapper() + self.initialize_post_processor() + self.initialize_loss_tracker() + self.initialize_dataloader() + self.initialize_loss_weights() + + def custom_setup(self): + pass + + def calculate_loss(self): + raise NotImplementedError + + def loss_names(self): + raise NotImplementedError + + def update_loss_weights(self): + pass + + def train(self): + self.set_to_train() + while self.epoch <= self.num_epochs: + print("TRAINING EPOCH %d" % self.epoch) + pbar = tqdm.tqdm(range(self.iterations_per_epoch)) + for self.iteration in pbar: + self.forward_and_backward() + pbar.set_description("total_loss=%.5f" % self.losses["total_loss"]) + self.step_lr_schedulers() + self.update_records(end_of_epoch=True) + self.epoch += 1 + + def initialize_dataloader(self): + self.dataloader = c_f.get_dataloader( + self.dataset, + self.batch_size, + self.sampler, + self.dataloader_num_workers, + self.collate_fn, + ) + self.dataloader_iter = iter(self.dataloader) + + def forward_and_backward(self): + self.zero_losses() + self.zero_grad() + self.update_loss_weights() + self.calculate_loss(self.get_batch()) + self.loss_tracker.update(self.loss_weights) + self.update_records() + self.backward() + self.clip_gradients() + self.step_optimizers() + + def zero_losses(self): + for k in self.losses.keys(): + self.losses[k] = 0 + + def zero_grad(self): + for v in self.models.values(): + v.zero_grad() + for v in self.optimizers.values(): + v.zero_grad() + + def get_batch(self): + self.dataloader_iter, curr_batch = c_f.try_next_on_generator( + self.dataloader_iter, self.dataloader) + curr_batch["label"] = c_f.process_label( + curr_batch["label"], self.label_hierarchy_level, self.label_mapper + ) + curr_batch = self.maybe_do_pre_gradient_mining(curr_batch) + return c_f.try_keys(curr_batch, ["data", "image"]), curr_batch["label"] + + def compute_embeddings(self, data, labels): + trunk_output = self.get_trunk_output(data) + embeddings = self.get_final_embeddings(trunk_output) + embeddings, labels = self.post_processor(embeddings, labels) + return embeddings, labels + + def get_final_embeddings(self, base_output): + return self.models["embedder"](base_output) + + def get_trunk_output(self, data): + return c_f.pass_data_to_model(self.models["trunk"], data, self.data_device) + + def maybe_mine_embeddings(self, embeddings, labels): + if "post_gradient_miner" in self.mining_funcs: + return self.mining_funcs["post_gradient_miner"](embeddings, labels) + return None + + def maybe_do_pre_gradient_mining(self, curr_batch): + if "pre_gradient_miner" in self.mining_funcs: + with torch.no_grad(): + self.set_to_eval() + data = c_f.try_keys(curr_batch, ["data", "image"]) + labels = curr_batch["label"] + embeddings, labels = self.compute_embeddings(data, labels) + idx = self.mining_funcs["pre_gradient_miner"](embeddings, labels) + self.set_to_train() + curr_batch = {"data": data[idx], "label": labels[idx]} + return curr_batch + + def backward(self): + if self.losses["total_loss"] > 0.0: + self.losses["total_loss"].backward() + + def get_global_iteration(self): + return self.iteration + self.iterations_per_epoch * (self.epoch - 1) + + def step_lr_schedulers(self): + if self.lr_schedulers is not None: + for v in self.lr_schedulers.values(): + v.step() + + def step_optimizers(self): + for v in self.optimizers.values(): + v.step() + + def clip_gradients(self): + if self.gradient_clippers is not None: + for v in self.gradient_clippers.values(): + v() + + def maybe_freeze_trunk_batchnorm(self): + if self.freeze_trunk_batchnorm: + self.models["trunk"].apply(c_f.set_layers_to_eval("BatchNorm")) + + def initialize_post_processor(self): + if self.post_processor is None: + self.post_processor = lambda embeddings, labels: (embeddings, labels) + + def initialize_data_device(self): + if self.data_device is None: + self.data_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def initialize_label_mapper(self): + if self.label_mapper is None: + self.label_mapper = lambda x: x + + def initialize_loss_tracker(self): + self.loss_tracker = l_t.LossTracker(self.loss_names()) + self.losses = self.loss_tracker.losses + + def set_to_train(self): + for k, v in self.models.items(): + self.models[k] = v.train() + self.maybe_freeze_trunk_batchnorm() + + def set_to_eval(self): + for k, v in self.models.items(): + self.models[k] = v.eval() + + def initialize_loss_weights(self): + if self.loss_weights is None: + self.loss_weights = {k: 1 for k in self.loss_names()} + + def update_records(self, end_of_epoch=False): + if self.record_keeper is not None: + if end_of_epoch: + self.record_keeper.maybe_add_custom_figures_to_tensorboard(self.get_global_iteration()) + else: + for record, kwargs in self.record_these(): + self.record_keeper.update_records(record, self.get_global_iteration(), **kwargs) + + def record_these(self): + return [[self.loss_tracker.losses, {"input_group_name_for_non_objects": "loss_histories"}], + [self.loss_tracker.loss_weights, {"input_group_name_for_non_objects": "loss_weights"}], + [self.loss_funcs, {}], + [self.mining_funcs, {}], + [self.models, {}], + [self.optimizers, {"custom_attr_func": lambda x: {"lr": x.param_groups[0]["lr"]}}]] diff --git a/trainers/cascaded_embeddings.py b/trainers/cascaded_embeddings.py new file mode 100644 index 00000000..4e5688d4 --- /dev/null +++ b/trainers/cascaded_embeddings.py @@ -0,0 +1,29 @@ +#! /usr/bin/env python3 + + +from . import train_with_classifier as twc + + +class CascadedEmbeddings(twc.TrainWithClassifier): + def __init__(self, embedding_sizes, logit_sizes=None, **kwargs): + super(CascadedEmbeddings, self).__init__(**kwargs) + self.embedding_sizes = embedding_sizes + self.logit_sizes = logit_sizes + + def calculate_loss(self, curr_batch): + data, labels = curr_batch + embeddings, labels = self.compute_embeddings(data, labels) + s = 0 + for curr_size in self.embedding_sizes: + e = embeddings[:, s : s + curr_size] + indices_tuple = self.maybe_mine_embeddings(e, labels) + self.losses["metric_loss"] += self.maybe_get_metric_loss(e, labels, indices_tuple) + s += curr_size + + logits = self.maybe_get_logits(embeddings) + if logits is not None: + s = 0 + for curr_size in self.logit_sizes: + L = logits[:, s : s + curr_size] + self.losses["classifier_loss"] += self.maybe_get_classifier_loss(L, labels) + s += curr_size \ No newline at end of file diff --git a/trainers/deep_adversarial_metric_learning.py b/trainers/deep_adversarial_metric_learning.py new file mode 100644 index 00000000..d694225b --- /dev/null +++ b/trainers/deep_adversarial_metric_learning.py @@ -0,0 +1,141 @@ +#! /usr/bin/env python3 + +from .. import miners +import torch +from ..utils import common_functions as c_f, loss_and_miner_utils as lmu + +from . import train_with_classifier as twc +import copy + +class DeepAdversarialMetricLearning(twc.TrainWithClassifier): + def __init__( + self, + metric_alone_epochs=0, + g_alone_epochs=0, + **kwargs + ): + super().__init__(**kwargs) + self.original_loss_weights = copy.deepcopy(self.loss_weights) + self.metric_alone_epochs = metric_alone_epochs + self.g_alone_epochs = g_alone_epochs + self.loss_funcs["G_neg_adv"].maybe_modify_loss = lambda x: x * -1 + + def custom_setup(self): + synth_packaged_as_triplets = miners.EmbeddingsAlreadyPackagedAsTriplets( + normalize_embeddings=False) + self.mining_funcs["synth_packaged_as_triplets"] = synth_packaged_as_triplets + + def calculate_loss(self, curr_batch): + data, labels = curr_batch + penultimate_embeddings = self.get_trunk_output(data) + + if self.do_metric: + authentic_final_embeddings = self.get_final_embeddings(penultimate_embeddings) + authentic_final_embeddings, labels = self.post_processor(authentic_final_embeddings, labels) + indices_tuple = self.maybe_mine_embeddings(authentic_final_embeddings, labels) + self.losses["metric_loss"] = self.loss_funcs["metric_loss"]( + authentic_final_embeddings, labels, indices_tuple + ) + logits = self.maybe_get_logits(authentic_final_embeddings) + self.losses["classifier_loss"] = self.maybe_get_classifier_loss(logits, labels) + + if self.do_adv: + self.calculate_synth_loss(penultimate_embeddings, labels) + + def loss_names(self): + return ["metric_loss", "classifier_loss", "synth_loss", "G_neg_hard", "G_neg_reg", "G_neg_adv"] + + def update_loss_weights(self): + self.do_metric_alone = self.epoch <= self.metric_alone_epochs + self.do_adv_alone = self.metric_alone_epochs < self.epoch <= self.g_alone_epochs + self.do_both = not self.do_adv_alone and not self.do_metric_alone + self.do_adv = self.do_adv_alone or self.do_both + self.do_metric = self.do_metric_alone or self.do_both + + non_zero_weight_list = [] + if self.do_adv: + non_zero_weight_list += ["G_neg_hard", "G_neg_reg", "G_neg_adv"] + if self.do_metric: + non_zero_weight_list += ["metric_loss", "classifier_loss"] + if self.do_both: + non_zero_weight_list += ["synth_loss"] + + for k in self.loss_weights.keys(): + if k in non_zero_weight_list: + self.loss_weights[k] = self.original_loss_weights[k] + else: + self.loss_weights[k] = 0 + + self.maybe_exclude_networks_from_gradient() + + def maybe_exclude_networks_from_gradient(self): + self.set_to_train() + self.maybe_freeze_trunk_batchnorm() + if self.do_adv_alone: + no_grad_list = ["trunk", "classifier"] + elif self.do_metric_alone: + no_grad_list = ["G_neg_model"] + else: + no_grad_list = [] + for k in self.models.keys(): + if k in no_grad_list: + c_f.set_requires_grad(self.models[k], requires_grad=False) + self.models[k].eval() + else: + c_f.set_requires_grad(self.models[k], requires_grad=True) + + + def step_optimizers(self): + step_list = [] + if self.do_metric: + step_list += ["trunk_optimizer", "embedder_optimizer", "classifier_optimizer"] + if self.do_adv: + step_list += ["G_neg_model_optimizer"] + for k in self.optimizers.keys(): + if k in step_list: + self.optimizers[k].step() + + def calculate_synth_loss(self, penultimate_embeddings, labels): + a_indices, p_indices, n_indices = lmu.get_random_triplet_indices(labels, t_per_anchor=10) + real_anchors = penultimate_embeddings[a_indices] + real_positives = penultimate_embeddings[p_indices] + real_negatives = penultimate_embeddings[n_indices] + penultimate_embeddings_cat = torch.cat([real_anchors, real_positives, real_negatives], dim=1) + synthetic_negatives = c_f.pass_data_to_model( + self.models["G_neg_model"], penultimate_embeddings_cat, self.data_device + ) + penultimate_embeddings_with_negative_synth = c_f.unslice_by_n( + [real_anchors, real_positives, synthetic_negatives] + ) + final_embeddings = self.get_final_embeddings(penultimate_embeddings_with_negative_synth) + + labels = torch.tensor( + [ + val + for tup in zip( + *[labels[a_indices], labels[p_indices], labels[n_indices]] + ) + for val in tup + ] + ) + + final_embeddings, labels = self.post_processor(final_embeddings, labels) + + indices_tuple = self.mining_funcs["synth_packaged_as_triplets"](final_embeddings, labels) + + if self.do_both: + self.losses["synth_loss"] = self.loss_funcs["synth_loss"]( + final_embeddings, labels, indices_tuple + ) + + self.losses["G_neg_adv"] = self.loss_funcs["G_neg_adv"]( + final_embeddings, labels, indices_tuple + ) + self.losses["G_neg_hard"] = torch.nn.functional.mse_loss( + torch.nn.functional.normalize(synthetic_negatives, p=2, dim=1), + torch.nn.functional.normalize(real_anchors, p=2, dim=1), + ) + self.losses["G_neg_reg"] = torch.nn.functional.mse_loss( + torch.nn.functional.normalize(synthetic_negatives, p=2, dim=1), + torch.nn.functional.normalize(real_negatives, p=2, dim=1), + ) diff --git a/trainers/metric_loss_only.py b/trainers/metric_loss_only.py new file mode 100644 index 00000000..6ebe19c0 --- /dev/null +++ b/trainers/metric_loss_only.py @@ -0,0 +1,21 @@ +#! /usr/bin/env python3 + + +from . import base_trainer as b_t + + +class MetricLossOnly(b_t.BaseTrainer): + def loss_names(self): + return ["metric_loss"] + + def calculate_loss(self, curr_batch): + data, labels = curr_batch + embeddings, labels = self.compute_embeddings(data, labels) + indices_tuple = self.maybe_mine_embeddings(embeddings, labels) + self.losses["metric_loss"] = self.maybe_get_metric_loss(embeddings, labels, indices_tuple) + + def maybe_get_metric_loss(self, embeddings, labels, indices_tuple): + if self.loss_weights.get("metric_loss", 0) > 0: + return self.loss_funcs["metric_loss"](embeddings, labels, indices_tuple) + return 0 + diff --git a/trainers/train_with_classifier.py b/trainers/train_with_classifier.py new file mode 100644 index 00000000..fb9ae1ad --- /dev/null +++ b/trainers/train_with_classifier.py @@ -0,0 +1,33 @@ +#! /usr/bin/env python3 + +from utils import common_functions as c_f + +from . import metric_loss_only as mlo + + +class TrainWithClassifier(mlo.MetricLossOnly): + def loss_names(self): + return ["metric_loss", "classifier_loss"] + + def calculate_loss(self, curr_batch): + data, labels = curr_batch + embeddings, labels = self.compute_embeddings(data, labels) + logits = self.maybe_get_logits(embeddings) + indices_tuple = self.maybe_mine_embeddings(embeddings, labels) + self.losses["metric_loss"] = self.maybe_get_metric_loss(embeddings, labels, indices_tuple) + self.losses["classifier_loss"] = self.maybe_get_classifier_loss(logits, labels) + + def maybe_get_classifier_loss(self, logits, labels): + if logits is not None: + return self.loss_funcs["classifier_loss"](logits, labels.to(logits.device)) + return 0 + + def maybe_get_logits(self, embeddings): + if self.loss_weights.get("classifier_loss",0) > 0: + return self.models["classifier"](embeddings) + return None + + + + + diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/utils/common_functions.py b/utils/common_functions.py new file mode 100644 index 00000000..da48e1fa --- /dev/null +++ b/utils/common_functions.py @@ -0,0 +1,221 @@ +import collections +import torch +from torch.autograd import Variable +import numpy as np +import pickle +import csv + +def save_pkl(obj, filename, protocol=None): + # https://stackoverflow.com/a/19201448 + if protocol is None: + protocol = pickle.HIGHEST_PROTOCOL + with open(filename, "wb") as f: + pickle.dump(obj, f, protocol) + + +def load_pkl(filename): + with open(filename, "rb") as f: + return pickle.load(f) + + +def write_dict_of_lists_to_csv(obj, filename): + # https://stackoverflow.com/a/23613603 + with open(filename, "w") as outfile: + writer = csv.writer(outfile) + writer.writerow(obj.keys()) + writer.writerows(zip(*obj.values())) + + +def convert_to_scalar(v): + try: + return v.detach().item() # pytorch + except: + try: + return v[0] # list or numpy + except: + return v # already a scalar + + +def convert_to_list(v): + try: + return v.detach().tolist() # pytorch + except: + try: + return list(v) # list or numpy + except: + return [v] # already a scalar + + +def try_get_len(v): + try: + return len(v) # most things + except: + try: + return v.nelement() # 0-d tensor + except: + return 0 # not a list + +def is_list_and_has_more_than_one_element(input_val): + return isinstance(input_val, collections.Sized) and try_get_len(input_val) > 1 + + +def try_keys(input_dict, keys): + for k in keys: + try: + return input_dict[k] + except BaseException: + pass + return None + + +def try_next_on_generator(gen, iterable): + try: + return gen, next(gen) + except StopIteration: + gen = iter(iterable) + return gen, next(gen) + + +def apply_func_to_dict(input, f): + if isinstance(input, collections.Mapping): + for k, v in input.items(): + input[k] = f(v) + return input + else: + return f(input) + + +def wrap_variable(batch_data, device): + def f(x): + return Variable(x).to(device) + + return apply_func_to_dict(batch_data, f) + + +def get_hierarchy_label(batch_labels, hierarchy_level): + def f(v): + try: + if v.ndim == 2: + v = v[:, hierarchy_level] + return v + except BaseException: + return v + + return apply_func_to_dict(batch_labels, f) + + +def numpy_to_torch(input): + def f(v): + try: + return torch.from_numpy(v) + except BaseException: + return v + + return apply_func_to_dict(input, f) + + +def torch_to_numpy(input): + def f(v): + try: + return v.cpu().numpy() + except BaseException: + return v + + return apply_func_to_dict(input, f) + + +def process_label(labels, hierarchy_level, label_map): + labels = get_hierarchy_label(labels, hierarchy_level) + labels = torch_to_numpy(labels) + labels = label_map(labels, hierarchy_level) + labels = numpy_to_torch(labels) + return labels + + +def pass_data_to_model(model, data, device, **kwargs): + if isinstance(data, collections.Mapping): + base_output = {} + for k, v in data.items(): + base_output[k] = model(wrap_variable(v, device), k=k, **kwargs) + return base_output + else: + return model(wrap_variable(data, device), **kwargs) + +def set_requires_grad(model, requires_grad): + for param in model.parameters(): + param.requires_grad = requires_grad + + +def try_getting_dataparallel_module(input_obj): + try: + return input_obj.module + except BaseException: + return input_obj + + +def copy_params_to_another_model(from_model, to_model): + params1 = from_model.named_parameters() + params2 = to_model.named_parameters() + dict_params2 = dict(params2) + for name1, param1 in params1: + if name1 in dict_params2: + dict_params2[name1].data.copy_(param1.data) + + +def safe_random_choice(input_data, size): + """ + Randomly samples without replacement from a sequence. It is "safe" because + if len(input_data) < size, it will randomly sample WITH replacement + Args: + input_data is a sequence, like a torch tensor, numpy array, + python list, tuple etc + size is the number of elements to randomly sample from input_data + Returns: + An array of size "size", randomly sampled from input_data + """ + replace = len(input_data) < size + return np.random.choice(input_data, size=size, replace=replace) + + +def longest_list(list_of_lists): + return max(list_of_lists, key=len) + + +def slice_by_n(input_array, n): + output = [] + for i in range(n): + output.append(input_array[i::n]) + return output + + +def unslice_by_n(input_tensors): + n = len(input_tensors) + rows, cols = input_tensors[0].size() + output = torch.zeros((rows * n, cols)).to(input_tensors[0].device) + for i in range(n): + output[i::n] = input_tensors[i] + return output + + +def set_layers_to_eval(layer_name): + def set_to_eval(m): + classname = m.__class__.__name__ + if classname.find(layer_name) != -1: + m.eval() + return set_to_eval + + +def get_dataloader(dataset, batch_size, sampler, num_workers, collate_fn): + return torch.utils.data.DataLoader( + dataset, + batch_size=int(batch_size), + sampler=sampler, + drop_last=True, + num_workers=num_workers, + collate_fn=collate_fn, + shuffle=sampler is None + ) + + +def try_torch_operation(torch_op, input_val): + return torch_op(input_val) if torch.is_tensor(input_val) else input_val \ No newline at end of file diff --git a/utils/loss_and_miner_utils.py b/utils/loss_and_miner_utils.py new file mode 100644 index 00000000..56a910d7 --- /dev/null +++ b/utils/loss_and_miner_utils.py @@ -0,0 +1,159 @@ +import torch +import numpy as np +import math +from . import common_functions as c_f + + +def sim_mat(x): + """ + returns a matrix where entry (i,j) is the dot product of x[i] and x[j] + """ + return torch.matmul(x, x.t()) + + +# https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/7 +def dist_mat(x, y=None, eps=1e-16, squared=False): + """ + Input: x is a Nxd matrix + y is an optional Mxd matirx + Output: dist is a NxM matrix where dist[i,j] + is the square norm between x[i,:] and y[j,:] + if y is not given then use 'y=x'. + i.e. dist[i,j] = ||x[i,:]-y[j,:]|| + """ + x_norm = (x ** 2).sum(1).view(-1, 1) + if y is not None: + y_t = torch.transpose(y, 0, 1) + y_norm = (y ** 2).sum(1).view(1, -1) + else: + y_t = torch.transpose(x, 0, 1) + y_norm = x_norm.view(1, -1) + + dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t) + # Ensure diagonal is zero if x=y + if y is None: + dist = dist - torch.diag(dist.diag()) + dist = torch.clamp(dist, 0.0, np.inf) + if not squared: + mask = (dist == 0).float() + dist = dist + mask * eps + dist = torch.sqrt(dist) + dist = dist * (1.0 - mask) + return dist + + +def get_all_pairs_indices(labels): + """ + Given a tensor of labels, this will return 4 tensors. + The first 2 tensors are the indices which form all positive pairs + The second 2 tensors are the indices which form all negative pairs + """ + labels1 = labels.unsqueeze(1).expand(labels.size(0), labels.size(0)) + labels2 = labels.unsqueeze(0).expand(labels.size(0), labels.size(0)) + matches = (labels1 == labels2).byte() + diffs = 1 - matches + matches -= torch.eye(matches.size(0)).byte().to(labels.device) + a1_idx = matches.nonzero()[:, 0].flatten() + p_idx = matches.nonzero()[:, 1].flatten() + a2_idx = diffs.nonzero()[:, 0].flatten() + n_idx = diffs.nonzero()[:, 1].flatten() + return a1_idx, p_idx, a2_idx, n_idx + + +def convert_to_pairs(indices_tuple, labels): + """ + This returns anchor-positive and anchor-negative indices, + regardless of what the input indices_tuple is + Args: + indices_tuple: tuple of tensors. Each tensor is 1d and specifies indices + within a batch + labels: a tensor which has the label for each element in a batch + """ + if indices_tuple is None: + return get_all_pairs_indices(labels) + elif len(indices_tuple) == 4: + return indices_tuple + else: + a, p, n = indices_tuple + return a, p, a, n + + +def convert_to_pos_pairs_with_unique_labels(indices_tuple, labels): + a, p, _, _ = convert_to_pairs(indices_tuple, labels) + _, unique_idx = np.unique(labels[a].cpu().numpy(), return_index=True) + return a[unique_idx], p[unique_idx] + + +# sample triplets, with a weighted distribution if weights is specified. +def get_random_triplet_indices(labels, t_per_anchor=None, weights=None): + a_idx, p_idx, n_idx = [], [], [] + batch_size = labels.size(0) + labels = labels.cpu().numpy() + label_count = dict(zip(*np.unique(labels, return_counts=True))) + indices = np.arange(batch_size) + for i, label in enumerate(labels): + curr_label_count = label_count[label] + if curr_label_count == 1: + continue + k = curr_label_count - 1 if t_per_anchor is None else t_per_anchor + + if weights is not None and not np.any(np.isnan(weights[i])): + n_idx += np.random.choice(batch_size, k, p=weights[i]).tolist() + else: + possible_n_idx = list(np.where(labels != label)[0]) + n_idx += np.random.choice(possible_n_idx, k).tolist() + + a_idx.extend([i] * k) + curr_p_idx = c_f.safe_random_choice(np.where((labels == label) & (indices != i))[0], k) + p_idx.extend(curr_p_idx.tolist()) + + return ( + torch.LongTensor(a_idx), + torch.LongTensor(p_idx), + torch.LongTensor(n_idx), + ) + + +def repeat_to_match_size(smaller_set, larger_size, smaller_size): + num_repeat = math.ceil(float(larger_size) / float(smaller_size)) + return smaller_set.repeat(num_repeat)[:larger_size] + + +def matched_size_indices(curr_p_idx, curr_n_idx): + num_pos_pairs = len(curr_p_idx) + num_neg_pairs = len(curr_n_idx) + if num_pos_pairs > num_neg_pairs: + n_idx = repeat_to_match_size(curr_n_idx, num_pos_pairs, num_neg_pairs) + p_idx = curr_p_idx + else: + p_idx = repeat_to_match_size(curr_p_idx, num_neg_pairs, num_pos_pairs) + n_idx = curr_n_idx + return p_idx, n_idx + + +def convert_to_triplets(indices_tuple, labels): + """ + This returns anchor-positive-negative triplets + regardless of what the input indices_tuple is + """ + if indices_tuple is None: + return get_random_triplet_indices(labels, t_per_anchor=10) + elif len(indices_tuple) == 3: + return indices_tuple + else: + a_out, p_out, n_out = [], [], [] + a1, p, a2, n = indices_tuple + if len(a1) == 0 or len(a2) == 0: + return [torch.tensor([]).to(labels.device)] * 3 + for i in range(len(labels)): + pos_idx = (a1 == i).nonzero().flatten() + neg_idx = (a2 == i).nonzero().flatten() + if len(pos_idx) > 0 and len(neg_idx) > 0: + p_idx = p[pos_idx] + n_idx = n[neg_idx] + p_idx, n_idx = matched_size_indices(p_idx, n_idx) + a_idx = torch.ones_like(c_f.longest_list([p_idx, n_idx])) * i + a_out.append(a_idx) + p_out.append(p_idx) + n_out.append(n_idx) + return [torch.cat(x, dim=0) for x in [a_out, p_out, n_out]] diff --git a/utils/loss_tracker.py b/utils/loss_tracker.py new file mode 100644 index 00000000..95aa6376 --- /dev/null +++ b/utils/loss_tracker.py @@ -0,0 +1,33 @@ +#! /usr/bin/env python3 + + +class LossTracker: + def __init__(self, loss_names): + if "total_loss" not in loss_names: + loss_names.append("total_loss") + self.losses = {key: 0 for key in loss_names} + self.loss_weights = {key: 1 for key in loss_names} + + def weight_the_losses(self, exclude_loss=("total_loss")): + for k, _ in self.losses.items(): + if k not in exclude_loss: + self.losses[k] *= self.loss_weights[k] + + def get_total_loss(self, exclude_loss=("total_loss")): + self.losses["total_loss"] = 0 + for k, v in self.losses.items(): + if k not in exclude_loss: + self.losses["total_loss"] += v + + def set_loss_weights(self, loss_weight_dict): + for k, _ in self.losses.items(): + if k in loss_weight_dict: + w = loss_weight_dict[k] + else: + w = 1.0 + self.loss_weights[k] = w + + def update(self, loss_weight_dict): + self.set_loss_weights(loss_weight_dict) + self.weight_the_losses() + self.get_total_loss() diff --git a/utils/misc_models.py b/utils/misc_models.py new file mode 100644 index 00000000..c0e3df91 --- /dev/null +++ b/utils/misc_models.py @@ -0,0 +1,23 @@ +import torch.nn as nn +import torch + + +class ListOfModels(nn.Module): + def __init__(self, list_of_models, input_sizes=None): + super().__init__() + self.list_of_models = nn.ModuleList(list_of_models) + self.input_sizes = input_sizes + + def forward(self, x): + outputs = [] + if self.input_sizes is None: + for m in self.list_of_models: + outputs.append(m(x)) + return torch.cat(outputs, dim=-1) + else: + s = 0 + for i, y in enumerate(self.input_sizes): + curr_input = x[:, s : s + y] + outputs.append(self.list_of_models[i](curr_input)) + s += y + return torch.cat(outputs, dim=-1) diff --git a/utils/record_keeper.py b/utils/record_keeper.py new file mode 100644 index 00000000..5e978438 --- /dev/null +++ b/utils/record_keeper.py @@ -0,0 +1,105 @@ +#! /usr/bin/env python3 + +from . import common_functions as c_f +import collections +import matplotlib.pyplot as plt +import numpy as np + + +class RecordKeeper: + def __init__(self, tensorboard_writer=None, pickler_and_csver=None): + self.tensorboard_writer = tensorboard_writer + self.pickler_and_csver = pickler_and_csver + + def append_data(self, group_name, series_name, value, iteration): + if self.tensorboard_writer is not None: + tag_name = '%s/%s' % (group_name, series_name) + if not c_f.is_list_and_has_more_than_one_element(value): + self.tensorboard_writer.add_scalar(tag_name, value, iteration) + if self.pickler_and_csver is not None: + self.pickler_and_csver.append(group_name, series_name, value) + + def update_records(self, record_these, global_iteration, custom_attr_func=None, input_group_name_for_non_objects=None): + for name_in_dict, input_obj in record_these.items(): + + if input_group_name_for_non_objects is not None: + group_name = input_group_name_for_non_objects + self.append_data(group_name, name_in_dict, input_obj, global_iteration) + else: + the_obj = c_f.try_getting_dataparallel_module(input_obj) + attr_list = self.get_attr_list_for_record_keeper(the_obj) + for k in attr_list: + v = getattr(the_obj, k) + name = self.get_record_name(name_in_dict, the_obj) + self.append_data(name, k, v, global_iteration) + if custom_attr_func is not None: + for k, v in custom_attr_func(the_obj).items(): + name = self.get_record_name(name_in_dict, the_obj) + self.append_data(name, k, v, global_iteration) + + + def get_attr_list_for_record_keeper(self, input_obj): + attr_list = [] + obj_attr_list_names = ["record_these", "learnable_param_names"] + for k in obj_attr_list_names: + if (hasattr(input_obj, k)) and (getattr(input_obj, k) is not None): + attr_list += getattr(input_obj, k) + return attr_list + + def get_record_name(self, name_in_dict, input_obj, key_name=None): + record_name = "%s_%s" % (name_in_dict, type(input_obj).__name__) + if key_name: + record_name += '_%s' % key_name + return record_name + + def maybe_add_custom_figures_to_tensorboard(self, global_iteration): + if self.pickler_and_csver is not None: + for group_name, dict_of_lists in self.pickler_and_csver.records.items(): + for series_name, v in dict_of_lists.data.items(): + if isinstance(v[0], list): + tag_name = '%s/%s' % (group_name, series_name) + figure = self.multi_line_plot(v) + self.tensorboard_writer.add_figure(tag_name, figure, global_iteration) + + def multi_line_plot(self, list_of_lists): + # Each sublist represents a snapshot at an iteration. + # Transpose so that each row covers many iterations. + numpified = np.transpose(np.array(list_of_lists)) + fig = plt.figure() + for sublist in numpified: + plt.plot(np.arange(numpified.shape[1]), sublist) + return fig + + +class DictOfLists: + def __init__(self): + self.data = collections.defaultdict(list) + + def append(self, series_name, input_val): + if c_f.is_list_and_has_more_than_one_element(input_val): + self.data[series_name].append(c_f.convert_to_list(input_val)) + else: + self.data[series_name].append(c_f.convert_to_scalar(input_val)) + + +class PicklerAndCSVer: + def __init__(self, folder): + self.records = collections.defaultdict(DictOfLists) + self.folder = folder + + def append(self, group_name, series_name, input_val): + self.records[group_name].append(series_name, input_val) + + def save_records(self): + for k, v in self.records.items(): + base_filename = "%s/%s" % (self.folder, k) + c_f.save_pkl(v.data, base_filename+".pkl") + c_f.write_dict_of_lists_to_csv(v.data, base_filename+".csv") + + def load_records(self, num_records_to_load=None): + for k, _ in self.records.items(): + filename = "%s/%s.pkl"%(self.folder,k) + self.records[k].data = c_f.load_pkl(filename) + if num_records_to_load is not None: + for zzz, _ in self.records[k].data.items(): + self.records[k].data[zzz] = self.records[k].data[zzz][:num_records_to_load]