diff --git a/server/cloud_proxy.go b/server/cloud_proxy.go index d09ca182e..170c350a6 100644 --- a/server/cloud_proxy.go +++ b/server/cloud_proxy.go @@ -348,10 +348,10 @@ func requiresCloudAnthropicChatFallback(path string, body []byte) bool { return false } - return hasAnthropicWebSearchTool(body) || hasAnthropicToolResultImage(body) + return hasAnthropicWebSearchTool(body) || hasAnthropicToolResultBase64Image(body) } -func hasAnthropicToolResultImage(body []byte) bool { +func hasAnthropicToolResultBase64Image(body []byte) bool { if len(body) == 0 { return false } @@ -378,7 +378,7 @@ func hasAnthropicToolResultImage(body []byte) bool { if strings.TrimSpace(block.Type) != "tool_result" { continue } - if anthropicToolResultContentHasImage(block.Content) { + if anthropicToolResultContentHasBase64Image(block.Content) { return true } } @@ -387,26 +387,32 @@ func hasAnthropicToolResultImage(body []byte) bool { return false } -func anthropicToolResultContentHasImage(raw json.RawMessage) bool { +func anthropicToolResultContentHasBase64Image(raw json.RawMessage) bool { if len(raw) == 0 || bytes.Equal(bytes.TrimSpace(raw), []byte("null")) { return false } var blocks []struct { - Type string `json:"type"` + Type string `json:"type"` + Source *struct { + Type string `json:"type"` + } `json:"source"` } if err := json.Unmarshal(raw, &blocks); err == nil { for _, block := range blocks { - if strings.TrimSpace(block.Type) == "image" { + if strings.TrimSpace(block.Type) == "image" && block.Source != nil && strings.TrimSpace(block.Source.Type) == "base64" { return true } } } var block struct { - Type string `json:"type"` + Type string `json:"type"` + Source *struct { + Type string `json:"type"` + } `json:"source"` } - if err := json.Unmarshal(raw, &block); err == nil && strings.TrimSpace(block.Type) == "image" { + if err := json.Unmarshal(raw, &block); err == nil && strings.TrimSpace(block.Type) == "image" && block.Source != nil && strings.TrimSpace(block.Source.Type) == "base64" { return true } diff --git a/server/routes_cloud_test.go b/server/routes_cloud_test.go index f0a23ad50..8664d1dde 100644 --- a/server/routes_cloud_test.go +++ b/server/routes_cloud_test.go @@ -863,6 +863,72 @@ func TestExplicitCloudPassthroughAPIAndV1(t *testing.T) { } }) + t.Run("v1 messages tool_result url image bypasses conversion", func(t *testing.T) { + upstream, capture := newUpstream(t, `{"id":"msg_1","type":"message"}`) + defer upstream.Close() + + original := cloudProxyBaseURL + cloudProxyBaseURL = upstream.URL + t.Cleanup(func() { cloudProxyBaseURL = original }) + + s := &Server{} + router, err := s.GenerateRoutes(nil) + if err != nil { + t.Fatal(err) + } + local := httptest.NewServer(router) + defer local.Close() + + reqBody := `{ + "model":"gpt-oss:120b-cloud", + "max_tokens":10, + "messages":[{ + "role":"user", + "content":[{ + "type":"tool_result", + "tool_use_id":"call_456", + "content":[ + {"type":"text","text":"Here is the screenshot:"}, + {"type":"image","source":{"type":"url","url":"https://example.com/image.png"}} + ] + }] + }], + "stream":false + }` + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/messages?beta=true", bytes.NewBufferString(reqBody)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := local.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body)) + } + + if capture.path != "/v1/messages" { + t.Fatalf("expected upstream path /v1/messages for url image passthrough, got %q", capture.path) + } + + if !strings.Contains(capture.body, `"type":"tool_result"`) { + t.Fatalf("expected original anthropic request body, got %q", capture.body) + } + + if !strings.Contains(capture.body, `"type":"url"`) { + t.Fatalf("expected url image source in upstream body, got %q", capture.body) + } + + if strings.Contains(capture.body, `"num_predict":10`) { + t.Fatalf("expected no converted ollama options in upstream body, got %q", capture.body) + } + }) + t.Run("v1 model retrieve bypasses conversion", func(t *testing.T) { upstream, capture := newUpstream(t, `{"id":"kimi-k2.5:cloud","object":"model","created":1,"owned_by":"ollama"}`) defer upstream.Close() @@ -1248,6 +1314,74 @@ func TestCloudPassthroughSkipsAnthropicToolResultImages(t *testing.T) { } } +func TestCloudPassthroughDoesNotSkipAnthropicToolResultURLImages(t *testing.T) { + gin.SetMode(gin.TestMode) + setTestHome(t, t.TempDir()) + + type upstreamCapture struct { + path string + } + capture := &upstreamCapture{} + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capture.path = r.URL.Path + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"id":"msg_1","type":"message"}`)) + })) + defer upstream.Close() + + original := cloudProxyBaseURL + cloudProxyBaseURL = upstream.URL + t.Cleanup(func() { cloudProxyBaseURL = original }) + + router := gin.New() + router.POST( + "/v1/messages", + cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), + middleware.AnthropicMessagesMiddleware(), + func(c *gin.Context) { c.Status(http.StatusTeapot) }, + ) + + local := httptest.NewServer(router) + defer local.Close() + + reqBody := `{ + "model":"kimi-k2.5:cloud", + "max_tokens":10, + "messages":[{ + "role":"user", + "content":[{ + "type":"tool_result", + "tool_use_id":"call_456", + "content":[ + {"type":"text","text":"Here is the screenshot:"}, + {"type":"image","source":{"type":"url","url":"https://example.com/image.png"}} + ] + }] + }] + }` + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/v1/messages", bytes.NewBufferString(reqBody)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := local.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected passthrough response status 200, got %d (%s)", resp.StatusCode, string(body)) + } + + if capture.path != "/v1/messages" { + t.Fatalf("expected passthrough to upstream /v1/messages for url images, got %q", capture.path) + } +} + func TestCloudPassthroughSigningFailureReturnsUnauthorized(t *testing.T) { gin.SetMode(gin.TestMode) setTestHome(t, t.TempDir())