Skip to content

RAM-efficient Retrieval #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 19, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 69 additions & 41 deletions mmlearn/modules/metrics/retrieval_recall.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
"""Retrieval Recall@K metric."""

import concurrent.futures
import os
from functools import partial
from typing import Any, Callable, Literal, Optional, Union
from typing import Any, Callable, Dict, Literal, Optional, Union

import torch
import torch.distributed
from hydra_zen import store
from torch.nn import functional as F # noqa: N812
from torchmetrics import Metric
from torchmetrics.retrieval.base import _retrieval_aggregate
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.compute import _safe_matmul
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.distributed import gather_all_tensors
from tqdm import tqdm


@store(group="modules/metrics", provider="mmlearn")
Expand All @@ -29,7 +31,7 @@ class RetrievalRecallAtK(Metric):
----------
top_k : int
The number of top elements to consider for computing the Recall@K.
reduction : {"mean", "sum", "none", None}, default="sum"
reduction : {"mean", "sum", "none", None}, optional, default="sum"
Specifies the reduction to apply after computing the pairwise cosine similarity
scores.
aggregation : {"mean", "median", "min", "max"} or callable, default="mean"
Expand All @@ -48,7 +50,6 @@ class RetrievalRecallAtK(Metric):
- If the `aggregation` is not one of {"mean", "median", "min", "max"} or a
custom callable function.


"""

is_differentiable: bool = False
Expand All @@ -63,13 +64,14 @@ class RetrievalRecallAtK(Metric):
def __init__(
self,
top_k: int,
reduction: Literal["mean", "sum", "none", None] = None,
reduction: Optional[Literal["mean", "sum", "none"]] = "sum",
aggregation: Union[
Literal["mean", "median", "min", "max"],
Callable[[torch.Tensor, int], torch.Tensor],
] = "mean",
**kwargs: Any,
) -> None:
"""Initialize the metric."""
super().__init__(**kwargs)

if top_k is not None and not (isinstance(top_k, int) and top_k > 0):
Expand Down Expand Up @@ -105,12 +107,6 @@ def __init__(
self._to_sync = self.sync_on_compute
self._should_unsync = False

def _is_distributed(self) -> bool:
if self.distributed_available_fn is not None:
distributed_available = self.distributed_available_fn

return distributed_available() if callable(distributed_available) else False

def update(self, x: torch.Tensor, y: torch.Tensor, indexes: torch.Tensor) -> None:
"""Check shape, convert dtypes and add to accumulators.

Expand Down Expand Up @@ -181,34 +177,49 @@ def compute(self) -> torch.Tensor:
torch.Tensor
The computed metric.
"""
# compute the cosine similarity
x_norm = F.normalize(dim_zero_cat(self.x), p=2, dim=-1)
y_norm = F.normalize(dim_zero_cat(self.y), p=2, dim=-1)
similarity = _safe_matmul(x_norm, y_norm)
reduction_mapping: dict[
x = dim_zero_cat(self.x)
y = dim_zero_cat(self.y)

# normalize embeddings
x /= x.norm(dim=-1, p=2, keepdim=True)
y /= y.norm(dim=-1, p=2, keepdim=True)

# instantiate reduction function
reduction_mapping: Dict[
Optional[str], Callable[[torch.Tensor], torch.Tensor]
] = {
"sum": partial(torch.sum, dim=-1),
"mean": partial(torch.mean, dim=-1),
"none": lambda x: x,
None: lambda x: x,
}
scores: torch.Tensor = reduction_mapping[self.reduction](similarity)

# concatenate indexes of true pairs
indexes = dim_zero_cat(self.indexes)
positive_pairs = torch.zeros_like(scores, dtype=torch.bool)
positive_pairs[torch.arange(len(scores)), indexes] = True

results = []
for start in range(0, len(scores), self._batch_size):
end = start + self._batch_size
x = scores[start:end]
y = positive_pairs[start:end]
result = _recall_at_k(x, y, self.top_k)
results.append(result)
results: list[torch.Tensor] = []
with concurrent.futures.ThreadPoolExecutor(
max_workers=os.cpu_count() or 1 # use all available CPUs
) as executor:
futures = [
executor.submit(
self._process_batch,
start,
x,
y,
indexes,
reduction_mapping,
self.top_k,
)
for start in tqdm(
range(0, len(x), self._batch_size), desc=f"Recall@{self.top_k}"
)
]
for future in concurrent.futures.as_completed(futures):
results.append(future.result())

return _retrieval_aggregate(
(torch.cat([x.to(scores) for x in results]) > 0).float(), self.aggregation
(torch.cat([x.float() for x in results]) > 0).float(), self.aggregation
)

def forward(self, *args: Any, **kwargs: Any) -> Any:
Expand All @@ -223,9 +234,36 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
"RetrievalRecallAtK metric does not support forward method"
)

def _is_distributed(self) -> bool:
if self.distributed_available_fn is not None:
distributed_available = self.distributed_available_fn

return distributed_available() if callable(distributed_available) else False

def _process_batch(
self,
start: int,
x_norm: torch.Tensor,
y_norm: torch.Tensor,
indexes: torch.Tensor,
reduction_mapping: Dict[Optional[str], Callable[[torch.Tensor], torch.Tensor]],
top_k: int,
) -> torch.Tensor:
"""Compute the Recall@K for a batch of samples."""
end = start + self._batch_size
x_norm_batch = x_norm[start:end]
indexes_batch = indexes[start:end]

similarity = _safe_matmul(x_norm_batch, y_norm)
scores: torch.Tensor = reduction_mapping[self.reduction](similarity)

with torch.inference_mode():
positive_pairs = torch.zeros_like(scores, dtype=torch.bool)
positive_pairs[torch.arange(len(scores)), indexes_batch] = True

return _recall_at_k(scores, positive_pairs, top_k)


# modified from:
# https://github.com/LAION-AI/CLIP_benchmark/blob/main/clip_benchmark/metrics/zeroshot_retrieval.py
def _recall_at_k(
scores: torch.Tensor, positive_pairs: torch.Tensor, k: int
) -> torch.Tensor:
Expand All @@ -244,20 +282,10 @@ def _recall_at_k(
-------
recall at k averaged over all texts
"""
nb_texts, nb_images = scores.shape
# for each text, sort according to image scores in decreasing order
topk_indices = torch.topk(scores, k, dim=1)[1]
# compute number of positives for each text
nb_positive = positive_pairs.sum(dim=1)
# nb_texts, k, nb_images
topk_indices_onehot = torch.nn.functional.one_hot(
topk_indices, num_classes=nb_images
)
# compute number of true positives
positive_pairs_reshaped = positive_pairs.view(nb_texts, 1, nb_images)
# a true positive means a positive among the topk
nb_true_positive = (topk_indices_onehot * positive_pairs_reshaped).sum(dim=(1, 2))
# compute recall at k
positive_topk = positive_pairs.gather(1, topk_indices)
nb_true_positive = positive_topk.sum(dim=1)
return nb_true_positive / nb_positive


Expand Down