From b20ef4a18ca5cbd2d3711501685ffb56b1d45c1a Mon Sep 17 00:00:00 2001 From: Regis David Souza Mesquita Date: Sat, 15 Feb 2025 23:43:32 +0000 Subject: [PATCH] Allows paths on the provider url and time formats can differ between providers so we just use current time instead of parsing downstream --- README.md | 2 +- ollama_proxy.go | 79 +++++++++++++++++++++++++++++++------------------ 2 files changed, 51 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index 1bfb2a0..d8924c2 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ go build -o proxy-server ollama_proxy.go Run the proxy server with the desired flags: ``` -./proxy-server --listen=":11434" --target="http://127.0.0.1:4000" --api-key="YOUR_API_KEY" --debug +./proxy-server --listen=":11434" --target="http://127.0.0.1:4000/v1" --api-key="YOUR_API_KEY" --debug ``` ## Command-Line Flags diff --git a/ollama_proxy.go b/ollama_proxy.go index d6697c4..b6f4f46 100644 --- a/ollama_proxy.go +++ b/ollama_proxy.go @@ -10,6 +10,7 @@ import ( "net/http" "net/http/httputil" "net/url" + "path" "strings" "time" ) @@ -18,11 +19,9 @@ import ( // Data Structures // -------------------- -// Structures used for /api/tags transformation. type DownstreamModel struct { ID string `json:"id"` Object string `json:"object"` - Created int64 `json:"created"` OwnedBy string `json:"owned_by"` } @@ -158,7 +157,7 @@ func main() { log.Fatalf("Error parsing target URL: %v", err) } - // Create a reverse proxy for /v1/models and /v1/completions. + // Create a reverse proxy for /models and /completions. proxy := httputil.NewSingleHostReverseProxy(targetUrl) originalDirector := proxy.Director proxy.Director = func(req *http.Request) { @@ -169,24 +168,23 @@ func main() { } } - // Handler for /v1/models. - http.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) { - log.Printf("Proxying /v1/models request to %s", targetUrl.String()) + // Handler for /models. + http.HandleFunc("/models", func(w http.ResponseWriter, r *http.Request) { + log.Printf("Proxying /models request to %s", targetUrl.String()) proxy.ServeHTTP(w, r) }) - // Handler for /v1/completions. - http.HandleFunc("/v1/completions", func(w http.ResponseWriter, r *http.Request) { - log.Printf("Proxying /v1/completions request to %s", targetUrl.String()) + // Handler for /completions. + http.HandleFunc("/completions", func(w http.ResponseWriter, r *http.Request) { + log.Printf("Proxying /completions request to %s", targetUrl.String()) proxy.ServeHTTP(w, r) }) // Handler for /api/tags. - // When building the list, if a model's ID does not contain a colon, - // append ":proxy" to it. http.HandleFunc("/api/tags", func(w http.ResponseWriter, r *http.Request) { - log.Printf("Handling /api/tags request by querying downstream /v1/models") - modelsURL := targetUrl.ResolveReference(&url.URL{Path: "/v1/models"}) + log.Printf("Handling /api/tags request by querying downstream /models") + modelsURL := *targetUrl + modelsURL.Path = path.Join(targetUrl.Path, "models") reqDown, err := http.NewRequest("GET", modelsURL.String(), nil) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -195,6 +193,13 @@ func main() { if *openaiApiKey != "" { reqDown.Header.Set("Authorization", "Bearer "+*openaiApiKey) } + if *debug { + if dump, err := httputil.DumpRequestOut(reqDown, true); err == nil { + log.Printf("Outgoing /models request:\n%s", dump) + } else { + log.Printf("Error dumping /models request: %v", err) + } + } client := &http.Client{} respDown, err := client.Do(reqDown) if err != nil { @@ -202,30 +207,28 @@ func main() { return } defer respDown.Body.Close() - - if respDown.StatusCode != http.StatusOK { - body, _ := io.ReadAll(respDown.Body) - http.Error(w, string(body), respDown.StatusCode) - return + if *debug { + if dump, err := httputil.DumpResponse(respDown, false); err == nil { + log.Printf("Received response from /models:\n%s", dump) + } else { + log.Printf("Error dumping /models response: %v", err) + } } - var dsResp DownstreamModelsResponse if err := json.NewDecoder(respDown.Body).Decode(&dsResp); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - var tagsResp TagsResponse for _, dm := range dsResp.Data { modelName := dm.ID - // Append ":proxy" if there is no colon in the model name. if !strings.Contains(modelName, ":") { modelName += ":proxy" } modelEntry := Model{ Name: modelName, Model: modelName, - ModifiedAt: time.Unix(dm.Created, 0).UTC().Format(time.RFC3339Nano), + ModifiedAt: time.Now().UTC().Format(time.RFC3339Nano), Size: 0, Digest: "", Details: ModelDetails{ @@ -239,7 +242,6 @@ func main() { } tagsResp.Models = append(tagsResp.Models, modelEntry) } - w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(tagsResp); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -252,7 +254,7 @@ func main() { }) // Explicit handler for /api/chat. - // This handler rewrites the URL to /v1/chat/completions, logs the outgoing payload, + // This handler rewrites the URL to /chat/completions, logs the outgoing payload, // strips any trailing ":proxy" from the model name in the request payload, // intercepts the downstream streaming response, transforms each chunk from OpenAI format // to Ollama format (stripping any ":proxy" from the model field), logs both the raw and transformed @@ -289,8 +291,9 @@ func main() { log.Printf("Warning: could not unmarshal payload for transformation: %v", err) } - // Create a new request to the downstream /v1/chat/completions endpoint. - newURL := targetUrl.ResolveReference(&url.URL{Path: "/v1/chat/completions"}) + // Create a new request with joined path: /v2/ai/chat/completions. + newURL := *targetUrl + newURL.Path = path.Join(targetUrl.Path, "chat/completions") newReq, err := http.NewRequest("POST", newURL.String(), bytes.NewReader(bodyBytes)) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -301,14 +304,32 @@ func main() { newReq.Header.Set("Authorization", "Bearer "+*openaiApiKey) } + // Log the full outgoing /api/chat request. + if *debug { + if dump, err := httputil.DumpRequestOut(newReq, true); err == nil { + log.Printf("Outgoing /api/chat request:\n%s", dump) + } else { + log.Printf("Error dumping /api/chat request: %v", err) + } + } + client := &http.Client{} resp, err := client.Do(newReq) if err != nil { http.Error(w, err.Error(), http.StatusBadGateway) return } - defer resp.Body.Close() + // Log the response headers (without draining the body). + if *debug { + if dump, err := httputil.DumpResponse(resp, false); err == nil { + log.Printf("Received response from /chat/completions:\n%s", dump) + } else { + log.Printf("Error dumping /chat/completions response: %v", err) + } + } + + defer resp.Body.Close() // Copy response headers. for key, values := range resp.Header { for _, value := range values { @@ -358,7 +379,7 @@ func main() { modelName := strings.TrimSuffix(chunk.Model, ":proxy") transformed := OllamaChunk{ Model: modelName, - CreatedAt: time.Unix(chunk.Created, 0).UTC().Format(time.RFC3339Nano), + CreatedAt: time.Now().Format(time.RFC3339), Message: Message{ Role: role, Content: content, @@ -397,7 +418,7 @@ func main() { } }) - log.Printf("Proxy server listening on %s\n- /v1/models & /v1/completions forwarded to %s\n- /api/tags dynamically transformed\n- /api/pull returns 404\n- /api/chat rewritten and transformed before forwarding to downstream (/v1/chat/completions)\n- Unknown endpoints will%s be forwarded to 127.0.0.1:11505", + log.Printf("Proxy server listening on %s\n- /models & /completions forwarded to %s\n- /api/tags dynamically transformed\n- /api/pull returns 404\n- /api/chat rewritten and transformed before forwarding to downstream (/chat/completions)\n- Unknown endpoints will%s be forwarded to 127.0.0.1:11505", *listenAddr, targetUrl.String(), func() string { if *forwardUnknown { return ""