Skip to content

Commit db4f455

Browse files
Applying Center Loss (#213)
* Applying Center Loss * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed minor typos and errors + Center loss implementation * sphinx version change --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2371ddf commit db4f455

11 files changed

+126
-3
lines changed

docs/generate_docs_netlify.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ poetry build -f wheel
1313
pip install dist/$(ls -1 dist | grep .whl)
1414
pip install pytorch-metric-learning==1.3.2
1515

16-
pip install sphinx>=5.0.1
16+
pip install sphinx==6.1.3
1717
pip install "git+https://github.com/qdrant/qdrant_sphinx_theme.git@master#egg=qdrant-sphinx-theme"
1818

1919
sphinx-apidoc --force --separate --no-toc -o docs/source quaterion

docs/source/api/index.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,9 @@ Implementations
128128
~softmax_loss.SoftmaxLoss
129129
~triplet_loss.TripletLoss
130130
~circle_loss.CircleLoss
131-
~fastap_loss.FastAPLoss
131+
~fast_ap_loss.FastAPLoss
132132
~cos_face_loss.CosFaceLoss
133+
~center_loss.CenterLoss
133134

134135
Extras
135136
++++++
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
quaterion.loss.center\_loss module
2+
==================================
3+
4+
.. automodule:: quaterion.loss.center_loss
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
quaterion.loss.circle\_loss module
2+
==================================
3+
4+
.. automodule:: quaterion.loss.circle_loss
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
quaterion.loss.cos\_face\_loss module
2+
=====================================
3+
4+
.. automodule:: quaterion.loss.cos_face_loss
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
quaterion.loss.fast\_ap\_loss module
2+
====================================
3+
4+
.. automodule:: quaterion.loss.fast_ap_loss
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:

docs/source/quaterion.loss.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@ Submodules
1616
:maxdepth: 4
1717

1818
quaterion.loss.arcface_loss
19+
quaterion.loss.center_loss
20+
quaterion.loss.circle_loss
1921
quaterion.loss.contrastive_loss
22+
quaterion.loss.cos_face_loss
23+
quaterion.loss.fast_ap_loss
2024
quaterion.loss.group_loss
2125
quaterion.loss.multiple_negatives_ranking_loss
2226
quaterion.loss.online_contrastive_loss

docs/source/tutorials/triplet_loss_trick.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
Triplet Loss: Vector Collapse Prevention
2-
============================
2+
========================================
33

44
Triplet Loss is one of the most widely known loss functions in similarity learning.
55
If you want to deep-dive into the details of its implementations and advantages,

quaterion/loss/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from quaterion.loss.arcface_loss import ArcFaceLoss
2+
from quaterion.loss.center_loss import CenterLoss
23
from quaterion.loss.circle_loss import CircleLoss
34
from quaterion.loss.contrastive_loss import ContrastiveLoss
45
from quaterion.loss.cos_face_loss import CosFaceLoss

quaterion/loss/center_loss.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from typing import Optional
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
from torch import LongTensor, Tensor
7+
8+
from quaterion.loss.group_loss import GroupLoss
9+
from quaterion.utils import l2_norm
10+
11+
12+
class CenterLoss(GroupLoss):
13+
"""
14+
Center Loss as defined in the paper "A Discriminative Feature Learning Approach
15+
for Deep Face Recognition" (http://ydwen.github.io/papers/WenECCV16.pdf)
16+
It aims to minimize the intra-class variations while keeping the features of
17+
different classes separable.
18+
19+
Args:
20+
embedding_size: Output dimension of the encoder.
21+
num_groups: Number of groups (classes) in the dataset.
22+
lambda_c: A regularization parameter that controls the contribution of the center loss.
23+
"""
24+
25+
def __init__(
26+
self, embedding_size: int, num_groups: int, lambda_c: Optional[float] = 0.5
27+
):
28+
super(GroupLoss, self).__init__()
29+
self.num_groups = num_groups
30+
self.centers = nn.Parameter(torch.randn(num_groups, embedding_size))
31+
self.lambda_c = lambda_c
32+
33+
nn.init.xavier_uniform_(self.centers)
34+
35+
def forward(self, embeddings: Tensor, groups: LongTensor) -> Tensor:
36+
"""
37+
Compute the Center Loss value.
38+
39+
Args:
40+
embeddings: shape (batch_size, vector_length) - Output embeddings from the encoder.
41+
groups: shape (batch_size,) - Group (class) ids associated with embeddings.
42+
43+
Returns:
44+
Tensor: loss value.
45+
"""
46+
embeddings = l2_norm(embeddings, 1)
47+
48+
# Gather the center for each embedding's corresponding group
49+
centers_batch = self.centers.index_select(0, groups)
50+
51+
# Calculate the distance between embeddings and their respective class centers
52+
loss = F.mse_loss(embeddings, centers_batch)
53+
54+
# Scale the loss by the regularization parameter
55+
loss *= self.lambda_c
56+
57+
return loss

tests/eval/losses/test_center_loss.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import torch
2+
3+
from quaterion.loss import CenterLoss
4+
5+
6+
class TestCenterLoss:
7+
embeddings = torch.Tensor(
8+
[
9+
[0.0, -1.0, 0.5],
10+
[0.1, 2.0, 0.5],
11+
[0.0, 0.3, 0.2],
12+
[1.0, 0.0, 0.9],
13+
[1.2, -1.2, 0.01],
14+
[-0.7, 0.0, 1.5],
15+
]
16+
)
17+
groups = torch.LongTensor([1, 2, 0, 0, 2, 1])
18+
19+
def test_batch_all(self):
20+
# Initialize the CenterLoss
21+
loss = CenterLoss(embedding_size=self.embeddings.size()[1], num_groups=3)
22+
23+
# Calculate the loss
24+
loss_res = loss.forward(embeddings=self.embeddings, groups=self.groups)
25+
26+
# Assertions to check the output shape and type
27+
assert isinstance(
28+
loss_res, torch.Tensor
29+
), "Loss result should be a torch.Tensor"
30+
assert loss_res.shape == torch.Size(
31+
[]
32+
), "Loss result should be a scalar (0-dimension tensor)"

0 commit comments

Comments
 (0)