From 4f1afd575d1dfd803b0d9abb995862d61e8d0734 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 3 Jul 2024 16:44:57 -0700 Subject: [PATCH] host --- api/client.go | 8 +-- cmd/cmd.go | 2 +- envconfig/config.go | 107 ++++++++++++++++----------------------- envconfig/config_test.go | 62 +++++++++-------------- 4 files changed, 71 insertions(+), 108 deletions(-) diff --git a/api/client.go b/api/client.go index c59fbc423..e02b21bfa 100644 --- a/api/client.go +++ b/api/client.go @@ -20,7 +20,6 @@ import ( "encoding/json" "fmt" "io" - "net" "net/http" "net/url" "runtime" @@ -63,13 +62,8 @@ func checkError(resp *http.Response, body []byte) error { // If the variable is not specified, a default ollama host and port will be // used. func ClientFromEnvironment() (*Client, error) { - ollamaHost := envconfig.Host - return &Client{ - base: &url.URL{ - Scheme: ollamaHost.Scheme, - Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port), - }, + base: envconfig.Host(), http: http.DefaultClient, }, nil } diff --git a/cmd/cmd.go b/cmd/cmd.go index b761d018f..5f3735f4f 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1076,7 +1076,7 @@ func RunServer(cmd *cobra.Command, _ []string) error { return err } - ln, err := net.Listen("tcp", net.JoinHostPort(envconfig.Host.Host, envconfig.Host.Port)) + ln, err := net.Listen("tcp", envconfig.Host().Host) if err != nil { return err } diff --git a/envconfig/config.go b/envconfig/config.go index 426507be0..23f932706 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -6,6 +6,7 @@ import ( "log/slog" "math" "net" + "net/url" "os" "path/filepath" "runtime" @@ -14,16 +15,6 @@ import ( "time" ) -type OllamaHost struct { - Scheme string - Host string - Port string -} - -func (o OllamaHost) String() string { - return fmt.Sprintf("%s://%s:%s", o.Scheme, o.Host, o.Port) -} - var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST") // Debug returns true if the OLLAMA_DEBUG environment variable is set to a truthy value. @@ -41,13 +32,54 @@ func Debug() bool { return false } +// Host returns the scheme and host. Host can be configured via the OLLAMA_HOST environment variable. +// Default is scheme "http" and host "127.0.0.1:11434" +func Host() *url.URL { + defaultPort := "11434" + + s := os.Getenv("OLLAMA_HOST") + s = strings.TrimSpace(strings.Trim(strings.TrimSpace(s), "\"'")) + scheme, hostport, ok := strings.Cut(s, "://") + switch { + case !ok: + scheme, hostport = "http", s + case scheme == "http": + defaultPort = "80" + case scheme == "https": + defaultPort = "443" + } + + // trim trailing slashes + hostport = strings.TrimRight(hostport, "/") + + host, port, err := net.SplitHostPort(hostport) + if err != nil { + host, port = "127.0.0.1", defaultPort + if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil { + host = ip.String() + } else if hostport != "" { + host = hostport + } + } + + if n, err := strconv.ParseInt(port, 10, 32); err != nil || n > 65535 || n < 0 { + return &url.URL{ + Scheme: scheme, + Host: net.JoinHostPort(host, defaultPort), + } + } + + return &url.URL{ + Scheme: scheme, + Host: net.JoinHostPort(host, port), + } +} + var ( // Set via OLLAMA_ORIGINS in the environment AllowOrigins []string // Experimental flash attention FlashAttention bool - // Set via OLLAMA_HOST in the environment - Host *OllamaHost // Set via OLLAMA_KEEP_ALIVE in the environment KeepAlive time.Duration // Set via OLLAMA_LLM_LIBRARY in the environment @@ -95,7 +127,7 @@ func AsMap() map[string]EnvVar { ret := map[string]EnvVar{ "OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"}, "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention, "Enabled flash attention"}, - "OLLAMA_HOST": {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"}, + "OLLAMA_HOST": {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"}, "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"}, "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"}, "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"}, @@ -271,11 +303,6 @@ func LoadConfig() { slog.Error("invalid setting", "OLLAMA_MODELS", ModelsDir, "error", err) } - Host, err = getOllamaHost() - if err != nil { - slog.Error("invalid setting", "OLLAMA_HOST", Host, "error", err, "using default port", Host.Port) - } - if set, err := strconv.ParseBool(clean("OLLAMA_INTEL_GPU")); err == nil { IntelGpu = set } @@ -298,50 +325,6 @@ func getModelsDir() (string, error) { return filepath.Join(home, ".ollama", "models"), nil } -func getOllamaHost() (*OllamaHost, error) { - defaultPort := "11434" - - hostVar := os.Getenv("OLLAMA_HOST") - hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'")) - - scheme, hostport, ok := strings.Cut(hostVar, "://") - switch { - case !ok: - scheme, hostport = "http", hostVar - case scheme == "http": - defaultPort = "80" - case scheme == "https": - defaultPort = "443" - } - - // trim trailing slashes - hostport = strings.TrimRight(hostport, "/") - - host, port, err := net.SplitHostPort(hostport) - if err != nil { - host, port = "127.0.0.1", defaultPort - if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil { - host = ip.String() - } else if hostport != "" { - host = hostport - } - } - - if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 { - return &OllamaHost{ - Scheme: scheme, - Host: host, - Port: defaultPort, - }, ErrInvalidHostPort - } - - return &OllamaHost{ - Scheme: scheme, - Host: host, - Port: port, - }, nil -} - func loadKeepAlive(ka string) { v, err := strconv.Atoi(ka) if err != nil { diff --git a/envconfig/config_test.go b/envconfig/config_test.go index f083bb03c..af89e7b75 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -1,13 +1,10 @@ package envconfig import ( - "fmt" "math" - "net" "testing" "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -42,45 +39,34 @@ func TestConfig(t *testing.T) { } func TestClientFromEnvironment(t *testing.T) { - type testCase struct { + cases := map[string]struct { value string expect string - err error + }{ + "empty": {"", "127.0.0.1:11434"}, + "only address": {"1.2.3.4", "1.2.3.4:11434"}, + "only port": {":1234", ":1234"}, + "address and port": {"1.2.3.4:1234", "1.2.3.4:1234"}, + "hostname": {"example.com", "example.com:11434"}, + "hostname and port": {"example.com:1234", "example.com:1234"}, + "zero port": {":0", ":0"}, + "too large port": {":66000", ":11434"}, + "too small port": {":-1", ":11434"}, + "ipv6 localhost": {"[::1]", "[::1]:11434"}, + "ipv6 world open": {"[::]", "[::]:11434"}, + "ipv6 no brackets": {"::1", "[::1]:11434"}, + "ipv6 + port": {"[::1]:1337", "[::1]:1337"}, + "extra space": {" 1.2.3.4 ", "1.2.3.4:11434"}, + "extra quotes": {"\"1.2.3.4\"", "1.2.3.4:11434"}, + "extra space+quotes": {" \" 1.2.3.4 \" ", "1.2.3.4:11434"}, + "extra single quotes": {"'1.2.3.4'", "1.2.3.4:11434"}, } - hostTestCases := map[string]*testCase{ - "empty": {value: "", expect: "127.0.0.1:11434"}, - "only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"}, - "only port": {value: ":1234", expect: ":1234"}, - "address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"}, - "hostname": {value: "example.com", expect: "example.com:11434"}, - "hostname and port": {value: "example.com:1234", expect: "example.com:1234"}, - "zero port": {value: ":0", expect: ":0"}, - "too large port": {value: ":66000", err: ErrInvalidHostPort}, - "too small port": {value: ":-1", err: ErrInvalidHostPort}, - "ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"}, - "ipv6 world open": {value: "[::]", expect: "[::]:11434"}, - "ipv6 no brackets": {value: "::1", expect: "[::1]:11434"}, - "ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"}, - "extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"}, - "extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"}, - "extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"}, - "extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"}, - } - - for k, v := range hostTestCases { - t.Run(k, func(t *testing.T) { - t.Setenv("OLLAMA_HOST", v.value) - LoadConfig() - - oh, err := getOllamaHost() - if err != v.err { - t.Fatalf("expected %s, got %s", v.err, err) - } - - if err == nil { - host := net.JoinHostPort(oh.Host, oh.Port) - assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host)) + for name, tt := range cases { + t.Run(name, func(t *testing.T) { + t.Setenv("OLLAMA_HOST", tt.value) + if host := Host(); host.Host != tt.expect { + t.Errorf("%s: expected %s, got %s", name, tt.expect, host.Host) } }) }