From d8401f729641f7dce9582e63cd0c431f19f54332 Mon Sep 17 00:00:00 2001 From: presbrey Date: Mon, 2 Sep 2024 22:46:47 -0400 Subject: [PATCH] server: Add catch-all proxy handler --- cmd/ollamafarmd/main.go | 7 +--- farm.go | 57 +++++++++++++++++++++---- ollama.go | 12 +++++- server/server.go | 93 ++++++++++++++++++++++++++++++++++++++--- types.go | 6 ++- 5 files changed, 153 insertions(+), 22 deletions(-) diff --git a/cmd/ollamafarmd/main.go b/cmd/ollamafarmd/main.go index 36cf9d6..b50ce6a 100644 --- a/cmd/ollamafarmd/main.go +++ b/cmd/ollamafarmd/main.go @@ -17,9 +17,6 @@ func main() { s := server.NewServer(farm) - http.HandleFunc("/version", s.VersionHandler) - http.HandleFunc("/models", s.ModelsHandler) - - log.Println("Server starting on :8080") - log.Fatal(http.ListenAndServe(":8080", nil)) + log.Println("Server starting on :11343") + log.Fatal(http.ListenAndServe(":11434", s)) } diff --git a/farm.go b/farm.go index 2746e9b..fbf075f 100644 --- a/farm.go +++ b/farm.go @@ -33,11 +33,11 @@ func NewWithOptions(options *Options) *Farm { } // RegisterClient adds a new Ollama to the Farm if it doesn't already exist. -func (f *Farm) RegisterClient(id string, client *api.Client, properties *Properties) { +func (f *Farm) RegisterClient(name string, client *api.Client, properties *Properties) { f.mu.Lock() defer f.mu.Unlock() - if _, exists := f.ollamas[id]; exists { + if _, exists := f.ollamas[name]; exists { return } @@ -49,18 +49,50 @@ func (f *Farm) RegisterClient(id string, client *api.Client, properties *Propert } ollama := &Ollama{ + name: name, + client: client, farm: f, - models: make(map[string]bool), + models: make(map[string]*api.ListModelResponse), properties: p, } - f.ollamas[id] = ollama + f.ollamas[name] = ollama go ollama.updateTickers() } -// RegisterURL adds a new Ollama to the Farm using the baseURL as the ID. -func (f *Farm) RegisterURL(baseURL string, properties *Properties) error { +// RegisterClient adds a new Ollama to the Farm if it doesn't already exist. +func (f *Farm) RegisterClientURL(name string, client *api.Client, properties *Properties, url *url.URL) { + f.mu.Lock() + defer f.mu.Unlock() + + if _, exists := f.ollamas[name]; exists { + return + } + + p := Properties{} + if properties != nil { + p.Group = properties.Group + p.Offline = properties.Offline + p.Priority = properties.Priority + } + + ollama := &Ollama{ + name: name, + url: url, + + client: client, + farm: f, + models: make(map[string]*api.ListModelResponse), + properties: p, + } + f.ollamas[name] = ollama + + go ollama.updateTickers() +} + +// RegisterNamedURL adds a new Ollama to the Farm using the given name as the ID. +func (f *Farm) RegisterNamedURL(name, baseURL string, properties *Properties) error { parsedURL, err := url.Parse(baseURL) if err != nil { return err @@ -68,10 +100,19 @@ func (f *Farm) RegisterURL(baseURL string, properties *Properties) error { client := api.NewClient(parsedURL, http.DefaultClient) - f.RegisterClient(parsedURL.String(), client, properties) + f.RegisterClientURL(name, client, properties, parsedURL) return nil } +// RegisterURL adds a new Ollama to the Farm using the baseURL as the ID. +func (f *Farm) RegisterURL(baseURL string, properties *Properties) error { + parsedURL, err := url.Parse(baseURL) + if err != nil { + return err + } + return f.RegisterNamedURL(parsedURL.String(), baseURL, properties) +} + // First returns the first Ollama that matches the given where. func (f *Farm) First(where *Where) *Ollama { f.mu.RLock() @@ -120,7 +161,7 @@ func (f *Farm) matchesWhere(ollama *Ollama, where *Where) bool { if where.Group != "" && ollama.properties.Group != where.Group { return false } - if where.Model != "" && !ollama.models[where.Model] { + if where.Model != "" && ollama.models[where.Model] == nil { return false } if where.Offline != ollama.properties.Offline { diff --git a/ollama.go b/ollama.go index 9605dee..95288f5 100644 --- a/ollama.go +++ b/ollama.go @@ -2,11 +2,19 @@ package ollamafarm import ( "context" + "net/url" "time" "github.com/ollama/ollama/api" ) +// BaseURL returns the base URL of the Ollama. +func (ollama *Ollama) BaseURL() *url.URL { + ollama.farm.mu.RLock() + defer ollama.farm.mu.RUnlock() + return ollama.url +} + // Client returns the Ollama client. func (ollama *Ollama) Client() *api.Client { ollama.farm.mu.RLock() @@ -48,11 +56,11 @@ func (ollama *Ollama) updateModels() { ollama.farm.mu.Lock() if err != nil { ollama.properties.Offline = true - ollama.models = make(map[string]bool) + ollama.models = make(map[string]*api.ListModelResponse) } else { ollama.properties.Offline = false for _, model := range listResponse.Models { - ollama.models[model.Name] = true + ollama.models[model.Name] = &model } } ollama.farm.mu.Unlock() diff --git a/server/server.go b/server/server.go index 40a6cc6..e977ad5 100644 --- a/server/server.go +++ b/server/server.go @@ -2,21 +2,32 @@ package server import ( "encoding/json" + "io" + "log" "net/http" "github.com/presbrey/ollamafarm" ) +// Server is an HTTP server that proxies requests to Ollamas on a Farm. type Server struct { - Farm *ollamafarm.Farm + farm *ollamafarm.Farm + mux *http.ServeMux } +// NewServer creates a new Server instance with the given Farm. func NewServer(farm *ollamafarm.Farm) *Server { - return &Server{Farm: farm} + s := &Server{farm: farm} + mux := http.NewServeMux() + mux.HandleFunc("/api/tags", s.handleTags) + mux.HandleFunc("/api/version", s.handleVersion) + mux.HandleFunc("/", s.catchAllPost) + s.mux = mux + return s } -func (s *Server) VersionHandler(w http.ResponseWriter, r *http.Request) { - ollama := s.Farm.First(nil) +func (s *Server) handleVersion(w http.ResponseWriter, r *http.Request) { + ollama := s.farm.First(nil) if ollama == nil { http.Error(w, "No available Ollama instances", http.StatusServiceUnavailable) return @@ -33,9 +44,79 @@ func (s *Server) VersionHandler(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(map[string]string{"version": version}) } -func (s *Server) ModelsHandler(w http.ResponseWriter, r *http.Request) { - models := s.Farm.AllModels() +func (s *Server) handleTags(w http.ResponseWriter, r *http.Request) { + models := s.farm.AllModels() w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(models) } + +func (s *Server) catchAllPost(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var body map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "Invalid JSON body", http.StatusBadRequest) + return + } + + model, ok := body["model"].(string) + if !ok { + http.Error(w, "Missing or invalid 'model' field", http.StatusBadRequest) + return + } + + ollama := s.farm.First(&ollamafarm.Where{Model: model}) + if ollama == nil { + http.Error(w, "No available Ollama instance for the specified model", http.StatusServiceUnavailable) + return + } + + // Create a new request to the selected Ollama + proxyURL := ollama.BaseURL().ResolveReference(r.URL) + proxyReq, err := http.NewRequest(r.Method, proxyURL.String(), r.Body) + if err != nil { + http.Error(w, "Error creating proxy request", http.StatusInternalServerError) + return + } + + // Copy headers + for key, values := range r.Header { + for _, value := range values { + proxyReq.Header.Add(key, value) + } + } + + // Send the request to the Ollama instance + resp, err := http.DefaultClient.Do(proxyReq) + if err != nil { + http.Error(w, "Error proxying request", http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + // Copy the response headers + for key, values := range resp.Header { + for _, value := range values { + w.Header().Add(key, value) + } + } + + // Set the status code + w.WriteHeader(resp.StatusCode) + + // Copy the response body + if _, err := io.Copy(w, resp.Body); err != nil { + // We've already started writing the response, so we can't use http.Error here + // Just log the error + log.Printf("Error copying response body: %v", err) + } +} + +// ServeHTTP implements the http.Handler interface. +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.mux.ServeHTTP(w, r) +} diff --git a/types.go b/types.go index 4a6832b..e67a99a 100644 --- a/types.go +++ b/types.go @@ -2,6 +2,7 @@ package ollamafarm import ( "net/http" + "net/url" "sync" "time" @@ -18,9 +19,12 @@ type Farm struct { // Ollama stores information about an Ollama server. type Ollama struct { + name string + url *url.URL + client *api.Client farm *Farm - models map[string]bool + models map[string]*api.ListModelResponse properties Properties }