diff --git a/.claude/skills/add-tests-and-ci/SKILL.md b/.claude/skills/add-tests-and-ci/SKILL.md index d4a0a7397..000ea983b 100644 --- a/.claude/skills/add-tests-and-ci/SKILL.md +++ b/.claude/skills/add-tests-and-ci/SKILL.md @@ -40,15 +40,33 @@ if __name__ == "__main__": - `run-ci-changed` extracts a top-level `NUM_GPUS = ` constant from added/modified `tests/test_*.py` and `tests/plugin_contracts/test_*.py`; if missing, it defaults to 8 GPUs. Set `NUM_GPUS = 0` for CPU-only tests. - For GPU/e2e tests, follow the nearby file pattern (`prepare()`, `execute()`, `NUM_GPUS`, and any model/dataset constants). -### Step 3: Run Local Validation +### Step 3: Register Tests in GitHub CI + +Whenever adding, moving, or renaming a test file, update the GitHub workflow template before finishing: + +1. Add the test to the appropriate matrix in `.github/workflows/pr-test.yml.j2`. + - CPU-only pytest/unit tests usually belong in `cpu-unittest` with `num_gpus: 0`. + - GPU/e2e tests should be placed beside the nearest similar model/path test with the matching `num_gpus` and environment fields. +2. Regenerate workflows: + +```bash +python .github/workflows/generate_github_workflows.py +``` + +3. Include both `.github/workflows/pr-test.yml.j2` and the generated `.github/workflows/pr-test.yml` in the change set. + +Only skip fixed matrix registration when the test is intentionally helper-only or manually invoked; state that reason in the final response. + +### Step 4: Run Local Validation - Run the exact existing test files you changed, if any. +- For new registered tests, run the same shape CI will use, for example `python tests/test_new_file.py`. - Run repository-wide checks only when they are already part of the task or workflow. - Avoid documenting placeholder test commands that may not exist in the current tree. -### Step 4: Update Workflow Template Correctly +### Step 5: Keep Workflow Template as Source of Truth -For CI workflow changes: +For CI workflow changes unrelated to a new, moved, or renamed test: 1. Edit `.github/workflows/pr-test.yml.j2` 2. Regenerate workflows: @@ -59,11 +77,12 @@ python .github/workflows/generate_github_workflows.py 3. Include both the template and generated workflow file in the change set (`.j2` and `.yml`). If the user asked for a commit, commit both. -### Step 5: Provide Verifiable PR Notes +### Step 6: Provide Verifiable PR Notes Include: - Which tests were added/changed +- Where each new/renamed test was registered in `.github/workflows/pr-test.yml.j2` - Exact commands executed - GPU assumptions for each test path - Why this coverage protects against regression @@ -71,6 +90,7 @@ Include: ## Common Mistakes - Editing generated workflow file only +- Relying on `run-ci-changed` discovery for a new test that should run in the regular PR matrix - Forgetting `NUM_GPUS = 0` on a CPU-only changed test, causing `run-ci-changed` to default to 8 GPUs - Adding a CPU pytest file that passes under `pytest tests/foo.py` but fails under CI's `python tests/foo.py` - Adding tests without following existing constants/conventions diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index f329d10c2..2cf852c58 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -205,7 +205,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"enable_eval": "0", "num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_glm4.7_30B_A3B_pd_mooncake.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"num_gpus": 8, "test_file": "test_qwen3.6_35B_A3B_pd_mooncake.py", "use_deepep": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo_disaggregate.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo_train_critic_only.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_debug_rollout_then_train.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_opd_sglang.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_fully_async_short.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_streaming_partial_rollout.py"}, {"num_gpus": 4, "test_file": "test_qwen3.5_0.8B_gsm8k_short.py"}, {"num_gpus": 4, "test_file": "test_qwen3.5_0.8B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_args": "--save-optimizer gpu --load-optimizer gpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer gpu --load-optimizer cpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer cpu --load-optimizer cpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer cpu --load-optimizer gpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--async-save", "test_file": "test_qwen3_4B_ckpt.py"}] + info: [{"enable_eval": "0", "num_gpus": 8, "test_file": "test_quick_start_glm4_9B.py"}, {"num_gpus": 8, "test_file": "test_glm4.7_30B_A3B_pd_mooncake.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"num_gpus": 8, "test_file": "test_qwen3.6_35B_A3B_pd_mooncake.py", "use_deepep": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_30B_A3B_r3.py", "use_deepep": "1", "use_fp8_rollout": "1"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo_disaggregate.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_qwen3_4B_ppo_train_critic_only.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B.py"}, {"enable_eval": "0", "num_gpus": 8, "test_file": "test_moonlight_16B_A3B_r3.py"}, {"num_gpus": 8, "test_file": "test_mimo_7B_mtp_only_grad.py"}, {"num_gpus": 8, "test_file": "test_qwen3_0.6B_parallel_check.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_debug_rollout_then_train.py"}, {"num_gpus": 8, "test_file": "test_qwen2.5_0.5B_opd_sglang.py"}, {"num_gpus": 6, "test_file": "test_qwen3_4B_external_pd.py"}, {"num_gpus": 4, "test_file": "test_qwen2.5_0.5B_fully_async_short.py"}, {"num_gpus": 8, "test_file": "test_qwen3_4B_streaming_partial_rollout.py"}, {"num_gpus": 4, "test_file": "test_qwen3.5_0.8B_gsm8k_short.py"}, {"num_gpus": 4, "test_file": "test_qwen3.5_0.8B_gsm8k_async_short.py"}, {"num_gpus": 8, "test_args": "--save-optimizer gpu --load-optimizer gpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer gpu --load-optimizer cpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer cpu --load-optimizer cpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--save-optimizer cpu --load-optimizer gpu", "test_file": "test_qwen3_4B_ckpt.py"}, {"num_gpus": 8, "test_args": "--async-save", "test_file": "test_qwen3_4B_ckpt.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -454,7 +454,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_agent_trajectory.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}] + info: [{"num_gpus": 0, "test_file": "test_megatron_argument_validation.py"}, {"num_gpus": 0, "test_file": "test_dp_schedule.py"}, {"num_gpus": 0, "test_file": "test_cp_utils.py"}, {"num_gpus": 0, "test_file": "test_metric_report.py"}, {"num_gpus": 0, "test_file": "test_metric_report_dist.py"}, {"num_gpus": 0, "test_file": "test_loss_cp_invariance.py"}, {"num_gpus": 0, "test_file": "test_value_temperature.py"}, {"num_gpus": 0, "test_file": "test_rm_f1.py"}, {"num_gpus": 0, "test_file": "test_rm_gpqa.py"}, {"num_gpus": 0, "test_file": "test_rm_math.py"}, {"num_gpus": 0, "test_file": "test_rm_math_dapo.py"}, {"num_gpus": 0, "test_file": "test_rm_deepscaler.py"}, {"num_gpus": 0, "test_file": "test_sample.py"}, {"num_gpus": 0, "test_file": "test_agent_trajectory.py"}, {"num_gpus": 0, "test_file": "test_rollout_validation.py"}, {"num_gpus": 0, "test_file": "test_placement_group.py"}, {"num_gpus": 0, "test_file": "test_external_sglang_engines.py"}, {"num_gpus": 0, "test_file": "utils/test_hf_checkpoint_saver.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_rollout_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_runtime_hook_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_path_loading_contracts.py"}, {"num_gpus": 0, "test_file": "plugin_contracts/test_plugin_generate_contracts.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -481,7 +481,7 @@ jobs: shell: bash run: | pip install torch --index-url https://download.pytorch.org/whl/cpu - pip install pytest numpy packaging pyyaml omegaconf tqdm httpx pybase64 pylatexenc sympy aiohttp pillow safetensors + pip install pytest numpy packaging pyyaml omegaconf tqdm httpx requests ray pybase64 pylatexenc sympy aiohttp pillow safetensors - name: Install @@ -547,7 +547,7 @@ jobs: shell: bash run: | pip install torch --index-url https://download.pytorch.org/whl/cpu - pip install pytest numpy packaging pyyaml omegaconf tqdm httpx pybase64 pylatexenc sympy aiohttp pillow safetensors + pip install pytest numpy packaging pyyaml omegaconf tqdm httpx requests ray pybase64 pylatexenc sympy aiohttp pillow safetensors pip install openai openai-agents anthropic diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 0122182bd..58c627e0e 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -35,6 +35,7 @@ {'test_file': 'test_qwen3_0.6B_parallel_check.py', 'num_gpus': 8}, {'test_file': 'test_qwen2.5_0.5B_debug_rollout_then_train.py', 'num_gpus': 8}, {'test_file': 'test_qwen2.5_0.5B_opd_sglang.py', 'num_gpus': 8}, + {'test_file': 'test_qwen3_4B_external_pd.py', 'num_gpus': 6}, {'test_file': 'test_qwen2.5_0.5B_fully_async_short.py', 'num_gpus': 4}, {'test_file': 'test_qwen3_4B_streaming_partial_rollout.py', 'num_gpus': 8}, {'test_file': 'test_qwen3.5_0.8B_gsm8k_short.py', 'num_gpus': 4}, @@ -83,6 +84,8 @@ {'test_file': 'test_sample.py', 'num_gpus': 0}, {'test_file': 'test_agent_trajectory.py', 'num_gpus': 0}, {'test_file': 'test_rollout_validation.py', 'num_gpus': 0}, + {'test_file': 'test_placement_group.py', 'num_gpus': 0}, + {'test_file': 'test_external_sglang_engines.py', 'num_gpus': 0}, {'test_file': 'utils/test_hf_checkpoint_saver.py', 'num_gpus': 0}, {'test_file': 'plugin_contracts/test_plugin_rollout_contracts.py', 'num_gpus': 0}, {'test_file': 'plugin_contracts/test_plugin_runtime_hook_contracts.py', 'num_gpus': 0}, @@ -194,7 +197,7 @@ jobs: shell: bash run: | pip install torch --index-url https://download.pytorch.org/whl/cpu - pip install pytest numpy packaging pyyaml omegaconf tqdm httpx pybase64 pylatexenc sympy aiohttp pillow safetensors + pip install pytest numpy packaging pyyaml omegaconf tqdm httpx requests ray pybase64 pylatexenc sympy aiohttp pillow safetensors <% if config.get('extra_pip_deps') %> pip install << config.extra_pip_deps >> <% endif %> diff --git a/docs/en/advanced/sglang-config.md b/docs/en/advanced/sglang-config.md index 70af72d2e..6e9bb5a50 100644 --- a/docs/en/advanced/sglang-config.md +++ b/docs/en/advanced/sglang-config.md @@ -257,7 +257,7 @@ Overrides take **highest priority**, overriding both the base `--sglang-*` CLI a ### 7. Standalone SGLang Launcher -While `--sglang-config` is designed for slime's training pipeline, it also works as a powerful launcher for pure inference scenarios using the `--rollout-external` pattern or by configuring slime to focus solely on serving. +While `--sglang-config` is designed for slime's training pipeline, it also works as a powerful launcher for pure inference scenarios using external engine addresses or by configuring slime to focus solely on serving. **Using external engines with a pre-launched topology:** @@ -270,12 +270,17 @@ python -m sglang.launch_server --model-path /path/to/model --port 10091 ... # Step 2: Connect slime to external engines python train.py \ - --rollout-external \ --rollout-external-engine-addrs host1:10090 host2:10091 \ ... ``` -> **Note:** `--sglang-config` and `--rollout-external` are mutually exclusive. Use `--sglang-config` when you want slime to manage the full engine lifecycle; use `--rollout-external` when engines are pre-deployed. +slime queries each external engine's `/server_info` endpoint to infer +`rollout_num_gpus`, per-engine GPU counts, SGLang parallel sizes, and +prefill/decode worker types. If no `--sglang-router-ip/--sglang-router-port` +is provided, slime launches its own router and registers the external engines +to it. + +> **Note:** `--sglang-config` and `--rollout-external-engine-addrs` are mutually exclusive. Use `--sglang-config` when you want slime to manage the full engine lifecycle; use `--rollout-external-engine-addrs` when engines are pre-deployed. --- @@ -332,7 +337,7 @@ When the config is loaded, slime applies the following resolution cascade: | Flag | Conflict Reason | |------|----------------| | `--prefill-num-servers` | PD disaggregation is configured via `server_groups` in the YAML | -| `--rollout-external` | External engines have their own topology; config manages the lifecycle internally | +| `--rollout-external-engine-addrs` | External engines have their own topology; config manages the lifecycle internally | --- @@ -446,7 +451,7 @@ Use `get_model_url(args, "model_name", "/endpoint")` from `slime.rollout.sglang_ ### Q: Can I use `--sglang-config` without training (inference only)? -While `--sglang-config` is designed for slime's training loop, you can effectively use it for inference-only scenarios by configuring a rollout-only run. For fully standalone SGLang serving, consider using SGLang's native `launch_server` directly or the `--rollout-external` mode for connecting to pre-deployed engines. +While `--sglang-config` is designed for slime's training loop, you can effectively use it for inference-only scenarios by configuring a rollout-only run. For fully standalone SGLang serving, consider using SGLang's native `launch_server` directly or `--rollout-external-engine-addrs` for connecting to pre-deployed engines. ### Q: What is the relationship between `--sglang-config` and `--prefill-num-servers`? diff --git a/docs/zh/advanced/sglang-config.md b/docs/zh/advanced/sglang-config.md index e68c52c84..ce05d3600 100644 --- a/docs/zh/advanced/sglang-config.md +++ b/docs/zh/advanced/sglang-config.md @@ -257,7 +257,7 @@ sglang: ### 7. 独立 SGLang 启动器 -虽然 `--sglang-config` 是为 slime 的训练流水线设计的,但它也可以作为纯推理场景的强大启动器,通过 `--rollout-external` 模式或配置 slime 仅关注推理服务。 +虽然 `--sglang-config` 是为 slime 的训练流水线设计的,但它也可以作为纯推理场景的强大启动器,通过外部 engine 地址或配置 slime 仅关注推理服务。 **使用预启动的外部引擎:** @@ -270,12 +270,16 @@ python -m sglang.launch_server --model-path /path/to/model --port 10091 ... # 步骤 2:将 slime 连接到外部引擎 python train.py \ - --rollout-external \ --rollout-external-engine-addrs host1:10090 host2:10091 \ ... ``` -> **注意:** `--sglang-config` 和 `--rollout-external` 互斥。当你希望 slime 管理完整的引擎生命周期时,使用 `--sglang-config`;当引擎已预部署时,使用 `--rollout-external`。 +slime 会请求每个外部引擎的 `/server_info`,自动推断 +`rollout_num_gpus`、单个 engine 的 GPU 数、SGLang 并行参数,以及 +prefill/decode worker 类型。如果没有提供 `--sglang-router-ip/--sglang-router-port`, +slime 会自己启动 router,并把这些外部引擎注册进去。 + +> **注意:** `--sglang-config` 和 `--rollout-external-engine-addrs` 互斥。当你希望 slime 管理完整的引擎生命周期时,使用 `--sglang-config`;当引擎已预部署时,使用 `--rollout-external-engine-addrs`。 --- @@ -332,7 +336,7 @@ slime 自动为每个 sample 分配一个唯一的 `session_id`(存储在 `sam | 选项 | 冲突原因 | |------|----------| | `--prefill-num-servers` | PD 分离通过 YAML 中的 `server_groups` 配置 | -| `--rollout-external` | 外部引擎有自己的拓扑;config 在内部管理生命周期 | +| `--rollout-external-engine-addrs` | 外部引擎有自己的拓扑;config 在内部管理生命周期 | --- @@ -446,7 +450,7 @@ async def generate_with_models(args, sample, sampling_params): ### Q: 可以不训练,只用 `--sglang-config` 做推理吗? -虽然 `--sglang-config` 是为 slime 的训练循环设计的,但你可以通过配置仅 rollout 的运行来实现纯推理场景。对于完全独立的 SGLang 推理服务,建议直接使用 SGLang 原生的 `launch_server`,或使用 `--rollout-external` 模式连接预部署的引擎。 +虽然 `--sglang-config` 是为 slime 的训练循环设计的,但你可以通过配置仅 rollout 的运行来实现纯推理场景。对于完全独立的 SGLang 推理服务,建议直接使用 SGLang 原生的 `launch_server`,或使用 `--rollout-external-engine-addrs` 连接预部署的引擎。 ### Q: `--sglang-config` 和 `--prefill-num-servers` 是什么关系? diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 49eee715e..74680e2ad 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -628,7 +628,7 @@ def update_weights(self) -> None: self.weight_updater.update_weights() print_memory("after update_weights") - if self.args.ci_test and len(rollout_engines) > 0: + if self.args.ci_test and len(rollout_engines) > 0 and self.weight_updater.weight_version > 0: engine = random.choice(rollout_engines) engine_version = ray.get(engine.get_weight_version.remote()) if str(engine_version) != str(self.weight_updater.weight_version): diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py index f5ae3cf33..aed9d086e 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed_delta.py @@ -577,11 +577,6 @@ def update_weights(self) -> None: if not self._snapshot_seeded: self._seed_snapshot() self._snapshot_seeded = True - # Pin the engine's recorded version to ours (0) on the seed call so the - # CI version-equality check holds before any real sync has happened. - if dist.get_rank() == 0 and self.transport == "disk" and self.rollout_engines: - weight_version = str(self.weight_version) - ray.get([engine.set_weight_version.remote(weight_version) for engine in self.rollout_engines]) return self.weight_version += 1 diff --git a/slime/backends/sglang_utils/arguments.py b/slime/backends/sglang_utils/arguments.py index 0a4801743..761086d8b 100644 --- a/slime/backends/sglang_utils/arguments.py +++ b/slime/backends/sglang_utils/arguments.py @@ -159,12 +159,12 @@ def validate_args(args): # Mutual-exclusion checks for PD disaggregation / sglang-config. assert not ( - getattr(args, "prefill_num_servers", None) is not None and args.rollout_external - ), "prefill_num_servers cannot be set when rollout_external is set." + getattr(args, "prefill_num_servers", None) is not None and getattr(args, "rollout_external", False) + ), "prefill_num_servers cannot be set with --rollout-external-engine-addrs." assert not ( - getattr(args, "sglang_config", None) is not None and args.rollout_external - ), "sglang_config cannot be set when rollout_external is set." + getattr(args, "sglang_config", None) is not None and getattr(args, "rollout_external", False) + ), "sglang_config cannot be set with --rollout-external-engine-addrs." assert not ( getattr(args, "sglang_config", None) is not None and getattr(args, "prefill_num_servers", None) is not None diff --git a/slime/backends/sglang_utils/external.py b/slime/backends/sglang_utils/external.py new file mode 100644 index 000000000..7499bb907 --- /dev/null +++ b/slime/backends/sglang_utils/external.py @@ -0,0 +1,229 @@ +"""Helpers for pre-launched external SGLang engines.""" + +from __future__ import annotations + +import dataclasses +import logging +from urllib.parse import urlparse + +import requests + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass(frozen=True) +class ExternalEngineInfo: + url: str + host: str + port: int + worker_type: str + num_gpus: int + disaggregation_bootstrap_port: int | None = None + server_info: dict = dataclasses.field(default_factory=dict) + + @property + def is_pd_worker(self) -> bool: + return self.worker_type in ("prefill", "decode") + + def to_dict(self) -> dict: + return dataclasses.asdict(self) + + +def normalize_external_engine_addr(addr: str) -> str: + """Normalize ``host:port`` or ``http://host:port`` to an HTTP base URL.""" + if "://" not in addr: + addr = f"http://{addr}" + addr = addr.rstrip("/") + parsed = urlparse(addr) + if parsed.scheme != "http" or parsed.hostname is None or parsed.port is None: + raise ValueError( + f"Invalid external SGLang engine address {addr!r}. " + "Use host:port or http://host:port (IPv6 must be bracketed)." + ) + return addr + + +def external_engine_init_kwargs(info: ExternalEngineInfo) -> dict: + init_kwargs = { + "dist_init_addr": f"{info.host}:{info.port}", + "nccl_port": None, + "host": info.host, + "port": info.port, + } + if info.worker_type == "prefill": + init_kwargs["disaggregation_bootstrap_port"] = info.disaggregation_bootstrap_port + return init_kwargs + + +def get_server_info(url: str, timeout: float = 30.0) -> dict: + errors = [] + for endpoint in ("/server_info", "/get_server_info"): + try: + response = requests.get(f"{url}{endpoint}", timeout=timeout) + response.raise_for_status() + return response.json() + except Exception as exc: + errors.append(f"{endpoint}: {exc}") + raise RuntimeError(f"Failed to fetch SGLang server info from {url}: {'; '.join(errors)}") + + +def _infer_worker_type(server_info: dict) -> str: + if server_info.get("encoder_only"): + return "encoder" + mode = server_info.get("disaggregation_mode") + if mode in ("prefill", "decode"): + return mode + return "regular" + + +def discover_external_engines(addrs: list[str], timeout: float = 30.0) -> list[ExternalEngineInfo]: + infos = [] + for addr in addrs: + url = normalize_external_engine_addr(addr) + parsed = urlparse(url) + assert parsed.hostname is not None and parsed.port is not None + server_info = get_server_info(url, timeout=timeout) + + pp_size = int(server_info.get("pp_size") or server_info.get("pipeline_parallel_size") or 1) + tp_size = int(server_info.get("tp_size") or server_info.get("tensor_parallel_size") or 1) + num_gpus = int(server_info.get("num_gpus") or server_info.get("num_gpus_per_engine") or tp_size * pp_size) + bootstrap_port = server_info.get("disaggregation_bootstrap_port") + bootstrap_port = int(bootstrap_port) if bootstrap_port is not None else None + + infos.append( + ExternalEngineInfo( + url=url, + host=parsed.hostname, + port=parsed.port, + worker_type=_infer_worker_type(server_info), + num_gpus=num_gpus, + disaggregation_bootstrap_port=bootstrap_port, + server_info=server_info, + ) + ) + return infos + + +def apply_external_engine_info_to_args(args, logger=None) -> None: + """Detect external engines and store the derived topology on ``args``.""" + addrs = args.rollout_external_engine_addrs + if not addrs: + raise ValueError("apply_external_engine_info_to_args requires --rollout-external-engine-addrs.") + + infos = discover_external_engines(addrs) + if not infos: + raise ValueError("--rollout-external-engine-addrs did not contain any engines.") + + args.rollout_external_engine_infos = [info.to_dict() for info in infos] + args.rollout_num_engines = len(infos) + args.rollout_num_gpus = sum(info.num_gpus for info in infos) + + if logger is not None: + summary = [ + { + "url": info.url, + "worker_type": info.worker_type, + "num_gpus": info.num_gpus, + "disaggregation_bootstrap_port": info.disaggregation_bootstrap_port, + } + for info in infos + ] + logger.info(f"Detected external SGLang engines: {summary}") + + +@dataclasses.dataclass +class ExternalRolloutServer: + """Rollout server backed by pre-launched external SGLang engines.""" + + engines: list + engine_gpu_counts: list[int] + engine_gpu_offsets: list[int] + router_ip: str | None = None + router_port: int | None = None + model_name: str = "default" + update_weights: bool = True + num_new_engines: int = 0 + server_groups: list = dataclasses.field(default_factory=list) + + @property + def all_engines(self): + return self.engines + + def recover(self): + logger.warning("Fault tolerance is not supported for external rollout engines; skip recover.") + + def offload(self): + return [] + + def onload(self, tags: list[str] | None = None): + return [] + + def onload_weights(self): + return [] + + def onload_kv(self): + return [] + + +def external_engine_infos_from_args(args) -> list[ExternalEngineInfo]: + raw_infos = getattr(args, "rollout_external_engine_infos", None) + if raw_infos is None: + raise RuntimeError( + "External rollout engine info is missing. " + "apply_external_engine_info_to_args must run before starting external rollout servers." + ) + return [ExternalEngineInfo(**info) if isinstance(info, dict) else info for info in raw_infos] + + +def start_external_rollout_servers(args, *, start_router) -> dict[str, ExternalRolloutServer]: + import ray + + from slime.backends.sglang_utils.sglang_engine import SGLangEngine + + infos = external_engine_infos_from_args(args) + router_ip, router_port = start_router(args, has_pd_disaggregation=any(info.is_pd_worker for info in infos)) + args.sglang_router_ip = router_ip + args.sglang_router_port = router_port + + engines = [] + engine_gpu_counts = [] + engine_gpu_offsets = [] + init_handles = [] + RolloutRayActor = ray.remote(SGLangEngine) + gpu_offset = 0 + for rank, info in enumerate(infos): + rollout_engine = RolloutRayActor.options(num_cpus=0.2, num_gpus=0).remote( + args=args, + rank=rank, + worker_type=info.worker_type, + base_gpu_id=0, + num_gpus_per_engine=info.num_gpus, + ) + engines.append(rollout_engine) + engine_gpu_counts.append(info.num_gpus) + engine_gpu_offsets.append(gpu_offset) + gpu_offset += info.num_gpus + init_handles.append( + rollout_engine.init.remote( + **external_engine_init_kwargs(info), + router_ip=router_ip, + router_port=router_port, + ) + ) + + if init_handles: + ray.get(init_handles) + + args.sglang_model_routers = {"default": (router_ip, router_port)} + return { + "default": ExternalRolloutServer( + engines=engines, + engine_gpu_counts=engine_gpu_counts, + engine_gpu_offsets=engine_gpu_offsets, + router_ip=router_ip, + router_port=router_port, + model_name="default", + update_weights=True, + num_new_engines=len(engines), + ) + } diff --git a/slime/backends/sglang_utils/sglang_engine.py b/slime/backends/sglang_utils/sglang_engine.py index 607f4c0d0..a4366d118 100644 --- a/slime/backends/sglang_utils/sglang_engine.py +++ b/slime/backends/sglang_utils/sglang_engine.py @@ -13,6 +13,7 @@ from sglang.srt.utils import kill_process_tree from urllib3.exceptions import NewConnectionError +from slime.backends.sglang_utils.external import get_server_info from slime.ray.ray_actor import RayActor from slime.utils.http_utils import get_host_info @@ -169,11 +170,6 @@ def _format_v6_uri(addr): def _init_external(self, expect_server_args, external_engine_need_check_fields): logger.info(f"Use external SGLang engine (rank={self.rank}, expect_server_args={expect_server_args})") - def _get_actual_server_args(): - response = requests.get(f"http://{self.server_host}:{self.server_port}/get_server_info") - response.raise_for_status() - return response.json() - def _sanity_check_server_args(actual_server_args, expect_server_args): for name in external_engine_need_check_fields: expect_value = expect_server_args.get(name) @@ -182,34 +178,39 @@ def _sanity_check_server_args(actual_server_args, expect_server_args): actual_value == expect_value ), f"{name=} {expect_value=} {actual_value=} {expect_server_args=} {actual_server_args=}" - _wait_server_healthy( - base_url=f"http://{self.server_host}:{self.server_port}", - api_key=None, - is_process_alive=lambda: True, - ) - actual_server_args = _get_actual_server_args() + actual_server_args = get_server_info(f"http://{self.server_host}:{self.server_port}") _sanity_check_server_args(actual_server_args, expect_server_args) + self._register_to_router(expect_server_args) def _init_normal(self, server_args_dict): logger.info(f"Launch HttpServerEngineAdapter at: {self.server_host}:{self.server_port}") self.process = launch_server_process(ServerArgs(**server_args_dict)) + self._register_to_router(server_args_dict) + def _register_to_router(self, server_args_dict): if self.worker_type == "encoder": return if self.node_rank == 0 and self.router_ip and self.router_port: + worker_url = f"http://{self.server_host}:{self.server_port}" if parse(sglang_router.__version__) <= parse("0.2.1"): assert self.worker_type == "regular", "pd disaggregation is not supported in old router." response = requests.post( - f"http://{self.router_ip}:{self.router_port}/add_worker?url=http://{self.server_host}:{self.server_port}", + f"http://{self.router_ip}:{self.router_port}/add_worker?url={worker_url}", ) else: payload = { - "url": f"http://{self.server_host}:{self.server_port}", + "url": worker_url, "worker_type": self.worker_type, } if self.worker_type == "prefill": - payload["bootstrap_port"] = server_args_dict["disaggregation_bootstrap_port"] + bootstrap_port = server_args_dict.get("disaggregation_bootstrap_port") + if bootstrap_port is None: + raise RuntimeError( + f"Prefill worker {worker_url} does not have disaggregation_bootstrap_port; " + "cannot register it to the PD router." + ) + payload["bootstrap_port"] = bootstrap_port response = requests.post( f"http://{self.router_ip}:{self.router_port}/workers", json=payload, @@ -651,8 +652,18 @@ def _compute_server_args( "model_path", "trust_remote_code", "random_seed", + "host", + "port", "nccl_port", + "nnodes", + "node_rank", "dist_init_addr", + "gpu_id_step", + "base_gpu_id", + "tp_size", + "dp_size", + "pp_size", + "ep_size", "skip_server_warmup", "enable_draft_weights_cpu_backup", "enable_metrics", diff --git a/slime/ray/placement_group.py b/slime/ray/placement_group.py index e8a778030..9499cdd57 100644 --- a/slime/ray/placement_group.py +++ b/slime/ray/placement_group.py @@ -7,7 +7,6 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from .actor_group import RayTrainGroup -from .rollout import RolloutManager logger = logging.getLogger(__name__) @@ -41,6 +40,9 @@ def sort_key(x): def _create_placement_group(num_gpus): """Create a placement group with the specified number of GPUs.""" + if num_gpus == 0: + return None, [], [] + bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_gpus)] pg = placement_group(bundles, strategy="PACK") num_bundles = len(bundles) @@ -77,22 +79,30 @@ def _create_placement_group(num_gpus): return pg, pg_reordered_bundle_indices, pg_reordered_gpu_ids +def _get_placement_group_layout(args) -> tuple[int, int]: + actor_num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node + + if args.debug_train_only: + return actor_num_gpus, 0 + + if args.rollout_external: + if args.debug_rollout_only: + return 0, 0 + return actor_num_gpus, actor_num_gpus + + if args.debug_rollout_only: + return args.rollout_num_gpus, 0 + + if args.colocate: + return actor_num_gpus, 0 + + return actor_num_gpus + args.rollout_num_gpus, actor_num_gpus + + def create_placement_groups(args): """Create placement groups for actor, critic, and rollout engines.""" - num_gpus = 0 - if args.debug_train_only: - num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node - rollout_offset = 0 - elif args.debug_rollout_only: - num_gpus = args.rollout_num_gpus - rollout_offset = 0 - elif args.colocate: - num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node - rollout_offset = 0 - else: - num_gpus = args.actor_num_nodes * args.actor_num_gpus_per_node + args.rollout_num_gpus - rollout_offset = args.actor_num_nodes * args.actor_num_gpus_per_node + num_gpus, rollout_offset = _get_placement_group_layout(args) logger.info(f"Creating placement group with {num_gpus} GPUs...") pg, actor_pg_reordered_bundle_indices, actor_pg_reordered_gpu_ids = _create_placement_group(num_gpus) @@ -185,6 +195,8 @@ def create_training_models(args, pgs, rollout_manager): def create_rollout_manager(args, pg): + from .rollout import RolloutManager + rollout_manager = RolloutManager.options( num_cpus=1, num_gpus=0, diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 4998c8232..9743d1223 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -14,6 +14,7 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS +from slime.backends.sglang_utils.external import start_external_rollout_servers from slime.backends.sglang_utils.sglang_config import ModelConfig, ServerGroupConfig, SglangConfig from slime.backends.sglang_utils.sglang_engine import SGLangEngine from slime.rollout.base_types import call_rollout_fn @@ -158,22 +159,17 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis if self.num_new_engines == 0: return [], port_cursors - if self.args.rollout_external: - addr_and_ports = _allocate_rollout_engine_addr_and_ports_external( - args=self.args, rollout_engines=rollout_engines - ) - else: - # Compute base_port from the maximum cursor across all nodes that - # this group's engines may land on (conservative: just use global max). - base_port = max(port_cursors.values()) if port_cursors else 15000 - addr_and_ports, port_cursors = _allocate_rollout_engine_addr_and_ports_normal( - args=self.args, - rollout_engines=rollout_engines, - worker_type=self.worker_type, - num_gpus_per_engine=self.num_gpus_per_engine, - rank_offset=self.rank_offset, - base_port=base_port, - ) + # Compute base_port from the maximum cursor across all nodes that + # this group's engines may land on (conservative: just use global max). + base_port = max(port_cursors.values()) if port_cursors else 15000 + addr_and_ports, port_cursors = _allocate_rollout_engine_addr_and_ports_normal( + args=self.args, + rollout_engines=rollout_engines, + worker_type=self.worker_type, + num_gpus_per_engine=self.num_gpus_per_engine, + rank_offset=self.rank_offset, + base_port=base_port, + ) init_handles = [ engine.init.remote( @@ -384,7 +380,7 @@ def __init__(self, args, pg): logger.info(f"import {self.args.eval_function_path} as eval_generate_rollout function.") if self.args.debug_train_only: - self.servers: dict[str, RolloutServer] = {} + self.servers: dict[str, Any] = {} else: init_http_client(args) self.servers = start_rollout_servers(args, pg) @@ -427,7 +423,12 @@ def _try_ci_fault_injection(self): # Only inject fault once self._ci_fault_injection_pending = False - if self.server and self.server.server_groups[0].all_engines and self.server.server_groups[0].all_engines[0]: + if ( + self.server + and self.server.server_groups + and self.server.server_groups[0].all_engines + and self.server.server_groups[0].all_engines[0] + ): logger.info("CI Fault Injection: Simulating crash on engine 0 during generate") try: # This will cause the ray actor to exit @@ -446,13 +447,13 @@ def dispose(self): logging_utils.finish_tracking(self.args) @property - def server(self) -> RolloutServer | None: + def server(self) -> Any | None: """Default server (first model). For backward compatibility.""" if not self.servers: return None return next(iter(self.servers.values())) - def _get_updatable_server(self) -> RolloutServer | None: + def _get_updatable_server(self) -> Any | None: """Return the server with ``update_weights=True``. When multiple updatable servers exist, returns the first one @@ -845,20 +846,6 @@ def _validate_rollout_id_annotated(node, depth=0): _validate_rollout_id_annotated(item, depth + 1) -def _allocate_rollout_engine_addr_and_ports_external(args, rollout_engines): - addr_and_ports = {} - for rank, _ in rollout_engines: - addr = args.rollout_external_engine_addrs[rank] - [host, port] = addr.split(":") - addr_and_ports[rank] = dict( - dist_init_addr=addr, - nccl_port=None, - host=host, - port=int(port), - ) - return addr_and_ports - - def _allocate_rollout_engine_addr_and_ports_normal( *, args, @@ -1018,7 +1005,7 @@ def _compute_megatron_num_gpus(args) -> int: return num -def start_rollout_servers(args, pg) -> dict[str, RolloutServer]: +def start_rollout_servers(args, pg) -> dict[str, Any]: """Start rollout servers: one per model, each with its own router. Each model defined in the sglang config gets its own router and set @@ -1031,6 +1018,9 @@ def start_rollout_servers(args, pg) -> dict[str, RolloutServer]: Note: ``init_http_client`` should be called separately before this, as the HTTP client is shared across all servers. """ + if args.rollout_external: + return start_external_rollout_servers(args, start_router=_start_router) + config = _resolve_sglang_config(args) servers: dict[str, RolloutServer] = {} diff --git a/slime/rollout/fully_async_rollout.py b/slime/rollout/fully_async_rollout.py index a54f4083a..c301075c5 100644 --- a/slime/rollout/fully_async_rollout.py +++ b/slime/rollout/fully_async_rollout.py @@ -11,8 +11,8 @@ :func:`generate_and_rm_group` which dispatches to those. Concurrency is sourced from ``args.sglang_server_concurrency`` and scaled by -the number of sglang engines (``rollout_num_gpus // rollout_num_gpus_per_engine``) -to match the per-sample semaphore cap in :mod:`slime.rollout.sglang_rollout`. +the number of sglang engines to match the per-sample semaphore cap in +:mod:`slime.rollout.sglang_rollout`. The worker is intentionally oblivious to slime's higher-level pause / weight-update signalling (e.g. ``GenerateState.aborted``). Each in-flight @@ -34,6 +34,7 @@ from slime.rollout.sglang_rollout import GenerateState, generate_and_rm_group from slime.utils.async_utils import run +from slime.utils.http_utils import get_rollout_num_engines from slime.utils.types import Sample __all__ = [ @@ -54,9 +55,8 @@ def _get_global_worker(args, data_buffer) -> AsyncRolloutWorker: with _worker_lock: if _global_worker is None or not _global_worker.worker_thread.is_alive(): logger.info("starting fully-async rollout worker") - num_engines = max(1, args.rollout_num_gpus // args.rollout_num_gpus_per_engine) _global_worker = AsyncRolloutWorker( - args, data_buffer, concurrency=args.sglang_server_concurrency * num_engines + args, data_buffer, concurrency=args.sglang_server_concurrency * get_rollout_num_engines(args) ) _global_worker.start() return _global_worker diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index c7f86b98c..eee3a1368 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -19,7 +19,7 @@ from slime.utils.async_utils import run from slime.utils.data import Dataset from slime.utils.eval_config import EvalDatasetConfig -from slime.utils.http_utils import get, post +from slime.utils.http_utils import get, get_rollout_num_engines, post from slime.utils.misc import SingletonMeta, load_function from slime.utils.processing_utils import ( build_processor_kwargs, @@ -91,9 +91,7 @@ def __init__(self, args: Namespace) -> None: self.tokenizer = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) self.processor = load_processor(args.hf_checkpoint, trust_remote_code=True) - self.semaphore = asyncio.Semaphore( - args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine - ) + self.semaphore = asyncio.Semaphore(args.sglang_server_concurrency * get_rollout_num_engines(args)) self.sampling_params: dict[str, Any] = dict( temperature=args.rollout_temperature, top_p=args.rollout_top_p, diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index d7f863455..f562f8c47 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -10,6 +10,7 @@ from slime.backends.sglang_utils.arguments import sglang_parse_args from slime.backends.sglang_utils.arguments import validate_args as sglang_validate_args +from slime.backends.sglang_utils.external import apply_external_engine_info_to_args from slime.utils.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list from slime.utils.logging_utils import configure_logger @@ -506,12 +507,6 @@ def add_rollout_arguments(parser): "It may be helpful for updating loss mask." ), ) - parser.add_argument( - "--rollout-external", - action="store_true", - default=False, - help="Use external SGLang instances instead of launching them inside the framework.", - ) parser.add_argument( "--rollout-external-engine-addrs", type=str, @@ -1786,6 +1781,11 @@ def slime_validate_args(args): ) args.debug_train_only = True + args.rollout_external = args.rollout_external_engine_addrs is not None + + if args.rollout_external and not args.debug_train_only: + apply_external_engine_info_to_args(args, logger=logger) + args.use_critic = args.advantage_estimator == "ppo" # Critic always uses the same GPU count as actor. args.critic_num_gpus_per_node = args.actor_num_gpus_per_node diff --git a/slime/utils/http_utils.py b/slime/utils/http_utils.py index 7ce395c4d..ced387e57 100644 --- a/slime/utils/http_utils.py +++ b/slime/utils/http_utils.py @@ -198,13 +198,26 @@ async def _post(client, url, payload, max_retries=60, headers=None): return output +def get_rollout_num_engines(args) -> int: + """Return the number of rollout HTTP engines behind the router.""" + if (num_engines := getattr(args, "rollout_num_engines", None)) is not None: + return int(num_engines) + + rollout_num_gpus = getattr(args, "rollout_num_gpus", None) or 0 + rollout_num_gpus_per_engine = getattr(args, "rollout_num_gpus_per_engine", None) or 1 + if rollout_num_gpus <= 0: + return 0 + return max(1, rollout_num_gpus // rollout_num_gpus_per_engine) + + def init_http_client(args): """Initialize HTTP client and optionally enable distributed POST via Ray.""" global _http_client, _client_concurrency, _distributed_post_enabled - if not args.rollout_num_gpus: + num_engines = get_rollout_num_engines(args) + if num_engines <= 0: return - _client_concurrency = args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine + _client_concurrency = args.sglang_server_concurrency * num_engines if _http_client is None: _http_client = httpx.AsyncClient( limits=httpx.Limits(max_connections=_client_concurrency), diff --git a/tests/test_external_sglang_engines.py b/tests/test_external_sglang_engines.py new file mode 100644 index 000000000..704bcc7a8 --- /dev/null +++ b/tests/test_external_sglang_engines.py @@ -0,0 +1,143 @@ +import sys +from argparse import Namespace +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from slime.backends.sglang_utils.external import apply_external_engine_info_to_args, discover_external_engines +from slime.utils.http_utils import get_rollout_num_engines + +NUM_GPUS = 0 + + +class _Response: + def __init__(self, payload, status_code=200): + self.payload = payload + self.status_code = status_code + + def raise_for_status(self): + if self.status_code >= 400: + raise RuntimeError(f"HTTP {self.status_code}") + + def json(self): + return self.payload + + +def test_discover_external_engines_reads_server_info(monkeypatch): + def fake_get(url, timeout): + assert timeout == 30.0 + assert url == "http://host1:10090/server_info" + return _Response( + { + "tp_size": 4, + "pp_size": 2, + "dp_size": 1, + "ep_size": 4, + "disaggregation_mode": "null", + } + ) + + monkeypatch.setattr("slime.backends.sglang_utils.external.requests.get", fake_get) + + infos = discover_external_engines(["host1:10090"]) + + assert len(infos) == 1 + info = infos[0] + assert info.url == "http://host1:10090" + assert info.host == "host1" + assert info.port == 10090 + assert info.worker_type == "regular" + assert info.num_gpus == 8 + assert info.server_info["tp_size"] == 4 + assert info.server_info["pp_size"] == 2 + assert info.server_info["dp_size"] == 1 + assert info.server_info["ep_size"] == 4 + + +def test_apply_external_engine_info_handles_pd(monkeypatch): + payloads = { + "http://prefill:10090/server_info": { + "tp_size": 2, + "pp_size": 1, + "dp_size": 1, + "ep_size": 1, + "disaggregation_mode": "prefill", + "disaggregation_bootstrap_port": 12090, + }, + "http://decode:10091/server_info": { + "tp_size": 4, + "pp_size": 1, + "dp_size": 2, + "ep_size": 2, + "disaggregation_mode": "decode", + }, + } + + def fake_get(url, timeout): + return _Response(payloads[url]) + + monkeypatch.setattr("slime.backends.sglang_utils.external.requests.get", fake_get) + args = Namespace( + rollout_external=True, + rollout_external_engine_addrs=["prefill:10090", "decode:10091"], + rollout_num_gpus=None, + rollout_num_gpus_per_engine=1, + sglang_pipeline_parallel_size=1, + sglang_data_parallel_size=1, + sglang_expert_parallel_size=1, + sglang_enable_dp_attention=False, + router_pd_disaggregation=False, + ) + + apply_external_engine_info_to_args(args) + + assert args.rollout_external is True + assert args.router_pd_disaggregation is False + assert args.rollout_num_gpus == 6 + assert args.rollout_num_engines == 2 + assert get_rollout_num_engines(args) == 2 + assert [info["worker_type"] for info in args.rollout_external_engine_infos] == ["prefill", "decode"] + assert [info["num_gpus"] for info in args.rollout_external_engine_infos] == [2, 4] + assert [info["server_info"]["dp_size"] for info in args.rollout_external_engine_infos] == [1, 2] + assert args.rollout_external_engine_infos[0]["disaggregation_bootstrap_port"] == 12090 + + +def test_apply_external_engine_info_preserves_router_pd_flag(monkeypatch): + def fake_get(url, timeout): + assert url == "http://regular:10090/server_info" + return _Response( + { + "tp_size": 2, + "pp_size": 1, + "disaggregation_mode": "null", + } + ) + + monkeypatch.setattr("slime.backends.sglang_utils.external.requests.get", fake_get) + args = Namespace( + rollout_external=True, + rollout_external_engine_addrs=["regular:10090"], + router_pd_disaggregation=True, + ) + + apply_external_engine_info_to_args(args) + + assert args.rollout_external is True + assert args.router_pd_disaggregation is True + assert args.rollout_num_gpus == 2 + assert args.rollout_num_engines == 1 + + +def test_apply_external_engine_info_requires_addrs(): + args = Namespace(rollout_external_engine_addrs=None) + + with pytest.raises(ValueError, match="rollout-external-engine-addrs"): + apply_external_engine_info_to_args(args) + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__])) diff --git a/tests/test_placement_group.py b/tests/test_placement_group.py new file mode 100644 index 000000000..8f918d4a7 --- /dev/null +++ b/tests/test_placement_group.py @@ -0,0 +1,51 @@ +import sys +from argparse import Namespace +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from slime.ray.placement_group import _create_placement_group, _get_placement_group_layout + + +NUM_GPUS = 0 + + +def _args(**overrides): + values = { + "actor_num_nodes": 2, + "actor_num_gpus_per_node": 8, + "rollout_num_gpus": 32, + "debug_train_only": False, + "debug_rollout_only": False, + "colocate": False, + "rollout_external": False, + } + values.update(overrides) + return Namespace(**values) + + +@pytest.mark.parametrize( + ("overrides", "expected"), + [ + pytest.param({}, (48, 16), id="normal_non_colocate"), + pytest.param({"debug_train_only": True}, (16, 0), id="debug_train_only"), + pytest.param({"debug_rollout_only": True}, (32, 0), id="debug_rollout_only"), + pytest.param({"colocate": True}, (16, 0), id="colocate"), + pytest.param({"rollout_external": True}, (16, 16), id="external"), + pytest.param({"rollout_external": True, "debug_rollout_only": True}, (0, 0), id="external_debug_rollout"), + ], +) +def test_placement_group_layout(overrides, expected): + assert _get_placement_group_layout(_args(**overrides)) == expected + + +def test_create_zero_gpu_placement_group_is_empty(): + assert _create_placement_group(0) == (None, [], []) + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__])) diff --git a/tests/test_qwen3_4B_external_pd.py b/tests/test_qwen3_4B_external_pd.py new file mode 100644 index 000000000..68078603c --- /dev/null +++ b/tests/test_qwen3_4B_external_pd.py @@ -0,0 +1,381 @@ +"""E2E test for --rollout-external-engine-addrs with a pure-PD external fleet. + +Spawns two SGLang servers out-of-band on a single GPU box (all tp=1): +- 1 prefill (``--disaggregation-mode prefill``, mooncake transfer backend) +- 1 decode (``--disaggregation-mode decode``, mooncake transfer backend) + +and points slime at both via ``--rollout-external-engine-addrs ...``. +The first 4 GPUs train. slime queries ``/server_info`` on each engine to +infer per-engine TP / GPU counts and registers them to its PD-enabled router. + +Weight sync uses ``--update-weight-mode delta --update-weight-transport disk`` +so the post-train sync writes sparse safetensors to a shared dir and the +external engines load them via ``update_weights_from_disk(load_format=delta)`` +— that's the only sync path that actually works for pre-launched workers (no +NCCL group between trainer and external engines). +""" + +import os +import socket +import subprocess +import tempfile +import time +import urllib.request +from pathlib import Path + +import slime.utils.external_utils.command_utils as U + +MODEL_NAME = "Qwen3-4B" +MODEL_TYPE = "qwen3-4B" +TORCH_DIST_CKPT = f"/root/models/{MODEL_NAME}_torch_dist" +NUM_GPUS = 6 +NUM_TRAIN_GPUS = 4 +NUM_PREFILL_ENGINES = 1 +NUM_DECODE_ENGINES = 1 + +EXTERNAL_HOST = "127.0.0.1" +PREFILL_PORTS = [13150] +DECODE_PORTS = [13151] +BOOTSTRAP_PORTS = [13160] + + +def _get_bond_ipv4(): + net_root = Path("/sys/class/net") + if not net_root.exists(): + return None + + bond_ifaces = [ + path.name for path in net_root.iterdir() if path.name.startswith("bond") and path.name[4:].isdigit() + ] + bond_ifaces.sort(key=lambda name: int(name[4:])) + for iface in bond_ifaces: + try: + output = subprocess.check_output(["ip", "-o", "-4", "addr", "show", "dev", iface], text=True) + except (OSError, subprocess.CalledProcessError): + continue + fields = output.split() + for idx, field in enumerate(fields): + if field == "inet" and idx + 1 < len(fields): + return fields[idx + 1].split("/", 1)[0] + return None + + +def _get_external_host(): + env_value = os.environ.get("SLIME_TEST_EXTERNAL_PD_HOST") + if env_value and env_value not in ("127.0.0.1", "localhost"): + return env_value + + bond_host = _get_bond_ipv4() + if bond_host is not None: + return bond_host + + master_addr = os.environ.get("MASTER_ADDR") + if master_addr and master_addr not in ("127.0.0.1", "localhost"): + return master_addr + + try: + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: + sock.connect(("8.8.8.8", 80)) + host = sock.getsockname()[0] + if host and not host.startswith("127."): + return host + except OSError: + pass + + return EXTERNAL_HOST + + +def _get_disaggregation_ib_device(): + env_value = os.environ.get("SLIME_TEST_DISAGGREGATION_IB_DEVICE") + if env_value is not None: + return env_value.strip() or None + + ib_root = Path("/sys/class/infiniband") + if not ib_root.exists(): + return None + + active_devices = [] + for device in ib_root.iterdir(): + for state_file in device.glob("ports/*/state"): + try: + if "ACTIVE" in state_file.read_text(): + active_devices.append(device.name) + break + except OSError: + continue + + bond_devices = [] + numeric_mlx5_devices = [] + for device in active_devices: + prefix, _, suffix = device.partition("_") + if prefix == "mlx5" and suffix.startswith("bond_") and suffix[5:].isdigit(): + bond_devices.append(device) + elif prefix == "mlx5" and suffix.isdigit(): + numeric_mlx5_devices.append(device) + bond_devices.sort(key=lambda name: int(name.rsplit("_", 1)[1])) + numeric_mlx5_devices.sort(key=lambda name: int(name.rsplit("_", 1)[1])) + + devices = bond_devices or numeric_mlx5_devices or sorted(active_devices) + return ",".join(devices) if devices else None + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/gsm8k") + U.convert_checkpoint( + model_name=MODEL_NAME, + megatron_model_type=MODEL_TYPE, + num_gpus_per_node=NUM_TRAIN_GPUS, + dir_dst="/root/models", + ) + + +def _get_gpu_split(): + """Partition visible GPUs: 4 train + 1 prefill + 1 decode.""" + all_gpus = os.environ.get("CUDA_VISIBLE_DEVICES", ",".join(str(i) for i in range(NUM_GPUS))).split(",") + assert len(all_gpus) >= NUM_GPUS, f"Expected at least {NUM_GPUS} GPUs, got {len(all_gpus)}" + train_gpus = all_gpus[:NUM_TRAIN_GPUS] + cursor = NUM_TRAIN_GPUS + prefill_gpus = all_gpus[cursor : cursor + NUM_PREFILL_ENGINES] + cursor += NUM_PREFILL_ENGINES + decode_gpus = all_gpus[cursor : cursor + NUM_DECODE_ENGINES] + return train_gpus, prefill_gpus, decode_gpus + + +def _launch_sglang_server( + *, + gpus: list[str], + port: int, + tp: int, + log_path: str, + disaggregation_mode: str, + disaggregation_bootstrap_port: int | None = None, + disaggregation_ib_device: str | None = None, + external_host: str = EXTERNAL_HOST, +) -> subprocess.Popen: + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = ",".join(gpus) + + cmd = [ + "python3", + "-m", + "sglang.launch_server", + "--model-path", + f"/root/models/{MODEL_NAME}", + "--host", + "0.0.0.0", + "--port", + str(port), + "--tp", + str(tp), + "--mem-fraction-static", + "0.6", + "--trust-remote-code", + "--disaggregation-mode", + disaggregation_mode, + "--disaggregation-transfer-backend", + "mooncake", + ] + if disaggregation_ib_device is not None: + cmd += ["--disaggregation-ib-device", disaggregation_ib_device] + if disaggregation_bootstrap_port is not None: + cmd += ["--disaggregation-bootstrap-port", str(disaggregation_bootstrap_port)] + cmd += ["--load-balance-method", "follow_bootstrap_room"] + else: + cmd += ["--prefill-round-robin-balance"] + + log_file = open(log_path, "w") + process = subprocess.Popen(cmd, env=env, stdout=log_file, stderr=subprocess.STDOUT) + print( + f"Starting external sglang {disaggregation_mode} server on GPUs {gpus} " + f"port={port} tp={tp} (pid={process.pid}), log: {log_path}" + ) + + # Wait up to ~10 minutes for /server_info to come up. /health_generate + # is unreliable for prefill/decode-only nodes, so we poll /server_info + # — that's what slime's discover_external_engines uses anyway. + deadline = time.time() + 600 + while time.time() < deadline: + if process.poll() is not None: + raise RuntimeError(f"{disaggregation_mode} server exited with code {process.returncode}; check {log_path}") + try: + req = urllib.request.urlopen(f"http://{external_host}:{port}/server_info", timeout=2) + if req.status == 200: + print(f"External sglang {disaggregation_mode} server is ready on GPUs {gpus}") + return process + except Exception: + pass + time.sleep(5) + + process.kill() + raise RuntimeError(f"{disaggregation_mode} server failed to start within timeout; check {log_path}") + + +def execute(): + train_gpus, prefill_gpus, decode_gpus = _get_gpu_split() + external_host = _get_external_host() + disaggregation_ib_device = _get_disaggregation_ib_device() + print(f"Using external host for SGLang workers: {external_host}") + print(f"Using SGLang disaggregation IB device: {disaggregation_ib_device}") + processes: list[subprocess.Popen] = [] + + # Restrict CUDA_VISIBLE_DEVICES to training GPUs before Ray starts so + # ray's bundle allocator doesn't try to claim the external sglang GPUs. + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(train_gpus) + + def launch_external_engines(): + for idx, (gpu, port, bootstrap_port) in enumerate( + zip(prefill_gpus, PREFILL_PORTS, BOOTSTRAP_PORTS, strict=True) + ): + processes.append( + _launch_sglang_server( + gpus=[gpu], + port=port, + tp=1, + disaggregation_mode="prefill", + disaggregation_bootstrap_port=bootstrap_port, + disaggregation_ib_device=disaggregation_ib_device, + external_host=external_host, + log_path=f"/tmp/sglang_external_prefill_{idx}.log", + ) + ) + for idx, (gpu, port) in enumerate(zip(decode_gpus, DECODE_PORTS, strict=True)): + processes.append( + _launch_sglang_server( + gpus=[gpu], + port=port, + tp=1, + disaggregation_mode="decode", + disaggregation_ib_device=disaggregation_ib_device, + external_host=external_host, + log_path=f"/tmp/sglang_external_decode_{idx}.log", + ) + ) + + delta_dir_cm = tempfile.TemporaryDirectory(prefix="slime_external_pd_delta_") + delta_dir = delta_dir_cm.name + try: + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load {TORCH_DIST_CKPT} " + + rollout_args = ( + "--prompt-data /root/datasets/gsm8k/train.parquet " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + "--num-rollout 3 " + "--rollout-batch-size 4 " + "--n-samples-per-prompt 4 " + "--rollout-max-response-len 1024 " + "--rollout-temperature 0.8 " + "--global-batch-size 16 " + ) + + perf_args = ( + "--tensor-model-parallel-size 2 " + "--sequence-parallel " + "--pipeline-model-parallel-size 1 " + "--context-parallel-size 1 " + "--expert-model-parallel-size 1 " + "--expert-tensor-parallel-size 1 " + "--use-dynamic-batch-size " + "--max-tokens-per-gpu 9216 " + "--recompute-granularity full " + "--recompute-method uniform " + "--recompute-num-layers 1 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + # Nonzero entropy coef guarantees a nonzero gradient even when all + # rewards in a group tie (advantages=0), so the delta sync writes + # real sparse files instead of an empty no-op. + "--entropy-coef 0.01 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + # No --rollout-num-gpus / --rollout-num-gpus-per-engine: those are + # inferred from /server_info on each external engine (1 prefill + + # 1 decode, all tp=1). + all_addrs = [f"{external_host}:{port}" for port in (*PREFILL_PORTS, *DECODE_PORTS)] + external_args = "--rollout-external-engine-addrs " + " ".join(all_addrs) + " " + + # External engines have no NCCL group with the trainer, so weight + # updates have to go through the disk-backed delta path: the trainer + # writes sparse safetensors per sync, the engines pull via + # update_weights_from_disk(load_format="delta", files=...). + delta_args = ( + "--update-weight-mode delta " + "--update-weight-transport disk " + "--update-weight-encoding deltas " + f"--update-weight-delta-dir {delta_dir} " + "--update-weight-delta-keep-files " + ) + + ci_args = "--ci-test " + + misc_args = ( + "--attention-dropout 0.0 " + "--hidden-dropout 0.0 " + "--accumulate-allreduce-grads-in-fp32 " + "--attention-softmax-in-fp32 " + "--attention-backend flash " + "--actor-num-nodes 1 " + f"--actor-num-gpus-per-node {NUM_TRAIN_GPUS} " + ) + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{perf_args} " + f"{external_args} " + f"{delta_args} " + f"{ci_args} " + f"{misc_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus_per_node=NUM_TRAIN_GPUS, + megatron_model_type=MODEL_TYPE, + before_ray_job_submit=launch_external_engines, + extra_env_vars={ + "no_proxy": f"127.0.0.1,localhost,{external_host}", + "NO_PROXY": f"127.0.0.1,localhost,{external_host}", + }, + ) + + delta_files = list(Path(delta_dir).glob("weight_v*/*.safetensors")) + assert delta_files, f"No disk delta safetensors were written under {delta_dir}" + finally: + for p in processes: + if p.poll() is None: + p.kill() + p.wait() + U.exec_command("pkill -9 sglang; true") + delta_dir_cm.cleanup() + + +if __name__ == "__main__": + prepare() + for proxy_var in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): + os.environ.pop(proxy_var, None) + execute()