diff --git a/ollama_proxy.go b/ollama_proxy.go index 1b958e7..6311410 100644 --- a/ollama_proxy.go +++ b/ollama_proxy.go @@ -41,6 +41,7 @@ type ModelDetails struct { type Model struct { Name string `json:"name"` + Model string `json:"model"` ModifiedAt string `json:"modified_at"` Size int64 `json:"size"` Digest string `json:"digest"` @@ -185,13 +186,57 @@ func main() { // Handler for /models. http.HandleFunc("/models", func(w http.ResponseWriter, r *http.Request) { log.Printf("Proxying /models request to %s", targetUrl.String()) - proxy.ServeHTTP(w, r) + http.Error(w, "404 page not found", http.StatusNotFound) + // proxy.ServeHTTP(w, r) }) - // Handler for /models. + // Handler for /v1/models with ":proxy" appended to model names http.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) { - r.URL.Path = "/models" - log.Printf("Proxying /models request to %s", targetUrl.String()) - proxy.ServeHTTP(w, r) + log.Printf("Handling /v1/models request with :proxy appended to model names") + modelsURL := *targetUrl + modelsURL.Path = path.Join(targetUrl.Path, "models") + reqDown, err := http.NewRequest("GET", modelsURL.String(), nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if *openaiApiKey != "" { + reqDown.Header.Set("Authorization", "Bearer "+*openaiApiKey) + } + if *debug { + if dump, err := httputil.DumpRequestOut(reqDown, true); err == nil { + log.Printf("Outgoing /models request:\n%s", dump) + } else { + log.Printf("Error dumping /models request: %v", err) + } + } + client := &http.Client{} + respDown, err := client.Do(reqDown) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + defer respDown.Body.Close() + if *debug { + if dump, err := httputil.DumpResponse(respDown, false); err == nil { + log.Printf("Received response from /models:\n%s", dump) + } else { + log.Printf("Error dumping /models response: %v", err) + } + } + var dsResp DownstreamModelsResponse + if err := json.NewDecoder(respDown.Body).Decode(&dsResp); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + for i, dm := range dsResp.Data { + if !strings.Contains(dm.ID, ":") { + dsResp.Data[i].ID = dm.ID + ":proxy" + } + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(dsResp); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } }) // Handler for /completions. @@ -267,6 +312,7 @@ func main() { modelEntry := Model{ Name: modelName, + Model: modelName, ModifiedAt: timeStr, Size: 3825819519, // Placeholder size Digest: digest,