|
2 | 2 | from typing import Optional, Callable
|
3 | 3 |
|
4 | 4 | import psutil
|
| 5 | +import torch |
| 6 | +import torch.distributed as dist |
5 | 7 | from pympler.asizeof import asizeof
|
6 | 8 | from tensorboardX import SummaryWriter
|
7 |
| -from typing import Optional, Callable |
| 9 | + |
| 10 | + |
8 | 11 | import torch
|
| 12 | +import torch.distributed as dist |
| 13 | + |
| 14 | +def is_ddp_enabled(): |
| 15 | + """ |
| 16 | + Check if Distributed Data Parallel (DDP) is enabled by verifying if |
| 17 | + PyTorch's distributed package is available and initialized. |
| 18 | + """ |
| 19 | + return dist.is_available() and dist.is_initialized() |
| 20 | + |
| 21 | +def ddp_synchronize(): |
| 22 | + """ |
| 23 | + Perform a barrier synchronization across all processes in DDP mode. |
| 24 | + Ensures all processes reach this point before continuing. |
| 25 | + """ |
| 26 | + if is_ddp_enabled(): |
| 27 | + dist.barrier() |
| 28 | + |
| 29 | +def ddp_all_reduce_sum(tensor): |
| 30 | + """ |
| 31 | + Perform an all-reduce operation (sum) on the given tensor across |
| 32 | + all processes in DDP mode. Returns the reduced tensor. |
9 | 33 |
|
| 34 | + Arguments: |
| 35 | + - tensor (:obj:`torch.Tensor`): The input tensor to be reduced. |
| 36 | +
|
| 37 | + Returns: |
| 38 | + - torch.Tensor: The reduced tensor, summed across all processes. |
| 39 | + """ |
| 40 | + if is_ddp_enabled(): |
| 41 | + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) |
| 42 | + return tensor |
| 43 | + |
| 44 | +def calculate_update_per_collect(cfg, new_data): |
| 45 | + """ |
| 46 | + Calculate the number of updates to perform per data collection in a |
| 47 | + Distributed Data Parallel (DDP) setting. This ensures that all GPUs |
| 48 | + compute the same `update_per_collect` value, synchronized across processes. |
| 49 | +
|
| 50 | + Arguments: |
| 51 | + - cfg: Configuration object containing policy settings. |
| 52 | + - new_data (list): The newly collected data segments. |
| 53 | +
|
| 54 | + Returns: |
| 55 | + - int: The number of updates to perform per collection. |
| 56 | + """ |
| 57 | + # Retrieve the update_per_collect setting from the configuration |
| 58 | + update_per_collect = cfg.policy.update_per_collect |
| 59 | + |
| 60 | + if update_per_collect is None: |
| 61 | + # If update_per_collect is not explicitly set, calculate it based on |
| 62 | + # the number of collected transitions and the replay ratio. |
| 63 | + |
| 64 | + # The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game. |
| 65 | + # On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps. |
| 66 | + collected_transitions_num = sum( |
| 67 | + min(len(game_segment), cfg.policy.game_segment_length) |
| 68 | + for game_segment in new_data[0] |
| 69 | + ) |
| 70 | + |
| 71 | + if torch.cuda.is_available(): |
| 72 | + # Convert the collected transitions count to a GPU tensor for DDP operations. |
| 73 | + collected_transitions_tensor = torch.tensor( |
| 74 | + collected_transitions_num, dtype=torch.int64, device='cuda' |
| 75 | + ) |
| 76 | + |
| 77 | + # Synchronize the collected transitions count across all GPUs using all-reduce. |
| 78 | + total_collected_transitions = ddp_all_reduce_sum( |
| 79 | + collected_transitions_tensor |
| 80 | + ).item() |
| 81 | + |
| 82 | + # Calculate update_per_collect based on the total synchronized transitions count. |
| 83 | + update_per_collect = int(total_collected_transitions * cfg.policy.replay_ratio) |
| 84 | + |
| 85 | + # Ensure the computed update_per_collect is positive. |
| 86 | + assert update_per_collect > 0, "update_per_collect must be positive" |
| 87 | + else: |
| 88 | + # If not using DDP, calculate update_per_collect directly from the local count. |
| 89 | + update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) |
| 90 | + |
| 91 | + return update_per_collect |
10 | 92 |
|
11 | 93 | def initialize_zeros_batch(observation_shape, batch_size, device):
|
12 | 94 | """
|
|
0 commit comments