Skip to content

Commit

Permalink
Merge pull request #543 from KevinMusgrave/dev
Browse files Browse the repository at this point in the history
v1.6.3
  • Loading branch information
Kevin Musgrave authored Nov 1, 2022
2 parents 64f47ba + 0ecb5d2 commit d3feece
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.6.2"
__version__ = "1.6.3"
4 changes: 4 additions & 0 deletions src/pytorch_metric_learning/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ def __init__(self, miner, efficient=False):
self.efficient = efficient

def forward(self, emb, labels, ref_emb=None, ref_labels=None):
world_size = torch.distributed.get_world_size()
if world_size <= 1:
return self.miner(emb, labels, ref_emb, ref_labels)

all_emb, all_labels, all_ref_emb, all_ref_labels, labels = gather_emb_and_ref(
emb, labels, ref_emb, ref_labels
)
Expand Down
15 changes: 15 additions & 0 deletions tests/utils/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,21 @@ def test_distributed_tuple_loss_and_miner_efficient(self):
miner_kwargs={"pos_margin": 0.5, "neg_margin": 0.5},
)

def test_single_proc(self):
setup(0, 1)
_loss_fn = ContrastiveLoss()
_miner_fn = PairMarginMiner()
loss_fn = distributed.DistributedLossWrapper(loss=_loss_fn)
miner_fn = distributed.DistributedMinerWrapper(miner=_miner_fn)

emb = torch.randn(32, 128, device=TEST_DEVICE)
labels = torch.randint(0, 3, size=(32,))
pairs = miner_fn(emb, labels)
loss = loss_fn(emb, labels, indices_tuple=pairs)
cleanup()

self.assertEqual(loss, _loss_fn(emb, indices_tuple=_miner_fn(emb, labels)))


if __name__ == "__main__":
unittest.main()

0 comments on commit d3feece

Please sign in to comment.