From 91ac2f664fa690bef84f5dd146b161b540a8a15e Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Thu, 4 Jun 2026 02:10:44 +0000 Subject: [PATCH 01/17] Fully support --rollout-external-engine-addrs --- docs/en/advanced/sglang-config.md | 6 + docs/zh/advanced/sglang-config.md | 5 + slime/backends/sglang_utils/arguments.py | 5 + slime/backends/sglang_utils/external.py | 156 +++++++++++++ slime/backends/sglang_utils/sglang_engine.py | 6 + slime/ray/placement_group.py | 3 + slime/ray/rollout.py | 231 +++++++++++++++++-- slime/rollout/fully_async_rollout.py | 8 +- slime/rollout/sglang_rollout.py | 6 +- slime/utils/arguments.py | 4 + slime/utils/http_utils.py | 17 +- tests/utils/test_external_sglang_engines.py | 93 ++++++++ 12 files changed, 511 insertions(+), 29 deletions(-) create mode 100644 slime/backends/sglang_utils/external.py create mode 100644 tests/utils/test_external_sglang_engines.py diff --git a/docs/en/advanced/sglang-config.md b/docs/en/advanced/sglang-config.md index 70af72d2ec..5532e77627 100644 --- a/docs/en/advanced/sglang-config.md +++ b/docs/en/advanced/sglang-config.md @@ -275,6 +275,12 @@ python train.py \ ... ``` +slime queries each external engine's `/server_info` endpoint to infer +`rollout_num_gpus`, per-engine GPU counts, SGLang parallel sizes, and +prefill/decode worker types. If no `--sglang-router-ip/--sglang-router-port` +is provided, slime launches its own router and registers the external engines +to it. + > **Note:** `--sglang-config` and `--rollout-external` are mutually exclusive. Use `--sglang-config` when you want slime to manage the full engine lifecycle; use `--rollout-external` when engines are pre-deployed. --- diff --git a/docs/zh/advanced/sglang-config.md b/docs/zh/advanced/sglang-config.md index e68c52c843..807f2b0448 100644 --- a/docs/zh/advanced/sglang-config.md +++ b/docs/zh/advanced/sglang-config.md @@ -275,6 +275,11 @@ python train.py \ ... ``` +slime 会请求每个外部引擎的 `/server_info`,自动推断 +`rollout_num_gpus`、单个 engine 的 GPU 数、SGLang 并行参数,以及 +prefill/decode worker 类型。如果没有提供 `--sglang-router-ip/--sglang-router-port`, +slime 会自己启动 router,并把这些外部引擎注册进去。 + > **注意:** `--sglang-config` 和 `--rollout-external` 互斥。当你希望 slime 管理完整的引擎生命周期时,使用 `--sglang-config`;当引擎已预部署时,使用 `--rollout-external`。 --- diff --git a/slime/backends/sglang_utils/arguments.py b/slime/backends/sglang_utils/arguments.py index 0a4801743f..75fc9b34ce 100644 --- a/slime/backends/sglang_utils/arguments.py +++ b/slime/backends/sglang_utils/arguments.py @@ -157,6 +157,11 @@ def validate_args(args): if getattr(args, "sglang_router_ip", None): args.sglang_router_ip = _wrap_ipv6(args.sglang_router_ip) + if getattr(args, "rollout_external", False) and args.sglang_router_ip is not None: + assert ( + args.sglang_router_port is not None + ), "--sglang-router-port must be set with --sglang-router-ip in --rollout-external mode." + # Mutual-exclusion checks for PD disaggregation / sglang-config. assert not ( getattr(args, "prefill_num_servers", None) is not None and args.rollout_external diff --git a/slime/backends/sglang_utils/external.py b/slime/backends/sglang_utils/external.py new file mode 100644 index 0000000000..d410cc2a09 --- /dev/null +++ b/slime/backends/sglang_utils/external.py @@ -0,0 +1,156 @@ +"""Helpers for pre-launched external SGLang engines.""" + +from __future__ import annotations + +import dataclasses +from urllib.parse import urlparse + +import requests + + +@dataclasses.dataclass(frozen=True) +class ExternalEngineInfo: + url: str + host: str + port: int + worker_type: str + num_gpus: int + tp_size: int + pp_size: int + dp_size: int + ep_size: int + disaggregation_bootstrap_port: int | None = None + server_info: dict = dataclasses.field(default_factory=dict) + + @property + def is_pd_worker(self) -> bool: + return self.worker_type in ("prefill", "decode") + + def to_dict(self) -> dict: + return dataclasses.asdict(self) + + +def normalize_external_engine_addr(addr: str) -> str: + """Normalize ``host:port`` or ``http://host:port`` to an HTTP base URL.""" + if "://" not in addr: + addr = f"http://{addr}" + addr = addr.rstrip("/") + parsed = urlparse(addr) + if parsed.scheme != "http" or parsed.hostname is None or parsed.port is None: + raise ValueError( + f"Invalid external SGLang engine address {addr!r}. " + "Use host:port or http://host:port (IPv6 must be bracketed)." + ) + return addr + + +def external_engine_info_from_dict(data: dict) -> ExternalEngineInfo: + return ExternalEngineInfo(**data) + + +def _positive_int(value, default: int) -> int: + if value is None: + return default + value = int(value) + return value if value > 0 else default + + +def _get_server_info(url: str, timeout: float = 30.0) -> dict: + errors = [] + for endpoint in ("/server_info", "/get_server_info"): + try: + response = requests.get(f"{url}{endpoint}", timeout=timeout) + response.raise_for_status() + return response.json() + except Exception as exc: + errors.append(f"{endpoint}: {exc}") + raise RuntimeError(f"Failed to fetch SGLang server info from {url}: {'; '.join(errors)}") + + +def _infer_worker_type(server_info: dict) -> str: + if server_info.get("encoder_only"): + return "encoder" + mode = server_info.get("disaggregation_mode") + if mode in ("prefill", "decode"): + return mode + return "regular" + + +def discover_external_engines(addrs: list[str], timeout: float = 30.0) -> list[ExternalEngineInfo]: + infos = [] + for addr in addrs: + url = normalize_external_engine_addr(addr) + parsed = urlparse(url) + assert parsed.hostname is not None and parsed.port is not None + server_info = _get_server_info(url, timeout=timeout) + + pp_size = _positive_int(server_info.get("pp_size") or server_info.get("pipeline_parallel_size"), 1) + tp_size = _positive_int(server_info.get("tp_size") or server_info.get("tensor_parallel_size"), 1) + dp_size = _positive_int(server_info.get("dp_size") or server_info.get("data_parallel_size"), 1) + ep_size = _positive_int(server_info.get("ep_size") or server_info.get("expert_parallel_size"), 1) + num_gpus = _positive_int( + server_info.get("num_gpus") or server_info.get("num_gpus_per_engine"), + tp_size * pp_size, + ) + bootstrap_port = server_info.get("disaggregation_bootstrap_port") + bootstrap_port = int(bootstrap_port) if bootstrap_port is not None else None + + infos.append( + ExternalEngineInfo( + url=url, + host=parsed.hostname, + port=parsed.port, + worker_type=_infer_worker_type(server_info), + num_gpus=num_gpus, + tp_size=tp_size, + pp_size=pp_size, + dp_size=dp_size, + ep_size=ep_size, + disaggregation_bootstrap_port=bootstrap_port, + server_info=server_info, + ) + ) + return infos + + +def apply_external_engine_info_to_args(args, logger=None) -> None: + """Detect external engines and store the derived topology on ``args``.""" + if not getattr(args, "rollout_external", False): + return + + addrs = getattr(args, "rollout_external_engine_addrs", None) + if not addrs: + raise ValueError("--rollout-external requires --rollout-external-engine-addrs.") + + infos = discover_external_engines(addrs) + if not infos: + raise ValueError("--rollout-external-engine-addrs did not contain any engines.") + + args.rollout_external_engine_infos = [info.to_dict() for info in infos] + args.rollout_num_engines = len(infos) + args.rollout_num_gpus = sum(info.num_gpus for info in infos) + + # Keep legacy homogeneous fields meaningful for code paths that still read + # them. Per-group rollout startup uses the exact per-engine values below. + first = infos[0] + args.rollout_num_gpus_per_engine = first.num_gpus + args.sglang_pipeline_parallel_size = first.pp_size + args.sglang_data_parallel_size = first.dp_size + args.sglang_expert_parallel_size = first.ep_size + if any(info.dp_size > 1 for info in infos): + args.sglang_enable_dp_attention = True + + if logger is not None: + summary = [ + { + "url": info.url, + "worker_type": info.worker_type, + "num_gpus": info.num_gpus, + "tp_size": info.tp_size, + "pp_size": info.pp_size, + "dp_size": info.dp_size, + "ep_size": info.ep_size, + } + for info in infos + ] + logger.info(f"Detected external SGLang engines: {summary}") diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index 607f4c0d07..de5c107861 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -651,8 +651,14 @@ def _compute_server_args( "model_path", "trust_remote_code", "random_seed", + "host", + "port", "nccl_port", + "nnodes", + "node_rank", "dist_init_addr", + "gpu_id_step", + "base_gpu_id", "skip_server_warmup", "enable_draft_weights_cpu_backup", "enable_metrics", diff --git a/slime/ray/placement_group.py b/slime/ray/placement_group.py index e8a778030e..d6af2e8102 100644 --- a/slime/ray/placement_group.py +++ b/slime/ray/placement_group.py @@ -90,6 +90,9 @@ def create_placement_groups(args): elif args.colocate: num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node rollout_offset = 0 + elif args.rollout_external: + num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node + rollout_offset = num_gpus else: num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node + args.rollout_num_gpus rollout_offset = args.actor_num_nodes * args.actor_num_gpus_per_node diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 4998c82329..f55cea3bca 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -10,10 +10,18 @@ import numpy as np import ray +import requests +import sglang_router import torch +from packaging.version import parse from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS +from slime.backends.sglang_utils.external import ( + ExternalEngineInfo, + discover_external_engines, + external_engine_info_from_dict, +) from slime.backends.sglang_utils.sglang_config import ModelConfig, ServerGroupConfig, SglangConfig from slime.backends.sglang_utils.sglang_engine import SGLangEngine from slime.rollout.base_types import call_rollout_fn @@ -58,14 +66,19 @@ class ServerGroup: model_path: str | None = None # checkpoint path for update_weights_from_disk router_ip: str | None = None router_port: int | None = None + external_worker_specs: list[ExternalEngineInfo] = dataclasses.field(default_factory=list) @property def nodes_per_engine(self): + if self.args.rollout_external: + return 1 return max(1, self.num_gpus_per_engine // self.args.num_gpus_per_node) @property def engines(self): """Node-0 engines only (for multi-node serving).""" + if self.args.rollout_external: + return [engine for engine in self.all_engines if engine is not None] return self.all_engines[:: self.nodes_per_engine] def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[list, dict[int, int]]: @@ -85,6 +98,9 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis self.num_new_engines = 0 return [], port_cursors + if self.args.rollout_external: + return self._start_external_proxy_engines(port_cursors) + num_gpu_per_engine = min(self.num_gpus_per_engine, self.args.num_gpus_per_node) pg, reordered_bundle_indices, reordered_gpu_ids = self.pg @@ -158,26 +174,62 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis if self.num_new_engines == 0: return [], port_cursors - if self.args.rollout_external: - addr_and_ports = _allocate_rollout_engine_addr_and_ports_external( - args=self.args, rollout_engines=rollout_engines + # Compute base_port from the maximum cursor across all nodes that + # this group's engines may land on (conservative: just use global max). + base_port = max(port_cursors.values()) if port_cursors else 15000 + addr_and_ports, port_cursors = _allocate_rollout_engine_addr_and_ports_normal( + args=self.args, + rollout_engines=rollout_engines, + worker_type=self.worker_type, + num_gpus_per_engine=self.num_gpus_per_engine, + rank_offset=self.rank_offset, + base_port=base_port, + ) + + init_handles = [ + engine.init.remote( + **(addr_and_ports[rank]), + router_ip=self.router_ip, + router_port=self.router_port, ) - else: - # Compute base_port from the maximum cursor across all nodes that - # this group's engines may land on (conservative: just use global max). - base_port = max(port_cursors.values()) if port_cursors else 15000 - addr_and_ports, port_cursors = _allocate_rollout_engine_addr_and_ports_normal( - args=self.args, - rollout_engines=rollout_engines, - worker_type=self.worker_type, + for rank, engine in rollout_engines + ] + return init_handles, port_cursors + + def _start_external_proxy_engines(self, port_cursors: dict[int, int]) -> tuple[list, dict[int, int]]: + """Create CPU-only proxy actors for pre-launched external workers.""" + assert self.external_worker_specs, "external_worker_specs must be populated for rollout_external." + + RolloutRayActor = ray.remote(SGLangEngine) + rollout_engines = [] + for i, spec in enumerate(self.external_worker_specs): + if self.all_engines[i] is not None: + continue + + global_rank = self.rank_offset + i + rollout_engine = RolloutRayActor.options(num_cpus=0.2, num_gpus=0).remote( + self.args, + rank=global_rank, + worker_type=spec.worker_type, + base_gpu_id=0, + sglang_overrides=self.sglang_overrides, num_gpus_per_engine=self.num_gpus_per_engine, - rank_offset=self.rank_offset, - base_port=base_port, ) + rollout_engines.append((global_rank, rollout_engine)) + self.all_engines[i] = rollout_engine + self.num_new_engines = len(rollout_engines) + if self.num_new_engines == 0: + return [], port_cursors + + addr_and_ports = _allocate_rollout_engine_addr_and_ports_external( + self.external_worker_specs, + rollout_engines, + rank_offset=self.rank_offset, + ) init_handles = [ engine.init.remote( - **(addr_and_ports[rank]), + **addr_and_ports[rank], router_ip=self.router_ip, router_port=self.router_port, ) @@ -845,17 +897,24 @@ def _validate_rollout_id_annotated(node, depth=0): _validate_rollout_id_annotated(item, depth + 1) -def _allocate_rollout_engine_addr_and_ports_external(args, rollout_engines): +def _allocate_rollout_engine_addr_and_ports_external( + external_worker_specs: list[ExternalEngineInfo], + rollout_engines, + *, + rank_offset: int = 0, +): addr_and_ports = {} for rank, _ in rollout_engines: - addr = args.rollout_external_engine_addrs[rank] - [host, port] = addr.split(":") + spec = external_worker_specs[rank - rank_offset] + addr = f"{spec.host}:{spec.port}" addr_and_ports[rank] = dict( dist_init_addr=addr, nccl_port=None, - host=host, - port=int(port), + host=spec.host, + port=spec.port, ) + if spec.worker_type == "prefill": + addr_and_ports[rank]["disaggregation_bootstrap_port"] = spec.disaggregation_bootstrap_port return addr_and_ports @@ -1018,6 +1077,137 @@ def _compute_megatron_num_gpus(args) -> int: return num +def _external_engine_infos_from_args(args) -> list[ExternalEngineInfo]: + raw_infos = getattr(args, "rollout_external_engine_infos", None) + if raw_infos is None: + addrs = getattr(args, "rollout_external_engine_addrs", None) + if not addrs: + raise RuntimeError("--rollout-external requires --rollout-external-engine-addrs.") + infos = discover_external_engines(addrs) + args.rollout_external_engine_infos = [info.to_dict() for info in infos] + args.rollout_num_engines = len(infos) + args.rollout_num_gpus = sum(info.num_gpus for info in infos) + if infos: + args.rollout_num_gpus_per_engine = infos[0].num_gpus + return infos + return [external_engine_info_from_dict(info) if isinstance(info, dict) else info for info in raw_infos] + + +def _get_registered_router_worker_urls(router_ip: str, router_port: int) -> set[str]: + router_addr = f"http://{router_ip}:{router_port}" + for endpoint in ("/workers", "/list_workers"): + try: + response = requests.get(f"{router_addr}{endpoint}", timeout=30) + response.raise_for_status() + payload = response.json() + except Exception: + continue + if "workers" in payload: + return {worker["url"] if isinstance(worker, dict) else worker for worker in payload["workers"]} + if "urls" in payload: + return set(payload["urls"]) + return set() + + +def _register_external_workers_to_router(args, engine_infos: list[ExternalEngineInfo]) -> None: + if not engine_infos: + return + + router_addr = f"http://{args.sglang_router_ip}:{args.sglang_router_port}" + registered_urls = _get_registered_router_worker_urls(args.sglang_router_ip, args.sglang_router_port) + has_pd = any(info.is_pd_worker for info in engine_infos) + + if parse(sglang_router.__version__) <= parse("0.2.1"): + assert not has_pd, "PD disaggregation for external engines requires sglang_router > 0.2.1." + for info in engine_infos: + if info.url in registered_urls: + continue + response = requests.post(f"{router_addr}/add_worker?url={info.url}") + response.raise_for_status() + return + + for info in engine_infos: + if info.worker_type == "encoder": + continue + if info.url in registered_urls: + continue + payload = { + "url": info.url, + "worker_type": info.worker_type, + } + if info.worker_type == "prefill": + if info.disaggregation_bootstrap_port is None: + raise RuntimeError( + f"External prefill worker {info.url} did not report disaggregation_bootstrap_port " + "from /server_info; cannot register it to the PD router." + ) + payload["bootstrap_port"] = info.disaggregation_bootstrap_port + response = requests.post(f"{router_addr}/workers", json=payload) + response.raise_for_status() + + +def _start_external_rollout_servers(args, pg) -> dict[str, RolloutServer]: + engine_infos = _external_engine_infos_from_args(args) + has_pd = any(info.is_pd_worker for info in engine_infos) + router_ip, router_port = _start_router(args, has_pd_disaggregation=has_pd) + args.sglang_router_ip = router_ip + args.sglang_router_port = router_port + _register_external_workers_to_router(args, engine_infos) + + specs_by_topology: dict[tuple, list[ExternalEngineInfo]] = {} + for info in engine_infos: + key = (info.worker_type, info.num_gpus, info.tp_size, info.pp_size, info.dp_size, info.ep_size) + specs_by_topology.setdefault(key, []).append(info) + + server_groups = [] + engine_offset = 0 + gpu_offset = 0 + init_handles = [] + for (worker_type, num_gpus, tp_size, pp_size, dp_size, ep_size), group_specs in specs_by_topology.items(): + overrides = { + "tp_size": tp_size, + "pp_size": pp_size, + "dp_size": dp_size, + "ep_size": ep_size, + } + group = ServerGroup( + args=args, + pg=pg, + all_engines=[None] * len(group_specs), + num_gpus_per_engine=num_gpus, + num_new_engines=0, + worker_type=worker_type, + rank_offset=engine_offset, + gpu_offset=gpu_offset, + sglang_overrides=overrides, + needs_offload=False, + model_path=args.hf_checkpoint, + router_ip=router_ip, + router_port=router_port, + external_worker_specs=group_specs, + ) + handles, _ = group.start_engines({}) + init_handles.extend(handles) + server_groups.append(group) + + engine_offset += len(group_specs) + gpu_offset += len(group_specs) * num_gpus + + if init_handles: + ray.get(init_handles) + + args.sglang_model_routers = {"default": (router_ip, router_port)} + return { + "default": RolloutServer( + server_groups=server_groups, + router_ip=router_ip, + router_port=router_port, + model_name="default", + update_weights=True, + ) + } + + def start_rollout_servers(args, pg) -> dict[str, RolloutServer]: """Start rollout servers: one per model, each with its own router. @@ -1031,6 +1221,9 @@ def start_rollout_servers(args, pg) -> dict[str, RolloutServer]: Note: ``init_http_client`` should be called separately before this, as the HTTP client is shared across all servers. """ + if args.rollout_external: + return _start_external_rollout_servers(args, pg) + config = _resolve_sglang_config(args) servers: dict[str, RolloutServer] = {} diff --git a/slime/rollout/fully_async_rollout.py b/slime/rollout/fully_async_rollout.py index a54f4083aa..c301075c5c 100644 --- a/slime/rollout/fully_async_rollout.py +++ b/slime/rollout/fully_async_rollout.py @@ -11,8 +11,8 @@ :func:`generate_and_rm_group` which dispatches to those. Concurrency is sourced from ``args.sglang_server_concurrency`` and scaled by -the number of sglang engines (``rollout_num_gpus // rollout_num_gpus_per_engine``) -to match the per-sample semaphore cap in :mod:`slime.rollout.sglang_rollout`. +the number of sglang engines to match the per-sample semaphore cap in +:mod:`slime.rollout.sglang_rollout`. The worker is intentionally oblivious to slime's higher-level pause / weight-update signalling (e.g. ``GenerateState.aborted``). Each in-flight @@ -34,6 +34,7 @@ from slime.rollout.sglang_rollout import GenerateState, generate_and_rm_group from slime.utils.async_utils import run +from slime.utils.http_utils import get_rollout_num_engines from slime.utils.types import Sample __all__ = [ @@ -54,9 +55,8 @@ def _get_global_worker(args, data_buffer) -> AsyncRolloutWorker: with _worker_lock: if _global_worker is None or not _global_worker.worker_thread.is_alive(): logger.info("starting fully-async rollout worker") - num_engines = max(1, args.rollout_num_gpus // args.rollout_num_gpus_per_engine) _global_worker = AsyncRolloutWorker( - args, data_buffer, concurrency=args.sglang_server_concurrency * num_engines + args, data_buffer, concurrency=args.sglang_server_concurrency * get_rollout_num_engines(args) ) _global_worker.start() return _global_worker diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index c7f86b98ca..eee3a13680 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -19,7 +19,7 @@ from slime.utils.async_utils import run from slime.utils.data import Dataset from slime.utils.eval_config import EvalDatasetConfig -from slime.utils.http_utils import get, post +from slime.utils.http_utils import get, get_rollout_num_engines, post from slime.utils.misc import SingletonMeta, load_function from slime.utils.processing_utils import ( build_processor_kwargs, @@ -91,9 +91,7 @@ def __init__(self, args: Namespace) -> None: self.tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) self.processor = load_processor(args.hf_checkpoint, trust_remote_code=True) - self.semaphore = asyncio.Semaphore( - args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine - ) + self.semaphore = asyncio.Semaphore(args.sglang_server_concurrency * get_rollout_num_engines(args)) self.sampling_params: dict[str, Any] = dict( temperature=args.rollout_temperature, top_p=args.rollout_top_p, diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index d7f8634550..f9d787fd5d 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -10,6 +10,7 @@ from slime.backends.sglang_utils.arguments import sglang_parse_args from slime.backends.sglang_utils.arguments import validate_args as sglang_validate_args +from slime.backends.sglang_utils.external import apply_external_engine_info_to_args from slime.utils.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list from slime.utils.logging_utils import configure_logger @@ -1786,6 +1787,9 @@ def slime_validate_args(args): ) args.debug_train_only = True + if not args.debug_train_only: + apply_external_engine_info_to_args(args, logger=logger) + args.use_critic = args.advantage_estimator == "ppo" # Critic always uses the same GPU count as actor. args.critic_num_gpus_per_node = args.actor_num_gpus_per_node diff --git a/slime/utils/http_utils.py b/slime/utils/http_utils.py index 7ce395c4db..ced387e573 100644 --- a/slime/utils/http_utils.py +++ b/slime/utils/http_utils.py @@ -198,13 +198,26 @@ async def _post(client, url, payload, max_retries=60, headers=None): return output +def get_rollout_num_engines(args) -> int: + """Return the number of rollout HTTP engines behind the router.""" + if (num_engines := getattr(args, "rollout_num_engines", None)) is not None: + return int(num_engines) + + rollout_num_gpus = getattr(args, "rollout_num_gpus", None) or 0 + rollout_num_gpus_per_engine = getattr(args, "rollout_num_gpus_per_engine", None) or 1 + if rollout_num_gpus <= 0: + return 0 + return max(1, rollout_num_gpus // rollout_num_gpus_per_engine) + + def init_http_client(args): """Initialize HTTP client and optionally enable distributed POST via Ray.""" global _http_client, _client_concurrency, _distributed_post_enabled - if not args.rollout_num_gpus: + num_engines = get_rollout_num_engines(args) + if num_engines <= 0: return - _client_concurrency = args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine + _client_concurrency = args.sglang_server_concurrency * num_engines if _http_client is None: _http_client = httpx.AsyncClient( limits=httpx.Limits(max_connections=_client_concurrency), diff --git a/tests/utils/test_external_sglang_engines.py b/tests/utils/test_external_sglang_engines.py new file mode 100644 index 0000000000..50dd468975 --- /dev/null +++ b/tests/utils/test_external_sglang_engines.py @@ -0,0 +1,93 @@ +from argparse import Namespace + +from slime.backends.sglang_utils.external import apply_external_engine_info_to_args, discover_external_engines +from slime.utils.http_utils import get_rollout_num_engines + + +class _Response: + def __init__(self, payload, status_code=200): + self.payload = payload + self.status_code = status_code + + def raise_for_status(self): + if self.status_code >= 400: + raise RuntimeError(f"HTTP {self.status_code}") + + def json(self): + return self.payload + + +def test_discover_external_engines_reads_server_info(monkeypatch): + def fake_get(url, timeout): + assert timeout == 30.0 + assert url == "http://host1:10090/server_info" + return _Response( + { + "tp_size": 4, + "pp_size": 2, + "dp_size": 1, + "ep_size": 4, + "disaggregation_mode": "null", + } + ) + + monkeypatch.setattr("slime.backends.sglang_utils.external.requests.get", fake_get) + + infos = discover_external_engines(["host1:10090"]) + + assert len(infos) == 1 + info = infos[0] + assert info.url == "http://host1:10090" + assert info.host == "host1" + assert info.port == 10090 + assert info.worker_type == "regular" + assert info.num_gpus == 8 + assert info.tp_size == 4 + assert info.pp_size == 2 + assert info.dp_size == 1 + assert info.ep_size == 4 + + +def test_apply_external_engine_info_handles_pd(monkeypatch): + payloads = { + "http://prefill:10090/server_info": { + "tp_size": 2, + "pp_size": 1, + "dp_size": 1, + "ep_size": 1, + "disaggregation_mode": "prefill", + "disaggregation_bootstrap_port": 12090, + }, + "http://decode:10091/server_info": { + "tp_size": 4, + "pp_size": 1, + "dp_size": 2, + "ep_size": 2, + "disaggregation_mode": "decode", + }, + } + + def fake_get(url, timeout): + return _Response(payloads[url]) + + monkeypatch.setattr("slime.backends.sglang_utils.external.requests.get", fake_get) + args = Namespace( + rollout_external=True, + rollout_external_engine_addrs=["prefill:10090", "decode:10091"], + rollout_num_gpus=None, + rollout_num_gpus_per_engine=1, + sglang_pipeline_parallel_size=1, + sglang_data_parallel_size=1, + sglang_expert_parallel_size=1, + sglang_enable_dp_attention=False, + ) + + apply_external_engine_info_to_args(args) + + assert args.rollout_num_gpus == 6 + assert args.rollout_num_engines == 2 + assert get_rollout_num_engines(args) == 2 + assert args.rollout_num_gpus_per_engine == 2 + assert args.sglang_enable_dp_attention is True + assert [info["worker_type"] for info in args.rollout_external_engine_infos] == ["prefill", "decode"] + assert args.rollout_external_engine_infos[0]["disaggregation_bootstrap_port"] == 12090 From 773f7167437856d4071d709382e00f6f77ae7cc0 Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Thu, 4 Jun 2026 02:15:53 +0000 Subject: [PATCH 02/17] remove --rollout-external --- docs/en/advanced/sglang-config.md | 9 ++++----- docs/zh/advanced/sglang-config.md | 9 ++++----- slime/backends/sglang_utils/arguments.py | 10 +++++----- slime/backends/sglang_utils/external.py | 7 +++---- slime/ray/rollout.py | 2 +- slime/utils/arguments.py | 11 +++++++++-- tests/utils/test_external_sglang_engines.py | 10 +++++++++- 7 files changed, 35 insertions(+), 23 deletions(-) diff --git a/docs/en/advanced/sglang-config.md b/docs/en/advanced/sglang-config.md index 5532e77627..6e9bb5a504 100644 --- a/docs/en/advanced/sglang-config.md +++ b/docs/en/advanced/sglang-config.md @@ -257,7 +257,7 @@ Overrides take **highest priority**, overriding both the base `--sglang-*` CLI a ### 7. Standalone SGLang Launcher -While `--sglang-config` is designed for slime's training pipeline, it also works as a powerful launcher for pure inference scenarios using the `--rollout-external` pattern or by configuring slime to focus solely on serving. +While `--sglang-config` is designed for slime's training pipeline, it also works as a powerful launcher for pure inference scenarios using external engine addresses or by configuring slime to focus solely on serving. **Using external engines with a pre-launched topology:** @@ -270,7 +270,6 @@ python -m sglang.launch_server --model-path /path/to/model --port 10091 ... # Step 2: Connect slime to external engines python train.py \ - --rollout-external \ --rollout-external-engine-addrs host1:10090 host2:10091 \ ... ``` @@ -281,7 +280,7 @@ prefill/decode worker types. If no `--sglang-router-ip/--sglang-router-port` is provided, slime launches its own router and registers the external engines to it. -> **Note:** `--sglang-config` and `--rollout-external` are mutually exclusive. Use `--sglang-config` when you want slime to manage the full engine lifecycle; use `--rollout-external` when engines are pre-deployed. +> **Note:** `--sglang-config` and `--rollout-external-engine-addrs` are mutually exclusive. Use `--sglang-config` when you want slime to manage the full engine lifecycle; use `--rollout-external-engine-addrs` when engines are pre-deployed. --- @@ -338,7 +337,7 @@ When the config is loaded, slime applies the following resolution cascade: | Flag | Conflict Reason | |------|----------------| | `--prefill-num-servers` | PD disaggregation is configured via `server_groups` in the YAML | -| `--rollout-external` | External engines have their own topology; config manages the lifecycle internally | +| `--rollout-external-engine-addrs` | External engines have their own topology; config manages the lifecycle internally | --- @@ -452,7 +451,7 @@ Use `get_model_url(args, "model_name", "/endpoint")` from `slime.rollout.sglang_ ### Q: Can I use `--sglang-config` without training (inference only)? -While `--sglang-config` is designed for slime's training loop, you can effectively use it for inference-only scenarios by configuring a rollout-only run. For fully standalone SGLang serving, consider using SGLang's native `launch_server` directly or the `--rollout-external` mode for connecting to pre-deployed engines. +While `--sglang-config` is designed for slime's training loop, you can effectively use it for inference-only scenarios by configuring a rollout-only run. For fully standalone SGLang serving, consider using SGLang's native `launch_server` directly or `--rollout-external-engine-addrs` for connecting to pre-deployed engines. ### Q: What is the relationship between `--sglang-config` and `--prefill-num-servers`? diff --git a/docs/zh/advanced/sglang-config.md b/docs/zh/advanced/sglang-config.md index 807f2b0448..ce05d36002 100644 --- a/docs/zh/advanced/sglang-config.md +++ b/docs/zh/advanced/sglang-config.md @@ -257,7 +257,7 @@ sglang: ### 7. 独立 SGLang 启动器 -虽然 `--sglang-config` 是为 slime 的训练流水线设计的,但它也可以作为纯推理场景的强大启动器,通过 `--rollout-external` 模式或配置 slime 仅关注推理服务。 +虽然 `--sglang-config` 是为 slime 的训练流水线设计的,但它也可以作为纯推理场景的强大启动器,通过外部 engine 地址或配置 slime 仅关注推理服务。 **使用预启动的外部引擎:** @@ -270,7 +270,6 @@ python -m sglang.launch_server --model-path /path/to/model --port 10091 ... # 步骤 2:将 slime 连接到外部引擎 python train.py \ - --rollout-external \ --rollout-external-engine-addrs host1:10090 host2:10091 \ ... ``` @@ -280,7 +279,7 @@ slime 会请求每个外部引擎的 `/server_info`,自动推断 prefill/decode worker 类型。如果没有提供 `--sglang-router-ip/--sglang-router-port`, slime 会自己启动 router,并把这些外部引擎注册进去。 -> **注意:** `--sglang-config` 和 `--rollout-external` 互斥。当你希望 slime 管理完整的引擎生命周期时,使用 `--sglang-config`;当引擎已预部署时,使用 `--rollout-external`。 +> **注意:** `--sglang-config` 和 `--rollout-external-engine-addrs` 互斥。当你希望 slime 管理完整的引擎生命周期时,使用 `--sglang-config`;当引擎已预部署时,使用 `--rollout-external-engine-addrs`。 --- @@ -337,7 +336,7 @@ slime 自动为每个 sample 分配一个唯一的 `session_id`(存储在 `sam | 选项 | 冲突原因 | |------|----------| | `--prefill-num-servers` | PD 分离通过 YAML 中的 `server_groups` 配置 | -| `--rollout-external` | 外部引擎有自己的拓扑;config 在内部管理生命周期 | +| `--rollout-external-engine-addrs` | 外部引擎有自己的拓扑;config 在内部管理生命周期 | --- @@ -451,7 +450,7 @@ async def generate_with_models(args, sample, sampling_params): ### Q: 可以不训练,只用 `--sglang-config` 做推理吗? -虽然 `--sglang-config` 是为 slime 的训练循环设计的,但你可以通过配置仅 rollout 的运行来实现纯推理场景。对于完全独立的 SGLang 推理服务,建议直接使用 SGLang 原生的 `launch_server`,或使用 `--rollout-external` 模式连接预部署的引擎。 +虽然 `--sglang-config` 是为 slime 的训练循环设计的,但你可以通过配置仅 rollout 的运行来实现纯推理场景。对于完全独立的 SGLang 推理服务,建议直接使用 SGLang 原生的 `launch_server`,或使用 `--rollout-external-engine-addrs` 连接预部署的引擎。 ### Q: `--sglang-config` 和 `--prefill-num-servers` 是什么关系? diff --git a/slime/backends/sglang_utils/arguments.py b/slime/backends/sglang_utils/arguments.py index 75fc9b34ce..17385bd2ea 100644 --- a/slime/backends/sglang_utils/arguments.py +++ b/slime/backends/sglang_utils/arguments.py @@ -160,16 +160,16 @@ def validate_args(args): if getattr(args, "rollout_external", False) and args.sglang_router_ip is not None: assert ( args.sglang_router_port is not None - ), "--sglang-router-port must be set with --sglang-router-ip in --rollout-external mode." + ), "--sglang-router-port must be set with --sglang-router-ip when using --rollout-external-engine-addrs." # Mutual-exclusion checks for PD disaggregation / sglang-config. assert not ( - getattr(args, "prefill_num_servers", None) is not None and args.rollout_external - ), "prefill_num_servers cannot be set when rollout_external is set." + getattr(args, "prefill_num_servers", None) is not None and getattr(args, "rollout_external", False) + ), "prefill_num_servers cannot be set with --rollout-external-engine-addrs." assert not ( - getattr(args, "sglang_config", None) is not None and args.rollout_external - ), "sglang_config cannot be set when rollout_external is set." + getattr(args, "sglang_config", None) is not None and getattr(args, "rollout_external", False) + ), "sglang_config cannot be set with --rollout-external-engine-addrs." assert not ( getattr(args, "sglang_config", None) is not None and getattr(args, "prefill_num_servers", None) is not None diff --git a/slime/backends/sglang_utils/external.py b/slime/backends/sglang_utils/external.py index d410cc2a09..b61ae6ad4f 100644 --- a/slime/backends/sglang_utils/external.py +++ b/slime/backends/sglang_utils/external.py @@ -115,13 +115,12 @@ def discover_external_engines(addrs: list[str], timeout: float = 30.0) -> list[E def apply_external_engine_info_to_args(args, logger=None) -> None: """Detect external engines and store the derived topology on ``args``.""" - if not getattr(args, "rollout_external", False): - return - addrs = getattr(args, "rollout_external_engine_addrs", None) if not addrs: - raise ValueError("--rollout-external requires --rollout-external-engine-addrs.") + args.rollout_external = False + return + args.rollout_external = True infos = discover_external_engines(addrs) if not infos: raise ValueError("--rollout-external-engine-addrs did not contain any engines.") diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index f55cea3bca..4addd539e7 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -1082,7 +1082,7 @@ def _external_engine_infos_from_args(args) -> list[ExternalEngineInfo]: if raw_infos is None: addrs = getattr(args, "rollout_external_engine_addrs", None) if not addrs: - raise RuntimeError("--rollout-external requires --rollout-external-engine-addrs.") + raise RuntimeError("External rollout requires --rollout-external-engine-addrs.") infos = discover_external_engines(addrs) args.rollout_external_engine_infos = [info.to_dict() for info in infos] args.rollout_num_engines = len(infos) diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index f9d787fd5d..cd107fce1f 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -510,8 +510,8 @@ def add_rollout_arguments(parser): parser.add_argument( "--rollout-external", action="store_true", - default=False, - help="Use external SGLang instances instead of launching them inside the framework.", + default=None, + help=argparse.SUPPRESS, ) parser.add_argument( "--rollout-external-engine-addrs", @@ -1787,6 +1787,13 @@ def slime_validate_args(args): ) args.debug_train_only = True + if getattr(args, "rollout_external", None) is not None: + logger.warning( + "--rollout-external is deprecated and ignored. " + "Set --rollout-external-engine-addrs to use pre-launched external SGLang engines." + ) + args.rollout_external = args.rollout_external_engine_addrs is not None + if not args.debug_train_only: apply_external_engine_info_to_args(args, logger=logger) diff --git a/tests/utils/test_external_sglang_engines.py b/tests/utils/test_external_sglang_engines.py index 50dd468975..a794b819d8 100644 --- a/tests/utils/test_external_sglang_engines.py +++ b/tests/utils/test_external_sglang_engines.py @@ -72,7 +72,6 @@ def fake_get(url, timeout): monkeypatch.setattr("slime.backends.sglang_utils.external.requests.get", fake_get) args = Namespace( - rollout_external=True, rollout_external_engine_addrs=["prefill:10090", "decode:10091"], rollout_num_gpus=None, rollout_num_gpus_per_engine=1, @@ -84,6 +83,7 @@ def fake_get(url, timeout): apply_external_engine_info_to_args(args) + assert args.rollout_external is True assert args.rollout_num_gpus == 6 assert args.rollout_num_engines == 2 assert get_rollout_num_engines(args) == 2 @@ -91,3 +91,11 @@ def fake_get(url, timeout): assert args.sglang_enable_dp_attention is True assert [info["worker_type"] for info in args.rollout_external_engine_infos] == ["prefill", "decode"] assert args.rollout_external_engine_infos[0]["disaggregation_bootstrap_port"] == 12090 + + +def test_apply_external_engine_info_no_addrs_disables_external(): + args = Namespace(rollout_external_engine_addrs=None) + + apply_external_engine_info_to_args(args) + + assert args.rollout_external is False From b99bcbe356f4ad3ed8a4c96539e741b702831ac1 Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Thu, 4 Jun 2026 02:36:13 +0000 Subject: [PATCH 03/17] fix --- .github/workflows/pr-test.yml | 2 +- .github/workflows/pr-test.yml.j2 | 1 + slime/backends/sglang_utils/external.py | 10 - slime/ray/rollout.py | 2 - tests/test_qwen3.5_0.8B_external_pd.py | 263 ++++++++++++++++++++ tests/utils/test_external_sglang_engines.py | 4 +- 6 files changed, 267 insertions(+), 15 deletions(-) create mode 100644 tests/test_qwen3.5_0.8B_external_pd.py diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index f329d10c2b..4c8e74517a 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -205,7 +205,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"enable_eval": "0", "num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_glm4.7_30B_A3B_pd_mooncake.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"num_gpus": 8, "test_file": "test_qwen3.6_35B_A3B_pd_mooncake.py", "use_deepep": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo_disaggregate.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo_train_critic_only.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_debug_rollout_then_train.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_opd_sglang.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_fully_async_short.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_streaming_partial_rollout.py"}, {"num_gpus": 4, "test_file": "test_qwen3.5_0.8B_gsm8k_short.py"}, {"num_gpus": 4, "test_file": "test_qwen3.5_0.8B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_args": "--save-optimizer gpu --load-optimizer gpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer gpu --load-optimizer cpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer cpu --load-optimizer cpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer cpu --load-optimizer gpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--async-save", "test_file": "test_qwen3_4B_ckpt.py"}] + info: [{"enable_eval": "0", "num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_glm4.7_30B_A3B_pd_mooncake.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"num_gpus": 8, "test_file": "test_qwen3.6_35B_A3B_pd_mooncake.py", "use_deepep": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo_disaggregate.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo_train_critic_only.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_debug_rollout_then_train.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_opd_sglang.py"}, {"num_gpus": 8, "test_file": "test_qwen3.5_0.8B_external_pd.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_fully_async_short.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_streaming_partial_rollout.py"}, {"num_gpus": 4, "test_file": "test_qwen3.5_0.8B_gsm8k_short.py"}, {"num_gpus": 4, "test_file": "test_qwen3.5_0.8B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_args": "--save-optimizer gpu --load-optimizer gpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer gpu --load-optimizer cpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer cpu --load-optimizer cpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer cpu --load-optimizer gpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--async-save", "test_file": "test_qwen3_4B_ckpt.py"}] defaults: run: working-directory: ${{ github.workspace }} diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 0122182bdd..9c898a5094 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -35,6 +35,7 @@ {'test_file': 'test_qwen3_0.6B_parallel_check.py', 'num_gpus': 8}, {'test_file': 'test_qwen2.5_0.5B_debug_rollout_then_train.py', 'num_gpus': 8}, {'test_file': 'test_qwen2.5_0.5B_opd_sglang.py', 'num_gpus': 8}, + {'test_file': 'test_qwen3.5_0.8B_external_pd.py', 'num_gpus': 8}, {'test_file': 'test_qwen2.5_0.5B_fully_async_short.py', 'num_gpus': 4}, {'test_file': 'test_qwen3_4B_streaming_partial_rollout.py', 'num_gpus': 8}, {'test_file': 'test_qwen3.5_0.8B_gsm8k_short.py', 'num_gpus': 4}, diff --git a/slime/backends/sglang_utils/external.py b/slime/backends/sglang_utils/external.py index b61ae6ad4f..02adaa7e6d 100644 --- a/slime/backends/sglang_utils/external.py +++ b/slime/backends/sglang_utils/external.py @@ -129,16 +129,6 @@ def apply_external_engine_info_to_args(args, logger=None) -> None: args.rollout_num_engines = len(infos) args.rollout_num_gpus = sum(info.num_gpus for info in infos) - # Keep legacy homogeneous fields meaningful for code paths that still read - # them. Per-group rollout startup uses the exact per-engine values below. - first = infos[0] - args.rollout_num_gpus_per_engine = first.num_gpus - args.sglang_pipeline_parallel_size = first.pp_size - args.sglang_data_parallel_size = first.dp_size - args.sglang_expert_parallel_size = first.ep_size - if any(info.dp_size > 1 for info in infos): - args.sglang_enable_dp_attention = True - if logger is not None: summary = [ { diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 4addd539e7..9480a87720 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -1087,8 +1087,6 @@ def _external_engine_infos_from_args(args) -> list[ExternalEngineInfo]: args.rollout_external_engine_infos = [info.to_dict() for info in infos] args.rollout_num_engines = len(infos) args.rollout_num_gpus = sum(info.num_gpus for info in infos) - if infos: - args.rollout_num_gpus_per_engine = infos[0].num_gpus return infos return [external_engine_info_from_dict(info) if isinstance(info, dict) else info for info in raw_infos] diff --git a/tests/test_qwen3.5_0.8B_external_pd.py b/tests/test_qwen3.5_0.8B_external_pd.py new file mode 100644 index 0000000000..9a8f550ab9 --- /dev/null +++ b/tests/test_qwen3.5_0.8B_external_pd.py @@ -0,0 +1,263 @@ +"""E2E test for --rollout-external-engine-addrs with a mixed external fleet. + +Spawns three SGLang servers out-of-band on a single 8-GPU box: +- 1 prefill (tp=2, 2 GPUs, ``--disaggregation-mode prefill``) +- 1 decode (tp=1, 1 GPU, ``--disaggregation-mode decode``) +- 1 regular (tp=1, 1 GPU, no disaggregation) + +and points slime at all three via ``--rollout-external-engine-addrs ...``. +The remaining 4 GPUs train. slime queries ``/server_info`` on each engine to +infer per-engine TP / GPU counts and registers them to its (PD-enabled) router. +""" + +import os +import subprocess +import time +import urllib.request + +import slime.utils.external_utils.command_utils as U + +MODEL_NAME = "Qwen3.5-0.8B" +MODEL_TYPE = "qwen3.5-0.8B" +NUM_GPUS = 8 +NUM_TRAIN_GPUS = 4 +PREFILL_TP = 2 + +EXTERNAL_HOST = "127.0.0.1" +PREFILL_PORT = 13150 +DECODE_PORT = 13151 +REGULAR_PORT = 13153 +BOOTSTRAP_PORT = 13152 + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/gsm8k") + U.convert_checkpoint( + model_name=MODEL_NAME, + megatron_model_type=MODEL_TYPE, + num_gpus_per_node=NUM_TRAIN_GPUS, + dir_dst="/dev/shm", + ) + + +def _get_gpu_split(): + """Partition the 8 visible GPUs: 4 train + 2 prefill + 1 decode + 1 regular.""" + all_gpus = os.environ.get("CUDA_VISIBLE_DEVICES", ",".join(str(i) for i in range(NUM_GPUS))).split(",") + assert len(all_gpus) >= NUM_GPUS, f"Expected at least {NUM_GPUS} GPUs, got {len(all_gpus)}" + train_gpus = all_gpus[:NUM_TRAIN_GPUS] + cursor = NUM_TRAIN_GPUS + prefill_gpus = all_gpus[cursor : cursor + PREFILL_TP] + cursor += PREFILL_TP + decode_gpu = all_gpus[cursor] + cursor += 1 + regular_gpu = all_gpus[cursor] + return train_gpus, prefill_gpus, decode_gpu, regular_gpu + + +def _launch_sglang_server( + *, + gpus: list[str], + port: int, + tp: int, + log_path: str, + disaggregation_mode: str | None = None, + disaggregation_bootstrap_port: int | None = None, +) -> subprocess.Popen: + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = ",".join(gpus) + + cmd = [ + "python3", + "-m", + "sglang.launch_server", + "--model-path", + f"/root/models/{MODEL_NAME}", + "--host", + "0.0.0.0", + "--port", + str(port), + "--tp", + str(tp), + "--mem-fraction-static", + "0.6", + "--trust-remote-code", + ] + if disaggregation_mode is not None: + cmd += [ + "--disaggregation-mode", + disaggregation_mode, + "--disaggregation-transfer-backend", + "mooncake", + ] + if disaggregation_bootstrap_port is not None: + cmd += ["--disaggregation-bootstrap-port", str(disaggregation_bootstrap_port)] + + log_file = open(log_path, "w") + process = subprocess.Popen(cmd, env=env, stdout=log_file, stderr=subprocess.STDOUT) + label = disaggregation_mode or "regular" + print( + f"Starting external sglang {label} server on GPUs {gpus} " + f"port={port} tp={tp} (pid={process.pid}), log: {log_path}" + ) + + # Wait up to ~10 minutes for /server_info to come up. /health_generate + # is unreliable for prefill/decode-only nodes, so we poll /server_info + # — that's what slime's discover_external_engines uses anyway. + deadline = time.time() + 600 + while time.time() < deadline: + if process.poll() is not None: + raise RuntimeError(f"{label} server exited with code {process.returncode}; check {log_path}") + try: + req = urllib.request.urlopen(f"http://{EXTERNAL_HOST}:{port}/server_info", timeout=2) + if req.status == 200: + print(f"External sglang {label} server is ready on GPUs {gpus}") + return process + except Exception: + pass + time.sleep(5) + + process.kill() + raise RuntimeError(f"{label} server failed to start within timeout; check {log_path}") + + +def execute(): + train_gpus, prefill_gpus, decode_gpu, regular_gpu = _get_gpu_split() + processes: list[subprocess.Popen] = [] + + # Restrict CUDA_VISIBLE_DEVICES to training GPUs before Ray starts so + # ray's bundle allocator doesn't try to claim the external sglang GPUs. + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(train_gpus) + + def launch_external_engines(): + processes.append( + _launch_sglang_server( + gpus=prefill_gpus, + port=PREFILL_PORT, + tp=PREFILL_TP, + disaggregation_mode="prefill", + disaggregation_bootstrap_port=BOOTSTRAP_PORT, + log_path="/tmp/sglang_external_prefill.log", + ) + ) + processes.append( + _launch_sglang_server( + gpus=[decode_gpu], + port=DECODE_PORT, + tp=1, + disaggregation_mode="decode", + log_path="/tmp/sglang_external_decode.log", + ) + ) + processes.append( + _launch_sglang_server( + gpus=[regular_gpu], + port=REGULAR_PORT, + tp=1, + log_path="/tmp/sglang_external_regular.log", + ) + ) + + try: + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /dev/shm/{MODEL_NAME}_torch_dist " + + rollout_args = ( + "--prompt-data /root/datasets/gsm8k/train.parquet " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 2 " + "--rollout-batch-size 4 " + "--n-samples-per-prompt 4 " + "--rollout-max-response-len 512 " + "--rollout-temperature 0.8 " + "--global-batch-size 16 " + ) + + perf_args = ( + "--tensor-model-parallel-size 1 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 9216 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + # No --rollout-num-gpus / --rollout-num-gpus-per-engine: those are + # inferred from /server_info on each external engine (heterogeneous + # topology — 2-GPU prefill, 1-GPU decode, 1-GPU regular). + external_args = ( + "--rollout-external-engine-addrs " + f"{EXTERNAL_HOST}:{PREFILL_PORT} " + f"{EXTERNAL_HOST}:{DECODE_PORT} " + f"{EXTERNAL_HOST}:{REGULAR_PORT} " + ) + + ci_args = "--ci-test " + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--loss-mask-type qwen3_5 " + "--actor-num-nodes 1 " + f"--actor-num-gpus-per-node {NUM_TRAIN_GPUS} " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{external_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_TRAIN_GPUS, + megatron_model_type=MODEL_TYPE, + before_ray_job_submit=launch_external_engines, + ) + finally: + for p in processes: + if p.poll() is None: + p.kill() + p.wait() + U.exec_command("pkill -9 sglang; true") + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute() diff --git a/tests/utils/test_external_sglang_engines.py b/tests/utils/test_external_sglang_engines.py index a794b819d8..432b941747 100644 --- a/tests/utils/test_external_sglang_engines.py +++ b/tests/utils/test_external_sglang_engines.py @@ -87,9 +87,9 @@ def fake_get(url, timeout): assert args.rollout_num_gpus == 6 assert args.rollout_num_engines == 2 assert get_rollout_num_engines(args) == 2 - assert args.rollout_num_gpus_per_engine == 2 - assert args.sglang_enable_dp_attention is True assert [info["worker_type"] for info in args.rollout_external_engine_infos] == ["prefill", "decode"] + assert [info["num_gpus"] for info in args.rollout_external_engine_infos] == [2, 4] + assert [info["dp_size"] for info in args.rollout_external_engine_infos] == [1, 2] assert args.rollout_external_engine_infos[0]["disaggregation_bootstrap_port"] == 12090 From 9741b6269218d97e90c50688d709de344c4fd02d Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Thu, 4 Jun 2026 02:42:15 +0000 Subject: [PATCH 04/17] fix --- slime/ray/rollout.py | 43 +++++++++++++------------------------------ 1 file changed, 13 insertions(+), 30 deletions(-) diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 9480a87720..3b9c6c5d27 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -215,25 +215,17 @@ def _start_external_proxy_engines(self, port_cursors: dict[int, int]) -> tuple[l sglang_overrides=self.sglang_overrides, num_gpus_per_engine=self.num_gpus_per_engine, ) - rollout_engines.append((global_rank, rollout_engine)) + rollout_engines.append((global_rank, rollout_engine, spec)) self.all_engines[i] = rollout_engine self.num_new_engines = len(rollout_engines) - if self.num_new_engines == 0: - return [], port_cursors - - addr_and_ports = _allocate_rollout_engine_addr_and_ports_external( - self.external_worker_specs, - rollout_engines, - rank_offset=self.rank_offset, - ) init_handles = [ engine.init.remote( - **addr_and_ports[rank], + **_external_engine_init_kwargs(spec), router_ip=self.router_ip, router_port=self.router_port, ) - for rank, engine in rollout_engines + for _rank, engine, spec in rollout_engines ] return init_handles, port_cursors @@ -897,25 +889,16 @@ def _validate_rollout_id_annotated(node, depth=0): _validate_rollout_id_annotated(item, depth + 1) -def _allocate_rollout_engine_addr_and_ports_external( - external_worker_specs: list[ExternalEngineInfo], - rollout_engines, - *, - rank_offset: int = 0, -): - addr_and_ports = {} - for rank, _ in rollout_engines: - spec = external_worker_specs[rank - rank_offset] - addr = f"{spec.host}:{spec.port}" - addr_and_ports[rank] = dict( - dist_init_addr=addr, - nccl_port=None, - host=spec.host, - port=spec.port, - ) - if spec.worker_type == "prefill": - addr_and_ports[rank]["disaggregation_bootstrap_port"] = spec.disaggregation_bootstrap_port - return addr_and_ports +def _external_engine_init_kwargs(spec: ExternalEngineInfo) -> dict: + init_kwargs = { + "dist_init_addr": f"{spec.host}:{spec.port}", + "nccl_port": None, + "host": spec.host, + "port": spec.port, + } + if spec.worker_type == "prefill": + init_kwargs["disaggregation_bootstrap_port"] = spec.disaggregation_bootstrap_port + return init_kwargs def _allocate_rollout_engine_addr_and_ports_normal( From f2e03d9b111d6e5775a7649fe7f880399bb6240d Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Thu, 4 Jun 2026 02:49:54 +0000 Subject: [PATCH 05/17] fix --- slime/backends/sglang_utils/external.py | 15 +-------------- slime/backends/sglang_utils/sglang_engine.py | 4 ++++ slime/ray/rollout.py | 12 +++--------- tests/utils/test_external_sglang_engines.py | 10 +++++----- 4 files changed, 13 insertions(+), 28 deletions(-) diff --git a/slime/backends/sglang_utils/external.py b/slime/backends/sglang_utils/external.py index 02adaa7e6d..a8a8b52729 100644 --- a/slime/backends/sglang_utils/external.py +++ b/slime/backends/sglang_utils/external.py @@ -15,10 +15,6 @@ class ExternalEngineInfo: port: int worker_type: str num_gpus: int - tp_size: int - pp_size: int - dp_size: int - ep_size: int disaggregation_bootstrap_port: int | None = None server_info: dict = dataclasses.field(default_factory=dict) @@ -86,8 +82,6 @@ def discover_external_engines(addrs: list[str], timeout: float = 30.0) -> list[E pp_size = _positive_int(server_info.get("pp_size") or server_info.get("pipeline_parallel_size"), 1) tp_size = _positive_int(server_info.get("tp_size") or server_info.get("tensor_parallel_size"), 1) - dp_size = _positive_int(server_info.get("dp_size") or server_info.get("data_parallel_size"), 1) - ep_size = _positive_int(server_info.get("ep_size") or server_info.get("expert_parallel_size"), 1) num_gpus = _positive_int( server_info.get("num_gpus") or server_info.get("num_gpus_per_engine"), tp_size * pp_size, @@ -102,10 +96,6 @@ def discover_external_engines(addrs: list[str], timeout: float = 30.0) -> list[E port=parsed.port, worker_type=_infer_worker_type(server_info), num_gpus=num_gpus, - tp_size=tp_size, - pp_size=pp_size, - dp_size=dp_size, - ep_size=ep_size, disaggregation_bootstrap_port=bootstrap_port, server_info=server_info, ) @@ -135,10 +125,7 @@ def apply_external_engine_info_to_args(args, logger=None) -> None: "url": info.url, "worker_type": info.worker_type, "num_gpus": info.num_gpus, - "tp_size": info.tp_size, - "pp_size": info.pp_size, - "dp_size": info.dp_size, - "ep_size": info.ep_size, + "disaggregation_bootstrap_port": info.disaggregation_bootstrap_port, } for info in infos ] diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index de5c107861..a2328664dd 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -659,6 +659,10 @@ def _compute_server_args( "dist_init_addr", "gpu_id_step", "base_gpu_id", + "tp_size", + "dp_size", + "pp_size", + "ep_size", "skip_server_warmup", "enable_draft_weights_cpu_backup", "enable_metrics", diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 3b9c6c5d27..7f3b9d739a 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -1137,20 +1137,14 @@ def _start_external_rollout_servers(args, pg) -> dict[str, RolloutServer]: specs_by_topology: dict[tuple, list[ExternalEngineInfo]] = {} for info in engine_infos: - key = (info.worker_type, info.num_gpus, info.tp_size, info.pp_size, info.dp_size, info.ep_size) + key = (info.worker_type, info.num_gpus) specs_by_topology.setdefault(key, []).append(info) server_groups = [] engine_offset = 0 gpu_offset = 0 init_handles = [] - for (worker_type, num_gpus, tp_size, pp_size, dp_size, ep_size), group_specs in specs_by_topology.items(): - overrides = { - "tp_size": tp_size, - "pp_size": pp_size, - "dp_size": dp_size, - "ep_size": ep_size, - } + for (worker_type, num_gpus), group_specs in specs_by_topology.items(): group = ServerGroup( args=args, pg=pg, @@ -1160,7 +1154,7 @@ def _start_external_rollout_servers(args, pg) -> dict[str, RolloutServer]: worker_type=worker_type, rank_offset=engine_offset, gpu_offset=gpu_offset, - sglang_overrides=overrides, + sglang_overrides={}, needs_offload=False, model_path=args.hf_checkpoint, router_ip=router_ip, diff --git a/tests/utils/test_external_sglang_engines.py b/tests/utils/test_external_sglang_engines.py index 432b941747..62db7b385f 100644 --- a/tests/utils/test_external_sglang_engines.py +++ b/tests/utils/test_external_sglang_engines.py @@ -42,10 +42,10 @@ def fake_get(url, timeout): assert info.port == 10090 assert info.worker_type == "regular" assert info.num_gpus == 8 - assert info.tp_size == 4 - assert info.pp_size == 2 - assert info.dp_size == 1 - assert info.ep_size == 4 + assert info.server_info["tp_size"] == 4 + assert info.server_info["pp_size"] == 2 + assert info.server_info["dp_size"] == 1 + assert info.server_info["ep_size"] == 4 def test_apply_external_engine_info_handles_pd(monkeypatch): @@ -89,7 +89,7 @@ def fake_get(url, timeout): assert get_rollout_num_engines(args) == 2 assert [info["worker_type"] for info in args.rollout_external_engine_infos] == ["prefill", "decode"] assert [info["num_gpus"] for info in args.rollout_external_engine_infos] == [2, 4] - assert [info["dp_size"] for info in args.rollout_external_engine_infos] == [1, 2] + assert [info["server_info"]["dp_size"] for info in args.rollout_external_engine_infos] == [1, 2] assert args.rollout_external_engine_infos[0]["disaggregation_bootstrap_port"] == 12090 From 958d174811bdea1cf968ab322674bb428e3cf506 Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Thu, 4 Jun 2026 03:09:24 +0000 Subject: [PATCH 06/17] fix test --- tests/test_qwen3.5_0.8B_external_pd.py | 52 +++++++++++++------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/tests/test_qwen3.5_0.8B_external_pd.py b/tests/test_qwen3.5_0.8B_external_pd.py index 9a8f550ab9..afbf4611f7 100644 --- a/tests/test_qwen3.5_0.8B_external_pd.py +++ b/tests/test_qwen3.5_0.8B_external_pd.py @@ -1,11 +1,11 @@ """E2E test for --rollout-external-engine-addrs with a mixed external fleet. -Spawns three SGLang servers out-of-band on a single 8-GPU box: -- 1 prefill (tp=2, 2 GPUs, ``--disaggregation-mode prefill``) -- 1 decode (tp=1, 1 GPU, ``--disaggregation-mode decode``) -- 1 regular (tp=1, 1 GPU, no disaggregation) +Spawns four SGLang servers out-of-band on a single 8-GPU box (all tp=1): +- 1 prefill (``--disaggregation-mode prefill``) +- 1 decode (``--disaggregation-mode decode``) +- 2 regular (no disaggregation) -and points slime at all three via ``--rollout-external-engine-addrs ...``. +and points slime at all four via ``--rollout-external-engine-addrs ...``. The remaining 4 GPUs train. slime queries ``/server_info`` on each engine to infer per-engine TP / GPU counts and registers them to its (PD-enabled) router. """ @@ -21,12 +21,12 @@ MODEL_TYPE = "qwen3.5-0.8B" NUM_GPUS = 8 NUM_TRAIN_GPUS = 4 -PREFILL_TP = 2 +NUM_REGULAR_ENGINES = 2 EXTERNAL_HOST = "127.0.0.1" PREFILL_PORT = 13150 DECODE_PORT = 13151 -REGULAR_PORT = 13153 +REGULAR_PORTS = [13153, 13154] BOOTSTRAP_PORT = 13152 @@ -43,17 +43,17 @@ def prepare(): def _get_gpu_split(): - """Partition the 8 visible GPUs: 4 train + 2 prefill + 1 decode + 1 regular.""" + """Partition the 8 visible GPUs: 4 train + 1 prefill + 1 decode + 2 regular.""" all_gpus = os.environ.get("CUDA_VISIBLE_DEVICES", ",".join(str(i) for i in range(NUM_GPUS))).split(",") assert len(all_gpus) >= NUM_GPUS, f"Expected at least {NUM_GPUS} GPUs, got {len(all_gpus)}" train_gpus = all_gpus[:NUM_TRAIN_GPUS] cursor = NUM_TRAIN_GPUS - prefill_gpus = all_gpus[cursor : cursor + PREFILL_TP] - cursor += PREFILL_TP + prefill_gpu = all_gpus[cursor] + cursor += 1 decode_gpu = all_gpus[cursor] cursor += 1 - regular_gpu = all_gpus[cursor] - return train_gpus, prefill_gpus, decode_gpu, regular_gpu + regular_gpus = all_gpus[cursor : cursor + NUM_REGULAR_ENGINES] + return train_gpus, prefill_gpu, decode_gpu, regular_gpus def _launch_sglang_server( @@ -123,7 +123,7 @@ def _launch_sglang_server( def execute(): - train_gpus, prefill_gpus, decode_gpu, regular_gpu = _get_gpu_split() + train_gpus, prefill_gpu, decode_gpu, regular_gpus = _get_gpu_split() processes: list[subprocess.Popen] = [] # Restrict CUDA_VISIBLE_DEVICES to training GPUs before Ray starts so @@ -133,9 +133,9 @@ def execute(): def launch_external_engines(): processes.append( _launch_sglang_server( - gpus=prefill_gpus, + gpus=[prefill_gpu], port=PREFILL_PORT, - tp=PREFILL_TP, + tp=1, disaggregation_mode="prefill", disaggregation_bootstrap_port=BOOTSTRAP_PORT, log_path="/tmp/sglang_external_prefill.log", @@ -150,14 +150,15 @@ def launch_external_engines(): log_path="/tmp/sglang_external_decode.log", ) ) - processes.append( - _launch_sglang_server( - gpus=[regular_gpu], - port=REGULAR_PORT, - tp=1, - log_path="/tmp/sglang_external_regular.log", + for idx, (gpu, port) in enumerate(zip(regular_gpus, REGULAR_PORTS, strict=True)): + processes.append( + _launch_sglang_server( + gpus=[gpu], + port=port, + tp=1, + log_path=f"/tmp/sglang_external_regular_{idx}.log", + ) ) - ) try: ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /dev/shm/{MODEL_NAME}_torch_dist " @@ -208,13 +209,12 @@ def launch_external_engines(): ) # No --rollout-num-gpus / --rollout-num-gpus-per-engine: those are - # inferred from /server_info on each external engine (heterogeneous - # topology — 2-GPU prefill, 1-GPU decode, 1-GPU regular). + # inferred from /server_info on each external engine (1 prefill + + # 1 decode + 2 regular, all tp=1). external_args = ( "--rollout-external-engine-addrs " f"{EXTERNAL_HOST}:{PREFILL_PORT} " - f"{EXTERNAL_HOST}:{DECODE_PORT} " - f"{EXTERNAL_HOST}:{REGULAR_PORT} " + f"{EXTERNAL_HOST}:{DECODE_PORT} " + " ".join(f"{EXTERNAL_HOST}:{port}" for port in REGULAR_PORTS) + " " ) ci_args = "--ci-test " From bc949b75637e4ee8fd40f1cfb9b494bca7762d1c Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Thu, 4 Jun 2026 03:19:25 +0000 Subject: [PATCH 07/17] use delta weight update --- tests/test_qwen3.5_0.8B_external_pd.py | 32 +++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/tests/test_qwen3.5_0.8B_external_pd.py b/tests/test_qwen3.5_0.8B_external_pd.py index afbf4611f7..6cf2bfbb15 100644 --- a/tests/test_qwen3.5_0.8B_external_pd.py +++ b/tests/test_qwen3.5_0.8B_external_pd.py @@ -8,12 +8,20 @@ and points slime at all four via ``--rollout-external-engine-addrs ...``. The remaining 4 GPUs train. slime queries ``/server_info`` on each engine to infer per-engine TP / GPU counts and registers them to its (PD-enabled) router. + +Weight sync uses ``--update-weight-mode delta --update-weight-transport disk`` +so the post-train sync writes sparse safetensors to a shared dir and the +external engines load them via ``update_weights_from_disk(load_format=delta)`` +— that's the only sync path that actually works for pre-launched workers (no +NCCL group between trainer and external engines). """ import os import subprocess +import tempfile import time import urllib.request +from pathlib import Path import slime.utils.external_utils.command_utils as U @@ -160,6 +168,8 @@ def launch_external_engines(): ) ) + delta_dir_cm = tempfile.TemporaryDirectory(prefix="slime_external_pd_delta_") + delta_dir = delta_dir_cm.name try: ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /dev/shm/{MODEL_NAME}_torch_dist " @@ -194,7 +204,10 @@ def launch_external_engines(): "--use-kl-loss " "--kl-loss-coef 0.00 " "--kl-loss-type low_var_kl " - "--entropy-coef 0.00 " + # Nonzero entropy coef guarantees a nonzero gradient even when all + # rewards in a group tie (advantages=0), so the delta sync writes + # real sparse files instead of an empty no-op. + "--entropy-coef 0.01 " "--eps-clip 0.2 " "--eps-clip-high 0.28 " ) @@ -217,6 +230,18 @@ def launch_external_engines(): f"{EXTERNAL_HOST}:{DECODE_PORT} " + " ".join(f"{EXTERNAL_HOST}:{port}" for port in REGULAR_PORTS) + " " ) + # External engines have no NCCL group with the trainer, so weight + # updates have to go through the disk-backed delta path: the trainer + # writes sparse safetensors per sync, the engines pull via + # update_weights_from_disk(load_format="delta", files=...). + delta_args = ( + "--update-weight-mode delta " + "--update-weight-transport disk " + "--update-weight-encoding deltas " + f"--update-weight-delta-dir {delta_dir} " + "--update-weight-delta-keep-files " + ) + ci_args = "--ci-test " misc_args = ( @@ -238,6 +263,7 @@ def launch_external_engines(): f"{U.get_default_wandb_args(__file__)} " f"{perf_args} " f"{external_args} " + f"{delta_args} " f"{ci_args} " f"{misc_args} " ) @@ -248,12 +274,16 @@ def launch_external_engines(): megatron_model_type=MODEL_TYPE, before_ray_job_submit=launch_external_engines, ) + + delta_files = list(Path(delta_dir).glob("weight_v*/*.safetensors")) + assert delta_files, f"No disk delta safetensors were written under {delta_dir}" finally: for p in processes: if p.poll() is None: p.kill() p.wait() U.exec_command("pkill -9 sglang; true") + delta_dir_cm.cleanup() if __name__ == "__main__": From 609331f8feee92e3c87ba533c1ee452192d404fa Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Thu, 4 Jun 2026 04:06:55 +0000 Subject: [PATCH 08/17] fix --- slime/backends/sglang_utils/sglang_engine.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index a2328664dd..2e63869f29 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -190,6 +190,17 @@ def _sanity_check_server_args(actual_server_args, expect_server_args): actual_server_args = _get_actual_server_args() _sanity_check_server_args(actual_server_args, expect_server_args) + # Pin the external engine's weight version to the trainer's initial value + # (0) so the --ci-test version-equality check in actor.update_weights does + # not trip on the seed call. SGLang's default `server_args.weight_version` + # is "default"; without this we would need the user to remember + # `--weight-version 0` when pre-launching every external server. + if self.worker_type != "encoder" and self.node_rank == 0: + requests.post( + f"http://{self.server_host}:{self.server_port}/update_weight_version", + json={"new_version": "0"}, + ).raise_for_status() + def _init_normal(self, server_args_dict): logger.info(f"Launch HttpServerEngineAdapter at: {self.server_host}:{self.server_port}") self.process = launch_server_process(ServerArgs(**server_args_dict)) From f2c0f3c016febfd93cf3ea10623d08ff6ee7b676 Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Thu, 4 Jun 2026 05:22:03 +0000 Subject: [PATCH 09/17] fix ci --- .claude/skills/add-tests-and-ci/SKILL.md | 28 ++++- .github/workflows/pr-test.yml | 2 +- .github/workflows/pr-test.yml.j2 | 2 + slime/backends/megatron_utils/actor.py | 2 +- .../update_weight_from_distributed_delta.py | 5 - slime/backends/sglang_utils/sglang_engine.py | 11 -- slime/ray/actor_group.py | 26 +++-- slime/ray/placement_group.py | 40 ++++--- slime/ray/train_actor.py | 7 +- .../test_external_sglang_engines.py | 15 +++ tests/test_placement_group.py | 51 +++++++++ tests/test_qwen3.5_0.8B_external_pd.py | 101 ++++++++---------- 12 files changed, 182 insertions(+), 108 deletions(-) rename tests/{utils => }/test_external_sglang_engines.py (92%) create mode 100644 tests/test_placement_group.py diff --git a/.claude/skills/add-tests-and-ci/SKILL.md b/.claude/skills/add-tests-and-ci/SKILL.md index d4a0a7397c..000ea983b9 100644 --- a/.claude/skills/add-tests-and-ci/SKILL.md +++ b/.claude/skills/add-tests-and-ci/SKILL.md @@ -40,15 +40,33 @@ if __name__ == "__main__": - `run-ci-changed` extracts a top-level `NUM_GPUS = ` constant from added/modified `tests/test_*.py` and `tests/plugin_contracts/test_*.py`; if missing, it defaults to 8 GPUs. Set `NUM_GPUS = 0` for CPU-only tests. - For GPU/e2e tests, follow the nearby file pattern (`prepare()`, `execute()`, `NUM_GPUS`, and any model/dataset constants). -### Step 3: Run Local Validation +### Step 3: Register Tests in GitHub CI + +Whenever adding, moving, or renaming a test file, update the GitHub workflow template before finishing: + +1. Add the test to the appropriate matrix in `.github/workflows/pr-test.yml.j2`. + - CPU-only pytest/unit tests usually belong in `cpu-unittest` with `num_gpus: 0`. + - GPU/e2e tests should be placed beside the nearest similar model/path test with the matching `num_gpus` and environment fields. +2. Regenerate workflows: + +```bash +python .github/workflows/generate_github_workflows.py +``` + +3. Include both `.github/workflows/pr-test.yml.j2` and the generated `.github/workflows/pr-test.yml` in the change set. + +Only skip fixed matrix registration when the test is intentionally helper-only or manually invoked; state that reason in the final response. + +### Step 4: Run Local Validation - Run the exact existing test files you changed, if any. +- For new registered tests, run the same shape CI will use, for example `python tests/test_new_file.py`. - Run repository-wide checks only when they are already part of the task or workflow. - Avoid documenting placeholder test commands that may not exist in the current tree. -### Step 4: Update Workflow Template Correctly +### Step 5: Keep Workflow Template as Source of Truth -For CI workflow changes: +For CI workflow changes unrelated to a new, moved, or renamed test: 1. Edit `.github/workflows/pr-test.yml.j2` 2. Regenerate workflows: @@ -59,11 +77,12 @@ python .github/workflows/generate_github_workflows.py 3. Include both the template and generated workflow file in the change set (`.j2` and `.yml`). If the user asked for a commit, commit both. -### Step 5: Provide Verifiable PR Notes +### Step 6: Provide Verifiable PR Notes Include: - Which tests were added/changed +- Where each new/renamed test was registered in `.github/workflows/pr-test.yml.j2` - Exact commands executed - GPU assumptions for each test path - Why this coverage protects against regression @@ -71,6 +90,7 @@ Include: ## Common Mistakes - Editing generated workflow file only +- Relying on `run-ci-changed` discovery for a new test that should run in the regular PR matrix - Forgetting `NUM_GPUS = 0` on a CPU-only changed test, causing `run-ci-changed` to default to 8 GPUs - Adding a CPU pytest file that passes under `pytest tests/foo.py` but fails under CI's `python tests/foo.py` - Adding tests without following existing constants/conventions diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 4c8e74517a..4d47f10c10 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -454,7 +454,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_agent_trajectory.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}] + info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_agent_trajectory.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}] defaults: run: working-directory: ${{ github.workspace }} diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 9c898a5094..611c2cd53b 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -84,6 +84,8 @@ {'test_file': 'test_sample.py', 'num_gpus': 0}, {'test_file': 'test_agent_trajectory.py', 'num_gpus': 0}, {'test_file': 'test_rollout_validation.py', 'num_gpus': 0}, + {'test_file': 'test_placement_group.py', 'num_gpus': 0}, + {'test_file': 'test_external_sglang_engines.py', 'num_gpus': 0}, {'test_file': 'utils/test_hf_checkpoint_saver.py', 'num_gpus': 0}, {'test_file': 'plugin_contracts/test_plugin_rollout_contracts.py', 'num_gpus': 0}, {'test_file': 'plugin_contracts/test_plugin_runtime_hook_contracts.py', 'num_gpus': 0}, diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 49eee715e0..74680e2ada 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -628,7 +628,7 @@ def update_weights(self) -> None: self.weight_updater.update_weights() print_memory("after update_weights") - if self.args.ci_test and len(rollout_engines) > 0: + if self.args.ci_test and len(rollout_engines) > 0 and self.weight_updater.weight_version > 0: engine = random.choice(rollout_engines) engine_version = ray.get(engine.get_weight_version.remote()) if str(engine_version) != str(self.weight_updater.weight_version): diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py index f5ae3cf334..aed9d086eb 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py @@ -577,11 +577,6 @@ def update_weights(self) -> None: if not self._snapshot_seeded: self._seed_snapshot() self._snapshot_seeded = True - # Pin the engine's recorded version to ours (0) on the seed call so the - # CI version-equality check holds before any real sync has happened. - if dist.get_rank() == 0 and self.transport == "disk" and self.rollout_engines: - weight_version = str(self.weight_version) - ray.get([engine.set_weight_version.remote(weight_version) for engine in self.rollout_engines]) return self.weight_version += 1 diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index 2e63869f29..a2328664dd 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -190,17 +190,6 @@ def _sanity_check_server_args(actual_server_args, expect_server_args): actual_server_args = _get_actual_server_args() _sanity_check_server_args(actual_server_args, expect_server_args) - # Pin the external engine's weight version to the trainer's initial value - # (0) so the --ci-test version-equality check in actor.update_weights does - # not trip on the seed call. SGLang's default `server_args.weight_version` - # is "default"; without this we would need the user to remember - # `--weight-version 0` when pre-launching every external server. - if self.worker_type != "encoder" and self.node_rank == 0: - requests.post( - f"http://{self.server_host}:{self.server_port}/update_weight_version", - json={"new_version": "0"}, - ).raise_for_status() - def _init_normal(self, server_args_dict): logger.info(f"Launch HttpServerEngineAdapter at: {self.server_host}:{self.server_port}") self.process = launch_server_process(ServerArgs(**server_args_dict)) diff --git a/slime/ray/actor_group.py b/slime/ray/actor_group.py index 27ad610ad9..15519ccebe 100644 --- a/slime/ray/actor_group.py +++ b/slime/ray/actor_group.py @@ -44,11 +44,14 @@ def __init__( self._allocate_gpus_for_actor(pg, num_gpus_per_actor) def _allocate_gpus_for_actor(self, pg, num_gpus_per_actor): - world_size = self._num_nodes * self._num_gpus_per_node + world_size = 1 if self.args.debug_rollout_only else self._num_nodes * self._num_gpus_per_node # Use placement group to lock resources for models of same type - assert pg is not None - pg, reordered_bundle_indices, _reordered_gpu_ids = pg + if self.args.debug_rollout_only: + pg, reordered_bundle_indices = None, [] + else: + assert pg is not None + pg, reordered_bundle_indices, _reordered_gpu_ids = pg env_vars = { # because sglang will always set NCCL_CUMEM_ENABLE to 0 @@ -89,20 +92,23 @@ def _allocate_gpus_for_actor(self, pg, num_gpus_per_actor): actor_impl = MegatronTrainRayActor - TrainRayActor = ray.remote(num_gpus=1, runtime_env={"env_vars": env_vars})(actor_impl) + default_num_gpus = 0 if self.args.debug_rollout_only else 1 + TrainRayActor = ray.remote(num_gpus=default_num_gpus, runtime_env={"env_vars": env_vars})(actor_impl) # Create worker actors self._actor_handlers = [] master_addr, master_port = None, None for rank in range(world_size): - actor = TrainRayActor.options( - num_cpus=num_gpus_per_actor, - num_gpus=num_gpus_per_actor, - scheduling_strategy=PlacementGroupSchedulingStrategy( + actor_options = { + "num_cpus": num_gpus_per_actor, + "num_gpus": 0 if self.args.debug_rollout_only else num_gpus_per_actor, + } + if not self.args.debug_rollout_only: + actor_options["scheduling_strategy"] = PlacementGroupSchedulingStrategy( placement_group=pg, placement_group_bundle_index=reordered_bundle_indices[rank], - ), - ).remote(world_size, rank, master_addr, master_port) + ) + actor = TrainRayActor.options(**actor_options).remote(world_size, rank, master_addr, master_port) if rank == 0: master_addr, master_port = ray.get(actor.get_master_addr_and_port.remote()) self._actor_handlers.append(actor) diff --git a/slime/ray/placement_group.py b/slime/ray/placement_group.py index d6af2e8102..96014928d7 100644 --- a/slime/ray/placement_group.py +++ b/slime/ray/placement_group.py @@ -41,6 +41,9 @@ def sort_key(x): def _create_placement_group(num_gpus): """Create a placement group with the specified number of GPUs.""" + if num_gpus == 0: + return None, [], [] + bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_gpus)] pg = placement_group(bundles, strategy="PACK") num_bundles = len(bundles) @@ -77,25 +80,30 @@ def _create_placement_group(num_gpus): return pg, pg_reordered_bundle_indices, pg_reordered_gpu_ids +def _get_placement_group_layout(args) -> tuple[int, int]: + actor_num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node + + if args.debug_train_only: + return actor_num_gpus, 0 + + if args.rollout_external: + if args.debug_rollout_only: + return 0, 0 + return actor_num_gpus, actor_num_gpus + + if args.debug_rollout_only: + return args.rollout_num_gpus, 0 + + if args.colocate: + return actor_num_gpus, 0 + + return actor_num_gpus + args.rollout_num_gpus, actor_num_gpus + + def create_placement_groups(args): """Create placement groups for actor, critic, and rollout engines.""" - num_gpus = 0 - if args.debug_train_only: - num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node - rollout_offset = 0 - elif args.debug_rollout_only: - num_gpus = args.rollout_num_gpus - rollout_offset = 0 - elif args.colocate: - num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node - rollout_offset = 0 - elif args.rollout_external: - num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node - rollout_offset = num_gpus - else: - num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node + args.rollout_num_gpus - rollout_offset = args.actor_num_nodes * args.actor_num_gpus_per_node + num_gpus, rollout_offset = _get_placement_group_layout(args) logger.info(f"Creating placement group with {num_gpus} GPUs...") pg, actor_pg_reordered_bundle_indices, actor_pg_reordered_gpu_ids = _create_placement_group(num_gpus) diff --git a/slime/ray/train_actor.py b/slime/ray/train_actor.py index a8ba6ddc64..b84b89d92e 100644 --- a/slime/ray/train_actor.py +++ b/slime/ray/train_actor.py @@ -18,11 +18,14 @@ def get_local_gpu_id(): + gpu_ids = ray.get_gpu_ids() + if not gpu_ids: + return 0 cvd = os.environ.get("CUDA_VISIBLE_DEVICES", None) if cvd is None: - return ray.get_gpu_ids()[0] + return gpu_ids[0] else: - return cvd.split(",").index(str(ray.get_gpu_ids()[0])) + return cvd.split(",").index(str(gpu_ids[0])) class TrainRayActor(RayActor): diff --git a/tests/utils/test_external_sglang_engines.py b/tests/test_external_sglang_engines.py similarity index 92% rename from tests/utils/test_external_sglang_engines.py rename to tests/test_external_sglang_engines.py index 62db7b385f..4e0b6c1371 100644 --- a/tests/utils/test_external_sglang_engines.py +++ b/tests/test_external_sglang_engines.py @@ -1,9 +1,20 @@ +import sys from argparse import Namespace +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) from slime.backends.sglang_utils.external import apply_external_engine_info_to_args, discover_external_engines from slime.utils.http_utils import get_rollout_num_engines +NUM_GPUS = 0 + + class _Response: def __init__(self, payload, status_code=200): self.payload = payload @@ -99,3 +110,7 @@ def test_apply_external_engine_info_no_addrs_disables_external(): apply_external_engine_info_to_args(args) assert args.rollout_external is False + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__])) diff --git a/tests/test_placement_group.py b/tests/test_placement_group.py new file mode 100644 index 0000000000..8f918d4a74 --- /dev/null +++ b/tests/test_placement_group.py @@ -0,0 +1,51 @@ +import sys +from argparse import Namespace +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from slime.ray.placement_group import _create_placement_group, _get_placement_group_layout + + +NUM_GPUS = 0 + + +def _args(**overrides): + values = { + "actor_num_nodes": 2, + "actor_num_gpus_per_node": 8, + "rollout_num_gpus": 32, + "debug_train_only": False, + "debug_rollout_only": False, + "colocate": False, + "rollout_external": False, + } + values.update(overrides) + return Namespace(**values) + + +@pytest.mark.parametrize( + ("overrides", "expected"), + [ + pytest.param({}, (48, 16), id="normal_non_colocate"), + pytest.param({"debug_train_only": True}, (16, 0), id="debug_train_only"), + pytest.param({"debug_rollout_only": True}, (32, 0), id="debug_rollout_only"), + pytest.param({"colocate": True}, (16, 0), id="colocate"), + pytest.param({"rollout_external": True}, (16, 16), id="external"), + pytest.param({"rollout_external": True, "debug_rollout_only": True}, (0, 0), id="external_debug_rollout"), + ], +) +def test_placement_group_layout(overrides, expected): + assert _get_placement_group_layout(_args(**overrides)) == expected + + +def test_create_zero_gpu_placement_group_is_empty(): + assert _create_placement_group(0) == (None, [], []) + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__])) diff --git a/tests/test_qwen3.5_0.8B_external_pd.py b/tests/test_qwen3.5_0.8B_external_pd.py index 6cf2bfbb15..0764908fe5 100644 --- a/tests/test_qwen3.5_0.8B_external_pd.py +++ b/tests/test_qwen3.5_0.8B_external_pd.py @@ -1,13 +1,12 @@ -"""E2E test for --rollout-external-engine-addrs with a mixed external fleet. +"""E2E test for --rollout-external-engine-addrs with a pure-PD external fleet. Spawns four SGLang servers out-of-band on a single 8-GPU box (all tp=1): -- 1 prefill (``--disaggregation-mode prefill``) -- 1 decode (``--disaggregation-mode decode``) -- 2 regular (no disaggregation) +- 2 prefill (``--disaggregation-mode prefill``, mooncake transfer backend) +- 2 decode (``--disaggregation-mode decode``, mooncake transfer backend) and points slime at all four via ``--rollout-external-engine-addrs ...``. The remaining 4 GPUs train. slime queries ``/server_info`` on each engine to -infer per-engine TP / GPU counts and registers them to its (PD-enabled) router. +infer per-engine TP / GPU counts and registers them to its PD-enabled router. Weight sync uses ``--update-weight-mode delta --update-weight-transport disk`` so the post-train sync writes sparse safetensors to a shared dir and the @@ -29,13 +28,13 @@ MODEL_TYPE = "qwen3.5-0.8B" NUM_GPUS = 8 NUM_TRAIN_GPUS = 4 -NUM_REGULAR_ENGINES = 2 +NUM_PREFILL_ENGINES = 2 +NUM_DECODE_ENGINES = 2 EXTERNAL_HOST = "127.0.0.1" -PREFILL_PORT = 13150 -DECODE_PORT = 13151 -REGULAR_PORTS = [13153, 13154] -BOOTSTRAP_PORT = 13152 +PREFILL_PORTS = [13150, 13151] +DECODE_PORTS = [13152, 13153] +BOOTSTRAP_PORTS = [13160, 13161] def prepare(): @@ -51,17 +50,15 @@ def prepare(): def _get_gpu_split(): - """Partition the 8 visible GPUs: 4 train + 1 prefill + 1 decode + 2 regular.""" + """Partition the 8 visible GPUs: 4 train + 2 prefill + 2 decode.""" all_gpus = os.environ.get("CUDA_VISIBLE_DEVICES", ",".join(str(i) for i in range(NUM_GPUS))).split(",") assert len(all_gpus) >= NUM_GPUS, f"Expected at least {NUM_GPUS} GPUs, got {len(all_gpus)}" train_gpus = all_gpus[:NUM_TRAIN_GPUS] cursor = NUM_TRAIN_GPUS - prefill_gpu = all_gpus[cursor] - cursor += 1 - decode_gpu = all_gpus[cursor] - cursor += 1 - regular_gpus = all_gpus[cursor : cursor + NUM_REGULAR_ENGINES] - return train_gpus, prefill_gpu, decode_gpu, regular_gpus + prefill_gpus = all_gpus[cursor : cursor + NUM_PREFILL_ENGINES] + cursor += NUM_PREFILL_ENGINES + decode_gpus = all_gpus[cursor : cursor + NUM_DECODE_ENGINES] + return train_gpus, prefill_gpus, decode_gpus def _launch_sglang_server( @@ -70,7 +67,7 @@ def _launch_sglang_server( port: int, tp: int, log_path: str, - disaggregation_mode: str | None = None, + disaggregation_mode: str, disaggregation_bootstrap_port: int | None = None, ) -> subprocess.Popen: env = os.environ.copy() @@ -91,22 +88,18 @@ def _launch_sglang_server( "--mem-fraction-static", "0.6", "--trust-remote-code", + "--disaggregation-mode", + disaggregation_mode, + "--disaggregation-transfer-backend", + "mooncake", ] - if disaggregation_mode is not None: - cmd += [ - "--disaggregation-mode", - disaggregation_mode, - "--disaggregation-transfer-backend", - "mooncake", - ] if disaggregation_bootstrap_port is not None: cmd += ["--disaggregation-bootstrap-port", str(disaggregation_bootstrap_port)] log_file = open(log_path, "w") process = subprocess.Popen(cmd, env=env, stdout=log_file, stderr=subprocess.STDOUT) - label = disaggregation_mode or "regular" print( - f"Starting external sglang {label} server on GPUs {gpus} " + f"Starting external sglang {disaggregation_mode} server on GPUs {gpus} " f"port={port} tp={tp} (pid={process.pid}), log: {log_path}" ) @@ -116,22 +109,22 @@ def _launch_sglang_server( deadline = time.time() + 600 while time.time() < deadline: if process.poll() is not None: - raise RuntimeError(f"{label} server exited with code {process.returncode}; check {log_path}") + raise RuntimeError(f"{disaggregation_mode} server exited with code {process.returncode}; check {log_path}") try: req = urllib.request.urlopen(f"http://{EXTERNAL_HOST}:{port}/server_info", timeout=2) if req.status == 200: - print(f"External sglang {label} server is ready on GPUs {gpus}") + print(f"External sglang {disaggregation_mode} server is ready on GPUs {gpus}") return process except Exception: pass time.sleep(5) process.kill() - raise RuntimeError(f"{label} server failed to start within timeout; check {log_path}") + raise RuntimeError(f"{disaggregation_mode} server failed to start within timeout; check {log_path}") def execute(): - train_gpus, prefill_gpu, decode_gpu, regular_gpus = _get_gpu_split() + train_gpus, prefill_gpus, decode_gpus = _get_gpu_split() processes: list[subprocess.Popen] = [] # Restrict CUDA_VISIBLE_DEVICES to training GPUs before Ray starts so @@ -139,32 +132,27 @@ def execute(): os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(train_gpus) def launch_external_engines(): - processes.append( - _launch_sglang_server( - gpus=[prefill_gpu], - port=PREFILL_PORT, - tp=1, - disaggregation_mode="prefill", - disaggregation_bootstrap_port=BOOTSTRAP_PORT, - log_path="/tmp/sglang_external_prefill.log", - ) - ) - processes.append( - _launch_sglang_server( - gpus=[decode_gpu], - port=DECODE_PORT, - tp=1, - disaggregation_mode="decode", - log_path="/tmp/sglang_external_decode.log", + for idx, (gpu, port, bootstrap_port) in enumerate( + zip(prefill_gpus, PREFILL_PORTS, BOOTSTRAP_PORTS, strict=True) + ): + processes.append( + _launch_sglang_server( + gpus=[gpu], + port=port, + tp=1, + disaggregation_mode="prefill", + disaggregation_bootstrap_port=bootstrap_port, + log_path=f"/tmp/sglang_external_prefill_{idx}.log", + ) ) - ) - for idx, (gpu, port) in enumerate(zip(regular_gpus, REGULAR_PORTS, strict=True)): + for idx, (gpu, port) in enumerate(zip(decode_gpus, DECODE_PORTS, strict=True)): processes.append( _launch_sglang_server( gpus=[gpu], port=port, tp=1, - log_path=f"/tmp/sglang_external_regular_{idx}.log", + disaggregation_mode="decode", + log_path=f"/tmp/sglang_external_decode_{idx}.log", ) ) @@ -222,13 +210,10 @@ def launch_external_engines(): ) # No --rollout-num-gpus / --rollout-num-gpus-per-engine: those are - # inferred from /server_info on each external engine (1 prefill + - # 1 decode + 2 regular, all tp=1). - external_args = ( - "--rollout-external-engine-addrs " - f"{EXTERNAL_HOST}:{PREFILL_PORT} " - f"{EXTERNAL_HOST}:{DECODE_PORT} " + " ".join(f"{EXTERNAL_HOST}:{port}" for port in REGULAR_PORTS) + " " - ) + # inferred from /server_info on each external engine (2 prefill + + # 2 decode, all tp=1). + all_addrs = [f"{EXTERNAL_HOST}:{port}" for port in (*PREFILL_PORTS, *DECODE_PORTS)] + external_args = "--rollout-external-engine-addrs " + " ".join(all_addrs) + " " # External engines have no NCCL group with the trainer, so weight # updates have to go through the disk-backed delta path: the trainer From 4d5513df334830ddce99d64ed8d30fba22b2cbad Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Thu, 4 Jun 2026 09:27:40 +0000 Subject: [PATCH 10/17] update --- slime/backends/sglang_utils/external.py | 72 ++++++++++++++++++++++++- slime/ray/rollout.py | 41 +++++++------- tests/test_external_sglang_engines.py | 69 +++++++++++++++++++++++- 3 files changed, 159 insertions(+), 23 deletions(-) diff --git a/slime/backends/sglang_utils/external.py b/slime/backends/sglang_utils/external.py index a8a8b52729..f85e6a1175 100644 --- a/slime/backends/sglang_utils/external.py +++ b/slime/backends/sglang_utils/external.py @@ -26,6 +26,55 @@ def to_dict(self) -> dict: return dataclasses.asdict(self) +@dataclasses.dataclass(frozen=True) +class ExternalServerGroupInfo: + worker_type: str + num_gpus_per_engine: int + engine_infos: tuple[ExternalEngineInfo, ...] + + @property + def num_gpus(self) -> int: + return sum(info.num_gpus for info in self.engine_infos) + + def to_dict(self) -> dict: + return { + "worker_type": self.worker_type, + "num_gpus": self.num_gpus, + "num_gpus_per_engine": self.num_gpus_per_engine, + "engine_infos": [info.to_dict() for info in self.engine_infos], + } + + +@dataclasses.dataclass(frozen=True) +class ExternalModelInfo: + name: str + server_groups: tuple[ExternalServerGroupInfo, ...] + update_weights: bool = True + + @property + def has_pd_disaggregation(self) -> bool: + return any(g.worker_type in ("prefill", "decode") for g in self.server_groups) + + @property + def engine_infos(self) -> list[ExternalEngineInfo]: + return [info for group in self.server_groups for info in group.engine_infos] + + @property + def total_num_gpus(self) -> int: + return sum(group.num_gpus for group in self.server_groups) + + @property + def num_engines(self) -> int: + return sum(len(group.engine_infos) for group in self.server_groups) + + def to_dict(self) -> dict: + return { + "name": self.name, + "server_groups": [group.to_dict() for group in self.server_groups], + "update_weights": self.update_weights, + } + + def normalize_external_engine_addr(addr: str) -> str: """Normalize ``host:port`` or ``http://host:port`` to an HTTP base URL.""" if "://" not in addr: @@ -44,6 +93,23 @@ def external_engine_info_from_dict(data: dict) -> ExternalEngineInfo: return ExternalEngineInfo(**data) +def build_external_model_info(infos: list[ExternalEngineInfo], name: str = "default") -> ExternalModelInfo: + specs_by_topology: dict[tuple[str, int], list[ExternalEngineInfo]] = {} + for info in infos: + key = (info.worker_type, info.num_gpus) + specs_by_topology.setdefault(key, []).append(info) + + server_groups = tuple( + ExternalServerGroupInfo( + worker_type=worker_type, + num_gpus_per_engine=num_gpus_per_engine, + engine_infos=tuple(group_infos), + ) + for (worker_type, num_gpus_per_engine), group_infos in specs_by_topology.items() + ) + return ExternalModelInfo(name=name, server_groups=server_groups) + + def _positive_int(value, default: int) -> int: if value is None: return default @@ -116,8 +182,10 @@ def apply_external_engine_info_to_args(args, logger=None) -> None: raise ValueError("--rollout-external-engine-addrs did not contain any engines.") args.rollout_external_engine_infos = [info.to_dict() for info in infos] - args.rollout_num_engines = len(infos) - args.rollout_num_gpus = sum(info.num_gpus for info in infos) + model_info = build_external_model_info(infos) + args.rollout_external_model_info = model_info.to_dict() + args.rollout_num_engines = model_info.num_engines + args.rollout_num_gpus = model_info.total_num_gpus if logger is not None: summary = [ diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 7f3b9d739a..37da20f709 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -19,6 +19,8 @@ from slime.backends.sglang_utils.external import ( ExternalEngineInfo, + ExternalModelInfo, + build_external_model_info, discover_external_engines, external_engine_info_from_dict, ) @@ -1060,7 +1062,7 @@ def _compute_megatron_num_gpus(args) -> int: return num -def _external_engine_infos_from_args(args) -> list[ExternalEngineInfo]: +def _external_model_from_args(args) -> ExternalModelInfo: raw_infos = getattr(args, "rollout_external_engine_infos", None) if raw_infos is None: addrs = getattr(args, "rollout_external_engine_addrs", None) @@ -1068,10 +1070,13 @@ def _external_engine_infos_from_args(args) -> list[ExternalEngineInfo]: raise RuntimeError("External rollout requires --rollout-external-engine-addrs.") infos = discover_external_engines(addrs) args.rollout_external_engine_infos = [info.to_dict() for info in infos] - args.rollout_num_engines = len(infos) - args.rollout_num_gpus = sum(info.num_gpus for info in infos) - return infos - return [external_engine_info_from_dict(info) if isinstance(info, dict) else info for info in raw_infos] + else: + infos = [external_engine_info_from_dict(info) if isinstance(info, dict) else info for info in raw_infos] + model_info = build_external_model_info(infos) + args.rollout_external_model_info = model_info.to_dict() + args.rollout_num_engines = model_info.num_engines + args.rollout_num_gpus = model_info.total_num_gpus + return model_info def _get_registered_router_worker_urls(router_ip: str, router_port: int) -> set[str]: @@ -1090,16 +1095,18 @@ def _get_registered_router_worker_urls(router_ip: str, router_port: int) -> set[ return set() -def _register_external_workers_to_router(args, engine_infos: list[ExternalEngineInfo]) -> None: +def _register_external_workers_to_router(args, model_info: ExternalModelInfo) -> None: + engine_infos = model_info.engine_infos if not engine_infos: return router_addr = f"http://{args.sglang_router_ip}:{args.sglang_router_port}" registered_urls = _get_registered_router_worker_urls(args.sglang_router_ip, args.sglang_router_port) - has_pd = any(info.is_pd_worker for info in engine_infos) if parse(sglang_router.__version__) <= parse("0.2.1"): - assert not has_pd, "PD disaggregation for external engines requires sglang_router > 0.2.1." + assert ( + not model_info.has_pd_disaggregation + ), "PD disaggregation for external engines requires sglang_router > 0.2.1." for info in engine_infos: if info.url in registered_urls: continue @@ -1128,30 +1135,26 @@ def _register_external_workers_to_router(args, engine_infos: list[ExternalEngine def _start_external_rollout_servers(args, pg) -> dict[str, RolloutServer]: - engine_infos = _external_engine_infos_from_args(args) - has_pd = any(info.is_pd_worker for info in engine_infos) - router_ip, router_port = _start_router(args, has_pd_disaggregation=has_pd) + model_cfg = _external_model_from_args(args) + router_ip, router_port = _start_router(args, has_pd_disaggregation=model_cfg.has_pd_disaggregation) args.sglang_router_ip = router_ip args.sglang_router_port = router_port - _register_external_workers_to_router(args, engine_infos) - - specs_by_topology: dict[tuple, list[ExternalEngineInfo]] = {} - for info in engine_infos: - key = (info.worker_type, info.num_gpus) - specs_by_topology.setdefault(key, []).append(info) + _register_external_workers_to_router(args, model_cfg) server_groups = [] engine_offset = 0 gpu_offset = 0 init_handles = [] - for (worker_type, num_gpus), group_specs in specs_by_topology.items(): + for group_cfg in model_cfg.server_groups: + group_specs = list(group_cfg.engine_infos) + num_gpus = group_cfg.num_gpus_per_engine group = ServerGroup( args=args, pg=pg, all_engines=[None] * len(group_specs), num_gpus_per_engine=num_gpus, num_new_engines=0, - worker_type=worker_type, + worker_type=group_cfg.worker_type, rank_offset=engine_offset, gpu_offset=gpu_offset, sglang_overrides={}, diff --git a/tests/test_external_sglang_engines.py b/tests/test_external_sglang_engines.py index 4e0b6c1371..74286d20cb 100644 --- a/tests/test_external_sglang_engines.py +++ b/tests/test_external_sglang_engines.py @@ -8,10 +8,14 @@ if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) -from slime.backends.sglang_utils.external import apply_external_engine_info_to_args, discover_external_engines +from slime.backends.sglang_utils.external import ( + ExternalEngineInfo, + apply_external_engine_info_to_args, + build_external_model_info, + discover_external_engines, +) from slime.utils.http_utils import get_rollout_num_engines - NUM_GPUS = 0 @@ -59,6 +63,40 @@ def fake_get(url, timeout): assert info.server_info["ep_size"] == 4 +def test_build_external_model_info_groups_engines_by_topology(): + infos = [ + ExternalEngineInfo( + url="http://prefill-0:10090", + host="prefill-0", + port=10090, + worker_type="prefill", + num_gpus=2, + ), + ExternalEngineInfo( + url="http://prefill-1:10091", + host="prefill-1", + port=10091, + worker_type="prefill", + num_gpus=2, + ), + ExternalEngineInfo(url="http://decode-0:10092", host="decode-0", port=10092, worker_type="decode", num_gpus=4), + ] + + model_info = build_external_model_info(infos) + + assert model_info.name == "default" + assert model_info.has_pd_disaggregation is True + assert model_info.num_engines == 3 + assert model_info.total_num_gpus == 8 + assert [ + (group.worker_type, group.num_gpus, group.num_gpus_per_engine, len(group.engine_infos)) + for group in model_info.server_groups + ] == [ + ("prefill", 4, 2, 2), + ("decode", 4, 4, 1), + ] + + def test_apply_external_engine_info_handles_pd(monkeypatch): payloads = { "http://prefill:10090/server_info": { @@ -90,11 +128,13 @@ def fake_get(url, timeout): sglang_data_parallel_size=1, sglang_expert_parallel_size=1, sglang_enable_dp_attention=False, + router_pd_disaggregation=False, ) apply_external_engine_info_to_args(args) assert args.rollout_external is True + assert args.router_pd_disaggregation is False assert args.rollout_num_gpus == 6 assert args.rollout_num_engines == 2 assert get_rollout_num_engines(args) == 2 @@ -104,6 +144,31 @@ def fake_get(url, timeout): assert args.rollout_external_engine_infos[0]["disaggregation_bootstrap_port"] == 12090 +def test_apply_external_engine_info_preserves_router_pd_flag(monkeypatch): + def fake_get(url, timeout): + assert url == "http://regular:10090/server_info" + return _Response( + { + "tp_size": 2, + "pp_size": 1, + "disaggregation_mode": "null", + } + ) + + monkeypatch.setattr("slime.backends.sglang_utils.external.requests.get", fake_get) + args = Namespace( + rollout_external_engine_addrs=["regular:10090"], + router_pd_disaggregation=True, + ) + + apply_external_engine_info_to_args(args) + + assert args.rollout_external is True + assert args.router_pd_disaggregation is True + assert args.rollout_num_gpus == 2 + assert args.rollout_num_engines == 1 + + def test_apply_external_engine_info_no_addrs_disables_external(): args = Namespace(rollout_external_engine_addrs=None) From a595d9b5107709599ffbcc44329d7421ad109826 Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Thu, 4 Jun 2026 09:37:44 +0000 Subject: [PATCH 11/17] fix --- .github/workflows/pr-test.yml | 4 ++-- .github/workflows/pr-test.yml.j2 | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 4d47f10c10..c8d9c776f4 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -481,7 +481,7 @@ jobs: shell: bash run: | pip install torch --index-url https://download.pytorch.org/whl/cpu - pip install pytest numpy packaging pyyaml omegaconf tqdm httpx pybase64 pylatexenc sympy aiohttp pillow safetensors + pip install pytest numpy packaging pyyaml omegaconf tqdm httpx requests pybase64 pylatexenc sympy aiohttp pillow safetensors - name: Install @@ -547,7 +547,7 @@ jobs: shell: bash run: | pip install torch --index-url https://download.pytorch.org/whl/cpu - pip install pytest numpy packaging pyyaml omegaconf tqdm httpx pybase64 pylatexenc sympy aiohttp pillow safetensors + pip install pytest numpy packaging pyyaml omegaconf tqdm httpx requests pybase64 pylatexenc sympy aiohttp pillow safetensors pip install openai openai-agents anthropic diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 611c2cd53b..c86c1572aa 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -197,7 +197,7 @@ jobs: shell: bash run: | pip install torch --index-url https://download.pytorch.org/whl/cpu - pip install pytest numpy packaging pyyaml omegaconf tqdm httpx pybase64 pylatexenc sympy aiohttp pillow safetensors + pip install pytest numpy packaging pyyaml omegaconf tqdm httpx requests pybase64 pylatexenc sympy aiohttp pillow safetensors <% if config.get('extra_pip_deps') %> pip install << config.extra_pip_deps >> <% endif %> From 349d5162b9afa2b3068cf0302a4514615e33f7f8 Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Thu, 4 Jun 2026 10:14:48 +0000 Subject: [PATCH 12/17] fix --- .github/workflows/pr-test.yml | 6 +- .github/workflows/pr-test.yml.j2 | 4 +- slime/backends/sglang_utils/sglang_engine.py | 16 +- slime/ray/rollout.py | 59 ------- ...nal_pd.py => test_qwen3_4B_external_pd.py} | 146 +++++++++++++++--- 5 files changed, 142 insertions(+), 89 deletions(-) rename tests/{test_qwen3.5_0.8B_external_pd.py => test_qwen3_4B_external_pd.py} (65%) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index c8d9c776f4..2cf852c586 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -205,7 +205,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"enable_eval": "0", "num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_glm4.7_30B_A3B_pd_mooncake.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"num_gpus": 8, "test_file": "test_qwen3.6_35B_A3B_pd_mooncake.py", "use_deepep": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo_disaggregate.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo_train_critic_only.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_debug_rollout_then_train.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_opd_sglang.py"}, {"num_gpus": 8, "test_file": "test_qwen3.5_0.8B_external_pd.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_fully_async_short.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_streaming_partial_rollout.py"}, {"num_gpus": 4, "test_file": "test_qwen3.5_0.8B_gsm8k_short.py"}, {"num_gpus": 4, "test_file": "test_qwen3.5_0.8B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_args": "--save-optimizer gpu --load-optimizer gpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer gpu --load-optimizer cpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer cpu --load-optimizer cpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer cpu --load-optimizer gpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--async-save", "test_file": "test_qwen3_4B_ckpt.py"}] + info: [{"enable_eval": "0", "num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_glm4.7_30B_A3B_pd_mooncake.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"num_gpus": 8, "test_file": "test_qwen3.6_35B_A3B_pd_mooncake.py", "use_deepep": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo_disaggregate.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo_train_critic_only.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_debug_rollout_then_train.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_opd_sglang.py"}, {"num_gpus": 6, "test_file": "test_qwen3_4B_external_pd.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_fully_async_short.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_streaming_partial_rollout.py"}, {"num_gpus": 4, "test_file": "test_qwen3.5_0.8B_gsm8k_short.py"}, {"num_gpus": 4, "test_file": "test_qwen3.5_0.8B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_args": "--save-optimizer gpu --load-optimizer gpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer gpu --load-optimizer cpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer cpu --load-optimizer cpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer cpu --load-optimizer gpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--async-save", "test_file": "test_qwen3_4B_ckpt.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -481,7 +481,7 @@ jobs: shell: bash run: | pip install torch --index-url https://download.pytorch.org/whl/cpu - pip install pytest numpy packaging pyyaml omegaconf tqdm httpx requests pybase64 pylatexenc sympy aiohttp pillow safetensors + pip install pytest numpy packaging pyyaml omegaconf tqdm httpx requests ray pybase64 pylatexenc sympy aiohttp pillow safetensors - name: Install @@ -547,7 +547,7 @@ jobs: shell: bash run: | pip install torch --index-url https://download.pytorch.org/whl/cpu - pip install pytest numpy packaging pyyaml omegaconf tqdm httpx requests pybase64 pylatexenc sympy aiohttp pillow safetensors + pip install pytest numpy packaging pyyaml omegaconf tqdm httpx requests ray pybase64 pylatexenc sympy aiohttp pillow safetensors pip install openai openai-agents anthropic diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index c86c1572aa..58c627e0e7 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -35,7 +35,7 @@ {'test_file': 'test_qwen3_0.6B_parallel_check.py', 'num_gpus': 8}, {'test_file': 'test_qwen2.5_0.5B_debug_rollout_then_train.py', 'num_gpus': 8}, {'test_file': 'test_qwen2.5_0.5B_opd_sglang.py', 'num_gpus': 8}, - {'test_file': 'test_qwen3.5_0.8B_external_pd.py', 'num_gpus': 8}, + {'test_file': 'test_qwen3_4B_external_pd.py', 'num_gpus': 6}, {'test_file': 'test_qwen2.5_0.5B_fully_async_short.py', 'num_gpus': 4}, {'test_file': 'test_qwen3_4B_streaming_partial_rollout.py', 'num_gpus': 8}, {'test_file': 'test_qwen3.5_0.8B_gsm8k_short.py', 'num_gpus': 4}, @@ -197,7 +197,7 @@ jobs: shell: bash run: | pip install torch --index-url https://download.pytorch.org/whl/cpu - pip install pytest numpy packaging pyyaml omegaconf tqdm httpx requests pybase64 pylatexenc sympy aiohttp pillow safetensors + pip install pytest numpy packaging pyyaml omegaconf tqdm httpx requests ray pybase64 pylatexenc sympy aiohttp pillow safetensors <% if config.get('extra_pip_deps') %> pip install << config.extra_pip_deps >> <% endif %> diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index a2328664dd..e04eb20b5d 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -189,27 +189,37 @@ def _sanity_check_server_args(actual_server_args, expect_server_args): ) actual_server_args = _get_actual_server_args() _sanity_check_server_args(actual_server_args, expect_server_args) + self._register_to_router(expect_server_args) def _init_normal(self, server_args_dict): logger.info(f"Launch HttpServerEngineAdapter at: {self.server_host}:{self.server_port}") self.process = launch_server_process(ServerArgs(**server_args_dict)) + self._register_to_router(server_args_dict) + def _register_to_router(self, server_args_dict): if self.worker_type == "encoder": return if self.node_rank == 0 and self.router_ip and self.router_port: + worker_url = f"http://{self.server_host}:{self.server_port}" if parse(sglang_router.__version__) <= parse("0.2.1"): assert self.worker_type == "regular", "pd disaggregation is not supported in old router." response = requests.post( - f"http://{self.router_ip}:{self.router_port}/add_worker?url=http://{self.server_host}:{self.server_port}", + f"http://{self.router_ip}:{self.router_port}/add_worker?url={worker_url}", ) else: payload = { - "url": f"http://{self.server_host}:{self.server_port}", + "url": worker_url, "worker_type": self.worker_type, } if self.worker_type == "prefill": - payload["bootstrap_port"] = server_args_dict["disaggregation_bootstrap_port"] + bootstrap_port = server_args_dict.get("disaggregation_bootstrap_port") + if bootstrap_port is None: + raise RuntimeError( + f"Prefill worker {worker_url} does not have disaggregation_bootstrap_port; " + "cannot register it to the PD router." + ) + payload["bootstrap_port"] = bootstrap_port response = requests.post( f"http://{self.router_ip}:{self.router_port}/workers", json=payload, diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 37da20f709..ac57f2b05a 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -10,10 +10,7 @@ import numpy as np import ray -import requests -import sglang_router import torch -from packaging.version import parse from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS @@ -1079,67 +1076,11 @@ def _external_model_from_args(args) -> ExternalModelInfo: return model_info -def _get_registered_router_worker_urls(router_ip: str, router_port: int) -> set[str]: - router_addr = f"http://{router_ip}:{router_port}" - for endpoint in ("/workers", "/list_workers"): - try: - response = requests.get(f"{router_addr}{endpoint}", timeout=30) - response.raise_for_status() - payload = response.json() - except Exception: - continue - if "workers" in payload: - return {worker["url"] if isinstance(worker, dict) else worker for worker in payload["workers"]} - if "urls" in payload: - return set(payload["urls"]) - return set() - - -def _register_external_workers_to_router(args, model_info: ExternalModelInfo) -> None: - engine_infos = model_info.engine_infos - if not engine_infos: - return - - router_addr = f"http://{args.sglang_router_ip}:{args.sglang_router_port}" - registered_urls = _get_registered_router_worker_urls(args.sglang_router_ip, args.sglang_router_port) - - if parse(sglang_router.__version__) <= parse("0.2.1"): - assert ( - not model_info.has_pd_disaggregation - ), "PD disaggregation for external engines requires sglang_router > 0.2.1." - for info in engine_infos: - if info.url in registered_urls: - continue - response = requests.post(f"{router_addr}/add_worker?url={info.url}") - response.raise_for_status() - return - - for info in engine_infos: - if info.worker_type == "encoder": - continue - if info.url in registered_urls: - continue - payload = { - "url": info.url, - "worker_type": info.worker_type, - } - if info.worker_type == "prefill": - if info.disaggregation_bootstrap_port is None: - raise RuntimeError( - f"External prefill worker {info.url} did not report disaggregation_bootstrap_port " - "from /server_info; cannot register it to the PD router." - ) - payload["bootstrap_port"] = info.disaggregation_bootstrap_port - response = requests.post(f"{router_addr}/workers", json=payload) - response.raise_for_status() - - def _start_external_rollout_servers(args, pg) -> dict[str, RolloutServer]: model_cfg = _external_model_from_args(args) router_ip, router_port = _start_router(args, has_pd_disaggregation=model_cfg.has_pd_disaggregation) args.sglang_router_ip = router_ip args.sglang_router_port = router_port - _register_external_workers_to_router(args, model_cfg) server_groups = [] engine_offset = 0 diff --git a/tests/test_qwen3.5_0.8B_external_pd.py b/tests/test_qwen3_4B_external_pd.py similarity index 65% rename from tests/test_qwen3.5_0.8B_external_pd.py rename to tests/test_qwen3_4B_external_pd.py index 0764908fe5..0158403bee 100644 --- a/tests/test_qwen3.5_0.8B_external_pd.py +++ b/tests/test_qwen3_4B_external_pd.py @@ -1,11 +1,11 @@ """E2E test for --rollout-external-engine-addrs with a pure-PD external fleet. -Spawns four SGLang servers out-of-band on a single 8-GPU box (all tp=1): -- 2 prefill (``--disaggregation-mode prefill``, mooncake transfer backend) -- 2 decode (``--disaggregation-mode decode``, mooncake transfer backend) +Spawns two SGLang servers out-of-band on a single GPU box (all tp=1): +- 1 prefill (``--disaggregation-mode prefill``, mooncake transfer backend) +- 1 decode (``--disaggregation-mode decode``, mooncake transfer backend) -and points slime at all four via ``--rollout-external-engine-addrs ...``. -The remaining 4 GPUs train. slime queries ``/server_info`` on each engine to +and points slime at both via ``--rollout-external-engine-addrs ...``. +The first 4 GPUs train. slime queries ``/server_info`` on each engine to infer per-engine TP / GPU counts and registers them to its PD-enabled router. Weight sync uses ``--update-weight-mode delta --update-weight-transport disk`` @@ -16,6 +16,7 @@ """ import os +import socket import subprocess import tempfile import time @@ -24,17 +25,97 @@ import slime.utils.external_utils.command_utils as U -MODEL_NAME = "Qwen3.5-0.8B" -MODEL_TYPE = "qwen3.5-0.8B" -NUM_GPUS = 8 +MODEL_NAME = "Qwen3-4B" +MODEL_TYPE = "qwen3-4B" +NUM_GPUS = 6 NUM_TRAIN_GPUS = 4 -NUM_PREFILL_ENGINES = 2 -NUM_DECODE_ENGINES = 2 +NUM_PREFILL_ENGINES = 1 +NUM_DECODE_ENGINES = 1 EXTERNAL_HOST = "127.0.0.1" -PREFILL_PORTS = [13150, 13151] -DECODE_PORTS = [13152, 13153] -BOOTSTRAP_PORTS = [13160, 13161] +PREFILL_PORTS = [13150] +DECODE_PORTS = [13151] +BOOTSTRAP_PORTS = [13160] + + +def _get_bond_ipv4(): + net_root = Path("/sys/class/net") + if not net_root.exists(): + return None + + bond_ifaces = [ + path.name for path in net_root.iterdir() if path.name.startswith("bond") and path.name[4:].isdigit() + ] + bond_ifaces.sort(key=lambda name: int(name[4:])) + for iface in bond_ifaces: + try: + output = subprocess.check_output(["ip", "-o", "-4", "addr", "show", "dev", iface], text=True) + except (OSError, subprocess.CalledProcessError): + continue + fields = output.split() + for idx, field in enumerate(fields): + if field == "inet" and idx + 1 < len(fields): + return fields[idx + 1].split("/", 1)[0] + return None + + +def _get_external_host(): + env_value = os.environ.get("SLIME_TEST_EXTERNAL_PD_HOST") + if env_value and env_value not in ("127.0.0.1", "localhost"): + return env_value + + bond_host = _get_bond_ipv4() + if bond_host is not None: + return bond_host + + master_addr = os.environ.get("MASTER_ADDR") + if master_addr and master_addr not in ("127.0.0.1", "localhost"): + return master_addr + + try: + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: + sock.connect(("8.8.8.8", 80)) + host = sock.getsockname()[0] + if host and not host.startswith("127."): + return host + except OSError: + pass + + return EXTERNAL_HOST + + +def _get_disaggregation_ib_device(): + env_value = os.environ.get("SLIME_TEST_DISAGGREGATION_IB_DEVICE") + if env_value is not None: + return env_value.strip() or None + + ib_root = Path("/sys/class/infiniband") + if not ib_root.exists(): + return None + + active_devices = [] + for device in ib_root.iterdir(): + for state_file in device.glob("ports/*/state"): + try: + if "ACTIVE" in state_file.read_text(): + active_devices.append(device.name) + break + except OSError: + continue + + bond_devices = [] + numeric_mlx5_devices = [] + for device in active_devices: + prefix, _, suffix = device.partition("_") + if prefix == "mlx5" and suffix.startswith("bond_") and suffix[5:].isdigit(): + bond_devices.append(device) + elif prefix == "mlx5" and suffix.isdigit(): + numeric_mlx5_devices.append(device) + bond_devices.sort(key=lambda name: int(name.rsplit("_", 1)[1])) + numeric_mlx5_devices.sort(key=lambda name: int(name.rsplit("_", 1)[1])) + + devices = bond_devices or numeric_mlx5_devices or sorted(active_devices) + return ",".join(devices) if devices else None def prepare(): @@ -50,7 +131,7 @@ def prepare(): def _get_gpu_split(): - """Partition the 8 visible GPUs: 4 train + 2 prefill + 2 decode.""" + """Partition visible GPUs: 4 train + 1 prefill + 1 decode.""" all_gpus = os.environ.get("CUDA_VISIBLE_DEVICES", ",".join(str(i) for i in range(NUM_GPUS))).split(",") assert len(all_gpus) >= NUM_GPUS, f"Expected at least {NUM_GPUS} GPUs, got {len(all_gpus)}" train_gpus = all_gpus[:NUM_TRAIN_GPUS] @@ -69,6 +150,8 @@ def _launch_sglang_server( log_path: str, disaggregation_mode: str, disaggregation_bootstrap_port: int | None = None, + disaggregation_ib_device: str | None = None, + external_host: str = EXTERNAL_HOST, ) -> subprocess.Popen: env = os.environ.copy() env["CUDA_VISIBLE_DEVICES"] = ",".join(gpus) @@ -93,8 +176,13 @@ def _launch_sglang_server( "--disaggregation-transfer-backend", "mooncake", ] + if disaggregation_ib_device is not None: + cmd += ["--disaggregation-ib-device", disaggregation_ib_device] if disaggregation_bootstrap_port is not None: cmd += ["--disaggregation-bootstrap-port", str(disaggregation_bootstrap_port)] + cmd += ["--load-balance-method", "follow_bootstrap_room"] + else: + cmd += ["--prefill-round-robin-balance"] log_file = open(log_path, "w") process = subprocess.Popen(cmd, env=env, stdout=log_file, stderr=subprocess.STDOUT) @@ -111,7 +199,7 @@ def _launch_sglang_server( if process.poll() is not None: raise RuntimeError(f"{disaggregation_mode} server exited with code {process.returncode}; check {log_path}") try: - req = urllib.request.urlopen(f"http://{EXTERNAL_HOST}:{port}/server_info", timeout=2) + req = urllib.request.urlopen(f"http://{external_host}:{port}/server_info", timeout=2) if req.status == 200: print(f"External sglang {disaggregation_mode} server is ready on GPUs {gpus}") return process @@ -125,6 +213,10 @@ def _launch_sglang_server( def execute(): train_gpus, prefill_gpus, decode_gpus = _get_gpu_split() + external_host = _get_external_host() + disaggregation_ib_device = _get_disaggregation_ib_device() + print(f"Using external host for SGLang workers: {external_host}") + print(f"Using SGLang disaggregation IB device: {disaggregation_ib_device}") processes: list[subprocess.Popen] = [] # Restrict CUDA_VISIBLE_DEVICES to training GPUs before Ray starts so @@ -142,6 +234,8 @@ def launch_external_engines(): tp=1, disaggregation_mode="prefill", disaggregation_bootstrap_port=bootstrap_port, + disaggregation_ib_device=disaggregation_ib_device, + external_host=external_host, log_path=f"/tmp/sglang_external_prefill_{idx}.log", ) ) @@ -152,6 +246,8 @@ def launch_external_engines(): port=port, tp=1, disaggregation_mode="decode", + disaggregation_ib_device=disaggregation_ib_device, + external_host=external_host, log_path=f"/tmp/sglang_external_decode_{idx}.log", ) ) @@ -168,16 +264,16 @@ def launch_external_engines(): "--apply-chat-template " "--rollout-shuffle " "--rm-type math " - "--num-rollout 2 " + "--num-rollout 3 " "--rollout-batch-size 4 " "--n-samples-per-prompt 4 " - "--rollout-max-response-len 512 " + "--rollout-max-response-len 1024 " "--rollout-temperature 0.8 " "--global-batch-size 16 " ) perf_args = ( - "--tensor-model-parallel-size 1 " + "--tensor-model-parallel-size 2 " "--sequence-parallel " "--pipeline-model-parallel-size 1 " "--context-parallel-size 1 " @@ -185,6 +281,9 @@ def launch_external_engines(): "--expert-tensor-parallel-size 1 " "--use-dynamic-batch-size " "--max-tokens-per-gpu 9216 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " ) grpo_args = ( @@ -210,9 +309,9 @@ def launch_external_engines(): ) # No --rollout-num-gpus / --rollout-num-gpus-per-engine: those are - # inferred from /server_info on each external engine (2 prefill + - # 2 decode, all tp=1). - all_addrs = [f"{EXTERNAL_HOST}:{port}" for port in (*PREFILL_PORTS, *DECODE_PORTS)] + # inferred from /server_info on each external engine (1 prefill + + # 1 decode, all tp=1). + all_addrs = [f"{external_host}:{port}" for port in (*PREFILL_PORTS, *DECODE_PORTS)] external_args = "--rollout-external-engine-addrs " + " ".join(all_addrs) + " " # External engines have no NCCL group with the trainer, so weight @@ -235,7 +334,6 @@ def launch_external_engines(): "--accumulate-allreduce-grads-in-fp32 " "--attention-softmax-in-fp32 " "--attention-backend flash " - "--loss-mask-type qwen3_5 " "--actor-num-nodes 1 " f"--actor-num-gpus-per-node {NUM_TRAIN_GPUS} " ) @@ -258,6 +356,10 @@ def launch_external_engines(): num_gpus_per_node=NUM_TRAIN_GPUS, megatron_model_type=MODEL_TYPE, before_ray_job_submit=launch_external_engines, + extra_env_vars={ + "no_proxy": f"127.0.0.1,localhost,{external_host}", + "NO_PROXY": f"127.0.0.1,localhost,{external_host}", + }, ) delta_files = list(Path(delta_dir).glob("weight_v*/*.safetensors")) From 130027af63e9f855293111dfe6bd03605fb472c4 Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Thu, 4 Jun 2026 10:42:24 +0000 Subject: [PATCH 13/17] fix --- slime/backends/sglang_utils/external.py | 26 ++++++-------------- slime/backends/sglang_utils/sglang_engine.py | 13 ++-------- slime/ray/actor_group.py | 26 ++++++++------------ slime/ray/placement_group.py | 3 ++- slime/ray/train_actor.py | 7 ++---- slime/utils/arguments.py | 13 +--------- tests/test_external_sglang_engines.py | 9 ++++--- tests/test_qwen3_4B_external_pd.py | 4 +-- 8 files changed, 31 insertions(+), 70 deletions(-) diff --git a/slime/backends/sglang_utils/external.py b/slime/backends/sglang_utils/external.py index f85e6a1175..573c6b5a5a 100644 --- a/slime/backends/sglang_utils/external.py +++ b/slime/backends/sglang_utils/external.py @@ -110,14 +110,7 @@ def build_external_model_info(infos: list[ExternalEngineInfo], name: str = "defa return ExternalModelInfo(name=name, server_groups=server_groups) -def _positive_int(value, default: int) -> int: - if value is None: - return default - value = int(value) - return value if value > 0 else default - - -def _get_server_info(url: str, timeout: float = 30.0) -> dict: +def get_server_info(url: str, timeout: float = 30.0) -> dict: errors = [] for endpoint in ("/server_info", "/get_server_info"): try: @@ -144,14 +137,11 @@ def discover_external_engines(addrs: list[str], timeout: float = 30.0) -> list[E url = normalize_external_engine_addr(addr) parsed = urlparse(url) assert parsed.hostname is not None and parsed.port is not None - server_info = _get_server_info(url, timeout=timeout) + server_info = get_server_info(url, timeout=timeout) - pp_size = _positive_int(server_info.get("pp_size") or server_info.get("pipeline_parallel_size"), 1) - tp_size = _positive_int(server_info.get("tp_size") or server_info.get("tensor_parallel_size"), 1) - num_gpus = _positive_int( - server_info.get("num_gpus") or server_info.get("num_gpus_per_engine"), - tp_size * pp_size, - ) + pp_size = int(server_info.get("pp_size") or server_info.get("pipeline_parallel_size") or 1) + tp_size = int(server_info.get("tp_size") or server_info.get("tensor_parallel_size") or 1) + num_gpus = int(server_info.get("num_gpus") or server_info.get("num_gpus_per_engine") or tp_size * pp_size) bootstrap_port = server_info.get("disaggregation_bootstrap_port") bootstrap_port = int(bootstrap_port) if bootstrap_port is not None else None @@ -171,12 +161,10 @@ def discover_external_engines(addrs: list[str], timeout: float = 30.0) -> list[E def apply_external_engine_info_to_args(args, logger=None) -> None: """Detect external engines and store the derived topology on ``args``.""" - addrs = getattr(args, "rollout_external_engine_addrs", None) + addrs = args.rollout_external_engine_addrs if not addrs: - args.rollout_external = False - return + raise ValueError("apply_external_engine_info_to_args requires --rollout-external-engine-addrs.") - args.rollout_external = True infos = discover_external_engines(addrs) if not infos: raise ValueError("--rollout-external-engine-addrs did not contain any engines.") diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index e04eb20b5d..a4366d1180 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -13,6 +13,7 @@ from sglang.srt.utils import kill_process_tree from urllib3.exceptions import NewConnectionError +from slime.backends.sglang_utils.external import get_server_info from slime.ray.ray_actor import RayActor from slime.utils.http_utils import get_host_info @@ -169,11 +170,6 @@ def _format_v6_uri(addr): def _init_external(self, expect_server_args, external_engine_need_check_fields): logger.info(f"Use external SGLang engine (rank={self.rank}, expect_server_args={expect_server_args})") - def _get_actual_server_args(): - response = requests.get(f"http://{self.server_host}:{self.server_port}/get_server_info") - response.raise_for_status() - return response.json() - def _sanity_check_server_args(actual_server_args, expect_server_args): for name in external_engine_need_check_fields: expect_value = expect_server_args.get(name) @@ -182,12 +178,7 @@ def _sanity_check_server_args(actual_server_args, expect_server_args): actual_value == expect_value ), f"{name=} {expect_value=} {actual_value=} {expect_server_args=} {actual_server_args=}" - _wait_server_healthy( - base_url=f"http://{self.server_host}:{self.server_port}", - api_key=None, - is_process_alive=lambda: True, - ) - actual_server_args = _get_actual_server_args() + actual_server_args = get_server_info(f"http://{self.server_host}:{self.server_port}") _sanity_check_server_args(actual_server_args, expect_server_args) self._register_to_router(expect_server_args) diff --git a/slime/ray/actor_group.py b/slime/ray/actor_group.py index 15519ccebe..27ad610ad9 100644 --- a/slime/ray/actor_group.py +++ b/slime/ray/actor_group.py @@ -44,14 +44,11 @@ def __init__( self._allocate_gpus_for_actor(pg, num_gpus_per_actor) def _allocate_gpus_for_actor(self, pg, num_gpus_per_actor): - world_size = 1 if self.args.debug_rollout_only else self._num_nodes * self._num_gpus_per_node + world_size = self._num_nodes * self._num_gpus_per_node # Use placement group to lock resources for models of same type - if self.args.debug_rollout_only: - pg, reordered_bundle_indices = None, [] - else: - assert pg is not None - pg, reordered_bundle_indices, _reordered_gpu_ids = pg + assert pg is not None + pg, reordered_bundle_indices, _reordered_gpu_ids = pg env_vars = { # because sglang will always set NCCL_CUMEM_ENABLE to 0 @@ -92,23 +89,20 @@ def _allocate_gpus_for_actor(self, pg, num_gpus_per_actor): actor_impl = MegatronTrainRayActor - default_num_gpus = 0 if self.args.debug_rollout_only else 1 - TrainRayActor = ray.remote(num_gpus=default_num_gpus, runtime_env={"env_vars": env_vars})(actor_impl) + TrainRayActor = ray.remote(num_gpus=1, runtime_env={"env_vars": env_vars})(actor_impl) # Create worker actors self._actor_handlers = [] master_addr, master_port = None, None for rank in range(world_size): - actor_options = { - "num_cpus": num_gpus_per_actor, - "num_gpus": 0 if self.args.debug_rollout_only else num_gpus_per_actor, - } - if not self.args.debug_rollout_only: - actor_options["scheduling_strategy"] = PlacementGroupSchedulingStrategy( + actor = TrainRayActor.options( + num_cpus=num_gpus_per_actor, + num_gpus=num_gpus_per_actor, + scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=pg, placement_group_bundle_index=reordered_bundle_indices[rank], - ) - actor = TrainRayActor.options(**actor_options).remote(world_size, rank, master_addr, master_port) + ), + ).remote(world_size, rank, master_addr, master_port) if rank == 0: master_addr, master_port = ray.get(actor.get_master_addr_and_port.remote()) self._actor_handlers.append(actor) diff --git a/slime/ray/placement_group.py b/slime/ray/placement_group.py index 96014928d7..9499cdd57f 100644 --- a/slime/ray/placement_group.py +++ b/slime/ray/placement_group.py @@ -7,7 +7,6 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from .actor_group import RayTrainGroup -from .rollout import RolloutManager logger = logging.getLogger(__name__) @@ -196,6 +195,8 @@ def create_training_models(args, pgs, rollout_manager): def create_rollout_manager(args, pg): + from .rollout import RolloutManager + rollout_manager = RolloutManager.options( num_cpus=1, num_gpus=0, diff --git a/slime/ray/train_actor.py b/slime/ray/train_actor.py index b84b89d92e..a8ba6ddc64 100644 --- a/slime/ray/train_actor.py +++ b/slime/ray/train_actor.py @@ -18,14 +18,11 @@ def get_local_gpu_id(): - gpu_ids = ray.get_gpu_ids() - if not gpu_ids: - return 0 cvd = os.environ.get("CUDA_VISIBLE_DEVICES", None) if cvd is None: - return gpu_ids[0] + return ray.get_gpu_ids()[0] else: - return cvd.split(",").index(str(gpu_ids[0])) + return cvd.split(",").index(str(ray.get_gpu_ids()[0])) class TrainRayActor(RayActor): diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index cd107fce1f..f562f8c474 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -507,12 +507,6 @@ def add_rollout_arguments(parser): "It may be helpful for updating loss mask." ), ) - parser.add_argument( - "--rollout-external", - action="store_true", - default=None, - help=argparse.SUPPRESS, - ) parser.add_argument( "--rollout-external-engine-addrs", type=str, @@ -1787,14 +1781,9 @@ def slime_validate_args(args): ) args.debug_train_only = True - if getattr(args, "rollout_external", None) is not None: - logger.warning( - "--rollout-external is deprecated and ignored. " - "Set --rollout-external-engine-addrs to use pre-launched external SGLang engines." - ) args.rollout_external = args.rollout_external_engine_addrs is not None - if not args.debug_train_only: + if args.rollout_external and not args.debug_train_only: apply_external_engine_info_to_args(args, logger=logger) args.use_critic = args.advantage_estimator == "ppo" diff --git a/tests/test_external_sglang_engines.py b/tests/test_external_sglang_engines.py index 74286d20cb..ec3bd38a1f 100644 --- a/tests/test_external_sglang_engines.py +++ b/tests/test_external_sglang_engines.py @@ -121,6 +121,7 @@ def fake_get(url, timeout): monkeypatch.setattr("slime.backends.sglang_utils.external.requests.get", fake_get) args = Namespace( + rollout_external=True, rollout_external_engine_addrs=["prefill:10090", "decode:10091"], rollout_num_gpus=None, rollout_num_gpus_per_engine=1, @@ -157,6 +158,7 @@ def fake_get(url, timeout): monkeypatch.setattr("slime.backends.sglang_utils.external.requests.get", fake_get) args = Namespace( + rollout_external=True, rollout_external_engine_addrs=["regular:10090"], router_pd_disaggregation=True, ) @@ -169,12 +171,11 @@ def fake_get(url, timeout): assert args.rollout_num_engines == 1 -def test_apply_external_engine_info_no_addrs_disables_external(): +def test_apply_external_engine_info_requires_addrs(): args = Namespace(rollout_external_engine_addrs=None) - apply_external_engine_info_to_args(args) - - assert args.rollout_external is False + with pytest.raises(ValueError, match="rollout-external-engine-addrs"): + apply_external_engine_info_to_args(args) if __name__ == "__main__": diff --git a/tests/test_qwen3_4B_external_pd.py b/tests/test_qwen3_4B_external_pd.py index 0158403bee..d39f22c625 100644 --- a/tests/test_qwen3_4B_external_pd.py +++ b/tests/test_qwen3_4B_external_pd.py @@ -27,6 +27,7 @@ MODEL_NAME = "Qwen3-4B" MODEL_TYPE = "qwen3-4B" +TORCH_DIST_CKPT = f"/root/{MODEL_NAME}_torch_dist" NUM_GPUS = 6 NUM_TRAIN_GPUS = 4 NUM_PREFILL_ENGINES = 1 @@ -126,7 +127,6 @@ def prepare(): model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_TRAIN_GPUS, - dir_dst="/dev/shm", ) @@ -255,7 +255,7 @@ def launch_external_engines(): delta_dir_cm = tempfile.TemporaryDirectory(prefix="slime_external_pd_delta_") delta_dir = delta_dir_cm.name try: - ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /dev/shm/{MODEL_NAME}_torch_dist " + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load {TORCH_DIST_CKPT} " rollout_args = ( "--prompt-data /root/datasets/gsm8k/train.parquet " From cdc81401758485c554a3e64b834ff1890423345e Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Thu, 4 Jun 2026 11:14:25 +0000 Subject: [PATCH 14/17] minor refactor --- slime/backends/sglang_utils/external.py | 186 +++++++++++++++--------- slime/ray/rollout.py | 146 ++----------------- tests/test_external_sglang_engines.py | 41 +----- 3 files changed, 131 insertions(+), 242 deletions(-) diff --git a/slime/backends/sglang_utils/external.py b/slime/backends/sglang_utils/external.py index 573c6b5a5a..9f7a307b17 100644 --- a/slime/backends/sglang_utils/external.py +++ b/slime/backends/sglang_utils/external.py @@ -3,10 +3,13 @@ from __future__ import annotations import dataclasses +import logging from urllib.parse import urlparse import requests +logger = logging.getLogger(__name__) + @dataclasses.dataclass(frozen=True) class ExternalEngineInfo: @@ -26,55 +29,6 @@ def to_dict(self) -> dict: return dataclasses.asdict(self) -@dataclasses.dataclass(frozen=True) -class ExternalServerGroupInfo: - worker_type: str - num_gpus_per_engine: int - engine_infos: tuple[ExternalEngineInfo, ...] - - @property - def num_gpus(self) -> int: - return sum(info.num_gpus for info in self.engine_infos) - - def to_dict(self) -> dict: - return { - "worker_type": self.worker_type, - "num_gpus": self.num_gpus, - "num_gpus_per_engine": self.num_gpus_per_engine, - "engine_infos": [info.to_dict() for info in self.engine_infos], - } - - -@dataclasses.dataclass(frozen=True) -class ExternalModelInfo: - name: str - server_groups: tuple[ExternalServerGroupInfo, ...] - update_weights: bool = True - - @property - def has_pd_disaggregation(self) -> bool: - return any(g.worker_type in ("prefill", "decode") for g in self.server_groups) - - @property - def engine_infos(self) -> list[ExternalEngineInfo]: - return [info for group in self.server_groups for info in group.engine_infos] - - @property - def total_num_gpus(self) -> int: - return sum(group.num_gpus for group in self.server_groups) - - @property - def num_engines(self) -> int: - return sum(len(group.engine_infos) for group in self.server_groups) - - def to_dict(self) -> dict: - return { - "name": self.name, - "server_groups": [group.to_dict() for group in self.server_groups], - "update_weights": self.update_weights, - } - - def normalize_external_engine_addr(addr: str) -> str: """Normalize ``host:port`` or ``http://host:port`` to an HTTP base URL.""" if "://" not in addr: @@ -93,21 +47,16 @@ def external_engine_info_from_dict(data: dict) -> ExternalEngineInfo: return ExternalEngineInfo(**data) -def build_external_model_info(infos: list[ExternalEngineInfo], name: str = "default") -> ExternalModelInfo: - specs_by_topology: dict[tuple[str, int], list[ExternalEngineInfo]] = {} - for info in infos: - key = (info.worker_type, info.num_gpus) - specs_by_topology.setdefault(key, []).append(info) - - server_groups = tuple( - ExternalServerGroupInfo( - worker_type=worker_type, - num_gpus_per_engine=num_gpus_per_engine, - engine_infos=tuple(group_infos), - ) - for (worker_type, num_gpus_per_engine), group_infos in specs_by_topology.items() - ) - return ExternalModelInfo(name=name, server_groups=server_groups) +def external_engine_init_kwargs(info: ExternalEngineInfo) -> dict: + init_kwargs = { + "dist_init_addr": f"{info.host}:{info.port}", + "nccl_port": None, + "host": info.host, + "port": info.port, + } + if info.worker_type == "prefill": + init_kwargs["disaggregation_bootstrap_port"] = info.disaggregation_bootstrap_port + return init_kwargs def get_server_info(url: str, timeout: float = 30.0) -> dict: @@ -170,10 +119,8 @@ def apply_external_engine_info_to_args(args, logger=None) -> None: raise ValueError("--rollout-external-engine-addrs did not contain any engines.") args.rollout_external_engine_infos = [info.to_dict() for info in infos] - model_info = build_external_model_info(infos) - args.rollout_external_model_info = model_info.to_dict() - args.rollout_num_engines = model_info.num_engines - args.rollout_num_gpus = model_info.total_num_gpus + args.rollout_num_engines = len(infos) + args.rollout_num_gpus = sum(info.num_gpus for info in infos) if logger is not None: summary = [ @@ -186,3 +133,106 @@ def apply_external_engine_info_to_args(args, logger=None) -> None: for info in infos ] logger.info(f"Detected external SGLang engines: {summary}") + + +@dataclasses.dataclass +class ExternalRolloutServer: + """Rollout server backed by pre-launched external SGLang engines.""" + + engines: list + engine_gpu_counts: list[int] + engine_gpu_offsets: list[int] + router_ip: str | None = None + router_port: int | None = None + model_name: str = "default" + update_weights: bool = True + num_new_engines: int = 0 + server_groups: list = dataclasses.field(default_factory=list) + + @property + def all_engines(self): + return self.engines + + def recover(self): + logger.warning("Fault tolerance is not supported for external rollout engines; skip recover.") + + def offload(self): + return [] + + def onload(self, tags: list[str] | None = None): + return [] + + def onload_weights(self): + return [] + + def onload_kv(self): + return [] + + +def external_engine_infos_from_args(args) -> list[ExternalEngineInfo]: + raw_infos = getattr(args, "rollout_external_engine_infos", None) + if raw_infos is None: + addrs = getattr(args, "rollout_external_engine_addrs", None) + if not addrs: + raise RuntimeError("External rollout requires --rollout-external-engine-addrs.") + infos = discover_external_engines(addrs) + args.rollout_external_engine_infos = [info.to_dict() for info in infos] + else: + infos = [external_engine_info_from_dict(info) if isinstance(info, dict) else info for info in raw_infos] + args.rollout_num_engines = len(infos) + args.rollout_num_gpus = sum(info.num_gpus for info in infos) + return infos + + +def start_external_rollout_servers(args, *, start_router) -> dict[str, ExternalRolloutServer]: + import ray + + from slime.backends.sglang_utils.sglang_engine import SGLangEngine + + infos = external_engine_infos_from_args(args) + router_ip, router_port = start_router(args, has_pd_disaggregation=any(info.is_pd_worker for info in infos)) + args.sglang_router_ip = router_ip + args.sglang_router_port = router_port + + engines = [] + engine_gpu_counts = [] + engine_gpu_offsets = [] + init_handles = [] + RolloutRayActor = ray.remote(SGLangEngine) + gpu_offset = 0 + for rank, info in enumerate(infos): + rollout_engine = RolloutRayActor.options(num_cpus=0.2, num_gpus=0).remote( + args=args, + rank=rank, + worker_type=info.worker_type, + base_gpu_id=0, + num_gpus_per_engine=info.num_gpus, + ) + engines.append(rollout_engine) + engine_gpu_counts.append(info.num_gpus) + engine_gpu_offsets.append(gpu_offset) + gpu_offset += info.num_gpus + init_handles.append( + rollout_engine.init.remote( + **external_engine_init_kwargs(info), + router_ip=router_ip, + router_port=router_port, + ) + ) + + if init_handles: + ray.get(init_handles) + + args.sglang_model_routers = {"default": (router_ip, router_port)} + return { + "default": ExternalRolloutServer( + engines=engines, + engine_gpu_counts=engine_gpu_counts, + engine_gpu_offsets=engine_gpu_offsets, + router_ip=router_ip, + router_port=router_port, + model_name="default", + update_weights=True, + num_new_engines=len(engines), + ) + } diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index ac57f2b05a..9743d12235 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -14,13 +14,7 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS -from slime.backends.sglang_utils.external import ( - ExternalEngineInfo, - ExternalModelInfo, - build_external_model_info, - discover_external_engines, - external_engine_info_from_dict, -) +from slime.backends.sglang_utils.external import start_external_rollout_servers from slime.backends.sglang_utils.sglang_config import ModelConfig, ServerGroupConfig, SglangConfig from slime.backends.sglang_utils.sglang_engine import SGLangEngine from slime.rollout.base_types import call_rollout_fn @@ -65,19 +59,14 @@ class ServerGroup: model_path: str | None = None # checkpoint path for update_weights_from_disk router_ip: str | None = None router_port: int | None = None - external_worker_specs: list[ExternalEngineInfo] = dataclasses.field(default_factory=list) @property def nodes_per_engine(self): - if self.args.rollout_external: - return 1 return max(1, self.num_gpus_per_engine // self.args.num_gpus_per_node) @property def engines(self): """Node-0 engines only (for multi-node serving).""" - if self.args.rollout_external: - return [engine for engine in self.all_engines if engine is not None] return self.all_engines[:: self.nodes_per_engine] def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[list, dict[int, int]]: @@ -97,9 +86,6 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis self.num_new_engines = 0 return [], port_cursors - if self.args.rollout_external: - return self._start_external_proxy_engines(port_cursors) - num_gpu_per_engine = min(self.num_gpus_per_engine, self.args.num_gpus_per_node) pg, reordered_bundle_indices, reordered_gpu_ids = self.pg @@ -195,39 +181,6 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis ] return init_handles, port_cursors - def _start_external_proxy_engines(self, port_cursors: dict[int, int]) -> tuple[list, dict[int, int]]: - """Create CPU-only proxy actors for pre-launched external workers.""" - assert self.external_worker_specs, "external_worker_specs must be populated for rollout_external." - - RolloutRayActor = ray.remote(SGLangEngine) - rollout_engines = [] - for i, spec in enumerate(self.external_worker_specs): - if self.all_engines[i] is not None: - continue - - global_rank = self.rank_offset + i - rollout_engine = RolloutRayActor.options(num_cpus=0.2, num_gpus=0).remote( - self.args, - rank=global_rank, - worker_type=spec.worker_type, - base_gpu_id=0, - sglang_overrides=self.sglang_overrides, - num_gpus_per_engine=self.num_gpus_per_engine, - ) - rollout_engines.append((global_rank, rollout_engine, spec)) - self.all_engines[i] = rollout_engine - - self.num_new_engines = len(rollout_engines) - init_handles = [ - engine.init.remote( - **_external_engine_init_kwargs(spec), - router_ip=self.router_ip, - router_port=self.router_port, - ) - for _rank, engine, spec in rollout_engines - ] - return init_handles, port_cursors - def offload(self): """Fire release_memory_occupation on all engines (non-blocking). @@ -427,7 +380,7 @@ def __init__(self, args, pg): logger.info(f"import {self.args.eval_function_path} as eval_generate_rollout function.") if self.args.debug_train_only: - self.servers: dict[str, RolloutServer] = {} + self.servers: dict[str, Any] = {} else: init_http_client(args) self.servers = start_rollout_servers(args, pg) @@ -470,7 +423,12 @@ def _try_ci_fault_injection(self): # Only inject fault once self._ci_fault_injection_pending = False - if self.server and self.server.server_groups[0].all_engines and self.server.server_groups[0].all_engines[0]: + if ( + self.server + and self.server.server_groups + and self.server.server_groups[0].all_engines + and self.server.server_groups[0].all_engines[0] + ): logger.info("CI Fault Injection: Simulating crash on engine 0 during generate") try: # This will cause the ray actor to exit @@ -489,13 +447,13 @@ def dispose(self): logging_utils.finish_tracking(self.args) @property - def server(self) -> RolloutServer | None: + def server(self) -> Any | None: """Default server (first model). For backward compatibility.""" if not self.servers: return None return next(iter(self.servers.values())) - def _get_updatable_server(self) -> RolloutServer | None: + def _get_updatable_server(self) -> Any | None: """Return the server with ``update_weights=True``. When multiple updatable servers exist, returns the first one @@ -888,18 +846,6 @@ def _validate_rollout_id_annotated(node, depth=0): _validate_rollout_id_annotated(item, depth + 1) -def _external_engine_init_kwargs(spec: ExternalEngineInfo) -> dict: - init_kwargs = { - "dist_init_addr": f"{spec.host}:{spec.port}", - "nccl_port": None, - "host": spec.host, - "port": spec.port, - } - if spec.worker_type == "prefill": - init_kwargs["disaggregation_bootstrap_port"] = spec.disaggregation_bootstrap_port - return init_kwargs - - def _allocate_rollout_engine_addr_and_ports_normal( *, args, @@ -1059,75 +1005,7 @@ def _compute_megatron_num_gpus(args) -> int: return num -def _external_model_from_args(args) -> ExternalModelInfo: - raw_infos = getattr(args, "rollout_external_engine_infos", None) - if raw_infos is None: - addrs = getattr(args, "rollout_external_engine_addrs", None) - if not addrs: - raise RuntimeError("External rollout requires --rollout-external-engine-addrs.") - infos = discover_external_engines(addrs) - args.rollout_external_engine_infos = [info.to_dict() for info in infos] - else: - infos = [external_engine_info_from_dict(info) if isinstance(info, dict) else info for info in raw_infos] - model_info = build_external_model_info(infos) - args.rollout_external_model_info = model_info.to_dict() - args.rollout_num_engines = model_info.num_engines - args.rollout_num_gpus = model_info.total_num_gpus - return model_info - - -def _start_external_rollout_servers(args, pg) -> dict[str, RolloutServer]: - model_cfg = _external_model_from_args(args) - router_ip, router_port = _start_router(args, has_pd_disaggregation=model_cfg.has_pd_disaggregation) - args.sglang_router_ip = router_ip - args.sglang_router_port = router_port - - server_groups = [] - engine_offset = 0 - gpu_offset = 0 - init_handles = [] - for group_cfg in model_cfg.server_groups: - group_specs = list(group_cfg.engine_infos) - num_gpus = group_cfg.num_gpus_per_engine - group = ServerGroup( - args=args, - pg=pg, - all_engines=[None] * len(group_specs), - num_gpus_per_engine=num_gpus, - num_new_engines=0, - worker_type=group_cfg.worker_type, - rank_offset=engine_offset, - gpu_offset=gpu_offset, - sglang_overrides={}, - needs_offload=False, - model_path=args.hf_checkpoint, - router_ip=router_ip, - router_port=router_port, - external_worker_specs=group_specs, - ) - handles, _ = group.start_engines({}) - init_handles.extend(handles) - server_groups.append(group) - - engine_offset += len(group_specs) - gpu_offset += len(group_specs) * num_gpus - - if init_handles: - ray.get(init_handles) - - args.sglang_model_routers = {"default": (router_ip, router_port)} - return { - "default": RolloutServer( - server_groups=server_groups, - router_ip=router_ip, - router_port=router_port, - model_name="default", - update_weights=True, - ) - } - - -def start_rollout_servers(args, pg) -> dict[str, RolloutServer]: +def start_rollout_servers(args, pg) -> dict[str, Any]: """Start rollout servers: one per model, each with its own router. Each model defined in the sglang config gets its own router and set @@ -1141,7 +1019,7 @@ def start_rollout_servers(args, pg) -> dict[str, RolloutServer]: as the HTTP client is shared across all servers. """ if args.rollout_external: - return _start_external_rollout_servers(args, pg) + return start_external_rollout_servers(args, start_router=_start_router) config = _resolve_sglang_config(args) diff --git a/tests/test_external_sglang_engines.py b/tests/test_external_sglang_engines.py index ec3bd38a1f..704bcc7a80 100644 --- a/tests/test_external_sglang_engines.py +++ b/tests/test_external_sglang_engines.py @@ -8,12 +8,7 @@ if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) -from slime.backends.sglang_utils.external import ( - ExternalEngineInfo, - apply_external_engine_info_to_args, - build_external_model_info, - discover_external_engines, -) +from slime.backends.sglang_utils.external import apply_external_engine_info_to_args, discover_external_engines from slime.utils.http_utils import get_rollout_num_engines NUM_GPUS = 0 @@ -63,40 +58,6 @@ def fake_get(url, timeout): assert info.server_info["ep_size"] == 4 -def test_build_external_model_info_groups_engines_by_topology(): - infos = [ - ExternalEngineInfo( - url="http://prefill-0:10090", - host="prefill-0", - port=10090, - worker_type="prefill", - num_gpus=2, - ), - ExternalEngineInfo( - url="http://prefill-1:10091", - host="prefill-1", - port=10091, - worker_type="prefill", - num_gpus=2, - ), - ExternalEngineInfo(url="http://decode-0:10092", host="decode-0", port=10092, worker_type="decode", num_gpus=4), - ] - - model_info = build_external_model_info(infos) - - assert model_info.name == "default" - assert model_info.has_pd_disaggregation is True - assert model_info.num_engines == 3 - assert model_info.total_num_gpus == 8 - assert [ - (group.worker_type, group.num_gpus, group.num_gpus_per_engine, len(group.engine_infos)) - for group in model_info.server_groups - ] == [ - ("prefill", 4, 2, 2), - ("decode", 4, 4, 1), - ] - - def test_apply_external_engine_info_handles_pd(monkeypatch): payloads = { "http://prefill:10090/server_info": { From 95fe1a3a03f9dc7a57df70dee2ef79ff2228e786 Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Thu, 4 Jun 2026 11:20:16 +0000 Subject: [PATCH 15/17] cleanup --- slime/backends/sglang_utils/arguments.py | 5 ----- slime/backends/sglang_utils/external.py | 15 +++++---------- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/slime/backends/sglang_utils/arguments.py b/slime/backends/sglang_utils/arguments.py index 17385bd2ea..761086d8bb 100644 --- a/slime/backends/sglang_utils/arguments.py +++ b/slime/backends/sglang_utils/arguments.py @@ -157,11 +157,6 @@ def validate_args(args): if getattr(args, "sglang_router_ip", None): args.sglang_router_ip = _wrap_ipv6(args.sglang_router_ip) - if getattr(args, "rollout_external", False) and args.sglang_router_ip is not None: - assert ( - args.sglang_router_port is not None - ), "--sglang-router-port must be set with --sglang-router-ip when using --rollout-external-engine-addrs." - # Mutual-exclusion checks for PD disaggregation / sglang-config. assert not ( getattr(args, "prefill_num_servers", None) is not None and getattr(args, "rollout_external", False) diff --git a/slime/backends/sglang_utils/external.py b/slime/backends/sglang_utils/external.py index 9f7a307b17..23b3a2c564 100644 --- a/slime/backends/sglang_utils/external.py +++ b/slime/backends/sglang_utils/external.py @@ -172,16 +172,11 @@ def onload_kv(self): def external_engine_infos_from_args(args) -> list[ExternalEngineInfo]: raw_infos = getattr(args, "rollout_external_engine_infos", None) if raw_infos is None: - addrs = getattr(args, "rollout_external_engine_addrs", None) - if not addrs: - raise RuntimeError("External rollout requires --rollout-external-engine-addrs.") - infos = discover_external_engines(addrs) - args.rollout_external_engine_infos = [info.to_dict() for info in infos] - else: - infos = [external_engine_info_from_dict(info) if isinstance(info, dict) else info for info in raw_infos] - args.rollout_num_engines = len(infos) - args.rollout_num_gpus = sum(info.num_gpus for info in infos) - return infos + raise RuntimeError( + "External rollout engine info is missing. " + "apply_external_engine_info_to_args must run before starting external rollout servers." + ) + return [external_engine_info_from_dict(info) if isinstance(info, dict) else info for info in raw_infos] def start_external_rollout_servers(args, *, start_router) -> dict[str, ExternalRolloutServer]: From ba252c44c8e7e9893ce8be90681bc6ac4f29bcfb Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Thu, 4 Jun 2026 11:27:36 +0000 Subject: [PATCH 16/17] cleanup --- slime/backends/sglang_utils/external.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/slime/backends/sglang_utils/external.py b/slime/backends/sglang_utils/external.py index 23b3a2c564..7499bb907d 100644 --- a/slime/backends/sglang_utils/external.py +++ b/slime/backends/sglang_utils/external.py @@ -43,10 +43,6 @@ def normalize_external_engine_addr(addr: str) -> str: return addr -def external_engine_info_from_dict(data: dict) -> ExternalEngineInfo: - return ExternalEngineInfo(**data) - - def external_engine_init_kwargs(info: ExternalEngineInfo) -> dict: init_kwargs = { "dist_init_addr": f"{info.host}:{info.port}", @@ -176,7 +172,7 @@ def external_engine_infos_from_args(args) -> list[ExternalEngineInfo]: "External rollout engine info is missing. " "apply_external_engine_info_to_args must run before starting external rollout servers." ) - return [external_engine_info_from_dict(info) if isinstance(info, dict) else info for info in raw_infos] + return [ExternalEngineInfo(**info) if isinstance(info, dict) else info for info in raw_infos] def start_external_rollout_servers(args, *, start_router) -> dict[str, ExternalRolloutServer]: From 49ac0c9fa9d7f57aea6f98ebf4fec76d8fb21fde Mon Sep 17 00:00:00 2001 From: Zilin Zhu Date: Thu, 4 Jun 2026 11:42:25 +0000 Subject: [PATCH 17/17] fix test --- tests/test_qwen3_4B_external_pd.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_qwen3_4B_external_pd.py b/tests/test_qwen3_4B_external_pd.py index d39f22c625..68078603c6 100644 --- a/tests/test_qwen3_4B_external_pd.py +++ b/tests/test_qwen3_4B_external_pd.py @@ -27,7 +27,7 @@ MODEL_NAME = "Qwen3-4B" MODEL_TYPE = "qwen3-4B" -TORCH_DIST_CKPT = f"/root/{MODEL_NAME}_torch_dist" +TORCH_DIST_CKPT = f"/root/models/{MODEL_NAME}_torch_dist" NUM_GPUS = 6 NUM_TRAIN_GPUS = 4 NUM_PREFILL_ENGINES = 1 @@ -127,6 +127,7 @@ def prepare(): model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_TRAIN_GPUS, + dir_dst="/root/models", )