Skip to content

Commit 527e8b8

Browse files
jeffkbkimfacebook-github-bot
authored andcommitted
TWA Fused Tasks (#3317)
Summary: Pull Request resolved: #3317 TensorWeightedAvgMetric currently does not support FUSED_TASKS computation. With this patch, TWA supports FUSED_TASKS mode Updated unit tests and created new ones for FUSED mode Reviewed By: iamzainhuda Differential Revision: D77958663 fbshipit-source-id: bea413046706c6f1a09f0bbfe1eda7281a481311
1 parent a4832f4 commit 527e8b8

File tree

4 files changed

+383
-125
lines changed

4 files changed

+383
-125
lines changed

torchrec/metrics/rec_metric.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,27 @@ def _update(
623623
labels, torch.Tensor
624624
)
625625

626+
# Metrics such as TensorWeightedAvgMetric will have tensors that we also need to stack.
627+
# Stack in task order: (n_tasks, batch_size)
628+
if "required_inputs" in kwargs:
629+
target_tensors: list[torch.Tensor] = []
630+
for task in self._tasks:
631+
if (
632+
task.tensor_name
633+
and task.tensor_name in kwargs["required_inputs"]
634+
):
635+
target_tensors.append(
636+
kwargs["required_inputs"][task.tensor_name]
637+
)
638+
639+
if target_tensors:
640+
stacked_tensor = torch.stack(target_tensors)
641+
642+
# Reshape the stacked_tensor to size([len(self._tasks), self._batch_size])
643+
stacked_tensor = stacked_tensor.view(len(self._tasks), -1)
644+
assert isinstance(stacked_tensor, torch.Tensor)
645+
kwargs["required_inputs"]["target_tensor"] = stacked_tensor
646+
626647
predictions = (
627648
# Reshape the predictions to size([len(self._tasks), self._batch_size])
628649
predictions.view(len(self._tasks), -1)

torchrec/metrics/tensor_weighted_avg.py

Lines changed: 95 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -30,23 +30,29 @@ class TensorWeightedAvgMetricComputation(RecMetricComputation):
3030
3131
It is a sibling to WeightedAvgMetricComputation, but it computes the weighted average of a tensor
3232
passed in as a required input instead of the predictions tensor.
33+
34+
FUSED_TASKS_COMPUTATION:
35+
This class requires all target tensors from tasks to be stacked together in RecMetrics._update().
36+
During TensorWeightedAvgMetricComputation.update(), the weighted sum and weighted num samples are
37+
computed per stacked tensor.
3338
"""
3439

3540
def __init__(
3641
self,
3742
*args: Any,
38-
tensor_name: Optional[str] = None,
39-
weighted: bool = True,
43+
tasks: List[RecTaskInfo],
4044
description: Optional[str] = None,
4145
**kwargs: Any,
4246
) -> None:
4347
super().__init__(*args, **kwargs)
44-
if tensor_name is None:
45-
raise RecMetricException(
46-
f"TensorWeightedAvgMetricComputation expects tensor_name to not be None got {tensor_name}"
47-
)
48-
self.tensor_name: str = tensor_name
49-
self.weighted: bool = weighted
48+
self.tasks: List[RecTaskInfo] = tasks
49+
50+
for task in self.tasks:
51+
if task.tensor_name is None:
52+
raise RecMetricException(
53+
"TensorWeightedAvgMetricComputation expects all tasks to have tensor_name, but got None."
54+
)
55+
5056
self._add_state(
5157
"weighted_sum",
5258
torch.zeros(self._n_tasks, dtype=torch.double),
@@ -63,6 +69,13 @@ def __init__(
6369
)
6470
self._description = description
6571

72+
self.weighted_mask: torch.Tensor = torch.tensor(
73+
[task.weighted for task in self.tasks]
74+
).unsqueeze(dim=-1)
75+
76+
if torch.cuda.is_available():
77+
self.weighted_mask = self.weighted_mask.cuda()
78+
6679
def update(
6780
self,
6881
*,
@@ -71,25 +84,54 @@ def update(
7184
weights: Optional[torch.Tensor],
7285
**kwargs: Dict[str, Any],
7386
) -> None:
74-
if (
75-
"required_inputs" not in kwargs
76-
or self.tensor_name not in kwargs["required_inputs"]
77-
):
87+
88+
target_tensor: torch.Tensor
89+
90+
if "required_inputs" not in kwargs:
7891
raise RecMetricException(
79-
f"TensorWeightedAvgMetricComputation expects {self.tensor_name} in the required_inputs"
92+
"TensorWeightedAvgMetricComputation expects 'required_inputs' to exist."
8093
)
94+
else:
95+
if len(self.tasks) > 1:
96+
# In FUSED mode, RecMetric._update() always creates "target_tensor" for the stacked tensor.
97+
# Note that RecMetric._update() only stacks if the tensor_name exists in kwargs["required_inputs"].
98+
target_tensor = cast(
99+
torch.Tensor,
100+
kwargs["required_inputs"]["target_tensor"],
101+
)
102+
elif len(self.tasks) == 1:
103+
# UNFUSED_TASKS_COMPUTATION
104+
tensor_name = self.tasks[0].tensor_name
105+
if tensor_name not in kwargs["required_inputs"]:
106+
raise RecMetricException(
107+
f"TensorWeightedAvgMetricComputation expects required_inputs to contain target tensor {self.tasks[0].tensor_name}"
108+
)
109+
else:
110+
target_tensor = cast(
111+
torch.Tensor,
112+
kwargs["required_inputs"][tensor_name],
113+
)
114+
81115
num_samples = labels.shape[0]
82-
target_tensor = cast(torch.Tensor, kwargs["required_inputs"][self.tensor_name])
83116
weights = cast(torch.Tensor, weights)
117+
118+
# Vectorized computation using masks
119+
weighted_values = torch.where(
120+
self.weighted_mask, target_tensor * weights, target_tensor
121+
)
122+
123+
weighted_counts = torch.where(
124+
self.weighted_mask, weights, torch.ones_like(weights)
125+
)
126+
127+
# Sum across batch dimension to Shape(n_tasks,)
128+
weighted_sum = weighted_values.sum(dim=-1)
129+
weighted_num_samples = weighted_counts.sum(dim=-1)
130+
131+
# Update states
84132
states = {
85-
"weighted_sum": (
86-
target_tensor * weights if self.weighted else target_tensor
87-
).sum(dim=-1),
88-
"weighted_num_samples": (
89-
weights.sum(dim=-1)
90-
if self.weighted
91-
else torch.ones(weights.shape).sum(dim=-1).to(device=weights.device)
92-
),
133+
"weighted_sum": weighted_sum,
134+
"weighted_num_samples": weighted_num_samples,
93135
}
94136
for state_name, state_value in states.items():
95137
state = getattr(self, state_name)
@@ -126,23 +168,40 @@ class TensorWeightedAvgMetric(RecMetric):
126168
def _get_task_kwargs(
127169
self, task_config: Union[RecTaskInfo, List[RecTaskInfo]]
128170
) -> Dict[str, Any]:
129-
if not isinstance(task_config, RecTaskInfo):
130-
raise RecMetricException(
131-
f"TensorWeightedAvgMetric expects task_config to be RecTaskInfo not {type(task_config)}. Check the FUSED_TASKS_COMPUTATION settings."
132-
)
171+
all_tasks = (
172+
[task_config] if isinstance(task_config, RecTaskInfo) else task_config
173+
)
133174
return {
134-
"tensor_name": task_config.tensor_name,
135-
"weighted": task_config.weighted,
175+
"tasks": all_tasks,
136176
}
137177

138178
def _get_task_required_inputs(
139179
self, task_config: Union[RecTaskInfo, List[RecTaskInfo]]
140180
) -> Set[str]:
141-
if not isinstance(task_config, RecTaskInfo):
142-
raise RecMetricException(
143-
f"TensorWeightedAvgMetric expects task_config to be RecTaskInfo not {type(task_config)}. Check the FUSED_TASKS_COMPUTATION settings."
144-
)
145-
required_inputs = set()
146-
if task_config.tensor_name is not None:
147-
required_inputs.add(task_config.tensor_name)
148-
return required_inputs
181+
"""
182+
Returns the required inputs for the task.
183+
184+
FUSED_TASKS_COMPUTATION:
185+
- Given two tasks with the same tensor_name, assume the same tensor reference
186+
- For a given tensor_name, assume all tasks have the same weighted flag
187+
"""
188+
all_tasks = (
189+
[task_config] if isinstance(task_config, RecTaskInfo) else task_config
190+
)
191+
192+
required_inputs: dict[str, bool] = {}
193+
for task in all_tasks:
194+
if task.tensor_name is not None:
195+
if (
196+
task.tensor_name in required_inputs
197+
and task.weighted is not required_inputs[task.tensor_name]
198+
):
199+
existing_weighted_flag = required_inputs[task.tensor_name]
200+
raise RecMetricException(
201+
f"This target tensor was already registered as weighted={existing_weighted_flag}. "
202+
f"This target tensor cannot be re-registered with weighted={task.weighted}"
203+
)
204+
else:
205+
required_inputs[str(task.tensor_name)] = task.weighted
206+
207+
return set(required_inputs.keys())

torchrec/metrics/test_utils/__init__.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import random
1313
import tempfile
1414
import uuid
15-
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
15+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
1616
from unittest.mock import Mock, patch
1717

1818
import torch
@@ -45,7 +45,9 @@ def gen_test_batch(
4545
mask: Optional[torch.Tensor] = None,
4646
n_classes: Optional[int] = None,
4747
seed: Optional[int] = None,
48+
device: Optional[Union[str, torch.device]] = None,
4849
) -> Dict[str, torch.Tensor]:
50+
device = torch.device(device or "cpu")
4951
if seed is not None:
5052
torch.manual_seed(seed)
5153
if label_value is not None:
@@ -65,14 +67,14 @@ def gen_test_batch(
6567
else:
6668
weight = torch.rand(batch_size, dtype=torch.double)
6769
test_batch = {
68-
label_name: label,
69-
prediction_name: prediction,
70-
weight_name: weight,
71-
tensor_name: torch.rand(batch_size, dtype=torch.double),
70+
label_name: label.to(device),
71+
prediction_name: prediction.to(device),
72+
weight_name: weight.to(device),
73+
tensor_name: torch.rand(batch_size, dtype=torch.double).to(device),
7274
}
7375
if mask_tensor_name is not None:
7476
if mask is None:
75-
mask = torch.ones(batch_size, dtype=torch.double)
77+
mask = torch.ones(batch_size, dtype=torch.double).to(device)
7678
test_batch[mask_tensor_name] = mask
7779

7880
return test_batch
@@ -240,8 +242,10 @@ def rec_metric_value_test_helper(
240242
n_classes: Optional[int] = None,
241243
zero_weights: bool = False,
242244
zero_labels: bool = False,
245+
device: Optional[Union[str, torch.device]] = None,
243246
**kwargs: Any,
244247
) -> Tuple[Dict[str, torch.Tensor], Tuple[Dict[str, torch.Tensor], ...]]:
248+
device = torch.device(device or "cpu")
245249
tasks = gen_test_tasks(task_names)
246250
model_outs = []
247251
for _ in range(nsteps):
@@ -263,6 +267,7 @@ def rec_metric_value_test_helper(
263267
n_classes=n_classes,
264268
weight_value=weight_value,
265269
label_value=label_value,
270+
device=device,
266271
)
267272
for task in tasks
268273
]
@@ -293,7 +298,8 @@ def get_target_rec_metric_value(
293298
compute_on_all_ranks=compute_on_all_ranks,
294299
should_validate_update=should_validate_update,
295300
**kwargs,
296-
)
301+
).to(device)
302+
297303
for i in range(nsteps):
298304
# Get required_inputs_list from the target metric
299305
required_inputs_list = list(target_metric_obj.get_required_inputs())
@@ -381,6 +387,7 @@ def rec_metric_gpu_sync_test_launcher(
381387
entry_point: Callable[..., None],
382388
batch_size: int = BATCH_SIZE,
383389
batch_window_size: int = BATCH_WINDOW_SIZE,
390+
device: Optional[Union[str, torch.device]] = None,
384391
**kwargs: Dict[str, Any],
385392
) -> None:
386393
with tempfile.TemporaryDirectory() as tmpdir:
@@ -402,6 +409,8 @@ def rec_metric_gpu_sync_test_launcher(
402409
batch_size,
403410
batch_window_size,
404411
kwargs.get("n_classes", None),
412+
False,
413+
torch.device(device or "cpu"),
405414
)
406415

407416

@@ -419,8 +428,10 @@ def sync_test_helper(
419428
batch_window_size: int = BATCH_WINDOW_SIZE,
420429
n_classes: Optional[int] = None,
421430
zero_weights: bool = False,
431+
device: Optional[Union[str, torch.device]] = None,
422432
**kwargs: Dict[str, Any],
423433
) -> None:
434+
device = torch.device(device or "cpu")
424435
rank = int(os.environ["RANK"])
425436
world_size = int(os.environ["WORLD_SIZE"])
426437
dist.init_process_group(
@@ -444,7 +455,7 @@ def sync_test_helper(
444455
window_size=batch_window_size * world_size,
445456
# pyre-ignore[6]: Incompatible parameter type
446457
**kwargs,
447-
)
458+
).to(device)
448459

449460
weight_value: Optional[torch.Tensor] = None
450461

@@ -458,6 +469,7 @@ def sync_test_helper(
458469
n_classes=n_classes,
459470
weight_value=weight_value,
460471
seed=42, # we set seed because of how test metric places tensors on ranks
472+
device=device,
461473
)
462474
for task in tasks
463475
]
@@ -575,6 +587,7 @@ def rec_metric_value_test_launcher(
575587
n_classes: Optional[int] = None,
576588
zero_weights: bool = False,
577589
zero_labels: bool = False,
590+
device: Optional[Union[str, torch.device]] = None,
578591
**kwargs: Any,
579592
) -> None:
580593
with tempfile.TemporaryDirectory() as tmpdir:
@@ -600,6 +613,7 @@ def rec_metric_value_test_launcher(
600613
n_classes=n_classes,
601614
zero_weights=zero_weights,
602615
zero_labels=zero_labels,
616+
device=device,
603617
**kwargs,
604618
)
605619

@@ -616,6 +630,7 @@ def rec_metric_value_test_launcher(
616630
n_classes,
617631
test_nsteps,
618632
zero_weights,
633+
device,
619634
)
620635

621636

@@ -642,6 +657,7 @@ def metric_test_helper(
642657
n_classes: Optional[int] = None,
643658
nsteps: int = 1,
644659
zero_weights: bool = False,
660+
device: Optional[Union[str, torch.device]] = None,
645661
is_time_dependent: bool = False,
646662
time_dependent_metric: Optional[Dict[Type[RecMetric], str]] = None,
647663
**kwargs: Any,
@@ -670,6 +686,7 @@ def metric_test_helper(
670686
is_time_dependent=is_time_dependent,
671687
time_dependent_metric=time_dependent_metric,
672688
zero_weights=zero_weights,
689+
device=device,
673690
**kwargs,
674691
)
675692

0 commit comments

Comments
 (0)