Adds support for api show and api ps.

This commit is contained in:
Regis David Souza Mesquita 2025-03-01 18:41:23 +00:00
parent b20ef4a18c
commit 002a04b23a

View file

@ -41,17 +41,31 @@ type ModelDetails struct {
type Model struct {
Name string `json:"name"`
Model string `json:"model"`
ModifiedAt string `json:"modified_at"`
Size int64 `json:"size"`
Digest string `json:"digest"`
Details ModelDetails `json:"details"`
}
// PSModel extends Model struct with additional fields needed for the /api/ps endpoint
type PSModel struct {
Name string `json:"name"`
Model string `json:"model"`
Size int64 `json:"size"`
Digest string `json:"digest"`
Details ModelDetails `json:"details"`
ExpiresAt string `json:"expires_at"`
SizeVram int64 `json:"size_vram"`
}
type TagsResponse struct {
Models []Model `json:"models"`
}
type PSResponse struct {
Models []PSModel `json:"models"`
}
// Structures used for transforming the /api/chat response.
// OpenAIChunk represents one NDJSON chunk from the OpenAIcompatible streaming endpoint.
@ -173,12 +187,30 @@ func main() {
log.Printf("Proxying /models request to %s", targetUrl.String())
proxy.ServeHTTP(w, r)
})
// Handler for /models.
http.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) {
r.URL.Path = "/models"
log.Printf("Proxying /models request to %s", targetUrl.String())
proxy.ServeHTTP(w, r)
})
// Handler for /completions.
http.HandleFunc("/completions", func(w http.ResponseWriter, r *http.Request) {
log.Printf("Proxying /completions request to %s", targetUrl.String())
proxy.ServeHTTP(w, r)
})
// Handler for /completions.
http.HandleFunc("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
r.URL.Path = "/completions"
log.Printf("Proxying /completions request to %s", targetUrl.String())
proxy.ServeHTTP(w, r)
})
// Handler for /completions.
http.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
r.URL.Path = "/chat/completions"
log.Printf("Proxying /completions request to %s", targetUrl.String())
proxy.ServeHTTP(w, r)
})
// Handler for /api/tags.
http.HandleFunc("/api/tags", func(w http.ResponseWriter, r *http.Request) {
@ -219,40 +251,195 @@ func main() {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Properly formatted timestamp for Ollama
timeStr := time.Now().UTC().Format(time.RFC3339Nano)
var tagsResp TagsResponse
for _, dm := range dsResp.Data {
modelName := dm.ID
if !strings.Contains(modelName, ":") {
modelName += ":proxy"
}
// Create a placeholder hash as digest
digest := "9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697"
modelEntry := Model{
Name: modelName,
Model: modelName,
ModifiedAt: time.Now().UTC().Format(time.RFC3339Nano),
Size: 0,
Digest: "",
ModifiedAt: timeStr,
Size: 3825819519, // Placeholder size
Digest: digest,
Details: ModelDetails{
ParentModel: "",
Format: "unknown",
Family: "",
Families: []string{},
ParameterSize: "",
QuantizationLevel: "",
Format: "gguf",
Family: "llama",
Families: nil,
ParameterSize: "7B",
QuantizationLevel: "Q4_0",
},
}
tagsResp.Models = append(tagsResp.Models, modelEntry)
}
// If no models were found, ensure we return at least an empty array
if tagsResp.Models == nil {
tagsResp.Models = []Model{}
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(tagsResp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
})
// Handler for /api/ps to list running models
http.HandleFunc("/api/ps", func(w http.ResponseWriter, r *http.Request) {
log.Printf("Handling /api/ps request by querying downstream /models")
modelsURL := *targetUrl
modelsURL.Path = path.Join(targetUrl.Path, "models")
reqDown, err := http.NewRequest("GET", modelsURL.String(), nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if *openaiApiKey != "" {
reqDown.Header.Set("Authorization", "Bearer "+*openaiApiKey)
}
if *debug {
if dump, err := httputil.DumpRequestOut(reqDown, true); err == nil {
log.Printf("Outgoing /models request for /api/ps:\n%s", dump)
} else {
log.Printf("Error dumping /models request: %v", err)
}
}
client := &http.Client{}
respDown, err := client.Do(reqDown)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
defer respDown.Body.Close()
if *debug {
if dump, err := httputil.DumpResponse(respDown, false); err == nil {
log.Printf("Received response from /models for /api/ps:\n%s", dump)
} else {
log.Printf("Error dumping /models response: %v", err)
}
}
var dsResp DownstreamModelsResponse
if err := json.NewDecoder(respDown.Body).Decode(&dsResp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Create expiry time (24 hours from now)
expiryTime := time.Now().Add(24 * time.Hour).Format(time.RFC3339Nano)
var psResp PSResponse
for _, dm := range dsResp.Data {
modelName := dm.ID
if !strings.Contains(modelName, ":") {
modelName += ":proxy"
}
// Create a placeholder hash as digest
digest := "2ae6f6dd7a3dd734790bbbf58b8909a606e0e7e97e94b7604e0aa7ae4490e6d8"
// Create families array
families := []string{"llama"}
modelEntry := PSModel{
Name: modelName,
Model: modelName,
Size: 5137025024, // Placeholder size
Digest: digest,
Details: ModelDetails{
ParentModel: "",
Format: "gguf",
Family: "llama",
Families: families,
ParameterSize: "7.2B",
QuantizationLevel: "Q4_0",
},
ExpiresAt: expiryTime,
SizeVram: 5137025024,
}
psResp.Models = append(psResp.Models, modelEntry)
}
// If no models were found, ensure we return at least an empty array
if psResp.Models == nil {
psResp.Models = []PSModel{}
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(psResp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
})
// Explicit handler for /api/pull: return 404 instead of forwarding.
http.HandleFunc("/api/pull", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Endpoint /api/pull is not supported", http.StatusNotFound)
})
// Handler for /api/show to return model information.
http.HandleFunc("/api/show", func(w http.ResponseWriter, r *http.Request) {
log.Println("Handling /api/show request")
// Parse the model name from the query parameters
modelName := r.URL.Query().Get("model")
if modelName == "" {
modelName = "LLAMA"
}
// Strip :proxy suffix if present
modelName = strings.TrimSuffix(modelName, ":proxy")
// Create response structure
modelInfo := map[string]interface{}{
"modelfile": "# Modelfile generated by \"ollama show\"\n# To build a new Modelfile based on this one, replace the FROM line with:\n# FROM " + modelName + ":latest\n\nFROM " + modelName + "\nTEMPLATE \"\"\"{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: \"\"\"",
"parameters": "num_keep 24\nstop \"<|start_header_id|>\"\nstop \"<|end_header_id|>\"\nstop \"<|eot_id|>\"",
"template": "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>",
"details": map[string]interface{}{
"parent_model": "",
"format": "gguf",
"family": "llama",
"families": []string{"llama"},
"parameter_size": "8.0B",
"quantization_level": "Q4_0",
},
"model_info": map[string]interface{}{
"general.architecture": "llama",
"general.file_type": 2,
"general.parameter_count": 8030261248,
"general.quantization_version": 2,
"llama.attention.head_count": 32,
"llama.attention.head_count_kv": 8,
"llama.attention.layer_norm_rms_epsilon": 0.00001,
"llama.block_count": 32,
"llama.context_length": 8192,
"llama.embedding_length": 4096,
"llama.feed_forward_length": 14336,
"llama.rope.dimension_count": 128,
"llama.rope.freq_base": 500000,
"llama.vocab_size": 128256,
"tokenizer.ggml.bos_token_id": 128000,
"tokenizer.ggml.eos_token_id": 128009,
"tokenizer.ggml.merges": []string{},
"tokenizer.ggml.model": "gpt2",
"tokenizer.ggml.pre": "llama-bpe",
"tokenizer.ggml.token_type": []string{},
"tokenizer.ggml.tokens": []string{},
},
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(modelInfo); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
})
// Explicit handler for /api/chat.
// This handler rewrites the URL to /chat/completions, logs the outgoing payload,
// strips any trailing ":proxy" from the model name in the request payload,
@ -405,6 +592,194 @@ func main() {
log.Printf("Scanner error: %v", err)
}
})
// OllamaGenerateChunk represents the output format for the /api/generate endpoint.
type OllamaGenerateChunk struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Response string `json:"response"`
Done bool `json:"done"`
}
// Handler for /api/generate
// This handler rewrites the URL to /completions, logs the outgoing payload,
// strips any trailing ":proxy" from the model name in the request payload,
// intercepts the downstream streaming response, transforms each chunk from OpenAI format
// to Ollama format, and streams the transformed chunks to the client.
http.HandleFunc("/api/generate", func(w http.ResponseWriter, r *http.Request) {
log.Println("Handling /api/generate transformation")
// Read the original request body.
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
r.Body.Close()
if *debug {
log.Printf("Outgoing /api/generate payload: %s", string(bodyBytes))
}
// Unmarshal and modify the request payload: strip ":proxy" from model field
// and transform to OpenAI completions format
var payload map[string]interface{}
if err := json.Unmarshal(bodyBytes, &payload); err == nil {
// Remove unsupported fields
delete(payload, "options")
// Extract model and prompt
var model string
if modelVal, ok := payload["model"].(string); ok {
model = strings.TrimSuffix(modelVal, ":proxy")
}
var prompt string
if promptVal, ok := payload["prompt"].(string); ok {
prompt = promptVal
}
// Create a new payload in OpenAI completions format
openaiPayload := map[string]interface{}{
"model": model,
"prompt": prompt,
"stream": true,
"max_tokens": 2048, // Default value, can be configurable
}
// Re-marshal payload to OpenAI format
bodyBytes, err = json.Marshal(openaiPayload)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
} else {
log.Printf("Warning: could not unmarshal payload for transformation: %v", err)
}
// Create a new request with joined path to /completions
newURL := *targetUrl
newURL.Path = path.Join(targetUrl.Path, "completions")
newReq, err := http.NewRequest("POST", newURL.String(), bytes.NewReader(bodyBytes))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
newReq.Header = r.Header.Clone()
if *openaiApiKey != "" {
newReq.Header.Set("Authorization", "Bearer "+*openaiApiKey)
}
// Log the full outgoing request
if *debug {
if dump, err := httputil.DumpRequestOut(newReq, true); err == nil {
log.Printf("Outgoing /completions request:\n%s", dump)
} else {
log.Printf("Error dumping /completions request: %v", err)
}
}
client := &http.Client{}
resp, err := client.Do(newReq)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
// Log the response headers (without draining the body)
if *debug {
if dump, err := httputil.DumpResponse(resp, false); err == nil {
log.Printf("Received response from /completions:\n%s", dump)
} else {
log.Printf("Error dumping /completions response: %v", err)
}
}
defer resp.Body.Close()
// Copy response headers
for key, values := range resp.Header {
for _, value := range values {
w.Header().Add(key, value)
}
}
w.WriteHeader(resp.StatusCode)
// Process streaming NDJSON response
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if *debug {
log.Printf("Raw downstream chunk: %s", line)
}
// Strip off the SSE "data:" prefix if present
if strings.HasPrefix(line, "data:") {
line = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
}
// Skip if the line is empty or indicates completion
if line == "" || line == "[DONE]" {
continue
}
// Parse the JSON chunk from OpenAI completions format
var openaiChunk struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Text string `json:"text"`
Index int `json:"index"`
FinishReason *string `json:"finish_reason,omitempty"`
} `json:"choices"`
}
if err := json.Unmarshal([]byte(line), &openaiChunk); err != nil {
log.Printf("Error unmarshalling chunk: %v", err)
// In case of error, send the raw line
w.Write([]byte(line + "\n"))
continue
}
// Transform the chunk into Ollama generate format
var text string
done := false
if len(openaiChunk.Choices) > 0 {
choice := openaiChunk.Choices[0]
text = choice.Text
if choice.FinishReason != nil && *choice.FinishReason != "" {
done = true
}
}
// Strip any ":proxy" from the model name
modelName := strings.TrimSuffix(openaiChunk.Model, ":proxy")
transformed := OllamaGenerateChunk{
Model: modelName,
CreatedAt: time.Now().Format(time.RFC3339),
Response: text,
Done: done,
}
transformedLine, err := json.Marshal(transformed)
if err != nil {
log.Printf("Error marshalling transformed chunk: %v", err)
w.Write([]byte(line + "\n"))
continue
}
if *debug {
log.Printf("Transformed generate chunk: %s", string(transformedLine))
}
w.Write(transformedLine)
w.Write([]byte("\n"))
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}
if err := scanner.Err(); err != nil {
log.Printf("Scanner error: %v", err)
}
})
// Catch-all handler for any other unknown endpoints.
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
@ -418,7 +793,7 @@ func main() {
}
})
log.Printf("Proxy server listening on %s\n- /models & /completions forwarded to %s\n- /api/tags dynamically transformed\n- /api/pull returns 404\n- /api/chat rewritten and transformed before forwarding to downstream (/chat/completions)\n- Unknown endpoints will%s be forwarded to 127.0.0.1:11505",
log.Printf("Proxy server listening on %s\n- /models & /completions forwarded to %s\n- /api/tags & /api/ps dynamically transformed\n- /api/pull returns 404\n- /api/chat rewritten and transformed before forwarding to downstream (/chat/completions)\n- Unknown endpoints will%s be forwarded to 127.0.0.1:11505",
*listenAddr, targetUrl.String(), func() string {
if *forwardUnknown {
return ""