Allows paths on the provider url and time formats can differ between providers so we just use current time instead of parsing downstream

This commit is contained in:
Regis David Souza Mesquita 2025-02-15 23:43:32 +00:00
parent 28797ae45c
commit b20ef4a18c
2 changed files with 51 additions and 30 deletions

View file

@ -48,7 +48,7 @@ go build -o proxy-server ollama_proxy.go
Run the proxy server with the desired flags:
```
./proxy-server --listen=":11434" --target="http://127.0.0.1:4000" --api-key="YOUR_API_KEY" --debug
./proxy-server --listen=":11434" --target="http://127.0.0.1:4000/v1" --api-key="YOUR_API_KEY" --debug
```
## Command-Line Flags

View file

@ -10,6 +10,7 @@ import (
"net/http"
"net/http/httputil"
"net/url"
"path"
"strings"
"time"
)
@ -18,11 +19,9 @@ import (
// Data Structures
// --------------------
// Structures used for /api/tags transformation.
type DownstreamModel struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
@ -158,7 +157,7 @@ func main() {
log.Fatalf("Error parsing target URL: %v", err)
}
// Create a reverse proxy for /v1/models and /v1/completions.
// Create a reverse proxy for /models and /completions.
proxy := httputil.NewSingleHostReverseProxy(targetUrl)
originalDirector := proxy.Director
proxy.Director = func(req *http.Request) {
@ -169,24 +168,23 @@ func main() {
}
}
// Handler for /v1/models.
http.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) {
log.Printf("Proxying /v1/models request to %s", targetUrl.String())
// Handler for /models.
http.HandleFunc("/models", func(w http.ResponseWriter, r *http.Request) {
log.Printf("Proxying /models request to %s", targetUrl.String())
proxy.ServeHTTP(w, r)
})
// Handler for /v1/completions.
http.HandleFunc("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
log.Printf("Proxying /v1/completions request to %s", targetUrl.String())
// Handler for /completions.
http.HandleFunc("/completions", func(w http.ResponseWriter, r *http.Request) {
log.Printf("Proxying /completions request to %s", targetUrl.String())
proxy.ServeHTTP(w, r)
})
// Handler for /api/tags.
// When building the list, if a model's ID does not contain a colon,
// append ":proxy" to it.
http.HandleFunc("/api/tags", func(w http.ResponseWriter, r *http.Request) {
log.Printf("Handling /api/tags request by querying downstream /v1/models")
modelsURL := targetUrl.ResolveReference(&url.URL{Path: "/v1/models"})
log.Printf("Handling /api/tags request by querying downstream /models")
modelsURL := *targetUrl
modelsURL.Path = path.Join(targetUrl.Path, "models")
reqDown, err := http.NewRequest("GET", modelsURL.String(), nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@ -195,6 +193,13 @@ func main() {
if *openaiApiKey != "" {
reqDown.Header.Set("Authorization", "Bearer "+*openaiApiKey)
}
if *debug {
if dump, err := httputil.DumpRequestOut(reqDown, true); err == nil {
log.Printf("Outgoing /models request:\n%s", dump)
} else {
log.Printf("Error dumping /models request: %v", err)
}
}
client := &http.Client{}
respDown, err := client.Do(reqDown)
if err != nil {
@ -202,30 +207,28 @@ func main() {
return
}
defer respDown.Body.Close()
if respDown.StatusCode != http.StatusOK {
body, _ := io.ReadAll(respDown.Body)
http.Error(w, string(body), respDown.StatusCode)
return
if *debug {
if dump, err := httputil.DumpResponse(respDown, false); err == nil {
log.Printf("Received response from /models:\n%s", dump)
} else {
log.Printf("Error dumping /models response: %v", err)
}
}
var dsResp DownstreamModelsResponse
if err := json.NewDecoder(respDown.Body).Decode(&dsResp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
var tagsResp TagsResponse
for _, dm := range dsResp.Data {
modelName := dm.ID
// Append ":proxy" if there is no colon in the model name.
if !strings.Contains(modelName, ":") {
modelName += ":proxy"
}
modelEntry := Model{
Name: modelName,
Model: modelName,
ModifiedAt: time.Unix(dm.Created, 0).UTC().Format(time.RFC3339Nano),
ModifiedAt: time.Now().UTC().Format(time.RFC3339Nano),
Size: 0,
Digest: "",
Details: ModelDetails{
@ -239,7 +242,6 @@ func main() {
}
tagsResp.Models = append(tagsResp.Models, modelEntry)
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(tagsResp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@ -252,7 +254,7 @@ func main() {
})
// Explicit handler for /api/chat.
// This handler rewrites the URL to /v1/chat/completions, logs the outgoing payload,
// This handler rewrites the URL to /chat/completions, logs the outgoing payload,
// strips any trailing ":proxy" from the model name in the request payload,
// intercepts the downstream streaming response, transforms each chunk from OpenAI format
// to Ollama format (stripping any ":proxy" from the model field), logs both the raw and transformed
@ -289,8 +291,9 @@ func main() {
log.Printf("Warning: could not unmarshal payload for transformation: %v", err)
}
// Create a new request to the downstream /v1/chat/completions endpoint.
newURL := targetUrl.ResolveReference(&url.URL{Path: "/v1/chat/completions"})
// Create a new request with joined path: /v2/ai/chat/completions.
newURL := *targetUrl
newURL.Path = path.Join(targetUrl.Path, "chat/completions")
newReq, err := http.NewRequest("POST", newURL.String(), bytes.NewReader(bodyBytes))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@ -301,14 +304,32 @@ func main() {
newReq.Header.Set("Authorization", "Bearer "+*openaiApiKey)
}
// Log the full outgoing /api/chat request.
if *debug {
if dump, err := httputil.DumpRequestOut(newReq, true); err == nil {
log.Printf("Outgoing /api/chat request:\n%s", dump)
} else {
log.Printf("Error dumping /api/chat request: %v", err)
}
}
client := &http.Client{}
resp, err := client.Do(newReq)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
defer resp.Body.Close()
// Log the response headers (without draining the body).
if *debug {
if dump, err := httputil.DumpResponse(resp, false); err == nil {
log.Printf("Received response from /chat/completions:\n%s", dump)
} else {
log.Printf("Error dumping /chat/completions response: %v", err)
}
}
defer resp.Body.Close()
// Copy response headers.
for key, values := range resp.Header {
for _, value := range values {
@ -358,7 +379,7 @@ func main() {
modelName := strings.TrimSuffix(chunk.Model, ":proxy")
transformed := OllamaChunk{
Model: modelName,
CreatedAt: time.Unix(chunk.Created, 0).UTC().Format(time.RFC3339Nano),
CreatedAt: time.Now().Format(time.RFC3339),
Message: Message{
Role: role,
Content: content,
@ -397,7 +418,7 @@ func main() {
}
})
log.Printf("Proxy server listening on %s\n- /v1/models & /v1/completions forwarded to %s\n- /api/tags dynamically transformed\n- /api/pull returns 404\n- /api/chat rewritten and transformed before forwarding to downstream (/v1/chat/completions)\n- Unknown endpoints will%s be forwarded to 127.0.0.1:11505",
log.Printf("Proxy server listening on %s\n- /models & /completions forwarded to %s\n- /api/tags dynamically transformed\n- /api/pull returns 404\n- /api/chat rewritten and transformed before forwarding to downstream (/chat/completions)\n- Unknown endpoints will%s be forwarded to 127.0.0.1:11505",
*listenAddr, targetUrl.String(), func() string {
if *forwardUnknown {
return ""