Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 117 additions & 47 deletions nemo_rl/algorithms/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
run_async_multi_turn_rollout,
)
from nemo_rl.models.generation.interfaces import GenerationInterface
from nemo_rl.utils.trace import define_collect_trace, new_tracer, Tracer

TokenizerType = PreTrainedTokenizerBase

Expand All @@ -52,6 +53,12 @@ def __init__(self, max_size: int):
self.last_target_weight_already_generated = -1
self._lock = _threading.Lock()

self._tracer = new_tracer("replay_buffer")

@define_collect_trace
def collect_trace(self):
return self._tracer.get_events()

def push_with_wait_signal(
self,
trajectory: dict[str, Any],
Expand All @@ -65,7 +72,14 @@ def push_with_wait_signal(
weight_version: version of the model weights used for generation
target_weight_version: version of the model weights this trajectory is intended for training
"""
with self._lock:
span = self._tracer.span(
"push_with_wait_signal",
metadata={
"weight_version": weight_version,
"target_weight_version": target_weight_version,
}
)
with span, self._lock:
if len(self.trajectories) >= self.max_size:
return "full"

Expand Down Expand Up @@ -115,7 +129,7 @@ def sample(
Returns:
Dictionary with 'trajectories' and 'avg_trajectory_age' keys, or None if insufficient data
"""
with self._lock:
with self._tracer.span("sample"), self._lock:
if not self.trajectories:
return None

Expand Down Expand Up @@ -291,6 +305,18 @@ def __init__(
# Track which target weights are currently being generated (globally)
self._generating_targets: set[int] = set()

self._tracer = new_tracer("trajectory_collector")
self._loop_tracer = new_tracer("trajectory_collector_loop")
self._worker_tracers = []

@define_collect_trace
def collect_trace(self):
events = self._tracer.get_events()
events.extend(self._loop_tracer.get_events())
for worker_tracer in self._worker_tracers:
events.extend(worker_tracer.get_events())
return events

def _calculate_target_weights(self, generation_weight_version: int) -> list[int]:
"""Calculate target weight versions for given generation weight version.

Expand Down Expand Up @@ -351,6 +377,13 @@ def set_weight_version(self, version: int) -> None:
print(f"🔄 Updated weight version to {version}, resuming collection")
else:
print(f"🔄 Updated weight version to {version}")
self._tracer.add_instant_event(
"set_weight_version",
metadata={
"version": version,
"was_paused": was_paused,
}
)

def _should_pause_for_generation_limits(self) -> bool:
"""Check if collection should be paused due to generation limits."""
Expand Down Expand Up @@ -454,6 +487,14 @@ def _process_batch(self, batch: BatchedDataDict[DatumSpec]) -> None:
generation_weight_version = self.current_weight_version
num_generations = self.master_config["grpo"]["num_generations_per_prompt"]
num_prompts = batch.size
self._loop_tracer.start_span(
"process_batch",
metadata={
"generation_weight_version": generation_weight_version,
"num_generations": num_generations,
"num_prompts": num_prompts,
},
)

# Get the next target weight that needs generation
target_weight = self._get_next_target_for_generation(
Expand Down Expand Up @@ -491,19 +532,31 @@ def _process_batch(self, batch: BatchedDataDict[DatumSpec]) -> None:
repeated_batch = single_prompt_batch.repeat_interleave(num_generations)

self._inflight_sema.acquire()
worker_tracer = new_tracer(f"trajectory_collector_worker_prompt{prompt_idx}")
if self._tracer.enabled:
self._worker_tracers.append(worker_tracer)
worker = _threading.Thread(
target=self._run_prompt_group_worker,
args=(
repeated_batch,
generation_weight_version,
target_weight,
prompt_idx,
worker_tracer,
),
daemon=True,
)
with self._threads_lock:
self._inflight_threads.add(worker)
worker.start()
self._loop_tracer.add_instant_event(
"launch run_prompt_group_worker",
metadata={
"generation_weight_version": generation_weight_version,
"target_weight_version": target_weight,
"prompt_idx": prompt_idx,
},
)

self._cleanup_finished_threads()

Expand All @@ -513,6 +566,9 @@ def _process_batch(self, batch: BatchedDataDict[DatumSpec]) -> None:

traceback.print_exc()

finally:
self._loop_tracer.end_span("process_batch")

def get_weight_version(self) -> int:
return self.current_weight_version

Expand Down Expand Up @@ -577,50 +633,52 @@ def prepare_for_refit(self) -> None:

def resume_after_refit(self) -> None:
"""Resume new generation starts after refit is complete."""
print("🔄 Resuming generation starts after refit")

# Invalidate&recompute vLLM caches after the in-flight weight updates if
# recompute_kv_cache_after_weight_updates is True (AREAL-style implementation).
# Otherwise, keep using the stale KV caches (Magistral-style implementation).
async_cfg = self.master_config.get("grpo", {}).get("async_grpo", {})
if async_cfg.get("in_flight_weight_updates", False) and async_cfg.get(
"recompute_kv_cache_after_weight_updates", False
):
try:
print("🔄 Invalidating vLLM prefix/KV caches after weight update")
invalidated = self.policy_generation.invalidate_kv_cache()
if invalidated:
print("✅ Invalidated vLLM prefix/KV caches after weight update")
else:
print(
"⚠️ vLLM cache invalidation reported partial/unsuccessful on some workers"
)
except Exception as e:
print(f"⚠️ Failed to invalidate vLLM caches: {e}")
with self._tracer.span("resume_after_refit"):
print("🔄 Resuming generation starts after refit")

# Invalidate&recompute vLLM caches after the in-flight weight updates if
# recompute_kv_cache_after_weight_updates is True (AREAL-style implementation).
# Otherwise, keep using the stale KV caches (Magistral-style implementation).
async_cfg = self.master_config.get("grpo", {}).get("async_grpo", {})
if async_cfg.get("in_flight_weight_updates", False) and async_cfg.get(
"recompute_kv_cache_after_weight_updates", False
):
try:
print("🔄 Invalidating vLLM prefix/KV caches after weight update")
invalidated = self.policy_generation.invalidate_kv_cache()
if invalidated:
print("✅ Invalidated vLLM prefix/KV caches after weight update")
else:
print(
"⚠️ vLLM cache invalidation reported partial/unsuccessful on some workers"
)
except Exception as e:
print(f"⚠️ Failed to invalidate vLLM caches: {e}")

self._refit_pause_cleared.set()
self._refit_pause_cleared.set()

def wait_for_pending_generations(self) -> None:
"""Wait for all in-flight generation threads to complete."""
start_time = time.time()
with self._tracer.span("wait_for_pending_generations"):
start_time = time.time()

while True:
with self._threads_lock:
finished = {t for t in self._inflight_threads if not t.is_alive()}
for t in finished:
self._inflight_threads.remove(t)
while True:
with self._threads_lock:
finished = {t for t in self._inflight_threads if not t.is_alive()}
for t in finished:
self._inflight_threads.remove(t)

pending_count = len(self._inflight_threads)
pending_count = len(self._inflight_threads)

if pending_count == 0:
print("✅ All generation threads completed")
break
if pending_count == 0:
print("✅ All generation threads completed")
break

elapsed = time.time() - start_time
print(
f"⏳ Waiting for {pending_count} pending generation threads... ({elapsed:.1f}s elapsed)"
)
time.sleep(0.5)
elapsed = time.time() - start_time
print(
f"⏳ Waiting for {pending_count} pending generation threads... ({elapsed:.1f}s elapsed)"
)
time.sleep(0.5)

def get_dataloader_state(self) -> dict:
"""Get the current dataloader state for checkpointing."""
Expand All @@ -640,19 +698,29 @@ def _run_prompt_group_worker(
generation_weight_version: int,
target_weight_version: int,
prompt_idx: int,
worker_tracer: Tracer,
) -> None:
worker_tracer.start_span(
"run_prompt_group_worker",
metadata={
"generation_weight_version": generation_weight_version,
"target_weight_version": target_weight_version,
"prompt_idx": prompt_idx,
}
)
try:
# Run rollout for this prompt group
# Async engine supports concurrent generation; avoid locking
final_batch, rollout_metrics = run_async_multi_turn_rollout(
policy_generation=self.policy_generation,
input_batch=repeated_batch,
tokenizer=self.tokenizer,
task_to_env=self.task_to_env,
max_seq_len=self.master_config["policy"]["max_total_sequence_length"],
max_rollout_turns=self.master_config["grpo"]["max_rollout_turns"],
greedy=False,
)
with worker_tracer.span("run_async_multi_turn_rollout"):
final_batch, rollout_metrics = run_async_multi_turn_rollout(
policy_generation=self.policy_generation,
input_batch=repeated_batch,
tokenizer=self.tokenizer,
task_to_env=self.task_to_env,
max_seq_len=self.master_config["policy"]["max_total_sequence_length"],
max_rollout_turns=self.master_config["grpo"]["max_rollout_turns"],
greedy=False,
)

# Move to CPU and push to buffer (avoid blocking on GC/push)
final_batch_cpu = final_batch.to("cpu")
Expand Down Expand Up @@ -728,3 +796,5 @@ def _run_prompt_group_worker(
import traceback

traceback.print_exc()

worker_tracer.end_span("run_prompt_group_worker")
Loading