Skip to content
Merged
109 changes: 99 additions & 10 deletions cli/serve/app.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
"""A simple app that runs an OpenAI compatible server wrapped around a M program."""

import asyncio
import importlib.util
import inspect
import os
import sys
import time
import uuid

import typer
import uvicorn
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse

from mellea.backends.model_options import ModelOption

from .models import (
ChatCompletion,
ChatCompletionMessage,
Expand All @@ -28,6 +33,34 @@
)


@app.exception_handler(RequestValidationError)
async def validation_exception_handler(
request: Request, exc: RequestValidationError
) -> JSONResponse:
"""Convert FastAPI validation errors to OpenAI-compatible format.

FastAPI returns 422 with a 'detail' array by default. OpenAI API uses
400 with an 'error' object containing message, type, and param fields.
"""
# Extract the first validation error
errors = exc.errors()
if errors:
first_error = errors[0]
# Get the field name from the location tuple (e.g., ('body', 'n') -> 'n')
param = first_error["loc"][-1] if first_error["loc"] else None
message = first_error["msg"]
else:
param = None
message = "Invalid request parameters"

return create_openai_error_response(
status_code=400,
message=message,
error_type="invalid_request_error",
param=str(param) if param else None,
)


def load_module_from_path(path: str):
"""Load the module with M program in it."""
module_name = os.path.splitext(os.path.basename(path))[0]
Expand All @@ -50,23 +83,79 @@ def create_openai_error_response(
)


def _build_model_options(request: ChatCompletionRequest) -> dict:
"""Build model_options dict from OpenAI-compatible request parameters."""
excluded_fields = {
# Request structure fields (handled separately)
"messages", # Chat messages - passed separately to serve()
"requirements", # Mellea requirements - passed separately to serve()
# Routing/metadata fields (not generation parameters)
"model", # Model identifier - used for routing, not generation
"n", # Number of completions - not supported in Mellea's model_options
"user", # User tracking ID - metadata, not a generation parameter
"extra", # Pydantic's extra fields dict - unused (see model_config)
# Not-yet-implemented OpenAI parameters (silently ignored)
"stream", # Streaming responses - not yet implemented
"stop", # Stop sequences - not yet implemented
"top_p", # Nucleus sampling - not yet implemented
"presence_penalty", # Presence penalty - not yet implemented
"frequency_penalty", # Frequency penalty - not yet implemented
"logit_bias", # Logit bias - not yet implemented
"response_format", # Response format (json_object) - not yet implemented
"functions", # Legacy function calling - not yet implemented
"function_call", # Legacy function calling - not yet implemented
"tools", # Tool calling - not yet implemented
"tool_choice", # Tool choice - not yet implemented
}
openai_to_model_option = {
"temperature": ModelOption.TEMPERATURE,
"max_tokens": ModelOption.MAX_NEW_TOKENS,
"seed": ModelOption.SEED,
}

filtered_options = {
key: value
for key, value in request.model_dump(exclude_none=True).items()
if key not in excluded_fields
}
return ModelOption.replace_keys(filtered_options, openai_to_model_option)


def make_chat_endpoint(module):
"""Makes a chat endpoint using a custom module."""

async def endpoint(request: ChatCompletionRequest):
try:
# Validate that n=1 (we don't support multiple completions)
if request.n is not None and request.n > 1:
return create_openai_error_response(
status_code=400,
message=f"Multiple completions (n={request.n}) are not supported. Please set n=1 or omit the parameter.",
error_type="invalid_request_error",
param="n",
)

completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}"
created_timestamp = int(time.time())

output = module.serve(
input=request.messages,
requirements=request.requirements,
model_options={
k: v
for k, v in request.model_dump().items()
if k not in ["messages", "requirements"]
},
)
model_options = _build_model_options(request)

# Detect if serve is async or sync and handle accordingly
if inspect.iscoroutinefunction(module.serve):
# It's async, await it directly
output = await module.serve(
input=request.messages,
requirements=request.requirements,
model_options=model_options,
)
else:
# It's sync, run in thread pool to avoid blocking event loop
output = await asyncio.to_thread(
module.serve,
input=request.messages,
requirements=request.requirements,
model_options=model_options,
)

# Extract usage information from the ModelOutputThunk if available
usage = None
Expand Down
4 changes: 4 additions & 0 deletions mellea/backends/model_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ def replace_keys(options: dict, from_to: dict[str, str]) -> dict[str, Any]:
# This will usually be a @@@<>@@@ ModelOption.<> key.
new_key = from_to.get(old_key, None)
if new_key:
# Skip if old_key and new_key are the same (no-op replacement)
if old_key == new_key:
continue

if new_options.get(new_key, None) is not None:
# The key already has a value associated with it in the dict. Leave it be.
conflict_log.append(
Expand Down
108 changes: 108 additions & 0 deletions test/cli/test_build_model_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Unit tests for _build_model_options function."""

import pytest

from cli.serve.app import _build_model_options
from cli.serve.models import ChatCompletionRequest, ChatMessage
from mellea.backends.model_options import ModelOption


class TestBuildModelOptions:
"""Direct unit tests for _build_model_options."""

def test_temperature_mapping(self):
"""Test that temperature is correctly mapped to ModelOption.TEMPERATURE."""
request = ChatCompletionRequest(
model="test-model",
messages=[ChatMessage(role="user", content="test")],
temperature=0.7,
)
options = _build_model_options(request)
assert options[ModelOption.TEMPERATURE] == 0.7

def test_max_tokens_mapping(self):
"""Test that max_tokens is correctly mapped to ModelOption.MAX_NEW_TOKENS."""
request = ChatCompletionRequest(
model="test-model",
messages=[ChatMessage(role="user", content="test")],
max_tokens=100,
)
options = _build_model_options(request)
assert options[ModelOption.MAX_NEW_TOKENS] == 100

def test_seed_mapping(self):
"""Test that seed is correctly mapped to ModelOption.SEED."""
request = ChatCompletionRequest(
model="test-model",
messages=[ChatMessage(role="user", content="test")],
seed=42,
)
options = _build_model_options(request)
assert options[ModelOption.SEED] == 42

def test_multiple_options(self):
"""Test that multiple options are correctly mapped together."""
request = ChatCompletionRequest(
model="test-model",
messages=[ChatMessage(role="user", content="test")],
temperature=0.8,
max_tokens=200,
seed=123,
)
options = _build_model_options(request)
assert options[ModelOption.TEMPERATURE] == 0.8
assert options[ModelOption.MAX_NEW_TOKENS] == 200
assert options[ModelOption.SEED] == 123

def test_excluded_fields_not_in_output(self):
"""Test that excluded fields are not included in model_options."""
request = ChatCompletionRequest(
model="test-model",
messages=[ChatMessage(role="user", content="test")],
n=1,
user="test-user",
stream=False,
temperature=0.5,
)
options = _build_model_options(request)
# Check that excluded fields are not present
assert "model" not in options
assert "messages" not in options
assert "n" not in options
assert "user" not in options
assert "stream" not in options
# Check that temperature is present
assert ModelOption.TEMPERATURE in options

def test_none_values_excluded(self):
"""Test that None values are excluded from output."""
request = ChatCompletionRequest(
model="test-model",
messages=[ChatMessage(role="user", content="test")],
temperature=None,
max_tokens=None,
)
options = _build_model_options(request)
assert ModelOption.TEMPERATURE not in options
assert ModelOption.MAX_NEW_TOKENS not in options

def test_minimal_request_includes_defaults(self):
"""Test that a minimal request includes default values like temperature."""
request = ChatCompletionRequest(
model="test-model", messages=[ChatMessage(role="user", content="test")]
)
options = _build_model_options(request)
# ChatCompletionRequest has default temperature=1.0
assert options == {ModelOption.TEMPERATURE: 1.0}

def test_requirements_excluded(self):
"""Test that requirements field is excluded from model_options."""
request = ChatCompletionRequest(
model="test-model",
messages=[ChatMessage(role="user", content="test")],
requirements=["req1", "req2"],
temperature=0.7,
)
options = _build_model_options(request)
assert "requirements" not in options
assert ModelOption.TEMPERATURE in options
Loading
Loading