Dont pull dont forwards unless told so.

This commit is contained in:
Regis David Souza Mesquita 2025-02-15 11:30:25 +00:00
parent 92f7ff7058
commit 73d03c95e9

View file

@ -149,6 +149,7 @@ func main() {
targetUrlStr := flag.String("target", "http://127.0.0.1:4000", "Target OpenAI-compatible server URL")
openaiApiKey := flag.String("api-key", "", "OpenAI API key (optional)")
debug := flag.Bool("debug", false, "Print debug logs for every call")
forwardUnknown := flag.Bool("forward-unknown", false, "Forward unknown endpoints to local Ollama instance at 127.0.0.1:11505")
flag.Parse()
// Parse the target URL.
@ -162,7 +163,7 @@ func main() {
originalDirector := proxy.Director
proxy.Director = func(req *http.Request) {
originalDirector(req)
// For downstream endpoints, also strip any ":proxy" from the model field in the URL query or headers if needed.
// For downstream endpoints, also set the API key if provided.
if *openaiApiKey != "" {
req.Header.Set("Authorization", "Bearer "+*openaiApiKey)
}
@ -225,7 +226,7 @@ func main() {
Name: modelName,
Model: modelName,
ModifiedAt: time.Unix(dm.Created, 0).UTC().Format(time.RFC3339Nano),
Size: 1337,
Size: 0,
Digest: "",
Details: ModelDetails{
ParentModel: "",
@ -245,10 +246,9 @@ func main() {
}
})
// Explicit handler for /api/pull: forward to the local Ollama instance.
// Explicit handler for /api/pull: return 404 instead of forwarding.
http.HandleFunc("/api/pull", func(w http.ResponseWriter, r *http.Request) {
log.Println("Handling /api/pull")
forwardToOllama(w, r)
http.Error(w, "Endpoint /api/pull is not supported", http.StatusNotFound)
})
// Explicit handler for /api/chat.
@ -283,7 +283,6 @@ func main() {
return
}
} else {
// If unmarshalling fails, continue with the original bytes.
log.Printf("Warning: could not unmarshal payload for transformation: %v", err)
}
@ -385,10 +384,20 @@ func main() {
// Catch-all handler for any other unknown endpoints.
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
forwardToOllama(w, r)
if *forwardUnknown {
forwardToOllama(w, r)
} else {
http.NotFound(w, r)
}
})
log.Printf("Proxy server listening on %s\n- /v1/models & /v1/completions forwarded to %s\n- /api/tags dynamically transformed\n- /api/pull explicitly forwarded to 127.0.0.1:11505\n- /api/chat rewritten and transformed before forwarding to downstream (/v1/chat/completions)\n- Other unknown endpoints are also forwarded", *listenAddr, targetUrl.String())
log.Printf("Proxy server listening on %s\n- /v1/models & /v1/completions forwarded to %s\n- /api/tags dynamically transformed\n- /api/pull returns 404\n- /api/chat rewritten and transformed before forwarding to downstream (/v1/chat/completions)\n- Unknown endpoints will%s be forwarded to 127.0.0.1:11505",
*listenAddr, targetUrl.String(), func() string {
if *forwardUnknown {
return ""
}
return " NOT"
}())
if err := http.ListenAndServe(*listenAddr, logMiddleware(*debug, http.DefaultServeMux)); err != nil {
log.Fatalf("Server failed: %v", err)
}