diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 75cb053c1..b5a975a06 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -129,6 +129,7 @@ def get_num_rollout_per_epoch(self): def generate(self, rollout_id): start_time = time.time() self.rollout_id = rollout_id + self.recover_rollout_engines() self.health_monitoring_resume() if self.args.ci_test and self.args.use_fault_tolerance and rollout_id >= 2: self._try_ci_fault_injection() diff --git a/slime/utils/health_monitor.py b/slime/utils/health_monitor.py index e95367e95..6ea8fd55e 100644 --- a/slime/utils/health_monitor.py +++ b/slime/utils/health_monitor.py @@ -83,10 +83,17 @@ def stop(self) -> None: self._is_checking_enabled = False def pause(self) -> None: - """Pause health checking. Called when engines are offloaded.""" + """Pause health checking. Called when engines are offloaded. + + Before pausing, performs a final health check to ensure all engines are healthy. + Any unhealthy engines will be killed before pausing. + """ if self._pause_event is None: return - logger.info("Pausing health monitor...") + logger.info("Pausing health monitor (running final health check first)...") + # Run a final health check before pausing to catch any unhealthy engines + if self._is_checking_enabled: + self._run_health_checks() self._pause_event.set() self._is_checking_enabled = False @@ -136,27 +143,57 @@ def _health_monitor_loop(self) -> None: break def _run_health_checks(self) -> None: - for rollout_engine_id, engine in enumerate(self._rollout_manager.rollout_engines): - if self._stop_event is not None and self._stop_event.is_set(): - break - if self._pause_event is not None and self._pause_event.is_set(): - break - self._check_engine_health(rollout_engine_id, engine) + """Run health checks for all engines in parallel.""" + engines = self._rollout_manager.rollout_engines + if not engines: + return + + # Collect all valid engines with their indices + engine_tasks = [ + (i, engine, engine.health_generate.remote(timeout=self._check_timeout)) + for i, engine in enumerate(engines) + if engine is not None + ] - def _check_engine_health(self, rollout_engine_id, engine) -> None: - if engine is None: - logger.info(f"Skipping health check for engine {rollout_engine_id} (None)") + if not engine_tasks: return + # Wait for all health checks in parallel + refs = [task for _, _, task in engine_tasks] try: - ray.get(engine.health_generate.remote(timeout=self._check_timeout)) + results = ray.get(refs, timeout=self._check_timeout + 5) + # All succeeded + for (engine_id, _, _), result in zip(engine_tasks, results, strict=True): + if result is not True: + logger.error(f"Health check returned non-True for engine {engine_id}: {result}. Killing actor.") + self._kill_engine(rollout_engine_id=engine_id) + else: + logger.debug(f"Health check passed for rollout engine {engine_id}") + except ray.exceptions.GetTimeoutError: + # Timeout - need to check which ones failed + logger.warning("Some health checks timed out, checking individual results...") + self._check_individual_results(engine_tasks) except Exception as e: - logger.error( - f"Health check failed for rollout engine {rollout_engine_id} (ray timeout or error). Killing actor. Exception: {e}" - ) - self._kill_engine(rollout_engine_id=rollout_engine_id) - else: - logger.debug(f"Health check passed for rollout engine {rollout_engine_id}") + # Some other error - check each one individually + logger.warning(f"Batch health check failed with error: {e}, checking individually...") + self._check_individual_results(engine_tasks) + + def _check_individual_results(self, engine_tasks: list) -> None: + """Check health check results individually after batch failure.""" + for engine_id, _engine, ref in engine_tasks: + try: + result = ray.get(ref, timeout=0) # Non-blocking check + if result is not True: + logger.error(f"Health check returned non-True for engine {engine_id}: {result}. Killing actor.") + self._kill_engine(rollout_engine_id=engine_id) + else: + logger.debug(f"Health check passed for rollout engine {engine_id}") + except ray.exceptions.GetTimeoutError: + logger.error(f"Health check timed out for rollout engine {engine_id}. Killing actor.") + self._kill_engine(rollout_engine_id=engine_id) + except Exception as e: + logger.error(f"Health check failed for rollout engine {engine_id}: {e}. Killing actor.") + self._kill_engine(rollout_engine_id=engine_id) def _kill_engine(self, rollout_engine_id: int): logger.info(f"Killing engine group {rollout_engine_id}...")