Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,17 @@ async def _save_lora_adapters_and_sync(
with open(os.path.join(lora_sync_path, "adapter_config.json"), "w", encoding="utf-8") as f:
json.dump(adapter_config, f, ensure_ascii=False, indent=4)

# Send LoRA disk loading request to inference engine.
# Sync after rank-0 disk write so all ranks see consistent state on
# Weka. Critically, this barrier is BEFORE the rank-0 HTTP fan-out to
# inference engines — fan-out to 4 engines on multi-engine configs can
# take >gloo-timeout (~10s) which would hang ranks 1-3 at the barrier.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too low for GLOO timeout?

For Megatron, we init process group here:

https://github.com/hershg/SkyRL/blob/976bfe3a45f6f1e3db3b5e84b5a3d5d485d0eb67/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py#L512-L521

The default timeout value is SKYRL_WORKER_NCCL_TIMEOUT_IN_S which is 600s.

Are you sure this fix is needed? Have you overridden SKYRL_WORKER_NCCL_TIMEOUT_IN_S in some way?

torch.distributed.barrier()

# Rank-0 HTTP fan-out happens AFTER the barrier, so ranks 1-3 don't
# block on its duration. The caller's outer barrier
# (broadcast_to_inference_engines, line ~1029) still gates the next
# training phase on the fan-out completing.
if torch.distributed.get_rank() == 0:
from skyrl.backends.skyrl_train.inference_servers.remote_inference_client import (
RemoteInferenceClient,
)
Expand All @@ -1010,8 +1020,6 @@ async def _save_lora_adapters_and_sync(
lora_request = LoraLoadRequest(lora_path=lora_sync_path, lora_name=lora_name)
await inference_engine_client.update_named_weights(lora_request)

torch.distributed.barrier()

async def broadcast_to_inference_engines(
self,
inference_engine_client: "InferenceEngineInterface",
Expand Down
Loading