Skip to content

Commit

Permalink
Merge pull request #503 from NoTody/master
Browse files Browse the repository at this point in the history
add compatibility of distributedwrapper for two stream input
  • Loading branch information
Kevin Musgrave authored Sep 3, 2022
2 parents a8cb8fa + 798ceb2 commit a5cede7
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 71 deletions.
93 changes: 69 additions & 24 deletions src/pytorch_metric_learning/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,36 @@ def all_gather(x):


# modified from https://github.com/JohnGiorgi/DeCLUTR
def all_gather_embeddings_and_labels(embeddings, labels):
def all_gather_embeddings_and_labels(emb, labels):
# If we are not using distributed training, this is a no-op.
if not is_distributed():
return None, None
ref_emb = all_gather(embeddings)
ref_emb = all_gather(emb)
ref_labels = all_gather(labels)
return ref_emb, ref_labels


def gather(embeddings, labels):
device = embeddings.device
def gather(emb, labels):
device = emb.device
labels = c_f.to_device(labels, device=device)
rank = torch.distributed.get_rank()
dist_ref_emb, dist_ref_labels = all_gather_embeddings_and_labels(embeddings, labels)
all_emb = torch.cat([embeddings, dist_ref_emb], dim=0)
all_labels = torch.cat([labels, dist_ref_labels], dim=0)
return all_emb, all_labels, labels, device
dist_emb, dist_labels = all_gather_embeddings_and_labels(emb, labels)
all_emb = torch.cat([emb, dist_emb], dim=0)
all_labels = torch.cat([labels, dist_labels], dim=0)
return all_emb, all_labels, labels


def get_indices_tuple(
labels, ref_labels, device, embeddings=None, ref_emb=None, miner=None
):
def gather_emb_and_ref(emb, labels, ref_emb=None, ref_labels=None):
all_emb, all_labels, labels = gather(emb, labels)
all_ref_emb, all_ref_labels = None, None

if ref_emb is not None and ref_labels is not None:
all_ref_emb, all_ref_labels, _ = gather(ref_emb, ref_labels)

return all_emb, all_labels, all_ref_emb, all_ref_labels, labels


def get_indices_tuple(labels, ref_labels, embeddings=None, ref_emb=None, miner=None):
device = labels.device
curr_batch_idx = torch.arange(len(labels), device=device)
if miner:
indices_tuple = miner(embeddings, labels, ref_emb, ref_labels)
Expand All @@ -55,6 +63,10 @@ def get_indices_tuple(
return lmu.remove_self_comparisons(indices_tuple, curr_batch_idx, len(ref_labels))


def select_ref_or_regular(regular, ref):
return regular if ref is None else ref


class DistributedLossWrapper(torch.nn.Module):
def __init__(self, loss, efficient=False):
super().__init__()
Expand All @@ -69,22 +81,53 @@ def __init__(self, loss, efficient=False):
self.loss = loss
self.efficient = efficient

def forward(self, embeddings, labels, indices_tuple=None):
def forward(self, emb, labels, indices_tuple=None, ref_emb=None, ref_labels=None):
world_size = torch.distributed.get_world_size()
common_args = [emb, labels, indices_tuple, ref_emb, ref_labels, world_size]
if isinstance(self.loss, CrossBatchMemory):
return self.forward_cross_batch(*common_args)
return self.forward_regular_loss(*common_args)

def forward_regular_loss(
self, emb, labels, indices_tuple, ref_emb, ref_labels, world_size
):
if world_size <= 1:
return self.loss(embeddings, labels, indices_tuple)
return self.loss(emb, labels, indices_tuple, ref_emb, ref_labels)

all_emb, all_labels, labels, device = gather(embeddings, labels)
all_emb, all_labels, all_ref_emb, all_ref_labels, labels = gather_emb_and_ref(
emb, labels, ref_emb, ref_labels
)

if self.efficient:
all_labels = select_ref_or_regular(all_labels, all_ref_labels)
all_emb = select_ref_or_regular(all_emb, all_ref_emb)
if indices_tuple is None:
indices_tuple = get_indices_tuple(labels, all_labels, device)
loss = self.loss(embeddings, labels, indices_tuple, all_emb, all_labels)
indices_tuple = get_indices_tuple(labels, all_labels)
loss = self.loss(emb, labels, indices_tuple, all_emb, all_labels)
else:
loss = self.loss(all_emb, all_labels, indices_tuple)
loss = self.loss(
all_emb, all_labels, indices_tuple, all_ref_emb, all_ref_labels
)

return loss * world_size

def forward_cross_batch(
self, emb, labels, indices_tuple, ref_emb, ref_labels, world_size
):
if ref_emb is not None or ref_labels is not None:
raise ValueError(
"CrossBatchMemory is not compatible with ref_emb and ref_labels"
)

if world_size <= 1:
return self.loss(emb, labels, indices_tuple)

all_emb, all_labels, _, _, _ = gather_emb_and_ref(
emb, labels, ref_emb, ref_labels
)
loss = self.loss(all_emb, all_labels, indices_tuple)
return loss * world_size


class DistributedMinerWrapper(torch.nn.Module):
def __init__(self, miner, efficient=False):
Expand All @@ -94,11 +137,13 @@ def __init__(self, miner, efficient=False):
self.miner = miner
self.efficient = efficient

def forward(self, embeddings, labels):
all_emb, all_labels, labels, device = gather(embeddings, labels)
def forward(self, emb, labels, ref_emb=None, ref_labels=None):
all_emb, all_labels, all_ref_emb, all_ref_labels, labels = gather_emb_and_ref(
emb, labels, ref_emb, ref_labels
)
if self.efficient:
return get_indices_tuple(
labels, all_labels, device, embeddings, all_emb, self.miner
)
all_labels = select_ref_or_regular(all_labels, all_ref_labels)
all_emb = select_ref_or_regular(all_emb, all_ref_emb)
return get_indices_tuple(labels, all_labels, emb, all_emb, self.miner)
else:
return self.miner(all_emb, all_labels)
return self.miner(all_emb, all_labels, all_ref_emb, all_ref_labels)
Loading

0 comments on commit a5cede7

Please sign in to comment.