From fd97ff75863adf36c304bea83d6c1f65bf66fd0c Mon Sep 17 00:00:00 2001 From: Max Wittig Date: Thu, 16 Apr 2026 17:27:10 +0200 Subject: [PATCH] feat(vllm-router): add fallback model support for zero-downtime GPU node reboots When all backends for a model are unavailable (either health-checked away or all attempts errored out), requests automatically fall through to a configured fallback model. The model name in the request body is rewritten so downstream gateways (e.g. Envoy AI Gateway routing to Bedrock) receive the correct model identifier. Config: per-model fallback_model in YAML, or --static-fallback-models CLI flag. Signed-off-by: Max Wittig --- pyproject.toml | 3 + src/vllm_router/README.md | 60 ++++++ src/vllm_router/app.py | 5 + src/vllm_router/dynamic_config.py | 7 + src/vllm_router/parsers/parser.py | 7 + src/vllm_router/parsers/yaml_utils.py | 15 ++ src/vllm_router/service_discovery.py | 2 + .../services/request_service/request.py | 180 ++++++++++++++---- 8 files changed, 246 insertions(+), 33 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5bf6725b3..88584b9c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,9 @@ write_to = "src/vllm_router/_version.py" [tool.isort] profile = "black" +[tool.ruff] +target-version = "py312" + [tool.pytest.ini_options] asyncio_mode = "auto" diff --git a/src/vllm_router/README.md b/src/vllm_router/README.md index 125744965..1548f6211 100644 --- a/src/vllm_router/README.md +++ b/src/vllm_router/README.md @@ -29,6 +29,7 @@ The router can be configured using command-line arguments. Below are the availab - `--static-models`: The models running in the static serving engines, separated by commas (e.g., `model1,model2`). - `--static-aliases`: The aliases of the models running in the static serving engines, separated by commas and associated using colons (e.g., `model_alias1:model,mode_alias2:model`). - `--static-backend-health-checks`: Enable this flag to make vllm-router check periodically if the models work by sending dummy requests to their endpoints. +- `--static-fallback-models`: Fallback model mappings, separated by commas (e.g., `model1:fallback1,model2:fallback2`). When all backends for a model are unavailable, requests are retried on the fallback model. - `--k8s-port`: The port of vLLM processes when using K8s service discovery. Default is `8000`. - `--k8s-namespace`: The namespace of vLLM pods when using K8s service discovery. Default is `default`. - `--k8s-label-selector`: The label selector to filter vLLM pods when using K8s service discovery. @@ -108,6 +109,64 @@ different endpoints for each model type. > Enabling this flag will put some load on your backend every minute as real requests are send to the nodes > to test their functionality. +## Fallback models + +When all backends for a model become unavailable (e.g. during node reboots), the +router can automatically retry the request on a different **fallback model**. The +model name in the request body is rewritten to the fallback model name before +forwarding, so the fallback backend receives the correct model identifier. + +Fallback triggers in two situations: + +1. **No healthy endpoints** -- all backends have been marked unhealthy by the + periodic health check. The router switches to the fallback model immediately + without attempting the primary backends. +2. **All instance-level failover attempts failed** -- the primary backends were + still considered healthy but every attempt returned a connection error (e.g. + the node went down between health checks). After exhausting + `--max-instance-failover-reroute-attempts`, the router retries once on the + fallback model. + +### Configuration + +**In a YAML config file**, add `fallback_model` to any model entry. The value +must be the name of another model defined in `static_models`: + +```yaml +static_models: + glm-5: + static_backends: + - https://gpu-node-1/glm-5 + - https://gpu-node-2/glm-5 + static_model_type: chat + fallback_model: glm-5-cloud # fall back to the cloud-hosted variant + glm-5-cloud: + static_backends: + - http://cloud-gateway:1975 + static_model_type: chat + healthcheck_disabled: true +``` + +**Via CLI**, use `--static-fallback-models` with comma-separated +`model:fallback` pairs: + +```bash +vllm-router --port 8000 \ + --service-discovery static \ + --static-backends "https://gpu-node-1/glm-5,https://gpu-node-2/glm-5,http://cloud-gateway:1975" \ + --static-models "glm-5,glm-5,glm-5-cloud" \ + --static-model-types "chat,chat,chat" \ + --static-fallback-models "glm-5:glm-5-cloud" \ + --static-backend-health-checks \ + --max-instance-failover-reroute-attempts 2 \ + --routing-logic roundrobin +``` + +Combining `fallback_model` with `--max-instance-failover-reroute-attempts` and a +short `--static-backend-health-check-interval` gives the best resilience: failed +requests are retried on other instances first, then on the fallback model, while +the health check quickly removes dead backends from future routing decisions. + ## Dynamic Router Config The router can be configured dynamically using a config file when passing the `--dynamic-config-yaml` or @@ -128,6 +187,7 @@ Currently, the dynamic config supports the following fields: - (When using `static` service discovery) `static_models`: The models running in the static serving engines, separated by commas (e.g., `model1,model2`). - (When using `static` service discovery) `static_aliases`: The aliases of the models running in the static serving engines, separated by commas and associated using colons (e.g., `model_alias1:model,mode_alias2:model`). - (When using `static` service discovery and if you enable the `--static-backend-health-checks` flag) `static_model_types`: The model types running in the static serving engines, separated by commas (e.g., `chat,chat`). +- (When using `static` service discovery) `fallback_model`: A per-model string in the YAML config (under each model entry) specifying another model to fall back to when all backends are unavailable. - (When using `k8s` service discovery) `k8s_port`: The port of vLLM processes when using K8s service discovery. Default is `8000`. - (When using `k8s` service discovery) `k8s_namespace`: The namespace of vLLM pods when using K8s service discovery. Default is `default`. - (When using `k8s` service discovery) `k8s_label_selector`: The label selector to filter vLLM pods when using K8s service discovery. diff --git a/src/vllm_router/app.py b/src/vllm_router/app.py index c71f6de83..a36f0d9fd 100644 --- a/src/vllm_router/app.py +++ b/src/vllm_router/app.py @@ -202,6 +202,11 @@ def initialize_all(app: FastAPI, args): static_backend_health_check_timeout_seconds=args.static_backend_health_check_timeout_seconds, prefill_model_labels=args.prefill_model_labels, decode_model_labels=args.decode_model_labels, + fallback_models=( + parse_static_aliases(args.static_fallback_models) + if args.static_fallback_models + else None + ), ) elif args.service_discovery == "k8s": initialize_service_discovery( diff --git a/src/vllm_router/dynamic_config.py b/src/vllm_router/dynamic_config.py index 5a99e94fd..e383a69c2 100644 --- a/src/vllm_router/dynamic_config.py +++ b/src/vllm_router/dynamic_config.py @@ -57,6 +57,7 @@ class DynamicRouterConfig: static_aliases: Optional[str] = None static_model_labels: Optional[str] = None static_model_types: Optional[str] = None + static_fallback_models: Optional[str] = None static_backend_health_checks: Optional[bool] = False static_backend_health_check_interval: Optional[int] = 60 static_backend_health_check_timeout_seconds: Optional[int] = 10 @@ -97,6 +98,7 @@ def from_args(args) -> "DynamicRouterConfig": static_backend_health_checks=args.static_backend_health_checks, static_backend_health_check_interval=args.static_backend_health_check_interval, static_backend_health_check_timeout_seconds=args.static_backend_health_check_timeout_seconds, + static_fallback_models=getattr(args, "static_fallback_models", None), k8s_port=args.k8s_port, k8s_namespace=args.k8s_namespace, k8s_label_selector=args.k8s_label_selector, @@ -185,6 +187,11 @@ def reconfigure_service_discovery(self, config: DynamicRouterConfig): decode_model_labels=parse_comma_separated_args( config.decode_model_labels ), + fallback_models=( + parse_static_aliases(config.static_fallback_models) + if config.static_fallback_models + else None + ), ) elif config.service_discovery == "k8s": reconfigure_service_discovery( diff --git a/src/vllm_router/parsers/parser.py b/src/vllm_router/parsers/parser.py index d98c7eb13..7389f0f15 100644 --- a/src/vllm_router/parsers/parser.py +++ b/src/vllm_router/parsers/parser.py @@ -171,6 +171,13 @@ def parse_args(): default=None, help="The model labels of static backends, separated by commas. E.g., model1,model2", ) + parser.add_argument( + "--static-fallback-models", + type=str, + default=None, + help="Fallback model mappings, separated by commas. E.g., model1:fallback1,model2:fallback2. " + "When all backends for a model are unavailable, requests are retried on the fallback model.", + ) parser.add_argument( "--static-backend-health-checks", action="store_true", diff --git a/src/vllm_router/parsers/yaml_utils.py b/src/vllm_router/parsers/yaml_utils.py index 28aa01a06..9bb7f0891 100644 --- a/src/vllm_router/parsers/yaml_utils.py +++ b/src/vllm_router/parsers/yaml_utils.py @@ -37,6 +37,18 @@ def generate_static_model_types(models: dict[str, Any]) -> str: return ",".join(static_model_types) +def generate_static_fallback_models(models: dict[str, Any]) -> str | None: + """Generate comma-separated fallback model mappings. + + Format: model1:fallback1,model2:fallback2 + """ + fallback_models = [] + for name, details in models.items(): + if "fallback_model" in details: + fallback_models.append(f"{name}:{details['fallback_model']}") + return ",".join(fallback_models) if fallback_models else None + + def read_and_process_yaml_config_file(config_path: str) -> dict[str, Any]: with open(config_path, encoding="utf-8") as f: try: @@ -49,6 +61,9 @@ def read_and_process_yaml_config_file(config_path: str) -> dict[str, Any]: yaml_config["static_backends"] = generate_static_backends(models) yaml_config["static_models"] = generate_static_models(models) yaml_config["static_model_types"] = generate_static_model_types(models) + fallback_models = generate_static_fallback_models(models) + if fallback_models: + yaml_config["static_fallback_models"] = fallback_models if aliases: yaml_config["static_aliases"] = generate_static_aliases(aliases) return yaml_config diff --git a/src/vllm_router/service_discovery.py b/src/vllm_router/service_discovery.py index 0a51e2be4..7ad21860e 100644 --- a/src/vllm_router/service_discovery.py +++ b/src/vllm_router/service_discovery.py @@ -217,6 +217,7 @@ def __init__( static_backend_health_check_timeout_seconds: int = 10, prefill_model_labels: List[str] | None = None, decode_model_labels: List[str] | None = None, + fallback_models: Dict[str, str] | None = None, ): self.app = app assert len(urls) == len(models), "URLs and models should have the same length" @@ -225,6 +226,7 @@ def __init__( self.aliases = aliases self.model_labels = model_labels self.model_types = model_types + self.fallback_models = fallback_models or {} self.engines_id = [str(uuid.uuid4()) for i in range(0, len(urls))] self.added_timestamp = int(time.time()) self.unhealthy_endpoint_hashes = [] diff --git a/src/vllm_router/services/request_service/request.py b/src/vllm_router/services/request_service/request.py index badc98b03..7cd8544fb 100644 --- a/src/vllm_router/services/request_service/request.py +++ b/src/vllm_router/services/request_service/request.py @@ -98,6 +98,48 @@ } +async def _select_backend( + router, endpoints, engine_stats, request_stats, request, request_json=None +): + """Pick a backend URL via the router's routing logic.""" + if isinstance(router, (KvawareRouter, PrefixAwareRouter, SessionRouter)): + return await router.route_request( + endpoints, engine_stats, request_stats, request, request_json + ) + return router.route_request(endpoints, engine_stats, request_stats, request) + + +async def _send_and_parse( + request, + request_body, + server_url, + request_id, + endpoint, + background_tasks, + span_context, +): + """Send a proxied request and return (generator, headers_dict, media_type, status).""" + gen = process_request( + request, + request_body, + server_url, + request_id, + endpoint, + background_tasks, + parent_span_context=span_context, + ) + headers, status = await anext(gen) + media_type = headers.get("content-type", "text/event-stream") + headers_dict = { + k: v + for k, v in headers.items() + if k.lower() not in _HEADERS_TO_STRIP_FROM_RESPONSE + and k.lower() != "content-type" + } + headers_dict["X-Request-Id"] = request_id + return gen, headers_dict, media_type, status + + # TODO: (Brian) check if request is json beforehand async def process_request( request: Request, @@ -372,9 +414,11 @@ async def route_general_request( else: endpoints = list( filter( - lambda x: requested_model in x.model_names - and x.Id == request_endpoint - and not x.sleep, + lambda x: ( + requested_model in x.model_names + and x.Id == request_endpoint + and not x.sleep + ), endpoints, ) ) @@ -382,6 +426,31 @@ async def route_general_request( # Track all valid incoming requests num_incoming_requests_total.labels(model=requested_model).inc() + # --- Fallback model support --- + # If no endpoints are available for the requested model, check if a + # fallback model is configured and switch to it. + fallback_models = getattr(service_discovery, "fallback_models", None) + fallback_model = fallback_models.get(requested_model) if fallback_models else None + used_fallback = False + + if not endpoints and fallback_model and not request_endpoint: + logger.info( + f"No healthy endpoints for model '{requested_model}', " + f"falling back to '{fallback_model}'" + ) + all_endpoints = service_discovery.get_endpoint_info() + endpoints = list( + filter( + lambda x: fallback_model in x.model_names and not x.sleep, + all_endpoints, + ) + ) + if endpoints: + requested_model = fallback_model + request_body = replace_model_in_request_body(request_json, fallback_model) + update_content_length(request, request_body) + used_fallback = True + if not endpoints: if not model_ever_existed: end_span(span, status_code=404) if tracing_active else None @@ -409,16 +478,14 @@ async def route_general_request( logger.debug( f"Routing request {request_id} to engine with Id: {endpoints[0].Id}" ) - - elif isinstance( - request.app.state.router, (KvawareRouter, PrefixAwareRouter, SessionRouter) - ): - server_url = await request.app.state.router.route_request( - endpoints, engine_stats, request_stats, request, request_json - ) else: - server_url = request.app.state.router.route_request( - endpoints, engine_stats, request_stats, request + server_url = await _select_backend( + request.app.state.router, + endpoints, + engine_stats, + request_stats, + request, + request_json, ) curr_time = time.time() @@ -457,16 +524,14 @@ async def route_general_request( break if request_endpoint: server_url = remaining[0].url - elif isinstance( - request.app.state.router, - (KvawareRouter, PrefixAwareRouter, SessionRouter), - ): - server_url = await request.app.state.router.route_request( - remaining, engine_stats, request_stats, request, request_json - ) else: - server_url = request.app.state.router.route_request( - remaining, engine_stats, request_stats, request + server_url = await _select_backend( + request.app.state.router, + remaining, + engine_stats, + request_stats, + request, + request_json, ) logger.info( f"Routing request {request_id} to {server_url} " @@ -475,26 +540,16 @@ async def route_general_request( if span is not None: span.set_attribute("vllm.backend_url", server_url) - media_type = "text/event-stream" try: - stream_generator = process_request( + stream_generator, headers_dict, media_type, status = await _send_and_parse( request, request_body, server_url, request_id, endpoint, background_tasks, - parent_span_context=span_context, + span_context, ) - headers, status = await anext(stream_generator) - media_type = headers.get("content-type", "text/event-stream") - headers_dict = { - key: value - for key, value in headers.items() - if key.lower() not in _HEADERS_TO_STRIP_FROM_RESPONSE - and key.lower() != "content-type" - } - headers_dict["X-Request-Id"] = request_id last_error = None break except HTTPException: @@ -507,6 +562,65 @@ async def route_general_request( f"(attempt {attempt + 1}/{max_attempts}): {e}" ) + # All instance-level failover attempts failed. Try fallback model + # if configured and not already used. + can_fallback = ( + last_error and fallback_model and not used_fallback and not request_endpoint + ) + if can_fallback: + logger.info( + f"All backends for '{requested_model}' failed, " + f"falling back to '{fallback_model}'" + ) + all_endpoints = service_discovery.get_endpoint_info() + fallback_endpoints = list( + filter( + lambda x: fallback_model in x.model_names and not x.sleep, + all_endpoints, + ) + ) + if not fallback_endpoints: + can_fallback = False + + if can_fallback: + requested_model = fallback_model + request_body = replace_model_in_request_body(request_json, fallback_model) + update_content_length(request, request_body) + + server_url = await _select_backend( + request.app.state.router, + fallback_endpoints, + engine_stats, + request_stats, + request, + request_json, + ) + logger.info( + f"Routing request {request_id} to fallback {fallback_model} at {server_url}" + ) + if span is not None: + span.set_attribute("vllm.backend_url", server_url) + span.set_attribute("vllm.fallback_model", fallback_model) + + try: + stream_generator, headers_dict, media_type, status = await _send_and_parse( + request, + request_body, + server_url, + request_id, + endpoint, + background_tasks, + span_context, + ) + last_error = None + except HTTPException: + raise + except Exception as e: + logger.warning( + f"Fallback to {fallback_model} at {server_url} also failed: {e}" + ) + last_error = e + if last_error: end_span(span, error=last_error, status_code=500) if tracing_active else None raise last_error