From 5d88d7d85ee2520eeb24fef2aa225efbbbbf56b3 Mon Sep 17 00:00:00 2001 From: Forge Date: Sun, 26 Apr 2026 10:43:38 -0700 Subject: [PATCH] fix codeql url and cursor json hardening Co-authored-by: Codex --- .../api/handlers/management/api_call.go | 6 +- .../api/handlers/management/api_call_guard.go | 50 +++++++++++ .../management/api_call_transport_test.go | 50 +++++++++++ .../api/handlers/management/api_call_url.go | 25 ++++-- .../api/handlers/management/http_transport.go | 14 +++ pkg/llmproxy/executor/cursor_executor.go | 61 +++++++------ pkg/llmproxy/executor/cursor_executor_test.go | 79 +++++++++++++++++ pkg/llmproxy/executor/cursor_json.go | 88 +++++++++++++++++++ 8 files changed, 330 insertions(+), 43 deletions(-) create mode 100644 pkg/llmproxy/api/handlers/management/api_call_guard.go create mode 100644 pkg/llmproxy/api/handlers/management/api_call_transport_test.go create mode 100644 pkg/llmproxy/executor/cursor_executor_test.go create mode 100644 pkg/llmproxy/executor/cursor_json.go diff --git a/pkg/llmproxy/api/handlers/management/api_call.go b/pkg/llmproxy/api/handlers/management/api_call.go index ce027b0469..f921378a96 100644 --- a/pkg/llmproxy/api/handlers/management/api_call.go +++ b/pkg/llmproxy/api/handlers/management/api_call.go @@ -201,11 +201,7 @@ func (h *Handler) APICall(c *gin.Context) { req.Host = hostOverride } - httpClient := &http.Client{ - Timeout: defaultAPICallTimeout, - } - httpClient.Transport = h.apiCallTransport(auth) - + httpClient := h.apiCallHTTPClient(auth) resp, errDo := httpClient.Do(req) if errDo != nil { log.WithError(errDo).Debug("management APICall request failed") diff --git a/pkg/llmproxy/api/handlers/management/api_call_guard.go b/pkg/llmproxy/api/handlers/management/api_call_guard.go new file mode 100644 index 0000000000..2139cca585 --- /dev/null +++ b/pkg/llmproxy/api/handlers/management/api_call_guard.go @@ -0,0 +1,50 @@ +package management + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" +) + +func guardedAPICallDialContext(ctx context.Context, network string, addr string) (net.Conn, error) { + host, port, errSplit := net.SplitHostPort(addr) + if errSplit != nil { + return nil, fmt.Errorf("invalid dial address: %w", errSplit) + } + resolved, errResolve := resolveAllowedAPICallHostIPs(host) + if errResolve != nil { + return nil, errResolve + } + if len(resolved) == 0 { + return nil, fmt.Errorf("target host resolution failed") + } + dialer := &net.Dialer{} + return dialer.DialContext(ctx, network, net.JoinHostPort(resolved[0].IP.String(), port)) +} + +type apiCallGuardedRoundTripper struct { + base http.RoundTripper +} + +func (t apiCallGuardedRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if req == nil { + return nil, fmt.Errorf("request is nil") + } + if errValidate := validateAPICallRequestURL(req.URL); errValidate != nil { + return nil, errValidate + } + base := t.base + if base == nil { + base = http.DefaultTransport + } + return base.RoundTrip(req) +} + +func validateAPICallRequestURL(reqURL *url.URL) error { + if errValidate := validateAPICallURL(reqURL); errValidate != nil { + return errValidate + } + return validateResolvedHostIPs(reqURL.Hostname()) +} diff --git a/pkg/llmproxy/api/handlers/management/api_call_transport_test.go b/pkg/llmproxy/api/handlers/management/api_call_transport_test.go new file mode 100644 index 0000000000..0f6c20d007 --- /dev/null +++ b/pkg/llmproxy/api/handlers/management/api_call_transport_test.go @@ -0,0 +1,50 @@ +package management + +import ( + "errors" + "net/http" + "net/url" + "testing" +) + +type apiCallRoundTripFunc func(*http.Request) (*http.Response, error) + +func (f apiCallRoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func TestAPICallGuardedRoundTripperRejectsUnsafeRequestURL(t *testing.T) { + t.Parallel() + + called := false + transport := apiCallGuardedRoundTripper{ + base: apiCallRoundTripFunc(func(*http.Request) (*http.Response, error) { + called = true + return nil, errors.New("base transport should not run") + }), + } + + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8317/ping", nil) + if err != nil { + t.Fatalf("new request: %v", err) + } + if _, err := transport.RoundTrip(req); err == nil { + t.Fatalf("RoundTrip error = nil, want unsafe target rejection") + } + if called { + t.Fatal("base transport ran for unsafe URL") + } +} + +func TestAPICallRequestURLValidationRejectsUnsafeRedirectURL(t *testing.T) { + t.Parallel() + + redirectURL, err := url.Parse("http://localhost:8317/ping") + if err != nil { + t.Fatalf("parse redirect url: %v", err) + } + req := &http.Request{URL: redirectURL} + if err := validateAPICallRequestURL(req.URL); err == nil { + t.Fatalf("validation error = nil, want unsafe redirect rejection") + } +} diff --git a/pkg/llmproxy/api/handlers/management/api_call_url.go b/pkg/llmproxy/api/handlers/management/api_call_url.go index 343cac5b3e..fcd85e461c 100644 --- a/pkg/llmproxy/api/handlers/management/api_call_url.go +++ b/pkg/llmproxy/api/handlers/management/api_call_url.go @@ -1,6 +1,7 @@ package management import ( + "context" "fmt" "net" "net/url" @@ -59,23 +60,33 @@ func sanitizeAPICallURL(raw string) (string, *url.URL, error) { } func validateResolvedHostIPs(host string) error { + _, err := resolveAllowedAPICallHostIPs(host) + return err +} + +func resolveAllowedAPICallHostIPs(host string) ([]net.IPAddr, error) { trimmed := strings.TrimSpace(host) if trimmed == "" { - return fmt.Errorf("invalid url host") + return nil, fmt.Errorf("invalid url host") } - resolved, errLookup := net.LookupIP(trimmed) + resolved, errLookup := net.DefaultResolver.LookupIPAddr(context.Background(), trimmed) if errLookup != nil { - return fmt.Errorf("target host resolution failed") + return nil, fmt.Errorf("target host resolution failed") } + allowed := make([]net.IPAddr, 0, len(resolved)) for _, ip := range resolved { - if ip == nil { + if ip.IP == nil { continue } - if ip.IsLoopback() || ip.IsPrivate() || ip.IsUnspecified() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { - return fmt.Errorf("target host is not allowed") + if ip.IP.IsLoopback() || ip.IP.IsPrivate() || ip.IP.IsUnspecified() || ip.IP.IsMulticast() || ip.IP.IsLinkLocalUnicast() || ip.IP.IsLinkLocalMulticast() { + return nil, fmt.Errorf("target host is not allowed") } + allowed = append(allowed, ip) } - return nil + if len(allowed) == 0 { + return nil, fmt.Errorf("target host resolution failed") + } + return allowed, nil } func isAllowedHostOverride(parsedURL *url.URL, override string) bool { diff --git a/pkg/llmproxy/api/handlers/management/http_transport.go b/pkg/llmproxy/api/handlers/management/http_transport.go index 92be850508..adcf81ba30 100644 --- a/pkg/llmproxy/api/handlers/management/http_transport.go +++ b/pkg/llmproxy/api/handlers/management/http_transport.go @@ -46,9 +46,23 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper { } clone := transport.Clone() clone.Proxy = nil + clone.DialContext = guardedAPICallDialContext return clone } +func (h *Handler) apiCallHTTPClient(auth *coreauth.Auth) *http.Client { + return &http.Client{ + Timeout: defaultAPICallTimeout, + Transport: apiCallGuardedRoundTripper{base: h.apiCallTransport(auth)}, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return fmt.Errorf("stopped after 10 redirects") + } + return validateAPICallRequestURL(req.URL) + }, + } +} + func buildProxyTransportWithError(proxyStr string) (*http.Transport, error) { proxyStr = strings.TrimSpace(proxyStr) if proxyStr == "" { diff --git a/pkg/llmproxy/executor/cursor_executor.go b/pkg/llmproxy/executor/cursor_executor.go index 286de93fa0..71b516d9e2 100644 --- a/pkg/llmproxy/executor/cursor_executor.go +++ b/pkg/llmproxy/executor/cursor_executor.go @@ -333,11 +333,13 @@ func (e *CursorExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r id := "chatcmpl-" + uuid.New().String()[:28] created := time.Now().Unix() - openaiResp := fmt.Sprintf(`{"id":"%s","object":"chat.completion","created":%d,"model":"%s","choices":[{"index":0,"message":{"role":"assistant","content":%s},"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`, - id, created, parsed.Model, jsonString(fullText.String())) + openaiResp, errMarshal := cursorCompletionJSON(id, created, parsed.Model, fullText.String()) + if errMarshal != nil { + return resp, fmt.Errorf("cursor: failed to encode response: %w", errMarshal) + } // Translate response back to source format if needed - result := []byte(openaiResp) + result := openaiResp if from.String() != "" && from.String() != "openai" { var param any result = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), payload, result, ¶m) @@ -536,13 +538,13 @@ func (e *CursorExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A // Wrap sendChunk/sendDone to use emitToOut sendChunkSwitchable := func(delta string, finishReason string) { - fr := "null" - if finishReason != "" { - fr = finishReason + openaiJSON, errMarshal := cursorChunkJSON(chatId, created, parsed.Model, json.RawMessage(delta), finishReason) + if errMarshal != nil { + log.Warnf("cursor: failed to encode stream chunk: %v", errMarshal) + return } - openaiJSON := fmt.Sprintf(`{"id":"%s","object":"chat.completion.chunk","created":%d,"model":"%s","choices":[{"index":0,"delta":%s,"finish_reason":%s}]}`, - chatId, created, parsed.Model, delta, fr) - sseLine := []byte("data: " + openaiJSON + "\n") + sseLine := append([]byte("data: "), openaiJSON...) + sseLine = append(sseLine, '\n') if needsTranslate { translated := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, payload, sseLine, &streamParam) @@ -550,7 +552,7 @@ func (e *CursorExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A emitToOut(cliproxyexecutor.StreamChunk{Payload: bytes.Clone(t)}) } } else { - emitToOut(cliproxyexecutor.StreamChunk{Payload: []byte(openaiJSON)}) + emitToOut(cliproxyexecutor.StreamChunk{Payload: openaiJSON}) } } @@ -595,13 +597,13 @@ func (e *CursorExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A thinkingActive = true sendChunkSwitchable(`{"role":"assistant","content":""}`, "") } - sendChunkSwitchable(fmt.Sprintf(`{"content":%s}`, jsonString(text)), "") + sendChunkSwitchable(cursorContentDeltaJSON(text), "") } else { if thinkingActive { thinkingActive = false sendChunkSwitchable(`{"content":""}`, "") } - sendChunkSwitchable(fmt.Sprintf(`{"content":%s}`, jsonString(text)), "") + sendChunkSwitchable(cursorContentDeltaJSON(text), "") } }, func(exec pendingMcpExec) { @@ -609,11 +611,10 @@ func (e *CursorExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A thinkingActive = false sendChunkSwitchable(`{"content":""}`, "") } - toolCallJSON := fmt.Sprintf(`{"tool_calls":[{"index":%d,"id":"%s","type":"function","function":{"name":"%s","arguments":%s}}]}`, - toolCallIndex, exec.ToolCallId, exec.ToolName, jsonString(exec.Args)) + toolCallJSON := cursorToolCallDeltaJSON(toolCallIndex, exec.ToolCallId, exec.ToolName, exec.Args) toolCallIndex++ sendChunkSwitchable(toolCallJSON, "") - sendChunkSwitchable(`{}`, `"tool_calls"`) + sendChunkSwitchable(`{}`, "tool_calls") sendDoneSwitchable() // Close current output to end the current HTTP SSE response @@ -701,23 +702,22 @@ func (e *CursorExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A } // Include token usage in the final stop chunk inputTok, outputTok := usage.get() - stopDelta := fmt.Sprintf(`{},"usage":{"prompt_tokens":%d,"completion_tokens":%d,"total_tokens":%d}`, - inputTok, outputTok, inputTok+outputTok) - // Build the stop chunk with usage embedded in the choices array level - fr := `"stop"` - openaiJSON := fmt.Sprintf(`{"id":"%s","object":"chat.completion.chunk","created":%d,"model":"%s","choices":[{"index":0,"delta":{},"finish_reason":%s}],"usage":{"prompt_tokens":%d,"completion_tokens":%d,"total_tokens":%d}}`, - chatId, created, parsed.Model, fr, inputTok, outputTok, inputTok+outputTok) - sseLine := []byte("data: " + openaiJSON + "\n") + openaiJSON, errMarshal := cursorUsageChunkJSON(chatId, created, parsed.Model, inputTok, outputTok) + if errMarshal != nil { + log.Warnf("cursor: failed to encode usage chunk: %v", errMarshal) + openaiJSON = []byte(`{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`) + } + sseLine := append([]byte("data: "), openaiJSON...) + sseLine = append(sseLine, '\n') if needsTranslate { translated := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, payload, sseLine, &streamParam) for _, t := range translated { emitToOut(cliproxyexecutor.StreamChunk{Payload: bytes.Clone(t)}) } } else { - emitToOut(cliproxyexecutor.StreamChunk{Payload: []byte(openaiJSON)}) + emitToOut(cliproxyexecutor.StreamChunk{Payload: openaiJSON}) } sendDoneSwitchable() - _ = stopDelta // unused // Close whatever output channel is still active outMu.Lock() @@ -1436,16 +1436,15 @@ func deriveSessionKey(clientKey string, model string, messages []gjson.Result) s } func sseChunk(id string, created int64, model string, delta string, finishReason string) cliproxyexecutor.StreamChunk { - fr := "null" - if finishReason != "" { - fr = finishReason - } // Note: the framework's WriteChunk adds "data: " prefix and "\n\n" suffix, // so we only output the raw JSON here. - data := fmt.Sprintf(`{"id":"%s","object":"chat.completion.chunk","created":%d,"model":"%s","choices":[{"index":0,"delta":%s,"finish_reason":%s}]}`, - id, created, model, delta, fr) + data, err := cursorChunkJSON(id, created, model, json.RawMessage(delta), finishReason) + if err != nil { + log.Warnf("cursor: failed to encode sse chunk: %v", err) + data = []byte(`{"choices":[{"index":0,"delta":{},"finish_reason":null}]}`) + } return cliproxyexecutor.StreamChunk{ - Payload: []byte(data), + Payload: data, } } diff --git a/pkg/llmproxy/executor/cursor_executor_test.go b/pkg/llmproxy/executor/cursor_executor_test.go new file mode 100644 index 0000000000..d9ae7b8072 --- /dev/null +++ b/pkg/llmproxy/executor/cursor_executor_test.go @@ -0,0 +1,79 @@ +package executor + +import ( + "encoding/json" + "testing" +) + +func TestCursorCompletionJSONEscapesModelAndContent(t *testing.T) { + t.Parallel() + + payload, err := cursorCompletionJSON("chatcmpl-test", 1700000000, `x","pwned":true,"y":"`, `hi "there"`) + if err != nil { + t.Fatalf("cursorCompletionJSON: %v", err) + } + + var got map[string]any + if err := json.Unmarshal(payload, &got); err != nil { + t.Fatalf("unmarshal payload: %v; payload=%s", err, payload) + } + if got["model"] != `x","pwned":true,"y":"` { + t.Fatalf("model = %q", got["model"]) + } + if _, ok := got["pwned"]; ok { + t.Fatalf("payload allowed model to inject top-level field: %s", payload) + } +} + +func TestCursorChunkJSONEscapesModelAndFinishReason(t *testing.T) { + t.Parallel() + + payload, err := cursorChunkJSON( + "chatcmpl-test", + 1700000000, + `x","pwned":true,"y":"`, + json.RawMessage(`{"content":"ok"}`), + `stop","pwned":true,"x":"`, + ) + if err != nil { + t.Fatalf("cursorChunkJSON: %v", err) + } + + var got struct { + Model string `json:"model"` + Pwned bool `json:"pwned"` + Choices []struct { + FinishReason string `json:"finish_reason"` + } `json:"choices"` + } + if err := json.Unmarshal(payload, &got); err != nil { + t.Fatalf("unmarshal payload: %v; payload=%s", err, payload) + } + if got.Model != `x","pwned":true,"y":"` { + t.Fatalf("model = %q", got.Model) + } + if got.Pwned { + t.Fatalf("payload allowed model to inject top-level field: %s", payload) + } + if got.Choices[0].FinishReason != `stop","pwned":true,"x":"` { + t.Fatalf("finish_reason = %q", got.Choices[0].FinishReason) + } +} + +func TestCursorToolCallDeltaJSONEscapesToolIdentifiers(t *testing.T) { + t.Parallel() + + payload := cursorToolCallDeltaJSON( + 0, + `call_1","pwned":true,"x":"`, + `tool","pwned":true,"x":"`, + `{"ok":true}`, + ) + var got map[string]any + if err := json.Unmarshal([]byte(payload), &got); err != nil { + t.Fatalf("unmarshal payload: %v; payload=%s", err, payload) + } + if _, ok := got["pwned"]; ok { + t.Fatalf("payload allowed tool metadata to inject top-level field: %s", payload) + } +} diff --git a/pkg/llmproxy/executor/cursor_json.go b/pkg/llmproxy/executor/cursor_json.go new file mode 100644 index 0000000000..6d7c0ac2fb --- /dev/null +++ b/pkg/llmproxy/executor/cursor_json.go @@ -0,0 +1,88 @@ +package executor + +import "encoding/json" + +func cursorCompletionJSON(id string, created int64, model string, content string) ([]byte, error) { + return json.Marshal(map[string]any{ + "id": id, + "object": "chat.completion", + "created": created, + "model": model, + "choices": []map[string]any{{ + "index": 0, + "message": map[string]any{ + "role": "assistant", + "content": content, + }, + "finish_reason": "stop", + }}, + "usage": map[string]int{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + }) +} + +func cursorChunkJSON(id string, created int64, model string, delta json.RawMessage, finishReason string) ([]byte, error) { + var parsedDelta any + if err := json.Unmarshal(delta, &parsedDelta); err != nil { + return nil, err + } + var finish any + if finishReason != "" { + finish = finishReason + } + return json.Marshal(map[string]any{ + "id": id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": []map[string]any{{ + "index": 0, + "delta": parsedDelta, + "finish_reason": finish, + }}, + }) +} + +func cursorUsageChunkJSON(id string, created int64, model string, inputTokens int64, outputTokens int64) ([]byte, error) { + return json.Marshal(map[string]any{ + "id": id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": []map[string]any{{ + "index": 0, + "delta": map[string]any{}, + "finish_reason": "stop", + }}, + "usage": map[string]int64{ + "prompt_tokens": inputTokens, + "completion_tokens": outputTokens, + "total_tokens": inputTokens + outputTokens, + }, + }) +} + +func cursorContentDeltaJSON(content string) string { + return string(mustCursorMarshal(map[string]any{"content": content})) +} + +func cursorToolCallDeltaJSON(index int, id string, name string, args string) string { + return string(mustCursorMarshal(map[string]any{ + "tool_calls": []map[string]any{{ + "index": index, + "id": id, + "type": "function", + "function": map[string]any{ + "name": name, + "arguments": args, + }, + }}, + })) +} + +func mustCursorMarshal(v any) []byte { + b, err := json.Marshal(v) + if err != nil { + return []byte("{}") + } + return b +}