Allows paths on the provider url and time formats can differ between providers so we just use current time instead of parsing downstream
This commit is contained in:
parent
28797ae45c
commit
b20ef4a18c
2 changed files with 51 additions and 30 deletions
|
|
@ -48,7 +48,7 @@ go build -o proxy-server ollama_proxy.go
|
|||
Run the proxy server with the desired flags:
|
||||
|
||||
```
|
||||
./proxy-server --listen=":11434" --target="http://127.0.0.1:4000" --api-key="YOUR_API_KEY" --debug
|
||||
./proxy-server --listen=":11434" --target="http://127.0.0.1:4000/v1" --api-key="YOUR_API_KEY" --debug
|
||||
```
|
||||
|
||||
## Command-Line Flags
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
|
@ -18,11 +19,9 @@ import (
|
|||
// Data Structures
|
||||
// --------------------
|
||||
|
||||
// Structures used for /api/tags transformation.
|
||||
type DownstreamModel struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
|
|
@ -158,7 +157,7 @@ func main() {
|
|||
log.Fatalf("Error parsing target URL: %v", err)
|
||||
}
|
||||
|
||||
// Create a reverse proxy for /v1/models and /v1/completions.
|
||||
// Create a reverse proxy for /models and /completions.
|
||||
proxy := httputil.NewSingleHostReverseProxy(targetUrl)
|
||||
originalDirector := proxy.Director
|
||||
proxy.Director = func(req *http.Request) {
|
||||
|
|
@ -169,24 +168,23 @@ func main() {
|
|||
}
|
||||
}
|
||||
|
||||
// Handler for /v1/models.
|
||||
http.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("Proxying /v1/models request to %s", targetUrl.String())
|
||||
// Handler for /models.
|
||||
http.HandleFunc("/models", func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("Proxying /models request to %s", targetUrl.String())
|
||||
proxy.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
// Handler for /v1/completions.
|
||||
http.HandleFunc("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("Proxying /v1/completions request to %s", targetUrl.String())
|
||||
// Handler for /completions.
|
||||
http.HandleFunc("/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("Proxying /completions request to %s", targetUrl.String())
|
||||
proxy.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
// Handler for /api/tags.
|
||||
// When building the list, if a model's ID does not contain a colon,
|
||||
// append ":proxy" to it.
|
||||
http.HandleFunc("/api/tags", func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("Handling /api/tags request by querying downstream /v1/models")
|
||||
modelsURL := targetUrl.ResolveReference(&url.URL{Path: "/v1/models"})
|
||||
log.Printf("Handling /api/tags request by querying downstream /models")
|
||||
modelsURL := *targetUrl
|
||||
modelsURL.Path = path.Join(targetUrl.Path, "models")
|
||||
reqDown, err := http.NewRequest("GET", modelsURL.String(), nil)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
|
|
@ -195,6 +193,13 @@ func main() {
|
|||
if *openaiApiKey != "" {
|
||||
reqDown.Header.Set("Authorization", "Bearer "+*openaiApiKey)
|
||||
}
|
||||
if *debug {
|
||||
if dump, err := httputil.DumpRequestOut(reqDown, true); err == nil {
|
||||
log.Printf("Outgoing /models request:\n%s", dump)
|
||||
} else {
|
||||
log.Printf("Error dumping /models request: %v", err)
|
||||
}
|
||||
}
|
||||
client := &http.Client{}
|
||||
respDown, err := client.Do(reqDown)
|
||||
if err != nil {
|
||||
|
|
@ -202,30 +207,28 @@ func main() {
|
|||
return
|
||||
}
|
||||
defer respDown.Body.Close()
|
||||
|
||||
if respDown.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(respDown.Body)
|
||||
http.Error(w, string(body), respDown.StatusCode)
|
||||
return
|
||||
if *debug {
|
||||
if dump, err := httputil.DumpResponse(respDown, false); err == nil {
|
||||
log.Printf("Received response from /models:\n%s", dump)
|
||||
} else {
|
||||
log.Printf("Error dumping /models response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
var dsResp DownstreamModelsResponse
|
||||
if err := json.NewDecoder(respDown.Body).Decode(&dsResp); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var tagsResp TagsResponse
|
||||
for _, dm := range dsResp.Data {
|
||||
modelName := dm.ID
|
||||
// Append ":proxy" if there is no colon in the model name.
|
||||
if !strings.Contains(modelName, ":") {
|
||||
modelName += ":proxy"
|
||||
}
|
||||
modelEntry := Model{
|
||||
Name: modelName,
|
||||
Model: modelName,
|
||||
ModifiedAt: time.Unix(dm.Created, 0).UTC().Format(time.RFC3339Nano),
|
||||
ModifiedAt: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Size: 0,
|
||||
Digest: "",
|
||||
Details: ModelDetails{
|
||||
|
|
@ -239,7 +242,6 @@ func main() {
|
|||
}
|
||||
tagsResp.Models = append(tagsResp.Models, modelEntry)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(tagsResp); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
|
|
@ -252,7 +254,7 @@ func main() {
|
|||
})
|
||||
|
||||
// Explicit handler for /api/chat.
|
||||
// This handler rewrites the URL to /v1/chat/completions, logs the outgoing payload,
|
||||
// This handler rewrites the URL to /chat/completions, logs the outgoing payload,
|
||||
// strips any trailing ":proxy" from the model name in the request payload,
|
||||
// intercepts the downstream streaming response, transforms each chunk from OpenAI format
|
||||
// to Ollama format (stripping any ":proxy" from the model field), logs both the raw and transformed
|
||||
|
|
@ -289,8 +291,9 @@ func main() {
|
|||
log.Printf("Warning: could not unmarshal payload for transformation: %v", err)
|
||||
}
|
||||
|
||||
// Create a new request to the downstream /v1/chat/completions endpoint.
|
||||
newURL := targetUrl.ResolveReference(&url.URL{Path: "/v1/chat/completions"})
|
||||
// Create a new request with joined path: /v2/ai/chat/completions.
|
||||
newURL := *targetUrl
|
||||
newURL.Path = path.Join(targetUrl.Path, "chat/completions")
|
||||
newReq, err := http.NewRequest("POST", newURL.String(), bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
|
|
@ -301,14 +304,32 @@ func main() {
|
|||
newReq.Header.Set("Authorization", "Bearer "+*openaiApiKey)
|
||||
}
|
||||
|
||||
// Log the full outgoing /api/chat request.
|
||||
if *debug {
|
||||
if dump, err := httputil.DumpRequestOut(newReq, true); err == nil {
|
||||
log.Printf("Outgoing /api/chat request:\n%s", dump)
|
||||
} else {
|
||||
log.Printf("Error dumping /api/chat request: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(newReq)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Log the response headers (without draining the body).
|
||||
if *debug {
|
||||
if dump, err := httputil.DumpResponse(resp, false); err == nil {
|
||||
log.Printf("Received response from /chat/completions:\n%s", dump)
|
||||
} else {
|
||||
log.Printf("Error dumping /chat/completions response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
// Copy response headers.
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
|
|
@ -358,7 +379,7 @@ func main() {
|
|||
modelName := strings.TrimSuffix(chunk.Model, ":proxy")
|
||||
transformed := OllamaChunk{
|
||||
Model: modelName,
|
||||
CreatedAt: time.Unix(chunk.Created, 0).UTC().Format(time.RFC3339Nano),
|
||||
CreatedAt: time.Now().Format(time.RFC3339),
|
||||
Message: Message{
|
||||
Role: role,
|
||||
Content: content,
|
||||
|
|
@ -397,7 +418,7 @@ func main() {
|
|||
}
|
||||
})
|
||||
|
||||
log.Printf("Proxy server listening on %s\n- /v1/models & /v1/completions forwarded to %s\n- /api/tags dynamically transformed\n- /api/pull returns 404\n- /api/chat rewritten and transformed before forwarding to downstream (/v1/chat/completions)\n- Unknown endpoints will%s be forwarded to 127.0.0.1:11505",
|
||||
log.Printf("Proxy server listening on %s\n- /models & /completions forwarded to %s\n- /api/tags dynamically transformed\n- /api/pull returns 404\n- /api/chat rewritten and transformed before forwarding to downstream (/chat/completions)\n- Unknown endpoints will%s be forwarded to 127.0.0.1:11505",
|
||||
*listenAddr, targetUrl.String(), func() string {
|
||||
if *forwardUnknown {
|
||||
return ""
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue