diff --git a/envconfig/config.go b/envconfig/config.go index fbd881ba7..d867bdac2 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -53,8 +53,8 @@ func Host() *url.URL { } } -// Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable. -func Origins() (origins []string) { +// AllowedOrigins returns a list of allowed origins. AllowedOrigins can be configured via the OLLAMA_ORIGINS environment variable. +func AllowedOrigins() (origins []string) { if s := Var("OLLAMA_ORIGINS"); s != "" { origins = strings.Split(s, ",") } @@ -249,7 +249,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"}, "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"}, "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"}, - "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"}, + "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, diff --git a/envconfig/config_test.go b/envconfig/config_test.go index 735b45405..993ddd9ca 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -134,7 +134,7 @@ func TestOrigins(t *testing.T) { t.Run(tt.value, func(t *testing.T) { t.Setenv("OLLAMA_ORIGINS", tt.value) - if diff := cmp.Diff(Origins(), tt.expect); diff != "" { + if diff := cmp.Diff(AllowedOrigins(), tt.expect); diff != "" { t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff) } }) diff --git a/server/routes.go b/server/routes.go index 9cefb6077..de72f847f 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1127,54 +1127,72 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc { } func (s *Server) GenerateRoutes() http.Handler { - config := cors.DefaultConfig() - config.AllowWildcard = true - config.AllowBrowserExtensions = true - config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"} - openAIProperties := []string{"lang", "package-version", "os", "arch", "retry-count", "runtime", "runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval", "timeout"} - for _, prop := range openAIProperties { - config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop) + corsConfig := cors.DefaultConfig() + corsConfig.AllowWildcard = true + corsConfig.AllowBrowserExtensions = true + corsConfig.AllowHeaders = []string{ + "Authorization", + "Content-Type", + "User-Agent", + "Accept", + "X-Requested-With", + + // OpenAI compatibility headers + "x-stainless-lang", + "x-stainless-package-version", + "x-stainless-os", + "x-stainless-arch", + "x-stainless-retry-count", + "x-stainless-runtime", + "x-stainless-runtime-version", + "x-stainless-async", + "x-stainless-helper-method", + "x-stainless-poll-helper", + "x-stainless-custom-poll-interval", + "x-stainless-timeout", } - config.AllowOrigins = envconfig.Origins() + corsConfig.AllowOrigins = envconfig.AllowedOrigins() r := gin.Default() r.Use( - cors.New(config), + cors.New(corsConfig), allowedHostsMiddleware(s.addr), ) + // General + r.HEAD("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") }) + r.GET("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") }) + r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) }) + r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) }) + + // Local model cache management r.POST("/api/pull", s.PullHandler) + r.POST("/api/push", s.PushHandler) + r.DELETE("/api/delete", s.DeleteHandler) + r.HEAD("/api/tags", s.ListHandler) + r.GET("/api/tags", s.ListHandler) + r.POST("/api/show", s.ShowHandler) + + // Create + r.POST("/api/create", s.CreateHandler) + r.POST("/api/blobs/:digest", s.CreateBlobHandler) + r.HEAD("/api/blobs/:digest", s.HeadBlobHandler) + r.POST("/api/copy", s.CopyHandler) + + // Inference + r.GET("/api/ps", s.PsHandler) r.POST("/api/generate", s.GenerateHandler) r.POST("/api/chat", s.ChatHandler) r.POST("/api/embed", s.EmbedHandler) r.POST("/api/embeddings", s.EmbeddingsHandler) - r.POST("/api/create", s.CreateHandler) - r.POST("/api/push", s.PushHandler) - r.POST("/api/copy", s.CopyHandler) - r.DELETE("/api/delete", s.DeleteHandler) - r.POST("/api/show", s.ShowHandler) - r.POST("/api/blobs/:digest", s.CreateBlobHandler) - r.HEAD("/api/blobs/:digest", s.HeadBlobHandler) - r.GET("/api/ps", s.PsHandler) - // Compatibility endpoints + // Inference (OpenAI compatibility) r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler) r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler) r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler) r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler) r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler) - for _, method := range []string{http.MethodGet, http.MethodHead} { - r.Handle(method, "/", func(c *gin.Context) { - c.String(http.StatusOK, "Ollama is running") - }) - - r.Handle(method, "/api/tags", s.ListHandler) - r.Handle(method, "/api/version", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"version": version.Version}) - }) - } - return r }