Skip to content

Commit 907fb78

Browse files
committed
feat: add OpenAI-compatible error handling to m serve
Add proper exception handling to the chat completion endpoint in cli/serve/app.py to prevent unhandled exceptions from crashing the server. Implements OpenAI API error format for the `m serve` endpoint to ensure compatibility with OpenAI client libraries and tools. Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com>
1 parent d971776 commit 907fb78

3 files changed

Lines changed: 233 additions & 32 deletions

File tree

cli/serve/app.py

Lines changed: 71 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,17 @@
88

99
import typer
1010
import uvicorn
11-
from fastapi import FastAPI
11+
from fastapi import FastAPI, Request
12+
from fastapi.responses import JSONResponse
1213

13-
from .models import ChatCompletion, ChatCompletionMessage, ChatCompletionRequest, Choice
14+
from .models import (
15+
ChatCompletion,
16+
ChatCompletionMessage,
17+
ChatCompletionRequest,
18+
Choice,
19+
OpenAIError,
20+
OpenAIErrorResponse,
21+
)
1422

1523
app = FastAPI(
1624
title="M serve OpenAI API Compatible Server",
@@ -29,35 +37,69 @@ def load_module_from_path(path: str):
2937
return module
3038

3139

40+
def create_openai_error_response(
41+
status_code: int, message: str, error_type: str, param: str | None = None
42+
) -> JSONResponse:
43+
"""Create an OpenAI-compatible error response."""
44+
error_response = OpenAIErrorResponse(
45+
error=OpenAIError(message=message, type=error_type, param=param)
46+
)
47+
return JSONResponse(
48+
status_code=status_code, content=error_response.model_dump(mode="json")
49+
)
50+
51+
3252
def make_chat_endpoint(module):
3353
"""Makes a chat endpoint using a custom module."""
3454

35-
async def endpoint(request: ChatCompletionRequest) -> ChatCompletion:
36-
completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}"
37-
created_timestamp = int(time.time())
38-
39-
output = module.serve(
40-
input=request.messages,
41-
requirements=request.requirements,
42-
model_options={
43-
k: v for k, v in request if k not in ["messages", "requirements"]
44-
},
45-
)
46-
47-
return ChatCompletion(
48-
id=completion_id,
49-
model=request.model,
50-
created=created_timestamp,
51-
choices=[
52-
Choice(
53-
index=0,
54-
message=ChatCompletionMessage(
55-
content=output.value, role="assistant"
56-
),
57-
)
58-
],
59-
object="chat.completion", # type: ignore
60-
) # type: ignore
55+
async def endpoint(request: ChatCompletionRequest):
56+
try:
57+
completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}"
58+
created_timestamp = int(time.time())
59+
60+
output = module.serve(
61+
input=request.messages,
62+
requirements=request.requirements,
63+
model_options={
64+
k: v for k, v in request if k not in ["messages", "requirements"]
65+
},
66+
)
67+
68+
return ChatCompletion(
69+
id=completion_id,
70+
model=request.model,
71+
created=created_timestamp,
72+
choices=[
73+
Choice(
74+
index=0,
75+
message=ChatCompletionMessage(
76+
content=output.value, role="assistant"
77+
),
78+
)
79+
],
80+
object="chat.completion", # type: ignore
81+
) # type: ignore
82+
except AttributeError as e:
83+
# Handle missing 'value' attribute or other attribute errors
84+
return create_openai_error_response(
85+
status_code=500,
86+
message=f"Internal server error: {e!s}",
87+
error_type="server_error",
88+
)
89+
except ValueError as e:
90+
# Handle validation errors or invalid input
91+
return create_openai_error_response(
92+
status_code=400,
93+
message=f"Invalid request: {e!s}",
94+
error_type="invalid_request_error",
95+
)
96+
except Exception as e:
97+
# Catch-all for any other unexpected errors
98+
return create_openai_error_response(
99+
status_code=500,
100+
message=f"Internal server error: {e!s}",
101+
error_type="server_error",
102+
)
61103

62104
endpoint.__name__ = f"chat_{module.__name__}_endpoint"
63105
return endpoint
@@ -79,7 +121,7 @@ def serve(
79121
route_path,
80122
make_chat_endpoint(module),
81123
methods=["POST"],
82-
response_model=ChatCompletion,
124+
response_model=None, # Allow both ChatCompletion and error responses
83125
)
84126
typer.echo(f"Serving {route_path} at http://{host}:{port}")
85127
uvicorn.run(app, host=host, port=port)

cli/serve/models.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class LogitBias(BaseModel):
3636

3737

3838
class ChatCompletionRequest(BaseModel):
39+
model_config = {"extra": "allow"}
40+
3941
model: str
4042
messages: list[ChatMessage]
4143
requirements: list[str | None] | None = Field(default_factory=list)
@@ -59,9 +61,6 @@ class ChatCompletionRequest(BaseModel):
5961
# For future/undocumented fields
6062
extra: dict[str, Any] = Field(default_factory=dict)
6163

62-
class Config:
63-
extra = "allow"
64-
6564

6665
# Taking this from OpenAI types https://github.com/openai/openai-python/blob/main/src/openai/types/chat/chat_completion.py,
6766
class ChatCompletionMessage(BaseModel):
@@ -101,3 +100,26 @@ class ChatCompletion(BaseModel):
101100

102101
object: Literal["chat.completion"]
103102
"""The object type, which is always `chat.completion`."""
103+
104+
105+
class OpenAIError(BaseModel):
106+
"""OpenAI API error object."""
107+
108+
message: str
109+
"""A human-readable error message."""
110+
111+
type: str
112+
"""The type of error (e.g., 'invalid_request_error', 'server_error')."""
113+
114+
param: str | None = None
115+
"""The parameter that caused the error, if applicable."""
116+
117+
code: str | None = None
118+
"""An error code, if applicable."""
119+
120+
121+
class OpenAIErrorResponse(BaseModel):
122+
"""OpenAI API error response wrapper."""
123+
124+
error: OpenAIError
125+
"""The error object."""

test/cli/test_serve_errors.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""Tests for the OpenAI-compatible serve endpoint."""
2+
3+
from unittest.mock import Mock
4+
5+
import pytest
6+
from fastapi.testclient import TestClient
7+
8+
from cli.serve.app import app, make_chat_endpoint
9+
from cli.serve.models import ChatCompletionRequest, ChatMessage
10+
11+
12+
@pytest.fixture
13+
def mock_module_success():
14+
"""Create a mock module that returns a successful response."""
15+
module = Mock()
16+
module.__name__ = "test_module"
17+
output = Mock()
18+
output.value = "Test response"
19+
module.serve = Mock(return_value=output)
20+
return module
21+
22+
23+
@pytest.fixture
24+
def mock_module_attribute_error():
25+
"""Create a mock module that raises AttributeError."""
26+
module = Mock()
27+
module.__name__ = "test_module"
28+
output = Mock(spec=[]) # No 'value' attribute
29+
module.serve = Mock(return_value=output)
30+
return module
31+
32+
33+
@pytest.fixture
34+
def mock_module_value_error():
35+
"""Create a mock module that raises ValueError."""
36+
module = Mock()
37+
module.__name__ = "test_module"
38+
module.serve = Mock(side_effect=ValueError("Invalid input"))
39+
return module
40+
41+
42+
@pytest.fixture
43+
def mock_module_generic_error():
44+
"""Create a mock module that raises a generic exception."""
45+
module = Mock()
46+
module.__name__ = "test_module"
47+
module.serve = Mock(side_effect=RuntimeError("Unexpected error"))
48+
return module
49+
50+
51+
@pytest.fixture
52+
def sample_request():
53+
"""Create a sample chat completion request."""
54+
return ChatCompletionRequest(
55+
model="test-model",
56+
messages=[ChatMessage(role="user", content="Hello")],
57+
requirements=None,
58+
)
59+
60+
61+
@pytest.mark.unit
62+
def test_successful_completion(mock_module_success, sample_request):
63+
"""Test successful chat completion."""
64+
endpoint = make_chat_endpoint(mock_module_success)
65+
client = TestClient(app)
66+
67+
# Add the endpoint to the app
68+
app.add_api_route("/test/completions", endpoint, methods=["POST"])
69+
70+
response = client.post("/test/completions", json=sample_request.model_dump())
71+
72+
assert response.status_code == 200
73+
data = response.json()
74+
assert data["choices"][0]["message"]["content"] == "Test response"
75+
assert data["model"] == "test-model"
76+
assert "id" in data
77+
assert data["object"] == "chat.completion"
78+
79+
80+
@pytest.mark.unit
81+
def test_attribute_error_handling(mock_module_attribute_error, sample_request):
82+
"""Test handling of AttributeError (e.g., missing 'value' attribute)."""
83+
endpoint = make_chat_endpoint(mock_module_attribute_error)
84+
client = TestClient(app)
85+
86+
app.add_api_route("/test/attribute-error", endpoint, methods=["POST"])
87+
88+
response = client.post("/test/attribute-error", json=sample_request.model_dump())
89+
90+
assert response.status_code == 500
91+
data = response.json()
92+
assert "error" in data
93+
assert data["error"]["type"] == "server_error"
94+
assert "Internal server error" in data["error"]["message"]
95+
96+
97+
@pytest.mark.unit
98+
def test_value_error_handling(mock_module_value_error, sample_request):
99+
"""Test handling of ValueError (validation errors)."""
100+
endpoint = make_chat_endpoint(mock_module_value_error)
101+
client = TestClient(app)
102+
103+
app.add_api_route("/test/value-error", endpoint, methods=["POST"])
104+
105+
response = client.post("/test/value-error", json=sample_request.model_dump())
106+
107+
assert response.status_code == 400
108+
data = response.json()
109+
assert "error" in data
110+
assert data["error"]["type"] == "invalid_request_error"
111+
assert "Invalid request" in data["error"]["message"]
112+
assert "Invalid input" in data["error"]["message"]
113+
114+
115+
@pytest.mark.unit
116+
def test_generic_error_handling(mock_module_generic_error, sample_request):
117+
"""Test handling of generic exceptions."""
118+
endpoint = make_chat_endpoint(mock_module_generic_error)
119+
client = TestClient(app)
120+
121+
app.add_api_route("/test/generic-error", endpoint, methods=["POST"])
122+
123+
response = client.post("/test/generic-error", json=sample_request.model_dump())
124+
125+
assert response.status_code == 500
126+
data = response.json()
127+
assert "error" in data
128+
assert data["error"]["type"] == "server_error"
129+
assert "Internal server error" in data["error"]["message"]
130+
assert "Unexpected error" in data["error"]["message"]
131+
132+
133+
@pytest.mark.unit
134+
def test_endpoint_name_generation(mock_module_success):
135+
"""Test that endpoint names are generated correctly."""
136+
endpoint = make_chat_endpoint(mock_module_success)
137+
assert endpoint.__name__ == "chat_test_module_endpoint"

0 commit comments

Comments
 (0)