Skip to content

Commit e172da1

Browse files
committed
fix: fixes for pr review comments
* remove unused import * fix FastAPI app route accumulation * remove duplicate error handler * add types for response_model Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com>
1 parent 907fb78 commit e172da1

2 files changed

Lines changed: 13 additions & 20 deletions

File tree

cli/serve/app.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import typer
1010
import uvicorn
11-
from fastapi import FastAPI, Request
11+
from fastapi import FastAPI
1212
from fastapi.responses import JSONResponse
1313

1414
from .models import (
@@ -79,13 +79,6 @@ async def endpoint(request: ChatCompletionRequest):
7979
],
8080
object="chat.completion", # type: ignore
8181
) # type: ignore
82-
except AttributeError as e:
83-
# Handle missing 'value' attribute or other attribute errors
84-
return create_openai_error_response(
85-
status_code=500,
86-
message=f"Internal server error: {e!s}",
87-
error_type="server_error",
88-
)
8982
except ValueError as e:
9083
# Handle validation errors or invalid input
9184
return create_openai_error_response(
@@ -94,7 +87,7 @@ async def endpoint(request: ChatCompletionRequest):
9487
error_type="invalid_request_error",
9588
)
9689
except Exception as e:
97-
# Catch-all for any other unexpected errors
90+
# Catch-all for any unexpected errors (including AttributeError)
9891
return create_openai_error_response(
9992
status_code=500,
10093
message=f"Internal server error: {e!s}",
@@ -121,7 +114,7 @@ def serve(
121114
route_path,
122115
make_chat_endpoint(module),
123116
methods=["POST"],
124-
response_model=None, # Allow both ChatCompletion and error responses
117+
response_model=ChatCompletion | OpenAIErrorResponse,
125118
)
126119
typer.echo(f"Serving {route_path} at http://{host}:{port}")
127120
uvicorn.run(app, host=host, port=port)

test/cli/test_serve_errors.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from unittest.mock import Mock
44

55
import pytest
6+
from fastapi import FastAPI
67
from fastapi.testclient import TestClient
78

8-
from cli.serve.app import app, make_chat_endpoint
9+
from cli.serve.app import make_chat_endpoint
910
from cli.serve.models import ChatCompletionRequest, ChatMessage
1011

1112

@@ -61,11 +62,10 @@ def sample_request():
6162
@pytest.mark.unit
6263
def test_successful_completion(mock_module_success, sample_request):
6364
"""Test successful chat completion."""
65+
app = FastAPI()
6466
endpoint = make_chat_endpoint(mock_module_success)
65-
client = TestClient(app)
66-
67-
# Add the endpoint to the app
6867
app.add_api_route("/test/completions", endpoint, methods=["POST"])
68+
client = TestClient(app)
6969

7070
response = client.post("/test/completions", json=sample_request.model_dump())
7171

@@ -80,10 +80,10 @@ def test_successful_completion(mock_module_success, sample_request):
8080
@pytest.mark.unit
8181
def test_attribute_error_handling(mock_module_attribute_error, sample_request):
8282
"""Test handling of AttributeError (e.g., missing 'value' attribute)."""
83+
app = FastAPI()
8384
endpoint = make_chat_endpoint(mock_module_attribute_error)
84-
client = TestClient(app)
85-
8685
app.add_api_route("/test/attribute-error", endpoint, methods=["POST"])
86+
client = TestClient(app)
8787

8888
response = client.post("/test/attribute-error", json=sample_request.model_dump())
8989

@@ -97,10 +97,10 @@ def test_attribute_error_handling(mock_module_attribute_error, sample_request):
9797
@pytest.mark.unit
9898
def test_value_error_handling(mock_module_value_error, sample_request):
9999
"""Test handling of ValueError (validation errors)."""
100+
app = FastAPI()
100101
endpoint = make_chat_endpoint(mock_module_value_error)
101-
client = TestClient(app)
102-
103102
app.add_api_route("/test/value-error", endpoint, methods=["POST"])
103+
client = TestClient(app)
104104

105105
response = client.post("/test/value-error", json=sample_request.model_dump())
106106

@@ -115,10 +115,10 @@ def test_value_error_handling(mock_module_value_error, sample_request):
115115
@pytest.mark.unit
116116
def test_generic_error_handling(mock_module_generic_error, sample_request):
117117
"""Test handling of generic exceptions."""
118+
app = FastAPI()
118119
endpoint = make_chat_endpoint(mock_module_generic_error)
119-
client = TestClient(app)
120-
121120
app.add_api_route("/test/generic-error", endpoint, methods=["POST"])
121+
client = TestClient(app)
122122

123123
response = client.post("/test/generic-error", json=sample_request.model_dump())
124124

0 commit comments

Comments
 (0)