Skip to content

Commit

Permalink
Fix bug in BatchEasyHardMiner.get_max_per_row where minimum value of …
Browse files Browse the repository at this point in the history
…mat was set being set to 0 instead of neg_inf
  • Loading branch information
KevinMusgrave committed May 28, 2022
1 parent 12ebcea commit 38ddab0
Showing 1 changed file with 26 additions and 22 deletions.
48 changes: 26 additions & 22 deletions src/pytorch_metric_learning/miners/batch_easy_hard_miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,35 +134,39 @@ def get_mine_function(self, strategy):

return mine_func

def get_max_per_row(
self, mat, anchor_idx, other_idx, val_range=None, semihard_thresholds=None
):
mask = torch.zeros_like(mat)
mask[anchor_idx, other_idx] = 1
if semihard_thresholds is not None:
mask[mat >= semihard_thresholds.unsqueeze(1)] = 0
if val_range is not None:
mask[(mat > val_range[1]) | (mat < val_range[0])] = 0
mat_masked = mat * mask
non_zero_rows = torch.any(mask != 0, dim=1)
return torch.max(mat_masked, dim=1), non_zero_rows
def get_max_per_row(self, *args, **kwargs):
return self.get_x_per_row("max", *args, **kwargs)

def get_min_per_row(
self, mat, anchor_idx, other_idx, val_range=None, semihard_thresholds=None
def get_min_per_row(self, *args, **kwargs):
return self.get_x_per_row("min", *args, **kwargs)

def get_x_per_row(
self,
xtype,
mat,
anchor_idx,
other_idx,
val_range=None,
semihard_thresholds=None,
):
pos_inf = c_f.pos_inf(mat.dtype)
mask = torch.ones_like(mat) * pos_inf
assert xtype in ["min", "max"]
inf = c_f.pos_inf(mat.dtype) if xtype == "min" else c_f.neg_inf(mat.dtype)
mask = torch.ones_like(mat) * inf
mask[anchor_idx, other_idx] = 1

if semihard_thresholds is not None:
mask[mat <= semihard_thresholds.unsqueeze(1)] = pos_inf
if xtype == "min":
condition = mat <= semihard_thresholds.unsqueeze(1)
else:
condition = mat >= semihard_thresholds.unsqueeze(1)
mask[condition] = inf
if val_range is not None:
mask[(mat > val_range[1]) | (mat < val_range[0])] = pos_inf
mask[(mat > val_range[1]) | (mat < val_range[0])] = inf

non_inf_rows = torch.any(mask != pos_inf, dim=1)
non_inf_rows = torch.any(mask != inf, dim=1)
mat = mat.clone()
mat[mask == pos_inf] = pos_inf
return torch.min(mat, dim=1), non_inf_rows
mat[mask == inf] = inf
dist_fn = torch.min if xtype == "min" else torch.max
return dist_fn(mat, dim=1), non_inf_rows

def set_stats(self, positive_dists, negative_dists):
if self.collect_stats:
Expand Down

0 comments on commit 38ddab0

Please sign in to comment.