diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 58061e9c3c..d237bd64d8 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -999,7 +999,17 @@ async def _save_lora_adapters_and_sync( with open(os.path.join(lora_sync_path, "adapter_config.json"), "w", encoding="utf-8") as f: json.dump(adapter_config, f, ensure_ascii=False, indent=4) - # Send LoRA disk loading request to inference engine. + # Sync after rank-0 disk write so all ranks see consistent state on + # Weka. Critically, this barrier is BEFORE the rank-0 HTTP fan-out to + # inference engines — fan-out to 4 engines on multi-engine configs can + # take >gloo-timeout (~10s) which would hang ranks 1-3 at the barrier. + torch.distributed.barrier() + + # Rank-0 HTTP fan-out happens AFTER the barrier, so ranks 1-3 don't + # block on its duration. The caller's outer barrier + # (broadcast_to_inference_engines, line ~1029) still gates the next + # training phase on the fan-out completing. + if torch.distributed.get_rank() == 0: from skyrl.backends.skyrl_train.inference_servers.remote_inference_client import ( RemoteInferenceClient, ) @@ -1010,8 +1020,6 @@ async def _save_lora_adapters_and_sync( lora_request = LoraLoadRequest(lora_path=lora_sync_path, lora_name=lora_name) await inference_engine_client.update_named_weights(lora_request) - torch.distributed.barrier() - async def broadcast_to_inference_engines( self, inference_engine_client: "InferenceEngineInterface",