|
| 1 | +""" |
| 2 | +------------------------------------------------- |
| 3 | + File Name: triplet.py |
| 4 | + Author: Zhonghao Huang |
| 5 | + Date: 2019/9/10 |
| 6 | + Description: |
| 7 | +------------------------------------------------- |
| 8 | +""" |
| 9 | + |
| 10 | +import torch |
| 11 | +import torch.nn as nn |
| 12 | + |
| 13 | + |
| 14 | +def topk_mask(input, dim, K=10, **kwargs): |
| 15 | + index = input.topk(max(1, min(K, input.size(dim))), dim=dim, **kwargs)[1] |
| 16 | + return torch.autograd.Variable(torch.zeros_like(input.data)).scatter(dim, index, 1.0) |
| 17 | + |
| 18 | + |
| 19 | +def pdist(A, squared=False, eps=1e-4): |
| 20 | + prod = torch.mm(A, A.t()) |
| 21 | + norm = prod.diag().unsqueeze(1).expand_as(prod) |
| 22 | + res = (norm + norm.t() - 2 * prod).clamp(min=0) |
| 23 | + return res if squared else res.clamp(min=eps).sqrt() |
| 24 | + |
| 25 | + |
| 26 | +def normalize(x, axis=-1): |
| 27 | + """Normalizing to unit length along the specified dimension. |
| 28 | + Args: |
| 29 | + x: pytorch Variable |
| 30 | + Returns: |
| 31 | + x: pytorch Variable, same shape as input |
| 32 | + """ |
| 33 | + x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) |
| 34 | + return x |
| 35 | + |
| 36 | + |
| 37 | +def euclidean_dist(x, y): |
| 38 | + """ |
| 39 | + Args: |
| 40 | + x: pytorch Variable, with shape [m, d] |
| 41 | + y: pytorch Variable, with shape [n, d] |
| 42 | + Returns: |
| 43 | + dist: pytorch Variable, with shape [m, n] |
| 44 | + """ |
| 45 | + m, n = x.size(0), y.size(0) |
| 46 | + xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) |
| 47 | + yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() |
| 48 | + dist = xx + yy |
| 49 | + dist.addmm_(1, -2, x, y.t()) |
| 50 | + dist = dist.clamp(min=1e-12).sqrt() # for numerical stability |
| 51 | + return dist |
| 52 | + |
| 53 | + |
| 54 | +def hard_example_mining(dist_mat, labels, margin, return_inds=False): |
| 55 | + """For each anchor, find the hardest positive and negative sample. |
| 56 | + Args: |
| 57 | + dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] |
| 58 | + labels: pytorch LongTensor, with shape [N] |
| 59 | + return_inds: whether to return the indices. Save time if `False`(?) |
| 60 | + Returns: |
| 61 | + dist_ap: pytorch Variable, distance(anchor, positive); shape [N] |
| 62 | + dist_an: pytorch Variable, distance(anchor, negative); shape [N] |
| 63 | + p_inds: pytorch LongTensor, with shape [N]; |
| 64 | + indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 |
| 65 | + n_inds: pytorch LongTensor, with shape [N]; |
| 66 | + indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 |
| 67 | + NOTE: Only consider the case in which all labels have same num of samples, |
| 68 | + thus we can cope with all anchors in parallel. |
| 69 | + """ |
| 70 | + |
| 71 | + torch.set_printoptions(threshold=5000) |
| 72 | + assert len(dist_mat.size()) == 2 |
| 73 | + assert dist_mat.size(0) == dist_mat.size(1) |
| 74 | + N = dist_mat.size(0) |
| 75 | + |
| 76 | + # shape [N, N] |
| 77 | + is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) |
| 78 | + is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) |
| 79 | + # `dist_ap` means distance(anchor, positive) |
| 80 | + # both `dist_ap` and `relative_p_inds` with shape [N, 1] |
| 81 | + dist_ap, relative_p_inds = torch.max( |
| 82 | + dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) |
| 83 | + # `dist_an` means distance(anchor, negative) |
| 84 | + # both `dist_an` and `relative_n_inds` with shape [N, 1] |
| 85 | + dist_an, relative_n_inds = torch.min( |
| 86 | + dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) |
| 87 | + # shape [N] |
| 88 | + dist_ap = dist_ap.squeeze(1) |
| 89 | + dist_an = dist_an.squeeze(1) |
| 90 | + |
| 91 | + if return_inds: |
| 92 | + # shape [N, N] |
| 93 | + ind = (labels.new().resize_as_(labels) |
| 94 | + .copy_(torch.arange(0, N).long()) |
| 95 | + .unsqueeze(0).expand(N, N)) |
| 96 | + # shape [N, 1] |
| 97 | + p_inds = torch.gather( |
| 98 | + ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) |
| 99 | + n_inds = torch.gather( |
| 100 | + ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) |
| 101 | + # shape [N] |
| 102 | + p_inds = p_inds.squeeze(1) |
| 103 | + n_inds = n_inds.squeeze(1) |
| 104 | + return dist_ap, dist_an, p_inds, n_inds |
| 105 | + |
| 106 | + return dist_ap, dist_an |
| 107 | + |
| 108 | + |
| 109 | +class TripletLoss(object): |
| 110 | + """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). |
| 111 | + Related Triplet Loss theory can be found in paper 'In Defense of the Triplet |
| 112 | + Loss for Person Re-Identification'.""" |
| 113 | + |
| 114 | + def __init__(self, margin=None): |
| 115 | + self.margin = margin |
| 116 | + if margin is not None: |
| 117 | + self.ranking_loss = nn.MarginRankingLoss(margin=margin) |
| 118 | + else: |
| 119 | + self.ranking_loss = nn.SoftMarginLoss() |
| 120 | + |
| 121 | + def __call__(self, global_feat, labels, normalize_feature=False): |
| 122 | + if normalize_feature: |
| 123 | + global_feat = normalize(global_feat, axis=-1) |
| 124 | + dist_mat = euclidean_dist(global_feat, global_feat) |
| 125 | + dist_ap, dist_an = hard_example_mining(dist_mat, labels, self.margin) |
| 126 | + y = dist_an.new().resize_as_(dist_an).fill_(1) |
| 127 | + if self.margin is not None: |
| 128 | + loss = self.ranking_loss(dist_an, dist_ap, y) |
| 129 | + else: |
| 130 | + loss = self.ranking_loss(dist_an - dist_ap, y) |
| 131 | + return loss, dist_ap, dist_an |
0 commit comments